diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-07-26 19:03:47 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-07-26 19:04:23 +0000 |
commit | 7fa27ce4a07f19b07799a767fc29416f3b625afb (patch) | |
tree | 27825c83636c4de341eb09a74f49f5d38a15d165 /llvm/lib/Transforms | |
parent | e3b557809604d036af6e00c60f012c2025b59a5e (diff) | |
download | src-7fa27ce4a07f19b07799a767fc29416f3b625afb.tar.gz src-7fa27ce4a07f19b07799a767fc29416f3b625afb.zip |
Diffstat (limited to 'llvm/lib/Transforms')
228 files changed, 29185 insertions, 20475 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 473b41241b8a..34c8a380448e 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -18,6 +18,8 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -27,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -64,7 +67,6 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { // shift amount. auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1, Value *&ShAmt) { - Value *SubAmt; unsigned Width = V->getType()->getScalarSizeInBits(); // fshl(ShVal0, ShVal1, ShAmt) @@ -72,8 +74,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { if (match(V, m_OneUse(m_c_Or( m_Shl(m_Value(ShVal0), m_Value(ShAmt)), m_LShr(m_Value(ShVal1), - m_Sub(m_SpecificInt(Width), m_Value(SubAmt))))))) { - if (ShAmt == SubAmt) // TODO: Use m_Specific + m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) { return Intrinsic::fshl; } @@ -81,9 +82,8 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt)) if (match(V, m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width), - m_Value(SubAmt))), - m_LShr(m_Value(ShVal1), m_Value(ShAmt)))))) { - if (ShAmt == SubAmt) // TODO: Use m_Specific + m_Value(ShAmt))), + m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) { return Intrinsic::fshr; } @@ -305,7 +305,7 @@ static bool tryToRecognizePopCount(Instruction &I) { Value *MulOp0; // Matching "(i * 0x01010101...) >> 24". if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) && - match(Op1, m_SpecificInt(MaskShift))) { + match(Op1, m_SpecificInt(MaskShift))) { Value *ShiftOp0; // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)), @@ -398,51 +398,6 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { return true; } -/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids -/// pessimistic codegen that has to account for setting errno and can enable -/// vectorization. -static bool -foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) { - // Match a call to sqrt mathlib function. - auto *Call = dyn_cast<CallInst>(&I); - if (!Call) - return false; - - Module *M = Call->getModule(); - LibFunc Func; - if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) - return false; - - if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) - return false; - - // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created - // (because NNAN or the operand arg must not be less than -0.0) and (2) we - // would not end up lowering to a libcall anyway (which could change the value - // of errno), then: - // (1) errno won't be set. - // (2) it is safe to convert this to an intrinsic call. - Type *Ty = Call->getType(); - Value *Arg = Call->getArgOperand(0); - if (TTI.haveFastSqrt(Ty) && - (Call->hasNoNaNs() || CannotBeOrderedLessThanZero(Arg, &TLI))) { - IRBuilder<> Builder(&I); - IRBuilderBase::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Call->getFastMathFlags()); - - Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); - Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); - I.replaceAllUsesWith(NewSqrt); - - // Explicitly erase the old call because a call with side effects is not - // trivially dead. - I.eraseFromParent(); - return true; - } - - return false; -} - // Check if this array of constants represents a cttz table. // Iterate over the elements from \p Table by trying to find/match all // the numbers from 0 to \p InputBits that should represent cttz results. @@ -613,7 +568,7 @@ struct LoadOps { LoadInst *RootInsert = nullptr; bool FoundRoot = false; uint64_t LoadSize = 0; - Value *Shift = nullptr; + const APInt *Shift = nullptr; Type *ZextType; AAMDNodes AATags; }; @@ -623,7 +578,7 @@ struct LoadOps { // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, AliasAnalysis &AA) { - Value *ShAmt2 = nullptr; + const APInt *ShAmt2 = nullptr; Value *X; Instruction *L1, *L2; @@ -631,7 +586,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, if (match(V, m_OneUse(m_c_Or( m_Value(X), m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), - m_Value(ShAmt2)))))) || + m_APInt(ShAmt2)))))) || match(V, m_OneUse(m_Or(m_Value(X), m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) @@ -642,11 +597,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, // Check if the pattern has loads LoadInst *LI1 = LOps.Root; - Value *ShAmt1 = LOps.Shift; + const APInt *ShAmt1 = LOps.Shift; if (LOps.FoundRoot == false && (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), - m_Value(ShAmt1)))))) { + m_APInt(ShAmt1)))))) { LI1 = dyn_cast<LoadInst>(L1); } LoadInst *LI2 = dyn_cast<LoadInst>(L2); @@ -721,12 +676,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, std::swap(ShAmt1, ShAmt2); // Find Shifts values. - const APInt *Temp; uint64_t Shift1 = 0, Shift2 = 0; - if (ShAmt1 && match(ShAmt1, m_APInt(Temp))) - Shift1 = Temp->getZExtValue(); - if (ShAmt2 && match(ShAmt2, m_APInt(Temp))) - Shift2 = Temp->getZExtValue(); + if (ShAmt1) + Shift1 = ShAmt1->getZExtValue(); + if (ShAmt2) + Shift2 = ShAmt2->getZExtValue(); // First load is always LI1. This is where we put the new load. // Use the merged load size available from LI1 for forward loads. @@ -768,7 +722,8 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, // pattern which suggests that the loads can be combined. The one and only use // of the loads is to form a wider load. static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, - TargetTransformInfo &TTI, AliasAnalysis &AA) { + TargetTransformInfo &TTI, AliasAnalysis &AA, + const DominatorTree &DT) { // Only consider load chains of scalar values. if (isa<VectorType>(I.getType())) return false; @@ -793,17 +748,18 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, if (!Allowed || !Fast) return false; - // Make sure the Load pointer of type GEP/non-GEP is above insert point - Instruction *Inst = dyn_cast<Instruction>(LI1->getPointerOperand()); - if (Inst && Inst->getParent() == LI1->getParent() && - !Inst->comesBefore(LOps.RootInsert)) - Inst->moveBefore(LOps.RootInsert); - - // New load can be generated + // Get the Index and Ptr for the new GEP. Value *Load1Ptr = LI1->getPointerOperand(); Builder.SetInsertPoint(LOps.RootInsert); - Value *NewPtr = Builder.CreateBitCast(Load1Ptr, WiderType->getPointerTo(AS)); - NewLoad = Builder.CreateAlignedLoad(WiderType, NewPtr, LI1->getAlign(), + if (!DT.dominates(Load1Ptr, LOps.RootInsert)) { + APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); + Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets( + DL, Offset1, /* AllowNonInbounds */ true); + Load1Ptr = Builder.CreateGEP(Builder.getInt8Ty(), Load1Ptr, + Builder.getInt32(Offset1.getZExtValue())); + } + // Generate wider load. + NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(), LI1->isVolatile(), ""); NewLoad->takeName(LI1); // Set the New Load AATags Metadata. @@ -818,18 +774,254 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, // Check if shift needed. We need to shift with the amount of load1 // shift if not zero. if (LOps.Shift) - NewOp = Builder.CreateShl(NewOp, LOps.Shift); + NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift)); I.replaceAllUsesWith(NewOp); return true; } +// Calculate GEP Stride and accumulated const ModOffset. Return Stride and +// ModOffset +static std::pair<APInt, APInt> +getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) { + unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); + std::optional<APInt> Stride; + APInt ModOffset(BW, 0); + // Return a minimum gep stride, greatest common divisor of consective gep + // index scales(c.f. Bézout's identity). + while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) { + MapVector<Value *, APInt> VarOffsets; + if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset)) + break; + + for (auto [V, Scale] : VarOffsets) { + // Only keep a power of two factor for non-inbounds + if (!GEP->isInBounds()) + Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero()); + + if (!Stride) + Stride = Scale; + else + Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale); + } + + PtrOp = GEP->getPointerOperand(); + } + + // Check whether pointer arrives back at Global Variable via at least one GEP. + // Even if it doesn't, we can check by alignment. + if (!isa<GlobalVariable>(PtrOp) || !Stride) + return {APInt(BW, 1), APInt(BW, 0)}; + + // In consideration of signed GEP indices, non-negligible offset become + // remainder of division by minimum GEP stride. + ModOffset = ModOffset.srem(*Stride); + if (ModOffset.isNegative()) + ModOffset += *Stride; + + return {*Stride, ModOffset}; +} + +/// If C is a constant patterned array and all valid loaded results for given +/// alignment are same to a constant, return that constant. +static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { + auto *LI = dyn_cast<LoadInst>(&I); + if (!LI || LI->isVolatile()) + return false; + + // We can only fold the load if it is from a constant global with definitive + // initializer. Skip expensive logic if this is not the case. + auto *PtrOp = LI->getPointerOperand(); + auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp)); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return false; + + // Bail for large initializers in excess of 4K to avoid too many scans. + Constant *C = GV->getInitializer(); + uint64_t GVSize = DL.getTypeAllocSize(C->getType()); + if (!GVSize || 4096 < GVSize) + return false; + + Type *LoadTy = LI->getType(); + unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); + auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL); + + // Any possible offset could be multiple of GEP stride. And any valid + // offset is multiple of load alignment, so checking only multiples of bigger + // one is sufficient to say results' equality. + if (auto LA = LI->getAlign(); + LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) { + ConstOffset = APInt(BW, 0); + Stride = APInt(BW, LA.value()); + } + + Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL); + if (!Ca) + return false; + + unsigned E = GVSize - DL.getTypeStoreSize(LoadTy); + for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride) + if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL)) + return false; + + I.replaceAllUsesWith(Ca); + + return true; +} + +/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids +/// pessimistic codegen that has to account for setting errno and can enable +/// vectorization. +static bool foldSqrt(CallInst *Call, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, AssumptionCache &AC, + DominatorTree &DT) { + Module *M = Call->getModule(); + + // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created + // (because NNAN or the operand arg must not be less than -0.0) and (2) we + // would not end up lowering to a libcall anyway (which could change the value + // of errno), then: + // (1) errno won't be set. + // (2) it is safe to convert this to an intrinsic call. + Type *Ty = Call->getType(); + Value *Arg = Call->getArgOperand(0); + if (TTI.haveFastSqrt(Ty) && + (Call->hasNoNaNs() || + cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, Call, + &DT))) { + IRBuilder<> Builder(Call); + IRBuilderBase::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(Call->getFastMathFlags()); + + Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); + Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); + Call->replaceAllUsesWith(NewSqrt); + + // Explicitly erase the old call because a call with side effects is not + // trivially dead. + Call->eraseFromParent(); + return true; + } + + return false; +} + +/// Try to expand strcmp(P, "x") calls. +static bool expandStrcmp(CallInst *CI, DominatorTree &DT, bool &MadeCFGChange) { + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + + // Trivial cases are optimized during inst combine + if (Str1P == Str2P) + return false; + + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); + + Value *NonConstantP = nullptr; + StringRef ConstantStr; + + if (!HasStr1 && HasStr2 && Str2.size() == 1) { + NonConstantP = Str1P; + ConstantStr = Str2; + } else if (!HasStr2 && HasStr1 && Str1.size() == 1) { + NonConstantP = Str2P; + ConstantStr = Str1; + } else { + return false; + } + + // Check if strcmp result is only used in a comparison with zero + if (!isOnlyUsedInZeroComparison(CI)) + return false; + + // For strcmp(P, "x") do the following transformation: + // + // (before) + // dst = strcmp(P, "x") + // + // (after) + // v0 = P[0] - 'x' + // [if v0 == 0] + // v1 = P[1] + // dst = phi(v0, v1) + // + + IRBuilder<> B(CI->getParent()); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + + Type *RetType = CI->getType(); + + B.SetInsertPoint(CI); + BasicBlock *InitialBB = B.GetInsertBlock(); + Value *Str1FirstCharacterValue = + B.CreateZExt(B.CreateLoad(B.getInt8Ty(), NonConstantP), RetType); + Value *Str2FirstCharacterValue = + ConstantInt::get(RetType, static_cast<unsigned char>(ConstantStr[0])); + Value *FirstCharacterSub = + B.CreateNSWSub(Str1FirstCharacterValue, Str2FirstCharacterValue); + Value *IsFirstCharacterSubZero = + B.CreateICmpEQ(FirstCharacterSub, ConstantInt::get(RetType, 0)); + Instruction *IsFirstCharacterSubZeroBBTerminator = SplitBlockAndInsertIfThen( + IsFirstCharacterSubZero, CI, /*Unreachable*/ false, + /*BranchWeights*/ nullptr, &DTU); + + B.SetInsertPoint(IsFirstCharacterSubZeroBBTerminator); + B.GetInsertBlock()->setName("strcmp_expand_sub_is_zero"); + BasicBlock *IsFirstCharacterSubZeroBB = B.GetInsertBlock(); + Value *Str1SecondCharacterValue = B.CreateZExt( + B.CreateLoad(B.getInt8Ty(), B.CreateConstInBoundsGEP1_64( + B.getInt8Ty(), NonConstantP, 1)), + RetType); + + B.SetInsertPoint(CI); + B.GetInsertBlock()->setName("strcmp_expand_sub_join"); + + PHINode *Result = B.CreatePHI(RetType, 2); + Result->addIncoming(FirstCharacterSub, InitialBB); + Result->addIncoming(Str1SecondCharacterValue, IsFirstCharacterSubZeroBB); + + CI->replaceAllUsesWith(Result); + CI->eraseFromParent(); + + MadeCFGChange = true; + + return true; +} + +static bool foldLibraryCalls(Instruction &I, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, DominatorTree &DT, + AssumptionCache &AC, bool &MadeCFGChange) { + CallInst *CI = dyn_cast<CallInst>(&I); + if (!CI) + return false; + + LibFunc Func; + Module *M = I.getModule(); + if (!TLI.getLibFunc(*CI, Func) || !isLibFuncEmittable(M, &TLI, Func)) + return false; + + switch (Func) { + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + return foldSqrt(CI, TTI, TLI, AC, DT); + case LibFunc_strcmp: + return expandStrcmp(CI, DT, MadeCFGChange); + default: + break; + } + + return false; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI, AliasAnalysis &AA) { + TargetLibraryInfo &TLI, AliasAnalysis &AA, + AssumptionCache &AC, bool &MadeCFGChange) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. @@ -849,11 +1041,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, MadeChange |= tryToRecognizePopCount(I); MadeChange |= tryToFPToSat(I, TTI); MadeChange |= tryToRecognizeTableBasedCttz(I); - MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA); + MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); + MadeChange |= foldPatternedLoads(I, DL); // NOTE: This function introduces erasing of the instruction `I`, so it // needs to be called at the end of this sequence, otherwise we may make // bugs. - MadeChange |= foldSqrt(I, TTI, TLI); + MadeChange |= foldLibraryCalls(I, TTI, TLI, DT, AC, MadeCFGChange); } } @@ -869,12 +1062,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, /// handled in the callers of this function. static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, DominatorTree &DT, - AliasAnalysis &AA) { + AliasAnalysis &AA, bool &ChangedCFG) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, ChangedCFG); return MadeChange; } @@ -885,12 +1078,21 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - if (!runImpl(F, AC, TTI, TLI, DT, AA)) { + + bool MadeCFGChange = false; + + if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } + // Mark all the analyses that instcombine updates as preserved. PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); + + if (MadeCFGChange) + PA.preserve<DominatorTreeAnalysis>(); + else + PA.preserveSet<CFGAnalyses>(); + return PA; } diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp index bebaa6cb5969..bf823ac55497 100644 --- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp +++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp @@ -15,12 +15,12 @@ #include "llvm/Transforms/CFGuard.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/TargetParser/Triple.h" using namespace llvm; diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index 81b43a2ab2c2..29978bef661c 100644 --- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -127,10 +127,16 @@ PreservedAnalyses CoroCleanupPass::run(Module &M, FunctionPassManager FPM; FPM.addPass(SimplifyCFGPass()); + PreservedAnalyses FuncPA; + FuncPA.preserveSet<CFGAnalyses>(); + Lowerer L(M); - for (auto &F : M) - if (L.lower(F)) + for (auto &F : M) { + if (L.lower(F)) { + FAM.invalidate(F, FuncPA); FPM.run(F, FAM); + } + } return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp b/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp index 974123fe36a1..3e71e58bb1de 100644 --- a/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp @@ -26,7 +26,7 @@ PreservedAnalyses CoroConditionalWrapper::run(Module &M, void CoroConditionalWrapper::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { OS << "coro-cond"; - OS << "("; + OS << '('; PM.printPipeline(OS, MapClassName2PassName); - OS << ")"; + OS << ')'; } diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index f032c568449b..d78ab1c1ea28 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/Support/ErrorHandling.h" @@ -46,7 +47,8 @@ struct Lowerer : coro::LowererBase { AAResults &AA); bool shouldElide(Function *F, DominatorTree &DT) const; void collectPostSplitCoroIds(Function *F); - bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT); + bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT, + OptimizationRemarkEmitter &ORE); bool hasEscapePath(const CoroBeginInst *, const SmallPtrSetImpl<BasicBlock *> &) const; }; @@ -299,7 +301,7 @@ void Lowerer::collectPostSplitCoroIds(Function *F) { } bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, - DominatorTree &DT) { + DominatorTree &DT, OptimizationRemarkEmitter &ORE) { CoroBegins.clear(); CoroAllocs.clear(); ResumeAddr.clear(); @@ -343,6 +345,24 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, replaceWithConstant(ResumeAddrConstant, ResumeAddr); bool ShouldElide = shouldElide(CoroId->getFunction(), DT); + if (!ShouldElide) + ORE.emit([&]() { + if (auto FrameSizeAndAlign = + getFrameLayout(cast<Function>(ResumeAddrConstant))) + return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) + << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) + << "' not elided in '" + << ore::NV("caller", CoroId->getFunction()->getName()) + << "' (frame_size=" + << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align=" + << ore::NV("align", FrameSizeAndAlign->second.value()) << ")"; + else + return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) + << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) + << "' not elided in '" + << ore::NV("caller", CoroId->getFunction()->getName()) + << "' (frame_size=unknown, align=unknown)"; + }); auto *DestroyAddrConstant = Resumers->getAggregateElement( ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); @@ -363,6 +383,23 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, << "Elide " << CoroId->getCoroutine()->getName() << " in " << CoroId->getFunction()->getName() << "\n"; #endif + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId) + << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) + << "' elided in '" + << ore::NV("caller", CoroId->getFunction()->getName()) + << "' (frame_size=" + << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align=" + << ore::NV("align", FrameSizeAndAlign->second.value()) << ")"; + }); + } else { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) + << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) + << "' not elided in '" + << ore::NV("caller", CoroId->getFunction()->getName()) + << "' (frame_size=unknown, align=unknown)"; + }); } } @@ -387,10 +424,11 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { AAResults &AA = AM.getResult<AAManager>(F); DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); bool Changed = false; for (auto *CII : L.CoroIds) - Changed |= L.processCoroId(CII, AA, DT); + Changed |= L.processCoroId(CII, AA, DT, ORE); return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index e98c601648e0..1f373270f951 100644 --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -16,6 +16,7 @@ #include "CoroInternal.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" #include "llvm/Analysis/PtrUseVisitor.h" @@ -37,6 +38,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> +#include <deque> #include <optional> using namespace llvm; @@ -87,7 +89,7 @@ public: // crosses a suspend point. // namespace { -struct SuspendCrossingInfo { +class SuspendCrossingInfo { BlockToIndexMapping Mapping; struct BlockData { @@ -96,20 +98,30 @@ struct SuspendCrossingInfo { bool Suspend = false; bool End = false; bool KillLoop = false; + bool Changed = false; }; SmallVector<BlockData, SmallVectorThreshold> Block; - iterator_range<succ_iterator> successors(BlockData const &BD) const { + iterator_range<pred_iterator> predecessors(BlockData const &BD) const { BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); - return llvm::successors(BB); + return llvm::predecessors(BB); } BlockData &getBlockData(BasicBlock *BB) { return Block[Mapping.blockToIndex(BB)]; } + /// Compute the BlockData for the current function in one iteration. + /// Returns whether the BlockData changes in this iteration. + /// Initialize - Whether this is the first iteration, we can optimize + /// the initial case a little bit by manual loop switch. + template <bool Initialize = false> bool computeBlockData(); + +public: +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void dump() const; void dump(StringRef Label, BitVector const &BV) const; +#endif SuspendCrossingInfo(Function &F, coro::Shape &Shape); @@ -211,6 +223,72 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { } #endif +template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() { + const size_t N = Mapping.size(); + bool Changed = false; + + for (size_t I = 0; I < N; ++I) { + auto &B = Block[I]; + + // We don't need to count the predecessors when initialization. + if constexpr (!Initialize) + // If all the predecessors of the current Block don't change, + // the BlockData for the current block must not change too. + if (all_of(predecessors(B), [this](BasicBlock *BB) { + return !Block[Mapping.blockToIndex(BB)].Changed; + })) { + B.Changed = false; + continue; + } + + // Saved Consumes and Kills bitsets so that it is easy to see + // if anything changed after propagation. + auto SavedConsumes = B.Consumes; + auto SavedKills = B.Kills; + + for (BasicBlock *PI : predecessors(B)) { + auto PrevNo = Mapping.blockToIndex(PI); + auto &P = Block[PrevNo]; + + // Propagate Kills and Consumes from predecessors into B. + B.Consumes |= P.Consumes; + B.Kills |= P.Kills; + + // If block P is a suspend block, it should propagate kills into block + // B for every block P consumes. + if (P.Suspend) + B.Kills |= P.Consumes; + } + + if (B.Suspend) { + // If block S is a suspend block, it should kill all of the blocks it + // consumes. + B.Kills |= B.Consumes; + } else if (B.End) { + // If block B is an end block, it should not propagate kills as the + // blocks following coro.end() are reached during initial invocation + // of the coroutine while all the data are still available on the + // stack or in the registers. + B.Kills.reset(); + } else { + // This is reached when B block it not Suspend nor coro.end and it + // need to make sure that it is not in the kill set. + B.KillLoop |= B.Kills[I]; + B.Kills.reset(I); + } + + if constexpr (!Initialize) { + B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); + Changed |= B.Changed; + } + } + + if constexpr (Initialize) + return true; + + return Changed; +} + SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) : Mapping(F) { const size_t N = Mapping.size(); @@ -222,6 +300,7 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) B.Consumes.resize(N); B.Kills.resize(N); B.Consumes.set(I); + B.Changed = true; } // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as @@ -246,73 +325,123 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) markSuspendBlock(Save); } - // Iterate propagating consumes and kills until they stop changing. - int Iteration = 0; - (void)Iteration; + computeBlockData</*Initialize=*/true>(); - bool Changed; - do { - LLVM_DEBUG(dbgs() << "iteration " << ++Iteration); - LLVM_DEBUG(dbgs() << "==============\n"); - - Changed = false; - for (size_t I = 0; I < N; ++I) { - auto &B = Block[I]; - for (BasicBlock *SI : successors(B)) { - - auto SuccNo = Mapping.blockToIndex(SI); - - // Saved Consumes and Kills bitsets so that it is easy to see - // if anything changed after propagation. - auto &S = Block[SuccNo]; - auto SavedConsumes = S.Consumes; - auto SavedKills = S.Kills; - - // Propagate Kills and Consumes from block B into its successor S. - S.Consumes |= B.Consumes; - S.Kills |= B.Kills; - - // If block B is a suspend block, it should propagate kills into the - // its successor for every block B consumes. - if (B.Suspend) { - S.Kills |= B.Consumes; - } - if (S.Suspend) { - // If block S is a suspend block, it should kill all of the blocks it - // consumes. - S.Kills |= S.Consumes; - } else if (S.End) { - // If block S is an end block, it should not propagate kills as the - // blocks following coro.end() are reached during initial invocation - // of the coroutine while all the data are still available on the - // stack or in the registers. - S.Kills.reset(); - } else { - // This is reached when S block it not Suspend nor coro.end and it - // need to make sure that it is not in the kill set. - S.KillLoop |= S.Kills[SuccNo]; - S.Kills.reset(SuccNo); - } + while (computeBlockData()) + ; + + LLVM_DEBUG(dump()); +} - // See if anything changed. - Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes); +namespace { - if (S.Kills != SavedKills) { - LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName() - << "\n"); - LLVM_DEBUG(dump("S.Kills", S.Kills)); - LLVM_DEBUG(dump("SavedKills", SavedKills)); - } - if (S.Consumes != SavedConsumes) { - LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n"); - LLVM_DEBUG(dump("S.Consume", S.Consumes)); - LLVM_DEBUG(dump("SavedCons", SavedConsumes)); +// RematGraph is used to construct a DAG for rematerializable instructions +// When the constructor is invoked with a candidate instruction (which is +// materializable) it builds a DAG of materializable instructions from that +// point. +// Typically, for each instruction identified as re-materializable across a +// suspend point, a RematGraph will be created. +struct RematGraph { + // Each RematNode in the graph contains the edges to instructions providing + // operands in the current node. + struct RematNode { + Instruction *Node; + SmallVector<RematNode *> Operands; + RematNode() = default; + RematNode(Instruction *V) : Node(V) {} + }; + + RematNode *EntryNode; + using RematNodeMap = + SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>; + RematNodeMap Remats; + const std::function<bool(Instruction &)> &MaterializableCallback; + SuspendCrossingInfo &Checker; + + RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback, + Instruction *I, SuspendCrossingInfo &Checker) + : MaterializableCallback(MaterializableCallback), Checker(Checker) { + std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I); + EntryNode = FirstNode.get(); + std::deque<std::unique_ptr<RematNode>> WorkList; + addNode(std::move(FirstNode), WorkList, cast<User>(I)); + while (WorkList.size()) { + std::unique_ptr<RematNode> N = std::move(WorkList.front()); + WorkList.pop_front(); + addNode(std::move(N), WorkList, cast<User>(I)); + } + } + + void addNode(std::unique_ptr<RematNode> NUPtr, + std::deque<std::unique_ptr<RematNode>> &WorkList, + User *FirstUse) { + RematNode *N = NUPtr.get(); + if (Remats.count(N->Node)) + return; + + // We haven't see this node yet - add to the list + Remats[N->Node] = std::move(NUPtr); + for (auto &Def : N->Node->operands()) { + Instruction *D = dyn_cast<Instruction>(Def.get()); + if (!D || !MaterializableCallback(*D) || + !Checker.isDefinitionAcrossSuspend(*D, FirstUse)) + continue; + + if (Remats.count(D)) { + // Already have this in the graph + N->Operands.push_back(Remats[D].get()); + continue; + } + + bool NoMatch = true; + for (auto &I : WorkList) { + if (I->Node == D) { + NoMatch = false; + N->Operands.push_back(I.get()); + break; } } + if (NoMatch) { + // Create a new node + std::unique_ptr<RematNode> ChildNode = std::make_unique<RematNode>(D); + N->Operands.push_back(ChildNode.get()); + WorkList.push_back(std::move(ChildNode)); + } } - } while (Changed); - LLVM_DEBUG(dump()); -} + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + void dump() const { + dbgs() << "Entry ("; + if (EntryNode->Node->getParent()->hasName()) + dbgs() << EntryNode->Node->getParent()->getName(); + else + EntryNode->Node->getParent()->printAsOperand(dbgs(), false); + dbgs() << ") : " << *EntryNode->Node << "\n"; + for (auto &E : Remats) { + dbgs() << *(E.first) << "\n"; + for (RematNode *U : E.second->Operands) + dbgs() << " " << *U->Node << "\n"; + } + } +#endif +}; +} // end anonymous namespace + +namespace llvm { + +template <> struct GraphTraits<RematGraph *> { + using NodeRef = RematGraph::RematNode *; + using ChildIteratorType = RematGraph::RematNode **; + + static NodeRef getEntryNode(RematGraph *G) { return G->EntryNode; } + static ChildIteratorType child_begin(NodeRef N) { + return N->Operands.begin(); + } + static ChildIteratorType child_end(NodeRef N) { return N->Operands.end(); } +}; + +} // end namespace llvm #undef DEBUG_TYPE // "coro-suspend-crossing" #define DEBUG_TYPE "coro-frame" @@ -425,6 +554,15 @@ static void dumpSpills(StringRef Title, const SpillInfo &Spills) { I->dump(); } } +static void dumpRemats( + StringRef Title, + const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> &RM) { + dbgs() << "------------- " << Title << "--------------\n"; + for (const auto &E : RM) { + E.second->dump(); + dbgs() << "--\n"; + } +} static void dumpAllocas(const SmallVectorImpl<AllocaInfo> &Allocas) { dbgs() << "------------- Allocas --------------\n"; @@ -637,10 +775,10 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F, return; } - // Because there are pathes from the lifetime.start to coro.end + // Because there are paths from the lifetime.start to coro.end // for each alloca, the liferanges for every alloca is overlaped // in the blocks who contain coro.end and the successor blocks. - // So we choose to skip there blocks when we calculates the liferange + // So we choose to skip there blocks when we calculate the liferange // for each alloca. It should be reasonable since there shouldn't be uses // in these blocks and the coroutine frame shouldn't be used outside the // coroutine body. @@ -820,7 +958,7 @@ void FrameTypeBuilder::finish(StructType *Ty) { static void cacheDIVar(FrameDataInfo &FrameData, DenseMap<Value *, DILocalVariable *> &DIVarCache) { for (auto *V : FrameData.getAllDefs()) { - if (DIVarCache.find(V) != DIVarCache.end()) + if (DIVarCache.contains(V)) continue; auto DDIs = FindDbgDeclareUses(V); @@ -852,18 +990,8 @@ static StringRef solveTypeName(Type *Ty) { return "__floating_type_"; } - if (auto *PtrTy = dyn_cast<PointerType>(Ty)) { - if (PtrTy->isOpaque()) - return "PointerType"; - Type *PointeeTy = PtrTy->getNonOpaquePointerElementType(); - auto Name = solveTypeName(PointeeTy); - if (Name == "UnknownType") - return "PointerType"; - SmallString<16> Buffer; - Twine(Name + "_Ptr").toStringRef(Buffer); - auto *MDName = MDString::get(Ty->getContext(), Buffer.str()); - return MDName->getString(); - } + if (Ty->isPointerTy()) + return "PointerType"; if (Ty->isStructTy()) { if (!cast<StructType>(Ty)->hasName()) @@ -1043,7 +1171,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, dwarf::DW_ATE_unsigned_char)}); for (auto *V : FrameData.getAllDefs()) { - if (DIVarCache.find(V) == DIVarCache.end()) + if (!DIVarCache.contains(V)) continue; auto Index = FrameData.getFieldIndex(V); @@ -1075,7 +1203,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, // fields confilicts with each other. unsigned UnknownTypeNum = 0; for (unsigned Index = 0; Index < FrameTy->getNumElements(); Index++) { - if (OffsetCache.find(Index) == OffsetCache.end()) + if (!OffsetCache.contains(Index)) continue; std::string Name; @@ -1090,7 +1218,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, AlignInBits = OffsetCache[Index].first * 8; OffsetInBits = OffsetCache[Index].second * 8; - if (NameCache.find(Index) != NameCache.end()) { + if (NameCache.contains(Index)) { Name = NameCache[Index].str(); DITy = TyCache[Index]; } else { @@ -1282,7 +1410,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, // function call or any of the memory intrinsics, we check whether this // instruction is prior to CoroBegin. To answer question 3, we track the offsets // of all aliases created for the alloca prior to CoroBegin but used after -// CoroBegin. llvm::Optional is used to be able to represent the case when the +// CoroBegin. std::optional is used to be able to represent the case when the // offset is unknown (e.g. when you have a PHINode that takes in different // offset values). We cannot handle unknown offsets and will assert. This is the // potential issue left out. An ideal solution would likely require a @@ -1586,11 +1714,12 @@ static void createFramePtr(coro::Shape &Shape) { static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { auto *CB = Shape.CoroBegin; LLVMContext &C = CB->getContext(); + Function *F = CB->getFunction(); IRBuilder<> Builder(C); StructType *FrameTy = Shape.FrameTy; Value *FramePtr = Shape.FramePtr; - DominatorTree DT(*CB->getFunction()); - SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache; + DominatorTree DT(*F); + SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; // Create a GEP with the given index into the coroutine frame for the original // value Orig. Appends an extra 0 index for array-allocas, preserving the @@ -1723,6 +1852,21 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { SpillAlignment, E.first->getName() + Twine(".reload")); TinyPtrVector<DbgDeclareInst *> DIs = FindDbgDeclareUses(Def); + // Try best to find dbg.declare. If the spill is a temp, there may not + // be a direct dbg.declare. Walk up the load chain to find one from an + // alias. + if (F->getSubprogram()) { + auto *CurDef = Def; + while (DIs.empty() && isa<LoadInst>(CurDef)) { + auto *LdInst = cast<LoadInst>(CurDef); + // Only consider ptr to ptr same type load. + if (LdInst->getPointerOperandType() != LdInst->getType()) + break; + CurDef = LdInst->getPointerOperand(); + DIs = FindDbgDeclareUses(CurDef); + } + } + for (DbgDeclareInst *DDI : DIs) { bool AllowUnresolved = false; // This dbg.declare is preserved for all coro-split function @@ -1734,16 +1878,10 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { &*Builder.GetInsertPoint()); // This dbg.declare is for the main function entry point. It // will be deleted in all coro-split functions. - coro::salvageDebugInfo(DbgPtrAllocaCache, DDI, Shape.OptimizeFrame); + coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame); } } - // Salvage debug info on any dbg.addr that we see. We do not insert them - // into each block where we have a use though. - if (auto *DI = dyn_cast<DbgAddrIntrinsic>(U)) { - coro::salvageDebugInfo(DbgPtrAllocaCache, DI, Shape.OptimizeFrame); - } - // If we have a single edge PHINode, remove it and replace it with a // reload from the coroutine frame. (We already took care of multi edge // PHINodes by rewriting them in the rewritePHIs function). @@ -1813,11 +1951,13 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { DVI->replaceUsesOfWith(Alloca, G); for (Instruction *I : UsersToUpdate) { - // It is meaningless to remain the lifetime intrinsics refer for the + // It is meaningless to retain the lifetime intrinsics refer for the // member of coroutine frames and the meaningless lifetime intrinsics // are possible to block further optimizations. - if (I->isLifetimeStartOrEnd()) + if (I->isLifetimeStartOrEnd()) { + I->eraseFromParent(); continue; + } I->replaceUsesOfWith(Alloca, G); } @@ -2089,11 +2229,12 @@ static void rewritePHIs(Function &F) { rewritePHIs(*BB); } +/// Default materializable callback // Check for instructions that we can recreate on resume as opposed to spill // the result into a coroutine frame. -static bool materializable(Instruction &V) { - return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) || - isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V); +bool coro::defaultMaterializable(Instruction &V) { + return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) || + isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V)); } // Check for structural coroutine intrinsics that should not be spilled into @@ -2103,41 +2244,82 @@ static bool isCoroutineStructureIntrinsic(Instruction &I) { isa<CoroSuspendInst>(&I); } -// For every use of the value that is across suspend point, recreate that value -// after a suspend point. -static void rewriteMaterializableInstructions(IRBuilder<> &IRB, - const SpillInfo &Spills) { - for (const auto &E : Spills) { - Value *Def = E.first; - BasicBlock *CurrentBlock = nullptr; +// For each instruction identified as materializable across the suspend point, +// and its associated DAG of other rematerializable instructions, +// recreate the DAG of instructions after the suspend point. +static void rewriteMaterializableInstructions( + const SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> + &AllRemats) { + // This has to be done in 2 phases + // Do the remats and record the required defs to be replaced in the + // original use instructions + // Once all the remats are complete, replace the uses in the final + // instructions with the new defs + typedef struct { + Instruction *Use; + Instruction *Def; + Instruction *Remat; + } ProcessNode; + + SmallVector<ProcessNode> FinalInstructionsToProcess; + + for (const auto &E : AllRemats) { + Instruction *Use = E.first; Instruction *CurrentMaterialization = nullptr; - for (Instruction *U : E.second) { - // If we have not seen this block, materialize the value. - if (CurrentBlock != U->getParent()) { + RematGraph *RG = E.second.get(); + ReversePostOrderTraversal<RematGraph *> RPOT(RG); + SmallVector<Instruction *> InstructionsToProcess; + + // If the target use is actually a suspend instruction then we have to + // insert the remats into the end of the predecessor (there should only be + // one). This is so that suspend blocks always have the suspend instruction + // as the first instruction. + auto InsertPoint = &*Use->getParent()->getFirstInsertionPt(); + if (isa<AnyCoroSuspendInst>(Use)) { + BasicBlock *SuspendPredecessorBlock = + Use->getParent()->getSinglePredecessor(); + assert(SuspendPredecessorBlock && "malformed coro suspend instruction"); + InsertPoint = SuspendPredecessorBlock->getTerminator(); + } - bool IsInCoroSuspendBlock = isa<AnyCoroSuspendInst>(U); - CurrentBlock = U->getParent(); - auto *InsertBlock = IsInCoroSuspendBlock - ? CurrentBlock->getSinglePredecessor() - : CurrentBlock; - CurrentMaterialization = cast<Instruction>(Def)->clone(); - CurrentMaterialization->setName(Def->getName()); - CurrentMaterialization->insertBefore( - IsInCoroSuspendBlock ? InsertBlock->getTerminator() - : &*InsertBlock->getFirstInsertionPt()); - } - if (auto *PN = dyn_cast<PHINode>(U)) { - assert(PN->getNumIncomingValues() == 1 && - "unexpected number of incoming " - "values in the PHINode"); - PN->replaceAllUsesWith(CurrentMaterialization); - PN->eraseFromParent(); - continue; - } - // Replace all uses of Def in the current instruction with the - // CurrentMaterialization for the block. - U->replaceUsesOfWith(Def, CurrentMaterialization); + // Note: skip the first instruction as this is the actual use that we're + // rematerializing everything for. + auto I = RPOT.begin(); + ++I; + for (; I != RPOT.end(); ++I) { + Instruction *D = (*I)->Node; + CurrentMaterialization = D->clone(); + CurrentMaterialization->setName(D->getName()); + CurrentMaterialization->insertBefore(InsertPoint); + InsertPoint = CurrentMaterialization; + + // Replace all uses of Def in the instructions being added as part of this + // rematerialization group + for (auto &I : InstructionsToProcess) + I->replaceUsesOfWith(D, CurrentMaterialization); + + // Don't replace the final use at this point as this can cause problems + // for other materializations. Instead, for any final use that uses a + // define that's being rematerialized, record the replace values + for (unsigned i = 0, E = Use->getNumOperands(); i != E; ++i) + if (Use->getOperand(i) == D) // Is this operand pointing to oldval? + FinalInstructionsToProcess.push_back( + {Use, D, CurrentMaterialization}); + + InstructionsToProcess.push_back(CurrentMaterialization); + } + } + + // Finally, replace the uses with the defines that we've just rematerialized + for (auto &R : FinalInstructionsToProcess) { + if (auto *PN = dyn_cast<PHINode>(R.Use)) { + assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming " + "values in the PHINode"); + PN->replaceAllUsesWith(R.Remat); + PN->eraseFromParent(); + continue; } + R.Use->replaceUsesOfWith(R.Def, R.Remat); } } @@ -2407,10 +2589,7 @@ static void eliminateSwiftErrorArgument(Function &F, Argument &Arg, IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg()); auto ArgTy = cast<PointerType>(Arg.getType()); - // swifterror arguments are required to have pointer-to-pointer type, - // so create a pointer-typed alloca with opaque pointers. - auto ValueTy = ArgTy->isOpaque() ? PointerType::getUnqual(F.getContext()) - : ArgTy->getNonOpaquePointerElementType(); + auto ValueTy = PointerType::getUnqual(F.getContext()); // Reduce to the alloca case: @@ -2523,6 +2702,9 @@ static void sinkSpillUsesAfterCoroBegin(Function &F, /// hence minimizing the amount of data we end up putting on the frame. static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape, SuspendCrossingInfo &Checker) { + if (F.hasOptNone()) + return; + DominatorTree DT(F); // Collect all possible basic blocks which may dominate all uses of allocas. @@ -2635,7 +2817,7 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape, } void coro::salvageDebugInfo( - SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> &DbgPtrAllocaCache, + SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, DbgVariableIntrinsic *DVI, bool OptimizeFrame) { Function *F = DVI->getFunction(); IRBuilder<> Builder(F->getContext()); @@ -2652,7 +2834,7 @@ void coro::salvageDebugInfo( while (auto *Inst = dyn_cast_or_null<Instruction>(Storage)) { if (auto *LdInst = dyn_cast<LoadInst>(Inst)) { - Storage = LdInst->getOperand(0); + Storage = LdInst->getPointerOperand(); // FIXME: This is a heuristic that works around the fact that // LLVM IR debug intrinsics cannot yet distinguish between // memory and value locations: Because a dbg.declare(alloca) is @@ -2662,7 +2844,7 @@ void coro::salvageDebugInfo( if (!SkipOutermostLoad) Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore); } else if (auto *StInst = dyn_cast<StoreInst>(Inst)) { - Storage = StInst->getOperand(0); + Storage = StInst->getValueOperand(); } else { SmallVector<uint64_t, 16> Ops; SmallVector<Value *, 0> AdditionalValues; @@ -2682,38 +2864,44 @@ void coro::salvageDebugInfo( if (!Storage) return; - // Store a pointer to the coroutine frame object in an alloca so it - // is available throughout the function when producing unoptimized - // code. Extending the lifetime this way is correct because the - // variable has been declared by a dbg.declare intrinsic. - // - // Avoid to create the alloca would be eliminated by optimization - // passes and the corresponding dbg.declares would be invalid. - if (!OptimizeFrame) - if (auto *Arg = dyn_cast<llvm::Argument>(Storage)) { - auto &Cached = DbgPtrAllocaCache[Storage]; - if (!Cached) { - Cached = Builder.CreateAlloca(Storage->getType(), 0, nullptr, - Arg->getName() + ".debug"); - Builder.CreateStore(Storage, Cached); - } - Storage = Cached; - // FIXME: LLVM lacks nuanced semantics to differentiate between - // memory and direct locations at the IR level. The backend will - // turn a dbg.declare(alloca, ..., DIExpression()) into a memory - // location. Thus, if there are deref and offset operations in the - // expression, we need to add a DW_OP_deref at the *start* of the - // expression to first load the contents of the alloca before - // adjusting it with the expression. - Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore); + auto *StorageAsArg = dyn_cast<Argument>(Storage); + const bool IsSwiftAsyncArg = + StorageAsArg && StorageAsArg->hasAttribute(Attribute::SwiftAsync); + + // Swift async arguments are described by an entry value of the ABI-defined + // register containing the coroutine context. + if (IsSwiftAsyncArg && !Expr->isEntryValue()) + Expr = DIExpression::prepend(Expr, DIExpression::EntryValue); + + // If the coroutine frame is an Argument, store it in an alloca to improve + // its availability (e.g. registers may be clobbered). + // Avoid this if optimizations are enabled (they would remove the alloca) or + // if the value is guaranteed to be available through other means (e.g. swift + // ABI guarantees). + if (StorageAsArg && !OptimizeFrame && !IsSwiftAsyncArg) { + auto &Cached = ArgToAllocaMap[StorageAsArg]; + if (!Cached) { + Cached = Builder.CreateAlloca(Storage->getType(), 0, nullptr, + Storage->getName() + ".debug"); + Builder.CreateStore(Storage, Cached); } + Storage = Cached; + // FIXME: LLVM lacks nuanced semantics to differentiate between + // memory and direct locations at the IR level. The backend will + // turn a dbg.declare(alloca, ..., DIExpression()) into a memory + // location. Thus, if there are deref and offset operations in the + // expression, we need to add a DW_OP_deref at the *start* of the + // expression to first load the contents of the alloca before + // adjusting it with the expression. + Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore); + } DVI->replaceVariableLocationOp(OriginalStorage, Storage); DVI->setExpression(Expr); // We only hoist dbg.declare today since it doesn't make sense to hoist - // dbg.value or dbg.addr since they do not have the same function wide - // guarantees that dbg.declare does. - if (!isa<DbgValueInst>(DVI) && !isa<DbgAddrIntrinsic>(DVI)) { + // dbg.value since it does not have the same function wide guarantees that + // dbg.declare does. + if (isa<DbgDeclareInst>(DVI)) { Instruction *InsertPt = nullptr; if (auto *I = dyn_cast<Instruction>(Storage)) InsertPt = I->getInsertionPointAfterDef(); @@ -2724,7 +2912,71 @@ void coro::salvageDebugInfo( } } -void coro::buildCoroutineFrame(Function &F, Shape &Shape) { +static void doRematerializations( + Function &F, SuspendCrossingInfo &Checker, + const std::function<bool(Instruction &)> &MaterializableCallback) { + if (F.hasOptNone()) + return; + + SpillInfo Spills; + + // See if there are materializable instructions across suspend points + // We record these as the starting point to also identify materializable + // defs of uses in these operations + for (Instruction &I : instructions(F)) { + if (!MaterializableCallback(I)) + continue; + for (User *U : I.users()) + if (Checker.isDefinitionAcrossSuspend(I, U)) + Spills[&I].push_back(cast<Instruction>(U)); + } + + // Process each of the identified rematerializable instructions + // and add predecessor instructions that can also be rematerialized. + // This is actually a graph of instructions since we could potentially + // have multiple uses of a def in the set of predecessor instructions. + // The approach here is to maintain a graph of instructions for each bottom + // level instruction - where we have a unique set of instructions (nodes) + // and edges between them. We then walk the graph in reverse post-dominator + // order to insert them past the suspend point, but ensure that ordering is + // correct. We also rely on CSE removing duplicate defs for remats of + // different instructions with a def in common (rather than maintaining more + // complex graphs for each suspend point) + + // We can do this by adding new nodes to the list for each suspend + // point. Then using standard GraphTraits to give a reverse post-order + // traversal when we insert the nodes after the suspend + SmallMapVector<Instruction *, std::unique_ptr<RematGraph>, 8> AllRemats; + for (auto &E : Spills) { + for (Instruction *U : E.second) { + // Don't process a user twice (this can happen if the instruction uses + // more than one rematerializable def) + if (AllRemats.count(U)) + continue; + + // Constructor creates the whole RematGraph for the given Use + auto RematUPtr = + std::make_unique<RematGraph>(MaterializableCallback, U, Checker); + + LLVM_DEBUG(dbgs() << "***** Next remat group *****\n"; + ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get()); + for (auto I = RPOT.begin(); I != RPOT.end(); + ++I) { (*I)->Node->dump(); } dbgs() + << "\n";); + + AllRemats[U] = std::move(RematUPtr); + } + } + + // Rewrite materializable instructions to be materialized at the use + // point. + LLVM_DEBUG(dumpRemats("Materializations", AllRemats)); + rewriteMaterializableInstructions(AllRemats); +} + +void coro::buildCoroutineFrame( + Function &F, Shape &Shape, + const std::function<bool(Instruction &)> &MaterializableCallback) { // Don't eliminate swifterror in async functions that won't be split. if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty()) eliminateSwiftError(F, Shape); @@ -2775,35 +3027,11 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { // Build suspend crossing info. SuspendCrossingInfo Checker(F, Shape); - IRBuilder<> Builder(F.getContext()); + doRematerializations(F, Checker, MaterializableCallback); + FrameDataInfo FrameData; SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas; SmallVector<Instruction*, 4> DeadInstructions; - - { - SpillInfo Spills; - for (int Repeat = 0; Repeat < 4; ++Repeat) { - // See if there are materializable instructions across suspend points. - // FIXME: We can use a worklist to track the possible materialize - // instructions instead of iterating the whole function again and again. - for (Instruction &I : instructions(F)) - if (materializable(I)) { - for (User *U : I.users()) - if (Checker.isDefinitionAcrossSuspend(I, U)) - Spills[&I].push_back(cast<Instruction>(U)); - } - - if (Spills.empty()) - break; - - // Rewrite materializable instructions to be materialized at the use - // point. - LLVM_DEBUG(dumpSpills("Materializations", Spills)); - rewriteMaterializableInstructions(Builder, Spills); - Spills.clear(); - } - } - if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && Shape.ABI != coro::ABI::RetconOnce) sinkLifetimeStartMarkers(F, Shape, Checker); diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h index 032361c22045..067fb6bba47e 100644 --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -25,10 +25,13 @@ bool declaresIntrinsics(const Module &M, const std::initializer_list<StringRef>); void replaceCoroFree(CoroIdInst *CoroId, bool Elide); -/// Recover a dbg.declare prepared by the frontend and emit an alloca -/// holding a pointer to the coroutine frame. +/// Attempts to rewrite the location operand of debug intrinsics in terms of +/// the coroutine frame pointer, folding pointer offsets into the DIExpression +/// of the intrinsic. +/// If the frame pointer is an Argument, store it into an alloca if +/// OptimizeFrame is false. void salvageDebugInfo( - SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> &DbgPtrAllocaCache, + SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, DbgVariableIntrinsic *DVI, bool OptimizeFrame); // Keeps data and helper functions for lowering coroutine intrinsics. @@ -124,7 +127,6 @@ struct LLVM_LIBRARY_VISIBILITY Shape { }; struct AsyncLoweringStorage { - FunctionType *AsyncFuncTy; Value *Context; CallingConv::ID AsyncCC; unsigned ContextArgNo; @@ -261,7 +263,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape { void buildFrom(Function &F); }; -void buildCoroutineFrame(Function &F, Shape &Shape); +bool defaultMaterializable(Instruction &V); +void buildCoroutineFrame( + Function &F, Shape &Shape, + const std::function<bool(Instruction &)> &MaterializableCallback); CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, ArrayRef<Value *> Arguments, IRBuilder<> &); } // End namespace coro. diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 1171878f749a..39e909bf3316 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -31,6 +31,7 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/Argument.h" @@ -299,6 +300,26 @@ static void markCoroutineAsDone(IRBuilder<> &Builder, const coro::Shape &Shape, auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( Shape.FrameTy->getTypeAtIndex(coro::Shape::SwitchFieldIndex::Resume))); Builder.CreateStore(NullPtr, GepIndex); + + // If the coroutine don't have unwind coro end, we could omit the store to + // the final suspend point since we could infer the coroutine is suspended + // at the final suspend point by the nullness of ResumeFnAddr. + // However, we can't skip it if the coroutine have unwind coro end. Since + // the coroutine reaches unwind coro end is considered suspended at the + // final suspend point (the ResumeFnAddr is null) but in fact the coroutine + // didn't complete yet. We need the IndexVal for the final suspend point + // to make the states clear. + if (Shape.SwitchLowering.HasUnwindCoroEnd && + Shape.SwitchLowering.HasFinalSuspend) { + assert(cast<CoroSuspendInst>(Shape.CoroSuspends.back())->isFinal() && + "The final suspend should only live in the last position of " + "CoroSuspends."); + ConstantInt *IndexVal = Shape.getIndex(Shape.CoroSuspends.size() - 1); + auto *FinalIndex = Builder.CreateStructGEP( + Shape.FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); + + Builder.CreateStore(IndexVal, FinalIndex); + } } /// Replace an unwind call to llvm.coro.end. @@ -396,17 +417,7 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { // The coroutine should be marked done if it reaches the final suspend // point. markCoroutineAsDone(Builder, Shape, FramePtr); - } - - // If the coroutine don't have unwind coro end, we could omit the store to - // the final suspend point since we could infer the coroutine is suspended - // at the final suspend point by the nullness of ResumeFnAddr. - // However, we can't skip it if the coroutine have unwind coro end. Since - // the coroutine reaches unwind coro end is considered suspended at the - // final suspend point (the ResumeFnAddr is null) but in fact the coroutine - // didn't complete yet. We need the IndexVal for the final suspend point - // to make the states clear. - if (!S->isFinal() || Shape.SwitchLowering.HasUnwindCoroEnd) { + } else { auto *GepIndex = Builder.CreateStructGEP( FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); Builder.CreateStore(IndexVal, GepIndex); @@ -565,7 +576,7 @@ void CoroCloner::replaceRetconOrAsyncSuspendUses() { if (NewS->use_empty()) return; // Otherwise, we need to create an aggregate. - Value *Agg = UndefValue::get(NewS->getType()); + Value *Agg = PoisonValue::get(NewS->getType()); for (size_t I = 0, E = Args.size(); I != E; ++I) Agg = Builder.CreateInsertValue(Agg, Args[I], I); @@ -623,20 +634,13 @@ static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape, return; Value *CachedSlot = nullptr; auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * { - if (CachedSlot) { - assert(cast<PointerType>(CachedSlot->getType()) - ->isOpaqueOrPointeeTypeMatches(ValueTy) && - "multiple swifterror slots in function with different types"); + if (CachedSlot) return CachedSlot; - } // Check if the function has a swifterror argument. for (auto &Arg : F.args()) { if (Arg.isSwiftError()) { CachedSlot = &Arg; - assert(cast<PointerType>(Arg.getType()) - ->isOpaqueOrPointeeTypeMatches(ValueTy) && - "swifterror argument does not have expected type"); return &Arg; } } @@ -679,19 +683,26 @@ static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape, } } +/// Returns all DbgVariableIntrinsic in F. +static SmallVector<DbgVariableIntrinsic *, 8> +collectDbgVariableIntrinsics(Function &F) { + SmallVector<DbgVariableIntrinsic *, 8> Intrinsics; + for (auto &I : instructions(F)) + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) + Intrinsics.push_back(DVI); + return Intrinsics; +} + void CoroCloner::replaceSwiftErrorOps() { ::replaceSwiftErrorOps(*NewF, Shape, &VMap); } void CoroCloner::salvageDebugInfo() { - SmallVector<DbgVariableIntrinsic *, 8> Worklist; - SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache; - for (auto &BB : *NewF) - for (auto &I : BB) - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) - Worklist.push_back(DVI); + SmallVector<DbgVariableIntrinsic *, 8> Worklist = + collectDbgVariableIntrinsics(*NewF); + SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; for (DbgVariableIntrinsic *DVI : Worklist) - coro::salvageDebugInfo(DbgPtrAllocaCache, DVI, Shape.OptimizeFrame); + coro::salvageDebugInfo(ArgToAllocaMap, DVI, Shape.OptimizeFrame); // Remove all salvaged dbg.declare intrinsics that became // either unreachable or stale due to the CoroSplit transformation. @@ -886,7 +897,7 @@ void CoroCloner::create() { // frame. SmallVector<Instruction *> DummyArgs; for (Argument &A : OrigF.args()) { - DummyArgs.push_back(new FreezeInst(UndefValue::get(A.getType()))); + DummyArgs.push_back(new FreezeInst(PoisonValue::get(A.getType()))); VMap[&A] = DummyArgs.back(); } @@ -1044,7 +1055,7 @@ void CoroCloner::create() { // All uses of the arguments should have been resolved by this point, // so we can safely remove the dummy values. for (Instruction *DummyArg : DummyArgs) { - DummyArg->replaceAllUsesWith(UndefValue::get(DummyArg->getType())); + DummyArg->replaceAllUsesWith(PoisonValue::get(DummyArg->getType())); DummyArg->deleteValue(); } @@ -1231,8 +1242,11 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, // instruction. Suspend instruction represented by a switch, track the PHI // values and select the correct case successor when possible. static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { + // There is nothing to simplify. + if (isa<ReturnInst>(InitialInst)) + return false; + DenseMap<Value *, Value *> ResolvedValues; - BasicBlock *UnconditionalSucc = nullptr; assert(InitialInst->getModule()); const DataLayout &DL = InitialInst->getModule()->getDataLayout(); @@ -1262,39 +1276,35 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { Instruction *I = InitialInst; while (I->isTerminator() || isa<CmpInst>(I)) { if (isa<ReturnInst>(I)) { - if (I != InitialInst) { - // If InitialInst is an unconditional branch, - // remove PHI values that come from basic block of InitialInst - if (UnconditionalSucc) - UnconditionalSucc->removePredecessor(InitialInst->getParent(), true); - ReplaceInstWithInst(InitialInst, I->clone()); - } + ReplaceInstWithInst(InitialInst, I->clone()); return true; } + if (auto *BR = dyn_cast<BranchInst>(I)) { - if (BR->isUnconditional()) { - BasicBlock *Succ = BR->getSuccessor(0); - if (I == InitialInst) - UnconditionalSucc = Succ; - scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues); - I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime()); - continue; + unsigned SuccIndex = 0; + if (BR->isConditional()) { + // Handle the case the condition of the conditional branch is constant. + // e.g., + // + // br i1 false, label %cleanup, label %CoroEnd + // + // It is possible during the transformation. We could continue the + // simplifying in this case. + ConstantInt *Cond = TryResolveConstant(BR->getCondition()); + if (!Cond) + return false; + + SuccIndex = Cond->isOne() ? 0 : 1; } - BasicBlock *BB = BR->getParent(); - // Handle the case the condition of the conditional branch is constant. - // e.g., - // - // br i1 false, label %cleanup, label %CoroEnd - // - // It is possible during the transformation. We could continue the - // simplifying in this case. - if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) { - // Handle this branch in next iteration. - I = BB->getTerminator(); - continue; - } - } else if (auto *CondCmp = dyn_cast<CmpInst>(I)) { + BasicBlock *Succ = BR->getSuccessor(SuccIndex); + scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues); + I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime()); + + continue; + } + + if (auto *CondCmp = dyn_cast<CmpInst>(I)) { // If the case number of suspended switch instruction is reduced to // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator. auto *BR = dyn_cast<BranchInst>( @@ -1318,13 +1328,14 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { if (!ConstResult) return false; - CondCmp->replaceAllUsesWith(ConstResult); - CondCmp->eraseFromParent(); + ResolvedValues[BR->getCondition()] = ConstResult; // Handle this branch in next iteration. I = BR; continue; - } else if (auto *SI = dyn_cast<SwitchInst>(I)) { + } + + if (auto *SI = dyn_cast<SwitchInst>(I)) { ConstantInt *Cond = TryResolveConstant(SI->getCondition()); if (!Cond) return false; @@ -1337,6 +1348,7 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { return false; } + return false; } @@ -1889,7 +1901,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, if (ReturnPHIs.size() == 1) { RetV = CastedContinuation; } else { - RetV = UndefValue::get(RetTy); + RetV = PoisonValue::get(RetTy); RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0); for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I) RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I); @@ -1929,10 +1941,10 @@ namespace { }; } -static coro::Shape splitCoroutine(Function &F, - SmallVectorImpl<Function *> &Clones, - TargetTransformInfo &TTI, - bool OptimizeFrame) { +static coro::Shape +splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, + TargetTransformInfo &TTI, bool OptimizeFrame, + std::function<bool(Instruction &)> MaterializableCallback) { PrettyStackTraceFunction prettyStackTrace(F); // The suspend-crossing algorithm in buildCoroutineFrame get tripped @@ -1944,7 +1956,7 @@ static coro::Shape splitCoroutine(Function &F, return Shape; simplifySuspendPoints(Shape); - buildCoroutineFrame(F, Shape); + buildCoroutineFrame(F, Shape, MaterializableCallback); replaceFrameSizeAndAlignment(Shape); // If there are no suspend points, no split required, just remove @@ -1970,25 +1982,12 @@ static coro::Shape splitCoroutine(Function &F, // This invalidates SwiftErrorOps in the Shape. replaceSwiftErrorOps(F, Shape, nullptr); - // Finally, salvage the llvm.dbg.{declare,addr} in our original function that - // point into the coroutine frame. We only do this for the current function - // since the Cloner salvaged debug info for us in the new coroutine funclets. - SmallVector<DbgVariableIntrinsic *, 8> Worklist; - SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache; - for (auto &BB : F) { - for (auto &I : BB) { - if (auto *DDI = dyn_cast<DbgDeclareInst>(&I)) { - Worklist.push_back(DDI); - continue; - } - if (auto *DDI = dyn_cast<DbgAddrIntrinsic>(&I)) { - Worklist.push_back(DDI); - continue; - } - } - } - for (auto *DDI : Worklist) - coro::salvageDebugInfo(DbgPtrAllocaCache, DDI, Shape.OptimizeFrame); + // Salvage debug intrinsics that point into the coroutine frame in the + // original function. The Cloner has already salvaged debug info in the new + // coroutine funclets. + SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; + for (auto *DDI : collectDbgVariableIntrinsics(F)) + coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame); return Shape; } @@ -2104,6 +2103,10 @@ static void addPrepareFunction(const Module &M, Fns.push_back(PrepareFn); } +CoroSplitPass::CoroSplitPass(bool OptimizeFrame) + : MaterializableCallback(coro::defaultMaterializable), + OptimizeFrame(OptimizeFrame) {} + PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -2142,10 +2145,19 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, F.setSplittedCoroutine(); SmallVector<Function *, 4> Clones; - const coro::Shape Shape = splitCoroutine( - F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame); + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + const coro::Shape Shape = + splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F), + OptimizeFrame, MaterializableCallback); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "CoroSplit", &F) + << "Split '" << ore::NV("function", F.getName()) + << "' (frame_size=" << ore::NV("frame_size", Shape.FrameSize) + << ", align=" << ore::NV("align", Shape.FrameAlign.value()) << ")"; + }); + if (!Shape.CoroSuspends.empty()) { // Run the CGSCC pipeline on the original and newly split functions. UR.CWorklist.insert(&C); diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index ce4262e593b6..cde74c5e693b 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -596,20 +596,6 @@ static void checkAsyncFuncPointer(const Instruction *I, Value *V) { auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts()); if (!AsyncFuncPtrAddr) fail(I, "llvm.coro.id.async async function pointer not a global", V); - - if (AsyncFuncPtrAddr->getType()->isOpaquePointerTy()) - return; - - auto *StructTy = cast<StructType>( - AsyncFuncPtrAddr->getType()->getNonOpaquePointerElementType()); - if (StructTy->isOpaque() || !StructTy->isPacked() || - StructTy->getNumElements() != 2 || - !StructTy->getElementType(0)->isIntegerTy(32) || - !StructTy->getElementType(1)->isIntegerTy(32)) - fail(I, - "llvm.coro.id.async async function pointer argument's type is not " - "<{i32, i32}>", - V); } void CoroIdAsyncInst::checkWellFormed() const { @@ -625,19 +611,15 @@ void CoroIdAsyncInst::checkWellFormed() const { static void checkAsyncContextProjectFunction(const Instruction *I, Function *F) { auto *FunTy = cast<FunctionType>(F->getValueType()); - Type *Int8Ty = Type::getInt8Ty(F->getContext()); - auto *RetPtrTy = dyn_cast<PointerType>(FunTy->getReturnType()); - if (!RetPtrTy || !RetPtrTy->isOpaqueOrPointeeTypeMatches(Int8Ty)) + if (!FunTy->getReturnType()->isPointerTy()) fail(I, "llvm.coro.suspend.async resume function projection function must " - "return an i8* type", + "return a ptr type", F); - if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy() || - !cast<PointerType>(FunTy->getParamType(0)) - ->isOpaqueOrPointeeTypeMatches(Int8Ty)) + if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy()) fail(I, "llvm.coro.suspend.async resume function projection function must " - "take one i8* type as parameter", + "take one ptr type as parameter", F); } diff --git a/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/llvm/lib/Transforms/IPO/AlwaysInliner.cpp index 09286482edff..cc375f9badcd 100644 --- a/llvm/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -28,16 +28,13 @@ using namespace llvm; #define DEBUG_TYPE "inline" -PreservedAnalyses AlwaysInlinerPass::run(Module &M, - ModuleAnalysisManager &MAM) { - // Add inline assumptions during code generation. - FunctionAnalysisManager &FAM = - MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { - return FAM.getResult<AssumptionAnalysis>(F); - }; - auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M); +namespace { +bool AlwaysInlineImpl( + Module &M, bool InsertLifetime, ProfileSummaryInfo &PSI, + function_ref<AssumptionCache &(Function &)> GetAssumptionCache, + function_ref<AAResults &(Function &)> GetAAR, + function_ref<BlockFrequencyInfo &(Function &)> GetBFI) { SmallSetVector<CallBase *, 16> Calls; bool Changed = false; SmallVector<Function *, 16> InlinedFunctions; @@ -65,14 +62,12 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, DebugLoc DLoc = CB->getDebugLoc(); BasicBlock *Block = CB->getParent(); - InlineFunctionInfo IFI( - /*cg=*/nullptr, GetAssumptionCache, &PSI, - &FAM.getResult<BlockFrequencyAnalysis>(*Caller), - &FAM.getResult<BlockFrequencyAnalysis>(F)); + InlineFunctionInfo IFI(GetAssumptionCache, &PSI, + GetBFI ? &GetBFI(*Caller) : nullptr, + GetBFI ? &GetBFI(F) : nullptr); - InlineResult Res = - InlineFunction(*CB, IFI, /*MergeAttributes=*/true, - &FAM.getResult<AAManager>(F), InsertLifetime); + InlineResult Res = InlineFunction(*CB, IFI, /*MergeAttributes=*/true, + &GetAAR(F), InsertLifetime); if (!Res.isSuccess()) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, @@ -127,48 +122,52 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, } } - return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + return Changed; } -namespace { - -/// Inliner pass which only handles "always inline" functions. -/// -/// Unlike the \c AlwaysInlinerPass, this uses the more heavyweight \c Inliner -/// base class to provide several facilities such as array alloca merging. -class AlwaysInlinerLegacyPass : public LegacyInlinerBase { +struct AlwaysInlinerLegacyPass : public ModulePass { + bool InsertLifetime; -public: - AlwaysInlinerLegacyPass() : LegacyInlinerBase(ID, /*InsertLifetime*/ true) { - initializeAlwaysInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); - } + AlwaysInlinerLegacyPass() + : AlwaysInlinerLegacyPass(/*InsertLifetime*/ true) {} AlwaysInlinerLegacyPass(bool InsertLifetime) - : LegacyInlinerBase(ID, InsertLifetime) { + : ModulePass(ID), InsertLifetime(InsertLifetime) { initializeAlwaysInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); } /// Main run interface method. We override here to avoid calling skipSCC(). - bool runOnSCC(CallGraphSCC &SCC) override { return inlineCalls(SCC); } + bool runOnModule(Module &M) override { + + auto &PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + auto GetAAR = [&](Function &F) -> AAResults & { + return getAnalysis<AAResultsWrapperPass>(F).getAAResults(); + }; + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + }; + + return AlwaysInlineImpl(M, InsertLifetime, PSI, GetAssumptionCache, GetAAR, + /*GetBFI*/ nullptr); + } static char ID; // Pass identification, replacement for typeid - InlineCost getInlineCost(CallBase &CB) override; - - using llvm::Pass::doFinalization; - bool doFinalization(CallGraph &CG) override { - return removeDeadFunctions(CG, /*AlwaysInlineOnly=*/true); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); } }; -} + +} // namespace char AlwaysInlinerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(AlwaysInlinerLegacyPass, "always-inline", "Inliner for always_inline functions", false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(AlwaysInlinerLegacyPass, "always-inline", "Inliner for always_inline functions", false, false) @@ -176,46 +175,23 @@ Pass *llvm::createAlwaysInlinerLegacyPass(bool InsertLifetime) { return new AlwaysInlinerLegacyPass(InsertLifetime); } -/// Get the inline cost for the always-inliner. -/// -/// The always inliner *only* handles functions which are marked with the -/// attribute to force inlining. As such, it is dramatically simpler and avoids -/// using the powerful (but expensive) inline cost analysis. Instead it uses -/// a very simple and boring direct walk of the instructions looking for -/// impossible-to-inline constructs. -/// -/// Note, it would be possible to go to some lengths to cache the information -/// computed here, but as we only expect to do this for relatively few and -/// small functions which have the explicit attribute to force inlining, it is -/// likely not worth it in practice. -InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallBase &CB) { - Function *Callee = CB.getCalledFunction(); - - // Only inline direct calls to functions with always-inline attributes - // that are viable for inlining. - if (!Callee) - return InlineCost::getNever("indirect call"); - - // When callee coroutine function is inlined into caller coroutine function - // before coro-split pass, - // coro-early pass can not handle this quiet well. - // So we won't inline the coroutine function if it have not been unsplited - if (Callee->isPresplitCoroutine()) - return InlineCost::getNever("unsplited coroutine call"); - - // FIXME: We shouldn't even get here for declarations. - if (Callee->isDeclaration()) - return InlineCost::getNever("no definition"); - - if (!CB.hasFnAttr(Attribute::AlwaysInline)) - return InlineCost::getNever("no alwaysinline attribute"); - - if (Callee->hasFnAttribute(Attribute::AlwaysInline) && CB.isNoInline()) - return InlineCost::getNever("noinline call site attribute"); - - auto IsViable = isInlineViable(*Callee); - if (!IsViable.isSuccess()) - return InlineCost::getNever(IsViable.getFailureReason()); - - return InlineCost::getAlways("always inliner"); +PreservedAnalyses AlwaysInlinerPass::run(Module &M, + ModuleAnalysisManager &MAM) { + FunctionAnalysisManager &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + auto GetBFI = [&](Function &F) -> BlockFrequencyInfo & { + return FAM.getResult<BlockFrequencyAnalysis>(F); + }; + auto GetAAR = [&](Function &F) -> AAResults & { + return FAM.getResult<AAManager>(F); + }; + auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M); + + bool Changed = AlwaysInlineImpl(M, InsertLifetime, PSI, GetAssumptionCache, + GetAAR, GetBFI); + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp b/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp index 6cc04544cabc..40cc00d2c78c 100644 --- a/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp +++ b/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp @@ -17,8 +17,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" using namespace llvm; @@ -64,36 +62,8 @@ static bool convertAnnotation2Metadata(Module &M) { return true; } -namespace { -struct Annotation2MetadataLegacy : public ModulePass { - static char ID; - - Annotation2MetadataLegacy() : ModulePass(ID) { - initializeAnnotation2MetadataLegacyPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { return convertAnnotation2Metadata(M); } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } -}; - -} // end anonymous namespace - -char Annotation2MetadataLegacy::ID = 0; - -INITIALIZE_PASS_BEGIN(Annotation2MetadataLegacy, DEBUG_TYPE, - "Annotation2Metadata", false, false) -INITIALIZE_PASS_END(Annotation2MetadataLegacy, DEBUG_TYPE, - "Annotation2Metadata", false, false) - -ModulePass *llvm::createAnnotation2MetadataLegacyPass() { - return new Annotation2MetadataLegacy(); -} - PreservedAnalyses Annotation2MetadataPass::run(Module &M, ModuleAnalysisManager &AM) { - convertAnnotation2Metadata(M); - return PreservedAnalyses::all(); + return convertAnnotation2Metadata(M) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index dd1a3b78a378..824da6395f2e 100644 --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -67,6 +67,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> @@ -97,49 +98,11 @@ using OffsetAndArgPart = std::pair<int64_t, ArgPart>; static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL, Value *Ptr, Type *ResElemTy, int64_t Offset) { - // For non-opaque pointers, try to create a "nice" GEP if possible, otherwise - // fall back to an i8 GEP to a specific offset. - unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace(); - APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset); - if (!Ptr->getType()->isOpaquePointerTy()) { - Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType(); - if (OrigOffset == 0 && OrigElemTy == ResElemTy) - return Ptr; - - if (OrigElemTy->isSized()) { - APInt TmpOffset = OrigOffset; - Type *TmpTy = OrigElemTy; - SmallVector<APInt> IntIndices = - DL.getGEPIndicesForOffset(TmpTy, TmpOffset); - if (TmpOffset == 0) { - // Try to add trailing zero indices to reach the right type. - while (TmpTy != ResElemTy) { - Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0); - if (!NextTy) - break; - - IntIndices.push_back(APInt::getZero( - isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth())); - TmpTy = NextTy; - } - - SmallVector<Value *> Indices; - for (const APInt &Index : IntIndices) - Indices.push_back(IRB.getInt(Index)); - - if (OrigOffset != 0 || TmpTy == ResElemTy) { - Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices); - return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace)); - } - } - } + if (Offset != 0) { + APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset); + Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(APOffset)); } - - if (OrigOffset != 0) { - Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace)); - Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset)); - } - return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace)); + return Ptr; } /// DoPromotion - This method actually performs the promotion of the specified @@ -220,6 +183,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // pass in the loaded pointers. SmallVector<Value *, 16> Args; const DataLayout &DL = F->getParent()->getDataLayout(); + SmallVector<WeakTrackingVH, 16> DeadArgs; + while (!F->use_empty()) { CallBase &CB = cast<CallBase>(*F->user_back()); assert(CB.getCalledFunction() == F); @@ -246,15 +211,25 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, if (Pair.second.MustExecInstr) { LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata()); LI->copyMetadata(*Pair.second.MustExecInstr, - {LLVMContext::MD_range, LLVMContext::MD_nonnull, - LLVMContext::MD_dereferenceable, + {LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_align, LLVMContext::MD_noundef, + LLVMContext::MD_noundef, LLVMContext::MD_nontemporal}); + // Only transfer poison-generating metadata if we also have + // !noundef. + // TODO: Without !noundef, we could merge this metadata across + // all promoted loads. + if (LI->hasMetadata(LLVMContext::MD_noundef)) + LI->copyMetadata(*Pair.second.MustExecInstr, + {LLVMContext::MD_range, LLVMContext::MD_nonnull, + LLVMContext::MD_align}); } Args.push_back(LI); ArgAttrVec.push_back(AttributeSet()); } + } else { + assert(ArgsToPromote.count(&*I) && I->use_empty()); + DeadArgs.emplace_back(AI->get()); } } @@ -297,6 +272,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, CB.eraseFromParent(); } + RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadArgs); + // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. @@ -766,6 +743,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM, // Check to see which arguments are promotable. If an argument is promotable, // add it to ArgsToPromote. DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote; + unsigned NumArgsAfterPromote = F->getFunctionType()->getNumParams(); for (Argument *PtrArg : PointerArgs) { // Replace sret attribute with noalias. This reduces register pressure by // avoiding a register copy. @@ -789,6 +767,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM, Types.push_back(Pair.second.Ty); if (areTypesABICompatible(Types, *F, TTI)) { + NumArgsAfterPromote += ArgParts.size() - 1; ArgsToPromote.insert({PtrArg, std::move(ArgParts)}); } } @@ -798,6 +777,9 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM, if (ArgsToPromote.empty()) return nullptr; + if (NumArgsAfterPromote > TTI.getMaxNumArgs()) + return nullptr; + return doPromotion(F, FAM, ArgsToPromote); } diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp index b9134ce26e80..847d07a49dee 100644 --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -15,16 +15,17 @@ #include "llvm/Transforms/IPO/Attributor.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MustExecute.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantFold.h" @@ -35,14 +36,15 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/GraphWriter.h" +#include "llvm/Support/ModRef.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -98,11 +100,6 @@ static cl::opt<unsigned, true> MaxInitializationChainLengthX( cl::location(MaxInitializationChainLength), cl::init(1024)); unsigned llvm::MaxInitializationChainLength; -static cl::opt<bool> VerifyMaxFixpointIterations( - "attributor-max-iterations-verify", cl::Hidden, - cl::desc("Verify that max-iterations is a tight bound for a fixpoint"), - cl::init(false)); - static cl::opt<bool> AnnotateDeclarationCallSites( "attributor-annotate-decl-cs", cl::Hidden, cl::desc("Annotate call sites of function declarations."), cl::init(false)); @@ -188,6 +185,11 @@ ChangeStatus &llvm::operator&=(ChangeStatus &L, ChangeStatus R) { } ///} +bool AA::isGPU(const Module &M) { + Triple T(M.getTargetTriple()); + return T.isAMDGPU() || T.isNVPTX(); +} + bool AA::isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA) { // We are looking for volatile instructions or non-relaxed atomics. @@ -202,9 +204,10 @@ bool AA::isNoSyncInst(Attributor &A, const Instruction &I, if (AANoSync::isNoSyncIntrinsic(&I)) return true; - const auto &NoSyncAA = A.getAAFor<AANoSync>( - QueryingAA, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); - return NoSyncAA.isAssumedNoSync(); + bool IsKnownNoSync; + return AA::hasAssumedIRAttr<Attribute::NoSync>( + A, &QueryingAA, IRPosition::callsite_function(*CB), + DepClassTy::OPTIONAL, IsKnownNoSync); } if (!I.mayReadOrWriteMemory()) @@ -218,12 +221,12 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA, // TODO: See the AAInstanceInfo class comment. if (!ForAnalysisOnly) return false; - auto &InstanceInfoAA = A.getAAFor<AAInstanceInfo>( + auto *InstanceInfoAA = A.getAAFor<AAInstanceInfo>( QueryingAA, IRPosition::value(V), DepClassTy::OPTIONAL); - return InstanceInfoAA.isAssumedUniqueForAnalysis(); + return InstanceInfoAA && InstanceInfoAA->isAssumedUniqueForAnalysis(); } -Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty, +Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty, const TargetLibraryInfo *TLI, const DataLayout &DL, AA::RangeTy *RangePtr) { @@ -234,17 +237,31 @@ Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty, auto *GV = dyn_cast<GlobalVariable>(&Obj); if (!GV) return nullptr; - if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer())) - return nullptr; - if (!GV->hasInitializer()) - return UndefValue::get(&Ty); + + bool UsedAssumedInformation = false; + Constant *Initializer = nullptr; + if (A.hasGlobalVariableSimplificationCallback(*GV)) { + auto AssumedGV = A.getAssumedInitializerFromCallBack( + *GV, /* const AbstractAttribute *AA */ nullptr, UsedAssumedInformation); + Initializer = *AssumedGV; + if (!Initializer) + return nullptr; + } else { + if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer())) + return nullptr; + if (!GV->hasInitializer()) + return UndefValue::get(&Ty); + + if (!Initializer) + Initializer = GV->getInitializer(); + } if (RangePtr && !RangePtr->offsetOrSizeAreUnknown()) { APInt Offset = APInt(64, RangePtr->Offset); - return ConstantFoldLoadFromConst(GV->getInitializer(), &Ty, Offset, DL); + return ConstantFoldLoadFromConst(Initializer, &Ty, Offset, DL); } - return ConstantFoldLoadFromUniformValue(GV->getInitializer(), &Ty); + return ConstantFoldLoadFromUniformValue(Initializer, &Ty); } bool AA::isValidInScope(const Value &V, const Function *Scope) { @@ -396,6 +413,18 @@ static bool getPotentialCopiesOfMemoryValue( NullOnly = false; }; + auto AdjustWrittenValueType = [&](const AAPointerInfo::Access &Acc, + Value &V) { + Value *AdjV = AA::getWithType(V, *I.getType()); + if (!AdjV) { + LLVM_DEBUG(dbgs() << "Underlying object written but stored value " + "cannot be converted to read type: " + << *Acc.getRemoteInst() << " : " << *I.getType() + << "\n";); + } + return AdjV; + }; + auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) { if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead())) return true; @@ -417,7 +446,10 @@ static bool getPotentialCopiesOfMemoryValue( if (IsLoad) { assert(isa<LoadInst>(I) && "Expected load or store instruction only!"); if (!Acc.isWrittenValueUnknown()) { - NewCopies.push_back(Acc.getWrittenValue()); + Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue()); + if (!V) + return false; + NewCopies.push_back(V); NewCopyOrigins.push_back(Acc.getRemoteInst()); return true; } @@ -428,7 +460,10 @@ static bool getPotentialCopiesOfMemoryValue( << *Acc.getRemoteInst() << "\n";); return false; } - NewCopies.push_back(SI->getValueOperand()); + Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand()); + if (!V) + return false; + NewCopies.push_back(V); NewCopyOrigins.push_back(SI); } else { assert(isa<StoreInst>(I) && "Expected load or store instruction only!"); @@ -449,10 +484,13 @@ static bool getPotentialCopiesOfMemoryValue( bool HasBeenWrittenTo = false; AA::RangeTy Range; - auto &PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj), + auto *PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj), DepClassTy::NONE); - if (!PI.forallInterferingAccesses(A, QueryingAA, I, CheckAccess, - HasBeenWrittenTo, Range)) { + if (!PI || + !PI->forallInterferingAccesses(A, QueryingAA, I, + /* FindInterferingWrites */ IsLoad, + /* FindInterferingReads */ !IsLoad, + CheckAccess, HasBeenWrittenTo, Range)) { LLVM_DEBUG( dbgs() << "Failed to verify all interfering accesses for underlying object: " @@ -463,7 +501,7 @@ static bool getPotentialCopiesOfMemoryValue( if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) { const DataLayout &DL = A.getDataLayout(); Value *InitialValue = - AA::getInitialValueForObj(Obj, *I.getType(), TLI, DL, &Range); + AA::getInitialValueForObj(A, Obj, *I.getType(), TLI, DL, &Range); if (!InitialValue) { LLVM_DEBUG(dbgs() << "Could not determine required initial value of " "underlying object, abort!\n"); @@ -480,14 +518,14 @@ static bool getPotentialCopiesOfMemoryValue( NewCopyOrigins.push_back(nullptr); } - PIs.push_back(&PI); + PIs.push_back(PI); return true; }; - const auto &AAUO = A.getAAFor<AAUnderlyingObjects>( + const auto *AAUO = A.getAAFor<AAUnderlyingObjects>( QueryingAA, IRPosition::value(Ptr), DepClassTy::OPTIONAL); - if (!AAUO.forallUnderlyingObjects(Pred)) { + if (!AAUO || !AAUO->forallUnderlyingObjects(Pred)) { LLVM_DEBUG( dbgs() << "Underlying objects stored into could not be determined\n";); return false; @@ -530,27 +568,37 @@ bool AA::getPotentialCopiesOfStoredValue( static bool isAssumedReadOnlyOrReadNone(Attributor &A, const IRPosition &IRP, const AbstractAttribute &QueryingAA, bool RequireReadNone, bool &IsKnown) { + if (RequireReadNone) { + if (AA::hasAssumedIRAttr<Attribute::ReadNone>( + A, &QueryingAA, IRP, DepClassTy::OPTIONAL, IsKnown, + /* IgnoreSubsumingPositions */ true)) + return true; + } else if (AA::hasAssumedIRAttr<Attribute::ReadOnly>( + A, &QueryingAA, IRP, DepClassTy::OPTIONAL, IsKnown, + /* IgnoreSubsumingPositions */ true)) + return true; IRPosition::Kind Kind = IRP.getPositionKind(); if (Kind == IRPosition::IRP_FUNCTION || Kind == IRPosition::IRP_CALL_SITE) { - const auto &MemLocAA = + const auto *MemLocAA = A.getAAFor<AAMemoryLocation>(QueryingAA, IRP, DepClassTy::NONE); - if (MemLocAA.isAssumedReadNone()) { - IsKnown = MemLocAA.isKnownReadNone(); + if (MemLocAA && MemLocAA->isAssumedReadNone()) { + IsKnown = MemLocAA->isKnownReadNone(); if (!IsKnown) - A.recordDependence(MemLocAA, QueryingAA, DepClassTy::OPTIONAL); + A.recordDependence(*MemLocAA, QueryingAA, DepClassTy::OPTIONAL); return true; } } - const auto &MemBehaviorAA = + const auto *MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(QueryingAA, IRP, DepClassTy::NONE); - if (MemBehaviorAA.isAssumedReadNone() || - (!RequireReadNone && MemBehaviorAA.isAssumedReadOnly())) { - IsKnown = RequireReadNone ? MemBehaviorAA.isKnownReadNone() - : MemBehaviorAA.isKnownReadOnly(); + if (MemBehaviorAA && + (MemBehaviorAA->isAssumedReadNone() || + (!RequireReadNone && MemBehaviorAA->isAssumedReadOnly()))) { + IsKnown = RequireReadNone ? MemBehaviorAA->isKnownReadNone() + : MemBehaviorAA->isKnownReadOnly(); if (!IsKnown) - A.recordDependence(MemBehaviorAA, QueryingAA, DepClassTy::OPTIONAL); + A.recordDependence(*MemBehaviorAA, QueryingAA, DepClassTy::OPTIONAL); return true; } @@ -574,7 +622,7 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, const AbstractAttribute &QueryingAA, const AA::InstExclusionSetTy *ExclusionSet, std::function<bool(const Function &F)> GoBackwardsCB) { - LLVM_DEBUG({ + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, { dbgs() << "[AA] isPotentiallyReachable @" << ToFn.getName() << " from " << FromI << " [GBCB: " << bool(GoBackwardsCB) << "][#ExS: " << (ExclusionSet ? std::to_string(ExclusionSet->size()) : "none") @@ -584,6 +632,19 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, dbgs() << *ES << "\n"; }); + // We know kernels (generally) cannot be called from within the module. Thus, + // for reachability we would need to step back from a kernel which would allow + // us to reach anything anyway. Even if a kernel is invoked from another + // kernel, values like allocas and shared memory are not accessible. We + // implicitly check for this situation to avoid costly lookups. + if (GoBackwardsCB && &ToFn != FromI.getFunction() && + !GoBackwardsCB(*FromI.getFunction()) && ToFn.hasFnAttribute("kernel") && + FromI.getFunction()->hasFnAttribute("kernel")) { + LLVM_DEBUG(dbgs() << "[AA] assume kernel cannot be reached from within the " + "module; success\n";); + return false; + } + // If we can go arbitrarily backwards we will eventually reach an entry point // that can reach ToI. Only if a set of blocks through which we cannot go is // provided, or once we track internal functions not accessible from the @@ -611,10 +672,10 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, return true; LLVM_DEBUG(dbgs() << "[AA] check " << *ToI << " from " << *CurFromI << " intraprocedurally\n"); - const auto &ReachabilityAA = A.getAAFor<AAIntraFnReachability>( + const auto *ReachabilityAA = A.getAAFor<AAIntraFnReachability>( QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL); - bool Result = - ReachabilityAA.isAssumedReachable(A, *CurFromI, *ToI, ExclusionSet); + bool Result = !ReachabilityAA || ReachabilityAA->isAssumedReachable( + A, *CurFromI, *ToI, ExclusionSet); LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " " << (Result ? "can potentially " : "cannot ") << "reach " << *ToI << " [Intra]\n"); @@ -624,11 +685,11 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, bool Result = true; if (!ToFn.isDeclaration() && ToI) { - const auto &ToReachabilityAA = A.getAAFor<AAIntraFnReachability>( + const auto *ToReachabilityAA = A.getAAFor<AAIntraFnReachability>( QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL); const Instruction &EntryI = ToFn.getEntryBlock().front(); - Result = - ToReachabilityAA.isAssumedReachable(A, EntryI, *ToI, ExclusionSet); + Result = !ToReachabilityAA || ToReachabilityAA->isAssumedReachable( + A, EntryI, *ToI, ExclusionSet); LLVM_DEBUG(dbgs() << "[AA] Entry " << EntryI << " of @" << ToFn.getName() << " " << (Result ? "can potentially " : "cannot ") << "reach @" << *ToI << " [ToFn]\n"); @@ -637,10 +698,10 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, if (Result) { // The entry of the ToFn can reach the instruction ToI. If the current // instruction is already known to reach the ToFn. - const auto &FnReachabilityAA = A.getAAFor<AAInterFnReachability>( + const auto *FnReachabilityAA = A.getAAFor<AAInterFnReachability>( QueryingAA, IRPosition::function(*FromFn), DepClassTy::OPTIONAL); - Result = FnReachabilityAA.instructionCanReach(A, *CurFromI, ToFn, - ExclusionSet); + Result = !FnReachabilityAA || FnReachabilityAA->instructionCanReach( + A, *CurFromI, ToFn, ExclusionSet); LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " in @" << FromFn->getName() << " " << (Result ? "can potentially " : "cannot ") << "reach @" << ToFn.getName() << " [FromFn]\n"); @@ -649,11 +710,11 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, } // TODO: Check assumed nounwind. - const auto &ReachabilityAA = A.getAAFor<AAIntraFnReachability>( + const auto *ReachabilityAA = A.getAAFor<AAIntraFnReachability>( QueryingAA, IRPosition::function(*FromFn), DepClassTy::OPTIONAL); auto ReturnInstCB = [&](Instruction &Ret) { - bool Result = - ReachabilityAA.isAssumedReachable(A, *CurFromI, Ret, ExclusionSet); + bool Result = !ReachabilityAA || ReachabilityAA->isAssumedReachable( + A, *CurFromI, Ret, ExclusionSet); LLVM_DEBUG(dbgs() << "[AA][Ret] " << *CurFromI << " " << (Result ? "can potentially " : "cannot ") << "reach " << Ret << " [Intra]\n"); @@ -743,14 +804,15 @@ bool AA::isAssumedThreadLocalObject(Attributor &A, Value &Obj, << "' is thread local; stack objects are thread local.\n"); return true; } - const auto &NoCaptureAA = A.getAAFor<AANoCapture>( - QueryingAA, IRPosition::value(Obj), DepClassTy::OPTIONAL); + bool IsKnownNoCapture; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, &QueryingAA, IRPosition::value(Obj), DepClassTy::OPTIONAL, + IsKnownNoCapture); LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj << "' is " - << (NoCaptureAA.isAssumedNoCapture() ? "" : "not") - << " thread local; " - << (NoCaptureAA.isAssumedNoCapture() ? "non-" : "") + << (IsAssumedNoCapture ? "" : "not") << " thread local; " + << (IsAssumedNoCapture ? "non-" : "") << "captured stack object.\n"); - return NoCaptureAA.isAssumedNoCapture(); + return IsAssumedNoCapture; } if (auto *GV = dyn_cast<GlobalVariable>(&Obj)) { if (GV->isConstant()) { @@ -831,9 +893,9 @@ bool AA::isPotentiallyAffectedByBarrier(Attributor &A, return false; }; - const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>( + const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>( QueryingAA, IRPosition::value(*Ptr), DepClassTy::OPTIONAL); - if (!UnderlyingObjsAA.forallUnderlyingObjects(Pred)) + if (!UnderlyingObjsAA || !UnderlyingObjsAA->forallUnderlyingObjects(Pred)) return true; } return false; @@ -848,38 +910,42 @@ static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) { } /// Return true if the information provided by \p Attr was added to the -/// attribute list \p Attrs. This is only the case if it was not already present -/// in \p Attrs at the position describe by \p PK and \p AttrIdx. +/// attribute set \p AttrSet. This is only the case if it was not already +/// present in \p AttrSet. static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr, - AttributeList &Attrs, int AttrIdx, - bool ForceReplace = false) { + AttributeSet AttrSet, bool ForceReplace, + AttrBuilder &AB) { if (Attr.isEnumAttribute()) { Attribute::AttrKind Kind = Attr.getKindAsEnum(); - if (Attrs.hasAttributeAtIndex(AttrIdx, Kind)) - if (!ForceReplace && - isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind))) - return false; - Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr); + if (AttrSet.hasAttribute(Kind)) + return false; + AB.addAttribute(Kind); return true; } if (Attr.isStringAttribute()) { StringRef Kind = Attr.getKindAsString(); - if (Attrs.hasAttributeAtIndex(AttrIdx, Kind)) - if (!ForceReplace && - isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind))) + if (AttrSet.hasAttribute(Kind)) { + if (!ForceReplace) return false; - Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr); + } + AB.addAttribute(Kind, Attr.getValueAsString()); return true; } if (Attr.isIntAttribute()) { Attribute::AttrKind Kind = Attr.getKindAsEnum(); - if (Attrs.hasAttributeAtIndex(AttrIdx, Kind)) - if (!ForceReplace && - isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind))) + if (!ForceReplace && Kind == Attribute::Memory) { + MemoryEffects ME = Attr.getMemoryEffects() & AttrSet.getMemoryEffects(); + if (ME == AttrSet.getMemoryEffects()) return false; - Attrs = Attrs.removeAttributeAtIndex(Ctx, AttrIdx, Kind); - Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr); + AB.addMemoryAttr(ME); + return true; + } + if (AttrSet.hasAttribute(Kind)) { + if (!ForceReplace && isEqualOrWorse(Attr, AttrSet.getAttribute(Kind))) + return false; + } + AB.addAttribute(Attr); return true; } @@ -933,7 +999,7 @@ Argument *IRPosition::getAssociatedArgument() const { // If no callbacks were found, or none used the underlying call site operand // exclusively, use the direct callee argument if available. - const Function *Callee = CB.getCalledFunction(); + auto *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand()); if (Callee && Callee->arg_size() > unsigned(ArgNo)) return Callee->getArg(ArgNo); @@ -955,63 +1021,168 @@ ChangeStatus AbstractAttribute::update(Attributor &A) { return HasChanged; } +bool Attributor::getAttrsFromAssumes(const IRPosition &IRP, + Attribute::AttrKind AK, + SmallVectorImpl<Attribute> &Attrs) { + assert(IRP.getPositionKind() != IRPosition::IRP_INVALID && + "Did expect a valid position!"); + MustBeExecutedContextExplorer *Explorer = + getInfoCache().getMustBeExecutedContextExplorer(); + if (!Explorer) + return false; + + Value &AssociatedValue = IRP.getAssociatedValue(); + + const Assume2KnowledgeMap &A2K = + getInfoCache().getKnowledgeMap().lookup({&AssociatedValue, AK}); + + // Check if we found any potential assume use, if not we don't need to create + // explorer iterators. + if (A2K.empty()) + return false; + + LLVMContext &Ctx = AssociatedValue.getContext(); + unsigned AttrsSize = Attrs.size(); + auto EIt = Explorer->begin(IRP.getCtxI()), + EEnd = Explorer->end(IRP.getCtxI()); + for (const auto &It : A2K) + if (Explorer->findInContextOf(It.first, EIt, EEnd)) + Attrs.push_back(Attribute::get(Ctx, AK, It.second.Max)); + return AttrsSize != Attrs.size(); +} + +template <typename DescTy> ChangeStatus -IRAttributeManifest::manifestAttrs(Attributor &A, const IRPosition &IRP, - const ArrayRef<Attribute> &DeducedAttrs, - bool ForceReplace) { - Function *ScopeFn = IRP.getAnchorScope(); - IRPosition::Kind PK = IRP.getPositionKind(); - - // In the following some generic code that will manifest attributes in - // DeducedAttrs if they improve the current IR. Due to the different - // annotation positions we use the underlying AttributeList interface. - - AttributeList Attrs; - switch (PK) { - case IRPosition::IRP_INVALID: +Attributor::updateAttrMap(const IRPosition &IRP, + const ArrayRef<DescTy> &AttrDescs, + function_ref<bool(const DescTy &, AttributeSet, + AttributeMask &, AttrBuilder &)> + CB) { + if (AttrDescs.empty()) + return ChangeStatus::UNCHANGED; + switch (IRP.getPositionKind()) { case IRPosition::IRP_FLOAT: + case IRPosition::IRP_INVALID: return ChangeStatus::UNCHANGED; - case IRPosition::IRP_ARGUMENT: - case IRPosition::IRP_FUNCTION: - case IRPosition::IRP_RETURNED: - Attrs = ScopeFn->getAttributes(); - break; - case IRPosition::IRP_CALL_SITE: - case IRPosition::IRP_CALL_SITE_RETURNED: - case IRPosition::IRP_CALL_SITE_ARGUMENT: - Attrs = cast<CallBase>(IRP.getAnchorValue()).getAttributes(); + default: break; - } + }; + + AttributeList AL; + Value *AttrListAnchor = IRP.getAttrListAnchor(); + auto It = AttrsMap.find(AttrListAnchor); + if (It == AttrsMap.end()) + AL = IRP.getAttrList(); + else + AL = It->getSecond(); - ChangeStatus HasChanged = ChangeStatus::UNCHANGED; LLVMContext &Ctx = IRP.getAnchorValue().getContext(); - for (const Attribute &Attr : DeducedAttrs) { - if (!addIfNotExistent(Ctx, Attr, Attrs, IRP.getAttrIdx(), ForceReplace)) - continue; + auto AttrIdx = IRP.getAttrIdx(); + AttributeSet AS = AL.getAttributes(AttrIdx); + AttributeMask AM; + AttrBuilder AB(Ctx); - HasChanged = ChangeStatus::CHANGED; - } + ChangeStatus HasChanged = ChangeStatus::UNCHANGED; + for (const DescTy &AttrDesc : AttrDescs) + if (CB(AttrDesc, AS, AM, AB)) + HasChanged = ChangeStatus::CHANGED; if (HasChanged == ChangeStatus::UNCHANGED) - return HasChanged; + return ChangeStatus::UNCHANGED; - switch (PK) { - case IRPosition::IRP_ARGUMENT: - case IRPosition::IRP_FUNCTION: - case IRPosition::IRP_RETURNED: - ScopeFn->setAttributes(Attrs); - break; - case IRPosition::IRP_CALL_SITE: - case IRPosition::IRP_CALL_SITE_RETURNED: - case IRPosition::IRP_CALL_SITE_ARGUMENT: - cast<CallBase>(IRP.getAnchorValue()).setAttributes(Attrs); - break; - case IRPosition::IRP_INVALID: - case IRPosition::IRP_FLOAT: - break; + AL = AL.removeAttributesAtIndex(Ctx, AttrIdx, AM); + AL = AL.addAttributesAtIndex(Ctx, AttrIdx, AB); + AttrsMap[AttrListAnchor] = AL; + return ChangeStatus::CHANGED; +} + +bool Attributor::hasAttr(const IRPosition &IRP, + ArrayRef<Attribute::AttrKind> AttrKinds, + bool IgnoreSubsumingPositions, + Attribute::AttrKind ImpliedAttributeKind) { + bool Implied = false; + bool HasAttr = false; + auto HasAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet, + AttributeMask &, AttrBuilder &) { + if (AttrSet.hasAttribute(Kind)) { + Implied |= Kind != ImpliedAttributeKind; + HasAttr = true; + } + return false; + }; + for (const IRPosition &EquivIRP : SubsumingPositionIterator(IRP)) { + updateAttrMap<Attribute::AttrKind>(EquivIRP, AttrKinds, HasAttrCB); + if (HasAttr) + break; + // The first position returned by the SubsumingPositionIterator is + // always the position itself. If we ignore subsuming positions we + // are done after the first iteration. + if (IgnoreSubsumingPositions) + break; + Implied = true; + } + if (!HasAttr) { + Implied = true; + SmallVector<Attribute> Attrs; + for (Attribute::AttrKind AK : AttrKinds) + if (getAttrsFromAssumes(IRP, AK, Attrs)) { + HasAttr = true; + break; + } } - return HasChanged; + // Check if we should manifest the implied attribute kind at the IRP. + if (ImpliedAttributeKind != Attribute::None && HasAttr && Implied) + manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(), + ImpliedAttributeKind)}); + return HasAttr; +} + +void Attributor::getAttrs(const IRPosition &IRP, + ArrayRef<Attribute::AttrKind> AttrKinds, + SmallVectorImpl<Attribute> &Attrs, + bool IgnoreSubsumingPositions) { + auto CollectAttrCB = [&](const Attribute::AttrKind &Kind, + AttributeSet AttrSet, AttributeMask &, + AttrBuilder &) { + if (AttrSet.hasAttribute(Kind)) + Attrs.push_back(AttrSet.getAttribute(Kind)); + return false; + }; + for (const IRPosition &EquivIRP : SubsumingPositionIterator(IRP)) { + updateAttrMap<Attribute::AttrKind>(EquivIRP, AttrKinds, CollectAttrCB); + // The first position returned by the SubsumingPositionIterator is + // always the position itself. If we ignore subsuming positions we + // are done after the first iteration. + if (IgnoreSubsumingPositions) + break; + } + for (Attribute::AttrKind AK : AttrKinds) + getAttrsFromAssumes(IRP, AK, Attrs); +} + +ChangeStatus +Attributor::removeAttrs(const IRPosition &IRP, + const ArrayRef<Attribute::AttrKind> &AttrKinds) { + auto RemoveAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet, + AttributeMask &AM, AttrBuilder &) { + if (!AttrSet.hasAttribute(Kind)) + return false; + AM.addAttribute(Kind); + return true; + }; + return updateAttrMap<Attribute::AttrKind>(IRP, AttrKinds, RemoveAttrCB); +} + +ChangeStatus Attributor::manifestAttrs(const IRPosition &IRP, + const ArrayRef<Attribute> &Attrs, + bool ForceReplace) { + LLVMContext &Ctx = IRP.getAnchorValue().getContext(); + auto AddAttrCB = [&](const Attribute &Attr, AttributeSet AttrSet, + AttributeMask &, AttrBuilder &AB) { + return addIfNotExistent(Ctx, Attr, AttrSet, ForceReplace, AB); + }; + return updateAttrMap<Attribute>(IRP, Attrs, AddAttrCB); } const IRPosition IRPosition::EmptyKey(DenseMapInfo<void *>::getEmptyKey()); @@ -1021,7 +1192,7 @@ const IRPosition SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { IRPositions.emplace_back(IRP); - // Helper to determine if operand bundles on a call site are benin or + // Helper to determine if operand bundles on a call site are benign or // potentially problematic. We handle only llvm.assume for now. auto CanIgnoreOperandBundles = [](const CallBase &CB) { return (isa<IntrinsicInst>(CB) && @@ -1043,7 +1214,7 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { // TODO: We need to look at the operand bundles similar to the redirection // in CallBase. if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) - if (const Function *Callee = CB->getCalledFunction()) + if (auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand())) IRPositions.emplace_back(IRPosition::function(*Callee)); return; case IRPosition::IRP_CALL_SITE_RETURNED: @@ -1051,7 +1222,8 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { // TODO: We need to look at the operand bundles similar to the redirection // in CallBase. if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) { - if (const Function *Callee = CB->getCalledFunction()) { + if (auto *Callee = + dyn_cast_if_present<Function>(CB->getCalledOperand())) { IRPositions.emplace_back(IRPosition::returned(*Callee)); IRPositions.emplace_back(IRPosition::function(*Callee)); for (const Argument &Arg : Callee->args()) @@ -1071,7 +1243,7 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { // TODO: We need to look at the operand bundles similar to the redirection // in CallBase. if (!CB->hasOperandBundles() || CanIgnoreOperandBundles(*CB)) { - const Function *Callee = CB->getCalledFunction(); + auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand()); if (Callee) { if (Argument *Arg = IRP.getAssociatedArgument()) IRPositions.emplace_back(IRPosition::argument(*Arg)); @@ -1084,85 +1256,6 @@ SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { } } -bool IRPosition::hasAttr(ArrayRef<Attribute::AttrKind> AKs, - bool IgnoreSubsumingPositions, Attributor *A) const { - SmallVector<Attribute, 4> Attrs; - for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) { - for (Attribute::AttrKind AK : AKs) - if (EquivIRP.getAttrsFromIRAttr(AK, Attrs)) - return true; - // The first position returned by the SubsumingPositionIterator is - // always the position itself. If we ignore subsuming positions we - // are done after the first iteration. - if (IgnoreSubsumingPositions) - break; - } - if (A) - for (Attribute::AttrKind AK : AKs) - if (getAttrsFromAssumes(AK, Attrs, *A)) - return true; - return false; -} - -void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs, - SmallVectorImpl<Attribute> &Attrs, - bool IgnoreSubsumingPositions, Attributor *A) const { - for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) { - for (Attribute::AttrKind AK : AKs) - EquivIRP.getAttrsFromIRAttr(AK, Attrs); - // The first position returned by the SubsumingPositionIterator is - // always the position itself. If we ignore subsuming positions we - // are done after the first iteration. - if (IgnoreSubsumingPositions) - break; - } - if (A) - for (Attribute::AttrKind AK : AKs) - getAttrsFromAssumes(AK, Attrs, *A); -} - -bool IRPosition::getAttrsFromIRAttr(Attribute::AttrKind AK, - SmallVectorImpl<Attribute> &Attrs) const { - if (getPositionKind() == IRP_INVALID || getPositionKind() == IRP_FLOAT) - return false; - - AttributeList AttrList; - if (const auto *CB = dyn_cast<CallBase>(&getAnchorValue())) - AttrList = CB->getAttributes(); - else - AttrList = getAssociatedFunction()->getAttributes(); - - bool HasAttr = AttrList.hasAttributeAtIndex(getAttrIdx(), AK); - if (HasAttr) - Attrs.push_back(AttrList.getAttributeAtIndex(getAttrIdx(), AK)); - return HasAttr; -} - -bool IRPosition::getAttrsFromAssumes(Attribute::AttrKind AK, - SmallVectorImpl<Attribute> &Attrs, - Attributor &A) const { - assert(getPositionKind() != IRP_INVALID && "Did expect a valid position!"); - Value &AssociatedValue = getAssociatedValue(); - - const Assume2KnowledgeMap &A2K = - A.getInfoCache().getKnowledgeMap().lookup({&AssociatedValue, AK}); - - // Check if we found any potential assume use, if not we don't need to create - // explorer iterators. - if (A2K.empty()) - return false; - - LLVMContext &Ctx = AssociatedValue.getContext(); - unsigned AttrsSize = Attrs.size(); - MustBeExecutedContextExplorer &Explorer = - A.getInfoCache().getMustBeExecutedContextExplorer(); - auto EIt = Explorer.begin(getCtxI()), EEnd = Explorer.end(getCtxI()); - for (const auto &It : A2K) - if (Explorer.findInContextOf(It.first, EIt, EEnd)) - Attrs.push_back(Attribute::get(Ctx, AK, It.second.Max)); - return AttrsSize != Attrs.size(); -} - void IRPosition::verify() { #ifdef EXPENSIVE_CHECKS switch (getPositionKind()) { @@ -1285,35 +1378,67 @@ std::optional<Value *> Attributor::getAssumedSimplified( } bool Attributor::getAssumedSimplifiedValues( - const IRPosition &IRP, const AbstractAttribute *AA, + const IRPosition &InitialIRP, const AbstractAttribute *AA, SmallVectorImpl<AA::ValueAndContext> &Values, AA::ValueScope S, - bool &UsedAssumedInformation) { - // First check all callbacks provided by outside AAs. If any of them returns - // a non-null value that is different from the associated value, or - // std::nullopt, we assume it's simplified. - const auto &SimplificationCBs = SimplificationCallbacks.lookup(IRP); - for (const auto &CB : SimplificationCBs) { - std::optional<Value *> CBResult = CB(IRP, AA, UsedAssumedInformation); - if (!CBResult.has_value()) - continue; - Value *V = *CBResult; - if (!V) - return false; - if ((S & AA::ValueScope::Interprocedural) || - AA::isValidInScope(*V, IRP.getAnchorScope())) - Values.push_back(AA::ValueAndContext{*V, nullptr}); - else - return false; - } - if (!SimplificationCBs.empty()) - return true; + bool &UsedAssumedInformation, bool RecurseForSelectAndPHI) { + SmallPtrSet<Value *, 8> Seen; + SmallVector<IRPosition, 8> Worklist; + Worklist.push_back(InitialIRP); + while (!Worklist.empty()) { + const IRPosition &IRP = Worklist.pop_back_val(); + + // First check all callbacks provided by outside AAs. If any of them returns + // a non-null value that is different from the associated value, or + // std::nullopt, we assume it's simplified. + int NV = Values.size(); + const auto &SimplificationCBs = SimplificationCallbacks.lookup(IRP); + for (const auto &CB : SimplificationCBs) { + std::optional<Value *> CBResult = CB(IRP, AA, UsedAssumedInformation); + if (!CBResult.has_value()) + continue; + Value *V = *CBResult; + if (!V) + return false; + if ((S & AA::ValueScope::Interprocedural) || + AA::isValidInScope(*V, IRP.getAnchorScope())) + Values.push_back(AA::ValueAndContext{*V, nullptr}); + else + return false; + } + if (SimplificationCBs.empty()) { + // If no high-level/outside simplification occurred, use + // AAPotentialValues. + const auto *PotentialValuesAA = + getOrCreateAAFor<AAPotentialValues>(IRP, AA, DepClassTy::OPTIONAL); + if (PotentialValuesAA && PotentialValuesAA->getAssumedSimplifiedValues(*this, Values, S)) { + UsedAssumedInformation |= !PotentialValuesAA->isAtFixpoint(); + } else if (IRP.getPositionKind() != IRPosition::IRP_RETURNED) { + Values.push_back({IRP.getAssociatedValue(), IRP.getCtxI()}); + } else { + // TODO: We could visit all returns and add the operands. + return false; + } + } - // If no high-level/outside simplification occurred, use AAPotentialValues. - const auto &PotentialValuesAA = - getOrCreateAAFor<AAPotentialValues>(IRP, AA, DepClassTy::OPTIONAL); - if (!PotentialValuesAA.getAssumedSimplifiedValues(*this, Values, S)) - return false; - UsedAssumedInformation |= !PotentialValuesAA.isAtFixpoint(); + if (!RecurseForSelectAndPHI) + break; + + for (int I = NV, E = Values.size(); I < E; ++I) { + Value *V = Values[I].getValue(); + if (!isa<PHINode>(V) && !isa<SelectInst>(V)) + continue; + if (!Seen.insert(V).second) + continue; + // Move the last element to this slot. + Values[I] = Values[E - 1]; + // Eliminate the last slot, adjust the indices. + Values.pop_back(); + --E; + --I; + // Add a new value (select or phi) to the worklist. + Worklist.push_back(IRPosition::value(*V)); + } + } return true; } @@ -1325,7 +1450,8 @@ std::optional<Value *> Attributor::translateArgumentToCallSiteContent( if (*V == nullptr || isa<Constant>(*V)) return V; if (auto *Arg = dyn_cast<Argument>(*V)) - if (CB.getCalledFunction() == Arg->getParent()) + if (CB.getCalledOperand() == Arg->getParent() && + CB.arg_size() > Arg->getArgNo()) if (!Arg->hasPointeeInMemoryValueAttr()) return getAssumedSimplified( IRPosition::callsite_argument(CB, Arg->getArgNo()), AA, @@ -1346,6 +1472,8 @@ bool Attributor::isAssumedDead(const AbstractAttribute &AA, const AAIsDead *FnLivenessAA, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, DepClassTy DepClass) { + if (!Configuration.UseLiveness) + return false; const IRPosition &IRP = AA.getIRPosition(); if (!Functions.count(IRP.getAnchorScope())) return false; @@ -1358,6 +1486,8 @@ bool Attributor::isAssumedDead(const Use &U, const AAIsDead *FnLivenessAA, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, DepClassTy DepClass) { + if (!Configuration.UseLiveness) + return false; Instruction *UserI = dyn_cast<Instruction>(U.getUser()); if (!UserI) return isAssumedDead(IRPosition::value(*U.get()), QueryingAA, FnLivenessAA, @@ -1384,12 +1514,12 @@ bool Attributor::isAssumedDead(const Use &U, } else if (StoreInst *SI = dyn_cast<StoreInst>(UserI)) { if (!CheckBBLivenessOnly && SI->getPointerOperand() != U.get()) { const IRPosition IRP = IRPosition::inst(*SI); - const AAIsDead &IsDeadAA = + const AAIsDead *IsDeadAA = getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); - if (IsDeadAA.isRemovableStore()) { + if (IsDeadAA && IsDeadAA->isRemovableStore()) { if (QueryingAA) - recordDependence(IsDeadAA, *QueryingAA, DepClass); - if (!IsDeadAA.isKnown(AAIsDead::IS_REMOVABLE)) + recordDependence(*IsDeadAA, *QueryingAA, DepClass); + if (!IsDeadAA->isKnown(AAIsDead::IS_REMOVABLE)) UsedAssumedInformation = true; return true; } @@ -1406,6 +1536,8 @@ bool Attributor::isAssumedDead(const Instruction &I, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, DepClassTy DepClass, bool CheckForDeadStore) { + if (!Configuration.UseLiveness) + return false; const IRPosition::CallBaseContext *CBCtx = QueryingAA ? QueryingAA->getCallBaseContext() : nullptr; @@ -1414,11 +1546,11 @@ bool Attributor::isAssumedDead(const Instruction &I, const Function &F = *I.getFunction(); if (!FnLivenessAA || FnLivenessAA->getAnchorScope() != &F) - FnLivenessAA = &getOrCreateAAFor<AAIsDead>(IRPosition::function(F, CBCtx), - QueryingAA, DepClassTy::NONE); + FnLivenessAA = getOrCreateAAFor<AAIsDead>(IRPosition::function(F, CBCtx), + QueryingAA, DepClassTy::NONE); // Don't use recursive reasoning. - if (QueryingAA == FnLivenessAA) + if (!FnLivenessAA || QueryingAA == FnLivenessAA) return false; // If we have a context instruction and a liveness AA we use it. @@ -1435,25 +1567,25 @@ bool Attributor::isAssumedDead(const Instruction &I, return false; const IRPosition IRP = IRPosition::inst(I, CBCtx); - const AAIsDead &IsDeadAA = + const AAIsDead *IsDeadAA = getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); // Don't use recursive reasoning. - if (QueryingAA == &IsDeadAA) + if (!IsDeadAA || QueryingAA == IsDeadAA) return false; - if (IsDeadAA.isAssumedDead()) { + if (IsDeadAA->isAssumedDead()) { if (QueryingAA) - recordDependence(IsDeadAA, *QueryingAA, DepClass); - if (!IsDeadAA.isKnownDead()) + recordDependence(*IsDeadAA, *QueryingAA, DepClass); + if (!IsDeadAA->isKnownDead()) UsedAssumedInformation = true; return true; } - if (CheckForDeadStore && isa<StoreInst>(I) && IsDeadAA.isRemovableStore()) { + if (CheckForDeadStore && isa<StoreInst>(I) && IsDeadAA->isRemovableStore()) { if (QueryingAA) - recordDependence(IsDeadAA, *QueryingAA, DepClass); - if (!IsDeadAA.isKnownDead()) + recordDependence(*IsDeadAA, *QueryingAA, DepClass); + if (!IsDeadAA->isKnownDead()) UsedAssumedInformation = true; return true; } @@ -1466,6 +1598,8 @@ bool Attributor::isAssumedDead(const IRPosition &IRP, const AAIsDead *FnLivenessAA, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, DepClassTy DepClass) { + if (!Configuration.UseLiveness) + return false; // Don't check liveness for constants, e.g. functions, used as (floating) // values since the context instruction and such is here meaningless. if (IRP.getPositionKind() == IRPosition::IRP_FLOAT && @@ -1486,14 +1620,14 @@ bool Attributor::isAssumedDead(const IRPosition &IRP, // If we haven't succeeded we query the specific liveness info for the IRP. const AAIsDead *IsDeadAA; if (IRP.getPositionKind() == IRPosition::IRP_CALL_SITE) - IsDeadAA = &getOrCreateAAFor<AAIsDead>( + IsDeadAA = getOrCreateAAFor<AAIsDead>( IRPosition::callsite_returned(cast<CallBase>(IRP.getAssociatedValue())), QueryingAA, DepClassTy::NONE); else - IsDeadAA = &getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); + IsDeadAA = getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); // Don't use recursive reasoning. - if (QueryingAA == IsDeadAA) + if (!IsDeadAA || QueryingAA == IsDeadAA) return false; if (IsDeadAA->isAssumedDead()) { @@ -1511,13 +1645,15 @@ bool Attributor::isAssumedDead(const BasicBlock &BB, const AbstractAttribute *QueryingAA, const AAIsDead *FnLivenessAA, DepClassTy DepClass) { + if (!Configuration.UseLiveness) + return false; const Function &F = *BB.getParent(); if (!FnLivenessAA || FnLivenessAA->getAnchorScope() != &F) - FnLivenessAA = &getOrCreateAAFor<AAIsDead>(IRPosition::function(F), - QueryingAA, DepClassTy::NONE); + FnLivenessAA = getOrCreateAAFor<AAIsDead>(IRPosition::function(F), + QueryingAA, DepClassTy::NONE); // Don't use recursive reasoning. - if (QueryingAA == FnLivenessAA) + if (!FnLivenessAA || QueryingAA == FnLivenessAA) return false; if (FnLivenessAA->isAssumedDead(&BB)) { @@ -1570,8 +1706,8 @@ bool Attributor::checkForAllUses( const Function *ScopeFn = IRP.getAnchorScope(); const auto *LivenessAA = - ScopeFn ? &getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*ScopeFn), - DepClassTy::NONE) + ScopeFn ? getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*ScopeFn), + DepClassTy::NONE) : nullptr; while (!Worklist.empty()) { @@ -1777,49 +1913,26 @@ bool Attributor::shouldPropagateCallBaseContext(const IRPosition &IRP) { return EnableCallSiteSpecific; } -bool Attributor::checkForAllReturnedValuesAndReturnInsts( - function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred, - const AbstractAttribute &QueryingAA) { +bool Attributor::checkForAllReturnedValues(function_ref<bool(Value &)> Pred, + const AbstractAttribute &QueryingAA, + AA::ValueScope S, + bool RecurseForSelectAndPHI) { const IRPosition &IRP = QueryingAA.getIRPosition(); - // Since we need to provide return instructions we have to have an exact - // definition. const Function *AssociatedFunction = IRP.getAssociatedFunction(); if (!AssociatedFunction) return false; - // If this is a call site query we use the call site specific return values - // and liveness information. - // TODO: use the function scope once we have call site AAReturnedValues. - const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); - const auto &AARetVal = - getAAFor<AAReturnedValues>(QueryingAA, QueryIRP, DepClassTy::REQUIRED); - if (!AARetVal.getState().isValidState()) - return false; - - return AARetVal.checkForAllReturnedValuesAndReturnInsts(Pred); -} - -bool Attributor::checkForAllReturnedValues( - function_ref<bool(Value &)> Pred, const AbstractAttribute &QueryingAA) { - - const IRPosition &IRP = QueryingAA.getIRPosition(); - const Function *AssociatedFunction = IRP.getAssociatedFunction(); - if (!AssociatedFunction) - return false; - - // TODO: use the function scope once we have call site AAReturnedValues. - const IRPosition &QueryIRP = IRPosition::function( - *AssociatedFunction, QueryingAA.getCallBaseContext()); - const auto &AARetVal = - getAAFor<AAReturnedValues>(QueryingAA, QueryIRP, DepClassTy::REQUIRED); - if (!AARetVal.getState().isValidState()) + bool UsedAssumedInformation = false; + SmallVector<AA::ValueAndContext> Values; + if (!getAssumedSimplifiedValues( + IRPosition::returned(*AssociatedFunction), &QueryingAA, Values, S, + UsedAssumedInformation, RecurseForSelectAndPHI)) return false; - return AARetVal.checkForAllReturnedValuesAndReturnInsts( - [&](Value &RV, const SmallSetVector<ReturnInst *, 4> &) { - return Pred(RV); - }); + return llvm::all_of(Values, [&](const AA::ValueAndContext &VAC) { + return Pred(*VAC.getValue()); + }); } static bool checkForAllInstructionsImpl( @@ -1863,12 +1976,11 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, if (!Fn || Fn->isDeclaration()) return false; - // TODO: use the function scope once we have call site AAReturnedValues. const IRPosition &QueryIRP = IRPosition::function(*Fn); const auto *LivenessAA = - (CheckBBLivenessOnly || CheckPotentiallyDead) + CheckPotentiallyDead ? nullptr - : &(getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE)); + : (getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE)); auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn); if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA, @@ -1895,21 +2007,21 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, bool Attributor::checkForAllReadWriteInstructions( function_ref<bool(Instruction &)> Pred, AbstractAttribute &QueryingAA, bool &UsedAssumedInformation) { + TimeTraceScope TS("checkForAllReadWriteInstructions"); const Function *AssociatedFunction = QueryingAA.getIRPosition().getAssociatedFunction(); if (!AssociatedFunction) return false; - // TODO: use the function scope once we have call site AAReturnedValues. const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); - const auto &LivenessAA = + const auto *LivenessAA = getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE); for (Instruction *I : InfoCache.getReadOrWriteInstsForFunction(*AssociatedFunction)) { // Skip dead instructions. - if (isAssumedDead(IRPosition::inst(*I), &QueryingAA, &LivenessAA, + if (isAssumedDead(IRPosition::inst(*I), &QueryingAA, LivenessAA, UsedAssumedInformation)) continue; @@ -1954,11 +2066,9 @@ void Attributor::runTillFixpoint() { dbgs() << "[Attributor] InvalidAA: " << *InvalidAA << " has " << InvalidAA->Deps.size() << " required & optional dependences\n"); - while (!InvalidAA->Deps.empty()) { - const auto &Dep = InvalidAA->Deps.back(); - InvalidAA->Deps.pop_back(); - AbstractAttribute *DepAA = cast<AbstractAttribute>(Dep.getPointer()); - if (Dep.getInt() == unsigned(DepClassTy::OPTIONAL)) { + for (auto &DepIt : InvalidAA->Deps) { + AbstractAttribute *DepAA = cast<AbstractAttribute>(DepIt.getPointer()); + if (DepIt.getInt() == unsigned(DepClassTy::OPTIONAL)) { DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, dbgs() << " - recompute: " << *DepAA); Worklist.insert(DepAA); @@ -1973,16 +2083,16 @@ void Attributor::runTillFixpoint() { else ChangedAAs.push_back(DepAA); } + InvalidAA->Deps.clear(); } // Add all abstract attributes that are potentially dependent on one that // changed to the work list. - for (AbstractAttribute *ChangedAA : ChangedAAs) - while (!ChangedAA->Deps.empty()) { - Worklist.insert( - cast<AbstractAttribute>(ChangedAA->Deps.back().getPointer())); - ChangedAA->Deps.pop_back(); - } + for (AbstractAttribute *ChangedAA : ChangedAAs) { + for (auto &DepIt : ChangedAA->Deps) + Worklist.insert(cast<AbstractAttribute>(DepIt.getPointer())); + ChangedAA->Deps.clear(); + } LLVM_DEBUG(dbgs() << "[Attributor] #Iteration: " << IterationCounter << ", Worklist+Dependent size: " << Worklist.size() @@ -2019,8 +2129,7 @@ void Attributor::runTillFixpoint() { QueryAAsAwaitingUpdate.end()); QueryAAsAwaitingUpdate.clear(); - } while (!Worklist.empty() && - (IterationCounter++ < MaxIterations || VerifyMaxFixpointIterations)); + } while (!Worklist.empty() && (IterationCounter++ < MaxIterations)); if (IterationCounter > MaxIterations && !Functions.empty()) { auto Remark = [&](OptimizationRemarkMissed ORM) { @@ -2053,11 +2162,9 @@ void Attributor::runTillFixpoint() { NumAttributesTimedOut++; } - while (!ChangedAA->Deps.empty()) { - ChangedAAs.push_back( - cast<AbstractAttribute>(ChangedAA->Deps.back().getPointer())); - ChangedAA->Deps.pop_back(); - } + for (auto &DepIt : ChangedAA->Deps) + ChangedAAs.push_back(cast<AbstractAttribute>(DepIt.getPointer())); + ChangedAA->Deps.clear(); } LLVM_DEBUG({ @@ -2065,13 +2172,6 @@ void Attributor::runTillFixpoint() { dbgs() << "\n[Attributor] Finalized " << Visited.size() << " abstract attributes.\n"; }); - - if (VerifyMaxFixpointIterations && IterationCounter != MaxIterations) { - errs() << "\n[Attributor] Fixpoint iteration done after: " - << IterationCounter << "/" << MaxIterations << " iterations\n"; - llvm_unreachable("The fixpoint was not reached with exactly the number of " - "specified iterations!"); - } } void Attributor::registerForUpdate(AbstractAttribute &AA) { @@ -2141,17 +2241,31 @@ ChangeStatus Attributor::manifestAttributes() { (void)NumFinalAAs; if (NumFinalAAs != DG.SyntheticRoot.Deps.size()) { - for (unsigned u = NumFinalAAs; u < DG.SyntheticRoot.Deps.size(); ++u) + auto DepIt = DG.SyntheticRoot.Deps.begin(); + for (unsigned u = 0; u < NumFinalAAs; ++u) + ++DepIt; + for (unsigned u = NumFinalAAs; u < DG.SyntheticRoot.Deps.size(); + ++u, ++DepIt) { errs() << "Unexpected abstract attribute: " - << cast<AbstractAttribute>(DG.SyntheticRoot.Deps[u].getPointer()) - << " :: " - << cast<AbstractAttribute>(DG.SyntheticRoot.Deps[u].getPointer()) + << cast<AbstractAttribute>(DepIt->getPointer()) << " :: " + << cast<AbstractAttribute>(DepIt->getPointer()) ->getIRPosition() .getAssociatedValue() << "\n"; + } llvm_unreachable("Expected the final number of abstract attributes to " "remain unchanged!"); } + + for (auto &It : AttrsMap) { + AttributeList &AL = It.getSecond(); + const IRPosition &IRP = + isa<Function>(It.getFirst()) + ? IRPosition::function(*cast<Function>(It.getFirst())) + : IRPosition::callsite_function(*cast<CallBase>(It.getFirst())); + IRP.setAttrList(AL); + } + return ManifestChange; } @@ -2271,9 +2385,9 @@ ChangeStatus Attributor::cleanupIR() { if (CB->isArgOperand(U)) { unsigned Idx = CB->getArgOperandNo(U); CB->removeParamAttr(Idx, Attribute::NoUndef); - Function *Fn = CB->getCalledFunction(); - if (Fn && Fn->arg_size() > Idx) - Fn->removeParamAttr(Idx, Attribute::NoUndef); + auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand()); + if (Callee && Callee->arg_size() > Idx) + Callee->removeParamAttr(Idx, Attribute::NoUndef); } } if (isa<Constant>(NewV) && isa<BranchInst>(U->getUser())) { @@ -2484,9 +2598,9 @@ ChangeStatus Attributor::run() { } ChangeStatus Attributor::updateAA(AbstractAttribute &AA) { - TimeTraceScope TimeScope( - AA.getName() + std::to_string(AA.getIRPosition().getPositionKind()) + - "::updateAA"); + TimeTraceScope TimeScope("updateAA", [&]() { + return AA.getName() + std::to_string(AA.getIRPosition().getPositionKind()); + }); assert(Phase == AttributorPhase::UPDATE && "We can update AA only in the update stage!"); @@ -2672,7 +2786,10 @@ bool Attributor::isValidFunctionSignatureRewrite( ACS.getInstruction()->getType() != ACS.getCalledFunction()->getReturnType()) return false; - if (ACS.getCalledOperand()->getType() != Fn->getType()) + if (cast<CallBase>(ACS.getInstruction())->getCalledOperand()->getType() != + Fn->getType()) + return false; + if (ACS.getNumArgOperands() != Fn->arg_size()) return false; // Forbid must-tail calls for now. return !ACS.isCallbackCall() && !ACS.getInstruction()->isMustTailCall(); @@ -2698,7 +2815,8 @@ bool Attributor::isValidFunctionSignatureRewrite( // Avoid callbacks for now. bool UsedAssumedInformation = false; if (!checkForAllCallSites(CallSiteCanBeChanged, *Fn, true, nullptr, - UsedAssumedInformation)) { + UsedAssumedInformation, + /* CheckPotentiallyDead */ true)) { LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite all call sites\n"); return false; } @@ -3041,7 +3159,8 @@ void InformationCache::initializeInformationCache(const Function &CF, AddToAssumeUsesMap(*Assume->getArgOperand(0)); } else if (cast<CallInst>(I).isMustTailCall()) { FI.ContainsMustTailCall = true; - if (const Function *Callee = cast<CallInst>(I).getCalledFunction()) + if (auto *Callee = dyn_cast_if_present<Function>( + cast<CallInst>(I).getCalledOperand())) getFunctionInfo(*Callee).CalledViaMustTail = true; } [[fallthrough]]; @@ -3077,10 +3196,6 @@ void InformationCache::initializeInformationCache(const Function &CF, InlineableFunctions.insert(&F); } -AAResults *InformationCache::getAAResultsForFunction(const Function &F) { - return AG.getAnalysis<AAManager>(F); -} - InformationCache::FunctionInfo::~FunctionInfo() { // The instruction vectors are allocated using a BumpPtrAllocator, we need to // manually destroy them. @@ -3111,11 +3226,21 @@ void Attributor::rememberDependences() { DI.DepClass == DepClassTy::OPTIONAL) && "Expected required or optional dependence (1 bit)!"); auto &DepAAs = const_cast<AbstractAttribute &>(*DI.FromAA).Deps; - DepAAs.push_back(AbstractAttribute::DepTy( + DepAAs.insert(AbstractAttribute::DepTy( const_cast<AbstractAttribute *>(DI.ToAA), unsigned(DI.DepClass))); } } +template <Attribute::AttrKind AK, typename AAType> +void Attributor::checkAndQueryIRAttr(const IRPosition &IRP, + AttributeSet Attrs) { + bool IsKnown; + if (!Attrs.hasAttribute(AK)) + if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE, + IsKnown)) + getOrCreateAAFor<AAType>(IRP); +} + void Attributor::identifyDefaultAbstractAttributes(Function &F) { if (!VisitedFunctions.insert(&F).second) return; @@ -3134,89 +3259,114 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { } IRPosition FPos = IRPosition::function(F); + bool IsIPOAmendable = isFunctionIPOAmendable(F); + auto Attrs = F.getAttributes(); + auto FnAttrs = Attrs.getFnAttrs(); // Check for dead BasicBlocks in every function. // We need dead instruction detection because we do not want to deal with // broken IR in which SSA rules do not apply. getOrCreateAAFor<AAIsDead>(FPos); - // Every function might be "will-return". - getOrCreateAAFor<AAWillReturn>(FPos); - - // Every function might contain instructions that cause "undefined behavior". + // Every function might contain instructions that cause "undefined + // behavior". getOrCreateAAFor<AAUndefinedBehavior>(FPos); - // Every function can be nounwind. - getOrCreateAAFor<AANoUnwind>(FPos); + // Every function might be applicable for Heap-To-Stack conversion. + if (EnableHeapToStack) + getOrCreateAAFor<AAHeapToStack>(FPos); - // Every function might be marked "nosync" - getOrCreateAAFor<AANoSync>(FPos); + // Every function might be "must-progress". + checkAndQueryIRAttr<Attribute::MustProgress, AAMustProgress>(FPos, FnAttrs); // Every function might be "no-free". - getOrCreateAAFor<AANoFree>(FPos); + checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(FPos, FnAttrs); - // Every function might be "no-return". - getOrCreateAAFor<AANoReturn>(FPos); + // Every function might be "will-return". + checkAndQueryIRAttr<Attribute::WillReturn, AAWillReturn>(FPos, FnAttrs); - // Every function might be "no-recurse". - getOrCreateAAFor<AANoRecurse>(FPos); + // Everything that is visible from the outside (=function, argument, return + // positions), cannot be changed if the function is not IPO amendable. We can + // however analyse the code inside. + if (IsIPOAmendable) { - // Every function might be "readnone/readonly/writeonly/...". - getOrCreateAAFor<AAMemoryBehavior>(FPos); + // Every function can be nounwind. + checkAndQueryIRAttr<Attribute::NoUnwind, AANoUnwind>(FPos, FnAttrs); - // Every function can be "readnone/argmemonly/inaccessiblememonly/...". - getOrCreateAAFor<AAMemoryLocation>(FPos); + // Every function might be marked "nosync" + checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs); - // Every function can track active assumptions. - getOrCreateAAFor<AAAssumptionInfo>(FPos); + // Every function might be "no-return". + checkAndQueryIRAttr<Attribute::NoReturn, AANoReturn>(FPos, FnAttrs); - // Every function might be applicable for Heap-To-Stack conversion. - if (EnableHeapToStack) - getOrCreateAAFor<AAHeapToStack>(FPos); + // Every function might be "no-recurse". + checkAndQueryIRAttr<Attribute::NoRecurse, AANoRecurse>(FPos, FnAttrs); - // Return attributes are only appropriate if the return type is non void. - Type *ReturnType = F.getReturnType(); - if (!ReturnType->isVoidTy()) { - // Argument attribute "returned" --- Create only one per function even - // though it is an argument attribute. - getOrCreateAAFor<AAReturnedValues>(FPos); + // Every function can be "non-convergent". + if (Attrs.hasFnAttr(Attribute::Convergent)) + getOrCreateAAFor<AANonConvergent>(FPos); - IRPosition RetPos = IRPosition::returned(F); + // Every function might be "readnone/readonly/writeonly/...". + getOrCreateAAFor<AAMemoryBehavior>(FPos); - // Every returned value might be dead. - getOrCreateAAFor<AAIsDead>(RetPos); + // Every function can be "readnone/argmemonly/inaccessiblememonly/...". + getOrCreateAAFor<AAMemoryLocation>(FPos); - // Every function might be simplified. - bool UsedAssumedInformation = false; - getAssumedSimplified(RetPos, nullptr, UsedAssumedInformation, - AA::Intraprocedural); + // Every function can track active assumptions. + getOrCreateAAFor<AAAssumptionInfo>(FPos); - // Every returned value might be marked noundef. - getOrCreateAAFor<AANoUndef>(RetPos); + // Return attributes are only appropriate if the return type is non void. + Type *ReturnType = F.getReturnType(); + if (!ReturnType->isVoidTy()) { + IRPosition RetPos = IRPosition::returned(F); + AttributeSet RetAttrs = Attrs.getRetAttrs(); - if (ReturnType->isPointerTy()) { + // Every returned value might be dead. + getOrCreateAAFor<AAIsDead>(RetPos); - // Every function with pointer return type might be marked align. - getOrCreateAAFor<AAAlign>(RetPos); + // Every function might be simplified. + bool UsedAssumedInformation = false; + getAssumedSimplified(RetPos, nullptr, UsedAssumedInformation, + AA::Intraprocedural); + + // Every returned value might be marked noundef. + checkAndQueryIRAttr<Attribute::NoUndef, AANoUndef>(RetPos, RetAttrs); + + if (ReturnType->isPointerTy()) { - // Every function with pointer return type might be marked nonnull. - getOrCreateAAFor<AANonNull>(RetPos); + // Every function with pointer return type might be marked align. + getOrCreateAAFor<AAAlign>(RetPos); - // Every function with pointer return type might be marked noalias. - getOrCreateAAFor<AANoAlias>(RetPos); + // Every function with pointer return type might be marked nonnull. + checkAndQueryIRAttr<Attribute::NonNull, AANonNull>(RetPos, RetAttrs); - // Every function with pointer return type might be marked - // dereferenceable. - getOrCreateAAFor<AADereferenceable>(RetPos); + // Every function with pointer return type might be marked noalias. + checkAndQueryIRAttr<Attribute::NoAlias, AANoAlias>(RetPos, RetAttrs); + + // Every function with pointer return type might be marked + // dereferenceable. + getOrCreateAAFor<AADereferenceable>(RetPos); + } else if (AttributeFuncs::isNoFPClassCompatibleType(ReturnType)) { + getOrCreateAAFor<AANoFPClass>(RetPos); + } } } for (Argument &Arg : F.args()) { IRPosition ArgPos = IRPosition::argument(Arg); + auto ArgNo = Arg.getArgNo(); + AttributeSet ArgAttrs = Attrs.getParamAttrs(ArgNo); + + if (!IsIPOAmendable) { + if (Arg.getType()->isPointerTy()) + // Every argument with pointer type might be marked nofree. + checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(ArgPos, ArgAttrs); + continue; + } - // Every argument might be simplified. We have to go through the Attributor - // interface though as outside AAs can register custom simplification - // callbacks. + // Every argument might be simplified. We have to go through the + // Attributor interface though as outside AAs can register custom + // simplification callbacks. bool UsedAssumedInformation = false; getAssumedSimplified(ArgPos, /* AA */ nullptr, UsedAssumedInformation, AA::Intraprocedural); @@ -3225,14 +3375,14 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { getOrCreateAAFor<AAIsDead>(ArgPos); // Every argument might be marked noundef. - getOrCreateAAFor<AANoUndef>(ArgPos); + checkAndQueryIRAttr<Attribute::NoUndef, AANoUndef>(ArgPos, ArgAttrs); if (Arg.getType()->isPointerTy()) { // Every argument with pointer type might be marked nonnull. - getOrCreateAAFor<AANonNull>(ArgPos); + checkAndQueryIRAttr<Attribute::NonNull, AANonNull>(ArgPos, ArgAttrs); // Every argument with pointer type might be marked noalias. - getOrCreateAAFor<AANoAlias>(ArgPos); + checkAndQueryIRAttr<Attribute::NoAlias, AANoAlias>(ArgPos, ArgAttrs); // Every argument with pointer type might be marked dereferenceable. getOrCreateAAFor<AADereferenceable>(ArgPos); @@ -3241,17 +3391,20 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { getOrCreateAAFor<AAAlign>(ArgPos); // Every argument with pointer type might be marked nocapture. - getOrCreateAAFor<AANoCapture>(ArgPos); + checkAndQueryIRAttr<Attribute::NoCapture, AANoCapture>(ArgPos, ArgAttrs); // Every argument with pointer type might be marked // "readnone/readonly/writeonly/..." getOrCreateAAFor<AAMemoryBehavior>(ArgPos); // Every argument with pointer type might be marked nofree. - getOrCreateAAFor<AANoFree>(ArgPos); + checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(ArgPos, ArgAttrs); - // Every argument with pointer type might be privatizable (or promotable) + // Every argument with pointer type might be privatizable (or + // promotable) getOrCreateAAFor<AAPrivatizablePtr>(ArgPos); + } else if (AttributeFuncs::isNoFPClassCompatibleType(Arg.getType())) { + getOrCreateAAFor<AANoFPClass>(ArgPos); } } @@ -3264,7 +3417,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // users. The return value might be dead if there are no live users. getOrCreateAAFor<AAIsDead>(CBInstPos); - Function *Callee = CB.getCalledFunction(); + Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand()); // TODO: Even if the callee is not known now we might be able to simplify // the call/callee. if (!Callee) @@ -3280,16 +3433,20 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { return true; if (!Callee->getReturnType()->isVoidTy() && !CB.use_empty()) { - IRPosition CBRetPos = IRPosition::callsite_returned(CB); bool UsedAssumedInformation = false; getAssumedSimplified(CBRetPos, nullptr, UsedAssumedInformation, AA::Intraprocedural); + + if (AttributeFuncs::isNoFPClassCompatibleType(Callee->getReturnType())) + getOrCreateAAFor<AANoFPClass>(CBInstPos); } + const AttributeList &CBAttrs = CBFnPos.getAttrList(); for (int I = 0, E = CB.arg_size(); I < E; ++I) { IRPosition CBArgPos = IRPosition::callsite_argument(CB, I); + AttributeSet CBArgAttrs = CBAttrs.getParamAttrs(I); // Every call site argument might be dead. getOrCreateAAFor<AAIsDead>(CBArgPos); @@ -3302,19 +3459,26 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { AA::Intraprocedural); // Every call site argument might be marked "noundef". - getOrCreateAAFor<AANoUndef>(CBArgPos); + checkAndQueryIRAttr<Attribute::NoUndef, AANoUndef>(CBArgPos, CBArgAttrs); + + Type *ArgTy = CB.getArgOperand(I)->getType(); + + if (!ArgTy->isPointerTy()) { + if (AttributeFuncs::isNoFPClassCompatibleType(ArgTy)) + getOrCreateAAFor<AANoFPClass>(CBArgPos); - if (!CB.getArgOperand(I)->getType()->isPointerTy()) continue; + } // Call site argument attribute "non-null". - getOrCreateAAFor<AANonNull>(CBArgPos); + checkAndQueryIRAttr<Attribute::NonNull, AANonNull>(CBArgPos, CBArgAttrs); // Call site argument attribute "nocapture". - getOrCreateAAFor<AANoCapture>(CBArgPos); + checkAndQueryIRAttr<Attribute::NoCapture, AANoCapture>(CBArgPos, + CBArgAttrs); // Call site argument attribute "no-alias". - getOrCreateAAFor<AANoAlias>(CBArgPos); + checkAndQueryIRAttr<Attribute::NoAlias, AANoAlias>(CBArgPos, CBArgAttrs); // Call site argument attribute "dereferenceable". getOrCreateAAFor<AADereferenceable>(CBArgPos); @@ -3324,10 +3488,11 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Call site argument attribute // "readnone/readonly/writeonly/..." - getOrCreateAAFor<AAMemoryBehavior>(CBArgPos); + if (!CBAttrs.hasParamAttr(I, Attribute::ReadNone)) + getOrCreateAAFor<AAMemoryBehavior>(CBArgPos); // Call site argument attribute "nofree". - getOrCreateAAFor<AANoFree>(CBArgPos); + checkAndQueryIRAttr<Attribute::NoFree, AANoFree>(CBArgPos, CBArgAttrs); } return true; }; @@ -3344,18 +3509,21 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { assert(Success && "Expected the check call to be successful!"); auto LoadStorePred = [&](Instruction &I) -> bool { - if (isa<LoadInst>(I)) { - getOrCreateAAFor<AAAlign>( - IRPosition::value(*cast<LoadInst>(I).getPointerOperand())); + if (auto *LI = dyn_cast<LoadInst>(&I)) { + getOrCreateAAFor<AAAlign>(IRPosition::value(*LI->getPointerOperand())); if (SimplifyAllLoads) getAssumedSimplified(IRPosition::value(I), nullptr, UsedAssumedInformation, AA::Intraprocedural); + getOrCreateAAFor<AAAddressSpace>( + IRPosition::value(*LI->getPointerOperand())); } else { auto &SI = cast<StoreInst>(I); getOrCreateAAFor<AAIsDead>(IRPosition::inst(I)); getAssumedSimplified(IRPosition::value(*SI.getValueOperand()), nullptr, UsedAssumedInformation, AA::Intraprocedural); getOrCreateAAFor<AAAlign>(IRPosition::value(*SI.getPointerOperand())); + getOrCreateAAFor<AAAddressSpace>( + IRPosition::value(*SI.getPointerOperand())); } return true; }; @@ -3461,7 +3629,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, return OS; } -void AbstractAttribute::print(raw_ostream &OS) const { +void AbstractAttribute::print(Attributor *A, raw_ostream &OS) const { OS << "["; OS << getName(); OS << "] for CtxI "; @@ -3473,7 +3641,7 @@ void AbstractAttribute::print(raw_ostream &OS) const { } else OS << "<<null inst>>"; - OS << " at position " << getIRPosition() << " with state " << getAsStr() + OS << " at position " << getIRPosition() << " with state " << getAsStr(A) << '\n'; } @@ -3679,11 +3847,11 @@ template <> struct GraphTraits<AADepGraphNode *> { using EdgeRef = PointerIntPair<AADepGraphNode *, 1>; static NodeRef getEntryNode(AADepGraphNode *DGN) { return DGN; } - static NodeRef DepGetVal(DepTy &DT) { return DT.getPointer(); } + static NodeRef DepGetVal(const DepTy &DT) { return DT.getPointer(); } using ChildIteratorType = - mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>; - using ChildEdgeIteratorType = TinyPtrVector<DepTy>::iterator; + mapped_iterator<AADepGraphNode::DepSetTy::iterator, decltype(&DepGetVal)>; + using ChildEdgeIteratorType = AADepGraphNode::DepSetTy::iterator; static ChildIteratorType child_begin(NodeRef N) { return N->child_begin(); } @@ -3695,7 +3863,7 @@ struct GraphTraits<AADepGraph *> : public GraphTraits<AADepGraphNode *> { static NodeRef getEntryNode(AADepGraph *DG) { return DG->GetEntryNode(); } using nodes_iterator = - mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>; + mapped_iterator<AADepGraphNode::DepSetTy::iterator, decltype(&DepGetVal)>; static nodes_iterator nodes_begin(AADepGraph *DG) { return DG->begin(); } @@ -3715,98 +3883,3 @@ template <> struct DOTGraphTraits<AADepGraph *> : public DefaultDOTGraphTraits { }; } // end namespace llvm - -namespace { - -struct AttributorLegacyPass : public ModulePass { - static char ID; - - AttributorLegacyPass() : ModulePass(ID) { - initializeAttributorLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - AnalysisGetter AG; - SetVector<Function *> Functions; - for (Function &F : M) - Functions.insert(&F); - - CallGraphUpdater CGUpdater; - BumpPtrAllocator Allocator; - InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); - return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns*/ true, - /* IsModulePass */ true); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - // FIXME: Think about passes we will preserve and add them here. - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; - -struct AttributorCGSCCLegacyPass : public CallGraphSCCPass { - static char ID; - - AttributorCGSCCLegacyPass() : CallGraphSCCPass(ID) { - initializeAttributorCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnSCC(CallGraphSCC &SCC) override { - if (skipSCC(SCC)) - return false; - - SetVector<Function *> Functions; - for (CallGraphNode *CGN : SCC) - if (Function *Fn = CGN->getFunction()) - if (!Fn->isDeclaration()) - Functions.insert(Fn); - - if (Functions.empty()) - return false; - - AnalysisGetter AG; - CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph()); - CallGraphUpdater CGUpdater; - CGUpdater.initialize(CG, SCC); - Module &M = *Functions.back()->getParent(); - BumpPtrAllocator Allocator; - InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions); - return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns */ false, - /* IsModulePass */ false); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - // FIXME: Think about passes we will preserve and add them here. - AU.addRequired<TargetLibraryInfoWrapperPass>(); - CallGraphSCCPass::getAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - -Pass *llvm::createAttributorLegacyPass() { return new AttributorLegacyPass(); } -Pass *llvm::createAttributorCGSCCLegacyPass() { - return new AttributorCGSCCLegacyPass(); -} - -char AttributorLegacyPass::ID = 0; -char AttributorCGSCCLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(AttributorLegacyPass, "attributor", - "Deduce and propagate attributes", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(AttributorLegacyPass, "attributor", - "Deduce and propagate attributes", false, false) -INITIALIZE_PASS_BEGIN(AttributorCGSCCLegacyPass, "attributor-cgscc", - "Deduce and propagate attributes (CGSCC pass)", false, - false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(AttributorCGSCCLegacyPass, "attributor-cgscc", - "Deduce and propagate attributes (CGSCC pass)", false, - false) diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 001ef55ba472..3a9a89d61355 100644 --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" @@ -38,6 +39,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Assumptions.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -52,6 +54,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" @@ -156,10 +159,11 @@ PIPE_OPERATOR(AAIsDead) PIPE_OPERATOR(AANoUnwind) PIPE_OPERATOR(AANoSync) PIPE_OPERATOR(AANoRecurse) +PIPE_OPERATOR(AANonConvergent) PIPE_OPERATOR(AAWillReturn) PIPE_OPERATOR(AANoReturn) -PIPE_OPERATOR(AAReturnedValues) PIPE_OPERATOR(AANonNull) +PIPE_OPERATOR(AAMustProgress) PIPE_OPERATOR(AANoAlias) PIPE_OPERATOR(AADereferenceable) PIPE_OPERATOR(AAAlign) @@ -177,11 +181,13 @@ PIPE_OPERATOR(AAUndefinedBehavior) PIPE_OPERATOR(AAPotentialConstantValues) PIPE_OPERATOR(AAPotentialValues) PIPE_OPERATOR(AANoUndef) +PIPE_OPERATOR(AANoFPClass) PIPE_OPERATOR(AACallEdges) PIPE_OPERATOR(AAInterFnReachability) PIPE_OPERATOR(AAPointerInfo) PIPE_OPERATOR(AAAssumptionInfo) PIPE_OPERATOR(AAUnderlyingObjects) +PIPE_OPERATOR(AAAddressSpace) #undef PIPE_OPERATOR @@ -196,6 +202,19 @@ ChangeStatus clampStateAndIndicateChange<DerefState>(DerefState &S, } // namespace llvm +static bool mayBeInCycle(const CycleInfo *CI, const Instruction *I, + bool HeaderOnly, Cycle **CPtr = nullptr) { + if (!CI) + return true; + auto *BB = I->getParent(); + auto *C = CI->getCycle(BB); + if (!C) + return false; + if (CPtr) + *CPtr = C; + return !HeaderOnly || BB == C->getHeader(); +} + /// Checks if a type could have padding bytes. static bool isDenselyPacked(Type *Ty, const DataLayout &DL) { // There is no size information, so be conservative. @@ -317,12 +336,14 @@ stripAndAccumulateOffsets(Attributor &A, const AbstractAttribute &QueryingAA, auto AttributorAnalysis = [&](Value &V, APInt &ROffset) -> bool { const IRPosition &Pos = IRPosition::value(V); // Only track dependence if we are going to use the assumed info. - const AAValueConstantRange &ValueConstantRangeAA = + const AAValueConstantRange *ValueConstantRangeAA = A.getAAFor<AAValueConstantRange>(QueryingAA, Pos, UseAssumed ? DepClassTy::OPTIONAL : DepClassTy::NONE); - ConstantRange Range = UseAssumed ? ValueConstantRangeAA.getAssumed() - : ValueConstantRangeAA.getKnown(); + if (!ValueConstantRangeAA) + return false; + ConstantRange Range = UseAssumed ? ValueConstantRangeAA->getAssumed() + : ValueConstantRangeAA->getKnown(); if (Range.isFullSet()) return false; @@ -355,7 +376,9 @@ getMinimalBaseOfPointer(Attributor &A, const AbstractAttribute &QueryingAA, /// Clamp the information known for all returned values of a function /// (identified by \p QueryingAA) into \p S. -template <typename AAType, typename StateType = typename AAType::StateType> +template <typename AAType, typename StateType = typename AAType::StateType, + Attribute::AttrKind IRAttributeKind = Attribute::None, + bool RecurseForSelectAndPHI = true> static void clampReturnedValueStates( Attributor &A, const AAType &QueryingAA, StateType &S, const IRPosition::CallBaseContext *CBContext = nullptr) { @@ -376,11 +399,20 @@ static void clampReturnedValueStates( // Callback for each possibly returned value. auto CheckReturnValue = [&](Value &RV) -> bool { const IRPosition &RVPos = IRPosition::value(RV, CBContext); - const AAType &AA = + // If possible, use the hasAssumedIRAttr interface. + if (IRAttributeKind != Attribute::None) { + bool IsKnown; + return AA::hasAssumedIRAttr<IRAttributeKind>( + A, &QueryingAA, RVPos, DepClassTy::REQUIRED, IsKnown); + } + + const AAType *AA = A.getAAFor<AAType>(QueryingAA, RVPos, DepClassTy::REQUIRED); - LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() - << " @ " << RVPos << "\n"); - const StateType &AAS = AA.getState(); + if (!AA) + return false; + LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV + << " AA: " << AA->getAsStr(&A) << " @ " << RVPos << "\n"); + const StateType &AAS = AA->getState(); if (!T) T = StateType::getBestState(AAS); *T &= AAS; @@ -389,7 +421,9 @@ static void clampReturnedValueStates( return T->isValidState(); }; - if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA)) + if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA, + AA::ValueScope::Intraprocedural, + RecurseForSelectAndPHI)) S.indicatePessimisticFixpoint(); else if (T) S ^= *T; @@ -399,7 +433,9 @@ namespace { /// Helper class for generic deduction: return value -> returned position. template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, - bool PropagateCallBaseContext = false> + bool PropagateCallBaseContext = false, + Attribute::AttrKind IRAttributeKind = Attribute::None, + bool RecurseForSelectAndPHI = true> struct AAReturnedFromReturnedValues : public BaseType { AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -407,7 +443,7 @@ struct AAReturnedFromReturnedValues : public BaseType { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { StateType S(StateType::getBestState(this->getState())); - clampReturnedValueStates<AAType, StateType>( + clampReturnedValueStates<AAType, StateType, IRAttributeKind, RecurseForSelectAndPHI>( A, *this, S, PropagateCallBaseContext ? this->getCallBaseContext() : nullptr); // TODO: If we know we visited all returned values, thus no are assumed @@ -418,7 +454,8 @@ struct AAReturnedFromReturnedValues : public BaseType { /// Clamp the information known at all call sites for a given argument /// (identified by \p QueryingAA) into \p S. -template <typename AAType, typename StateType = typename AAType::StateType> +template <typename AAType, typename StateType = typename AAType::StateType, + Attribute::AttrKind IRAttributeKind = Attribute::None> static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, StateType &S) { LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for " @@ -442,11 +479,21 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID) return false; - const AAType &AA = + // If possible, use the hasAssumedIRAttr interface. + if (IRAttributeKind != Attribute::None) { + bool IsKnown; + return AA::hasAssumedIRAttr<IRAttributeKind>( + A, &QueryingAA, ACSArgPos, DepClassTy::REQUIRED, IsKnown); + } + + const AAType *AA = A.getAAFor<AAType>(QueryingAA, ACSArgPos, DepClassTy::REQUIRED); + if (!AA) + return false; LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction() - << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n"); - const StateType &AAS = AA.getState(); + << " AA: " << AA->getAsStr(&A) << " @" << ACSArgPos + << "\n"); + const StateType &AAS = AA->getState(); if (!T) T = StateType::getBestState(AAS); *T &= AAS; @@ -466,7 +513,8 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, /// This function is the bridge between argument position and the call base /// context. template <typename AAType, typename BaseType, - typename StateType = typename AAType::StateType> + typename StateType = typename AAType::StateType, + Attribute::AttrKind IRAttributeKind = Attribute::None> bool getArgumentStateFromCallBaseContext(Attributor &A, BaseType &QueryingAttribute, IRPosition &Pos, StateType &State) { @@ -478,12 +526,21 @@ bool getArgumentStateFromCallBaseContext(Attributor &A, int ArgNo = Pos.getCallSiteArgNo(); assert(ArgNo >= 0 && "Invalid Arg No!"); + const IRPosition CBArgPos = IRPosition::callsite_argument(*CBContext, ArgNo); + + // If possible, use the hasAssumedIRAttr interface. + if (IRAttributeKind != Attribute::None) { + bool IsKnown; + return AA::hasAssumedIRAttr<IRAttributeKind>( + A, &QueryingAttribute, CBArgPos, DepClassTy::REQUIRED, IsKnown); + } - const auto &AA = A.getAAFor<AAType>( - QueryingAttribute, IRPosition::callsite_argument(*CBContext, ArgNo), - DepClassTy::REQUIRED); + const auto *AA = + A.getAAFor<AAType>(QueryingAttribute, CBArgPos, DepClassTy::REQUIRED); + if (!AA) + return false; const StateType &CBArgumentState = - static_cast<const StateType &>(AA.getState()); + static_cast<const StateType &>(AA->getState()); LLVM_DEBUG(dbgs() << "[Attributor] Briding Call site context to argument" << "Position:" << Pos << "CB Arg state:" << CBArgumentState @@ -497,7 +554,8 @@ bool getArgumentStateFromCallBaseContext(Attributor &A, /// Helper class for generic deduction: call site argument -> argument position. template <typename AAType, typename BaseType, typename StateType = typename AAType::StateType, - bool BridgeCallBaseContext = false> + bool BridgeCallBaseContext = false, + Attribute::AttrKind IRAttributeKind = Attribute::None> struct AAArgumentFromCallSiteArguments : public BaseType { AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -508,12 +566,14 @@ struct AAArgumentFromCallSiteArguments : public BaseType { if (BridgeCallBaseContext) { bool Success = - getArgumentStateFromCallBaseContext<AAType, BaseType, StateType>( + getArgumentStateFromCallBaseContext<AAType, BaseType, StateType, + IRAttributeKind>( A, *this, this->getIRPosition(), S); if (Success) return clampStateAndIndicateChange<StateType>(this->getState(), S); } - clampCallSiteArgumentStates<AAType, StateType>(A, *this, S); + clampCallSiteArgumentStates<AAType, StateType, IRAttributeKind>(A, *this, + S); // TODO: If we know we visited all incoming values, thus no are assumed // dead, we can take the known information from the state T. @@ -524,7 +584,8 @@ struct AAArgumentFromCallSiteArguments : public BaseType { /// Helper class for generic replication: function returned -> cs returned. template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, - bool IntroduceCallBaseContext = false> + bool IntroduceCallBaseContext = false, + Attribute::AttrKind IRAttributeKind = Attribute::None> struct AACallSiteReturnedFromReturned : public BaseType { AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -549,8 +610,20 @@ struct AACallSiteReturnedFromReturned : public BaseType { IRPosition FnPos = IRPosition::returned( *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr); - const AAType &AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(S, AA.getState()); + + // If possible, use the hasAssumedIRAttr interface. + if (IRAttributeKind != Attribute::None) { + bool IsKnown; + if (!AA::hasAssumedIRAttr<IRAttributeKind>(A, this, FnPos, + DepClassTy::REQUIRED, IsKnown)) + return S.indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } + + const AAType *AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED); + if (!AA) + return S.indicatePessimisticFixpoint(); + return clampStateAndIndicateChange(S, AA->getState()); } }; @@ -585,16 +658,17 @@ static void followUsesInContext(AAType &AA, Attributor &A, template <class AAType, typename StateType = typename AAType::StateType> static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S, Instruction &CtxI) { + MustBeExecutedContextExplorer *Explorer = + A.getInfoCache().getMustBeExecutedContextExplorer(); + if (!Explorer) + return; // Container for (transitive) uses of the associated value. SetVector<const Use *> Uses; for (const Use &U : AA.getIRPosition().getAssociatedValue().uses()) Uses.insert(&U); - MustBeExecutedContextExplorer &Explorer = - A.getInfoCache().getMustBeExecutedContextExplorer(); - - followUsesInContext<AAType>(AA, A, Explorer, &CtxI, Uses, S); + followUsesInContext<AAType>(AA, A, *Explorer, &CtxI, Uses, S); if (S.isAtFixpoint()) return; @@ -639,7 +713,7 @@ static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S, // } // } - Explorer.checkForAllContext(&CtxI, Pred); + Explorer->checkForAllContext(&CtxI, Pred); for (const BranchInst *Br : BrInsts) { StateType ParentState; @@ -651,7 +725,7 @@ static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S, StateType ChildState; size_t BeforeSize = Uses.size(); - followUsesInContext(AA, A, Explorer, &BB->front(), Uses, ChildState); + followUsesInContext(AA, A, *Explorer, &BB->front(), Uses, ChildState); // Erase uses which only appear in the child. for (auto It = Uses.begin() + BeforeSize; It != Uses.end();) @@ -855,7 +929,7 @@ protected: for (unsigned Index : LocalList->getSecond()) { for (auto &R : AccessList[Index]) { Range &= R; - if (Range.offsetOrSizeAreUnknown()) + if (Range.offsetAndSizeAreUnknown()) break; } } @@ -887,10 +961,8 @@ ChangeStatus AA::PointerInfo::State::addAccess( } auto AddToBins = [&](const AAPointerInfo::RangeList &ToAdd) { - LLVM_DEBUG( - if (ToAdd.size()) - dbgs() << "[AAPointerInfo] Inserting access in new offset bins\n"; - ); + LLVM_DEBUG(if (ToAdd.size()) dbgs() + << "[AAPointerInfo] Inserting access in new offset bins\n";); for (auto Key : ToAdd) { LLVM_DEBUG(dbgs() << " key " << Key << "\n"); @@ -923,10 +995,8 @@ ChangeStatus AA::PointerInfo::State::addAccess( // from the offset bins. AAPointerInfo::RangeList ToRemove; AAPointerInfo::RangeList::set_difference(ExistingRanges, NewRanges, ToRemove); - LLVM_DEBUG( - if (ToRemove.size()) - dbgs() << "[AAPointerInfo] Removing access from old offset bins\n"; - ); + LLVM_DEBUG(if (ToRemove.size()) dbgs() + << "[AAPointerInfo] Removing access from old offset bins\n";); for (auto Key : ToRemove) { LLVM_DEBUG(dbgs() << " key " << Key << "\n"); @@ -1011,7 +1081,7 @@ struct AAPointerInfoImpl AAPointerInfoImpl(const IRPosition &IRP, Attributor &A) : BaseTy(IRP) {} /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return std::string("PointerInfo ") + (isValidState() ? (std::string("#") + std::to_string(OffsetBins.size()) + " bins") @@ -1032,6 +1102,7 @@ struct AAPointerInfoImpl bool forallInterferingAccesses( Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I, + bool FindInterferingWrites, bool FindInterferingReads, function_ref<bool(const Access &, bool)> UserCB, bool &HasBeenWrittenTo, AA::RangeTy &Range) const override { HasBeenWrittenTo = false; @@ -1040,15 +1111,27 @@ struct AAPointerInfoImpl SmallVector<std::pair<const Access *, bool>, 8> InterferingAccesses; Function &Scope = *I.getFunction(); - const auto &NoSyncAA = A.getAAFor<AANoSync>( - QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); + bool IsKnownNoSync; + bool IsAssumedNoSync = AA::hasAssumedIRAttr<Attribute::NoSync>( + A, &QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL, + IsKnownNoSync); const auto *ExecDomainAA = A.lookupAAFor<AAExecutionDomain>( - IRPosition::function(Scope), &QueryingAA, DepClassTy::OPTIONAL); - bool AllInSameNoSyncFn = NoSyncAA.isAssumedNoSync(); + IRPosition::function(Scope), &QueryingAA, DepClassTy::NONE); + bool AllInSameNoSyncFn = IsAssumedNoSync; bool InstIsExecutedByInitialThreadOnly = ExecDomainAA && ExecDomainAA->isExecutedByInitialThreadOnly(I); + + // If the function is not ending in aligned barriers, we need the stores to + // be in aligned barriers. The load being in one is not sufficient since the + // store might be executed by a thread that disappears after, causing the + // aligned barrier guarding the load to unblock and the load to read a value + // that has no CFG path to the load. bool InstIsExecutedInAlignedRegion = - ExecDomainAA && ExecDomainAA->isExecutedInAlignedRegion(A, I); + FindInterferingReads && ExecDomainAA && + ExecDomainAA->isExecutedInAlignedRegion(A, I); + + if (InstIsExecutedInAlignedRegion || InstIsExecutedByInitialThreadOnly) + A.recordDependence(*ExecDomainAA, QueryingAA, DepClassTy::OPTIONAL); InformationCache &InfoCache = A.getInfoCache(); bool IsThreadLocalObj = @@ -1063,14 +1146,25 @@ struct AAPointerInfoImpl auto CanIgnoreThreadingForInst = [&](const Instruction &I) -> bool { if (IsThreadLocalObj || AllInSameNoSyncFn) return true; - if (!ExecDomainAA) + const auto *FnExecDomainAA = + I.getFunction() == &Scope + ? ExecDomainAA + : A.lookupAAFor<AAExecutionDomain>( + IRPosition::function(*I.getFunction()), &QueryingAA, + DepClassTy::NONE); + if (!FnExecDomainAA) return false; if (InstIsExecutedInAlignedRegion || - ExecDomainAA->isExecutedInAlignedRegion(A, I)) + (FindInterferingWrites && + FnExecDomainAA->isExecutedInAlignedRegion(A, I))) { + A.recordDependence(*FnExecDomainAA, QueryingAA, DepClassTy::OPTIONAL); return true; + } if (InstIsExecutedByInitialThreadOnly && - ExecDomainAA->isExecutedByInitialThreadOnly(I)) + FnExecDomainAA->isExecutedByInitialThreadOnly(I)) { + A.recordDependence(*FnExecDomainAA, QueryingAA, DepClassTy::OPTIONAL); return true; + } return false; }; @@ -1084,13 +1178,13 @@ struct AAPointerInfoImpl }; // TODO: Use inter-procedural reachability and dominance. - const auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); + bool IsKnownNoRecurse; + AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, this, IRPosition::function(Scope), DepClassTy::OPTIONAL, + IsKnownNoRecurse); - const bool FindInterferingWrites = I.mayReadFromMemory(); - const bool FindInterferingReads = I.mayWriteToMemory(); const bool UseDominanceReasoning = - FindInterferingWrites && NoRecurseAA.isKnownNoRecurse(); + FindInterferingWrites && IsKnownNoRecurse; const DominatorTree *DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(Scope); @@ -1098,8 +1192,7 @@ struct AAPointerInfoImpl // outlive a GPU kernel. This is true for shared, constant, and local // globals on AMD and NVIDIA GPUs. auto HasKernelLifetime = [&](Value *V, Module &M) { - Triple T(M.getTargetTriple()); - if (!(T.isAMDGPU() || T.isNVPTX())) + if (!AA::isGPU(M)) return false; switch (AA::GPUAddressSpace(V->getType()->getPointerAddressSpace())) { case AA::GPUAddressSpace::Shared: @@ -1122,9 +1215,10 @@ struct AAPointerInfoImpl // If the alloca containing function is not recursive the alloca // must be dead in the callee. const Function *AIFn = AI->getFunction(); - const auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - *this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL); - if (NoRecurseAA.isAssumedNoRecurse()) { + bool IsKnownNoRecurse; + if (AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL, + IsKnownNoRecurse)) { IsLiveInCalleeCB = [AIFn](const Function &Fn) { return AIFn != &Fn; }; } } else if (auto *GV = dyn_cast<GlobalValue>(&getAssociatedValue())) { @@ -1220,7 +1314,7 @@ struct AAPointerInfoImpl if (!WriteChecked && HasBeenWrittenTo && Acc.getRemoteInst()->getFunction() != &Scope) { - const auto &FnReachabilityAA = A.getAAFor<AAInterFnReachability>( + const auto *FnReachabilityAA = A.getAAFor<AAInterFnReachability>( QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); // Without going backwards in the call tree, can we reach the access @@ -1228,7 +1322,8 @@ struct AAPointerInfoImpl // itself either. bool Inserted = ExclusionSet.insert(&I).second; - if (!FnReachabilityAA.instructionCanReach( + if (!FnReachabilityAA || + !FnReachabilityAA->instructionCanReach( A, *LeastDominatingWriteInst, *Acc.getRemoteInst()->getFunction(), &ExclusionSet)) WriteChecked = true; @@ -1337,7 +1432,10 @@ struct AAPointerInfoImpl O << " --> " << *Acc.getRemoteInst() << "\n"; if (!Acc.isWrittenValueYetUndetermined()) { - if (Acc.getWrittenValue()) + if (isa_and_nonnull<Function>(Acc.getWrittenValue())) + O << " - c: func " << Acc.getWrittenValue()->getName() + << "\n"; + else if (Acc.getWrittenValue()) O << " - c: " << *Acc.getWrittenValue() << "\n"; else O << " - c: <unknown>\n"; @@ -1450,22 +1548,22 @@ bool AAPointerInfoFloating::collectConstantsForGEP(Attributor &A, // combination of elements, picked one each from these sets, is separately // added to the original set of offsets, thus resulting in more offsets. for (const auto &VI : VariableOffsets) { - auto &PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>( + auto *PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>( *this, IRPosition::value(*VI.first), DepClassTy::OPTIONAL); - if (!PotentialConstantsAA.isValidState()) { + if (!PotentialConstantsAA || !PotentialConstantsAA->isValidState()) { UsrOI.setUnknown(); return true; } // UndefValue is treated as a zero, which leaves Union as is. - if (PotentialConstantsAA.undefIsContained()) + if (PotentialConstantsAA->undefIsContained()) continue; // We need at least one constant in every set to compute an actual offset. // Otherwise, we end up pessimizing AAPointerInfo by respecting offsets that // don't actually exist. In other words, the absence of constant values // implies that the operation can be assumed dead for now. - auto &AssumedSet = PotentialConstantsAA.getAssumedSet(); + auto &AssumedSet = PotentialConstantsAA->getAssumedSet(); if (AssumedSet.empty()) return false; @@ -1602,16 +1700,6 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { return true; } - auto mayBeInCycleHeader = [](const CycleInfo *CI, const Instruction *I) { - if (!CI) - return true; - auto *BB = I->getParent(); - auto *C = CI->getCycle(BB); - if (!C) - return false; - return BB == C->getHeader(); - }; - // Check if the PHI operand is not dependent on the PHI itself. Every // recurrence is a cyclic net of PHIs in the data flow, and has an // equivalent Cycle in the control flow. One of those PHIs must be in the @@ -1619,7 +1707,7 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { // Cycles reported by CycleInfo. It is sufficient to check the PHIs in // every Cycle header; if such a node is marked unknown, this will // eventually propagate through the whole net of PHIs in the recurrence. - if (mayBeInCycleHeader(CI, cast<Instruction>(Usr))) { + if (mayBeInCycle(CI, cast<Instruction>(Usr), /* HeaderOnly */ true)) { auto BaseOI = It->getSecond(); BaseOI.addToAll(Offset.getZExtValue()); if (IsFirstPHIUser || BaseOI == UsrOI) { @@ -1681,6 +1769,8 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { return false; } else { auto PredIt = pred_begin(IntrBB); + if (PredIt == pred_end(IntrBB)) + return false; if ((*PredIt) != BB) return false; if (++PredIt != pred_end(IntrBB)) @@ -1780,11 +1870,14 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { return true; if (CB->isArgOperand(&U)) { unsigned ArgNo = CB->getArgOperandNo(&U); - const auto &CSArgPI = A.getAAFor<AAPointerInfo>( + const auto *CSArgPI = A.getAAFor<AAPointerInfo>( *this, IRPosition::callsite_argument(*CB, ArgNo), DepClassTy::REQUIRED); - Changed = translateAndAddState(A, CSArgPI, OffsetInfoMap[CurPtr], *CB) | - Changed; + if (!CSArgPI) + return false; + Changed = + translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB) | + Changed; return isValidState(); } LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB @@ -1845,13 +1938,6 @@ struct AAPointerInfoArgument final : AAPointerInfoFloating { AAPointerInfoArgument(const IRPosition &IRP, Attributor &A) : AAPointerInfoFloating(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AAPointerInfoFloating::initialize(A); - if (getAnchorScope()->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { AAPointerInfoImpl::trackPointerInfoStatistics(getIRPosition()); @@ -1900,19 +1986,18 @@ struct AAPointerInfoCallSiteArgument final : AAPointerInfoFloating { Argument *Arg = getAssociatedArgument(); if (Arg) { const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = + auto *ArgAA = A.getAAFor<AAPointerInfo>(*this, ArgPos, DepClassTy::REQUIRED); - if (ArgAA.getState().isValidState()) - return translateAndAddStateFromCallee(A, ArgAA, + if (ArgAA && ArgAA->getState().isValidState()) + return translateAndAddStateFromCallee(A, *ArgAA, *cast<CallBase>(getCtxI())); if (!Arg->getParent()->isDeclaration()) return indicatePessimisticFixpoint(); } - const auto &NoCaptureAA = - A.getAAFor<AANoCapture>(*this, getIRPosition(), DepClassTy::OPTIONAL); - - if (!NoCaptureAA.isAssumedNoCapture()) + bool IsKnownNoCapture; + if (!AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnownNoCapture)) return indicatePessimisticFixpoint(); bool IsKnown = false; @@ -1948,7 +2033,15 @@ namespace { struct AANoUnwindImpl : AANoUnwind { AANoUnwindImpl(const IRPosition &IRP, Attributor &A) : AANoUnwind(IRP, A) {} - const std::string getAsStr() const override { + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::NoUnwind>( + A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown)); + (void)IsKnown; + } + + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "nounwind" : "may-unwind"; } @@ -1960,13 +2053,14 @@ struct AANoUnwindImpl : AANoUnwind { (unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume}; auto CheckForNoUnwind = [&](Instruction &I) { - if (!I.mayThrow()) + if (!I.mayThrow(/* IncludePhaseOneUnwind */ true)) return true; if (const auto *CB = dyn_cast<CallBase>(&I)) { - const auto &NoUnwindAA = A.getAAFor<AANoUnwind>( - *this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED); - return NoUnwindAA.isAssumedNoUnwind(); + bool IsKnownNoUnwind; + return AA::hasAssumedIRAttr<Attribute::NoUnwind>( + A, this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED, + IsKnownNoUnwind); } return false; }; @@ -1993,14 +2087,6 @@ struct AANoUnwindCallSite final : AANoUnwindImpl { AANoUnwindCallSite(const IRPosition &IRP, Attributor &A) : AANoUnwindImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoUnwindImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -2009,263 +2095,15 @@ struct AANoUnwindCallSite final : AANoUnwindImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AANoUnwind>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); - } - - /// See AbstractAttribute::trackStatistics() - void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); } -}; -} // namespace - -/// --------------------- Function Return Values ------------------------------- - -namespace { -/// "Attribute" that collects all potential returned values and the return -/// instructions that they arise from. -/// -/// If there is a unique returned value R, the manifest method will: -/// - mark R with the "returned" attribute, if R is an argument. -class AAReturnedValuesImpl : public AAReturnedValues, public AbstractState { - - /// Mapping of values potentially returned by the associated function to the - /// return instructions that might return them. - MapVector<Value *, SmallSetVector<ReturnInst *, 4>> ReturnedValues; - - /// State flags - /// - ///{ - bool IsFixed = false; - bool IsValidState = true; - ///} - -public: - AAReturnedValuesImpl(const IRPosition &IRP, Attributor &A) - : AAReturnedValues(IRP, A) {} - - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - // Reset the state. - IsFixed = false; - IsValidState = true; - ReturnedValues.clear(); - - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) { - indicatePessimisticFixpoint(); - return; - } - assert(!F->getReturnType()->isVoidTy() && - "Did not expect a void return type!"); - - // The map from instruction opcodes to those instructions in the function. - auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(*F); - - // Look through all arguments, if one is marked as returned we are done. - for (Argument &Arg : F->args()) { - if (Arg.hasReturnedAttr()) { - auto &ReturnInstSet = ReturnedValues[&Arg]; - if (auto *Insts = OpcodeInstMap.lookup(Instruction::Ret)) - for (Instruction *RI : *Insts) - ReturnInstSet.insert(cast<ReturnInst>(RI)); - - indicateOptimisticFixpoint(); - return; - } - } - - if (!A.isFunctionIPOAmendable(*F)) - indicatePessimisticFixpoint(); - } - - /// See AbstractAttribute::manifest(...). - ChangeStatus manifest(Attributor &A) override; - - /// See AbstractAttribute::getState(...). - AbstractState &getState() override { return *this; } - - /// See AbstractAttribute::getState(...). - const AbstractState &getState() const override { return *this; } - - /// See AbstractAttribute::updateImpl(Attributor &A). - ChangeStatus updateImpl(Attributor &A) override; - - llvm::iterator_range<iterator> returned_values() override { - return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end()); - } - - llvm::iterator_range<const_iterator> returned_values() const override { - return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end()); - } - - /// Return the number of potential return values, -1 if unknown. - size_t getNumReturnValues() const override { - return isValidState() ? ReturnedValues.size() : -1; - } - - /// Return an assumed unique return value if a single candidate is found. If - /// there cannot be one, return a nullptr. If it is not clear yet, return - /// std::nullopt. - std::optional<Value *> getAssumedUniqueReturnValue(Attributor &A) const; - - /// See AbstractState::checkForAllReturnedValues(...). - bool checkForAllReturnedValuesAndReturnInsts( - function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred) - const override; - - /// Pretty print the attribute similar to the IR representation. - const std::string getAsStr() const override; - - /// See AbstractState::isAtFixpoint(). - bool isAtFixpoint() const override { return IsFixed; } - - /// See AbstractState::isValidState(). - bool isValidState() const override { return IsValidState; } - - /// See AbstractState::indicateOptimisticFixpoint(...). - ChangeStatus indicateOptimisticFixpoint() override { - IsFixed = true; - return ChangeStatus::UNCHANGED; - } - - ChangeStatus indicatePessimisticFixpoint() override { - IsFixed = true; - IsValidState = false; - return ChangeStatus::CHANGED; - } -}; - -ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) { - ChangeStatus Changed = ChangeStatus::UNCHANGED; - - // Bookkeeping. - assert(isValidState()); - STATS_DECLTRACK(KnownReturnValues, FunctionReturn, - "Number of function with known return values"); - - // Check if we have an assumed unique return value that we could manifest. - std::optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A); - - if (!UniqueRV || !*UniqueRV) - return Changed; - - // Bookkeeping. - STATS_DECLTRACK(UniqueReturnValue, FunctionReturn, - "Number of function with unique return"); - // If the assumed unique return value is an argument, annotate it. - if (auto *UniqueRVArg = dyn_cast<Argument>(*UniqueRV)) { - if (UniqueRVArg->getType()->canLosslesslyBitCastTo( - getAssociatedFunction()->getReturnType())) { - getIRPosition() = IRPosition::argument(*UniqueRVArg); - Changed = IRAttribute::manifest(A); - } - } - return Changed; -} - -const std::string AAReturnedValuesImpl::getAsStr() const { - return (isAtFixpoint() ? "returns(#" : "may-return(#") + - (isValidState() ? std::to_string(getNumReturnValues()) : "?") + ")"; -} - -std::optional<Value *> -AAReturnedValuesImpl::getAssumedUniqueReturnValue(Attributor &A) const { - // If checkForAllReturnedValues provides a unique value, ignoring potential - // undef values that can also be present, it is assumed to be the actual - // return value and forwarded to the caller of this method. If there are - // multiple, a nullptr is returned indicating there cannot be a unique - // returned value. - std::optional<Value *> UniqueRV; - Type *Ty = getAssociatedFunction()->getReturnType(); - - auto Pred = [&](Value &RV) -> bool { - UniqueRV = AA::combineOptionalValuesInAAValueLatice(UniqueRV, &RV, Ty); - return UniqueRV != std::optional<Value *>(nullptr); - }; - - if (!A.checkForAllReturnedValues(Pred, *this)) - UniqueRV = nullptr; - - return UniqueRV; -} - -bool AAReturnedValuesImpl::checkForAllReturnedValuesAndReturnInsts( - function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred) - const { - if (!isValidState()) - return false; - - // Check all returned values but ignore call sites as long as we have not - // encountered an overdefined one during an update. - for (const auto &It : ReturnedValues) { - Value *RV = It.first; - if (!Pred(*RV, It.second)) - return false; - } - - return true; -} - -ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) { - ChangeStatus Changed = ChangeStatus::UNCHANGED; - - SmallVector<AA::ValueAndContext> Values; - bool UsedAssumedInformation = false; - auto ReturnInstCB = [&](Instruction &I) { - ReturnInst &Ret = cast<ReturnInst>(I); - Values.clear(); - if (!A.getAssumedSimplifiedValues(IRPosition::value(*Ret.getReturnValue()), - *this, Values, AA::Intraprocedural, - UsedAssumedInformation)) - Values.push_back({*Ret.getReturnValue(), Ret}); - - for (auto &VAC : Values) { - assert(AA::isValidInScope(*VAC.getValue(), Ret.getFunction()) && - "Assumed returned value should be valid in function scope!"); - if (ReturnedValues[VAC.getValue()].insert(&Ret)) - Changed = ChangeStatus::CHANGED; - } - return true; - }; - - // Discover returned values from all live returned instructions in the - // associated function. - if (!A.checkForAllInstructions(ReturnInstCB, *this, {Instruction::Ret}, - UsedAssumedInformation)) - return indicatePessimisticFixpoint(); - return Changed; -} - -struct AAReturnedValuesFunction final : public AAReturnedValuesImpl { - AAReturnedValuesFunction(const IRPosition &IRP, Attributor &A) - : AAReturnedValuesImpl(IRP, A) {} - - /// See AbstractAttribute::trackStatistics() - void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(returned) } -}; - -/// Returned values information for a call sites. -struct AAReturnedValuesCallSite final : AAReturnedValuesImpl { - AAReturnedValuesCallSite(const IRPosition &IRP, Attributor &A) - : AAReturnedValuesImpl(IRP, A) {} - - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites instead of - // redirecting requests to the callee. - llvm_unreachable("Abstract attributes for returned values are not " - "supported for call sites yet!"); - } - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { + bool IsKnownNoUnwind; + if (AA::hasAssumedIRAttr<Attribute::NoUnwind>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoUnwind)) + return ChangeStatus::UNCHANGED; return indicatePessimisticFixpoint(); } /// See AbstractAttribute::trackStatistics() - void trackStatistics() const override {} + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); } }; } // namespace @@ -2334,7 +2172,15 @@ namespace { struct AANoSyncImpl : AANoSync { AANoSyncImpl(const IRPosition &IRP, Attributor &A) : AANoSync(IRP, A) {} - const std::string getAsStr() const override { + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::NoSync>(A, nullptr, getIRPosition(), + DepClassTy::NONE, IsKnown)); + (void)IsKnown; + } + + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "nosync" : "may-sync"; } @@ -2381,14 +2227,6 @@ struct AANoSyncCallSite final : AANoSyncImpl { AANoSyncCallSite(const IRPosition &IRP, Attributor &A) : AANoSyncImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoSyncImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -2397,8 +2235,11 @@ struct AANoSyncCallSite final : AANoSyncImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AANoSync>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + bool IsKnownNoSycn; + if (AA::hasAssumedIRAttr<Attribute::NoSync>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoSycn)) + return ChangeStatus::UNCHANGED; + return indicatePessimisticFixpoint(); } /// See AbstractAttribute::trackStatistics() @@ -2412,16 +2253,21 @@ namespace { struct AANoFreeImpl : public AANoFree { AANoFreeImpl(const IRPosition &IRP, Attributor &A) : AANoFree(IRP, A) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::NoFree>(A, nullptr, getIRPosition(), + DepClassTy::NONE, IsKnown)); + (void)IsKnown; + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { auto CheckForNoFree = [&](Instruction &I) { - const auto &CB = cast<CallBase>(I); - if (CB.hasFnAttr(Attribute::NoFree)) - return true; - - const auto &NoFreeAA = A.getAAFor<AANoFree>( - *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - return NoFreeAA.isAssumedNoFree(); + bool IsKnown; + return AA::hasAssumedIRAttr<Attribute::NoFree>( + A, this, IRPosition::callsite_function(cast<CallBase>(I)), + DepClassTy::REQUIRED, IsKnown); }; bool UsedAssumedInformation = false; @@ -2432,7 +2278,7 @@ struct AANoFreeImpl : public AANoFree { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "nofree" : "may-free"; } }; @@ -2450,14 +2296,6 @@ struct AANoFreeCallSite final : AANoFreeImpl { AANoFreeCallSite(const IRPosition &IRP, Attributor &A) : AANoFreeImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoFreeImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -2466,8 +2304,11 @@ struct AANoFreeCallSite final : AANoFreeImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AANoFree>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + bool IsKnown; + if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, FnPos, + DepClassTy::REQUIRED, IsKnown)) + return ChangeStatus::UNCHANGED; + return indicatePessimisticFixpoint(); } /// See AbstractAttribute::trackStatistics() @@ -2486,9 +2327,10 @@ struct AANoFreeFloating : AANoFreeImpl { ChangeStatus updateImpl(Attributor &A) override { const IRPosition &IRP = getIRPosition(); - const auto &NoFreeAA = A.getAAFor<AANoFree>( - *this, IRPosition::function_scope(IRP), DepClassTy::OPTIONAL); - if (NoFreeAA.isAssumedNoFree()) + bool IsKnown; + if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, + IRPosition::function_scope(IRP), + DepClassTy::OPTIONAL, IsKnown)) return ChangeStatus::UNCHANGED; Value &AssociatedValue = getIRPosition().getAssociatedValue(); @@ -2501,10 +2343,10 @@ struct AANoFreeFloating : AANoFreeImpl { return true; unsigned ArgNo = CB->getArgOperandNo(&U); - const auto &NoFreeArg = A.getAAFor<AANoFree>( - *this, IRPosition::callsite_argument(*CB, ArgNo), - DepClassTy::REQUIRED); - return NoFreeArg.isAssumedNoFree(); + bool IsKnown; + return AA::hasAssumedIRAttr<Attribute::NoFree>( + A, this, IRPosition::callsite_argument(*CB, ArgNo), + DepClassTy::REQUIRED, IsKnown); } if (isa<GetElementPtrInst>(UserI) || isa<BitCastInst>(UserI) || @@ -2550,8 +2392,11 @@ struct AANoFreeCallSiteArgument final : AANoFreeFloating { if (!Arg) return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = A.getAAFor<AANoFree>(*this, ArgPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), ArgAA.getState()); + bool IsKnown; + if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, ArgPos, + DepClassTy::REQUIRED, IsKnown)) + return ChangeStatus::UNCHANGED; + return indicatePessimisticFixpoint(); } /// See AbstractAttribute::trackStatistics() @@ -2593,6 +2438,39 @@ struct AANoFreeCallSiteReturned final : AANoFreeFloating { } // namespace /// ------------------------ NonNull Argument Attribute ------------------------ + +bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP, + Attribute::AttrKind ImpliedAttributeKind, + bool IgnoreSubsumingPositions) { + SmallVector<Attribute::AttrKind, 2> AttrKinds; + AttrKinds.push_back(Attribute::NonNull); + if (!NullPointerIsDefined(IRP.getAnchorScope(), + IRP.getAssociatedType()->getPointerAddressSpace())) + AttrKinds.push_back(Attribute::Dereferenceable); + if (A.hasAttr(IRP, AttrKinds, IgnoreSubsumingPositions, Attribute::NonNull)) + return true; + + if (IRP.getPositionKind() == IRP_RETURNED) + return false; + + DominatorTree *DT = nullptr; + AssumptionCache *AC = nullptr; + InformationCache &InfoCache = A.getInfoCache(); + if (const Function *Fn = IRP.getAnchorScope()) { + if (!Fn->isDeclaration()) { + DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*Fn); + AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn); + } + } + + if (!isKnownNonZero(&IRP.getAssociatedValue(), A.getDataLayout(), 0, AC, + IRP.getCtxI(), DT)) + return false; + A.manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(), + Attribute::NonNull)}); + return true; +} + namespace { static int64_t getKnownNonNullAndDerefBytesForUse( Attributor &A, const AbstractAttribute &QueryingAA, Value &AssociatedValue, @@ -2641,10 +2519,13 @@ static int64_t getKnownNonNullAndDerefBytesForUse( IRPosition IRP = IRPosition::callsite_argument(*CB, ArgNo); // As long as we only use known information there is no need to track // dependences here. - auto &DerefAA = + bool IsKnownNonNull; + AA::hasAssumedIRAttr<Attribute::NonNull>(A, &QueryingAA, IRP, + DepClassTy::NONE, IsKnownNonNull); + IsNonNull |= IsKnownNonNull; + auto *DerefAA = A.getAAFor<AADereferenceable>(QueryingAA, IRP, DepClassTy::NONE); - IsNonNull |= DerefAA.isKnownNonNull(); - return DerefAA.getKnownDereferenceableBytes(); + return DerefAA ? DerefAA->getKnownDereferenceableBytes() : 0; } std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); @@ -2673,43 +2554,16 @@ static int64_t getKnownNonNullAndDerefBytesForUse( } struct AANonNullImpl : AANonNull { - AANonNullImpl(const IRPosition &IRP, Attributor &A) - : AANonNull(IRP, A), - NullIsDefined(NullPointerIsDefined( - getAnchorScope(), - getAssociatedValue().getType()->getPointerAddressSpace())) {} + AANonNullImpl(const IRPosition &IRP, Attributor &A) : AANonNull(IRP, A) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { Value &V = *getAssociatedValue().stripPointerCasts(); - if (!NullIsDefined && - hasAttr({Attribute::NonNull, Attribute::Dereferenceable}, - /* IgnoreSubsumingPositions */ false, &A)) { - indicateOptimisticFixpoint(); - return; - } - if (isa<ConstantPointerNull>(V)) { indicatePessimisticFixpoint(); return; } - AANonNull::initialize(A); - - bool CanBeNull, CanBeFreed; - if (V.getPointerDereferenceableBytes(A.getDataLayout(), CanBeNull, - CanBeFreed)) { - if (!CanBeNull) { - indicateOptimisticFixpoint(); - return; - } - } - - if (isa<GlobalValue>(V)) { - indicatePessimisticFixpoint(); - return; - } - if (Instruction *CtxI = getCtxI()) followUsesInMBEC(*this, A, getState(), *CtxI); } @@ -2726,13 +2580,9 @@ struct AANonNullImpl : AANonNull { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "nonnull" : "may-null"; } - - /// Flag to determine if the underlying value can be null and still allow - /// valid accesses. - const bool NullIsDefined; }; /// NonNull attribute for a floating value. @@ -2742,48 +2592,39 @@ struct AANonNullFloating : public AANonNullImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - const DataLayout &DL = A.getDataLayout(); + auto CheckIRP = [&](const IRPosition &IRP) { + bool IsKnownNonNull; + return AA::hasAssumedIRAttr<Attribute::NonNull>( + A, *this, IRP, DepClassTy::OPTIONAL, IsKnownNonNull); + }; bool Stripped; bool UsedAssumedInformation = false; + Value *AssociatedValue = &getAssociatedValue(); SmallVector<AA::ValueAndContext> Values; if (!A.getAssumedSimplifiedValues(getIRPosition(), *this, Values, - AA::AnyScope, UsedAssumedInformation)) { - Values.push_back({getAssociatedValue(), getCtxI()}); + AA::AnyScope, UsedAssumedInformation)) Stripped = false; - } else { - Stripped = Values.size() != 1 || - Values.front().getValue() != &getAssociatedValue(); - } - - DominatorTree *DT = nullptr; - AssumptionCache *AC = nullptr; - InformationCache &InfoCache = A.getInfoCache(); - if (const Function *Fn = getAnchorScope()) { - DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*Fn); - AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn); + else + Stripped = + Values.size() != 1 || Values.front().getValue() != AssociatedValue; + + if (!Stripped) { + // If we haven't stripped anything we might still be able to use a + // different AA, but only if the IRP changes. Effectively when we + // interpret this not as a call site value but as a floating/argument + // value. + const IRPosition AVIRP = IRPosition::value(*AssociatedValue); + if (AVIRP == getIRPosition() || !CheckIRP(AVIRP)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; } - AANonNull::StateType T; - auto VisitValueCB = [&](Value &V, const Instruction *CtxI) -> bool { - const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V), - DepClassTy::REQUIRED); - if (!Stripped && this == &AA) { - if (!isKnownNonZero(&V, DL, 0, AC, CtxI, DT)) - T.indicatePessimisticFixpoint(); - } else { - // Use abstract attribute information. - const AANonNull::StateType &NS = AA.getState(); - T ^= NS; - } - return T.isValidState(); - }; - for (const auto &VAC : Values) - if (!VisitValueCB(*VAC.getValue(), VAC.getCtxI())) + if (!CheckIRP(IRPosition::value(*VAC.getValue()))) return indicatePessimisticFixpoint(); - return clampStateAndIndicateChange(getState(), T); + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -2792,12 +2633,14 @@ struct AANonNullFloating : public AANonNullImpl { /// NonNull attribute for function return value. struct AANonNullReturned final - : AAReturnedFromReturnedValues<AANonNull, AANonNull> { + : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType, + false, AANonNull::IRAttributeKind> { AANonNullReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AANonNull, AANonNull>(IRP, A) {} + : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType, + false, Attribute::NonNull>(IRP, A) {} /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "nonnull" : "may-null"; } @@ -2807,9 +2650,13 @@ struct AANonNullReturned final /// NonNull attribute for function argument. struct AANonNullArgument final - : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl> { + : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl, + AANonNull::StateType, false, + AANonNull::IRAttributeKind> { AANonNullArgument(const IRPosition &IRP, Attributor &A) - : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl>(IRP, A) {} + : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl, + AANonNull::StateType, false, + AANonNull::IRAttributeKind>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) } @@ -2825,23 +2672,118 @@ struct AANonNullCallSiteArgument final : AANonNullFloating { /// NonNull attribute for a call site return position. struct AANonNullCallSiteReturned final - : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl> { + : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl, + AANonNull::StateType, false, + AANonNull::IRAttributeKind> { AANonNullCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl>(IRP, A) {} + : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl, + AANonNull::StateType, false, + AANonNull::IRAttributeKind>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) } }; } // namespace +/// ------------------------ Must-Progress Attributes -------------------------- +namespace { +struct AAMustProgressImpl : public AAMustProgress { + AAMustProgressImpl(const IRPosition &IRP, Attributor &A) + : AAMustProgress(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::MustProgress>( + A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown)); + (void)IsKnown; + } + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr(Attributor *A) const override { + return getAssumed() ? "mustprogress" : "may-not-progress"; + } +}; + +struct AAMustProgressFunction final : AAMustProgressImpl { + AAMustProgressFunction(const IRPosition &IRP, Attributor &A) + : AAMustProgressImpl(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + bool IsKnown; + if (AA::hasAssumedIRAttr<Attribute::WillReturn>( + A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnown)) { + if (IsKnown) + return indicateOptimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } + + auto CheckForMustProgress = [&](AbstractCallSite ACS) { + IRPosition IPos = IRPosition::callsite_function(*ACS.getInstruction()); + bool IsKnownMustProgress; + return AA::hasAssumedIRAttr<Attribute::MustProgress>( + A, this, IPos, DepClassTy::REQUIRED, IsKnownMustProgress, + /* IgnoreSubsumingPositions */ true); + }; + + bool AllCallSitesKnown = true; + if (!A.checkForAllCallSites(CheckForMustProgress, *this, + /* RequireAllCallSites */ true, + AllCallSitesKnown)) + return indicatePessimisticFixpoint(); + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FN_ATTR(mustprogress) + } +}; + +/// MustProgress attribute deduction for a call sites. +struct AAMustProgressCallSite final : AAMustProgressImpl { + AAMustProgressCallSite(const IRPosition &IRP, Attributor &A) + : AAMustProgressImpl(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + const IRPosition &FnPos = IRPosition::function(*getAnchorScope()); + bool IsKnownMustProgress; + if (!AA::hasAssumedIRAttr<Attribute::MustProgress>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnownMustProgress)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CS_ATTR(mustprogress); + } +}; +} // namespace + /// ------------------------ No-Recurse Attributes ---------------------------- namespace { struct AANoRecurseImpl : public AANoRecurse { AANoRecurseImpl(const IRPosition &IRP, Attributor &A) : AANoRecurse(IRP, A) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown)); + (void)IsKnown; + } + /// See AbstractAttribute::getAsStr() - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "norecurse" : "may-recurse"; } }; @@ -2855,10 +2797,13 @@ struct AANoRecurseFunction final : AANoRecurseImpl { // If all live call sites are known to be no-recurse, we are as well. auto CallSitePred = [&](AbstractCallSite ACS) { - const auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - *this, IRPosition::function(*ACS.getInstruction()->getFunction()), - DepClassTy::NONE); - return NoRecurseAA.isKnownNoRecurse(); + bool IsKnownNoRecurse; + if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, this, + IRPosition::function(*ACS.getInstruction()->getFunction()), + DepClassTy::NONE, IsKnownNoRecurse)) + return false; + return IsKnownNoRecurse; }; bool UsedAssumedInformation = false; if (A.checkForAllCallSites(CallSitePred, *this, true, @@ -2873,10 +2818,10 @@ struct AANoRecurseFunction final : AANoRecurseImpl { return ChangeStatus::UNCHANGED; } - const AAInterFnReachability &EdgeReachability = + const AAInterFnReachability *EdgeReachability = A.getAAFor<AAInterFnReachability>(*this, getIRPosition(), DepClassTy::REQUIRED); - if (EdgeReachability.canReach(A, *getAnchorScope())) + if (EdgeReachability && EdgeReachability->canReach(A, *getAnchorScope())) return indicatePessimisticFixpoint(); return ChangeStatus::UNCHANGED; } @@ -2889,14 +2834,6 @@ struct AANoRecurseCallSite final : AANoRecurseImpl { AANoRecurseCallSite(const IRPosition &IRP, Attributor &A) : AANoRecurseImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoRecurseImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -2905,8 +2842,11 @@ struct AANoRecurseCallSite final : AANoRecurseImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AANoRecurse>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + bool IsKnownNoRecurse; + if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoRecurse)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -2914,6 +2854,62 @@ struct AANoRecurseCallSite final : AANoRecurseImpl { }; } // namespace +/// ------------------------ No-Convergent Attribute -------------------------- + +namespace { +struct AANonConvergentImpl : public AANonConvergent { + AANonConvergentImpl(const IRPosition &IRP, Attributor &A) + : AANonConvergent(IRP, A) {} + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr(Attributor *A) const override { + return getAssumed() ? "non-convergent" : "may-be-convergent"; + } +}; + +struct AANonConvergentFunction final : AANonConvergentImpl { + AANonConvergentFunction(const IRPosition &IRP, Attributor &A) + : AANonConvergentImpl(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // If all function calls are known to not be convergent, we are not + // convergent. + auto CalleeIsNotConvergent = [&](Instruction &Inst) { + CallBase &CB = cast<CallBase>(Inst); + auto *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand()); + if (!Callee || Callee->isIntrinsic()) { + return false; + } + if (Callee->isDeclaration()) { + return !Callee->hasFnAttribute(Attribute::Convergent); + } + const auto *ConvergentAA = A.getAAFor<AANonConvergent>( + *this, IRPosition::function(*Callee), DepClassTy::REQUIRED); + return ConvergentAA && ConvergentAA->isAssumedNotConvergent(); + }; + + bool UsedAssumedInformation = false; + if (!A.checkForAllCallLikeInstructions(CalleeIsNotConvergent, *this, + UsedAssumedInformation)) { + return indicatePessimisticFixpoint(); + } + return ChangeStatus::UNCHANGED; + } + + ChangeStatus manifest(Attributor &A) override { + if (isKnownNotConvergent() && + A.hasAttr(getIRPosition(), Attribute::Convergent)) { + A.removeAttrs(getIRPosition(), {Attribute::Convergent}); + return ChangeStatus::CHANGED; + } + return ChangeStatus::UNCHANGED; + } + + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(convergent) } +}; +} // namespace + /// -------------------- Undefined-Behavior Attributes ------------------------ namespace { @@ -3009,7 +3005,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // Check nonnull and noundef argument attribute violation for each // callsite. CallBase &CB = cast<CallBase>(I); - Function *Callee = CB.getCalledFunction(); + auto *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand()); if (!Callee) return true; for (unsigned idx = 0; idx < CB.arg_size(); idx++) { @@ -3030,9 +3026,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // (3) Simplified to null pointer where known to be nonnull. // The argument is a poison value and violate noundef attribute. IRPosition CalleeArgumentIRP = IRPosition::callsite_argument(CB, idx); - auto &NoUndefAA = - A.getAAFor<AANoUndef>(*this, CalleeArgumentIRP, DepClassTy::NONE); - if (!NoUndefAA.isKnownNoUndef()) + bool IsKnownNoUndef; + AA::hasAssumedIRAttr<Attribute::NoUndef>( + A, this, CalleeArgumentIRP, DepClassTy::NONE, IsKnownNoUndef); + if (!IsKnownNoUndef) continue; bool UsedAssumedInformation = false; std::optional<Value *> SimplifiedVal = @@ -3049,9 +3046,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { if (!ArgVal->getType()->isPointerTy() || !isa<ConstantPointerNull>(**SimplifiedVal)) continue; - auto &NonNullAA = - A.getAAFor<AANonNull>(*this, CalleeArgumentIRP, DepClassTy::NONE); - if (NonNullAA.isKnownNonNull()) + bool IsKnownNonNull; + AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, CalleeArgumentIRP, DepClassTy::NONE, IsKnownNonNull); + if (IsKnownNonNull) KnownUBInsts.insert(&I); } return true; @@ -3081,9 +3079,11 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // position has nonnull attribute (because the returned value is // poison). if (isa<ConstantPointerNull>(*SimplifiedRetValue)) { - auto &NonNullAA = A.getAAFor<AANonNull>( - *this, IRPosition::returned(*getAnchorScope()), DepClassTy::NONE); - if (NonNullAA.isKnownNonNull()) + bool IsKnownNonNull; + AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::returned(*getAnchorScope()), DepClassTy::NONE, + IsKnownNonNull); + if (IsKnownNonNull) KnownUBInsts.insert(&I); } @@ -3108,9 +3108,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { if (!getAnchorScope()->getReturnType()->isVoidTy()) { const IRPosition &ReturnIRP = IRPosition::returned(*getAnchorScope()); if (!A.isAssumedDead(ReturnIRP, this, nullptr, UsedAssumedInformation)) { - auto &RetPosNoUndefAA = - A.getAAFor<AANoUndef>(*this, ReturnIRP, DepClassTy::NONE); - if (RetPosNoUndefAA.isKnownNoUndef()) + bool IsKnownNoUndef; + AA::hasAssumedIRAttr<Attribute::NoUndef>( + A, this, ReturnIRP, DepClassTy::NONE, IsKnownNoUndef); + if (IsKnownNoUndef) A.checkForAllInstructions(InspectReturnInstForUB, *this, {Instruction::Ret}, UsedAssumedInformation, /* CheckBBLivenessOnly */ true); @@ -3161,7 +3162,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { } /// See AbstractAttribute::getAsStr() - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "undefined-behavior" : "no-ub"; } @@ -3284,20 +3285,15 @@ struct AAWillReturnImpl : public AAWillReturn { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - AAWillReturn::initialize(A); - - if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ true)) { - indicateOptimisticFixpoint(); - return; - } + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::WillReturn>( + A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown)); + (void)IsKnown; } /// Check for `mustprogress` and `readonly` as they imply `willreturn`. bool isImpliedByMustprogressAndReadonly(Attributor &A, bool KnownOnly) { - // Check for `mustprogress` in the scope and the associated function which - // might be different if this is a call site. - if ((!getAnchorScope() || !getAnchorScope()->mustProgress()) && - (!getAssociatedFunction() || !getAssociatedFunction()->mustProgress())) + if (!A.hasAttr(getIRPosition(), {Attribute::MustProgress})) return false; bool IsKnown; @@ -3313,15 +3309,17 @@ struct AAWillReturnImpl : public AAWillReturn { auto CheckForWillReturn = [&](Instruction &I) { IRPosition IPos = IRPosition::callsite_function(cast<CallBase>(I)); - const auto &WillReturnAA = - A.getAAFor<AAWillReturn>(*this, IPos, DepClassTy::REQUIRED); - if (WillReturnAA.isKnownWillReturn()) - return true; - if (!WillReturnAA.isAssumedWillReturn()) + bool IsKnown; + if (AA::hasAssumedIRAttr<Attribute::WillReturn>( + A, this, IPos, DepClassTy::REQUIRED, IsKnown)) { + if (IsKnown) + return true; + } else { return false; - const auto &NoRecurseAA = - A.getAAFor<AANoRecurse>(*this, IPos, DepClassTy::REQUIRED); - return NoRecurseAA.isAssumedNoRecurse(); + } + bool IsKnownNoRecurse; + return AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, this, IPos, DepClassTy::REQUIRED, IsKnownNoRecurse); }; bool UsedAssumedInformation = false; @@ -3333,7 +3331,7 @@ struct AAWillReturnImpl : public AAWillReturn { } /// See AbstractAttribute::getAsStr() - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "willreturn" : "may-noreturn"; } }; @@ -3347,7 +3345,8 @@ struct AAWillReturnFunction final : AAWillReturnImpl { AAWillReturnImpl::initialize(A); Function *F = getAnchorScope(); - if (!F || F->isDeclaration() || mayContainUnboundedCycle(*F, A)) + assert(F && "Did expect an anchor function"); + if (F->isDeclaration() || mayContainUnboundedCycle(*F, A)) indicatePessimisticFixpoint(); } @@ -3360,14 +3359,6 @@ struct AAWillReturnCallSite final : AAWillReturnImpl { AAWillReturnCallSite(const IRPosition &IRP, Attributor &A) : AAWillReturnImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AAWillReturnImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || !A.isFunctionIPOAmendable(*F)) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ false)) @@ -3379,8 +3370,11 @@ struct AAWillReturnCallSite final : AAWillReturnImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AAWillReturn>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + bool IsKnown; + if (AA::hasAssumedIRAttr<Attribute::WillReturn>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnown)) + return ChangeStatus::UNCHANGED; + return indicatePessimisticFixpoint(); } /// See AbstractAttribute::trackStatistics() @@ -3414,22 +3408,18 @@ template <typename ToTy> struct ReachabilityQueryInfo { /// Constructor replacement to ensure unique and stable sets are used for the /// cache. ReachabilityQueryInfo(Attributor &A, const Instruction &From, const ToTy &To, - const AA::InstExclusionSetTy *ES) + const AA::InstExclusionSetTy *ES, bool MakeUnique) : From(&From), To(&To), ExclusionSet(ES) { - if (ExclusionSet && !ExclusionSet->empty()) { - ExclusionSet = - A.getInfoCache().getOrCreateUniqueBlockExecutionSet(ExclusionSet); - } else { + if (!ES || ES->empty()) { ExclusionSet = nullptr; + } else if (MakeUnique) { + ExclusionSet = A.getInfoCache().getOrCreateUniqueBlockExecutionSet(ES); } } ReachabilityQueryInfo(const ReachabilityQueryInfo &RQI) - : From(RQI.From), To(RQI.To), ExclusionSet(RQI.ExclusionSet) { - assert(RQI.Result == Reachable::No && - "Didn't expect to copy an explored RQI!"); - } + : From(RQI.From), To(RQI.To), ExclusionSet(RQI.ExclusionSet) {} }; namespace llvm { @@ -3482,8 +3472,7 @@ template <typename BaseTy, typename ToTy> struct CachedReachabilityAA : public BaseTy { using RQITy = ReachabilityQueryInfo<ToTy>; - CachedReachabilityAA<BaseTy, ToTy>(const IRPosition &IRP, Attributor &A) - : BaseTy(IRP, A) {} + CachedReachabilityAA(const IRPosition &IRP, Attributor &A) : BaseTy(IRP, A) {} /// See AbstractAttribute::isQueryAA. bool isQueryAA() const override { return true; } @@ -3492,7 +3481,8 @@ struct CachedReachabilityAA : public BaseTy { ChangeStatus updateImpl(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; InUpdate = true; - for (RQITy *RQI : QueryVector) { + for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) { + RQITy *RQI = QueryVector[u]; if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI)) Changed = ChangeStatus::CHANGED; } @@ -3503,39 +3493,78 @@ struct CachedReachabilityAA : public BaseTy { virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0; bool rememberResult(Attributor &A, typename RQITy::Reachable Result, - RQITy &RQI) { - if (Result == RQITy::Reachable::No) { - if (!InUpdate) - A.registerForUpdate(*this); - return false; - } - assert(RQI.Result == RQITy::Reachable::No && "Already reachable?"); + RQITy &RQI, bool UsedExclusionSet) { RQI.Result = Result; - return true; + + // Remove the temporary RQI from the cache. + if (!InUpdate) + QueryCache.erase(&RQI); + + // Insert a plain RQI (w/o exclusion set) if that makes sense. Two options: + // 1) If it is reachable, it doesn't matter if we have an exclusion set for + // this query. 2) We did not use the exclusion set, potentially because + // there is none. + if (Result == RQITy::Reachable::Yes || !UsedExclusionSet) { + RQITy PlainRQI(RQI.From, RQI.To); + if (!QueryCache.count(&PlainRQI)) { + RQITy *RQIPtr = new (A.Allocator) RQITy(RQI.From, RQI.To); + RQIPtr->Result = Result; + QueryVector.push_back(RQIPtr); + QueryCache.insert(RQIPtr); + } + } + + // Check if we need to insert a new permanent RQI with the exclusion set. + if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) { + assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) && + "Did not expect empty set!"); + RQITy *RQIPtr = new (A.Allocator) + RQITy(A, *RQI.From, *RQI.To, RQI.ExclusionSet, true); + assert(RQIPtr->Result == RQITy::Reachable::No && "Already reachable?"); + RQIPtr->Result = Result; + assert(!QueryCache.count(RQIPtr)); + QueryVector.push_back(RQIPtr); + QueryCache.insert(RQIPtr); + } + + if (Result == RQITy::Reachable::No && !InUpdate) + A.registerForUpdate(*this); + return Result == RQITy::Reachable::Yes; } - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { // TODO: Return the number of reachable queries. return "#queries(" + std::to_string(QueryVector.size()) + ")"; } - RQITy *checkQueryCache(Attributor &A, RQITy &StackRQI, - typename RQITy::Reachable &Result) { + bool checkQueryCache(Attributor &A, RQITy &StackRQI, + typename RQITy::Reachable &Result) { if (!this->getState().isValidState()) { Result = RQITy::Reachable::Yes; - return nullptr; + return true; + } + + // If we have an exclusion set we might be able to find our answer by + // ignoring it first. + if (StackRQI.ExclusionSet) { + RQITy PlainRQI(StackRQI.From, StackRQI.To); + auto It = QueryCache.find(&PlainRQI); + if (It != QueryCache.end() && (*It)->Result == RQITy::Reachable::No) { + Result = RQITy::Reachable::No; + return true; + } } auto It = QueryCache.find(&StackRQI); if (It != QueryCache.end()) { Result = (*It)->Result; - return nullptr; + return true; } - RQITy *RQIPtr = new (A.Allocator) RQITy(StackRQI); - QueryVector.push_back(RQIPtr); - QueryCache.insert(RQIPtr); - return RQIPtr; + // Insert a temporary for recursive queries. We will replace it with a + // permanent entry later. + QueryCache.insert(&StackRQI); + return false; } private: @@ -3546,8 +3575,9 @@ private: struct AAIntraFnReachabilityFunction final : public CachedReachabilityAA<AAIntraFnReachability, Instruction> { + using Base = CachedReachabilityAA<AAIntraFnReachability, Instruction>; AAIntraFnReachabilityFunction(const IRPosition &IRP, Attributor &A) - : CachedReachabilityAA<AAIntraFnReachability, Instruction>(IRP, A) {} + : Base(IRP, A) {} bool isAssumedReachable( Attributor &A, const Instruction &From, const Instruction &To, @@ -3556,23 +3586,39 @@ struct AAIntraFnReachabilityFunction final if (&From == &To) return true; - RQITy StackRQI(A, From, To, ExclusionSet); + RQITy StackRQI(A, From, To, ExclusionSet, false); typename RQITy::Reachable Result; - if (RQITy *RQIPtr = NonConstThis->checkQueryCache(A, StackRQI, Result)) { - return NonConstThis->isReachableImpl(A, *RQIPtr); - } + if (!NonConstThis->checkQueryCache(A, StackRQI, Result)) + return NonConstThis->isReachableImpl(A, StackRQI); return Result == RQITy::Reachable::Yes; } + ChangeStatus updateImpl(Attributor &A) override { + // We only depend on liveness. DeadEdges is all we care about, check if any + // of them changed. + auto *LivenessAA = + A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (LivenessAA && llvm::all_of(DeadEdges, [&](const auto &DeadEdge) { + return LivenessAA->isEdgeDead(DeadEdge.first, DeadEdge.second); + })) { + return ChangeStatus::UNCHANGED; + } + DeadEdges.clear(); + return Base::updateImpl(A); + } + bool isReachableImpl(Attributor &A, RQITy &RQI) override { const Instruction *Origin = RQI.From; + bool UsedExclusionSet = false; - auto WillReachInBlock = [=](const Instruction &From, const Instruction &To, + auto WillReachInBlock = [&](const Instruction &From, const Instruction &To, const AA::InstExclusionSetTy *ExclusionSet) { const Instruction *IP = &From; while (IP && IP != &To) { - if (ExclusionSet && IP != Origin && ExclusionSet->count(IP)) + if (ExclusionSet && IP != Origin && ExclusionSet->count(IP)) { + UsedExclusionSet = true; break; + } IP = IP->getNextNode(); } return IP == &To; @@ -3587,7 +3633,12 @@ struct AAIntraFnReachabilityFunction final // possible. if (FromBB == ToBB && WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::Yes, RQI); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet); + + // Check if reaching the ToBB block is sufficient or if even that would not + // ensure reaching the target. In the latter case we are done. + if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet)) + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks; if (RQI.ExclusionSet) @@ -3598,40 +3649,80 @@ struct AAIntraFnReachabilityFunction final if (ExclusionBlocks.count(FromBB) && !WillReachInBlock(*RQI.From, *FromBB->getTerminator(), RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::No, RQI); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); SmallPtrSet<const BasicBlock *, 16> Visited; SmallVector<const BasicBlock *, 16> Worklist; Worklist.push_back(FromBB); - auto &LivenessAA = + DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> LocalDeadEdges; + auto *LivenessAA = A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); while (!Worklist.empty()) { const BasicBlock *BB = Worklist.pop_back_val(); if (!Visited.insert(BB).second) continue; for (const BasicBlock *SuccBB : successors(BB)) { - if (LivenessAA.isEdgeDead(BB, SuccBB)) + if (LivenessAA && LivenessAA->isEdgeDead(BB, SuccBB)) { + LocalDeadEdges.insert({BB, SuccBB}); continue; - if (SuccBB == ToBB && - WillReachInBlock(SuccBB->front(), *RQI.To, RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::Yes, RQI); - if (ExclusionBlocks.count(SuccBB)) + } + // We checked before if we just need to reach the ToBB block. + if (SuccBB == ToBB) + return rememberResult(A, RQITy::Reachable::Yes, RQI, + UsedExclusionSet); + if (ExclusionBlocks.count(SuccBB)) { + UsedExclusionSet = true; continue; + } Worklist.push_back(SuccBB); } } - return rememberResult(A, RQITy::Reachable::No, RQI); + DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end()); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override {} + +private: + // Set of assumed dead edges we used in the last query. If any changes we + // update the state. + DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> DeadEdges; }; } // namespace /// ------------------------ NoAlias Argument Attribute ------------------------ +bool AANoAlias::isImpliedByIR(Attributor &A, const IRPosition &IRP, + Attribute::AttrKind ImpliedAttributeKind, + bool IgnoreSubsumingPositions) { + assert(ImpliedAttributeKind == Attribute::NoAlias && + "Unexpected attribute kind"); + Value *Val = &IRP.getAssociatedValue(); + if (IRP.getPositionKind() != IRP_CALL_SITE_ARGUMENT) { + if (isa<AllocaInst>(Val)) + return true; + } else { + IgnoreSubsumingPositions = true; + } + + if (isa<UndefValue>(Val)) + return true; + + if (isa<ConstantPointerNull>(Val) && + !NullPointerIsDefined(IRP.getAnchorScope(), + Val->getType()->getPointerAddressSpace())) + return true; + + if (A.hasAttr(IRP, {Attribute::ByVal, Attribute::NoAlias}, + IgnoreSubsumingPositions, Attribute::NoAlias)) + return true; + + return false; +} + namespace { struct AANoAliasImpl : AANoAlias { AANoAliasImpl(const IRPosition &IRP, Attributor &A) : AANoAlias(IRP, A) { @@ -3639,7 +3730,7 @@ struct AANoAliasImpl : AANoAlias { "Noalias is a pointer attribute"); } - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "noalias" : "may-alias"; } }; @@ -3649,39 +3740,6 @@ struct AANoAliasFloating final : AANoAliasImpl { AANoAliasFloating(const IRPosition &IRP, Attributor &A) : AANoAliasImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoAliasImpl::initialize(A); - Value *Val = &getAssociatedValue(); - do { - CastInst *CI = dyn_cast<CastInst>(Val); - if (!CI) - break; - Value *Base = CI->getOperand(0); - if (!Base->hasOneUse()) - break; - Val = Base; - } while (true); - - if (!Val->getType()->isPointerTy()) { - indicatePessimisticFixpoint(); - return; - } - - if (isa<AllocaInst>(Val)) - indicateOptimisticFixpoint(); - else if (isa<ConstantPointerNull>(Val) && - !NullPointerIsDefined(getAnchorScope(), - Val->getType()->getPointerAddressSpace())) - indicateOptimisticFixpoint(); - else if (Val != &getAssociatedValue()) { - const auto &ValNoAliasAA = A.getAAFor<AANoAlias>( - *this, IRPosition::value(*Val), DepClassTy::OPTIONAL); - if (ValNoAliasAA.isKnownNoAlias()) - indicateOptimisticFixpoint(); - } - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Implement this. @@ -3696,18 +3754,14 @@ struct AANoAliasFloating final : AANoAliasImpl { /// NoAlias attribute for an argument. struct AANoAliasArgument final - : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> { - using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>; + : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl, + AANoAlias::StateType, false, + Attribute::NoAlias> { + using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl, + AANoAlias::StateType, false, + Attribute::NoAlias>; AANoAliasArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - Base::initialize(A); - // See callsite argument attribute and callee argument attribute. - if (hasAttr({Attribute::ByVal})) - indicateOptimisticFixpoint(); - } - /// See AbstractAttribute::update(...). ChangeStatus updateImpl(Attributor &A) override { // We have to make sure no-alias on the argument does not break @@ -3716,10 +3770,10 @@ struct AANoAliasArgument final // function, otherwise we give up for now. // If the function is no-sync, no-alias cannot break synchronization. - const auto &NoSyncAA = - A.getAAFor<AANoSync>(*this, IRPosition::function_scope(getIRPosition()), - DepClassTy::OPTIONAL); - if (NoSyncAA.isAssumedNoSync()) + bool IsKnownNoSycn; + if (AA::hasAssumedIRAttr<Attribute::NoSync>( + A, this, IRPosition::function_scope(getIRPosition()), + DepClassTy::OPTIONAL, IsKnownNoSycn)) return Base::updateImpl(A); // If the argument is read-only, no-alias cannot break synchronization. @@ -3752,19 +3806,6 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { AANoAliasCallSiteArgument(const IRPosition &IRP, Attributor &A) : AANoAliasImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - // See callsite argument attribute and callee argument attribute. - const auto &CB = cast<CallBase>(getAnchorValue()); - if (CB.paramHasAttr(getCallSiteArgNo(), Attribute::NoAlias)) - indicateOptimisticFixpoint(); - Value &Val = getAssociatedValue(); - if (isa<ConstantPointerNull>(Val) && - !NullPointerIsDefined(getAnchorScope(), - Val.getType()->getPointerAddressSpace())) - indicateOptimisticFixpoint(); - } - /// Determine if the underlying value may alias with the call site argument /// \p OtherArgNo of \p ICS (= the underlying call site). bool mayAliasWithArgument(Attributor &A, AAResults *&AAR, @@ -3779,27 +3820,29 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { if (!ArgOp->getType()->isPtrOrPtrVectorTy()) return false; - auto &CBArgMemBehaviorAA = A.getAAFor<AAMemoryBehavior>( + auto *CBArgMemBehaviorAA = A.getAAFor<AAMemoryBehavior>( *this, IRPosition::callsite_argument(CB, OtherArgNo), DepClassTy::NONE); // If the argument is readnone, there is no read-write aliasing. - if (CBArgMemBehaviorAA.isAssumedReadNone()) { - A.recordDependence(CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL); + if (CBArgMemBehaviorAA && CBArgMemBehaviorAA->isAssumedReadNone()) { + A.recordDependence(*CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL); return false; } // If the argument is readonly and the underlying value is readonly, there // is no read-write aliasing. bool IsReadOnly = MemBehaviorAA.isAssumedReadOnly(); - if (CBArgMemBehaviorAA.isAssumedReadOnly() && IsReadOnly) { + if (CBArgMemBehaviorAA && CBArgMemBehaviorAA->isAssumedReadOnly() && + IsReadOnly) { A.recordDependence(MemBehaviorAA, *this, DepClassTy::OPTIONAL); - A.recordDependence(CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL); + A.recordDependence(*CBArgMemBehaviorAA, *this, DepClassTy::OPTIONAL); return false; } // We have to utilize actual alias analysis queries so we need the object. if (!AAR) - AAR = A.getInfoCache().getAAResultsForFunction(*getAnchorScope()); + AAR = A.getInfoCache().getAnalysisResultForFunction<AAManager>( + *getAnchorScope()); // Try to rule it out at the call site. bool IsAliasing = !AAR || !AAR->isNoAlias(&getAssociatedValue(), ArgOp); @@ -3811,10 +3854,8 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { return IsAliasing; } - bool - isKnownNoAliasDueToNoAliasPreservation(Attributor &A, AAResults *&AAR, - const AAMemoryBehavior &MemBehaviorAA, - const AANoAlias &NoAliasAA) { + bool isKnownNoAliasDueToNoAliasPreservation( + Attributor &A, AAResults *&AAR, const AAMemoryBehavior &MemBehaviorAA) { // We can deduce "noalias" if the following conditions hold. // (i) Associated value is assumed to be noalias in the definition. // (ii) Associated value is assumed to be no-capture in all the uses @@ -3822,24 +3863,14 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { // (iii) There is no other pointer argument which could alias with the // value. - bool AssociatedValueIsNoAliasAtDef = NoAliasAA.isAssumedNoAlias(); - if (!AssociatedValueIsNoAliasAtDef) { - LLVM_DEBUG(dbgs() << "[AANoAlias] " << getAssociatedValue() - << " is not no-alias at the definition\n"); - return false; - } - auto IsDereferenceableOrNull = [&](Value *O, const DataLayout &DL) { - const auto &DerefAA = A.getAAFor<AADereferenceable>( + const auto *DerefAA = A.getAAFor<AADereferenceable>( *this, IRPosition::value(*O), DepClassTy::OPTIONAL); - return DerefAA.getAssumedDereferenceableBytes(); + return DerefAA ? DerefAA->getAssumedDereferenceableBytes() : 0; }; - A.recordDependence(NoAliasAA, *this, DepClassTy::OPTIONAL); - const IRPosition &VIRP = IRPosition::value(getAssociatedValue()); const Function *ScopeFn = VIRP.getAnchorScope(); - auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, VIRP, DepClassTy::NONE); // Check whether the value is captured in the scope using AANoCapture. // Look at CFG and check only uses possibly executed before this // callsite. @@ -3859,11 +3890,10 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { unsigned ArgNo = CB->getArgOperandNo(&U); - const auto &NoCaptureAA = A.getAAFor<AANoCapture>( - *this, IRPosition::callsite_argument(*CB, ArgNo), - DepClassTy::OPTIONAL); - - if (NoCaptureAA.isAssumedNoCapture()) + bool IsKnownNoCapture; + if (AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, IRPosition::callsite_argument(*CB, ArgNo), + DepClassTy::OPTIONAL, IsKnownNoCapture)) return true; } } @@ -3891,7 +3921,12 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { llvm_unreachable("unknown UseCaptureKind"); }; - if (!NoCaptureAA.isAssumedNoCaptureMaybeReturned()) { + bool IsKnownNoCapture; + const AANoCapture *NoCaptureAA = nullptr; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, VIRP, DepClassTy::NONE, IsKnownNoCapture, false, &NoCaptureAA); + if (!IsAssumedNoCapture && + (!NoCaptureAA || !NoCaptureAA->isAssumedNoCaptureMaybeReturned())) { if (!A.checkForAllUses(UsePred, *this, getAssociatedValue())) { LLVM_DEBUG( dbgs() << "[AANoAliasCSArg] " << getAssociatedValue() @@ -3899,7 +3934,8 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { return false; } } - A.recordDependence(NoCaptureAA, *this, DepClassTy::OPTIONAL); + if (NoCaptureAA) + A.recordDependence(*NoCaptureAA, *this, DepClassTy::OPTIONAL); // Check there is no other pointer argument which could alias with the // value passed at this call site. @@ -3916,20 +3952,25 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { ChangeStatus updateImpl(Attributor &A) override { // If the argument is readnone we are done as there are no accesses via the // argument. - auto &MemBehaviorAA = + auto *MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(*this, getIRPosition(), DepClassTy::NONE); - if (MemBehaviorAA.isAssumedReadNone()) { - A.recordDependence(MemBehaviorAA, *this, DepClassTy::OPTIONAL); + if (MemBehaviorAA && MemBehaviorAA->isAssumedReadNone()) { + A.recordDependence(*MemBehaviorAA, *this, DepClassTy::OPTIONAL); return ChangeStatus::UNCHANGED; } + bool IsKnownNoAlias; const IRPosition &VIRP = IRPosition::value(getAssociatedValue()); - const auto &NoAliasAA = - A.getAAFor<AANoAlias>(*this, VIRP, DepClassTy::NONE); + if (!AA::hasAssumedIRAttr<Attribute::NoAlias>( + A, this, VIRP, DepClassTy::REQUIRED, IsKnownNoAlias)) { + LLVM_DEBUG(dbgs() << "[AANoAlias] " << getAssociatedValue() + << " is not no-alias at the definition\n"); + return indicatePessimisticFixpoint(); + } AAResults *AAR = nullptr; - if (isKnownNoAliasDueToNoAliasPreservation(A, AAR, MemBehaviorAA, - NoAliasAA)) { + if (MemBehaviorAA && + isKnownNoAliasDueToNoAliasPreservation(A, AAR, *MemBehaviorAA)) { LLVM_DEBUG( dbgs() << "[AANoAlias] No-Alias deduced via no-alias preservation\n"); return ChangeStatus::UNCHANGED; @@ -3947,14 +3988,6 @@ struct AANoAliasReturned final : AANoAliasImpl { AANoAliasReturned(const IRPosition &IRP, Attributor &A) : AANoAliasImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoAliasImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { @@ -3969,14 +4002,18 @@ struct AANoAliasReturned final : AANoAliasImpl { return false; const IRPosition &RVPos = IRPosition::value(RV); - const auto &NoAliasAA = - A.getAAFor<AANoAlias>(*this, RVPos, DepClassTy::REQUIRED); - if (!NoAliasAA.isAssumedNoAlias()) + bool IsKnownNoAlias; + if (!AA::hasAssumedIRAttr<Attribute::NoAlias>( + A, this, RVPos, DepClassTy::REQUIRED, IsKnownNoAlias)) return false; - const auto &NoCaptureAA = - A.getAAFor<AANoCapture>(*this, RVPos, DepClassTy::REQUIRED); - return NoCaptureAA.isAssumedNoCaptureMaybeReturned(); + bool IsKnownNoCapture; + const AANoCapture *NoCaptureAA = nullptr; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, RVPos, DepClassTy::REQUIRED, IsKnownNoCapture, false, + &NoCaptureAA); + return IsAssumedNoCapture || + (NoCaptureAA && NoCaptureAA->isAssumedNoCaptureMaybeReturned()); }; if (!A.checkForAllReturnedValues(CheckReturnValue, *this)) @@ -3994,14 +4031,6 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl { AANoAliasCallSiteReturned(const IRPosition &IRP, Attributor &A) : AANoAliasImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoAliasImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -4010,8 +4039,11 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::returned(*F); - auto &FnAA = A.getAAFor<AANoAlias>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + bool IsKnownNoAlias; + if (!AA::hasAssumedIRAttr<Attribute::NoAlias>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoAlias)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -4025,13 +4057,6 @@ namespace { struct AAIsDeadValueImpl : public AAIsDead { AAIsDeadValueImpl(const IRPosition &IRP, Attributor &A) : AAIsDead(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - if (auto *Scope = getAnchorScope()) - if (!A.isRunOn(*Scope)) - indicatePessimisticFixpoint(); - } - /// See AAIsDead::isAssumedDead(). bool isAssumedDead() const override { return isAssumed(IS_DEAD); } @@ -4055,7 +4080,7 @@ struct AAIsDeadValueImpl : public AAIsDead { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return isAssumedDead() ? "assumed-dead" : "assumed-live"; } @@ -4097,12 +4122,11 @@ struct AAIsDeadValueImpl : public AAIsDead { return false; const IRPosition &CallIRP = IRPosition::callsite_function(*CB); - const auto &NoUnwindAA = - A.getAndUpdateAAFor<AANoUnwind>(*this, CallIRP, DepClassTy::NONE); - if (!NoUnwindAA.isAssumedNoUnwind()) + + bool IsKnownNoUnwind; + if (!AA::hasAssumedIRAttr<Attribute::NoUnwind>( + A, this, CallIRP, DepClassTy::OPTIONAL, IsKnownNoUnwind)) return false; - if (!NoUnwindAA.isKnownNoUnwind()) - A.recordDependence(NoUnwindAA, *this, DepClassTy::OPTIONAL); bool IsKnown; return AA::isAssumedReadOnly(A, CallIRP, *this, IsKnown); @@ -4124,13 +4148,22 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { Instruction *I = dyn_cast<Instruction>(&getAssociatedValue()); if (!isAssumedSideEffectFree(A, I)) { - if (!isa_and_nonnull<StoreInst>(I)) + if (!isa_and_nonnull<StoreInst>(I) && !isa_and_nonnull<FenceInst>(I)) indicatePessimisticFixpoint(); else removeAssumedBits(HAS_NO_EFFECT); } } + bool isDeadFence(Attributor &A, FenceInst &FI) { + const auto *ExecDomainAA = A.lookupAAFor<AAExecutionDomain>( + IRPosition::function(*FI.getFunction()), *this, DepClassTy::NONE); + if (!ExecDomainAA || !ExecDomainAA->isNoOpFence(FI)) + return false; + A.recordDependence(*ExecDomainAA, *this, DepClassTy::OPTIONAL); + return true; + } + bool isDeadStore(Attributor &A, StoreInst &SI, SmallSetVector<Instruction *, 8> *AssumeOnlyInst = nullptr) { // Lang ref now states volatile store is not UB/dead, let's skip them. @@ -4161,12 +4194,14 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { return true; if (auto *LI = dyn_cast<LoadInst>(V)) { if (llvm::all_of(LI->uses(), [&](const Use &U) { - return InfoCache.isOnlyUsedByAssume( - cast<Instruction>(*U.getUser())) || - A.isAssumedDead(U, this, nullptr, UsedAssumedInformation); + auto &UserI = cast<Instruction>(*U.getUser()); + if (InfoCache.isOnlyUsedByAssume(UserI)) { + if (AssumeOnlyInst) + AssumeOnlyInst->insert(&UserI); + return true; + } + return A.isAssumedDead(U, this, nullptr, UsedAssumedInformation); })) { - if (AssumeOnlyInst) - AssumeOnlyInst->insert(LI); return true; } } @@ -4177,12 +4212,15 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { Instruction *I = dyn_cast<Instruction>(&getAssociatedValue()); if (isa_and_nonnull<StoreInst>(I)) if (isValidState()) return "assumed-dead-store"; - return AAIsDeadValueImpl::getAsStr(); + if (isa_and_nonnull<FenceInst>(I)) + if (isValidState()) + return "assumed-dead-fence"; + return AAIsDeadValueImpl::getAsStr(A); } /// See AbstractAttribute::updateImpl(...). @@ -4191,6 +4229,9 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { if (auto *SI = dyn_cast_or_null<StoreInst>(I)) { if (!isDeadStore(A, *SI)) return indicatePessimisticFixpoint(); + } else if (auto *FI = dyn_cast_or_null<FenceInst>(I)) { + if (!isDeadFence(A, *FI)) + return indicatePessimisticFixpoint(); } else { if (!isAssumedSideEffectFree(A, I)) return indicatePessimisticFixpoint(); @@ -4226,6 +4267,11 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { } return ChangeStatus::CHANGED; } + if (auto *FI = dyn_cast<FenceInst>(I)) { + assert(isDeadFence(A, *FI)); + A.deleteAfterManifest(*FI); + return ChangeStatus::CHANGED; + } if (isAssumedSideEffectFree(A, I) && !isa<InvokeInst>(I)) { A.deleteAfterManifest(*I); return ChangeStatus::CHANGED; @@ -4248,13 +4294,6 @@ struct AAIsDeadArgument : public AAIsDeadFloating { AAIsDeadArgument(const IRPosition &IRP, Attributor &A) : AAIsDeadFloating(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AAIsDeadFloating::initialize(A); - if (!A.isFunctionIPOAmendable(*getAnchorScope())) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { Argument &Arg = *getAssociatedArgument(); @@ -4293,8 +4332,10 @@ struct AAIsDeadCallSiteArgument : public AAIsDeadValueImpl { if (!Arg) return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = A.getAAFor<AAIsDead>(*this, ArgPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), ArgAA.getState()); + auto *ArgAA = A.getAAFor<AAIsDead>(*this, ArgPos, DepClassTy::REQUIRED); + if (!ArgAA) + return indicatePessimisticFixpoint(); + return clampStateAndIndicateChange(getState(), ArgAA->getState()); } /// See AbstractAttribute::manifest(...). @@ -4355,7 +4396,7 @@ struct AAIsDeadCallSiteReturned : public AAIsDeadFloating { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return isAssumedDead() ? "assumed-dead" : (getAssumed() ? "assumed-dead-users" : "assumed-live"); @@ -4416,10 +4457,7 @@ struct AAIsDeadFunction : public AAIsDead { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { Function *F = getAnchorScope(); - if (!F || F->isDeclaration() || !A.isRunOn(*F)) { - indicatePessimisticFixpoint(); - return; - } + assert(F && "Did expect an anchor function"); if (!isAssumedDeadInternalFunction(A)) { ToBeExploredFrom.insert(&F->getEntryBlock().front()); assumeLive(A, F->getEntryBlock()); @@ -4435,7 +4473,7 @@ struct AAIsDeadFunction : public AAIsDead { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return "Live[#BB " + std::to_string(AssumedLiveBlocks.size()) + "/" + std::to_string(getAnchorScope()->size()) + "][#TBEP " + std::to_string(ToBeExploredFrom.size()) + "][#KDE " + @@ -4465,9 +4503,10 @@ struct AAIsDeadFunction : public AAIsDead { auto *CB = dyn_cast<CallBase>(DeadEndI); if (!CB) continue; - const auto &NoReturnAA = A.getAndUpdateAAFor<AANoReturn>( - *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); - bool MayReturn = !NoReturnAA.isAssumedNoReturn(); + bool IsKnownNoReturn; + bool MayReturn = !AA::hasAssumedIRAttr<Attribute::NoReturn>( + A, this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL, + IsKnownNoReturn); if (MayReturn && (!Invoke2CallAllowed || !isa<InvokeInst>(CB))) continue; @@ -4564,7 +4603,7 @@ struct AAIsDeadFunction : public AAIsDead { // functions. It can however cause dead functions to be treated as live. for (const Instruction &I : BB) if (const auto *CB = dyn_cast<CallBase>(&I)) - if (const Function *F = CB->getCalledFunction()) + if (auto *F = dyn_cast_if_present<Function>(CB->getCalledOperand())) if (F->hasLocalLinkage()) A.markLiveInternalFunction(*F); return true; @@ -4590,10 +4629,10 @@ identifyAliveSuccessors(Attributor &A, const CallBase &CB, SmallVectorImpl<const Instruction *> &AliveSuccessors) { const IRPosition &IPos = IRPosition::callsite_function(CB); - const auto &NoReturnAA = - A.getAndUpdateAAFor<AANoReturn>(AA, IPos, DepClassTy::OPTIONAL); - if (NoReturnAA.isAssumedNoReturn()) - return !NoReturnAA.isKnownNoReturn(); + bool IsKnownNoReturn; + if (AA::hasAssumedIRAttr<Attribute::NoReturn>( + A, &AA, IPos, DepClassTy::OPTIONAL, IsKnownNoReturn)) + return !IsKnownNoReturn; if (CB.isTerminator()) AliveSuccessors.push_back(&CB.getSuccessor(0)->front()); else @@ -4615,10 +4654,11 @@ identifyAliveSuccessors(Attributor &A, const InvokeInst &II, AliveSuccessors.push_back(&II.getUnwindDest()->front()); } else { const IRPosition &IPos = IRPosition::callsite_function(II); - const auto &AANoUnw = - A.getAndUpdateAAFor<AANoUnwind>(AA, IPos, DepClassTy::OPTIONAL); - if (AANoUnw.isAssumedNoUnwind()) { - UsedAssumedInformation |= !AANoUnw.isKnownNoUnwind(); + + bool IsKnownNoUnwind; + if (AA::hasAssumedIRAttr<Attribute::NoUnwind>( + A, &AA, IPos, DepClassTy::OPTIONAL, IsKnownNoUnwind)) { + UsedAssumedInformation |= !IsKnownNoUnwind; } else { AliveSuccessors.push_back(&II.getUnwindDest()->front()); } @@ -4829,25 +4869,21 @@ struct AADereferenceableImpl : AADereferenceable { void initialize(Attributor &A) override { Value &V = *getAssociatedValue().stripPointerCasts(); SmallVector<Attribute, 4> Attrs; - getAttrs({Attribute::Dereferenceable, Attribute::DereferenceableOrNull}, - Attrs, /* IgnoreSubsumingPositions */ false, &A); + A.getAttrs(getIRPosition(), + {Attribute::Dereferenceable, Attribute::DereferenceableOrNull}, + Attrs, /* IgnoreSubsumingPositions */ false); for (const Attribute &Attr : Attrs) takeKnownDerefBytesMaximum(Attr.getValueAsInt()); - const IRPosition &IRP = this->getIRPosition(); - NonNullAA = &A.getAAFor<AANonNull>(*this, IRP, DepClassTy::NONE); + // Ensure we initialize the non-null AA (if necessary). + bool IsKnownNonNull; + AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnownNonNull); bool CanBeNull, CanBeFreed; takeKnownDerefBytesMaximum(V.getPointerDereferenceableBytes( A.getDataLayout(), CanBeNull, CanBeFreed)); - bool IsFnInterface = IRP.isFnInterfaceKind(); - Function *FnScope = IRP.getAnchorScope(); - if (IsFnInterface && (!FnScope || !A.isFunctionIPOAmendable(*FnScope))) { - indicatePessimisticFixpoint(); - return; - } - if (Instruction *CtxI = getCtxI()) followUsesInMBEC(*this, A, getState(), *CtxI); } @@ -4894,17 +4930,24 @@ struct AADereferenceableImpl : AADereferenceable { /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { ChangeStatus Change = AADereferenceable::manifest(A); - if (isAssumedNonNull() && hasAttr(Attribute::DereferenceableOrNull)) { - removeAttrs({Attribute::DereferenceableOrNull}); + bool IsKnownNonNull; + bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, getIRPosition(), DepClassTy::NONE, IsKnownNonNull); + if (IsAssumedNonNull && + A.hasAttr(getIRPosition(), Attribute::DereferenceableOrNull)) { + A.removeAttrs(getIRPosition(), {Attribute::DereferenceableOrNull}); return ChangeStatus::CHANGED; } return Change; } - void getDeducedAttributes(LLVMContext &Ctx, + void getDeducedAttributes(Attributor &A, LLVMContext &Ctx, SmallVectorImpl<Attribute> &Attrs) const override { // TODO: Add *_globally support - if (isAssumedNonNull()) + bool IsKnownNonNull; + bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, getIRPosition(), DepClassTy::NONE, IsKnownNonNull); + if (IsAssumedNonNull) Attrs.emplace_back(Attribute::getWithDereferenceableBytes( Ctx, getAssumedDereferenceableBytes())); else @@ -4913,14 +4956,20 @@ struct AADereferenceableImpl : AADereferenceable { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { if (!getAssumedDereferenceableBytes()) return "unknown-dereferenceable"; + bool IsKnownNonNull; + bool IsAssumedNonNull = false; + if (A) + IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( + *A, this, getIRPosition(), DepClassTy::NONE, IsKnownNonNull); return std::string("dereferenceable") + - (isAssumedNonNull() ? "" : "_or_null") + + (IsAssumedNonNull ? "" : "_or_null") + (isAssumedGlobal() ? "_globally" : "") + "<" + std::to_string(getKnownDereferenceableBytes()) + "-" + - std::to_string(getAssumedDereferenceableBytes()) + ">"; + std::to_string(getAssumedDereferenceableBytes()) + ">" + + (!A ? " [non-null is unknown]" : ""); } }; @@ -4931,7 +4980,6 @@ struct AADereferenceableFloating : AADereferenceableImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - bool Stripped; bool UsedAssumedInformation = false; SmallVector<AA::ValueAndContext> Values; @@ -4955,10 +5003,10 @@ struct AADereferenceableFloating : AADereferenceableImpl { A, *this, &V, DL, Offset, /* GetMinOffset */ false, /* AllowNonInbounds */ true); - const auto &AA = A.getAAFor<AADereferenceable>( + const auto *AA = A.getAAFor<AADereferenceable>( *this, IRPosition::value(*Base), DepClassTy::REQUIRED); int64_t DerefBytes = 0; - if (!Stripped && this == &AA) { + if (!AA || (!Stripped && this == AA)) { // Use IR information if we did not strip anything. // TODO: track globally. bool CanBeNull, CanBeFreed; @@ -4966,7 +5014,7 @@ struct AADereferenceableFloating : AADereferenceableImpl { Base->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed); T.GlobalState.indicatePessimisticFixpoint(); } else { - const DerefState &DS = AA.getState(); + const DerefState &DS = AA->getState(); DerefBytes = DS.DerefBytesState.getAssumed(); T.GlobalState &= DS.GlobalState; } @@ -4981,7 +5029,7 @@ struct AADereferenceableFloating : AADereferenceableImpl { T.takeAssumedDerefBytesMinimum( std::max(int64_t(0), DerefBytes - OffsetSExt)); - if (this == &AA) { + if (this == AA) { if (!Stripped) { // If nothing was stripped IR information is all we got. T.takeKnownDerefBytesMaximum( @@ -5016,9 +5064,10 @@ struct AADereferenceableFloating : AADereferenceableImpl { /// Dereferenceable attribute for a return value. struct AADereferenceableReturned final : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl> { + using Base = + AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl>; AADereferenceableReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl>( - IRP, A) {} + : Base(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -5095,8 +5144,9 @@ static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA, IRPosition IRP = IRPosition::callsite_argument(*CB, ArgNo); // As long as we only use known information there is no need to track // dependences here. - auto &AlignAA = A.getAAFor<AAAlign>(QueryingAA, IRP, DepClassTy::NONE); - MA = MaybeAlign(AlignAA.getKnownAlign()); + auto *AlignAA = A.getAAFor<AAAlign>(QueryingAA, IRP, DepClassTy::NONE); + if (AlignAA) + MA = MaybeAlign(AlignAA->getKnownAlign()); } const DataLayout &DL = A.getDataLayout(); @@ -5122,7 +5172,7 @@ static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA, // gcd(Offset, Alignment) is an alignment. uint32_t gcd = std::gcd(uint32_t(abs((int32_t)Offset)), Alignment); - Alignment = llvm::PowerOf2Floor(gcd); + Alignment = llvm::bit_floor(gcd); } } @@ -5135,20 +5185,13 @@ struct AAAlignImpl : AAAlign { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { SmallVector<Attribute, 4> Attrs; - getAttrs({Attribute::Alignment}, Attrs); + A.getAttrs(getIRPosition(), {Attribute::Alignment}, Attrs); for (const Attribute &Attr : Attrs) takeKnownMaximum(Attr.getValueAsInt()); Value &V = *getAssociatedValue().stripPointerCasts(); takeKnownMaximum(V.getPointerAlignment(A.getDataLayout()).value()); - if (getIRPosition().isFnInterfaceKind() && - (!getAnchorScope() || - !A.isFunctionIPOAmendable(*getAssociatedFunction()))) { - indicatePessimisticFixpoint(); - return; - } - if (Instruction *CtxI = getCtxI()) followUsesInMBEC(*this, A, getState(), *CtxI); } @@ -5193,7 +5236,7 @@ struct AAAlignImpl : AAAlign { // to avoid making the alignment explicit if it did not improve. /// See AbstractAttribute::getDeducedAttributes - void getDeducedAttributes(LLVMContext &Ctx, + void getDeducedAttributes(Attributor &A, LLVMContext &Ctx, SmallVectorImpl<Attribute> &Attrs) const override { if (getAssumedAlign() > 1) Attrs.emplace_back( @@ -5213,7 +5256,7 @@ struct AAAlignImpl : AAAlign { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return "align<" + std::to_string(getKnownAlign().value()) + "-" + std::to_string(getAssumedAlign().value()) + ">"; } @@ -5243,9 +5286,9 @@ struct AAAlignFloating : AAAlignImpl { auto VisitValueCB = [&](Value &V) -> bool { if (isa<UndefValue>(V) || isa<ConstantPointerNull>(V)) return true; - const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V), + const auto *AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V), DepClassTy::REQUIRED); - if (!Stripped && this == &AA) { + if (!AA || (!Stripped && this == AA)) { int64_t Offset; unsigned Alignment = 1; if (const Value *Base = @@ -5258,7 +5301,7 @@ struct AAAlignFloating : AAAlignImpl { uint32_t gcd = std::gcd(uint32_t(abs((int32_t)Offset)), uint32_t(PA.value())); - Alignment = llvm::PowerOf2Floor(gcd); + Alignment = llvm::bit_floor(gcd); } else { Alignment = V.getPointerAlignment(DL).value(); } @@ -5267,7 +5310,7 @@ struct AAAlignFloating : AAAlignImpl { T.indicatePessimisticFixpoint(); } else { // Use abstract attribute information. - const AAAlign::StateType &DS = AA.getState(); + const AAAlign::StateType &DS = AA->getState(); T ^= DS; } return T.isValidState(); @@ -5293,14 +5336,6 @@ struct AAAlignReturned final using Base = AAReturnedFromReturnedValues<AAAlign, AAAlignImpl>; AAAlignReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - Base::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(aligned) } }; @@ -5351,9 +5386,10 @@ struct AAAlignCallSiteArgument final : AAAlignFloating { if (Argument *Arg = getAssociatedArgument()) { // We only take known information from the argument // so we do not need to track a dependence. - const auto &ArgAlignAA = A.getAAFor<AAAlign>( + const auto *ArgAlignAA = A.getAAFor<AAAlign>( *this, IRPosition::argument(*Arg), DepClassTy::NONE); - takeKnownMaximum(ArgAlignAA.getKnownAlign().value()); + if (ArgAlignAA) + takeKnownMaximum(ArgAlignAA->getKnownAlign().value()); } return Changed; } @@ -5369,14 +5405,6 @@ struct AAAlignCallSiteReturned final AAAlignCallSiteReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - Base::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(align); } }; @@ -5389,14 +5417,14 @@ struct AANoReturnImpl : public AANoReturn { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - AANoReturn::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::NoReturn>( + A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown)); + (void)IsKnown; } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "noreturn" : "may-return"; } @@ -5425,17 +5453,6 @@ struct AANoReturnCallSite final : AANoReturnImpl { AANoReturnCallSite(const IRPosition &IRP, Attributor &A) : AANoReturnImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AANoReturnImpl::initialize(A); - if (Function *F = getAssociatedFunction()) { - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos, DepClassTy::REQUIRED); - if (!FnAA.isAssumedNoReturn()) - indicatePessimisticFixpoint(); - } - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -5444,8 +5461,11 @@ struct AANoReturnCallSite final : AANoReturnImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + bool IsKnownNoReturn; + if (!AA::hasAssumedIRAttr<Attribute::NoReturn>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoReturn)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -5477,6 +5497,15 @@ struct AAInstanceInfoImpl : public AAInstanceInfo { indicateOptimisticFixpoint(); return; } + if (auto *I = dyn_cast<Instruction>(&V)) { + const auto *CI = + A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>( + *I->getFunction()); + if (mayBeInCycle(CI, I, /* HeaderOnly */ false)) { + indicatePessimisticFixpoint(); + return; + } + } } /// See AbstractAttribute::updateImpl(...). @@ -5495,9 +5524,10 @@ struct AAInstanceInfoImpl : public AAInstanceInfo { if (!Scope) return indicateOptimisticFixpoint(); - auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - *this, IRPosition::function(*Scope), DepClassTy::OPTIONAL); - if (NoRecurseAA.isAssumedNoRecurse()) + bool IsKnownNoRecurse; + if (AA::hasAssumedIRAttr<Attribute::NoRecurse>( + A, this, IRPosition::function(*Scope), DepClassTy::OPTIONAL, + IsKnownNoRecurse)) return Changed; auto UsePred = [&](const Use &U, bool &Follow) { @@ -5514,15 +5544,16 @@ struct AAInstanceInfoImpl : public AAInstanceInfo { if (auto *CB = dyn_cast<CallBase>(UserI)) { // This check is not guaranteeing uniqueness but for now that we cannot // end up with two versions of \p U thinking it was one. - if (!CB->getCalledFunction() || - !CB->getCalledFunction()->hasLocalLinkage()) + auto *Callee = dyn_cast_if_present<Function>(CB->getCalledOperand()); + if (!Callee || !Callee->hasLocalLinkage()) return true; if (!CB->isArgOperand(&U)) return false; - const auto &ArgInstanceInfoAA = A.getAAFor<AAInstanceInfo>( + const auto *ArgInstanceInfoAA = A.getAAFor<AAInstanceInfo>( *this, IRPosition::callsite_argument(*CB, CB->getArgOperandNo(&U)), DepClassTy::OPTIONAL); - if (!ArgInstanceInfoAA.isAssumedUniqueForAnalysis()) + if (!ArgInstanceInfoAA || + !ArgInstanceInfoAA->isAssumedUniqueForAnalysis()) return false; // If this call base might reach the scope again we might forward the // argument back here. This is very conservative. @@ -5554,7 +5585,7 @@ struct AAInstanceInfoImpl : public AAInstanceInfo { } /// See AbstractState::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return isAssumedUniqueForAnalysis() ? "<unique [fAa]>" : "<unknown>"; } @@ -5589,9 +5620,11 @@ struct AAInstanceInfoCallSiteArgument final : AAInstanceInfoImpl { if (!Arg) return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = + auto *ArgAA = A.getAAFor<AAInstanceInfo>(*this, ArgPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), ArgAA.getState()); + if (!ArgAA) + return indicatePessimisticFixpoint(); + return clampStateAndIndicateChange(getState(), ArgAA->getState()); } }; @@ -5621,6 +5654,95 @@ struct AAInstanceInfoCallSiteReturned final : AAInstanceInfoFloating { } // namespace /// ----------------------- Variable Capturing --------------------------------- +bool AANoCapture::isImpliedByIR(Attributor &A, const IRPosition &IRP, + Attribute::AttrKind ImpliedAttributeKind, + bool IgnoreSubsumingPositions) { + assert(ImpliedAttributeKind == Attribute::NoCapture && + "Unexpected attribute kind"); + Value &V = IRP.getAssociatedValue(); + if (!IRP.isArgumentPosition()) + return V.use_empty(); + + // You cannot "capture" null in the default address space. + if (isa<UndefValue>(V) || (isa<ConstantPointerNull>(V) && + V.getType()->getPointerAddressSpace() == 0)) { + return true; + } + + if (A.hasAttr(IRP, {Attribute::NoCapture}, + /* IgnoreSubsumingPositions */ true, Attribute::NoCapture)) + return true; + + if (IRP.getPositionKind() == IRP_CALL_SITE_ARGUMENT) + if (Argument *Arg = IRP.getAssociatedArgument()) + if (A.hasAttr(IRPosition::argument(*Arg), + {Attribute::NoCapture, Attribute::ByVal}, + /* IgnoreSubsumingPositions */ true)) { + A.manifestAttrs(IRP, + Attribute::get(V.getContext(), Attribute::NoCapture)); + return true; + } + + if (const Function *F = IRP.getAssociatedFunction()) { + // Check what state the associated function can actually capture. + AANoCapture::StateType State; + determineFunctionCaptureCapabilities(IRP, *F, State); + if (State.isKnown(NO_CAPTURE)) { + A.manifestAttrs(IRP, + Attribute::get(V.getContext(), Attribute::NoCapture)); + return true; + } + } + + return false; +} + +/// Set the NOT_CAPTURED_IN_MEM and NOT_CAPTURED_IN_RET bits in \p Known +/// depending on the ability of the function associated with \p IRP to capture +/// state in memory and through "returning/throwing", respectively. +void AANoCapture::determineFunctionCaptureCapabilities(const IRPosition &IRP, + const Function &F, + BitIntegerState &State) { + // TODO: Once we have memory behavior attributes we should use them here. + + // If we know we cannot communicate or write to memory, we do not care about + // ptr2int anymore. + bool ReadOnly = F.onlyReadsMemory(); + bool NoThrow = F.doesNotThrow(); + bool IsVoidReturn = F.getReturnType()->isVoidTy(); + if (ReadOnly && NoThrow && IsVoidReturn) { + State.addKnownBits(NO_CAPTURE); + return; + } + + // A function cannot capture state in memory if it only reads memory, it can + // however return/throw state and the state might be influenced by the + // pointer value, e.g., loading from a returned pointer might reveal a bit. + if (ReadOnly) + State.addKnownBits(NOT_CAPTURED_IN_MEM); + + // A function cannot communicate state back if it does not through + // exceptions and doesn not return values. + if (NoThrow && IsVoidReturn) + State.addKnownBits(NOT_CAPTURED_IN_RET); + + // Check existing "returned" attributes. + int ArgNo = IRP.getCalleeArgNo(); + if (!NoThrow || ArgNo < 0 || + !F.getAttributes().hasAttrSomewhere(Attribute::Returned)) + return; + + for (unsigned U = 0, E = F.arg_size(); U < E; ++U) + if (F.hasParamAttribute(U, Attribute::Returned)) { + if (U == unsigned(ArgNo)) + State.removeAssumedBits(NOT_CAPTURED_IN_RET); + else if (ReadOnly) + State.addKnownBits(NO_CAPTURE); + else + State.addKnownBits(NOT_CAPTURED_IN_RET); + break; + } +} namespace { /// A class to hold the state of for no-capture attributes. @@ -5629,39 +5751,17 @@ struct AANoCaptureImpl : public AANoCapture { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - if (hasAttr(getAttrKind(), /* IgnoreSubsumingPositions */ true)) { - indicateOptimisticFixpoint(); - return; - } - Function *AnchorScope = getAnchorScope(); - if (isFnInterfaceKind() && - (!AnchorScope || !A.isFunctionIPOAmendable(*AnchorScope))) { - indicatePessimisticFixpoint(); - return; - } - - // You cannot "capture" null in the default address space. - if (isa<ConstantPointerNull>(getAssociatedValue()) && - getAssociatedValue().getType()->getPointerAddressSpace() == 0) { - indicateOptimisticFixpoint(); - return; - } - - const Function *F = - isArgumentPosition() ? getAssociatedFunction() : AnchorScope; - - // Check what state the associated function can actually capture. - if (F) - determineFunctionCaptureCapabilities(getIRPosition(), *F, *this); - else - indicatePessimisticFixpoint(); + bool IsKnown; + assert(!AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, nullptr, getIRPosition(), DepClassTy::NONE, IsKnown)); + (void)IsKnown; } /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override; /// see AbstractAttribute::isAssumedNoCaptureMaybeReturned(...). - void getDeducedAttributes(LLVMContext &Ctx, + void getDeducedAttributes(Attributor &A, LLVMContext &Ctx, SmallVectorImpl<Attribute> &Attrs) const override { if (!isAssumedNoCaptureMaybeReturned()) return; @@ -5674,51 +5774,8 @@ struct AANoCaptureImpl : public AANoCapture { } } - /// Set the NOT_CAPTURED_IN_MEM and NOT_CAPTURED_IN_RET bits in \p Known - /// depending on the ability of the function associated with \p IRP to capture - /// state in memory and through "returning/throwing", respectively. - static void determineFunctionCaptureCapabilities(const IRPosition &IRP, - const Function &F, - BitIntegerState &State) { - // TODO: Once we have memory behavior attributes we should use them here. - - // If we know we cannot communicate or write to memory, we do not care about - // ptr2int anymore. - if (F.onlyReadsMemory() && F.doesNotThrow() && - F.getReturnType()->isVoidTy()) { - State.addKnownBits(NO_CAPTURE); - return; - } - - // A function cannot capture state in memory if it only reads memory, it can - // however return/throw state and the state might be influenced by the - // pointer value, e.g., loading from a returned pointer might reveal a bit. - if (F.onlyReadsMemory()) - State.addKnownBits(NOT_CAPTURED_IN_MEM); - - // A function cannot communicate state back if it does not through - // exceptions and doesn not return values. - if (F.doesNotThrow() && F.getReturnType()->isVoidTy()) - State.addKnownBits(NOT_CAPTURED_IN_RET); - - // Check existing "returned" attributes. - int ArgNo = IRP.getCalleeArgNo(); - if (F.doesNotThrow() && ArgNo >= 0) { - for (unsigned u = 0, e = F.arg_size(); u < e; ++u) - if (F.hasParamAttribute(u, Attribute::Returned)) { - if (u == unsigned(ArgNo)) - State.removeAssumedBits(NOT_CAPTURED_IN_RET); - else if (F.onlyReadsMemory()) - State.addKnownBits(NO_CAPTURE); - else - State.addKnownBits(NOT_CAPTURED_IN_RET); - break; - } - } - } - /// See AbstractState::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { if (isKnownNoCapture()) return "known not-captured"; if (isAssumedNoCapture()) @@ -5771,12 +5828,15 @@ struct AANoCaptureImpl : public AANoCapture { const IRPosition &CSArgPos = IRPosition::callsite_argument(*CB, ArgNo); // If we have a abstract no-capture attribute for the argument we can use // it to justify a non-capture attribute here. This allows recursion! - auto &ArgNoCaptureAA = - A.getAAFor<AANoCapture>(*this, CSArgPos, DepClassTy::REQUIRED); - if (ArgNoCaptureAA.isAssumedNoCapture()) + bool IsKnownNoCapture; + const AANoCapture *ArgNoCaptureAA = nullptr; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, CSArgPos, DepClassTy::REQUIRED, IsKnownNoCapture, false, + &ArgNoCaptureAA); + if (IsAssumedNoCapture) return isCapturedIn(State, /* Memory */ false, /* Integer */ false, /* Return */ false); - if (ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) { + if (ArgNoCaptureAA && ArgNoCaptureAA->isAssumedNoCaptureMaybeReturned()) { Follow = true; return isCapturedIn(State, /* Memory */ false, /* Integer */ false, /* Return */ false); @@ -5830,37 +5890,35 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { // TODO: we could do this in a more sophisticated way inside // AAReturnedValues, e.g., track all values that escape through returns // directly somehow. - auto CheckReturnedArgs = [&](const AAReturnedValues &RVAA) { - if (!RVAA.getState().isValidState()) + auto CheckReturnedArgs = [&](bool &UsedAssumedInformation) { + SmallVector<AA::ValueAndContext> Values; + if (!A.getAssumedSimplifiedValues(IRPosition::returned(*F), this, Values, + AA::ValueScope::Intraprocedural, + UsedAssumedInformation)) return false; bool SeenConstant = false; - for (const auto &It : RVAA.returned_values()) { - if (isa<Constant>(It.first)) { + for (const AA::ValueAndContext &VAC : Values) { + if (isa<Constant>(VAC.getValue())) { if (SeenConstant) return false; SeenConstant = true; - } else if (!isa<Argument>(It.first) || - It.first == getAssociatedArgument()) + } else if (!isa<Argument>(VAC.getValue()) || + VAC.getValue() == getAssociatedArgument()) return false; } return true; }; - const auto &NoUnwindAA = - A.getAAFor<AANoUnwind>(*this, FnPos, DepClassTy::OPTIONAL); - if (NoUnwindAA.isAssumedNoUnwind()) { + bool IsKnownNoUnwind; + if (AA::hasAssumedIRAttr<Attribute::NoUnwind>( + A, this, FnPos, DepClassTy::OPTIONAL, IsKnownNoUnwind)) { bool IsVoidTy = F->getReturnType()->isVoidTy(); - const AAReturnedValues *RVAA = - IsVoidTy ? nullptr - : &A.getAAFor<AAReturnedValues>(*this, FnPos, - - DepClassTy::OPTIONAL); - if (IsVoidTy || CheckReturnedArgs(*RVAA)) { + bool UsedAssumedInformation = false; + if (IsVoidTy || CheckReturnedArgs(UsedAssumedInformation)) { T.addKnownBits(NOT_CAPTURED_IN_RET); if (T.isKnown(NOT_CAPTURED_IN_MEM)) return ChangeStatus::UNCHANGED; - if (NoUnwindAA.isKnownNoUnwind() && - (IsVoidTy || RVAA->getState().isAtFixpoint())) { + if (IsKnownNoUnwind && (IsVoidTy || !UsedAssumedInformation)) { addKnownBits(NOT_CAPTURED_IN_RET); if (isKnown(NOT_CAPTURED_IN_MEM)) return indicateOptimisticFixpoint(); @@ -5869,9 +5927,9 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { } auto IsDereferenceableOrNull = [&](Value *O, const DataLayout &DL) { - const auto &DerefAA = A.getAAFor<AADereferenceable>( + const auto *DerefAA = A.getAAFor<AADereferenceable>( *this, IRPosition::value(*O), DepClassTy::OPTIONAL); - return DerefAA.getAssumedDereferenceableBytes(); + return DerefAA && DerefAA->getAssumedDereferenceableBytes(); }; auto UseCheck = [&](const Use &U, bool &Follow) -> bool { @@ -5913,14 +5971,6 @@ struct AANoCaptureCallSiteArgument final : AANoCaptureImpl { AANoCaptureCallSiteArgument(const IRPosition &IRP, Attributor &A) : AANoCaptureImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - if (Argument *Arg = getAssociatedArgument()) - if (Arg->hasByValAttr()) - indicateOptimisticFixpoint(); - AANoCaptureImpl::initialize(A); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -5931,8 +5981,15 @@ struct AANoCaptureCallSiteArgument final : AANoCaptureImpl { if (!Arg) return indicatePessimisticFixpoint(); const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = A.getAAFor<AANoCapture>(*this, ArgPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), ArgAA.getState()); + bool IsKnownNoCapture; + const AANoCapture *ArgAA = nullptr; + if (AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, ArgPos, DepClassTy::REQUIRED, IsKnownNoCapture, false, + &ArgAA)) + return ChangeStatus::UNCHANGED; + if (!ArgAA || !ArgAA->isAssumedNoCaptureMaybeReturned()) + return indicatePessimisticFixpoint(); + return clampStateAndIndicateChange(getState(), ArgAA->getState()); } /// See AbstractAttribute::trackStatistics() @@ -6023,7 +6080,7 @@ struct AAValueSimplifyImpl : AAValueSimplify { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { LLVM_DEBUG({ dbgs() << "SAV: " << (bool)SimplifiedAssociatedValue << " "; if (SimplifiedAssociatedValue && *SimplifiedAssociatedValue) @@ -6156,19 +6213,21 @@ struct AAValueSimplifyImpl : AAValueSimplify { return false; // This will also pass the call base context. - const auto &AA = + const auto *AA = A.getAAFor<AAType>(*this, getIRPosition(), DepClassTy::NONE); + if (!AA) + return false; - std::optional<Constant *> COpt = AA.getAssumedConstant(A); + std::optional<Constant *> COpt = AA->getAssumedConstant(A); if (!COpt) { SimplifiedAssociatedValue = std::nullopt; - A.recordDependence(AA, *this, DepClassTy::OPTIONAL); + A.recordDependence(*AA, *this, DepClassTy::OPTIONAL); return true; } if (auto *C = *COpt) { SimplifiedAssociatedValue = C; - A.recordDependence(AA, *this, DepClassTy::OPTIONAL); + A.recordDependence(*AA, *this, DepClassTy::OPTIONAL); return true; } return false; @@ -6215,11 +6274,10 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { void initialize(Attributor &A) override { AAValueSimplifyImpl::initialize(A); - if (!getAnchorScope() || getAnchorScope()->isDeclaration()) - indicatePessimisticFixpoint(); - if (hasAttr({Attribute::InAlloca, Attribute::Preallocated, - Attribute::StructRet, Attribute::Nest, Attribute::ByVal}, - /* IgnoreSubsumingPositions */ true)) + if (A.hasAttr(getIRPosition(), + {Attribute::InAlloca, Attribute::Preallocated, + Attribute::StructRet, Attribute::Nest, Attribute::ByVal}, + /* IgnoreSubsumingPositions */ true)) indicatePessimisticFixpoint(); } @@ -6266,7 +6324,7 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { bool Success; bool UsedAssumedInformation = false; if (hasCallBaseContext() && - getCallBaseContext()->getCalledFunction() == Arg->getParent()) + getCallBaseContext()->getCalledOperand() == Arg->getParent()) Success = PredForCallSite( AbstractCallSite(&getCallBaseContext()->getCalledOperandUse())); else @@ -6401,10 +6459,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl { void initialize(Attributor &A) override { AAValueSimplifyImpl::initialize(A); Function *Fn = getAssociatedFunction(); - if (!Fn) { - indicatePessimisticFixpoint(); - return; - } + assert(Fn && "Did expect an associted function"); for (Argument &Arg : Fn->args()) { if (Arg.hasReturnedAttr()) { auto IRP = IRPosition::callsite_argument(*cast<CallBase>(getCtxI()), @@ -6421,26 +6476,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - auto Before = SimplifiedAssociatedValue; - auto &RetAA = A.getAAFor<AAReturnedValues>( - *this, IRPosition::function(*getAssociatedFunction()), - DepClassTy::REQUIRED); - auto PredForReturned = - [&](Value &RetVal, const SmallSetVector<ReturnInst *, 4> &RetInsts) { - bool UsedAssumedInformation = false; - std::optional<Value *> CSRetVal = - A.translateArgumentToCallSiteContent( - &RetVal, *cast<CallBase>(getCtxI()), *this, - UsedAssumedInformation); - SimplifiedAssociatedValue = AA::combineOptionalValuesInAAValueLatice( - SimplifiedAssociatedValue, CSRetVal, getAssociatedType()); - return SimplifiedAssociatedValue != std::optional<Value *>(nullptr); - }; - if (!RetAA.checkForAllReturnedValuesAndReturnInsts(PredForReturned)) - if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); - return Before == SimplifiedAssociatedValue ? ChangeStatus::UNCHANGED - : ChangeStatus ::CHANGED; } void trackStatistics() const override { @@ -6581,7 +6617,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { SCB); } - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { unsigned NumH2SMallocs = 0, NumInvalidMallocs = 0; for (const auto &It : AllocationInfos) { if (It.second->Status == AllocationInfo::INVALID) @@ -6773,10 +6809,10 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { const Function *F = getAnchorScope(); const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F); - const auto &LivenessAA = + const auto *LivenessAA = A.getAAFor<AAIsDead>(*this, IRPosition::function(*F), DepClassTy::NONE); - MustBeExecutedContextExplorer &Explorer = + MustBeExecutedContextExplorer *Explorer = A.getInfoCache().getMustBeExecutedContextExplorer(); bool StackIsAccessibleByOtherThreads = @@ -6813,7 +6849,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { // No need to analyze dead calls, ignore them instead. bool UsedAssumedInformation = false; - if (A.isAssumedDead(*DI.CB, this, &LivenessAA, UsedAssumedInformation, + if (A.isAssumedDead(*DI.CB, this, LivenessAA, UsedAssumedInformation, /* CheckBBLivenessOnly */ true)) continue; @@ -6855,9 +6891,9 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { // doesn't apply as the pointer could be shared and needs to be places in // "shareable" memory. if (!StackIsAccessibleByOtherThreads) { - auto &NoSyncAA = - A.getAAFor<AANoSync>(*this, getIRPosition(), DepClassTy::OPTIONAL); - if (!NoSyncAA.isAssumedNoSync()) { + bool IsKnownNoSycn; + if (!AA::hasAssumedIRAttr<Attribute::NoSync>( + A, this, getIRPosition(), DepClassTy::OPTIONAL, IsKnownNoSycn)) { LLVM_DEBUG( dbgs() << "[H2S] found an escaping use, stack is not accessible by " "other threads and function is not nosync:\n"); @@ -6902,7 +6938,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { return false; } Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode(); - if (!Explorer.findInContextOf(UniqueFree, CtxI)) { + if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) { LLVM_DEBUG( dbgs() << "[H2S] unique free call might not be executed with the allocation " @@ -6938,22 +6974,21 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { } unsigned ArgNo = CB->getArgOperandNo(&U); + auto CBIRP = IRPosition::callsite_argument(*CB, ArgNo); - const auto &NoCaptureAA = A.getAAFor<AANoCapture>( - *this, IRPosition::callsite_argument(*CB, ArgNo), - DepClassTy::OPTIONAL); + bool IsKnownNoCapture; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, CBIRP, DepClassTy::OPTIONAL, IsKnownNoCapture); // If a call site argument use is nofree, we are fine. - const auto &ArgNoFreeAA = A.getAAFor<AANoFree>( - *this, IRPosition::callsite_argument(*CB, ArgNo), - DepClassTy::OPTIONAL); + bool IsKnownNoFree; + bool IsAssumedNoFree = AA::hasAssumedIRAttr<Attribute::NoFree>( + A, this, CBIRP, DepClassTy::OPTIONAL, IsKnownNoFree); - bool MaybeCaptured = !NoCaptureAA.isAssumedNoCapture(); - bool MaybeFreed = !ArgNoFreeAA.isAssumedNoFree(); - if (MaybeCaptured || + if (!IsAssumedNoCapture || (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared && - MaybeFreed)) { - AI.HasPotentiallyFreeingUnknownUses |= MaybeFreed; + !IsAssumedNoFree)) { + AI.HasPotentiallyFreeingUnknownUses |= !IsAssumedNoFree; // Emit a missed remark if this is missed OpenMP globalization. auto Remark = [&](OptimizationRemarkMissed ORM) { @@ -6984,7 +7019,14 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { ValidUsesOnly = false; return true; }; - if (!A.checkForAllUses(Pred, *this, *AI.CB)) + if (!A.checkForAllUses(Pred, *this, *AI.CB, /* CheckBBLivenessOnly */ false, + DepClassTy::OPTIONAL, /* IgnoreDroppableUses */ true, + [&](const Use &OldU, const Use &NewU) { + auto *SI = dyn_cast<StoreInst>(OldU.getUser()); + return !SI || StackIsAccessibleByOtherThreads || + AA::isAssumedThreadLocalObject( + A, *SI->getPointerOperand(), *this); + })) return false; return ValidUsesOnly; }; @@ -7018,7 +7060,8 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { } std::optional<APInt> Size = getSize(A, *this, AI); - if (MaxHeapToStackSize != -1) { + if (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared && + MaxHeapToStackSize != -1) { if (!Size || Size->ugt(MaxHeapToStackSize)) { LLVM_DEBUG({ if (!Size) @@ -7078,7 +7121,8 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { } /// Identify the type we can chose for a private copy of the underlying - /// argument. None means it is not clear yet, nullptr means there is none. + /// argument. std::nullopt means it is not clear yet, nullptr means there is + /// none. virtual std::optional<Type *> identifyPrivatizableType(Attributor &A) = 0; /// Return a privatizable type that encloses both T0 and T1. @@ -7098,7 +7142,7 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { return PrivatizableType; } - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return isAssumedPrivatizablePtr() ? "[priv]" : "[no-priv]"; } @@ -7118,7 +7162,8 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // rewrite them), there is no need to check them explicitly. bool UsedAssumedInformation = false; SmallVector<Attribute, 1> Attrs; - getAttrs({Attribute::ByVal}, Attrs, /* IgnoreSubsumingPositions */ true); + A.getAttrs(getIRPosition(), {Attribute::ByVal}, Attrs, + /* IgnoreSubsumingPositions */ true); if (!Attrs.empty() && A.checkForAllCallSites([](AbstractCallSite ACS) { return true; }, *this, true, UsedAssumedInformation)) @@ -7141,9 +7186,11 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { return false; // Check that all call sites agree on a type. - auto &PrivCSArgAA = + auto *PrivCSArgAA = A.getAAFor<AAPrivatizablePtr>(*this, ACSArgPos, DepClassTy::REQUIRED); - std::optional<Type *> CSTy = PrivCSArgAA.getPrivatizableType(); + if (!PrivCSArgAA) + return false; + std::optional<Type *> CSTy = PrivCSArgAA->getPrivatizableType(); LLVM_DEBUG({ dbgs() << "[AAPrivatizablePtr] ACSPos: " << ACSArgPos << ", CSTy: "; @@ -7191,7 +7238,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { DepClassTy::OPTIONAL); // Avoid arguments with padding for now. - if (!getIRPosition().hasAttr(Attribute::ByVal) && + if (!A.hasAttr(getIRPosition(), Attribute::ByVal) && !isDenselyPacked(*PrivatizableType, A.getInfoCache().getDL())) { LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] Padding detected\n"); return indicatePessimisticFixpoint(); @@ -7216,7 +7263,9 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { auto CallSiteCheck = [&](AbstractCallSite ACS) { CallBase *CB = ACS.getInstruction(); return TTI->areTypesABICompatible( - CB->getCaller(), CB->getCalledFunction(), ReplacementTypes); + CB->getCaller(), + dyn_cast_if_present<Function>(CB->getCalledOperand()), + ReplacementTypes); }; bool UsedAssumedInformation = false; if (!A.checkForAllCallSites(CallSiteCheck, *this, true, @@ -7264,10 +7313,10 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { if (CBArgNo != int(ArgNo)) continue; - const auto &CBArgPrivAA = A.getAAFor<AAPrivatizablePtr>( + const auto *CBArgPrivAA = A.getAAFor<AAPrivatizablePtr>( *this, IRPosition::argument(CBArg), DepClassTy::REQUIRED); - if (CBArgPrivAA.isValidState()) { - auto CBArgPrivTy = CBArgPrivAA.getPrivatizableType(); + if (CBArgPrivAA && CBArgPrivAA->isValidState()) { + auto CBArgPrivTy = CBArgPrivAA->getPrivatizableType(); if (!CBArgPrivTy) continue; if (*CBArgPrivTy == PrivatizableType) @@ -7298,23 +7347,23 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { assert(DCArgNo >= 0 && unsigned(DCArgNo) < DC->arg_size() && "Expected a direct call operand for callback call operand"); + Function *DCCallee = + dyn_cast_if_present<Function>(DC->getCalledOperand()); LLVM_DEBUG({ dbgs() << "[AAPrivatizablePtr] Argument " << *Arg << " check if be privatized in the context of its parent (" << Arg->getParent()->getName() << ")\n[AAPrivatizablePtr] because it is an argument in a " "direct call of (" - << DCArgNo << "@" << DC->getCalledFunction()->getName() - << ").\n"; + << DCArgNo << "@" << DCCallee->getName() << ").\n"; }); - Function *DCCallee = DC->getCalledFunction(); if (unsigned(DCArgNo) < DCCallee->arg_size()) { - const auto &DCArgPrivAA = A.getAAFor<AAPrivatizablePtr>( + const auto *DCArgPrivAA = A.getAAFor<AAPrivatizablePtr>( *this, IRPosition::argument(*DCCallee->getArg(DCArgNo)), DepClassTy::REQUIRED); - if (DCArgPrivAA.isValidState()) { - auto DCArgPrivTy = DCArgPrivAA.getPrivatizableType(); + if (DCArgPrivAA && DCArgPrivAA->isValidState()) { + auto DCArgPrivTy = DCArgPrivAA->getPrivatizableType(); if (!DCArgPrivTy) return true; if (*DCArgPrivTy == PrivatizableType) @@ -7328,7 +7377,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { << Arg->getParent()->getName() << ")\n[AAPrivatizablePtr] because it is an argument in a " "direct call of (" - << ACS.getInstruction()->getCalledFunction()->getName() + << ACS.getInstruction()->getCalledOperand()->getName() << ").\n[AAPrivatizablePtr] for which the argument " "privatization is not compatible.\n"; }); @@ -7479,7 +7528,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { Argument *Arg = getAssociatedArgument(); // Query AAAlign attribute for alignment of associated argument to // determine the best alignment of loads. - const auto &AlignAA = + const auto *AlignAA = A.getAAFor<AAAlign>(*this, IRPosition::value(*Arg), DepClassTy::NONE); // Callback to repair the associated function. A new alloca is placed at the @@ -7510,13 +7559,13 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // of the privatizable type are loaded prior to the call and passed to the // new function version. Attributor::ArgumentReplacementInfo::ACSRepairCBTy ACSRepairCB = - [=, &AlignAA](const Attributor::ArgumentReplacementInfo &ARI, - AbstractCallSite ACS, - SmallVectorImpl<Value *> &NewArgOperands) { + [=](const Attributor::ArgumentReplacementInfo &ARI, + AbstractCallSite ACS, SmallVectorImpl<Value *> &NewArgOperands) { // When no alignment is specified for the load instruction, // natural alignment is assumed. createReplacementValues( - AlignAA.getAssumedAlign(), *PrivatizableType, ACS, + AlignAA ? AlignAA->getAssumedAlign() : Align(0), + *PrivatizableType, ACS, ACS.getCallArgOperand(ARI.getReplacedArg().getArgNo()), NewArgOperands); }; @@ -7568,10 +7617,10 @@ struct AAPrivatizablePtrFloating : public AAPrivatizablePtrImpl { if (CI->isOne()) return AI->getAllocatedType(); if (auto *Arg = dyn_cast<Argument>(Obj)) { - auto &PrivArgAA = A.getAAFor<AAPrivatizablePtr>( + auto *PrivArgAA = A.getAAFor<AAPrivatizablePtr>( *this, IRPosition::argument(*Arg), DepClassTy::REQUIRED); - if (PrivArgAA.isAssumedPrivatizablePtr()) - return PrivArgAA.getPrivatizableType(); + if (PrivArgAA && PrivArgAA->isAssumedPrivatizablePtr()) + return PrivArgAA->getPrivatizableType(); } LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] Underlying object neither valid " @@ -7593,7 +7642,7 @@ struct AAPrivatizablePtrCallSiteArgument final /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - if (getIRPosition().hasAttr(Attribute::ByVal)) + if (A.hasAttr(getIRPosition(), Attribute::ByVal)) indicateOptimisticFixpoint(); } @@ -7606,15 +7655,17 @@ struct AAPrivatizablePtrCallSiteArgument final return indicatePessimisticFixpoint(); const IRPosition &IRP = getIRPosition(); - auto &NoCaptureAA = - A.getAAFor<AANoCapture>(*this, IRP, DepClassTy::REQUIRED); - if (!NoCaptureAA.isAssumedNoCapture()) { + bool IsKnownNoCapture; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, IRP, DepClassTy::REQUIRED, IsKnownNoCapture); + if (!IsAssumedNoCapture) { LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] pointer might be captured!\n"); return indicatePessimisticFixpoint(); } - auto &NoAliasAA = A.getAAFor<AANoAlias>(*this, IRP, DepClassTy::REQUIRED); - if (!NoAliasAA.isAssumedNoAlias()) { + bool IsKnownNoAlias; + if (!AA::hasAssumedIRAttr<Attribute::NoAlias>( + A, this, IRP, DepClassTy::REQUIRED, IsKnownNoAlias)) { LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] pointer might alias!\n"); return indicatePessimisticFixpoint(); } @@ -7679,16 +7730,16 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { intersectAssumedBits(BEST_STATE); - getKnownStateFromValue(getIRPosition(), getState()); + getKnownStateFromValue(A, getIRPosition(), getState()); AAMemoryBehavior::initialize(A); } /// Return the memory behavior information encoded in the IR for \p IRP. - static void getKnownStateFromValue(const IRPosition &IRP, + static void getKnownStateFromValue(Attributor &A, const IRPosition &IRP, BitIntegerState &State, bool IgnoreSubsumingPositions = false) { SmallVector<Attribute, 2> Attrs; - IRP.getAttrs(AttrKinds, Attrs, IgnoreSubsumingPositions); + A.getAttrs(IRP, AttrKinds, Attrs, IgnoreSubsumingPositions); for (const Attribute &Attr : Attrs) { switch (Attr.getKindAsEnum()) { case Attribute::ReadNone: @@ -7714,7 +7765,7 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior { } /// See AbstractAttribute::getDeducedAttributes(...). - void getDeducedAttributes(LLVMContext &Ctx, + void getDeducedAttributes(Attributor &A, LLVMContext &Ctx, SmallVectorImpl<Attribute> &Attrs) const override { assert(Attrs.size() == 0); if (isAssumedReadNone()) @@ -7728,29 +7779,30 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior { /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { - if (hasAttr(Attribute::ReadNone, /* IgnoreSubsumingPositions */ true)) - return ChangeStatus::UNCHANGED; - const IRPosition &IRP = getIRPosition(); + if (A.hasAttr(IRP, Attribute::ReadNone, + /* IgnoreSubsumingPositions */ true)) + return ChangeStatus::UNCHANGED; + // Check if we would improve the existing attributes first. SmallVector<Attribute, 4> DeducedAttrs; - getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs); + getDeducedAttributes(A, IRP.getAnchorValue().getContext(), DeducedAttrs); if (llvm::all_of(DeducedAttrs, [&](const Attribute &Attr) { - return IRP.hasAttr(Attr.getKindAsEnum(), - /* IgnoreSubsumingPositions */ true); + return A.hasAttr(IRP, Attr.getKindAsEnum(), + /* IgnoreSubsumingPositions */ true); })) return ChangeStatus::UNCHANGED; // Clear existing attributes. - IRP.removeAttrs(AttrKinds); + A.removeAttrs(IRP, AttrKinds); // Use the generic manifest method. return IRAttribute::manifest(A); } /// See AbstractState::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { if (isAssumedReadNone()) return "readnone"; if (isAssumedReadOnly()) @@ -7807,15 +7859,10 @@ struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating { // TODO: Make IgnoreSubsumingPositions a property of an IRAttribute so we // can query it when we use has/getAttr. That would allow us to reuse the // initialize of the base class here. - bool HasByVal = - IRP.hasAttr({Attribute::ByVal}, /* IgnoreSubsumingPositions */ true); - getKnownStateFromValue(IRP, getState(), + bool HasByVal = A.hasAttr(IRP, {Attribute::ByVal}, + /* IgnoreSubsumingPositions */ true); + getKnownStateFromValue(A, IRP, getState(), /* IgnoreSubsumingPositions */ HasByVal); - - // Initialize the use vector with all direct uses of the associated value. - Argument *Arg = getAssociatedArgument(); - if (!Arg || !A.isFunctionIPOAmendable(*(Arg->getParent()))) - indicatePessimisticFixpoint(); } ChangeStatus manifest(Attributor &A) override { @@ -7825,10 +7872,12 @@ struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating { // TODO: From readattrs.ll: "inalloca parameters are always // considered written" - if (hasAttr({Attribute::InAlloca, Attribute::Preallocated})) { + if (A.hasAttr(getIRPosition(), + {Attribute::InAlloca, Attribute::Preallocated})) { removeKnownBits(NO_WRITES); removeAssumedBits(NO_WRITES); } + A.removeAttrs(getIRPosition(), AttrKinds); return AAMemoryBehaviorFloating::manifest(A); } @@ -7874,9 +7923,11 @@ struct AAMemoryBehaviorCallSiteArgument final : AAMemoryBehaviorArgument { // redirecting requests to the callee argument. Argument *Arg = getAssociatedArgument(); const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = + auto *ArgAA = A.getAAFor<AAMemoryBehavior>(*this, ArgPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), ArgAA.getState()); + if (!ArgAA) + return indicatePessimisticFixpoint(); + return clampStateAndIndicateChange(getState(), ArgAA->getState()); } /// See AbstractAttribute::trackStatistics() @@ -7898,11 +7949,7 @@ struct AAMemoryBehaviorCallSiteReturned final : AAMemoryBehaviorFloating { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { AAMemoryBehaviorImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); } - /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { // We do not annotate returned values. @@ -7935,16 +7982,9 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { else if (isAssumedWriteOnly()) ME = MemoryEffects::writeOnly(); - // Intersect with existing memory attribute, as we currently deduce the - // location and modref portion separately. - MemoryEffects ExistingME = F.getMemoryEffects(); - ME &= ExistingME; - if (ME == ExistingME) - return ChangeStatus::UNCHANGED; - - return IRAttributeManifest::manifestAttrs( - A, getIRPosition(), Attribute::getWithMemoryEffects(F.getContext(), ME), - /*ForceReplace*/ true); + A.removeAttrs(getIRPosition(), AttrKinds); + return A.manifestAttrs(getIRPosition(), + Attribute::getWithMemoryEffects(F.getContext(), ME)); } /// See AbstractAttribute::trackStatistics() @@ -7963,14 +8003,6 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { AAMemoryBehaviorCallSite(const IRPosition &IRP, Attributor &A) : AAMemoryBehaviorImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AAMemoryBehaviorImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -7979,9 +8011,11 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = + auto *FnAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::REQUIRED); - return clampStateAndIndicateChange(getState(), FnAA.getState()); + if (!FnAA) + return indicatePessimisticFixpoint(); + return clampStateAndIndicateChange(getState(), FnAA->getState()); } /// See AbstractAttribute::manifest(...). @@ -7996,17 +8030,9 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { else if (isAssumedWriteOnly()) ME = MemoryEffects::writeOnly(); - // Intersect with existing memory attribute, as we currently deduce the - // location and modref portion separately. - MemoryEffects ExistingME = CB.getMemoryEffects(); - ME &= ExistingME; - if (ME == ExistingME) - return ChangeStatus::UNCHANGED; - - return IRAttributeManifest::manifestAttrs( - A, getIRPosition(), - Attribute::getWithMemoryEffects(CB.getContext(), ME), - /*ForceReplace*/ true); + A.removeAttrs(getIRPosition(), AttrKinds); + return A.manifestAttrs( + getIRPosition(), Attribute::getWithMemoryEffects(CB.getContext(), ME)); } /// See AbstractAttribute::trackStatistics() @@ -8030,10 +8056,12 @@ ChangeStatus AAMemoryBehaviorFunction::updateImpl(Attributor &A) { // the local state. No further analysis is required as the other memory // state is as optimistic as it gets. if (const auto *CB = dyn_cast<CallBase>(&I)) { - const auto &MemBehaviorAA = A.getAAFor<AAMemoryBehavior>( + const auto *MemBehaviorAA = A.getAAFor<AAMemoryBehavior>( *this, IRPosition::callsite_function(*CB), DepClassTy::REQUIRED); - intersectAssumedBits(MemBehaviorAA.getAssumed()); - return !isAtFixpoint(); + if (MemBehaviorAA) { + intersectAssumedBits(MemBehaviorAA->getAssumed()); + return !isAtFixpoint(); + } } // Remove access kind modifiers if necessary. @@ -8066,12 +8094,14 @@ ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) { AAMemoryBehavior::base_t FnMemAssumedState = AAMemoryBehavior::StateType::getWorstState(); if (!Arg || !Arg->hasByValAttr()) { - const auto &FnMemAA = + const auto *FnMemAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::OPTIONAL); - FnMemAssumedState = FnMemAA.getAssumed(); - S.addKnownBits(FnMemAA.getKnown()); - if ((S.getAssumed() & FnMemAA.getAssumed()) == S.getAssumed()) - return ChangeStatus::UNCHANGED; + if (FnMemAA) { + FnMemAssumedState = FnMemAA->getAssumed(); + S.addKnownBits(FnMemAA->getKnown()); + if ((S.getAssumed() & FnMemAA->getAssumed()) == S.getAssumed()) + return ChangeStatus::UNCHANGED; + } } // The current assumed state used to determine a change. @@ -8081,9 +8111,14 @@ ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) { // it is, any information derived would be irrelevant anyway as we cannot // check the potential aliases introduced by the capture. However, no need // to fall back to anythign less optimistic than the function state. - const auto &ArgNoCaptureAA = - A.getAAFor<AANoCapture>(*this, IRP, DepClassTy::OPTIONAL); - if (!ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) { + bool IsKnownNoCapture; + const AANoCapture *ArgNoCaptureAA = nullptr; + bool IsAssumedNoCapture = AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, IRP, DepClassTy::OPTIONAL, IsKnownNoCapture, false, + &ArgNoCaptureAA); + + if (!IsAssumedNoCapture && + (!ArgNoCaptureAA || !ArgNoCaptureAA->isAssumedNoCaptureMaybeReturned())) { S.intersectAssumedBits(FnMemAssumedState); return (AssumedState != getAssumed()) ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; @@ -8137,9 +8172,10 @@ bool AAMemoryBehaviorFloating::followUsersOfUseIn(Attributor &A, const Use &U, // need to check call users. if (U.get()->getType()->isPointerTy()) { unsigned ArgNo = CB->getArgOperandNo(&U); - const auto &ArgNoCaptureAA = A.getAAFor<AANoCapture>( - *this, IRPosition::callsite_argument(*CB, ArgNo), DepClassTy::OPTIONAL); - return !ArgNoCaptureAA.isAssumedNoCapture(); + bool IsKnownNoCapture; + return !AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, IRPosition::callsite_argument(*CB, ArgNo), + DepClassTy::OPTIONAL, IsKnownNoCapture); } return true; @@ -8195,11 +8231,13 @@ void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use &U, Pos = IRPosition::callsite_argument(*CB, CB->getArgOperandNo(&U)); else Pos = IRPosition::callsite_function(*CB); - const auto &MemBehaviorAA = + const auto *MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(*this, Pos, DepClassTy::OPTIONAL); + if (!MemBehaviorAA) + break; // "assumed" has at most the same bits as the MemBehaviorAA assumed // and at least "known". - intersectAssumedBits(MemBehaviorAA.getAssumed()); + intersectAssumedBits(MemBehaviorAA->getAssumed()); return; } }; @@ -8286,7 +8324,7 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { UseArgMemOnly = !AnchorFn->hasLocalLinkage(); SmallVector<Attribute, 2> Attrs; - IRP.getAttrs({Attribute::Memory}, Attrs, IgnoreSubsumingPositions); + A.getAttrs(IRP, {Attribute::Memory}, Attrs, IgnoreSubsumingPositions); for (const Attribute &Attr : Attrs) { // TODO: We can map MemoryEffects to Attributor locations more precisely. MemoryEffects ME = Attr.getMemoryEffects(); @@ -8304,11 +8342,10 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { else { // Remove location information, only keep read/write info. ME = MemoryEffects(ME.getModRef()); - IRAttributeManifest::manifestAttrs( - A, IRP, - Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), - ME), - /*ForceReplace*/ true); + A.manifestAttrs(IRP, + Attribute::getWithMemoryEffects( + IRP.getAnchorValue().getContext(), ME), + /*ForceReplace*/ true); } continue; } @@ -8319,11 +8356,10 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { else { // Remove location information, only keep read/write info. ME = MemoryEffects(ME.getModRef()); - IRAttributeManifest::manifestAttrs( - A, IRP, - Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), - ME), - /*ForceReplace*/ true); + A.manifestAttrs(IRP, + Attribute::getWithMemoryEffects( + IRP.getAnchorValue().getContext(), ME), + /*ForceReplace*/ true); } continue; } @@ -8331,7 +8367,7 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { } /// See AbstractAttribute::getDeducedAttributes(...). - void getDeducedAttributes(LLVMContext &Ctx, + void getDeducedAttributes(Attributor &A, LLVMContext &Ctx, SmallVectorImpl<Attribute> &Attrs) const override { // TODO: We can map Attributor locations to MemoryEffects more precisely. assert(Attrs.size() == 0); @@ -8359,27 +8395,13 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { const IRPosition &IRP = getIRPosition(); SmallVector<Attribute, 1> DeducedAttrs; - getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs); + getDeducedAttributes(A, IRP.getAnchorValue().getContext(), DeducedAttrs); if (DeducedAttrs.size() != 1) return ChangeStatus::UNCHANGED; MemoryEffects ME = DeducedAttrs[0].getMemoryEffects(); - // Intersect with existing memory attribute, as we currently deduce the - // location and modref portion separately. - SmallVector<Attribute, 1> ExistingAttrs; - IRP.getAttrs({Attribute::Memory}, ExistingAttrs, - /* IgnoreSubsumingPositions */ true); - if (ExistingAttrs.size() == 1) { - MemoryEffects ExistingME = ExistingAttrs[0].getMemoryEffects(); - ME &= ExistingME; - if (ME == ExistingME) - return ChangeStatus::UNCHANGED; - } - - return IRAttributeManifest::manifestAttrs( - A, IRP, - Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), ME), - /*ForceReplace*/ true); + return A.manifestAttrs(IRP, Attribute::getWithMemoryEffects( + IRP.getAnchorValue().getContext(), ME)); } /// See AAMemoryLocation::checkForAllAccessesToMemoryKind(...). @@ -8492,13 +8514,16 @@ protected: if (!Accesses) Accesses = new (Allocator) AccessSet(); Changed |= Accesses->insert(AccessInfo{I, Ptr, AK}).second; + if (MLK == NO_UNKOWN_MEM) + MLK = NO_LOCATIONS; State.removeAssumedBits(MLK); } /// Determine the underlying locations kinds for \p Ptr, e.g., globals or /// arguments, and update the state and access map accordingly. void categorizePtrValue(Attributor &A, const Instruction &I, const Value &Ptr, - AAMemoryLocation::StateType &State, bool &Changed); + AAMemoryLocation::StateType &State, bool &Changed, + unsigned AccessAS = 0); /// Used to allocate access sets. BumpPtrAllocator &Allocator; @@ -8506,14 +8531,24 @@ protected: void AAMemoryLocationImpl::categorizePtrValue( Attributor &A, const Instruction &I, const Value &Ptr, - AAMemoryLocation::StateType &State, bool &Changed) { + AAMemoryLocation::StateType &State, bool &Changed, unsigned AccessAS) { LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Categorize pointer locations for " << Ptr << " [" << getMemoryLocationsAsStr(State.getAssumed()) << "]\n"); auto Pred = [&](Value &Obj) { + unsigned ObjectAS = Obj.getType()->getPointerAddressSpace(); // TODO: recognize the TBAA used for constant accesses. MemoryLocationsKind MLK = NO_LOCATIONS; + + // Filter accesses to constant (GPU) memory if we have an AS at the access + // site or the object is known to actually have the associated AS. + if ((AccessAS == (unsigned)AA::GPUAddressSpace::Constant || + (ObjectAS == (unsigned)AA::GPUAddressSpace::Constant && + isIdentifiedObject(&Obj))) && + AA::isGPU(*I.getModule())) + return true; + if (isa<UndefValue>(&Obj)) return true; if (isa<Argument>(&Obj)) { @@ -8537,15 +8572,16 @@ void AAMemoryLocationImpl::categorizePtrValue( else MLK = NO_GLOBAL_EXTERNAL_MEM; } else if (isa<ConstantPointerNull>(&Obj) && - !NullPointerIsDefined(getAssociatedFunction(), - Ptr.getType()->getPointerAddressSpace())) { + (!NullPointerIsDefined(getAssociatedFunction(), AccessAS) || + !NullPointerIsDefined(getAssociatedFunction(), ObjectAS))) { return true; } else if (isa<AllocaInst>(&Obj)) { MLK = NO_LOCAL_MEM; } else if (const auto *CB = dyn_cast<CallBase>(&Obj)) { - const auto &NoAliasAA = A.getAAFor<AANoAlias>( - *this, IRPosition::callsite_returned(*CB), DepClassTy::OPTIONAL); - if (NoAliasAA.isAssumedNoAlias()) + bool IsKnownNoAlias; + if (AA::hasAssumedIRAttr<Attribute::NoAlias>( + A, this, IRPosition::callsite_returned(*CB), DepClassTy::OPTIONAL, + IsKnownNoAlias)) MLK = NO_MALLOCED_MEM; else MLK = NO_UNKOWN_MEM; @@ -8556,15 +8592,15 @@ void AAMemoryLocationImpl::categorizePtrValue( assert(MLK != NO_LOCATIONS && "No location specified!"); LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Ptr value can be categorized: " << Obj << " -> " << getMemoryLocationsAsStr(MLK) << "\n"); - updateStateAndAccessesMap(getState(), MLK, &I, &Obj, Changed, + updateStateAndAccessesMap(State, MLK, &I, &Obj, Changed, getAccessKindFromInst(&I)); return true; }; - const auto &AA = A.getAAFor<AAUnderlyingObjects>( + const auto *AA = A.getAAFor<AAUnderlyingObjects>( *this, IRPosition::value(Ptr), DepClassTy::OPTIONAL); - if (!AA.forallUnderlyingObjects(Pred, AA::Intraprocedural)) { + if (!AA || !AA->forallUnderlyingObjects(Pred, AA::Intraprocedural)) { LLVM_DEBUG( dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n"); updateStateAndAccessesMap(State, NO_UNKOWN_MEM, &I, nullptr, Changed, @@ -8589,10 +8625,10 @@ void AAMemoryLocationImpl::categorizeArgumentPointerLocations( // Skip readnone arguments. const IRPosition &ArgOpIRP = IRPosition::callsite_argument(CB, ArgNo); - const auto &ArgOpMemLocationAA = + const auto *ArgOpMemLocationAA = A.getAAFor<AAMemoryBehavior>(*this, ArgOpIRP, DepClassTy::OPTIONAL); - if (ArgOpMemLocationAA.isAssumedReadNone()) + if (ArgOpMemLocationAA && ArgOpMemLocationAA->isAssumedReadNone()) continue; // Categorize potentially accessed pointer arguments as if there was an @@ -8613,22 +8649,27 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I, if (auto *CB = dyn_cast<CallBase>(&I)) { // First check if we assume any memory is access is visible. - const auto &CBMemLocationAA = A.getAAFor<AAMemoryLocation>( + const auto *CBMemLocationAA = A.getAAFor<AAMemoryLocation>( *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Categorize call site: " << I << " [" << CBMemLocationAA << "]\n"); + if (!CBMemLocationAA) { + updateStateAndAccessesMap(AccessedLocs, NO_UNKOWN_MEM, &I, nullptr, + Changed, getAccessKindFromInst(&I)); + return NO_UNKOWN_MEM; + } - if (CBMemLocationAA.isAssumedReadNone()) + if (CBMemLocationAA->isAssumedReadNone()) return NO_LOCATIONS; - if (CBMemLocationAA.isAssumedInaccessibleMemOnly()) { + if (CBMemLocationAA->isAssumedInaccessibleMemOnly()) { updateStateAndAccessesMap(AccessedLocs, NO_INACCESSIBLE_MEM, &I, nullptr, Changed, getAccessKindFromInst(&I)); return AccessedLocs.getAssumed(); } uint32_t CBAssumedNotAccessedLocs = - CBMemLocationAA.getAssumedNotAccessedLocation(); + CBMemLocationAA->getAssumedNotAccessedLocation(); // Set the argmemonly and global bit as we handle them separately below. uint32_t CBAssumedNotAccessedLocsNoArgMem = @@ -8651,7 +8692,7 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I, getAccessKindFromInst(&I)); return true; }; - if (!CBMemLocationAA.checkForAllAccessesToMemoryKind( + if (!CBMemLocationAA->checkForAllAccessesToMemoryKind( AccessPred, inverseLocation(NO_GLOBAL_MEM, false, false))) return AccessedLocs.getWorstState(); } @@ -8676,7 +8717,8 @@ AAMemoryLocationImpl::categorizeAccessedLocations(Attributor &A, Instruction &I, LLVM_DEBUG( dbgs() << "[AAMemoryLocation] Categorize memory access with pointer: " << I << " [" << *Ptr << "]\n"); - categorizePtrValue(A, I, *Ptr, AccessedLocs, Changed); + categorizePtrValue(A, I, *Ptr, AccessedLocs, Changed, + Ptr->getType()->getPointerAddressSpace()); return AccessedLocs.getAssumed(); } @@ -8695,14 +8737,14 @@ struct AAMemoryLocationFunction final : public AAMemoryLocationImpl { /// See AbstractAttribute::updateImpl(Attributor &A). ChangeStatus updateImpl(Attributor &A) override { - const auto &MemBehaviorAA = + const auto *MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(*this, getIRPosition(), DepClassTy::NONE); - if (MemBehaviorAA.isAssumedReadNone()) { - if (MemBehaviorAA.isKnownReadNone()) + if (MemBehaviorAA && MemBehaviorAA->isAssumedReadNone()) { + if (MemBehaviorAA->isKnownReadNone()) return indicateOptimisticFixpoint(); assert(isAssumedReadNone() && "AAMemoryLocation was not read-none but AAMemoryBehavior was!"); - A.recordDependence(MemBehaviorAA, *this, DepClassTy::OPTIONAL); + A.recordDependence(*MemBehaviorAA, *this, DepClassTy::OPTIONAL); return ChangeStatus::UNCHANGED; } @@ -8747,14 +8789,6 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { AAMemoryLocationCallSite(const IRPosition &IRP, Attributor &A) : AAMemoryLocationImpl(IRP, A) {} - /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override { - AAMemoryLocationImpl::initialize(A); - Function *F = getAssociatedFunction(); - if (!F || F->isDeclaration()) - indicatePessimisticFixpoint(); - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { // TODO: Once we have call site specific value information we can provide @@ -8763,8 +8797,10 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = + auto *FnAA = A.getAAFor<AAMemoryLocation>(*this, FnPos, DepClassTy::REQUIRED); + if (!FnAA) + return indicatePessimisticFixpoint(); bool Changed = false; auto AccessPred = [&](const Instruction *I, const Value *Ptr, AccessKind Kind, MemoryLocationsKind MLK) { @@ -8772,7 +8808,7 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { getAccessKindFromInst(I)); return true; }; - if (!FnAA.checkForAllAccessesToMemoryKind(AccessPred, ALL_LOCATIONS)) + if (!FnAA->checkForAllAccessesToMemoryKind(AccessPred, ALL_LOCATIONS)) return indicatePessimisticFixpoint(); return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; } @@ -8808,7 +8844,7 @@ struct AAValueConstantRangeImpl : AAValueConstantRange { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { std::string Str; llvm::raw_string_ostream OS(Str); OS << "range(" << getBitWidth() << ")<"; @@ -9023,15 +9059,6 @@ struct AAValueConstantRangeArgument final AAValueConstantRangeArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} - /// See AbstractAttribute::initialize(..). - void initialize(Attributor &A) override { - if (!getAnchorScope() || getAnchorScope()->isDeclaration()) { - indicatePessimisticFixpoint(); - } else { - Base::initialize(A); - } - } - /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(value_range) @@ -9052,7 +9079,10 @@ struct AAValueConstantRangeReturned : Base(IRP, A) {} /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override {} + void initialize(Attributor &A) override { + if (!A.isFunctionIPOAmendable(*getAssociatedFunction())) + indicatePessimisticFixpoint(); + } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -9141,17 +9171,21 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return false; - auto &LHSAA = A.getAAFor<AAValueConstantRange>( + auto *LHSAA = A.getAAFor<AAValueConstantRange>( *this, IRPosition::value(*LHS, getCallBaseContext()), DepClassTy::REQUIRED); - QuerriedAAs.push_back(&LHSAA); - auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI); + if (!LHSAA) + return false; + QuerriedAAs.push_back(LHSAA); + auto LHSAARange = LHSAA->getAssumedConstantRange(A, CtxI); - auto &RHSAA = A.getAAFor<AAValueConstantRange>( + auto *RHSAA = A.getAAFor<AAValueConstantRange>( *this, IRPosition::value(*RHS, getCallBaseContext()), DepClassTy::REQUIRED); - QuerriedAAs.push_back(&RHSAA); - auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI); + if (!RHSAA) + return false; + QuerriedAAs.push_back(RHSAA); + auto RHSAARange = RHSAA->getAssumedConstantRange(A, CtxI); auto AssumedRange = LHSAARange.binaryOp(BinOp->getOpcode(), RHSAARange); @@ -9184,12 +9218,14 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { if (!OpV->getType()->isIntegerTy()) return false; - auto &OpAA = A.getAAFor<AAValueConstantRange>( + auto *OpAA = A.getAAFor<AAValueConstantRange>( *this, IRPosition::value(*OpV, getCallBaseContext()), DepClassTy::REQUIRED); - QuerriedAAs.push_back(&OpAA); - T.unionAssumed( - OpAA.getAssumed().castOp(CastI->getOpcode(), getState().getBitWidth())); + if (!OpAA) + return false; + QuerriedAAs.push_back(OpAA); + T.unionAssumed(OpAA->getAssumed().castOp(CastI->getOpcode(), + getState().getBitWidth())); return T.isValidState(); } @@ -9224,16 +9260,20 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return false; - auto &LHSAA = A.getAAFor<AAValueConstantRange>( + auto *LHSAA = A.getAAFor<AAValueConstantRange>( *this, IRPosition::value(*LHS, getCallBaseContext()), DepClassTy::REQUIRED); - QuerriedAAs.push_back(&LHSAA); - auto &RHSAA = A.getAAFor<AAValueConstantRange>( + if (!LHSAA) + return false; + QuerriedAAs.push_back(LHSAA); + auto *RHSAA = A.getAAFor<AAValueConstantRange>( *this, IRPosition::value(*RHS, getCallBaseContext()), DepClassTy::REQUIRED); - QuerriedAAs.push_back(&RHSAA); - auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI); - auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI); + if (!RHSAA) + return false; + QuerriedAAs.push_back(RHSAA); + auto LHSAARange = LHSAA->getAssumedConstantRange(A, CtxI); + auto RHSAARange = RHSAA->getAssumedConstantRange(A, CtxI); // If one of them is empty set, we can't decide. if (LHSAARange.isEmptySet() || RHSAARange.isEmptySet()) @@ -9260,8 +9300,10 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { else T.unionAssumed(ConstantRange(/* BitWidth */ 1, /* isFullSet */ true)); - LLVM_DEBUG(dbgs() << "[AAValueConstantRange] " << *CmpI << " " << LHSAA - << " " << RHSAA << "\n"); + LLVM_DEBUG(dbgs() << "[AAValueConstantRange] " << *CmpI << " after " + << (MustTrue ? "true" : (MustFalse ? "false" : "unknown")) + << ": " << T << "\n\t" << *LHSAA << "\t<op>\n\t" + << *RHSAA); // TODO: Track a known state too. return T.isValidState(); @@ -9287,12 +9329,15 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { Value *VPtr = *SimplifiedOpV; // If the value is not instruction, we query AA to Attributor. - const auto &AA = A.getAAFor<AAValueConstantRange>( + const auto *AA = A.getAAFor<AAValueConstantRange>( *this, IRPosition::value(*VPtr, getCallBaseContext()), DepClassTy::REQUIRED); // Clamp operator is not used to utilize a program point CtxI. - T.unionAssumed(AA.getAssumedConstantRange(A, CtxI)); + if (AA) + T.unionAssumed(AA->getAssumedConstantRange(A, CtxI)); + else + return false; return T.isValidState(); } @@ -9454,12 +9499,12 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues { return false; if (!IRP.getAssociatedType()->isIntegerTy()) return false; - auto &PotentialValuesAA = A.getAAFor<AAPotentialConstantValues>( + auto *PotentialValuesAA = A.getAAFor<AAPotentialConstantValues>( *this, IRP, DepClassTy::REQUIRED); - if (!PotentialValuesAA.getState().isValidState()) + if (!PotentialValuesAA || !PotentialValuesAA->getState().isValidState()) return false; - ContainsUndef = PotentialValuesAA.getState().undefIsContained(); - S = PotentialValuesAA.getState().getAssumedSet(); + ContainsUndef = PotentialValuesAA->getState().undefIsContained(); + S = PotentialValuesAA->getState().getAssumedSet(); return true; } @@ -9483,7 +9528,7 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { std::string Str; llvm::raw_string_ostream OS(Str); OS << getState(); @@ -9506,15 +9551,6 @@ struct AAPotentialConstantValuesArgument final AAPotentialConstantValuesArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} - /// See AbstractAttribute::initialize(..). - void initialize(Attributor &A) override { - if (!getAnchorScope() || getAnchorScope()->isDeclaration()) { - indicatePessimisticFixpoint(); - } else { - Base::initialize(A); - } - } - /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(potential_values) @@ -9529,6 +9565,12 @@ struct AAPotentialConstantValuesReturned AAPotentialConstantValuesReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} + void initialize(Attributor &A) override { + if (!A.isFunctionIPOAmendable(*getAssociatedFunction())) + indicatePessimisticFixpoint(); + Base::initialize(A); + } + /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(potential_values) @@ -9958,9 +10000,11 @@ struct AAPotentialConstantValuesCallSiteArgument ChangeStatus updateImpl(Attributor &A) override { Value &V = getAssociatedValue(); auto AssumedBefore = getAssumed(); - auto &AA = A.getAAFor<AAPotentialConstantValues>( + auto *AA = A.getAAFor<AAPotentialConstantValues>( *this, IRPosition::value(V), DepClassTy::REQUIRED); - const auto &S = AA.getAssumed(); + if (!AA) + return indicatePessimisticFixpoint(); + const auto &S = AA->getAssumed(); unionAssumed(S); return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; @@ -9971,27 +10015,39 @@ struct AAPotentialConstantValuesCallSiteArgument STATS_DECLTRACK_CSARG_ATTR(potential_values) } }; +} // namespace /// ------------------------ NoUndef Attribute --------------------------------- +bool AANoUndef::isImpliedByIR(Attributor &A, const IRPosition &IRP, + Attribute::AttrKind ImpliedAttributeKind, + bool IgnoreSubsumingPositions) { + assert(ImpliedAttributeKind == Attribute::NoUndef && + "Unexpected attribute kind"); + if (A.hasAttr(IRP, {Attribute::NoUndef}, IgnoreSubsumingPositions, + Attribute::NoUndef)) + return true; + + Value &Val = IRP.getAssociatedValue(); + if (IRP.getPositionKind() != IRPosition::IRP_RETURNED && + isGuaranteedNotToBeUndefOrPoison(&Val)) { + LLVMContext &Ctx = Val.getContext(); + A.manifestAttrs(IRP, Attribute::get(Ctx, Attribute::NoUndef)); + return true; + } + + return false; +} + +namespace { struct AANoUndefImpl : AANoUndef { AANoUndefImpl(const IRPosition &IRP, Attributor &A) : AANoUndef(IRP, A) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - if (getIRPosition().hasAttr({Attribute::NoUndef})) { - indicateOptimisticFixpoint(); - return; - } Value &V = getAssociatedValue(); if (isa<UndefValue>(V)) indicatePessimisticFixpoint(); - else if (isa<FreezeInst>(V)) - indicateOptimisticFixpoint(); - else if (getPositionKind() != IRPosition::IRP_RETURNED && - isGuaranteedNotToBeUndefOrPoison(&V)) - indicateOptimisticFixpoint(); - else - AANoUndef::initialize(A); + assert(!isImpliedByIR(A, getIRPosition(), Attribute::NoUndef)); } /// See followUsesInMBEC @@ -10015,7 +10071,7 @@ struct AANoUndefImpl : AANoUndef { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return getAssumed() ? "noundef" : "may-undef-or-poison"; } @@ -10052,33 +10108,39 @@ struct AANoUndefFloating : public AANoUndefImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { + auto VisitValueCB = [&](const IRPosition &IRP) -> bool { + bool IsKnownNoUndef; + return AA::hasAssumedIRAttr<Attribute::NoUndef>( + A, this, IRP, DepClassTy::REQUIRED, IsKnownNoUndef); + }; - SmallVector<AA::ValueAndContext> Values; + bool Stripped; bool UsedAssumedInformation = false; + Value *AssociatedValue = &getAssociatedValue(); + SmallVector<AA::ValueAndContext> Values; if (!A.getAssumedSimplifiedValues(getIRPosition(), *this, Values, - AA::AnyScope, UsedAssumedInformation)) { - Values.push_back({getAssociatedValue(), getCtxI()}); + AA::AnyScope, UsedAssumedInformation)) + Stripped = false; + else + Stripped = + Values.size() != 1 || Values.front().getValue() != AssociatedValue; + + if (!Stripped) { + // If we haven't stripped anything we might still be able to use a + // different AA, but only if the IRP changes. Effectively when we + // interpret this not as a call site value but as a floating/argument + // value. + const IRPosition AVIRP = IRPosition::value(*AssociatedValue); + if (AVIRP == getIRPosition() || !VisitValueCB(AVIRP)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; } - StateType T; - auto VisitValueCB = [&](Value &V, const Instruction *CtxI) -> bool { - const auto &AA = A.getAAFor<AANoUndef>(*this, IRPosition::value(V), - DepClassTy::REQUIRED); - if (this == &AA) { - T.indicatePessimisticFixpoint(); - } else { - const AANoUndef::StateType &S = - static_cast<const AANoUndef::StateType &>(AA.getState()); - T ^= S; - } - return T.isValidState(); - }; - for (const auto &VAC : Values) - if (!VisitValueCB(*VAC.getValue(), VAC.getCtxI())) + if (!VisitValueCB(IRPosition::value(*VAC.getValue()))) return indicatePessimisticFixpoint(); - return clampStateAndIndicateChange(getState(), T); + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -10086,18 +10148,26 @@ struct AANoUndefFloating : public AANoUndefImpl { }; struct AANoUndefReturned final - : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl> { + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl, + AANoUndef::StateType, false, + Attribute::NoUndef> { AANoUndefReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl>(IRP, A) {} + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl, + AANoUndef::StateType, false, + Attribute::NoUndef>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) } }; struct AANoUndefArgument final - : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl> { + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl, + AANoUndef::StateType, false, + Attribute::NoUndef> { AANoUndefArgument(const IRPosition &IRP, Attributor &A) - : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl>(IRP, A) {} + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl, + AANoUndef::StateType, false, + Attribute::NoUndef>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noundef) } @@ -10112,14 +10182,173 @@ struct AANoUndefCallSiteArgument final : AANoUndefFloating { }; struct AANoUndefCallSiteReturned final - : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl> { + : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl, + AANoUndef::StateType, false, + Attribute::NoUndef> { AANoUndefCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl>(IRP, A) {} + : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl, + AANoUndef::StateType, false, + Attribute::NoUndef>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) } }; +/// ------------------------ NoFPClass Attribute ------------------------------- + +struct AANoFPClassImpl : AANoFPClass { + AANoFPClassImpl(const IRPosition &IRP, Attributor &A) : AANoFPClass(IRP, A) {} + + void initialize(Attributor &A) override { + const IRPosition &IRP = getIRPosition(); + + Value &V = IRP.getAssociatedValue(); + if (isa<UndefValue>(V)) { + indicateOptimisticFixpoint(); + return; + } + + SmallVector<Attribute> Attrs; + A.getAttrs(getIRPosition(), {Attribute::NoFPClass}, Attrs, false); + for (const auto &Attr : Attrs) { + addKnownBits(Attr.getNoFPClass()); + return; + } + + const DataLayout &DL = A.getDataLayout(); + if (getPositionKind() != IRPosition::IRP_RETURNED) { + KnownFPClass KnownFPClass = computeKnownFPClass(&V, DL); + addKnownBits(~KnownFPClass.KnownFPClasses); + } + + if (Instruction *CtxI = getCtxI()) + followUsesInMBEC(*this, A, getState(), *CtxI); + } + + /// See followUsesInMBEC + bool followUseInMBEC(Attributor &A, const Use *U, const Instruction *I, + AANoFPClass::StateType &State) { + const Value *UseV = U->get(); + const DominatorTree *DT = nullptr; + AssumptionCache *AC = nullptr; + const TargetLibraryInfo *TLI = nullptr; + InformationCache &InfoCache = A.getInfoCache(); + + if (Function *F = getAnchorScope()) { + DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F); + AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F); + TLI = InfoCache.getTargetLibraryInfoForFunction(*F); + } + + const DataLayout &DL = A.getDataLayout(); + + KnownFPClass KnownFPClass = + computeKnownFPClass(UseV, DL, + /*InterestedClasses=*/fcAllFlags, + /*Depth=*/0, TLI, AC, I, DT); + State.addKnownBits(~KnownFPClass.KnownFPClasses); + + bool TrackUse = false; + return TrackUse; + } + + const std::string getAsStr(Attributor *A) const override { + std::string Result = "nofpclass"; + raw_string_ostream OS(Result); + OS << getAssumedNoFPClass(); + return Result; + } + + void getDeducedAttributes(Attributor &A, LLVMContext &Ctx, + SmallVectorImpl<Attribute> &Attrs) const override { + Attrs.emplace_back(Attribute::getWithNoFPClass(Ctx, getAssumedNoFPClass())); + } +}; + +struct AANoFPClassFloating : public AANoFPClassImpl { + AANoFPClassFloating(const IRPosition &IRP, Attributor &A) + : AANoFPClassImpl(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + SmallVector<AA::ValueAndContext> Values; + bool UsedAssumedInformation = false; + if (!A.getAssumedSimplifiedValues(getIRPosition(), *this, Values, + AA::AnyScope, UsedAssumedInformation)) { + Values.push_back({getAssociatedValue(), getCtxI()}); + } + + StateType T; + auto VisitValueCB = [&](Value &V, const Instruction *CtxI) -> bool { + const auto *AA = A.getAAFor<AANoFPClass>(*this, IRPosition::value(V), + DepClassTy::REQUIRED); + if (!AA || this == AA) { + T.indicatePessimisticFixpoint(); + } else { + const AANoFPClass::StateType &S = + static_cast<const AANoFPClass::StateType &>(AA->getState()); + T ^= S; + } + return T.isValidState(); + }; + + for (const auto &VAC : Values) + if (!VisitValueCB(*VAC.getValue(), VAC.getCtxI())) + return indicatePessimisticFixpoint(); + + return clampStateAndIndicateChange(getState(), T); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(nofpclass) + } +}; + +struct AANoFPClassReturned final + : AAReturnedFromReturnedValues<AANoFPClass, AANoFPClassImpl, + AANoFPClassImpl::StateType, false, Attribute::None, false> { + AANoFPClassReturned(const IRPosition &IRP, Attributor &A) + : AAReturnedFromReturnedValues<AANoFPClass, AANoFPClassImpl, + AANoFPClassImpl::StateType, false, Attribute::None, false>( + IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(nofpclass) + } +}; + +struct AANoFPClassArgument final + : AAArgumentFromCallSiteArguments<AANoFPClass, AANoFPClassImpl> { + AANoFPClassArgument(const IRPosition &IRP, Attributor &A) + : AAArgumentFromCallSiteArguments<AANoFPClass, AANoFPClassImpl>(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nofpclass) } +}; + +struct AANoFPClassCallSiteArgument final : AANoFPClassFloating { + AANoFPClassCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AANoFPClassFloating(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(nofpclass) + } +}; + +struct AANoFPClassCallSiteReturned final + : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl> { + AANoFPClassCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl>(IRP, A) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(nofpclass) + } +}; + struct AACallEdgesImpl : public AACallEdges { AACallEdgesImpl(const IRPosition &IRP, Attributor &A) : AACallEdges(IRP, A) {} @@ -10133,7 +10362,7 @@ struct AACallEdgesImpl : public AACallEdges { return HasUnknownCalleeNonAsm; } - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return "CallEdges[" + std::to_string(HasUnknownCallee) + "," + std::to_string(CalledFunctions.size()) + "]"; } @@ -10191,6 +10420,11 @@ struct AACallEdgesCallSite : public AACallEdgesImpl { SmallVector<AA::ValueAndContext> Values; // Process any value that we might call. auto ProcessCalledOperand = [&](Value *V, Instruction *CtxI) { + if (isa<Constant>(V)) { + VisitValue(*V, CtxI); + return; + } + bool UsedAssumedInformation = false; Values.clear(); if (!A.getAssumedSimplifiedValues(IRPosition::value(*V), *this, Values, @@ -10246,14 +10480,16 @@ struct AACallEdgesFunction : public AACallEdgesImpl { auto ProcessCallInst = [&](Instruction &Inst) { CallBase &CB = cast<CallBase>(Inst); - auto &CBEdges = A.getAAFor<AACallEdges>( + auto *CBEdges = A.getAAFor<AACallEdges>( *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - if (CBEdges.hasNonAsmUnknownCallee()) + if (!CBEdges) + return false; + if (CBEdges->hasNonAsmUnknownCallee()) setHasUnknownCallee(true, Change); - if (CBEdges.hasUnknownCallee()) + if (CBEdges->hasUnknownCallee()) setHasUnknownCallee(false, Change); - for (Function *F : CBEdges.getOptimisticEdges()) + for (Function *F : CBEdges->getOptimisticEdges()) addCalledFunction(F, Change); return true; @@ -10277,8 +10513,9 @@ struct AACallEdgesFunction : public AACallEdgesImpl { struct AAInterFnReachabilityFunction : public CachedReachabilityAA<AAInterFnReachability, Function> { + using Base = CachedReachabilityAA<AAInterFnReachability, Function>; AAInterFnReachabilityFunction(const IRPosition &IRP, Attributor &A) - : CachedReachabilityAA<AAInterFnReachability, Function>(IRP, A) {} + : Base(IRP, A) {} bool instructionCanReach( Attributor &A, const Instruction &From, const Function &To, @@ -10287,10 +10524,10 @@ struct AAInterFnReachabilityFunction assert(From.getFunction() == getAnchorScope() && "Queried the wrong AA!"); auto *NonConstThis = const_cast<AAInterFnReachabilityFunction *>(this); - RQITy StackRQI(A, From, To, ExclusionSet); + RQITy StackRQI(A, From, To, ExclusionSet, false); typename RQITy::Reachable Result; - if (RQITy *RQIPtr = NonConstThis->checkQueryCache(A, StackRQI, Result)) - return NonConstThis->isReachableImpl(A, *RQIPtr); + if (!NonConstThis->checkQueryCache(A, StackRQI, Result)) + return NonConstThis->isReachableImpl(A, StackRQI); return Result == RQITy::Reachable::Yes; } @@ -10305,59 +10542,61 @@ struct AAInterFnReachabilityFunction if (!Visited) Visited = &LocalVisited; - const auto &IntraFnReachability = A.getAAFor<AAIntraFnReachability>( - *this, IRPosition::function(*RQI.From->getFunction()), - DepClassTy::OPTIONAL); - - // Determine call like instructions that we can reach from the inst. - SmallVector<CallBase *> ReachableCallBases; - auto CheckCallBase = [&](Instruction &CBInst) { - if (IntraFnReachability.isAssumedReachable(A, *RQI.From, CBInst, - RQI.ExclusionSet)) - ReachableCallBases.push_back(cast<CallBase>(&CBInst)); - return true; - }; - - bool UsedAssumedInformation = false; - if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this, - UsedAssumedInformation, - /* CheckBBLivenessOnly */ true)) - return rememberResult(A, RQITy::Reachable::Yes, RQI); - - for (CallBase *CB : ReachableCallBases) { - auto &CBEdges = A.getAAFor<AACallEdges>( + auto CheckReachableCallBase = [&](CallBase *CB) { + auto *CBEdges = A.getAAFor<AACallEdges>( *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); - if (!CBEdges.getState().isValidState()) - return rememberResult(A, RQITy::Reachable::Yes, RQI); + if (!CBEdges || !CBEdges->getState().isValidState()) + return false; // TODO Check To backwards in this case. - if (CBEdges.hasUnknownCallee()) - return rememberResult(A, RQITy::Reachable::Yes, RQI); + if (CBEdges->hasUnknownCallee()) + return false; - for (Function *Fn : CBEdges.getOptimisticEdges()) { + for (Function *Fn : CBEdges->getOptimisticEdges()) { if (Fn == RQI.To) - return rememberResult(A, RQITy::Reachable::Yes, RQI); + return false; if (!Visited->insert(Fn).second) continue; if (Fn->isDeclaration()) { if (Fn->hasFnAttribute(Attribute::NoCallback)) continue; // TODO Check To backwards in this case. - return rememberResult(A, RQITy::Reachable::Yes, RQI); + return false; } const AAInterFnReachability *InterFnReachability = this; if (Fn != getAnchorScope()) - InterFnReachability = &A.getAAFor<AAInterFnReachability>( + InterFnReachability = A.getAAFor<AAInterFnReachability>( *this, IRPosition::function(*Fn), DepClassTy::OPTIONAL); const Instruction &FnFirstInst = Fn->getEntryBlock().front(); - if (InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To, + if (!InterFnReachability || + InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To, RQI.ExclusionSet, Visited)) - return rememberResult(A, RQITy::Reachable::Yes, RQI); + return false; } - } + return true; + }; + + const auto *IntraFnReachability = A.getAAFor<AAIntraFnReachability>( + *this, IRPosition::function(*RQI.From->getFunction()), + DepClassTy::OPTIONAL); + + // Determine call like instructions that we can reach from the inst. + auto CheckCallBase = [&](Instruction &CBInst) { + if (!IntraFnReachability || !IntraFnReachability->isAssumedReachable( + A, *RQI.From, CBInst, RQI.ExclusionSet)) + return true; + return CheckReachableCallBase(cast<CallBase>(&CBInst)); + }; + + bool UsedExclusionSet = /* conservative */ true; + bool UsedAssumedInformation = false; + if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this, + UsedAssumedInformation, + /* CheckBBLivenessOnly */ true)) + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet); - return rememberResult(A, RQITy::Reachable::No, RQI); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); } void trackStatistics() const override {} @@ -10376,16 +10615,18 @@ askForAssumedConstant(Attributor &A, const AbstractAttribute &QueryingAA, return nullptr; // This will also pass the call base context. - const auto &AA = A.getAAFor<AAType>(QueryingAA, IRP, DepClassTy::NONE); + const auto *AA = A.getAAFor<AAType>(QueryingAA, IRP, DepClassTy::NONE); + if (!AA) + return nullptr; - std::optional<Constant *> COpt = AA.getAssumedConstant(A); + std::optional<Constant *> COpt = AA->getAssumedConstant(A); if (!COpt.has_value()) { - A.recordDependence(AA, QueryingAA, DepClassTy::OPTIONAL); + A.recordDependence(*AA, QueryingAA, DepClassTy::OPTIONAL); return std::nullopt; } if (auto *C = *COpt) { - A.recordDependence(AA, QueryingAA, DepClassTy::OPTIONAL); + A.recordDependence(*AA, QueryingAA, DepClassTy::OPTIONAL); return C; } return nullptr; @@ -10432,7 +10673,7 @@ struct AAPotentialValuesImpl : AAPotentialValues { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { std::string Str; llvm::raw_string_ostream OS(Str); OS << getState(); @@ -10454,9 +10695,9 @@ struct AAPotentialValuesImpl : AAPotentialValues { return nullptr; } - void addValue(Attributor &A, StateType &State, Value &V, - const Instruction *CtxI, AA::ValueScope S, - Function *AnchorScope) const { + virtual void addValue(Attributor &A, StateType &State, Value &V, + const Instruction *CtxI, AA::ValueScope S, + Function *AnchorScope) const { IRPosition ValIRP = IRPosition::value(V); if (auto *CB = dyn_cast_or_null<CallBase>(CtxI)) { @@ -10474,12 +10715,12 @@ struct AAPotentialValuesImpl : AAPotentialValues { std::optional<Value *> SimpleV = askOtherAA<AAValueConstantRange>(A, *this, ValIRP, Ty); if (SimpleV.has_value() && !*SimpleV) { - auto &PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>( + auto *PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>( *this, ValIRP, DepClassTy::OPTIONAL); - if (PotentialConstantsAA.isValidState()) { - for (const auto &It : PotentialConstantsAA.getAssumedSet()) + if (PotentialConstantsAA && PotentialConstantsAA->isValidState()) { + for (const auto &It : PotentialConstantsAA->getAssumedSet()) State.unionAssumed({{*ConstantInt::get(&Ty, It), nullptr}, S}); - if (PotentialConstantsAA.undefIsContained()) + if (PotentialConstantsAA->undefIsContained()) State.unionAssumed({{*UndefValue::get(&Ty), nullptr}, S}); return; } @@ -10586,14 +10827,23 @@ struct AAPotentialValuesImpl : AAPotentialValues { return ChangeStatus::UNCHANGED; } - bool getAssumedSimplifiedValues(Attributor &A, - SmallVectorImpl<AA::ValueAndContext> &Values, - AA::ValueScope S) const override { + bool getAssumedSimplifiedValues( + Attributor &A, SmallVectorImpl<AA::ValueAndContext> &Values, + AA::ValueScope S, bool RecurseForSelectAndPHI = false) const override { if (!isValidState()) return false; + bool UsedAssumedInformation = false; for (const auto &It : getAssumedSet()) - if (It.second & S) + if (It.second & S) { + if (RecurseForSelectAndPHI && (isa<PHINode>(It.first.getValue()) || + isa<SelectInst>(It.first.getValue()))) { + if (A.getAssumedSimplifiedValues( + IRPosition::inst(*cast<Instruction>(It.first.getValue())), + this, Values, S, UsedAssumedInformation)) + continue; + } Values.push_back(It.first); + } assert(!undefIsContained() && "Undef should be an explicit value!"); return true; } @@ -10607,7 +10857,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { ChangeStatus updateImpl(Attributor &A) override { auto AssumedBefore = getAssumed(); - genericValueTraversal(A); + genericValueTraversal(A, &getAssociatedValue()); return (AssumedBefore == getAssumed()) ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; @@ -10677,9 +10927,11 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { // The index is the operand that we assume is not null. unsigned PtrIdx = LHSIsNull; - auto &PtrNonNullAA = A.getAAFor<AANonNull>( - *this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED); - if (!PtrNonNullAA.isAssumedNonNull()) + bool IsKnownNonNull; + bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED, + IsKnownNonNull); + if (!IsAssumedNonNull) return false; // The new value depends on the predicate, true for != and false for ==. @@ -10743,7 +10995,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { InformationCache &InfoCache = A.getInfoCache(); if (InfoCache.isOnlyUsedByAssume(LI)) { if (!llvm::all_of(PotentialValueOrigins, [&](Instruction *I) { - if (!I) + if (!I || isa<AssumeInst>(I)) return true; if (auto *SI = dyn_cast<StoreInst>(I)) return A.isAssumedDead(SI->getOperandUse(0), this, @@ -10797,21 +11049,37 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { auto GetLivenessInfo = [&](const Function &F) -> LivenessInfo & { LivenessInfo &LI = LivenessAAs[&F]; if (!LI.LivenessAA) - LI.LivenessAA = &A.getAAFor<AAIsDead>(*this, IRPosition::function(F), - DepClassTy::NONE); + LI.LivenessAA = A.getAAFor<AAIsDead>(*this, IRPosition::function(F), + DepClassTy::NONE); return LI; }; if (&PHI == &getAssociatedValue()) { LivenessInfo &LI = GetLivenessInfo(*PHI.getFunction()); + const auto *CI = + A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>( + *PHI.getFunction()); + + Cycle *C = nullptr; + bool CyclePHI = mayBeInCycle(CI, &PHI, /* HeaderOnly */ true, &C); for (unsigned u = 0, e = PHI.getNumIncomingValues(); u < e; u++) { BasicBlock *IncomingBB = PHI.getIncomingBlock(u); - if (LI.LivenessAA->isEdgeDead(IncomingBB, PHI.getParent())) { + if (LI.LivenessAA && + LI.LivenessAA->isEdgeDead(IncomingBB, PHI.getParent())) { LI.AnyDead = true; continue; } - Worklist.push_back( - {{*PHI.getIncomingValue(u), IncomingBB->getTerminator()}, II.S}); + Value *V = PHI.getIncomingValue(u); + if (V == &PHI) + continue; + + // If the incoming value is not the PHI but an instruction in the same + // cycle we might have multiple versions of it flying around. + if (CyclePHI && isa<Instruction>(V) && + (!C || C->contains(cast<Instruction>(V)->getParent()))) + return false; + + Worklist.push_back({{*V, IncomingBB->getTerminator()}, II.S}); } return true; } @@ -10866,11 +11134,10 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F); const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F); auto *AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F); - OptimizationRemarkEmitter *ORE = nullptr; const DataLayout &DL = I.getModule()->getDataLayout(); SimplifyQuery Q(DL, TLI, DT, AC, &I); - Value *NewV = simplifyInstructionWithOperands(&I, NewOps, Q, ORE); + Value *NewV = simplifyInstructionWithOperands(&I, NewOps, Q); if (!NewV || NewV == &I) return false; @@ -10902,10 +11169,9 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { return false; } - void genericValueTraversal(Attributor &A) { + void genericValueTraversal(Attributor &A, Value *InitialV) { SmallMapVector<const Function *, LivenessInfo, 4> LivenessAAs; - Value *InitialV = &getAssociatedValue(); SmallSet<ItemInfo, 16> Visited; SmallVector<ItemInfo, 16> Worklist; Worklist.push_back({{*InitialV, getCtxI()}, AA::AnyScope}); @@ -10937,14 +11203,15 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { if (V->getType()->isPointerTy()) { NewV = AA::getWithType(*V->stripPointerCasts(), *V->getType()); } else { - auto *CB = dyn_cast<CallBase>(V); - if (CB && CB->getCalledFunction()) { - for (Argument &Arg : CB->getCalledFunction()->args()) - if (Arg.hasReturnedAttr()) { - NewV = CB->getArgOperand(Arg.getArgNo()); - break; - } - } + if (auto *CB = dyn_cast<CallBase>(V)) + if (auto *Callee = + dyn_cast_if_present<Function>(CB->getCalledOperand())) { + for (Argument &Arg : Callee->args()) + if (Arg.hasReturnedAttr()) { + NewV = CB->getArgOperand(Arg.getArgNo()); + break; + } + } } if (NewV && NewV != V) { Worklist.push_back({{*NewV, CtxI}, S}); @@ -11062,25 +11329,127 @@ struct AAPotentialValuesArgument final : AAPotentialValuesImpl { } }; -struct AAPotentialValuesReturned - : AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl> { - using Base = - AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl>; +struct AAPotentialValuesReturned : public AAPotentialValuesFloating { + using Base = AAPotentialValuesFloating; AAPotentialValuesReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} /// See AbstractAttribute::initialize(..). void initialize(Attributor &A) override { - if (A.hasSimplificationCallback(getIRPosition())) + Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration() || F->getReturnType()->isVoidTy()) { indicatePessimisticFixpoint(); - else - AAPotentialValues::initialize(A); + return; + } + + for (Argument &Arg : F->args()) + if (Arg.hasReturnedAttr()) { + addValue(A, getState(), Arg, nullptr, AA::AnyScope, F); + ReturnedArg = &Arg; + break; + } + if (!A.isFunctionIPOAmendable(*F) || + A.hasSimplificationCallback(getIRPosition())) { + if (!ReturnedArg) + indicatePessimisticFixpoint(); + else + indicateOptimisticFixpoint(); + } + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto AssumedBefore = getAssumed(); + bool UsedAssumedInformation = false; + + SmallVector<AA::ValueAndContext> Values; + Function *AnchorScope = getAnchorScope(); + auto HandleReturnedValue = [&](Value &V, Instruction *CtxI, + bool AddValues) { + for (AA::ValueScope S : {AA::Interprocedural, AA::Intraprocedural}) { + Values.clear(); + if (!A.getAssumedSimplifiedValues(IRPosition::value(V), this, Values, S, + UsedAssumedInformation, + /* RecurseForSelectAndPHI */ true)) + return false; + if (!AddValues) + continue; + for (const AA::ValueAndContext &VAC : Values) + addValue(A, getState(), *VAC.getValue(), + VAC.getCtxI() ? VAC.getCtxI() : CtxI, S, AnchorScope); + } + return true; + }; + + if (ReturnedArg) { + HandleReturnedValue(*ReturnedArg, nullptr, true); + } else { + auto RetInstPred = [&](Instruction &RetI) { + bool AddValues = true; + if (isa<PHINode>(RetI.getOperand(0)) || + isa<SelectInst>(RetI.getOperand(0))) { + addValue(A, getState(), *RetI.getOperand(0), &RetI, AA::AnyScope, + AnchorScope); + AddValues = false; + } + return HandleReturnedValue(*RetI.getOperand(0), &RetI, AddValues); + }; + + if (!A.checkForAllInstructions(RetInstPred, *this, {Instruction::Ret}, + UsedAssumedInformation, + /* CheckBBLivenessOnly */ true)) + return indicatePessimisticFixpoint(); + } + + return (AssumedBefore == getAssumed()) ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + void addValue(Attributor &A, StateType &State, Value &V, + const Instruction *CtxI, AA::ValueScope S, + Function *AnchorScope) const override { + Function *F = getAssociatedFunction(); + if (auto *CB = dyn_cast<CallBase>(&V)) + if (CB->getCalledOperand() == F) + return; + Base::addValue(A, State, V, CtxI, S, AnchorScope); } ChangeStatus manifest(Attributor &A) override { - // We queried AAValueSimplify for the returned values so they will be - // replaced if a simplified form was found. Nothing to do here. - return ChangeStatus::UNCHANGED; + if (ReturnedArg) + return ChangeStatus::UNCHANGED; + SmallVector<AA::ValueAndContext> Values; + if (!getAssumedSimplifiedValues(A, Values, AA::ValueScope::Intraprocedural, + /* RecurseForSelectAndPHI */ true)) + return ChangeStatus::UNCHANGED; + Value *NewVal = getSingleValue(A, *this, getIRPosition(), Values); + if (!NewVal) + return ChangeStatus::UNCHANGED; + + ChangeStatus Changed = ChangeStatus::UNCHANGED; + if (auto *Arg = dyn_cast<Argument>(NewVal)) { + STATS_DECLTRACK(UniqueReturnValue, FunctionReturn, + "Number of function with unique return"); + Changed |= A.manifestAttrs( + IRPosition::argument(*Arg), + {Attribute::get(Arg->getContext(), Attribute::Returned)}); + STATS_DECLTRACK_ARG_ATTR(returned); + } + + auto RetInstPred = [&](Instruction &RetI) { + Value *RetOp = RetI.getOperand(0); + if (isa<UndefValue>(RetOp) || RetOp == NewVal) + return true; + if (AA::isValidAtPosition({*NewVal, RetI}, A.getInfoCache())) + if (A.changeUseAfterManifest(RetI.getOperandUse(0), *NewVal)) + Changed = ChangeStatus::CHANGED; + return true; + }; + bool UsedAssumedInformation = false; + (void)A.checkForAllInstructions(RetInstPred, *this, {Instruction::Ret}, + UsedAssumedInformation, + /* CheckBBLivenessOnly */ true); + return Changed; } ChangeStatus indicatePessimisticFixpoint() override { @@ -11088,9 +11457,11 @@ struct AAPotentialValuesReturned } /// See AbstractAttribute::trackStatistics() - void trackStatistics() const override { - STATS_DECLTRACK_FNRET_ATTR(potential_values) - } + void trackStatistics() const override{ + STATS_DECLTRACK_FNRET_ATTR(potential_values)} + + /// The argumented with an existing `returned` attribute. + Argument *ReturnedArg = nullptr; }; struct AAPotentialValuesFunction : AAPotentialValuesImpl { @@ -11162,7 +11533,7 @@ struct AAPotentialValuesCallSiteReturned : AAPotentialValuesImpl { SmallVector<AA::ValueAndContext> ArgValues; IRPosition IRP = IRPosition::value(*V); if (auto *Arg = dyn_cast<Argument>(V)) - if (Arg->getParent() == CB->getCalledFunction()) + if (Arg->getParent() == CB->getCalledOperand()) IRP = IRPosition::callsite_argument(*CB, Arg->getArgNo()); if (recurseForValue(A, IRP, AA::AnyScope)) continue; @@ -11228,12 +11599,26 @@ struct AAAssumptionInfoImpl : public AAAssumptionInfo { const DenseSet<StringRef> &Known) : AAAssumptionInfo(IRP, A, Known) {} + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + // Don't manifest a universal set if it somehow made it here. + if (getKnown().isUniversal()) + return ChangeStatus::UNCHANGED; + + const IRPosition &IRP = getIRPosition(); + return A.manifestAttrs( + IRP, + Attribute::get(IRP.getAnchorValue().getContext(), AssumptionAttrKey, + llvm::join(getAssumed().getSet(), ",")), + /* ForceReplace */ true); + } + bool hasAssumption(const StringRef Assumption) const override { return isValidState() && setContains(Assumption); } /// See AbstractAttribute::getAsStr() - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { const SetContents &Known = getKnown(); const SetContents &Assumed = getAssumed(); @@ -11264,31 +11649,18 @@ struct AAAssumptionInfoFunction final : AAAssumptionInfoImpl { : AAAssumptionInfoImpl(IRP, A, getAssumptions(*IRP.getAssociatedFunction())) {} - /// See AbstractAttribute::manifest(...). - ChangeStatus manifest(Attributor &A) override { - const auto &Assumptions = getKnown(); - - // Don't manifest a universal set if it somehow made it here. - if (Assumptions.isUniversal()) - return ChangeStatus::UNCHANGED; - - Function *AssociatedFunction = getAssociatedFunction(); - - bool Changed = addAssumptions(*AssociatedFunction, Assumptions.getSet()); - - return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { bool Changed = false; auto CallSitePred = [&](AbstractCallSite ACS) { - const auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>( + const auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>( *this, IRPosition::callsite_function(*ACS.getInstruction()), DepClassTy::REQUIRED); + if (!AssumptionAA) + return false; // Get the set of assumptions shared by all of this function's callers. - Changed |= getIntersection(AssumptionAA.getAssumed()); + Changed |= getIntersection(AssumptionAA->getAssumed()); return !getAssumed().empty() || !getKnown().empty(); }; @@ -11319,24 +11691,14 @@ struct AAAssumptionInfoCallSite final : AAAssumptionInfoImpl { A.getAAFor<AAAssumptionInfo>(*this, FnPos, DepClassTy::REQUIRED); } - /// See AbstractAttribute::manifest(...). - ChangeStatus manifest(Attributor &A) override { - // Don't manifest a universal set if it somehow made it here. - if (getKnown().isUniversal()) - return ChangeStatus::UNCHANGED; - - CallBase &AssociatedCall = cast<CallBase>(getAssociatedValue()); - bool Changed = addAssumptions(AssociatedCall, getAssumed().getSet()); - - return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { const IRPosition &FnPos = IRPosition::function(*getAnchorScope()); - auto &AssumptionAA = + auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(*this, FnPos, DepClassTy::REQUIRED); - bool Changed = getIntersection(AssumptionAA.getAssumed()); + if (!AssumptionAA) + return indicatePessimisticFixpoint(); + bool Changed = getIntersection(AssumptionAA->getAssumed()); return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; } @@ -11360,7 +11722,7 @@ private: AACallGraphNode *AACallEdgeIterator::operator*() const { return static_cast<AACallGraphNode *>(const_cast<AACallEdges *>( - &A.getOrCreateAAFor<AACallEdges>(IRPosition::function(**I)))); + A.getOrCreateAAFor<AACallEdges>(IRPosition::function(**I)))); } void AttributorCallGraph::print() { llvm::WriteGraph(outs(), this); } @@ -11374,7 +11736,7 @@ struct AAUnderlyingObjectsImpl AAUnderlyingObjectsImpl(const IRPosition &IRP, Attributor &A) : BaseTy(IRP) {} /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *A) const override { return std::string("UnderlyingObjects ") + (isValidState() ? (std::string("inter #") + @@ -11409,24 +11771,33 @@ struct AAUnderlyingObjectsImpl auto *Obj = VAC.getValue(); Value *UO = getUnderlyingObject(Obj); if (UO && UO != VAC.getValue() && SeenObjects.insert(UO).second) { - const auto &OtherAA = A.getAAFor<AAUnderlyingObjects>( + const auto *OtherAA = A.getAAFor<AAUnderlyingObjects>( *this, IRPosition::value(*UO), DepClassTy::OPTIONAL); auto Pred = [&Values](Value &V) { Values.emplace_back(V, nullptr); return true; }; - if (!OtherAA.forallUnderlyingObjects(Pred, Scope)) + if (!OtherAA || !OtherAA->forallUnderlyingObjects(Pred, Scope)) llvm_unreachable( "The forall call should not return false at this position"); continue; } - if (isa<SelectInst>(Obj) || isa<PHINode>(Obj)) { + if (isa<SelectInst>(Obj)) { Changed |= handleIndirect(A, *Obj, UnderlyingObjects, Scope); continue; } + if (auto *PHI = dyn_cast<PHINode>(Obj)) { + // Explicitly look through PHIs as we do not care about dynamically + // uniqueness. + for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) { + Changed |= handleIndirect(A, *PHI->getIncomingValue(u), + UnderlyingObjects, Scope); + } + continue; + } Changed |= UnderlyingObjects.insert(Obj); } @@ -11464,13 +11835,13 @@ private: SmallSetVector<Value *, 8> &UnderlyingObjects, AA::ValueScope Scope) { bool Changed = false; - const auto &AA = A.getAAFor<AAUnderlyingObjects>( + const auto *AA = A.getAAFor<AAUnderlyingObjects>( *this, IRPosition::value(V), DepClassTy::OPTIONAL); auto Pred = [&](Value &V) { Changed |= UnderlyingObjects.insert(&V); return true; }; - if (!AA.forallUnderlyingObjects(Pred, Scope)) + if (!AA || !AA->forallUnderlyingObjects(Pred, Scope)) llvm_unreachable( "The forall call should not return false at this position"); return Changed; @@ -11516,14 +11887,190 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl { AAUnderlyingObjectsFunction(const IRPosition &IRP, Attributor &A) : AAUnderlyingObjectsImpl(IRP, A) {} }; -} +} // namespace + +/// ------------------------ Address Space ------------------------------------ +namespace { +struct AAAddressSpaceImpl : public AAAddressSpace { + AAAddressSpaceImpl(const IRPosition &IRP, Attributor &A) + : AAAddressSpace(IRP, A) {} + + int32_t getAddressSpace() const override { + assert(isValidState() && "the AA is invalid"); + return AssumedAddressSpace; + } + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + assert(getAssociatedType()->isPtrOrPtrVectorTy() && + "Associated value is not a pointer"); + } + + ChangeStatus updateImpl(Attributor &A) override { + int32_t OldAddressSpace = AssumedAddressSpace; + auto *AUO = A.getOrCreateAAFor<AAUnderlyingObjects>(getIRPosition(), this, + DepClassTy::REQUIRED); + auto Pred = [&](Value &Obj) { + if (isa<UndefValue>(&Obj)) + return true; + return takeAddressSpace(Obj.getType()->getPointerAddressSpace()); + }; + + if (!AUO->forallUnderlyingObjects(Pred)) + return indicatePessimisticFixpoint(); + + return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + Value *AssociatedValue = &getAssociatedValue(); + Value *OriginalValue = peelAddrspacecast(AssociatedValue); + if (getAddressSpace() == NoAddressSpace || + static_cast<uint32_t>(getAddressSpace()) == + getAssociatedType()->getPointerAddressSpace()) + return ChangeStatus::UNCHANGED; + + Type *NewPtrTy = PointerType::get(getAssociatedType()->getContext(), + static_cast<uint32_t>(getAddressSpace())); + bool UseOriginalValue = + OriginalValue->getType()->getPointerAddressSpace() == + static_cast<uint32_t>(getAddressSpace()); + + bool Changed = false; + + auto MakeChange = [&](Instruction *I, Use &U) { + Changed = true; + if (UseOriginalValue) { + A.changeUseAfterManifest(U, *OriginalValue); + return; + } + Instruction *CastInst = new AddrSpaceCastInst(OriginalValue, NewPtrTy); + CastInst->insertBefore(cast<Instruction>(I)); + A.changeUseAfterManifest(U, *CastInst); + }; + + auto Pred = [&](const Use &U, bool &) { + if (U.get() != AssociatedValue) + return true; + auto *Inst = dyn_cast<Instruction>(U.getUser()); + if (!Inst) + return true; + // This is a WA to make sure we only change uses from the corresponding + // CGSCC if the AA is run on CGSCC instead of the entire module. + if (!A.isRunOn(Inst->getFunction())) + return true; + if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) + MakeChange(Inst, const_cast<Use &>(U)); + return true; + }; + + // It doesn't matter if we can't check all uses as we can simply + // conservatively ignore those that can not be visited. + (void)A.checkForAllUses(Pred, *this, getAssociatedValue(), + /* CheckBBLivenessOnly */ true); + + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + if (!isValidState()) + return "addrspace(<invalid>)"; + return "addrspace(" + + (AssumedAddressSpace == NoAddressSpace + ? "none" + : std::to_string(AssumedAddressSpace)) + + ")"; + } + +private: + int32_t AssumedAddressSpace = NoAddressSpace; + + bool takeAddressSpace(int32_t AS) { + if (AssumedAddressSpace == NoAddressSpace) { + AssumedAddressSpace = AS; + return true; + } + return AssumedAddressSpace == AS; + } + + static Value *peelAddrspacecast(Value *V) { + if (auto *I = dyn_cast<AddrSpaceCastInst>(V)) + return peelAddrspacecast(I->getPointerOperand()); + if (auto *C = dyn_cast<ConstantExpr>(V)) + if (C->getOpcode() == Instruction::AddrSpaceCast) + return peelAddrspacecast(C->getOperand(0)); + return V; + } +}; + +struct AAAddressSpaceFloating final : AAAddressSpaceImpl { + AAAddressSpaceFloating(const IRPosition &IRP, Attributor &A) + : AAAddressSpaceImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(addrspace); + } +}; + +struct AAAddressSpaceReturned final : AAAddressSpaceImpl { + AAAddressSpaceReturned(const IRPosition &IRP, Attributor &A) + : AAAddressSpaceImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // TODO: we don't rewrite function argument for now because it will need to + // rewrite the function signature and all call sites. + (void)indicatePessimisticFixpoint(); + } + + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(addrspace); + } +}; + +struct AAAddressSpaceCallSiteReturned final : AAAddressSpaceImpl { + AAAddressSpaceCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AAAddressSpaceImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(addrspace); + } +}; + +struct AAAddressSpaceArgument final : AAAddressSpaceImpl { + AAAddressSpaceArgument(const IRPosition &IRP, Attributor &A) + : AAAddressSpaceImpl(IRP, A) {} + + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(addrspace); } +}; + +struct AAAddressSpaceCallSiteArgument final : AAAddressSpaceImpl { + AAAddressSpaceCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AAAddressSpaceImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // TODO: we don't rewrite call site argument for now because it will need to + // rewrite the function signature of the callee. + (void)indicatePessimisticFixpoint(); + } + + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(addrspace); + } +}; +} // namespace -const char AAReturnedValues::ID = 0; const char AANoUnwind::ID = 0; const char AANoSync::ID = 0; const char AANoFree::ID = 0; const char AANonNull::ID = 0; +const char AAMustProgress::ID = 0; const char AANoRecurse::ID = 0; +const char AANonConvergent::ID = 0; const char AAWillReturn::ID = 0; const char AAUndefinedBehavior::ID = 0; const char AANoAlias::ID = 0; @@ -11543,11 +12090,13 @@ const char AAValueConstantRange::ID = 0; const char AAPotentialConstantValues::ID = 0; const char AAPotentialValues::ID = 0; const char AANoUndef::ID = 0; +const char AANoFPClass::ID = 0; const char AACallEdges::ID = 0; const char AAInterFnReachability::ID = 0; const char AAPointerInfo::ID = 0; const char AAAssumptionInfo::ID = 0; const char AAUnderlyingObjects::ID = 0; +const char AAAddressSpace::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -11647,10 +12196,10 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoSync) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoRecurse) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAWillReturn) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoReturn) -CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMustProgress) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) @@ -11663,7 +12212,9 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueConstantRange) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialConstantValues) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialValues) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFPClass) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAddressSpace) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) @@ -11672,6 +12223,7 @@ CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIntraFnReachability) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInterFnReachability) diff --git a/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/llvm/lib/Transforms/IPO/BlockExtractor.cpp index a68cf7db7c85..0c406aa9822e 100644 --- a/llvm/lib/Transforms/IPO/BlockExtractor.cpp +++ b/llvm/lib/Transforms/IPO/BlockExtractor.cpp @@ -17,8 +17,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MemoryBuffer.h" diff --git a/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp b/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp index 64bfcb2a9a9f..2c8756c07f87 100644 --- a/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp +++ b/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -21,8 +21,6 @@ #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" @@ -405,33 +403,3 @@ PreservedAnalyses CalledValuePropagationPass::run(Module &M, runCVP(M); return PreservedAnalyses::all(); } - -namespace { -class CalledValuePropagationLegacyPass : public ModulePass { -public: - static char ID; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - - CalledValuePropagationLegacyPass() : ModulePass(ID) { - initializeCalledValuePropagationLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - return runCVP(M); - } -}; -} // namespace - -char CalledValuePropagationLegacyPass::ID = 0; -INITIALIZE_PASS(CalledValuePropagationLegacyPass, "called-value-propagation", - "Called Value Propagation", false, false) - -ModulePass *llvm::createCalledValuePropagationPass() { - return new CalledValuePropagationLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/llvm/lib/Transforms/IPO/ConstantMerge.cpp index 77bc377f4514..29052c8d997e 100644 --- a/llvm/lib/Transforms/IPO/ConstantMerge.cpp +++ b/llvm/lib/Transforms/IPO/ConstantMerge.cpp @@ -28,8 +28,6 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/IPO.h" #include <algorithm> @@ -251,32 +249,3 @@ PreservedAnalyses ConstantMergePass::run(Module &M, ModuleAnalysisManager &) { return PreservedAnalyses::all(); return PreservedAnalyses::none(); } - -namespace { - -struct ConstantMergeLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - - ConstantMergeLegacyPass() : ModulePass(ID) { - initializeConstantMergeLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - // For this pass, process all of the globals in the module, eliminating - // duplicate constants. - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - return mergeConstants(M); - } -}; - -} // end anonymous namespace - -char ConstantMergeLegacyPass::ID = 0; - -INITIALIZE_PASS(ConstantMergeLegacyPass, "constmerge", - "Merge Duplicate Global Constants", false, false) - -ModulePass *llvm::createConstantMergePass() { - return new ConstantMergeLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 4fe7bb6c757c..93d15f59a036 100644 --- a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -14,7 +14,6 @@ #include "llvm/Transforms/IPO/CrossDSOCFI.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalObject.h" @@ -23,8 +22,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO.h" using namespace llvm; @@ -35,28 +33,16 @@ STATISTIC(NumTypeIds, "Number of unique type identifiers"); namespace { -struct CrossDSOCFI : public ModulePass { - static char ID; - CrossDSOCFI() : ModulePass(ID) { - initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry()); - } - +struct CrossDSOCFI { MDNode *VeryLikelyWeights; ConstantInt *extractNumericTypeId(MDNode *MD); void buildCFICheck(Module &M); - bool runOnModule(Module &M) override; + bool runOnModule(Module &M); }; } // anonymous namespace -INITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, - false) -INITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false) -char CrossDSOCFI::ID = 0; - -ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; } - /// Extracts a numeric type identifier from an MDNode containing type metadata. ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) { // This check excludes vtables for classes inside anonymous namespaces. diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index bf2c65a2402c..01834015f3fd 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -16,9 +16,11 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -43,7 +45,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cassert> #include <utility> @@ -85,6 +86,11 @@ public: virtual bool shouldHackArguments() const { return false; } }; +bool isMustTailCalleeAnalyzable(const CallBase &CB) { + assert(CB.isMustTailCall()); + return CB.getCalledFunction() && !CB.getCalledFunction()->isDeclaration(); +} + } // end anonymous namespace char DAE::ID = 0; @@ -520,8 +526,16 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) { for (const BasicBlock &BB : F) { // If we have any returns of `musttail` results - the signature can't // change - if (BB.getTerminatingMustTailCall() != nullptr) + if (const auto *TC = BB.getTerminatingMustTailCall()) { HasMustTailCalls = true; + // In addition, if the called function is not locally defined (or unknown, + // if this is an indirect call), we can't change the callsite and thus + // can't change this function's signature either. + if (!isMustTailCalleeAnalyzable(*TC)) { + markLive(F); + return; + } + } } if (HasMustTailCalls) { @@ -1081,6 +1095,26 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { return true; } +void DeadArgumentEliminationPass::propagateVirtMustcallLiveness( + const Module &M) { + // If a function was marked "live", and it has musttail callers, they in turn + // can't change either. + LiveFuncSet NewLiveFuncs(LiveFunctions); + while (!NewLiveFuncs.empty()) { + LiveFuncSet Temp; + for (const auto *F : NewLiveFuncs) + for (const auto *U : F->users()) + if (const auto *CB = dyn_cast<CallBase>(U)) + if (CB->isMustTailCall()) + if (!LiveFunctions.count(CB->getParent()->getParent())) + Temp.insert(CB->getParent()->getParent()); + NewLiveFuncs.clear(); + NewLiveFuncs.insert(Temp.begin(), Temp.end()); + for (const auto *F : Temp) + markLive(*F); + } +} + PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; @@ -1101,6 +1135,8 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, for (auto &F : M) surveyFunction(F); + propagateVirtMustcallLiveness(M); + // Now, remove all dead arguments and return values from each function in // turn. We use make_early_inc_range here because functions will probably get // removed (i.e. replaced by new ones). diff --git a/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp b/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp index 7f138d206fac..2b34d3b5a56e 100644 --- a/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp +++ b/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp @@ -12,24 +12,82 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/ElimAvailExtern.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/GlobalStatus.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; #define DEBUG_TYPE "elim-avail-extern" -STATISTIC(NumFunctions, "Number of functions removed"); +cl::opt<bool> ConvertToLocal( + "avail-extern-to-local", cl::Hidden, + cl::desc("Convert available_externally into locals, renaming them " + "to avoid link-time clashes.")); + +STATISTIC(NumRemovals, "Number of functions removed"); +STATISTIC(NumConversions, "Number of functions converted"); STATISTIC(NumVariables, "Number of global variables removed"); +void deleteFunction(Function &F) { + // This will set the linkage to external + F.deleteBody(); + ++NumRemovals; +} + +/// Create a copy of the thinlto import, mark it local, and redirect direct +/// calls to the copy. Only direct calls are replaced, so that e.g. indirect +/// call function pointer tests would use the global identity of the function. +/// +/// Currently, Value Profiling ("VP") MD_prof data isn't updated to refer to the +/// clone's GUID (which will be different, because the name and linkage is +/// different), under the assumption that the last consumer of this data is +/// upstream the pipeline (e.g. ICP). +static void convertToLocalCopy(Module &M, Function &F) { + assert(F.hasAvailableExternallyLinkage()); + assert(!F.isDeclaration()); + // If we can't find a single use that's a call, just delete the function. + if (F.uses().end() == llvm::find_if(F.uses(), [&](Use &U) { + return isa<CallBase>(U.getUser()); + })) + return deleteFunction(F); + + auto OrigName = F.getName().str(); + // Build a new name. We still need the old name (see below). + // We could just rely on internal linking allowing 2 modules have internal + // functions with the same name, but that just creates more trouble than + // necessary e.g. distinguishing profiles or debugging. Instead, we append the + // module identifier. + auto NewName = OrigName + ".__uniq" + getUniqueModuleId(&M); + F.setName(NewName); + if (auto *SP = F.getSubprogram()) + SP->replaceLinkageName(MDString::get(F.getParent()->getContext(), NewName)); + + F.setLinkage(GlobalValue::InternalLinkage); + // Now make a declaration for the old name. We'll use it if there are non-call + // uses. For those, it would be incorrect to replace them with the local copy: + // for example, one such use could be taking the address of the function and + // passing it to an external function, which, in turn, might compare the + // function pointer to the original (non-local) function pointer, e.g. as part + // of indirect call promotion. + auto *Decl = + Function::Create(F.getFunctionType(), GlobalValue::ExternalLinkage, + F.getAddressSpace(), OrigName, F.getParent()); + F.replaceUsesWithIf(Decl, + [&](Use &U) { return !isa<CallBase>(U.getUser()); }); + ++NumConversions; +} + static bool eliminateAvailableExternally(Module &M) { bool Changed = false; @@ -45,19 +103,21 @@ static bool eliminateAvailableExternally(Module &M) { } GV.removeDeadConstantUsers(); GV.setLinkage(GlobalValue::ExternalLinkage); - NumVariables++; + ++NumVariables; Changed = true; } // Drop the bodies of available externally functions. - for (Function &F : M) { - if (!F.hasAvailableExternallyLinkage()) + for (Function &F : llvm::make_early_inc_range(M)) { + if (F.isDeclaration() || !F.hasAvailableExternallyLinkage()) continue; - if (!F.isDeclaration()) - // This will set the linkage to external - F.deleteBody(); + + if (ConvertToLocal) + convertToLocalCopy(M, F); + else + deleteFunction(F); + F.removeDeadConstantUsers(); - NumFunctions++; Changed = true; } @@ -70,33 +130,3 @@ EliminateAvailableExternallyPass::run(Module &M, ModuleAnalysisManager &) { return PreservedAnalyses::all(); return PreservedAnalyses::none(); } - -namespace { - -struct EliminateAvailableExternallyLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - - EliminateAvailableExternallyLegacyPass() : ModulePass(ID) { - initializeEliminateAvailableExternallyLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - // run - Do the EliminateAvailableExternally pass on the specified module, - // optionally updating the specified callgraph to reflect the changes. - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - return eliminateAvailableExternally(M); - } -}; - -} // end anonymous namespace - -char EliminateAvailableExternallyLegacyPass::ID = 0; - -INITIALIZE_PASS(EliminateAvailableExternallyLegacyPass, "elim-avail-extern", - "Eliminate Available Externally Globals", false, false) - -ModulePass *llvm::createEliminateAvailableExternallyPass() { - return new EliminateAvailableExternallyLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp new file mode 100644 index 000000000000..fa56a5b564ae --- /dev/null +++ b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp @@ -0,0 +1,52 @@ +//===- EmbedBitcodePass.cpp - Pass that embeds the bitcode into a global---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/EmbedBitcodePass.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Bitcode/BitcodeWriterPass.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include <memory> +#include <string> + +using namespace llvm; + +PreservedAnalyses EmbedBitcodePass::run(Module &M, ModuleAnalysisManager &AM) { + if (M.getGlobalVariable("llvm.embedded.module", /*AllowInternal=*/true)) + report_fatal_error("Can only embed the module once", + /*gen_crash_diag=*/false); + + Triple T(M.getTargetTriple()); + if (T.getObjectFormat() != Triple::ELF) + report_fatal_error( + "EmbedBitcode pass currently only supports ELF object format", + /*gen_crash_diag=*/false); + + std::unique_ptr<Module> NewModule = CloneModule(M); + MPM.run(*NewModule, AM); + + std::string Data; + raw_string_ostream OS(Data); + if (IsThinLTO) + ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(*NewModule, AM); + else + BitcodeWriterPass(OS, /*ShouldPreserveUseListOrder=*/false, EmitLTOSummary) + .run(*NewModule, AM); + + embedBufferInModule(M, MemoryBufferRef(Data, "ModuleData"), ".llvm.lto"); + + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/IPO/ExtractGV.cpp b/llvm/lib/Transforms/IPO/ExtractGV.cpp index d5073eed2fef..6414ea69c9f7 100644 --- a/llvm/lib/Transforms/IPO/ExtractGV.cpp +++ b/llvm/lib/Transforms/IPO/ExtractGV.cpp @@ -36,7 +36,7 @@ static void makeVisible(GlobalValue &GV, bool Delete) { } // Map linkonce* to weak* so that llvm doesn't drop this GV. - switch(GV.getLinkage()) { + switch (GV.getLinkage()) { default: llvm_unreachable("Unexpected linkage"); case GlobalValue::LinkOnceAnyLinkage: @@ -48,10 +48,9 @@ static void makeVisible(GlobalValue &GV, bool Delete) { } } - - /// If deleteS is true, this pass deletes the specified global values. - /// Otherwise, it deletes as much of the module as possible, except for the - /// global values specified. +/// If deleteS is true, this pass deletes the specified global values. +/// Otherwise, it deletes as much of the module as possible, except for the +/// global values specified. ExtractGVPass::ExtractGVPass(std::vector<GlobalValue *> &GVs, bool deleteS, bool keepConstInit) : Named(GVs.begin(), GVs.end()), deleteStuff(deleteS), @@ -129,5 +128,22 @@ PreservedAnalyses ExtractGVPass::run(Module &M, ModuleAnalysisManager &) { } } + // Visit the IFuncs. + for (GlobalIFunc &IF : llvm::make_early_inc_range(M.ifuncs())) { + bool Delete = deleteStuff == (bool)Named.count(&IF); + makeVisible(IF, Delete); + + if (!Delete) + continue; + + auto *FuncType = dyn_cast<FunctionType>(IF.getValueType()); + IF.removeFromParent(); + llvm::Value *Declaration = + Function::Create(FuncType, GlobalValue::ExternalLinkage, + IF.getAddressSpace(), IF.getName(), &M); + IF.replaceAllUsesWith(Declaration); + delete &IF; + } + return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp index b10c2ea13469..74931e1032d1 100644 --- a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -9,8 +9,6 @@ #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/IR/Function.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -80,32 +78,3 @@ PreservedAnalyses ForceFunctionAttrsPass::run(Module &M, // Just conservatively invalidate analyses, this isn't likely to be important. return PreservedAnalyses::none(); } - -namespace { -struct ForceFunctionAttrsLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - ForceFunctionAttrsLegacyPass() : ModulePass(ID) { - initializeForceFunctionAttrsLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (!hasForceAttributes()) - return false; - - for (Function &F : M.functions()) - forceAttributes(F); - - // Conservatively assume we changed something. - return true; - } -}; -} - -char ForceFunctionAttrsLegacyPass::ID = 0; -INITIALIZE_PASS(ForceFunctionAttrsLegacyPass, "forceattrs", - "Force set function attributes", false, false) - -Pass *llvm::createForceFunctionAttrsLegacyPass() { - return new ForceFunctionAttrsLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 3f61dbe3354e..34299f9dbb23 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -50,8 +50,6 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -154,7 +152,7 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, // If it's not an identified object, it might be an argument. if (!isIdentifiedObject(UO)) ME |= MemoryEffects::argMemOnly(MR); - ME |= MemoryEffects(MemoryEffects::Other, MR); + ME |= MemoryEffects(IRMemLocation::Other, MR); }; // Scan the function body for instructions that may read or write memory. for (Instruction &I : instructions(F)) { @@ -181,17 +179,17 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, if (isa<PseudoProbeInst>(I)) continue; - ME |= CallME.getWithoutLoc(MemoryEffects::ArgMem); + ME |= CallME.getWithoutLoc(IRMemLocation::ArgMem); // If the call accesses captured memory (currently part of "other") and // an argument is captured (currently not tracked), then it may also // access argument memory. - ModRefInfo OtherMR = CallME.getModRef(MemoryEffects::Other); + ModRefInfo OtherMR = CallME.getModRef(IRMemLocation::Other); ME |= MemoryEffects::argMemOnly(OtherMR); // Check whether all pointer arguments point to local memory, and // ignore calls that only access local memory. - ModRefInfo ArgMR = CallME.getModRef(MemoryEffects::ArgMem); + ModRefInfo ArgMR = CallME.getModRef(IRMemLocation::ArgMem); if (ArgMR != ModRefInfo::NoModRef) { for (const Use &U : Call->args()) { const Value *Arg = U; @@ -640,7 +638,7 @@ determinePointerAccessAttrs(Argument *A, if (Visited.insert(&UU).second) Worklist.push_back(&UU); } - + if (CB.doesNotAccessMemory()) continue; @@ -723,18 +721,18 @@ static void addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes, continue; // There is nothing to do if an argument is already marked as 'returned'. - if (llvm::any_of(F->args(), - [](const Argument &Arg) { return Arg.hasReturnedAttr(); })) + if (F->getAttributes().hasAttrSomewhere(Attribute::Returned)) continue; - auto FindRetArg = [&]() -> Value * { - Value *RetArg = nullptr; + auto FindRetArg = [&]() -> Argument * { + Argument *RetArg = nullptr; for (BasicBlock &BB : *F) if (auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) { // Note that stripPointerCasts should look through functions with // returned arguments. - Value *RetVal = Ret->getReturnValue()->stripPointerCasts(); - if (!isa<Argument>(RetVal) || RetVal->getType() != F->getReturnType()) + auto *RetVal = + dyn_cast<Argument>(Ret->getReturnValue()->stripPointerCasts()); + if (!RetVal || RetVal->getType() != F->getReturnType()) return nullptr; if (!RetArg) @@ -746,9 +744,8 @@ static void addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes, return RetArg; }; - if (Value *RetArg = FindRetArg()) { - auto *A = cast<Argument>(RetArg); - A->addAttr(Attribute::Returned); + if (Argument *RetArg = FindRetArg()) { + RetArg->addAttr(Attribute::Returned); ++NumReturned; Changed.insert(F); } @@ -1379,7 +1376,7 @@ static bool InstrBreaksNonConvergent(Instruction &I, /// Helper for NoUnwind inference predicate InstrBreaksAttribute. static bool InstrBreaksNonThrowing(Instruction &I, const SCCNodeSet &SCCNodes) { - if (!I.mayThrow()) + if (!I.mayThrow(/* IncludePhaseOneUnwind */ true)) return false; if (const auto *CI = dyn_cast<CallInst>(&I)) { if (Function *Callee = CI->getCalledFunction()) { @@ -1410,6 +1407,61 @@ static bool InstrBreaksNoFree(Instruction &I, const SCCNodeSet &SCCNodes) { return true; } +// Return true if this is an atomic which has an ordering stronger than +// unordered. Note that this is different than the predicate we use in +// Attributor. Here we chose to be conservative and consider monotonic +// operations potentially synchronizing. We generally don't do much with +// monotonic operations, so this is simply risk reduction. +static bool isOrderedAtomic(Instruction *I) { + if (!I->isAtomic()) + return false; + + if (auto *FI = dyn_cast<FenceInst>(I)) + // All legal orderings for fence are stronger than monotonic. + return FI->getSyncScopeID() != SyncScope::SingleThread; + else if (isa<AtomicCmpXchgInst>(I) || isa<AtomicRMWInst>(I)) + return true; + else if (auto *SI = dyn_cast<StoreInst>(I)) + return !SI->isUnordered(); + else if (auto *LI = dyn_cast<LoadInst>(I)) + return !LI->isUnordered(); + else { + llvm_unreachable("unknown atomic instruction?"); + } +} + +static bool InstrBreaksNoSync(Instruction &I, const SCCNodeSet &SCCNodes) { + // Volatile may synchronize + if (I.isVolatile()) + return true; + + // An ordered atomic may synchronize. (See comment about on monotonic.) + if (isOrderedAtomic(&I)) + return true; + + auto *CB = dyn_cast<CallBase>(&I); + if (!CB) + // Non call site cases covered by the two checks above + return false; + + if (CB->hasFnAttr(Attribute::NoSync)) + return false; + + // Non volatile memset/memcpy/memmoves are nosync + // NOTE: Only intrinsics with volatile flags should be handled here. All + // others should be marked in Intrinsics.td. + if (auto *MI = dyn_cast<MemIntrinsic>(&I)) + if (!MI->isVolatile()) + return false; + + // Speculatively assume in SCC. + if (Function *Callee = CB->getCalledFunction()) + if (SCCNodes.contains(Callee)) + return false; + + return true; +} + /// Attempt to remove convergent function attribute when possible. /// /// Returns true if any changes to function attributes were made. @@ -1441,9 +1493,7 @@ static void inferConvergent(const SCCNodeSet &SCCNodes, } /// Infer attributes from all functions in the SCC by scanning every -/// instruction for compliance to the attribute assumptions. Currently it -/// does: -/// - addition of NoUnwind attribute +/// instruction for compliance to the attribute assumptions. /// /// Returns true if any changes to function attributes were made. static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes, @@ -1495,6 +1545,22 @@ static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes, }, /* RequiresExactDefinition= */ true}); + AI.registerAttrInference(AttributeInferer::InferenceDescriptor{ + Attribute::NoSync, + // Skip already marked functions. + [](const Function &F) { return F.hasNoSync(); }, + // Instructions that break nosync assumption. + [&SCCNodes](Instruction &I) { + return InstrBreaksNoSync(I, SCCNodes); + }, + [](Function &F) { + LLVM_DEBUG(dbgs() + << "Adding nosync attr to fn " << F.getName() << "\n"); + F.setNoSync(); + ++NumNoSync; + }, + /* RequiresExactDefinition= */ true}); + // Perform all the requested attribute inference actions. AI.run(SCCNodes, Changed); } @@ -1622,83 +1688,6 @@ static void addWillReturn(const SCCNodeSet &SCCNodes, } } -// Return true if this is an atomic which has an ordering stronger than -// unordered. Note that this is different than the predicate we use in -// Attributor. Here we chose to be conservative and consider monotonic -// operations potentially synchronizing. We generally don't do much with -// monotonic operations, so this is simply risk reduction. -static bool isOrderedAtomic(Instruction *I) { - if (!I->isAtomic()) - return false; - - if (auto *FI = dyn_cast<FenceInst>(I)) - // All legal orderings for fence are stronger than monotonic. - return FI->getSyncScopeID() != SyncScope::SingleThread; - else if (isa<AtomicCmpXchgInst>(I) || isa<AtomicRMWInst>(I)) - return true; - else if (auto *SI = dyn_cast<StoreInst>(I)) - return !SI->isUnordered(); - else if (auto *LI = dyn_cast<LoadInst>(I)) - return !LI->isUnordered(); - else { - llvm_unreachable("unknown atomic instruction?"); - } -} - -static bool InstrBreaksNoSync(Instruction &I, const SCCNodeSet &SCCNodes) { - // Volatile may synchronize - if (I.isVolatile()) - return true; - - // An ordered atomic may synchronize. (See comment about on monotonic.) - if (isOrderedAtomic(&I)) - return true; - - auto *CB = dyn_cast<CallBase>(&I); - if (!CB) - // Non call site cases covered by the two checks above - return false; - - if (CB->hasFnAttr(Attribute::NoSync)) - return false; - - // Non volatile memset/memcpy/memmoves are nosync - // NOTE: Only intrinsics with volatile flags should be handled here. All - // others should be marked in Intrinsics.td. - if (auto *MI = dyn_cast<MemIntrinsic>(&I)) - if (!MI->isVolatile()) - return false; - - // Speculatively assume in SCC. - if (Function *Callee = CB->getCalledFunction()) - if (SCCNodes.contains(Callee)) - return false; - - return true; -} - -// Infer the nosync attribute. -static void addNoSyncAttr(const SCCNodeSet &SCCNodes, - SmallSet<Function *, 8> &Changed) { - AttributeInferer AI; - AI.registerAttrInference(AttributeInferer::InferenceDescriptor{ - Attribute::NoSync, - // Skip already marked functions. - [](const Function &F) { return F.hasNoSync(); }, - // Instructions that break nosync assumption. - [&SCCNodes](Instruction &I) { - return InstrBreaksNoSync(I, SCCNodes); - }, - [](Function &F) { - LLVM_DEBUG(dbgs() - << "Adding nosync attr to fn " << F.getName() << "\n"); - F.setNoSync(); - ++NumNoSync; - }, - /* RequiresExactDefinition= */ true}); - AI.run(SCCNodes, Changed); -} - static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) { SCCNodesResult Res; Res.HasUnknownCall = false; @@ -1756,8 +1745,6 @@ deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { addNoRecurseAttrs(Nodes.SCCNodes, Changed); } - addNoSyncAttr(Nodes.SCCNodes, Changed); - // Finally, infer the maximal set of attributes from the ones we've inferred // above. This is handling the cases where one attribute on a signature // implies another, but for implementation reasons the inference rule for @@ -1774,6 +1761,13 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &) { + // Skip non-recursive functions if requested. + if (C.size() == 1 && SkipNonRecursive) { + LazyCallGraph::Node &N = *C.begin(); + if (!N->lookup(N)) + return PreservedAnalyses::all(); + } + FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); @@ -1819,40 +1813,12 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, return PA; } -namespace { - -struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { - // Pass identification, replacement for typeid - static char ID; - - PostOrderFunctionAttrsLegacyPass() : CallGraphSCCPass(ID) { - initializePostOrderFunctionAttrsLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnSCC(CallGraphSCC &SCC) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - getAAResultsAnalysisUsage(AU); - CallGraphSCCPass::getAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - -char PostOrderFunctionAttrsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "function-attrs", - "Deduce function attributes", false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "function-attrs", - "Deduce function attributes", false, false) - -Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { - return new PostOrderFunctionAttrsLegacyPass(); +void PostOrderFunctionAttrsPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<PostOrderFunctionAttrsPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + if (SkipNonRecursive) + OS << "<skip-non-recursive>"; } template <typename AARGetterT> @@ -1865,48 +1831,6 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { return !deriveAttrsInPostOrder(Functions, AARGetter).empty(); } -bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { - if (skipSCC(SCC)) - return false; - return runImpl(SCC, LegacyAARGetter(*this)); -} - -namespace { - -struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass { - // Pass identification, replacement for typeid - static char ID; - - ReversePostOrderFunctionAttrsLegacyPass() : ModulePass(ID) { - initializeReversePostOrderFunctionAttrsLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<CallGraphWrapperPass>(); - AU.addPreserved<CallGraphWrapperPass>(); - } -}; - -} // end anonymous namespace - -char ReversePostOrderFunctionAttrsLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrsLegacyPass, - "rpo-function-attrs", "Deduce function attributes in RPO", - false, false) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(ReversePostOrderFunctionAttrsLegacyPass, - "rpo-function-attrs", "Deduce function attributes in RPO", - false, false) - -Pass *llvm::createReversePostOrderFunctionAttrsPass() { - return new ReversePostOrderFunctionAttrsLegacyPass(); -} - static bool addNoRecurseAttrsTopDown(Function &F) { // We check the preconditions for the function prior to calling this to avoid // the cost of building up a reversible post-order list. We assert them here @@ -1939,7 +1863,7 @@ static bool addNoRecurseAttrsTopDown(Function &F) { return true; } -static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) { +static bool deduceFunctionAttributeInRPO(Module &M, LazyCallGraph &CG) { // We only have a post-order SCC traversal (because SCCs are inherently // discovered in post-order), so we accumulate them in a vector and then walk // it in reverse. This is simpler than using the RPO iterator infrastructure @@ -1947,17 +1871,18 @@ static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) { // graph. We can also cheat egregiously because we're primarily interested in // synthesizing norecurse and so we can only save the singular SCCs as SCCs // with multiple functions in them will clearly be recursive. - SmallVector<Function *, 16> Worklist; - for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) { - if (I->size() != 1) - continue; - Function *F = I->front()->getFunction(); - if (F && !F->isDeclaration() && !F->doesNotRecurse() && - F->hasInternalLinkage()) - Worklist.push_back(F); + SmallVector<Function *, 16> Worklist; + CG.buildRefSCCs(); + for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) { + for (LazyCallGraph::SCC &SCC : RC) { + if (SCC.size() != 1) + continue; + Function &F = SCC.begin()->getFunction(); + if (!F.isDeclaration() && !F.doesNotRecurse() && F.hasInternalLinkage()) + Worklist.push_back(&F); + } } - bool Changed = false; for (auto *F : llvm::reverse(Worklist)) Changed |= addNoRecurseAttrsTopDown(*F); @@ -1965,23 +1890,14 @@ static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) { return Changed; } -bool ReversePostOrderFunctionAttrsLegacyPass::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - - return deduceFunctionAttributeInRPO(M, CG); -} - PreservedAnalyses ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) { - auto &CG = AM.getResult<CallGraphAnalysis>(M); + auto &CG = AM.getResult<LazyCallGraphAnalysis>(M); if (!deduceFunctionAttributeInRPO(M, CG)) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserve<CallGraphAnalysis>(); + PA.preserve<LazyCallGraphAnalysis>(); return PA; } diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp index 7c994657e5c8..f635b14cd2a9 100644 --- a/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -30,9 +30,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/ModuleSummaryIndex.h" #include "llvm/IRReader/IRReader.h" -#include "llvm/InitializePasses.h" #include "llvm/Linker/IRMover.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -159,39 +157,37 @@ static std::unique_ptr<Module> loadFile(const std::string &FileName, return Result; } -/// Given a list of possible callee implementation for a call site, select one -/// that fits the \p Threshold. -/// -/// FIXME: select "best" instead of first that fits. But what is "best"? -/// - The smallest: more likely to be inlined. -/// - The one with the least outgoing edges (already well optimized). -/// - One from a module already being imported from in order to reduce the -/// number of source modules parsed/linked. -/// - One that has PGO data attached. -/// - [insert you fancy metric here] -static const GlobalValueSummary * -selectCallee(const ModuleSummaryIndex &Index, - ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList, - unsigned Threshold, StringRef CallerModulePath, - FunctionImporter::ImportFailureReason &Reason, - GlobalValue::GUID GUID) { - Reason = FunctionImporter::ImportFailureReason::None; - auto It = llvm::find_if( +/// Given a list of possible callee implementation for a call site, qualify the +/// legality of importing each. The return is a range of pairs. Each pair +/// corresponds to a candidate. The first value is the ImportFailureReason for +/// that candidate, the second is the candidate. +static auto qualifyCalleeCandidates( + const ModuleSummaryIndex &Index, + ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList, + StringRef CallerModulePath) { + return llvm::map_range( CalleeSummaryList, - [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { + [&Index, CalleeSummaryList, + CallerModulePath](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) + -> std::pair<FunctionImporter::ImportFailureReason, + const GlobalValueSummary *> { auto *GVSummary = SummaryPtr.get(); - if (!Index.isGlobalValueLive(GVSummary)) { - Reason = FunctionImporter::ImportFailureReason::NotLive; - return false; - } + if (!Index.isGlobalValueLive(GVSummary)) + return {FunctionImporter::ImportFailureReason::NotLive, GVSummary}; - if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) { - Reason = FunctionImporter::ImportFailureReason::InterposableLinkage; - // There is no point in importing these, we can't inline them - return false; - } + if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) + return {FunctionImporter::ImportFailureReason::InterposableLinkage, + GVSummary}; - auto *Summary = cast<FunctionSummary>(GVSummary->getBaseObject()); + auto *Summary = dyn_cast<FunctionSummary>(GVSummary->getBaseObject()); + + // Ignore any callees that aren't actually functions. This could happen + // in the case of GUID hash collisions. It could also happen in theory + // for SamplePGO profiles collected on old versions of the code after + // renaming, since we synthesize edges to any inlined callees appearing + // in the profile. + if (!Summary) + return {FunctionImporter::ImportFailureReason::GlobalVar, GVSummary}; // If this is a local function, make sure we import the copy // in the caller's module. The only time a local function can @@ -205,119 +201,174 @@ selectCallee(const ModuleSummaryIndex &Index, // a local in another module. if (GlobalValue::isLocalLinkage(Summary->linkage()) && CalleeSummaryList.size() > 1 && - Summary->modulePath() != CallerModulePath) { - Reason = - FunctionImporter::ImportFailureReason::LocalLinkageNotInModule; - return false; - } - - if ((Summary->instCount() > Threshold) && - !Summary->fflags().AlwaysInline && !ForceImportAll) { - Reason = FunctionImporter::ImportFailureReason::TooLarge; - return false; - } + Summary->modulePath() != CallerModulePath) + return { + FunctionImporter::ImportFailureReason::LocalLinkageNotInModule, + GVSummary}; // Skip if it isn't legal to import (e.g. may reference unpromotable // locals). - if (Summary->notEligibleToImport()) { - Reason = FunctionImporter::ImportFailureReason::NotEligible; - return false; - } + if (Summary->notEligibleToImport()) + return {FunctionImporter::ImportFailureReason::NotEligible, + GVSummary}; - // Don't bother importing if we can't inline it anyway. - if (Summary->fflags().NoInline && !ForceImportAll) { - Reason = FunctionImporter::ImportFailureReason::NoInline; - return false; - } - - return true; + return {FunctionImporter::ImportFailureReason::None, GVSummary}; }); - if (It == CalleeSummaryList.end()) - return nullptr; +} + +/// Given a list of possible callee implementation for a call site, select one +/// that fits the \p Threshold. If none are found, the Reason will give the last +/// reason for the failure (last, in the order of CalleeSummaryList entries). +/// +/// FIXME: select "best" instead of first that fits. But what is "best"? +/// - The smallest: more likely to be inlined. +/// - The one with the least outgoing edges (already well optimized). +/// - One from a module already being imported from in order to reduce the +/// number of source modules parsed/linked. +/// - One that has PGO data attached. +/// - [insert you fancy metric here] +static const GlobalValueSummary * +selectCallee(const ModuleSummaryIndex &Index, + ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList, + unsigned Threshold, StringRef CallerModulePath, + FunctionImporter::ImportFailureReason &Reason) { + auto QualifiedCandidates = + qualifyCalleeCandidates(Index, CalleeSummaryList, CallerModulePath); + for (auto QualifiedValue : QualifiedCandidates) { + Reason = QualifiedValue.first; + if (Reason != FunctionImporter::ImportFailureReason::None) + continue; + auto *Summary = + cast<FunctionSummary>(QualifiedValue.second->getBaseObject()); + + if ((Summary->instCount() > Threshold) && !Summary->fflags().AlwaysInline && + !ForceImportAll) { + Reason = FunctionImporter::ImportFailureReason::TooLarge; + continue; + } - return cast<GlobalValueSummary>(It->get()); + // Don't bother importing if we can't inline it anyway. + if (Summary->fflags().NoInline && !ForceImportAll) { + Reason = FunctionImporter::ImportFailureReason::NoInline; + continue; + } + + return Summary; + } + return nullptr; } namespace { -using EdgeInfo = - std::tuple<const GlobalValueSummary *, unsigned /* Threshold */>; +using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */>; } // anonymous namespace -static bool shouldImportGlobal(const ValueInfo &VI, - const GVSummaryMapTy &DefinedGVSummaries) { - const auto &GVS = DefinedGVSummaries.find(VI.getGUID()); - if (GVS == DefinedGVSummaries.end()) - return true; - // We should not skip import if the module contains a definition with - // interposable linkage type. This is required for correctness in - // the situation with two following conditions: - // * the def with interposable linkage is non-prevailing, - // * there is a prevailing def available for import and marked read-only. - // In this case, the non-prevailing def will be converted to a declaration, - // while the prevailing one becomes internal, thus no definitions will be - // available for linking. In order to prevent undefined symbol link error, - // the prevailing definition must be imported. - // FIXME: Consider adding a check that the suitable prevailing definition - // exists and marked read-only. - if (VI.getSummaryList().size() > 1 && - GlobalValue::isInterposableLinkage(GVS->second->linkage())) - return true; - - return false; -} +/// Import globals referenced by a function or other globals that are being +/// imported, if importing such global is possible. +class GlobalsImporter final { + const ModuleSummaryIndex &Index; + const GVSummaryMapTy &DefinedGVSummaries; + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing; + FunctionImporter::ImportMapTy &ImportList; + StringMap<FunctionImporter::ExportSetTy> *const ExportLists; + + bool shouldImportGlobal(const ValueInfo &VI) { + const auto &GVS = DefinedGVSummaries.find(VI.getGUID()); + if (GVS == DefinedGVSummaries.end()) + return true; + // We should not skip import if the module contains a non-prevailing + // definition with interposable linkage type. This is required for + // correctness in the situation where there is a prevailing def available + // for import and marked read-only. In this case, the non-prevailing def + // will be converted to a declaration, while the prevailing one becomes + // internal, thus no definitions will be available for linking. In order to + // prevent undefined symbol link error, the prevailing definition must be + // imported. + // FIXME: Consider adding a check that the suitable prevailing definition + // exists and marked read-only. + if (VI.getSummaryList().size() > 1 && + GlobalValue::isInterposableLinkage(GVS->second->linkage()) && + !IsPrevailing(VI.getGUID(), GVS->second)) + return true; -static void computeImportForReferencedGlobals( - const GlobalValueSummary &Summary, const ModuleSummaryIndex &Index, - const GVSummaryMapTy &DefinedGVSummaries, - SmallVectorImpl<EdgeInfo> &Worklist, - FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists) { - for (const auto &VI : Summary.refs()) { - if (!shouldImportGlobal(VI, DefinedGVSummaries)) { - LLVM_DEBUG( - dbgs() << "Ref ignored! Target already in destination module.\n"); - continue; - } + return false; + } - LLVM_DEBUG(dbgs() << " ref -> " << VI << "\n"); - - // If this is a local variable, make sure we import the copy - // in the caller's module. The only time a local variable can - // share an entry in the index is if there is a local with the same name - // in another module that had the same source file name (in a different - // directory), where each was compiled in their own directory so there - // was not distinguishing path. - auto LocalNotInModule = [&](const GlobalValueSummary *RefSummary) -> bool { - return GlobalValue::isLocalLinkage(RefSummary->linkage()) && - RefSummary->modulePath() != Summary.modulePath(); - }; + void + onImportingSummaryImpl(const GlobalValueSummary &Summary, + SmallVectorImpl<const GlobalVarSummary *> &Worklist) { + for (const auto &VI : Summary.refs()) { + if (!shouldImportGlobal(VI)) { + LLVM_DEBUG( + dbgs() << "Ref ignored! Target already in destination module.\n"); + continue; + } - for (const auto &RefSummary : VI.getSummaryList()) - if (isa<GlobalVarSummary>(RefSummary.get()) && - Index.canImportGlobalVar(RefSummary.get(), /* AnalyzeRefs */ true) && - !LocalNotInModule(RefSummary.get())) { + LLVM_DEBUG(dbgs() << " ref -> " << VI << "\n"); + + // If this is a local variable, make sure we import the copy + // in the caller's module. The only time a local variable can + // share an entry in the index is if there is a local with the same name + // in another module that had the same source file name (in a different + // directory), where each was compiled in their own directory so there + // was not distinguishing path. + auto LocalNotInModule = + [&](const GlobalValueSummary *RefSummary) -> bool { + return GlobalValue::isLocalLinkage(RefSummary->linkage()) && + RefSummary->modulePath() != Summary.modulePath(); + }; + + for (const auto &RefSummary : VI.getSummaryList()) { + const auto *GVS = dyn_cast<GlobalVarSummary>(RefSummary.get()); + // Functions could be referenced by global vars - e.g. a vtable; but we + // don't currently imagine a reason those would be imported here, rather + // than as part of the logic deciding which functions to import (i.e. + // based on profile information). Should we decide to handle them here, + // we can refactor accordingly at that time. + if (!GVS || !Index.canImportGlobalVar(GVS, /* AnalyzeRefs */ true) || + LocalNotInModule(GVS)) + continue; auto ILI = ImportList[RefSummary->modulePath()].insert(VI.getGUID()); // Only update stat and exports if we haven't already imported this // variable. if (!ILI.second) break; NumImportedGlobalVarsThinLink++; - // Any references made by this variable will be marked exported later, - // in ComputeCrossModuleImport, after import decisions are complete, - // which is more efficient than adding them here. + // Any references made by this variable will be marked exported + // later, in ComputeCrossModuleImport, after import decisions are + // complete, which is more efficient than adding them here. if (ExportLists) (*ExportLists)[RefSummary->modulePath()].insert(VI); // If variable is not writeonly we attempt to recursively analyze // its references in order to import referenced constants. - if (!Index.isWriteOnly(cast<GlobalVarSummary>(RefSummary.get()))) - Worklist.emplace_back(RefSummary.get(), 0); + if (!Index.isWriteOnly(GVS)) + Worklist.emplace_back(GVS); break; } + } } -} + +public: + GlobalsImporter( + const ModuleSummaryIndex &Index, const GVSummaryMapTy &DefinedGVSummaries, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing, + FunctionImporter::ImportMapTy &ImportList, + StringMap<FunctionImporter::ExportSetTy> *ExportLists) + : Index(Index), DefinedGVSummaries(DefinedGVSummaries), + IsPrevailing(IsPrevailing), ImportList(ImportList), + ExportLists(ExportLists) {} + + void onImportingSummary(const GlobalValueSummary &Summary) { + SmallVector<const GlobalVarSummary *, 128> Worklist; + onImportingSummaryImpl(Summary, Worklist); + while (!Worklist.empty()) + onImportingSummaryImpl(*Worklist.pop_back_val(), Worklist); + } +}; static const char * getFailureName(FunctionImporter::ImportFailureReason Reason) { @@ -348,12 +399,13 @@ getFailureName(FunctionImporter::ImportFailureReason Reason) { static void computeImportForFunction( const FunctionSummary &Summary, const ModuleSummaryIndex &Index, const unsigned Threshold, const GVSummaryMapTy &DefinedGVSummaries, - SmallVectorImpl<EdgeInfo> &Worklist, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing, + SmallVectorImpl<EdgeInfo> &Worklist, GlobalsImporter &GVImporter, FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists, FunctionImporter::ImportThresholdsTy &ImportThresholds) { - computeImportForReferencedGlobals(Summary, Index, DefinedGVSummaries, - Worklist, ImportList, ExportLists); + GVImporter.onImportingSummary(Summary); static int ImportCount = 0; for (const auto &Edge : Summary.calls()) { ValueInfo VI = Edge.first; @@ -432,7 +484,7 @@ static void computeImportForFunction( FunctionImporter::ImportFailureReason Reason; CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, - Summary.modulePath(), Reason, VI.getGUID()); + Summary.modulePath(), Reason); if (!CalleeSummary) { // Update with new larger threshold if this was a retry (otherwise // we would have already inserted with NewThreshold above). Also @@ -519,12 +571,17 @@ static void computeImportForFunction( /// as well as the list of "exports", i.e. the list of symbols referenced from /// another module (that may require promotion). static void ComputeImportForModule( - const GVSummaryMapTy &DefinedGVSummaries, const ModuleSummaryIndex &Index, - StringRef ModName, FunctionImporter::ImportMapTy &ImportList, + const GVSummaryMapTy &DefinedGVSummaries, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing, + const ModuleSummaryIndex &Index, StringRef ModName, + FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. SmallVector<EdgeInfo, 128> Worklist; + GlobalsImporter GVI(Index, DefinedGVSummaries, isPrevailing, ImportList, + ExportLists); FunctionImporter::ImportThresholdsTy ImportThresholds; // Populate the worklist with the import for the functions in the current @@ -546,8 +603,8 @@ static void ComputeImportForModule( continue; LLVM_DEBUG(dbgs() << "Initialize import for " << VI << "\n"); computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, - DefinedGVSummaries, Worklist, ImportList, - ExportLists, ImportThresholds); + DefinedGVSummaries, isPrevailing, Worklist, GVI, + ImportList, ExportLists, ImportThresholds); } // Process the newly imported functions and add callees to the worklist. @@ -558,11 +615,8 @@ static void ComputeImportForModule( if (auto *FS = dyn_cast<FunctionSummary>(Summary)) computeImportForFunction(*FS, Index, Threshold, DefinedGVSummaries, - Worklist, ImportList, ExportLists, - ImportThresholds); - else - computeImportForReferencedGlobals(*Summary, Index, DefinedGVSummaries, - Worklist, ImportList, ExportLists); + isPrevailing, Worklist, GVI, ImportList, + ExportLists, ImportThresholds); } // Print stats about functions considered but rejected for importing @@ -632,17 +686,23 @@ checkVariableImport(const ModuleSummaryIndex &Index, // Checks that all GUIDs of read/writeonly vars we see in export lists // are also in the import lists. Otherwise we my face linker undefs, // because readonly and writeonly vars are internalized in their - // source modules. - auto IsReadOrWriteOnlyVar = [&](StringRef ModulePath, const ValueInfo &VI) { + // source modules. The exception would be if it has a linkage type indicating + // that there may have been a copy existing in the importing module (e.g. + // linkonce_odr). In that case we cannot accurately do this checking. + auto IsReadOrWriteOnlyVarNeedingImporting = [&](StringRef ModulePath, + const ValueInfo &VI) { auto *GVS = dyn_cast_or_null<GlobalVarSummary>( Index.findSummaryInModule(VI, ModulePath)); - return GVS && (Index.isReadOnly(GVS) || Index.isWriteOnly(GVS)); + return GVS && (Index.isReadOnly(GVS) || Index.isWriteOnly(GVS)) && + !(GVS->linkage() == GlobalValue::AvailableExternallyLinkage || + GVS->linkage() == GlobalValue::WeakODRLinkage || + GVS->linkage() == GlobalValue::LinkOnceODRLinkage); }; for (auto &ExportPerModule : ExportLists) for (auto &VI : ExportPerModule.second) if (!FlattenedImports.count(VI.getGUID()) && - IsReadOrWriteOnlyVar(ExportPerModule.first(), VI)) + IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first(), VI)) return false; return true; @@ -653,6 +713,8 @@ checkVariableImport(const ModuleSummaryIndex &Index, void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing, StringMap<FunctionImporter::ImportMapTy> &ImportLists, StringMap<FunctionImporter::ExportSetTy> &ExportLists) { // For each module that has function defined, compute the import/export lists. @@ -660,7 +722,7 @@ void llvm::ComputeCrossModuleImport( auto &ImportList = ImportLists[DefinedGVSummaries.first()]; LLVM_DEBUG(dbgs() << "Computing import for Module '" << DefinedGVSummaries.first() << "'\n"); - ComputeImportForModule(DefinedGVSummaries.second, Index, + ComputeImportForModule(DefinedGVSummaries.second, isPrevailing, Index, DefinedGVSummaries.first(), ImportList, &ExportLists); } @@ -759,7 +821,10 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index, /// Compute all the imports for the given module in the Index. void llvm::ComputeCrossModuleImportForModule( - StringRef ModulePath, const ModuleSummaryIndex &Index, + StringRef ModulePath, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing, + const ModuleSummaryIndex &Index, FunctionImporter::ImportMapTy &ImportList) { // Collect the list of functions this module defines. // GUID -> Summary @@ -768,7 +833,8 @@ void llvm::ComputeCrossModuleImportForModule( // Compute the import list for this module. LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); - ComputeImportForModule(FunctionSummaryMap, Index, ModulePath, ImportList); + ComputeImportForModule(FunctionSummaryMap, isPrevailing, Index, ModulePath, + ImportList); #ifndef NDEBUG dumpImportListForModule(Index, ModulePath, ImportList); @@ -1373,8 +1439,9 @@ Expected<bool> FunctionImporter::importFunctions( if (Error Err = Mover.move(std::move(SrcModule), GlobalsToImport.getArrayRef(), nullptr, /*IsPerformingImport=*/true)) - report_fatal_error(Twine("Function Import: link error: ") + - toString(std::move(Err))); + return createStringError(errc::invalid_argument, + Twine("Function Import: link error: ") + + toString(std::move(Err))); ImportedCount += GlobalsToImport.size(); NumImportedModules++; @@ -1394,7 +1461,9 @@ Expected<bool> FunctionImporter::importFunctions( return ImportedCount; } -static bool doImportingForModule(Module &M) { +static bool doImportingForModule( + Module &M, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing) { if (SummaryFile.empty()) report_fatal_error("error: -function-import requires -summary-file\n"); Expected<std::unique_ptr<ModuleSummaryIndex>> IndexPtrOrErr = @@ -1415,8 +1484,8 @@ static bool doImportingForModule(Module &M) { ComputeCrossModuleImportForModuleFromIndex(M.getModuleIdentifier(), *Index, ImportList); else - ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index, - ImportList); + ComputeCrossModuleImportForModule(M.getModuleIdentifier(), isPrevailing, + *Index, ImportList); // Conservatively mark all internal values as promoted. This interface is // only used when doing importing via the function importing pass. The pass @@ -1434,7 +1503,7 @@ static bool doImportingForModule(Module &M) { if (renameModuleForThinLTO(M, *Index, /*ClearDSOLocalOnDeclarations=*/false, /*GlobalsToImport=*/nullptr)) { errs() << "Error renaming module\n"; - return false; + return true; } // Perform the import now. @@ -1449,15 +1518,22 @@ static bool doImportingForModule(Module &M) { if (!Result) { logAllUnhandledErrors(Result.takeError(), errs(), "Error importing module: "); - return false; + return true; } - return *Result; + return true; } PreservedAnalyses FunctionImportPass::run(Module &M, ModuleAnalysisManager &AM) { - if (!doImportingForModule(M)) + // This is only used for testing the function import pass via opt, where we + // don't have prevailing information from the LTO context available, so just + // conservatively assume everything is prevailing (which is fine for the very + // limited use of prevailing checking in this pass). + auto isPrevailing = [](GlobalValue::GUID, const GlobalValueSummary *) { + return true; + }; + if (!doImportingForModule(M, isPrevailing)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index 4a7efb28e853..3d6c501e4596 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -48,11 +48,13 @@ #include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -64,42 +66,324 @@ using namespace llvm; #define DEBUG_TYPE "function-specialization" -STATISTIC(NumFuncSpecialized, "Number of functions specialized"); +STATISTIC(NumSpecsCreated, "Number of specializations created"); -static cl::opt<bool> ForceFunctionSpecialization( - "force-function-specialization", cl::init(false), cl::Hidden, - cl::desc("Force function specialization for every call site with a " - "constant argument")); +static cl::opt<bool> ForceSpecialization( + "force-specialization", cl::init(false), cl::Hidden, cl::desc( + "Force function specialization for every call site with a constant " + "argument")); -static cl::opt<unsigned> MaxClonesThreshold( - "func-specialization-max-clones", cl::Hidden, - cl::desc("The maximum number of clones allowed for a single function " - "specialization"), - cl::init(3)); +static cl::opt<unsigned> MaxClones( + "funcspec-max-clones", cl::init(3), cl::Hidden, cl::desc( + "The maximum number of clones allowed for a single function " + "specialization")); -static cl::opt<unsigned> SmallFunctionThreshold( - "func-specialization-size-threshold", cl::Hidden, - cl::desc("Don't specialize functions that have less than this theshold " - "number of instructions"), - cl::init(100)); +static cl::opt<unsigned> MaxIncomingPhiValues( + "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc( + "The maximum number of incoming values a PHI node can have to be " + "considered during the specialization bonus estimation")); -static cl::opt<unsigned> - AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden, - cl::desc("Average loop iteration count cost"), - cl::init(10)); +static cl::opt<unsigned> MinFunctionSize( + "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc( + "Don't specialize functions that have less than this number of " + "instructions")); -static cl::opt<bool> SpecializeOnAddresses( - "func-specialization-on-address", cl::init(false), cl::Hidden, - cl::desc("Enable function specialization on the address of global values")); +static cl::opt<bool> SpecializeOnAddress( + "funcspec-on-address", cl::init(false), cl::Hidden, cl::desc( + "Enable function specialization on the address of global values")); // Disabled by default as it can significantly increase compilation times. // // https://llvm-compile-time-tracker.com // https://github.com/nikic/llvm-compile-time-tracker -static cl::opt<bool> EnableSpecializationForLiteralConstant( - "function-specialization-for-literal-constant", cl::init(false), cl::Hidden, - cl::desc("Enable specialization of functions that take a literal constant " - "as an argument.")); +static cl::opt<bool> SpecializeLiteralConstant( + "funcspec-for-literal-constant", cl::init(false), cl::Hidden, cl::desc( + "Enable specialization of functions that take a literal constant as an " + "argument")); + +// Estimates the instruction cost of all the basic blocks in \p WorkList. +// The successors of such blocks are added to the list as long as they are +// executable and they have a unique predecessor. \p WorkList represents +// the basic blocks of a specialization which become dead once we replace +// instructions that are known to be constants. The aim here is to estimate +// the combination of size and latency savings in comparison to the non +// specialized version of the function. +static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList, + DenseSet<BasicBlock *> &DeadBlocks, + ConstMap &KnownConstants, SCCPSolver &Solver, + BlockFrequencyInfo &BFI, + TargetTransformInfo &TTI) { + Cost Bonus = 0; + + // Accumulate the instruction cost of each basic block weighted by frequency. + while (!WorkList.empty()) { + BasicBlock *BB = WorkList.pop_back_val(); + + uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + continue; + + // These blocks are considered dead as far as the InstCostVisitor is + // concerned. They haven't been proven dead yet by the Solver, but + // may become if we propagate the constant specialization arguments. + if (!DeadBlocks.insert(BB).second) + continue; + + for (Instruction &I : *BB) { + // Disregard SSA copies. + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + continue; + // If it's a known constant we have already accounted for it. + if (KnownConstants.contains(&I)) + continue; + + Bonus += Weight * + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " after user " << I << "\n"); + } + + // Keep adding dead successors to the list as long as they are + // executable and they have a unique predecessor. + for (BasicBlock *SuccBB : successors(BB)) + if (Solver.isBlockExecutable(SuccBB) && + SuccBB->getUniquePredecessor() == BB) + WorkList.push_back(SuccBB); + } + return Bonus; +} + +static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) { + if (auto *C = dyn_cast<Constant>(V)) + return C; + if (auto It = KnownConstants.find(V); It != KnownConstants.end()) + return It->second; + return nullptr; +} + +Cost InstCostVisitor::getBonusFromPendingPHIs() { + Cost Bonus = 0; + while (!PendingPHIs.empty()) { + Instruction *Phi = PendingPHIs.pop_back_val(); + Bonus += getUserBonus(Phi); + } + return Bonus; +} + +Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { + // Cache the iterator before visiting. + LastVisited = Use ? KnownConstants.insert({Use, C}).first + : KnownConstants.end(); + + if (auto *I = dyn_cast<SwitchInst>(User)) + return estimateSwitchInst(*I); + + if (auto *I = dyn_cast<BranchInst>(User)) + return estimateBranchInst(*I); + + C = visit(*User); + if (!C) + return 0; + + KnownConstants.insert({User, C}); + + uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + return 0; + + Cost Bonus = Weight * + TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " for user " << *User << "\n"); + + for (auto *U : User->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + if (UI != User && Solver.isBlockExecutable(UI->getParent())) + Bonus += getUserBonus(UI, User, C); + + return Bonus; +} + +Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + if (I.getCondition() != LastVisited->first) + return 0; + + auto *C = dyn_cast<ConstantInt>(LastVisited->second); + if (!C) + return 0; + + BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor(); + // Initialize the worklist with the dead basic blocks. These are the + // destination labels which are different from the one corresponding + // to \p C. They should be executable and have a unique predecessor. + SmallVector<BasicBlock *> WorkList; + for (const auto &Case : I.cases()) { + BasicBlock *BB = Case.getCaseSuccessor(); + if (BB == Succ || !Solver.isBlockExecutable(BB) || + BB->getUniquePredecessor() != I.getParent()) + continue; + WorkList.push_back(BB); + } + + return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI, + TTI); +} + +Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + if (I.getCondition() != LastVisited->first) + return 0; + + BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue()); + // Initialize the worklist with the dead successor as long as + // it is executable and has a unique predecessor. + SmallVector<BasicBlock *> WorkList; + if (Solver.isBlockExecutable(Succ) && + Succ->getUniquePredecessor() == I.getParent()) + WorkList.push_back(Succ); + + return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI, + TTI); +} + +Constant *InstCostVisitor::visitPHINode(PHINode &I) { + if (I.getNumIncomingValues() > MaxIncomingPhiValues) + return nullptr; + + bool Inserted = VisitedPHIs.insert(&I).second; + Constant *Const = nullptr; + + for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) { + Value *V = I.getIncomingValue(Idx); + if (auto *Inst = dyn_cast<Instruction>(V)) + if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx))) + continue; + Constant *C = findConstantFor(V, KnownConstants); + if (!C) { + if (Inserted) + PendingPHIs.push_back(&I); + return nullptr; + } + if (!Const) + Const = C; + else if (C != Const) + return nullptr; + } + return Const; +} + +Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second)) + return LastVisited->second; + return nullptr; +} + +Constant *InstCostVisitor::visitCallBase(CallBase &I) { + Function *F = I.getCalledFunction(); + if (!F || !canConstantFoldCallTo(&I, F)) + return nullptr; + + SmallVector<Constant *, 8> Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned Idx = 0, E = I.getNumOperands() - 1; Idx != E; ++Idx) { + Value *V = I.getOperand(Idx); + Constant *C = findConstantFor(V, KnownConstants); + if (!C) + return nullptr; + Operands.push_back(C); + } + + auto Ops = ArrayRef(Operands.begin(), Operands.end()); + return ConstantFoldCall(&I, F, Ops); +} + +Constant *InstCostVisitor::visitLoadInst(LoadInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + if (isa<ConstantPointerNull>(LastVisited->second)) + return nullptr; + return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL); +} + +Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { + SmallVector<Constant *, 8> Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) { + Value *V = I.getOperand(Idx); + Constant *C = findConstantFor(V, KnownConstants); + if (!C) + return nullptr; + Operands.push_back(C); + } + + auto Ops = ArrayRef(Operands.begin(), Operands.end()); + return ConstantFoldInstOperands(&I, Ops, DL); +} + +Constant *InstCostVisitor::visitSelectInst(SelectInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + if (I.getCondition() != LastVisited->first) + return nullptr; + + Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue() + : I.getTrueValue(); + Constant *C = findConstantFor(V, KnownConstants); + return C; +} + +Constant *InstCostVisitor::visitCastInst(CastInst &I) { + return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second, + I.getType(), DL); +} + +Constant *InstCostVisitor::visitCmpInst(CmpInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + bool Swap = I.getOperand(1) == LastVisited->first; + Value *V = Swap ? I.getOperand(0) : I.getOperand(1); + Constant *Other = findConstantFor(V, KnownConstants); + if (!Other) + return nullptr; + + Constant *Const = LastVisited->second; + return Swap ? + ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL) + : ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL); +} + +Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL); +} + +Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + + bool Swap = I.getOperand(1) == LastVisited->first; + Value *V = Swap ? I.getOperand(0) : I.getOperand(1); + Constant *Other = findConstantFor(V, KnownConstants); + if (!Other) + return nullptr; + + Constant *Const = LastVisited->second; + return dyn_cast_or_null<Constant>(Swap ? + simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL)) + : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL))); +} Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { @@ -125,6 +409,10 @@ Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca, // Bail if there is any other unknown usage. return nullptr; } + + if (!StoreValue) + return nullptr; + return getCandidateConstant(StoreValue); } @@ -165,49 +453,37 @@ Constant *FunctionSpecializer::getConstantStackValue(CallInst *Call, // ret void // } // -void FunctionSpecializer::promoteConstantStackValues() { - // Iterate over the argument tracked functions see if there - // are any new constant values for the call instruction via - // stack variables. - for (Function &F : M) { - if (!Solver.isArgumentTrackedFunction(&F)) +// See if there are any new constant values for the callers of \p F via +// stack variables and promote them to global variables. +void FunctionSpecializer::promoteConstantStackValues(Function *F) { + for (User *U : F->users()) { + + auto *Call = dyn_cast<CallInst>(U); + if (!Call) continue; - for (auto *User : F.users()) { + if (!Solver.isBlockExecutable(Call->getParent())) + continue; - auto *Call = dyn_cast<CallInst>(User); - if (!Call) - continue; + for (const Use &U : Call->args()) { + unsigned Idx = Call->getArgOperandNo(&U); + Value *ArgOp = Call->getArgOperand(Idx); + Type *ArgOpType = ArgOp->getType(); - if (!Solver.isBlockExecutable(Call->getParent())) + if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) continue; - bool Changed = false; - for (const Use &U : Call->args()) { - unsigned Idx = Call->getArgOperandNo(&U); - Value *ArgOp = Call->getArgOperand(Idx); - Type *ArgOpType = ArgOp->getType(); - - if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) - continue; - - auto *ConstVal = getConstantStackValue(Call, ArgOp); - if (!ConstVal) - continue; - - Value *GV = new GlobalVariable(M, ConstVal->getType(), true, - GlobalValue::InternalLinkage, ConstVal, - "funcspec.arg"); - if (ArgOpType != ConstVal->getType()) - GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType); + auto *ConstVal = getConstantStackValue(Call, ArgOp); + if (!ConstVal) + continue; - Call->setArgOperand(Idx, GV); - Changed = true; - } + Value *GV = new GlobalVariable(M, ConstVal->getType(), true, + GlobalValue::InternalLinkage, ConstVal, + "funcspec.arg"); + if (ArgOpType != ConstVal->getType()) + GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType); - // Add the changed CallInst to Solver Worklist - if (Changed) - Solver.visitCall(*Call); + Call->setArgOperand(Idx, GV); } } } @@ -230,7 +506,7 @@ static void removeSSACopy(Function &F) { /// Remove any ssa_copy intrinsics that may have been introduced. void FunctionSpecializer::cleanUpSSA() { - for (Function *F : SpecializedFuncs) + for (Function *F : Specializations) removeSSACopy(*F); } @@ -249,6 +525,16 @@ template <> struct llvm::DenseMapInfo<SpecSig> { } }; +FunctionSpecializer::~FunctionSpecializer() { + LLVM_DEBUG( + if (NumSpecsCreated > 0) + dbgs() << "FnSpecialization: Created " << NumSpecsCreated + << " specializations in module " << M.getName() << "\n"); + // Eliminate dead code. + removeDeadFunctions(); + cleanUpSSA(); +} + /// Attempt to specialize functions in the module to enable constant /// propagation across function boundaries. /// @@ -262,17 +548,37 @@ bool FunctionSpecializer::run() { if (!isCandidateFunction(&F)) continue; - auto Cost = getSpecializationCost(&F); - if (!Cost.isValid()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialization cost for " - << F.getName() << "\n"); - continue; + auto [It, Inserted] = FunctionMetrics.try_emplace(&F); + CodeMetrics &Metrics = It->second; + //Analyze the function. + if (Inserted) { + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(&F, &GetAC(F), EphValues); + for (BasicBlock &BB : F) + Metrics.analyzeBasicBlock(&BB, GetTTI(F), EphValues); } + // If the code metrics reveal that we shouldn't duplicate the function, + // or if the code size implies that this function is easy to get inlined, + // then we shouldn't specialize it. + if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || + (!ForceSpecialization && !F.hasFnAttribute(Attribute::NoInline) && + Metrics.NumInsts < MinFunctionSize)) + continue; + + // TODO: For now only consider recursive functions when running multiple + // times. This should change if specialization on literal constants gets + // enabled. + if (!Inserted && !Metrics.isRecursive && !SpecializeLiteralConstant) + continue; + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " - << F.getName() << " is " << Cost << "\n"); + << F.getName() << " is " << Metrics.NumInsts << "\n"); + + if (Inserted && Metrics.isRecursive) + promoteConstantStackValues(&F); - if (!findSpecializations(&F, Cost, AllSpecs, SM)) { + if (!findSpecializations(&F, Metrics.NumInsts, AllSpecs, SM)) { LLVM_DEBUG( dbgs() << "FnSpecialization: No possible specializations found for " << F.getName() << "\n"); @@ -292,11 +598,11 @@ bool FunctionSpecializer::run() { // Choose the most profitable specialisations, which fit in the module // specialization budget, which is derived from maximum number of // specializations per specialization candidate function. - auto CompareGain = [&AllSpecs](unsigned I, unsigned J) { - return AllSpecs[I].Gain > AllSpecs[J].Gain; + auto CompareScore = [&AllSpecs](unsigned I, unsigned J) { + return AllSpecs[I].Score > AllSpecs[J].Score; }; const unsigned NSpecs = - std::min(NumCandidates * MaxClonesThreshold, unsigned(AllSpecs.size())); + std::min(NumCandidates * MaxClones, unsigned(AllSpecs.size())); SmallVector<unsigned> BestSpecs(NSpecs + 1); std::iota(BestSpecs.begin(), BestSpecs.begin() + NSpecs, 0); if (AllSpecs.size() > NSpecs) { @@ -305,11 +611,11 @@ bool FunctionSpecializer::run() { << "FnSpecialization: Specializing the " << NSpecs << " most profitable candidates.\n"); - std::make_heap(BestSpecs.begin(), BestSpecs.begin() + NSpecs, CompareGain); + std::make_heap(BestSpecs.begin(), BestSpecs.begin() + NSpecs, CompareScore); for (unsigned I = NSpecs, N = AllSpecs.size(); I < N; ++I) { BestSpecs[NSpecs] = I; - std::push_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain); - std::pop_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain); + std::push_heap(BestSpecs.begin(), BestSpecs.end(), CompareScore); + std::pop_heap(BestSpecs.begin(), BestSpecs.end(), CompareScore); } } @@ -317,7 +623,7 @@ bool FunctionSpecializer::run() { for (unsigned I = 0; I < NSpecs; ++I) { const Spec &S = AllSpecs[BestSpecs[I]]; dbgs() << "FnSpecialization: Function " << S.F->getName() - << " , gain " << S.Gain << "\n"; + << " , score " << S.Score << "\n"; for (const ArgInfo &Arg : S.Sig.Args) dbgs() << "FnSpecialization: FormalArg = " << Arg.Formal->getNameOrAsOperand() @@ -353,12 +659,37 @@ bool FunctionSpecializer::run() { updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End); } - promoteConstantStackValues(); - LLVM_DEBUG(if (NbFunctionsSpecialized) dbgs() - << "FnSpecialization: Specialized " << NbFunctionsSpecialized - << " functions in module " << M.getName() << "\n"); + for (Function *F : Clones) { + if (F->getReturnType()->isVoidTy()) + continue; + if (F->getReturnType()->isStructTy()) { + auto *STy = cast<StructType>(F->getReturnType()); + if (!Solver.isStructLatticeConstant(F, STy)) + continue; + } else { + auto It = Solver.getTrackedRetVals().find(F); + assert(It != Solver.getTrackedRetVals().end() && + "Return value ought to be tracked"); + if (SCCPSolver::isOverdefined(It->second)) + continue; + } + for (User *U : F->users()) { + if (auto *CS = dyn_cast<CallBase>(U)) { + //The user instruction does not call our function. + if (CS->getCalledFunction() != F) + continue; + Solver.resetLatticeValueFor(CS); + } + } + } + + // Rerun the solver to notify the users of the modified callsites. + Solver.solveWhileResolvedUndefs(); + + for (Function *F : OriginalFuncs) + if (FunctionMetrics[F].isRecursive) + promoteConstantStackValues(F); - NumFuncSpecialized += NbFunctionsSpecialized; return true; } @@ -373,24 +704,6 @@ void FunctionSpecializer::removeDeadFunctions() { FullySpecialized.clear(); } -// Compute the code metrics for function \p F. -CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) { - auto I = FunctionMetrics.insert({F, CodeMetrics()}); - CodeMetrics &Metrics = I.first->second; - if (I.second) { - // The code metrics were not cached. - SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); - for (BasicBlock &BB : *F) - Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function " - << F->getName() << " is " << Metrics.NumInsts - << " instructions\n"); - } - return Metrics; -} - /// Clone the function \p F and remove the ssa_copy intrinsics added by /// the SCCPSolver in the cloned version. static Function *cloneCandidateFunction(Function *F) { @@ -400,13 +713,13 @@ static Function *cloneCandidateFunction(Function *F) { return Clone; } -bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost, +bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, SmallVectorImpl<Spec> &AllSpecs, SpecMap &SM) { // A mapping from a specialisation signature to the index of the respective // entry in the all specialisation array. Used to ensure uniqueness of // specialisations. - DenseMap<SpecSig, unsigned> UM; + DenseMap<SpecSig, unsigned> UniqueSpecs; // Get a list of interesting arguments. SmallVector<Argument *> Args; @@ -417,7 +730,6 @@ bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost, if (Args.empty()) return false; - bool Found = false; for (User *U : F->users()) { if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) continue; @@ -454,7 +766,7 @@ bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost, continue; // Check if we have encountered the same specialisation already. - if (auto It = UM.find(S); It != UM.end()) { + if (auto It = UniqueSpecs.find(S); It != UniqueSpecs.end()) { // Existing specialisation. Add the call to the list to rewrite, unless // it's a recursive call. A specialisation, generated because of a // recursive call may end up as not the best specialisation for all @@ -467,42 +779,42 @@ bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost, AllSpecs[Index].CallSites.push_back(&CS); } else { // Calculate the specialisation gain. - InstructionCost Gain = 0 - Cost; + Cost Score = 0; + InstCostVisitor Visitor = getInstCostVisitorFor(F); for (ArgInfo &A : S.Args) - Gain += - getSpecializationBonus(A.Formal, A.Actual, Solver.getLoopInfo(*F)); + Score += getSpecializationBonus(A.Formal, A.Actual, Visitor); + Score += Visitor.getBonusFromPendingPHIs(); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = " + << Score << "\n"); // Discard unprofitable specialisations. - if (!ForceFunctionSpecialization && Gain <= 0) + if (!ForceSpecialization && Score <= SpecCost) continue; // Create a new specialisation entry. - auto &Spec = AllSpecs.emplace_back(F, S, Gain); + auto &Spec = AllSpecs.emplace_back(F, S, Score); if (CS.getFunction() != F) Spec.CallSites.push_back(&CS); const unsigned Index = AllSpecs.size() - 1; - UM[S] = Index; + UniqueSpecs[S] = Index; if (auto [It, Inserted] = SM.try_emplace(F, Index, Index + 1); !Inserted) It->second.second = Index + 1; - Found = true; } } - return Found; + return !UniqueSpecs.empty(); } bool FunctionSpecializer::isCandidateFunction(Function *F) { - if (F->isDeclaration()) + if (F->isDeclaration() || F->arg_empty()) return false; if (F->hasFnAttribute(Attribute::NoDuplicate)) return false; - if (!Solver.isArgumentTrackedFunction(F)) - return false; - // Do not specialize the cloned function again. - if (SpecializedFuncs.contains(F)) + if (Specializations.contains(F)) return false; // If we're optimizing the function for size, we shouldn't specialize it. @@ -524,86 +836,50 @@ bool FunctionSpecializer::isCandidateFunction(Function *F) { return true; } -Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) { +Function *FunctionSpecializer::createSpecialization(Function *F, + const SpecSig &S) { Function *Clone = cloneCandidateFunction(F); + // The original function does not neccessarily have internal linkage, but the + // clone must. + Clone->setLinkage(GlobalValue::InternalLinkage); + // Initialize the lattice state of the arguments of the function clone, // marking the argument on which we specialized the function constant // with the given value. - Solver.markArgInFuncSpecialization(Clone, S.Args); - - Solver.addArgumentTrackedFunction(Clone); + Solver.setLatticeValueForSpecializationArguments(Clone, S.Args); Solver.markBlockExecutable(&Clone->front()); + Solver.addArgumentTrackedFunction(Clone); + Solver.addTrackedFunction(Clone); // Mark all the specialized functions - SpecializedFuncs.insert(Clone); - NbFunctionsSpecialized++; + Specializations.insert(Clone); + ++NumSpecsCreated; return Clone; } -/// Compute and return the cost of specializing function \p F. -InstructionCost FunctionSpecializer::getSpecializationCost(Function *F) { - CodeMetrics &Metrics = analyzeFunction(F); - // If the code metrics reveal that we shouldn't duplicate the function, we - // shouldn't specialize it. Set the specialization cost to Invalid. - // Or if the lines of codes implies that this function is easy to get - // inlined so that we shouldn't specialize it. - if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || - (!ForceFunctionSpecialization && - !F->hasFnAttribute(Attribute::NoInline) && - Metrics.NumInsts < SmallFunctionThreshold)) - return InstructionCost::getInvalid(); - - // Otherwise, set the specialization cost to be the cost of all the - // instructions in the function. - return Metrics.NumInsts * InlineConstants::getInstrCost(); -} - -static InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI, - const LoopInfo &LI) { - auto *I = dyn_cast_or_null<Instruction>(U); - // If not an instruction we do not know how to evaluate. - // Keep minimum possible cost for now so that it doesnt affect - // specialization. - if (!I) - return std::numeric_limits<unsigned>::min(); - - InstructionCost Cost = - TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency); - - // Increase the cost if it is inside the loop. - unsigned LoopDepth = LI.getLoopDepth(I->getParent()); - Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth); - - // Traverse recursively if there are more uses. - // TODO: Any other instructions to be added here? - if (I->mayReadFromMemory() || I->isCast()) - for (auto *User : I->users()) - Cost += getUserBonus(User, TTI, LI); - - return Cost; -} - /// Compute a bonus for replacing argument \p A with constant \p C. -InstructionCost -FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, - const LoopInfo &LI) { - Function *F = A->getParent(); - auto &TTI = (GetTTI)(*F); +Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, + InstCostVisitor &Visitor) { LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " << C->getNameOrAsOperand() << "\n"); - InstructionCost TotalCost = 0; - for (auto *U : A->users()) { - TotalCost += getUserBonus(U, TTI, LI); - LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; - TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); - } + Cost TotalCost = 0; + for (auto *U : A->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + if (Solver.isBlockExecutable(UI->getParent())) + TotalCost += Visitor.getUserBonus(UI, A, C); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus " + << TotalCost << " for argument " << *A << "\n"); // The below heuristic is only concerned with exposing inlining // opportunities via indirect call promotion. If the argument is not a // (potentially casted) function pointer, give up. + // + // TODO: Perhaps we should consider checking such inlining opportunities + // while traversing the users of the specialization arguments ? Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts()); if (!CalledFunction) return TotalCost; @@ -661,16 +937,9 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) { if (A->user_empty()) return false; - // For now, don't attempt to specialize functions based on the values of - // composite types. - Type *ArgTy = A->getType(); - if (!ArgTy->isSingleValueType()) - return false; - - // Specialization of integer and floating point types needs to be explicitly - // enabled. - if (!EnableSpecializationForLiteralConstant && - (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy())) + Type *Ty = A->getType(); + if (!Ty->isPointerTy() && (!SpecializeLiteralConstant || + (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isStructTy()))) return false; // SCCP solver does not record an argument that will be constructed on @@ -678,54 +947,46 @@ bool FunctionSpecializer::isArgumentInteresting(Argument *A) { if (A->hasByValAttr() && !A->getParent()->onlyReadsMemory()) return false; + // For non-argument-tracked functions every argument is overdefined. + if (!Solver.isArgumentTrackedFunction(A->getParent())) + return true; + // Check the lattice value and decide if we should attemt to specialize, // based on this argument. No point in specialization, if the lattice value // is already a constant. - const ValueLatticeElement &LV = Solver.getLatticeValueFor(A); - if (LV.isUnknownOrUndef() || LV.isConstant() || - (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter " - << A->getNameOrAsOperand() << " is already constant\n"); - return false; - } - - LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter " - << A->getNameOrAsOperand() << "\n"); - - return true; + bool IsOverdefined = Ty->isStructTy() + ? any_of(Solver.getStructLatticeValueFor(A), SCCPSolver::isOverdefined) + : SCCPSolver::isOverdefined(Solver.getLatticeValueFor(A)); + + LLVM_DEBUG( + if (IsOverdefined) + dbgs() << "FnSpecialization: Found interesting parameter " + << A->getNameOrAsOperand() << "\n"; + else + dbgs() << "FnSpecialization: Nothing to do, parameter " + << A->getNameOrAsOperand() << " is already constant\n"; + ); + return IsOverdefined; } -/// Check if the valuy \p V (an actual argument) is a constant or can only +/// Check if the value \p V (an actual argument) is a constant or can only /// have a constant value. Return that constant. Constant *FunctionSpecializer::getCandidateConstant(Value *V) { if (isa<PoisonValue>(V)) return nullptr; - // TrackValueOfGlobalVariable only tracks scalar global variables. - if (auto *GV = dyn_cast<GlobalVariable>(V)) { - // Check if we want to specialize on the address of non-constant - // global values. - if (!GV->isConstant() && !SpecializeOnAddresses) - return nullptr; - - if (!GV->getValueType()->isSingleValueType()) - return nullptr; - } - // Select for possible specialisation values that are constants or // are deduced to be constants or constant ranges with a single element. Constant *C = dyn_cast<Constant>(V); - if (!C) { - const ValueLatticeElement &LV = Solver.getLatticeValueFor(V); - if (LV.isConstant()) - C = LV.getConstant(); - else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) { - assert(V->getType()->isIntegerTy() && "Non-integral constant range"); - C = Constant::getIntegerValue(V->getType(), - *LV.getConstantRange().getSingleElement()); - } else + if (!C) + C = Solver.getConstantOrNull(V); + + // Don't specialize on (anything derived from) the address of a non-constant + // global variable, unless explicitly enabled. + if (C && C->getType()->isPointerTy() && !C->isNullValue()) + if (auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(C)); + GV && !(GV->isConstant() || SpecializeOnAddress)) return nullptr; - } return C; } @@ -747,7 +1008,7 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin, // Find the best matching specialisation. const Spec *BestSpec = nullptr; for (const Spec &S : make_range(Begin, End)) { - if (!S.Clone || (BestSpec && S.Gain <= BestSpec->Gain)) + if (!S.Clone || (BestSpec && S.Score <= BestSpec->Score)) continue; if (any_of(S.Sig.Args, [CS, this](const ArgInfo &Arg) { @@ -772,7 +1033,7 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin, // If the function has been completely specialized, the original function // is no longer needed. Mark it unreachable. - if (NCallsLeft == 0) { + if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F)) { Solver.markFunctionUnreachable(F); FullySpecialized.insert(F); } diff --git a/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/llvm/lib/Transforms/IPO/GlobalDCE.cpp index 2f2bb174a8c8..e36d524d7667 100644 --- a/llvm/lib/Transforms/IPO/GlobalDCE.cpp +++ b/llvm/lib/Transforms/IPO/GlobalDCE.cpp @@ -21,8 +21,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CtorUtils.h" @@ -42,47 +40,6 @@ STATISTIC(NumIFuncs, "Number of indirect functions removed"); STATISTIC(NumVariables, "Number of global variables removed"); STATISTIC(NumVFuncs, "Number of virtual functions removed"); -namespace { - class GlobalDCELegacyPass : public ModulePass { - public: - static char ID; // Pass identification, replacement for typeid - GlobalDCELegacyPass() : ModulePass(ID) { - initializeGlobalDCELegacyPassPass(*PassRegistry::getPassRegistry()); - } - - // run - Do the GlobalDCE pass on the specified module, optionally updating - // the specified callgraph to reflect the changes. - // - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - // We need a minimally functional dummy module analysis manager. It needs - // to at least know about the possibility of proxying a function analysis - // manager. - FunctionAnalysisManager DummyFAM; - ModuleAnalysisManager DummyMAM; - DummyMAM.registerPass( - [&] { return FunctionAnalysisManagerModuleProxy(DummyFAM); }); - - auto PA = Impl.run(M, DummyMAM); - return !PA.areAllPreserved(); - } - - private: - GlobalDCEPass Impl; - }; -} - -char GlobalDCELegacyPass::ID = 0; -INITIALIZE_PASS(GlobalDCELegacyPass, "globaldce", - "Dead Global Elimination", false, false) - -// Public interface to the GlobalDCEPass. -ModulePass *llvm::createGlobalDCEPass() { - return new GlobalDCELegacyPass(); -} - /// Returns true if F is effectively empty. static bool isEmptyFunction(Function *F) { // Skip external functions. @@ -163,12 +120,6 @@ void GlobalDCEPass::ScanVTables(Module &M) { SmallVector<MDNode *, 2> Types; LLVM_DEBUG(dbgs() << "Building type info -> vtable map\n"); - auto *LTOPostLinkMD = - cast_or_null<ConstantAsMetadata>(M.getModuleFlag("LTOPostLink")); - bool LTOPostLink = - LTOPostLinkMD && - (cast<ConstantInt>(LTOPostLinkMD->getValue())->getZExtValue() != 0); - for (GlobalVariable &GV : M.globals()) { Types.clear(); GV.getMetadata(LLVMContext::MD_type, Types); @@ -195,7 +146,7 @@ void GlobalDCEPass::ScanVTables(Module &M) { if (auto GO = dyn_cast<GlobalObject>(&GV)) { GlobalObject::VCallVisibility TypeVis = GO->getVCallVisibility(); if (TypeVis == GlobalObject::VCallVisibilityTranslationUnit || - (LTOPostLink && + (InLTOPostLink && TypeVis == GlobalObject::VCallVisibilityLinkageUnit)) { LLVM_DEBUG(dbgs() << GV.getName() << " is safe for VFE\n"); VFESafeVTables.insert(&GV); @@ -236,29 +187,36 @@ void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) { LLVM_DEBUG(dbgs() << "Scanning type.checked.load intrinsics\n"); Function *TypeCheckedLoadFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); - - if (!TypeCheckedLoadFunc) - return; - - for (auto *U : TypeCheckedLoadFunc->users()) { - auto CI = dyn_cast<CallInst>(U); - if (!CI) - continue; - - auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - Value *TypeIdValue = CI->getArgOperand(2); - auto *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); - - if (Offset) { - ScanVTableLoad(CI->getFunction(), TypeId, Offset->getZExtValue()); - } else { - // type.checked.load with a non-constant offset, so assume every entry in - // every matching vtable is used. - for (const auto &VTableInfo : TypeIdMap[TypeId]) { - VFESafeVTables.erase(VTableInfo.first); + Function *TypeCheckedLoadRelativeFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative)); + + auto scan = [&](Function *CheckedLoadFunc) { + if (!CheckedLoadFunc) + return; + + for (auto *U : CheckedLoadFunc->users()) { + auto CI = dyn_cast<CallInst>(U); + if (!CI) + continue; + + auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + Value *TypeIdValue = CI->getArgOperand(2); + auto *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); + + if (Offset) { + ScanVTableLoad(CI->getFunction(), TypeId, Offset->getZExtValue()); + } else { + // type.checked.load with a non-constant offset, so assume every entry + // in every matching vtable is used. + for (const auto &VTableInfo : TypeIdMap[TypeId]) { + VFESafeVTables.erase(VTableInfo.first); + } } } - } + }; + + scan(TypeCheckedLoadFunc); + scan(TypeCheckedLoadRelativeFunc); } void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) { @@ -271,7 +229,7 @@ void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) { // Don't attempt VFE in that case. auto *Val = mdconst::dyn_extract_or_null<ConstantInt>( M.getModuleFlag("Virtual Function Elim")); - if (!Val || Val->getZExtValue() == 0) + if (!Val || Val->isZero()) return; ScanVTables(M); @@ -458,3 +416,11 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { return PreservedAnalyses::none(); return PreservedAnalyses::all(); } + +void GlobalDCEPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<GlobalDCEPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + if (InLTOPostLink) + OS << "<vfe-linkage-unit-visibility>"; +} diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 0317a8bcb6bc..1ccc523ead8a 100644 --- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -53,8 +53,6 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -206,8 +204,10 @@ CleanupPointerRootUsers(GlobalVariable *GV, // chain of computation and the store to the global in Dead[n].second. SmallVector<std::pair<Instruction *, Instruction *>, 32> Dead; + SmallVector<User *> Worklist(GV->users()); // Constants can't be pointers to dynamically allocated memory. - for (User *U : llvm::make_early_inc_range(GV->users())) { + while (!Worklist.empty()) { + User *U = Worklist.pop_back_val(); if (StoreInst *SI = dyn_cast<StoreInst>(U)) { Value *V = SI->getValueOperand(); if (isa<Constant>(V)) { @@ -235,18 +235,8 @@ CleanupPointerRootUsers(GlobalVariable *GV, Dead.push_back(std::make_pair(I, MTI)); } } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) { - if (CE->use_empty()) { - CE->destroyConstant(); - Changed = true; - } - } else if (Constant *C = dyn_cast<Constant>(U)) { - if (isSafeToDestroyConstant(C)) { - C->destroyConstant(); - // This could have invalidated UI, start over from scratch. - Dead.clear(); - CleanupPointerRootUsers(GV, GetTLI); - return true; - } + if (isa<GEPOperator>(CE)) + append_range(Worklist, CE->users()); } } @@ -268,6 +258,7 @@ CleanupPointerRootUsers(GlobalVariable *GV, } } + GV->removeDeadConstantUsers(); return Changed; } @@ -335,10 +326,19 @@ static bool CleanupConstantGlobalUsers(GlobalVariable *GV, return Changed; } +/// Part of the global at a specific offset, which is only accessed through +/// loads and stores with the given type. +struct GlobalPart { + Type *Ty; + Constant *Initializer = nullptr; + bool IsLoaded = false; + bool IsStored = false; +}; + /// Look at all uses of the global and determine which (offset, type) pairs it /// can be split into. -static bool collectSRATypes(DenseMap<uint64_t, Type *> &Types, GlobalValue *GV, - const DataLayout &DL) { +static bool collectSRATypes(DenseMap<uint64_t, GlobalPart> &Parts, + GlobalVariable *GV, const DataLayout &DL) { SmallVector<Use *, 16> Worklist; SmallPtrSet<Use *, 16> Visited; auto AppendUses = [&](Value *V) { @@ -373,14 +373,41 @@ static bool collectSRATypes(DenseMap<uint64_t, Type *> &Types, GlobalValue *GV, // TODO: We currently require that all accesses at a given offset must // use the same type. This could be relaxed. Type *Ty = getLoadStoreType(V); - auto It = Types.try_emplace(Offset.getZExtValue(), Ty).first; - if (Ty != It->second) + const auto &[It, Inserted] = + Parts.try_emplace(Offset.getZExtValue(), GlobalPart{Ty}); + if (Ty != It->second.Ty) return false; + if (Inserted) { + It->second.Initializer = + ConstantFoldLoadFromConst(GV->getInitializer(), Ty, Offset, DL); + if (!It->second.Initializer) { + LLVM_DEBUG(dbgs() << "Global SRA: Failed to evaluate initializer of " + << *GV << " with type " << *Ty << " at offset " + << Offset.getZExtValue()); + return false; + } + } + // Scalable types not currently supported. if (isa<ScalableVectorType>(Ty)) return false; + auto IsStored = [](Value *V, Constant *Initializer) { + auto *SI = dyn_cast<StoreInst>(V); + if (!SI) + return false; + + Constant *StoredConst = dyn_cast<Constant>(SI->getOperand(0)); + if (!StoredConst) + return true; + + // Don't consider stores that only write the initializer value. + return Initializer != StoredConst; + }; + + It->second.IsLoaded |= isa<LoadInst>(V); + It->second.IsStored |= IsStored(V, It->second.Initializer); continue; } @@ -410,6 +437,7 @@ static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV, DIExpression *Expr = GVE->getExpression(); int64_t CurVarOffsetInBytes = 0; uint64_t CurVarOffsetInBits = 0; + uint64_t FragmentEndInBits = FragmentOffsetInBits + FragmentSizeInBits; // Calculate the offset (Bytes), Continue if unknown. if (!Expr->extractIfOffset(CurVarOffsetInBytes)) @@ -423,27 +451,50 @@ static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV, CurVarOffsetInBits = CHAR_BIT * (uint64_t)CurVarOffsetInBytes; // Current var starts after the fragment, ignore. - if (CurVarOffsetInBits >= (FragmentOffsetInBits + FragmentSizeInBits)) + if (CurVarOffsetInBits >= FragmentEndInBits) continue; uint64_t CurVarSize = Var->getType()->getSizeInBits(); + uint64_t CurVarEndInBits = CurVarOffsetInBits + CurVarSize; // Current variable ends before start of fragment, ignore. - if (CurVarSize != 0 && - (CurVarOffsetInBits + CurVarSize) <= FragmentOffsetInBits) + if (CurVarSize != 0 && /* CurVarSize is known */ + CurVarEndInBits <= FragmentOffsetInBits) continue; - // Current variable fits in the fragment. - if (CurVarOffsetInBits == FragmentOffsetInBits && - CurVarSize == FragmentSizeInBits) - Expr = DIExpression::get(Expr->getContext(), {}); - // If the FragmentSize is smaller than the variable, + // Current variable fits in (not greater than) the fragment, + // does not need fragment expression. + if (CurVarSize != 0 && /* CurVarSize is known */ + CurVarOffsetInBits >= FragmentOffsetInBits && + CurVarEndInBits <= FragmentEndInBits) { + uint64_t CurVarOffsetInFragment = + (CurVarOffsetInBits - FragmentOffsetInBits) / 8; + if (CurVarOffsetInFragment != 0) + Expr = DIExpression::get(Expr->getContext(), {dwarf::DW_OP_plus_uconst, + CurVarOffsetInFragment}); + else + Expr = DIExpression::get(Expr->getContext(), {}); + auto *NGVE = + DIGlobalVariableExpression::get(GVE->getContext(), Var, Expr); + NGV->addDebugInfo(NGVE); + continue; + } + // Current variable does not fit in single fragment, // emit a fragment expression. - else if (FragmentSizeInBits < VarSize) { + if (FragmentSizeInBits < VarSize) { + if (CurVarOffsetInBits > FragmentOffsetInBits) + continue; + uint64_t CurVarFragmentOffsetInBits = + FragmentOffsetInBits - CurVarOffsetInBits; + uint64_t CurVarFragmentSizeInBits = FragmentSizeInBits; + if (CurVarSize != 0 && CurVarEndInBits < FragmentEndInBits) + CurVarFragmentSizeInBits -= (FragmentEndInBits - CurVarEndInBits); + if (CurVarOffsetInBits) + Expr = DIExpression::get(Expr->getContext(), {}); if (auto E = DIExpression::createFragmentExpression( - Expr, FragmentOffsetInBits, FragmentSizeInBits)) + Expr, CurVarFragmentOffsetInBits, CurVarFragmentSizeInBits)) Expr = *E; else - return; + continue; } auto *NGVE = DIGlobalVariableExpression::get(GVE->getContext(), Var, Expr); NGV->addDebugInfo(NGVE); @@ -459,52 +510,45 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { assert(GV->hasLocalLinkage()); // Collect types to split into. - DenseMap<uint64_t, Type *> Types; - if (!collectSRATypes(Types, GV, DL) || Types.empty()) + DenseMap<uint64_t, GlobalPart> Parts; + if (!collectSRATypes(Parts, GV, DL) || Parts.empty()) return nullptr; // Make sure we don't SRA back to the same type. - if (Types.size() == 1 && Types.begin()->second == GV->getValueType()) + if (Parts.size() == 1 && Parts.begin()->second.Ty == GV->getValueType()) return nullptr; - // Don't perform SRA if we would have to split into many globals. - if (Types.size() > 16) + // Don't perform SRA if we would have to split into many globals. Ignore + // parts that are either only loaded or only stored, because we expect them + // to be optimized away. + unsigned NumParts = count_if(Parts, [](const auto &Pair) { + return Pair.second.IsLoaded && Pair.second.IsStored; + }); + if (NumParts > 16) return nullptr; // Sort by offset. - SmallVector<std::pair<uint64_t, Type *>, 16> TypesVector; - append_range(TypesVector, Types); + SmallVector<std::tuple<uint64_t, Type *, Constant *>, 16> TypesVector; + for (const auto &Pair : Parts) { + TypesVector.push_back( + {Pair.first, Pair.second.Ty, Pair.second.Initializer}); + } sort(TypesVector, llvm::less_first()); // Check that the types are non-overlapping. uint64_t Offset = 0; - for (const auto &Pair : TypesVector) { + for (const auto &[OffsetForTy, Ty, _] : TypesVector) { // Overlaps with previous type. - if (Pair.first < Offset) + if (OffsetForTy < Offset) return nullptr; - Offset = Pair.first + DL.getTypeAllocSize(Pair.second); + Offset = OffsetForTy + DL.getTypeAllocSize(Ty); } // Some accesses go beyond the end of the global, don't bother. if (Offset > DL.getTypeAllocSize(GV->getValueType())) return nullptr; - // Collect initializers for new globals. - Constant *OrigInit = GV->getInitializer(); - DenseMap<uint64_t, Constant *> Initializers; - for (const auto &Pair : Types) { - Constant *NewInit = ConstantFoldLoadFromConst(OrigInit, Pair.second, - APInt(64, Pair.first), DL); - if (!NewInit) { - LLVM_DEBUG(dbgs() << "Global SRA: Failed to evaluate initializer of " - << *GV << " with type " << *Pair.second << " at offset " - << Pair.first << "\n"); - return nullptr; - } - Initializers.insert({Pair.first, NewInit}); - } - LLVM_DEBUG(dbgs() << "PERFORMING GLOBAL SRA ON: " << *GV << "\n"); // Get the alignment of the global, either explicit or target-specific. @@ -515,26 +559,24 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { // Create replacement globals. DenseMap<uint64_t, GlobalVariable *> NewGlobals; unsigned NameSuffix = 0; - for (auto &Pair : TypesVector) { - uint64_t Offset = Pair.first; - Type *Ty = Pair.second; + for (auto &[OffsetForTy, Ty, Initializer] : TypesVector) { GlobalVariable *NGV = new GlobalVariable( *GV->getParent(), Ty, false, GlobalVariable::InternalLinkage, - Initializers[Offset], GV->getName() + "." + Twine(NameSuffix++), GV, + Initializer, GV->getName() + "." + Twine(NameSuffix++), GV, GV->getThreadLocalMode(), GV->getAddressSpace()); NGV->copyAttributesFrom(GV); - NewGlobals.insert({Offset, NGV}); + NewGlobals.insert({OffsetForTy, NGV}); // Calculate the known alignment of the field. If the original aggregate // had 256 byte alignment for example, something might depend on that: // propagate info to each field. - Align NewAlign = commonAlignment(StartAlignment, Offset); + Align NewAlign = commonAlignment(StartAlignment, OffsetForTy); if (NewAlign > DL.getABITypeAlign(Ty)) NGV->setAlignment(NewAlign); // Copy over the debug info for the variable. - transferSRADebugInfo(GV, NGV, Offset * 8, DL.getTypeAllocSizeInBits(Ty), - VarSize); + transferSRADebugInfo(GV, NGV, OffsetForTy * 8, + DL.getTypeAllocSizeInBits(Ty), VarSize); } // Replace uses of the original global with uses of the new global. @@ -621,8 +663,9 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, if (II->getCalledOperand() != V) { return false; // Not calling the ptr } - } else if (const BitCastInst *CI = dyn_cast<BitCastInst>(U)) { - if (!AllUsesOfValueWillTrapIfNull(CI, PHIs)) return false; + } else if (const AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(U)) { + if (!AllUsesOfValueWillTrapIfNull(CI, PHIs)) + return false; } else if (const GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U)) { if (!AllUsesOfValueWillTrapIfNull(GEPI, PHIs)) return false; } else if (const PHINode *PN = dyn_cast<PHINode>(U)) { @@ -735,10 +778,9 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { UI = V->user_begin(); } } - } else if (CastInst *CI = dyn_cast<CastInst>(I)) { - Changed |= OptimizeAwayTrappingUsesOfValue(CI, - ConstantExpr::getCast(CI->getOpcode(), - NewV, CI->getType())); + } else if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(I)) { + Changed |= OptimizeAwayTrappingUsesOfValue( + CI, ConstantExpr::getAddrSpaceCast(NewV, CI->getType())); if (CI->use_empty()) { Changed = true; CI->eraseFromParent(); @@ -803,7 +845,8 @@ static bool OptimizeAwayTrappingUsesOfLoads( assert((isa<PHINode>(GlobalUser) || isa<SelectInst>(GlobalUser) || isa<ConstantExpr>(GlobalUser) || isa<CmpInst>(GlobalUser) || isa<BitCastInst>(GlobalUser) || - isa<GetElementPtrInst>(GlobalUser)) && + isa<GetElementPtrInst>(GlobalUser) || + isa<AddrSpaceCastInst>(GlobalUser)) && "Only expect load and stores!"); } } @@ -976,7 +1019,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, cast<StoreInst>(InitBool->user_back())->eraseFromParent(); delete InitBool; } else - GV->getParent()->getGlobalList().insert(GV->getIterator(), InitBool); + GV->getParent()->insertGlobalVariable(GV->getIterator(), InitBool); // Now the GV is dead, nuke it and the allocation.. GV->eraseFromParent(); @@ -1103,9 +1146,6 @@ optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, nullptr /* F */, GV->getInitializer()->getType()->getPointerAddressSpace())) { if (Constant *SOVC = dyn_cast<Constant>(StoredOnceVal)) { - if (GV->getInitializer()->getType() != SOVC->getType()) - SOVC = ConstantExpr::getBitCast(SOVC, GV->getInitializer()->getType()); - // Optimize away any trapping uses of the loaded value. if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC, DL, GetTLI)) return true; @@ -1158,7 +1198,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { GV->getThreadLocalMode(), GV->getType()->getAddressSpace()); NewGV->copyAttributesFrom(GV); - GV->getParent()->getGlobalList().insert(GV->getIterator(), NewGV); + GV->getParent()->insertGlobalVariable(GV->getIterator(), NewGV); Constant *InitVal = GV->getInitializer(); assert(InitVal->getType() != Type::getInt1Ty(GV->getContext()) && @@ -1330,18 +1370,6 @@ static bool isPointerValueDeadOnEntryToFunction( SmallVector<LoadInst *, 4> Loads; SmallVector<StoreInst *, 4> Stores; for (auto *U : GV->users()) { - if (Operator::getOpcode(U) == Instruction::BitCast) { - for (auto *UU : U->users()) { - if (auto *LI = dyn_cast<LoadInst>(UU)) - Loads.push_back(LI); - else if (auto *SI = dyn_cast<StoreInst>(UU)) - Stores.push_back(SI); - else - return false; - } - continue; - } - Instruction *I = dyn_cast<Instruction>(U); if (!I) return false; @@ -1391,62 +1419,6 @@ static bool isPointerValueDeadOnEntryToFunction( return true; } -/// C may have non-instruction users. Can all of those users be turned into -/// instructions? -static bool allNonInstructionUsersCanBeMadeInstructions(Constant *C) { - // We don't do this exhaustively. The most common pattern that we really need - // to care about is a constant GEP or constant bitcast - so just looking - // through one single ConstantExpr. - // - // The set of constants that this function returns true for must be able to be - // handled by makeAllConstantUsesInstructions. - for (auto *U : C->users()) { - if (isa<Instruction>(U)) - continue; - if (!isa<ConstantExpr>(U)) - // Non instruction, non-constantexpr user; cannot convert this. - return false; - for (auto *UU : U->users()) - if (!isa<Instruction>(UU)) - // A constantexpr used by another constant. We don't try and recurse any - // further but just bail out at this point. - return false; - } - - return true; -} - -/// C may have non-instruction users, and -/// allNonInstructionUsersCanBeMadeInstructions has returned true. Convert the -/// non-instruction users to instructions. -static void makeAllConstantUsesInstructions(Constant *C) { - SmallVector<ConstantExpr*,4> Users; - for (auto *U : C->users()) { - if (isa<ConstantExpr>(U)) - Users.push_back(cast<ConstantExpr>(U)); - else - // We should never get here; allNonInstructionUsersCanBeMadeInstructions - // should not have returned true for C. - assert( - isa<Instruction>(U) && - "Can't transform non-constantexpr non-instruction to instruction!"); - } - - SmallVector<Value*,4> UUsers; - for (auto *U : Users) { - UUsers.clear(); - append_range(UUsers, U->users()); - for (auto *UU : UUsers) { - Instruction *UI = cast<Instruction>(UU); - Instruction *NewU = U->getAsInstruction(UI); - UI->replaceUsesOfWith(U, NewU); - } - // We've replaced all the uses, so destroy the constant. (destroyConstant - // will update value handles and metadata.) - U->destroyConstant(); - } -} - // For a global variable with one store, if the store dominates any loads, // those loads will always load the stored value (as opposed to the // initializer), even in the presence of recursion. @@ -1504,7 +1476,6 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, GV->getValueType()->isSingleValueType() && GV->getType()->getAddressSpace() == 0 && !GV->isExternallyInitialized() && - allNonInstructionUsersCanBeMadeInstructions(GV) && GS.AccessingFunction->doesNotRecurse() && isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV, LookupDomTree)) { @@ -1520,8 +1491,6 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, if (!isa<UndefValue>(GV->getInitializer())) new StoreInst(GV->getInitializer(), Alloca, &FirstI); - makeAllConstantUsesInstructions(GV); - GV->replaceAllUsesWith(Alloca); GV->eraseFromParent(); ++NumLocalized; @@ -2142,15 +2111,22 @@ static void setUsedInitializer(GlobalVariable &V, return; } + // Get address space of pointers in the array of pointers. + const Type *UsedArrayType = V.getValueType(); + const auto *VAT = cast<ArrayType>(UsedArrayType); + const auto *VEPT = cast<PointerType>(VAT->getArrayElementType()); + // Type of pointer to the array of pointers. - PointerType *Int8PtrTy = Type::getInt8PtrTy(V.getContext(), 0); + PointerType *Int8PtrTy = + Type::getInt8PtrTy(V.getContext(), VEPT->getAddressSpace()); SmallVector<Constant *, 8> UsedArray; for (GlobalValue *GV : Init) { - Constant *Cast - = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy); + Constant *Cast = + ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy); UsedArray.push_back(Cast); } + // Sort to get deterministic order. array_pod_sort(UsedArray.begin(), UsedArray.end(), compareNames); ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size()); @@ -2241,22 +2217,11 @@ static bool hasUseOtherThanLLVMUsed(GlobalAlias &GA, const LLVMUsed &U) { return !U.usedCount(&GA) && !U.compilerUsedCount(&GA); } -static bool hasMoreThanOneUseOtherThanLLVMUsed(GlobalValue &V, - const LLVMUsed &U) { - unsigned N = 2; - assert((!U.usedCount(&V) || !U.compilerUsedCount(&V)) && - "We should have removed the duplicated " - "element from llvm.compiler.used"); - if (U.usedCount(&V) || U.compilerUsedCount(&V)) - ++N; - return V.hasNUsesOrMore(N); -} - -static bool mayHaveOtherReferences(GlobalAlias &GA, const LLVMUsed &U) { - if (!GA.hasLocalLinkage()) +static bool mayHaveOtherReferences(GlobalValue &GV, const LLVMUsed &U) { + if (!GV.hasLocalLinkage()) return true; - return U.usedCount(&GA) || U.compilerUsedCount(&GA); + return U.usedCount(&GV) || U.compilerUsedCount(&GV); } static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U, @@ -2270,21 +2235,16 @@ static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U, if (!mayHaveOtherReferences(GA, U)) return Ret; - // If the aliasee has internal linkage, give it the name and linkage - // of the alias, and delete the alias. This turns: + // If the aliasee has internal linkage and no other references (e.g., + // @llvm.used, @llvm.compiler.used), give it the name and linkage of the + // alias, and delete the alias. This turns: // define internal ... @f(...) // @a = alias ... @f // into: // define ... @a(...) Constant *Aliasee = GA.getAliasee(); GlobalValue *Target = cast<GlobalValue>(Aliasee->stripPointerCasts()); - if (!Target->hasLocalLinkage()) - return Ret; - - // Do not perform the transform if multiple aliases potentially target the - // aliasee. This check also ensures that it is safe to replace the section - // and other attributes of the aliasee with those of the alias. - if (hasMoreThanOneUseOtherThanLLVMUsed(*Target, U)) + if (mayHaveOtherReferences(*Target, U)) return Ret; RenameTarget = true; @@ -2360,7 +2320,7 @@ OptimizeGlobalAliases(Module &M, continue; // Delete the alias. - M.getAliasList().erase(&J); + M.eraseAlias(&J); ++NumAliasesRemoved; Changed = true; } @@ -2562,65 +2522,3 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { PA.preserveSet<CFGAnalyses>(); return PA; } - -namespace { - -struct GlobalOptLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - - GlobalOptLegacyPass() : ModulePass(ID) { - initializeGlobalOptLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - auto &DL = M.getDataLayout(); - auto LookupDomTree = [this](Function &F) -> DominatorTree & { - return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); - }; - auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - auto GetTTI = [this](Function &F) -> TargetTransformInfo & { - return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - }; - - auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { - return this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); - }; - - auto ChangedCFGCallback = [&LookupDomTree](Function &F) { - auto &DT = LookupDomTree(F); - DT.recalculate(F); - }; - - return optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree, - ChangedCFGCallback, nullptr); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - -char GlobalOptLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(GlobalOptLegacyPass, "globalopt", - "Global Variable Optimizer", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(GlobalOptLegacyPass, "globalopt", - "Global Variable Optimizer", false, false) - -ModulePass *llvm::createGlobalOptimizerPass() { - return new GlobalOptLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/llvm/lib/Transforms/IPO/GlobalSplit.cpp index 7d9e6135b2eb..84e9c219f935 100644 --- a/llvm/lib/Transforms/IPO/GlobalSplit.cpp +++ b/llvm/lib/Transforms/IPO/GlobalSplit.cpp @@ -29,8 +29,6 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/IPO.h" #include <cstdint> @@ -149,8 +147,12 @@ static bool splitGlobals(Module &M) { M.getFunction(Intrinsic::getName(Intrinsic::type_test)); Function *TypeCheckedLoadFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *TypeCheckedLoadRelativeFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative)); if ((!TypeTestFunc || TypeTestFunc->use_empty()) && - (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) && + (!TypeCheckedLoadRelativeFunc || + TypeCheckedLoadRelativeFunc->use_empty())) return false; bool Changed = false; @@ -159,33 +161,6 @@ static bool splitGlobals(Module &M) { return Changed; } -namespace { - -struct GlobalSplit : public ModulePass { - static char ID; - - GlobalSplit() : ModulePass(ID) { - initializeGlobalSplitPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - return splitGlobals(M); - } -}; - -} // end anonymous namespace - -char GlobalSplit::ID = 0; - -INITIALIZE_PASS(GlobalSplit, "globalsplit", "Global splitter", false, false) - -ModulePass *llvm::createGlobalSplitPass() { - return new GlobalSplit; -} - PreservedAnalyses GlobalSplitPass::run(Module &M, ModuleAnalysisManager &AM) { if (!splitGlobals(M)) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index 95e8ae0fd22f..599ace9ca79f 100644 --- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -46,8 +46,6 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -169,23 +167,6 @@ static bool markFunctionCold(Function &F, bool UpdateEntryCount = false) { return Changed; } -class HotColdSplittingLegacyPass : public ModulePass { -public: - static char ID; - HotColdSplittingLegacyPass() : ModulePass(ID) { - initializeHotColdSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addUsedIfAvailable<AssumptionCacheTracker>(); - } - - bool runOnModule(Module &M) override; -}; - } // end anonymous namespace /// Check whether \p F is inherently cold. @@ -713,32 +694,6 @@ bool HotColdSplitting::run(Module &M) { return Changed; } -bool HotColdSplittingLegacyPass::runOnModule(Module &M) { - if (skipModule(M)) - return false; - ProfileSummaryInfo *PSI = - &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - auto GTTI = [this](Function &F) -> TargetTransformInfo & { - return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - }; - auto GBFI = [this](Function &F) { - return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); - }; - std::unique_ptr<OptimizationRemarkEmitter> ORE; - std::function<OptimizationRemarkEmitter &(Function &)> GetORE = - [&ORE](Function &F) -> OptimizationRemarkEmitter & { - ORE.reset(new OptimizationRemarkEmitter(&F)); - return *ORE; - }; - auto LookupAC = [this](Function &F) -> AssumptionCache * { - if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) - return ACT->lookupAssumptionCache(F); - return nullptr; - }; - - return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M); -} - PreservedAnalyses HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); @@ -769,15 +724,3 @@ HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { return PreservedAnalyses::none(); return PreservedAnalyses::all(); } - -char HotColdSplittingLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(HotColdSplittingLegacyPass, "hotcoldsplit", - "Hot Cold Splitting", false, false) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_END(HotColdSplittingLegacyPass, "hotcoldsplit", - "Hot Cold Splitting", false, false) - -ModulePass *llvm::createHotColdSplittingPass() { - return new HotColdSplittingLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/IPO.cpp b/llvm/lib/Transforms/IPO/IPO.cpp index 4163c448dc8f..5ad1289277a7 100644 --- a/llvm/lib/Transforms/IPO/IPO.cpp +++ b/llvm/lib/Transforms/IPO/IPO.cpp @@ -12,9 +12,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm-c/Transforms/IPO.h" -#include "llvm-c/Initialization.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" @@ -23,104 +20,10 @@ using namespace llvm; void llvm::initializeIPO(PassRegistry &Registry) { - initializeAnnotation2MetadataLegacyPass(Registry); - initializeCalledValuePropagationLegacyPassPass(Registry); - initializeConstantMergeLegacyPassPass(Registry); - initializeCrossDSOCFIPass(Registry); initializeDAEPass(Registry); initializeDAHPass(Registry); - initializeForceFunctionAttrsLegacyPassPass(Registry); - initializeGlobalDCELegacyPassPass(Registry); - initializeGlobalOptLegacyPassPass(Registry); - initializeGlobalSplitPass(Registry); - initializeHotColdSplittingLegacyPassPass(Registry); - initializeIROutlinerLegacyPassPass(Registry); initializeAlwaysInlinerLegacyPassPass(Registry); - initializeSimpleInlinerPass(Registry); - initializeInferFunctionAttrsLegacyPassPass(Registry); - initializeInternalizeLegacyPassPass(Registry); initializeLoopExtractorLegacyPassPass(Registry); initializeSingleLoopExtractorPass(Registry); - initializeMergeFunctionsLegacyPassPass(Registry); - initializePartialInlinerLegacyPassPass(Registry); - initializeAttributorLegacyPassPass(Registry); - initializeAttributorCGSCCLegacyPassPass(Registry); - initializePostOrderFunctionAttrsLegacyPassPass(Registry); - initializeReversePostOrderFunctionAttrsLegacyPassPass(Registry); - initializeIPSCCPLegacyPassPass(Registry); - initializeStripDeadPrototypesLegacyPassPass(Registry); - initializeStripSymbolsPass(Registry); - initializeStripDebugDeclarePass(Registry); - initializeStripDeadDebugInfoPass(Registry); - initializeStripNonDebugSymbolsPass(Registry); initializeBarrierNoopPass(Registry); - initializeEliminateAvailableExternallyLegacyPassPass(Registry); -} - -void LLVMInitializeIPO(LLVMPassRegistryRef R) { - initializeIPO(*unwrap(R)); -} - -void LLVMAddCalledValuePropagationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCalledValuePropagationPass()); -} - -void LLVMAddConstantMergePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createConstantMergePass()); -} - -void LLVMAddDeadArgEliminationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createDeadArgEliminationPass()); -} - -void LLVMAddFunctionAttrsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createPostOrderFunctionAttrsLegacyPass()); -} - -void LLVMAddFunctionInliningPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createFunctionInliningPass()); -} - -void LLVMAddAlwaysInlinerPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(llvm::createAlwaysInlinerLegacyPass()); -} - -void LLVMAddGlobalDCEPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createGlobalDCEPass()); -} - -void LLVMAddGlobalOptimizerPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createGlobalOptimizerPass()); -} - -void LLVMAddIPSCCPPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createIPSCCPPass()); -} - -void LLVMAddMergeFunctionsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createMergeFunctionsPass()); -} - -void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) { - auto PreserveMain = [=](const GlobalValue &GV) { - return AllButMain && GV.getName() == "main"; - }; - unwrap(PM)->add(createInternalizePass(PreserveMain)); -} - -void LLVMAddInternalizePassWithMustPreservePredicate( - LLVMPassManagerRef PM, - void *Context, - LLVMBool (*Pred)(LLVMValueRef, void *)) { - unwrap(PM)->add(createInternalizePass([=](const GlobalValue &GV) { - return Pred(wrap(&GV), Context) == 0 ? false : true; - })); -} - -void LLVMAddStripDeadPrototypesPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createStripDeadPrototypesPass()); -} - -void LLVMAddStripSymbolsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createStripSymbolsPass()); } diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp index f5c52e5c7f5d..e258299c6a4c 100644 --- a/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -22,8 +22,6 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Mangler.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include <optional> @@ -179,10 +177,8 @@ static void getSortedConstantKeys(std::vector<Value *> &SortedKeys, stable_sort(SortedKeys, [](const Value *LHS, const Value *RHS) { assert(LHS && RHS && "Expected non void values."); - const ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS); - const ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS); - assert(RHSC && "Not a constant integer in return value?"); - assert(LHSC && "Not a constant integer in return value?"); + const ConstantInt *LHSC = cast<ConstantInt>(LHS); + const ConstantInt *RHSC = cast<ConstantInt>(RHS); return LHSC->getLimitedValue() < RHSC->getLimitedValue(); }); @@ -590,7 +586,7 @@ collectRegionsConstants(OutlinableRegion &Region, // While this value is a register, it might not have been previously, // make sure we don't already have a constant mapped to this global value // number. - if (GVNToConstant.find(GVN) != GVNToConstant.end()) + if (GVNToConstant.contains(GVN)) ConstantsTheSame = false; NotSame.insert(GVN); @@ -818,7 +814,7 @@ static void mapInputsToGVNs(IRSimilarityCandidate &C, // replacement. for (Value *Input : CurrentInputs) { assert(Input && "Have a nullptr as an input"); - if (OutputMappings.find(Input) != OutputMappings.end()) + if (OutputMappings.contains(Input)) Input = OutputMappings.find(Input)->second; assert(C.getGVN(Input) && "Could not find a numbering for the given input"); EndInputNumbers.push_back(*C.getGVN(Input)); @@ -840,7 +836,7 @@ remapExtractedInputs(const ArrayRef<Value *> ArgInputs, // Get the global value number for each input that will be extracted as an // argument by the code extractor, remapping if needed for reloaded values. for (Value *Input : ArgInputs) { - if (OutputMappings.find(Input) != OutputMappings.end()) + if (OutputMappings.contains(Input)) Input = OutputMappings.find(Input)->second; RemappedArgInputs.insert(Input); } @@ -1332,7 +1328,7 @@ findExtractedOutputToOverallOutputMapping(Module &M, OutlinableRegion &Region, unsigned AggArgIdx = 0; for (unsigned Jdx = TypeIndex; Jdx < ArgumentSize; Jdx++) { - if (Group.ArgumentTypes[Jdx] != PointerType::getUnqual(Output->getType())) + if (!isa<PointerType>(Group.ArgumentTypes[Jdx])) continue; if (AggArgsUsed.contains(Jdx)) @@ -1483,8 +1479,7 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { } // If it is a constant, we simply add it to the argument list as a value. - if (Region.AggArgToConstant.find(AggArgIdx) != - Region.AggArgToConstant.end()) { + if (Region.AggArgToConstant.contains(AggArgIdx)) { Constant *CST = Region.AggArgToConstant.find(AggArgIdx)->second; LLVM_DEBUG(dbgs() << "Setting argument " << AggArgIdx << " to value " << *CST << "\n"); @@ -1818,8 +1813,7 @@ replaceArgumentUses(OutlinableRegion &Region, for (unsigned ArgIdx = 0; ArgIdx < Region.ExtractedFunction->arg_size(); ArgIdx++) { - assert(Region.ExtractedArgToAgg.find(ArgIdx) != - Region.ExtractedArgToAgg.end() && + assert(Region.ExtractedArgToAgg.contains(ArgIdx) && "No mapping from extracted to outlined?"); unsigned AggArgIdx = Region.ExtractedArgToAgg.find(ArgIdx)->second; Argument *AggArg = Group.OutlinedFunction->getArg(AggArgIdx); @@ -2700,7 +2694,7 @@ void IROutliner::updateOutputMapping(OutlinableRegion &Region, if (!OutputIdx) return; - if (OutputMappings.find(Outputs[*OutputIdx]) == OutputMappings.end()) { + if (!OutputMappings.contains(Outputs[*OutputIdx])) { LLVM_DEBUG(dbgs() << "Mapping extracted output " << *LI << " to " << *Outputs[*OutputIdx] << "\n"); OutputMappings.insert(std::make_pair(LI, Outputs[*OutputIdx])); @@ -3024,46 +3018,6 @@ bool IROutliner::run(Module &M) { return doOutline(M) > 0; } -// Pass Manager Boilerplate -namespace { -class IROutlinerLegacyPass : public ModulePass { -public: - static char ID; - IROutlinerLegacyPass() : ModulePass(ID) { - initializeIROutlinerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<IRSimilarityIdentifierWrapperPass>(); - } - - bool runOnModule(Module &M) override; -}; -} // namespace - -bool IROutlinerLegacyPass::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - std::unique_ptr<OptimizationRemarkEmitter> ORE; - auto GORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & { - ORE.reset(new OptimizationRemarkEmitter(&F)); - return *ORE; - }; - - auto GTTI = [this](Function &F) -> TargetTransformInfo & { - return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - }; - - auto GIRSI = [this](Module &) -> IRSimilarityIdentifier & { - return this->getAnalysis<IRSimilarityIdentifierWrapperPass>().getIRSI(); - }; - - return IROutliner(GTTI, GIRSI, GORE).run(M); -} - PreservedAnalyses IROutlinerPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); @@ -3088,14 +3042,3 @@ PreservedAnalyses IROutlinerPass::run(Module &M, ModuleAnalysisManager &AM) { return PreservedAnalyses::none(); return PreservedAnalyses::all(); } - -char IROutlinerLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(IROutlinerLegacyPass, "iroutliner", "IR Outliner", false, - false) -INITIALIZE_PASS_DEPENDENCY(IRSimilarityIdentifierWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(IROutlinerLegacyPass, "iroutliner", "IR Outliner", false, - false) - -ModulePass *llvm::createIROutlinerPass() { return new IROutlinerLegacyPass(); } diff --git a/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp b/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp index 76f8f1a7a482..18d5911d10f1 100644 --- a/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -10,7 +10,6 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -52,38 +51,3 @@ PreservedAnalyses InferFunctionAttrsPass::run(Module &M, // out all the passes. return PreservedAnalyses::none(); } - -namespace { -struct InferFunctionAttrsLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - InferFunctionAttrsLegacyPass() : ModulePass(ID) { - initializeInferFunctionAttrsLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - return inferAllPrototypeAttributes(M, GetTLI); - } -}; -} - -char InferFunctionAttrsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(InferFunctionAttrsLegacyPass, "inferattrs", - "Infer set function attributes", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(InferFunctionAttrsLegacyPass, "inferattrs", - "Infer set function attributes", false, false) - -Pass *llvm::createInferFunctionAttrsLegacyPass() { - return new InferFunctionAttrsLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/InlineSimple.cpp b/llvm/lib/Transforms/IPO/InlineSimple.cpp deleted file mode 100644 index eba0d6636d6c..000000000000 --- a/llvm/lib/Transforms/IPO/InlineSimple.cpp +++ /dev/null @@ -1,118 +0,0 @@ -//===- InlineSimple.cpp - Code to perform simple function inlining --------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements bottom-up inlining of functions into callees. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/InitializePasses.h" -#include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/IPO/Inliner.h" - -using namespace llvm; - -#define DEBUG_TYPE "inline" - -namespace { - -/// Actual inliner pass implementation. -/// -/// The common implementation of the inlining logic is shared between this -/// inliner pass and the always inliner pass. The two passes use different cost -/// analyses to determine when to inline. -class SimpleInliner : public LegacyInlinerBase { - - InlineParams Params; - -public: - SimpleInliner() : LegacyInlinerBase(ID), Params(llvm::getInlineParams()) { - initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); - } - - explicit SimpleInliner(InlineParams Params) - : LegacyInlinerBase(ID), Params(std::move(Params)) { - initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); - } - - static char ID; // Pass identification, replacement for typeid - - InlineCost getInlineCost(CallBase &CB) override { - Function *Callee = CB.getCalledFunction(); - TargetTransformInfo &TTI = TTIWP->getTTI(*Callee); - - bool RemarksEnabled = false; - const auto &BBs = *CB.getCaller(); - if (!BBs.empty()) { - auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBs.front()); - if (DI.isEnabled()) - RemarksEnabled = true; - } - OptimizationRemarkEmitter ORE(CB.getCaller()); - - std::function<AssumptionCache &(Function &)> GetAssumptionCache = - [&](Function &F) -> AssumptionCache & { - return ACT->getAssumptionCache(F); - }; - return llvm::getInlineCost(CB, Params, TTI, GetAssumptionCache, GetTLI, - /*GetBFI=*/nullptr, PSI, - RemarksEnabled ? &ORE : nullptr); - } - - bool runOnSCC(CallGraphSCC &SCC) override; - void getAnalysisUsage(AnalysisUsage &AU) const override; - -private: - TargetTransformInfoWrapperPass *TTIWP; - -}; - -} // end anonymous namespace - -char SimpleInliner::ID = 0; -INITIALIZE_PASS_BEGIN(SimpleInliner, "inline", "Function Integration/Inlining", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(SimpleInliner, "inline", "Function Integration/Inlining", - false, false) - -Pass *llvm::createFunctionInliningPass() { return new SimpleInliner(); } - -Pass *llvm::createFunctionInliningPass(int Threshold) { - return new SimpleInliner(llvm::getInlineParams(Threshold)); -} - -Pass *llvm::createFunctionInliningPass(unsigned OptLevel, - unsigned SizeOptLevel, - bool DisableInlineHotCallSite) { - auto Param = llvm::getInlineParams(OptLevel, SizeOptLevel); - if (DisableInlineHotCallSite) - Param.HotCallSiteThreshold = 0; - return new SimpleInliner(Param); -} - -Pass *llvm::createFunctionInliningPass(InlineParams &Params) { - return new SimpleInliner(Params); -} - -bool SimpleInliner::runOnSCC(CallGraphSCC &SCC) { - TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>(); - return LegacyInlinerBase::runOnSCC(SCC); -} - -void SimpleInliner::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<TargetTransformInfoWrapperPass>(); - LegacyInlinerBase::getAnalysisUsage(AU); -} diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp index 5bcfc38c585b..3e00aebce372 100644 --- a/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/llvm/lib/Transforms/IPO/Inliner.cpp @@ -27,7 +27,6 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CGSCCPassManager.h" -#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LazyCallGraph.h" @@ -71,20 +70,7 @@ using namespace llvm; #define DEBUG_TYPE "inline" STATISTIC(NumInlined, "Number of functions inlined"); -STATISTIC(NumCallsDeleted, "Number of call sites deleted, not inlined"); STATISTIC(NumDeleted, "Number of functions deleted because all callers found"); -STATISTIC(NumMergedAllocas, "Number of allocas merged together"); - -/// Flag to disable manual alloca merging. -/// -/// Merging of allocas was originally done as a stack-size saving technique -/// prior to LLVM's code generator having support for stack coloring based on -/// lifetime markers. It is now in the process of being removed. To experiment -/// with disabling it and relying fully on lifetime marker based stack -/// coloring, you can pass this flag to LLVM. -static cl::opt<bool> - DisableInlinedAllocaMerging("disable-inlined-alloca-merging", - cl::init(false), cl::Hidden); static cl::opt<int> IntraSCCCostMultiplier( "intra-scc-cost-multiplier", cl::init(2), cl::Hidden, @@ -108,9 +94,6 @@ static cl::opt<bool> EnablePostSCCAdvisorPrinting("enable-scc-inline-advisor-printing", cl::init(false), cl::Hidden); -namespace llvm { -extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats; -} static cl::opt<std::string> CGSCCInlineReplayFile( "cgscc-inline-replay", cl::init(""), cl::value_desc("filename"), @@ -163,174 +146,6 @@ static cl::opt<CallSiteFormat::Format> CGSCCInlineReplayFormat( "<Line Number>:<Column Number>.<Discriminator> (default)")), cl::desc("How cgscc inline replay file is formatted"), cl::Hidden); -LegacyInlinerBase::LegacyInlinerBase(char &ID) : CallGraphSCCPass(ID) {} - -LegacyInlinerBase::LegacyInlinerBase(char &ID, bool InsertLifetime) - : CallGraphSCCPass(ID), InsertLifetime(InsertLifetime) {} - -/// For this class, we declare that we require and preserve the call graph. -/// If the derived class implements this method, it should -/// always explicitly call the implementation here. -void LegacyInlinerBase::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - getAAResultsAnalysisUsage(AU); - CallGraphSCCPass::getAnalysisUsage(AU); -} - -using InlinedArrayAllocasTy = DenseMap<ArrayType *, std::vector<AllocaInst *>>; - -/// Look at all of the allocas that we inlined through this call site. If we -/// have already inlined other allocas through other calls into this function, -/// then we know that they have disjoint lifetimes and that we can merge them. -/// -/// There are many heuristics possible for merging these allocas, and the -/// different options have different tradeoffs. One thing that we *really* -/// don't want to hurt is SRoA: once inlining happens, often allocas are no -/// longer address taken and so they can be promoted. -/// -/// Our "solution" for that is to only merge allocas whose outermost type is an -/// array type. These are usually not promoted because someone is using a -/// variable index into them. These are also often the most important ones to -/// merge. -/// -/// A better solution would be to have real memory lifetime markers in the IR -/// and not have the inliner do any merging of allocas at all. This would -/// allow the backend to do proper stack slot coloring of all allocas that -/// *actually make it to the backend*, which is really what we want. -/// -/// Because we don't have this information, we do this simple and useful hack. -static void mergeInlinedArrayAllocas(Function *Caller, InlineFunctionInfo &IFI, - InlinedArrayAllocasTy &InlinedArrayAllocas, - int InlineHistory) { - SmallPtrSet<AllocaInst *, 16> UsedAllocas; - - // When processing our SCC, check to see if the call site was inlined from - // some other call site. For example, if we're processing "A" in this code: - // A() { B() } - // B() { x = alloca ... C() } - // C() { y = alloca ... } - // Assume that C was not inlined into B initially, and so we're processing A - // and decide to inline B into A. Doing this makes an alloca available for - // reuse and makes a callsite (C) available for inlining. When we process - // the C call site we don't want to do any alloca merging between X and Y - // because their scopes are not disjoint. We could make this smarter by - // keeping track of the inline history for each alloca in the - // InlinedArrayAllocas but this isn't likely to be a significant win. - if (InlineHistory != -1) // Only do merging for top-level call sites in SCC. - return; - - // Loop over all the allocas we have so far and see if they can be merged with - // a previously inlined alloca. If not, remember that we had it. - for (unsigned AllocaNo = 0, E = IFI.StaticAllocas.size(); AllocaNo != E; - ++AllocaNo) { - AllocaInst *AI = IFI.StaticAllocas[AllocaNo]; - - // Don't bother trying to merge array allocations (they will usually be - // canonicalized to be an allocation *of* an array), or allocations whose - // type is not itself an array (because we're afraid of pessimizing SRoA). - ArrayType *ATy = dyn_cast<ArrayType>(AI->getAllocatedType()); - if (!ATy || AI->isArrayAllocation()) - continue; - - // Get the list of all available allocas for this array type. - std::vector<AllocaInst *> &AllocasForType = InlinedArrayAllocas[ATy]; - - // Loop over the allocas in AllocasForType to see if we can reuse one. Note - // that we have to be careful not to reuse the same "available" alloca for - // multiple different allocas that we just inlined, we use the 'UsedAllocas' - // set to keep track of which "available" allocas are being used by this - // function. Also, AllocasForType can be empty of course! - bool MergedAwayAlloca = false; - for (AllocaInst *AvailableAlloca : AllocasForType) { - Align Align1 = AI->getAlign(); - Align Align2 = AvailableAlloca->getAlign(); - - // The available alloca has to be in the right function, not in some other - // function in this SCC. - if (AvailableAlloca->getParent() != AI->getParent()) - continue; - - // If the inlined function already uses this alloca then we can't reuse - // it. - if (!UsedAllocas.insert(AvailableAlloca).second) - continue; - - // Otherwise, we *can* reuse it, RAUW AI into AvailableAlloca and declare - // success! - LLVM_DEBUG(dbgs() << " ***MERGED ALLOCA: " << *AI - << "\n\t\tINTO: " << *AvailableAlloca << '\n'); - - // Move affected dbg.declare calls immediately after the new alloca to - // avoid the situation when a dbg.declare precedes its alloca. - if (auto *L = LocalAsMetadata::getIfExists(AI)) - if (auto *MDV = MetadataAsValue::getIfExists(AI->getContext(), L)) - for (User *U : MDV->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - DDI->moveBefore(AvailableAlloca->getNextNode()); - - AI->replaceAllUsesWith(AvailableAlloca); - - if (Align1 > Align2) - AvailableAlloca->setAlignment(AI->getAlign()); - - AI->eraseFromParent(); - MergedAwayAlloca = true; - ++NumMergedAllocas; - IFI.StaticAllocas[AllocaNo] = nullptr; - break; - } - - // If we already nuked the alloca, we're done with it. - if (MergedAwayAlloca) - continue; - - // If we were unable to merge away the alloca either because there are no - // allocas of the right type available or because we reused them all - // already, remember that this alloca came from an inlined function and mark - // it used so we don't reuse it for other allocas from this inline - // operation. - AllocasForType.push_back(AI); - UsedAllocas.insert(AI); - } -} - -/// If it is possible to inline the specified call site, -/// do so and update the CallGraph for this operation. -/// -/// This function also does some basic book-keeping to update the IR. The -/// InlinedArrayAllocas map keeps track of any allocas that are already -/// available from other functions inlined into the caller. If we are able to -/// inline this call site we attempt to reuse already available allocas or add -/// any new allocas to the set if not possible. -static InlineResult inlineCallIfPossible( - CallBase &CB, InlineFunctionInfo &IFI, - InlinedArrayAllocasTy &InlinedArrayAllocas, int InlineHistory, - bool InsertLifetime, function_ref<AAResults &(Function &)> &AARGetter, - ImportedFunctionsInliningStatistics &ImportedFunctionsStats) { - Function *Callee = CB.getCalledFunction(); - Function *Caller = CB.getCaller(); - - AAResults &AAR = AARGetter(*Callee); - - // Try to inline the function. Get the list of static allocas that were - // inlined. - InlineResult IR = - InlineFunction(CB, IFI, - /*MergeAttributes=*/true, &AAR, InsertLifetime); - if (!IR.isSuccess()) - return IR; - - if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) - ImportedFunctionsStats.recordInline(*Caller, *Callee); - - if (!DisableInlinedAllocaMerging) - mergeInlinedArrayAllocas(Caller, IFI, InlinedArrayAllocas, InlineHistory); - - return IR; // success -} - /// Return true if the specified inline history ID /// indicates an inline history that includes the specified function. static bool inlineHistoryIncludes( @@ -346,361 +161,6 @@ static bool inlineHistoryIncludes( return false; } -bool LegacyInlinerBase::doInitialization(CallGraph &CG) { - if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) - ImportedFunctionsStats.setModuleInfo(CG.getModule()); - return false; // No changes to CallGraph. -} - -bool LegacyInlinerBase::runOnSCC(CallGraphSCC &SCC) { - if (skipSCC(SCC)) - return false; - return inlineCalls(SCC); -} - -static bool -inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, - std::function<AssumptionCache &(Function &)> GetAssumptionCache, - ProfileSummaryInfo *PSI, - std::function<const TargetLibraryInfo &(Function &)> GetTLI, - bool InsertLifetime, - function_ref<InlineCost(CallBase &CB)> GetInlineCost, - function_ref<AAResults &(Function &)> AARGetter, - ImportedFunctionsInliningStatistics &ImportedFunctionsStats) { - SmallPtrSet<Function *, 8> SCCFunctions; - LLVM_DEBUG(dbgs() << "Inliner visiting SCC:"); - for (CallGraphNode *Node : SCC) { - Function *F = Node->getFunction(); - if (F) - SCCFunctions.insert(F); - LLVM_DEBUG(dbgs() << " " << (F ? F->getName() : "INDIRECTNODE")); - } - - // Scan through and identify all call sites ahead of time so that we only - // inline call sites in the original functions, not call sites that result - // from inlining other functions. - SmallVector<std::pair<CallBase *, int>, 16> CallSites; - - // When inlining a callee produces new call sites, we want to keep track of - // the fact that they were inlined from the callee. This allows us to avoid - // infinite inlining in some obscure cases. To represent this, we use an - // index into the InlineHistory vector. - SmallVector<std::pair<Function *, int>, 8> InlineHistory; - - for (CallGraphNode *Node : SCC) { - Function *F = Node->getFunction(); - if (!F || F->isDeclaration()) - continue; - - OptimizationRemarkEmitter ORE(F); - for (BasicBlock &BB : *F) - for (Instruction &I : BB) { - auto *CB = dyn_cast<CallBase>(&I); - // If this isn't a call, or it is a call to an intrinsic, it can - // never be inlined. - if (!CB || isa<IntrinsicInst>(I)) - continue; - - // If this is a direct call to an external function, we can never inline - // it. If it is an indirect call, inlining may resolve it to be a - // direct call, so we keep it. - if (Function *Callee = CB->getCalledFunction()) - if (Callee->isDeclaration()) { - using namespace ore; - - setInlineRemark(*CB, "unavailable definition"); - ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) - << NV("Callee", Callee) << " will not be inlined into " - << NV("Caller", CB->getCaller()) - << " because its definition is unavailable" - << setIsVerbose(); - }); - continue; - } - - CallSites.push_back(std::make_pair(CB, -1)); - } - } - - LLVM_DEBUG(dbgs() << ": " << CallSites.size() << " call sites.\n"); - - // If there are no calls in this function, exit early. - if (CallSites.empty()) - return false; - - // Now that we have all of the call sites, move the ones to functions in the - // current SCC to the end of the list. - unsigned FirstCallInSCC = CallSites.size(); - for (unsigned I = 0; I < FirstCallInSCC; ++I) - if (Function *F = CallSites[I].first->getCalledFunction()) - if (SCCFunctions.count(F)) - std::swap(CallSites[I--], CallSites[--FirstCallInSCC]); - - InlinedArrayAllocasTy InlinedArrayAllocas; - InlineFunctionInfo InlineInfo(&CG, GetAssumptionCache, PSI); - - // Now that we have all of the call sites, loop over them and inline them if - // it looks profitable to do so. - bool Changed = false; - bool LocalChange; - do { - LocalChange = false; - // Iterate over the outer loop because inlining functions can cause indirect - // calls to become direct calls. - // CallSites may be modified inside so ranged for loop can not be used. - for (unsigned CSi = 0; CSi != CallSites.size(); ++CSi) { - auto &P = CallSites[CSi]; - CallBase &CB = *P.first; - const int InlineHistoryID = P.second; - - Function *Caller = CB.getCaller(); - Function *Callee = CB.getCalledFunction(); - - // We can only inline direct calls to non-declarations. - if (!Callee || Callee->isDeclaration()) - continue; - - bool IsTriviallyDead = isInstructionTriviallyDead(&CB, &GetTLI(*Caller)); - - if (!IsTriviallyDead) { - // If this call site was obtained by inlining another function, verify - // that the include path for the function did not include the callee - // itself. If so, we'd be recursively inlining the same function, - // which would provide the same callsites, which would cause us to - // infinitely inline. - if (InlineHistoryID != -1 && - inlineHistoryIncludes(Callee, InlineHistoryID, InlineHistory)) { - setInlineRemark(CB, "recursive"); - continue; - } - } - - // FIXME for new PM: because of the old PM we currently generate ORE and - // in turn BFI on demand. With the new PM, the ORE dependency should - // just become a regular analysis dependency. - OptimizationRemarkEmitter ORE(Caller); - - auto OIC = shouldInline(CB, GetInlineCost, ORE); - // If the policy determines that we should inline this function, - // delete the call instead. - if (!OIC) - continue; - - // If this call site is dead and it is to a readonly function, we should - // just delete the call instead of trying to inline it, regardless of - // size. This happens because IPSCCP propagates the result out of the - // call and then we're left with the dead call. - if (IsTriviallyDead) { - LLVM_DEBUG(dbgs() << " -> Deleting dead call: " << CB << "\n"); - // Update the call graph by deleting the edge from Callee to Caller. - setInlineRemark(CB, "trivially dead"); - CG[Caller]->removeCallEdgeFor(CB); - CB.eraseFromParent(); - ++NumCallsDeleted; - } else { - // Get DebugLoc to report. CB will be invalid after Inliner. - DebugLoc DLoc = CB.getDebugLoc(); - BasicBlock *Block = CB.getParent(); - - // Attempt to inline the function. - using namespace ore; - - InlineResult IR = inlineCallIfPossible( - CB, InlineInfo, InlinedArrayAllocas, InlineHistoryID, - InsertLifetime, AARGetter, ImportedFunctionsStats); - if (!IR.isSuccess()) { - setInlineRemark(CB, std::string(IR.getFailureReason()) + "; " + - inlineCostStr(*OIC)); - ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, - Block) - << NV("Callee", Callee) << " will not be inlined into " - << NV("Caller", Caller) << ": " - << NV("Reason", IR.getFailureReason()); - }); - continue; - } - ++NumInlined; - - emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC); - - // If inlining this function gave us any new call sites, throw them - // onto our worklist to process. They are useful inline candidates. - if (!InlineInfo.InlinedCalls.empty()) { - // Create a new inline history entry for this, so that we remember - // that these new callsites came about due to inlining Callee. - int NewHistoryID = InlineHistory.size(); - InlineHistory.push_back(std::make_pair(Callee, InlineHistoryID)); - -#ifndef NDEBUG - // Make sure no dupplicates in the inline candidates. This could - // happen when a callsite is simpilfied to reusing the return value - // of another callsite during function cloning, thus the other - // callsite will be reconsidered here. - DenseSet<CallBase *> DbgCallSites; - for (auto &II : CallSites) - DbgCallSites.insert(II.first); -#endif - - for (Value *Ptr : InlineInfo.InlinedCalls) { -#ifndef NDEBUG - assert(DbgCallSites.count(dyn_cast<CallBase>(Ptr)) == 0); -#endif - CallSites.push_back( - std::make_pair(dyn_cast<CallBase>(Ptr), NewHistoryID)); - } - } - } - - // If we inlined or deleted the last possible call site to the function, - // delete the function body now. - if (Callee && Callee->use_empty() && Callee->hasLocalLinkage() && - // TODO: Can remove if in SCC now. - !SCCFunctions.count(Callee) && - // The function may be apparently dead, but if there are indirect - // callgraph references to the node, we cannot delete it yet, this - // could invalidate the CGSCC iterator. - CG[Callee]->getNumReferences() == 0) { - LLVM_DEBUG(dbgs() << " -> Deleting dead function: " - << Callee->getName() << "\n"); - CallGraphNode *CalleeNode = CG[Callee]; - - // Remove any call graph edges from the callee to its callees. - CalleeNode->removeAllCalledFunctions(); - - // Removing the node for callee from the call graph and delete it. - delete CG.removeFunctionFromModule(CalleeNode); - ++NumDeleted; - } - - // Remove this call site from the list. If possible, use - // swap/pop_back for efficiency, but do not use it if doing so would - // move a call site to a function in this SCC before the - // 'FirstCallInSCC' barrier. - if (SCC.isSingular()) { - CallSites[CSi] = CallSites.back(); - CallSites.pop_back(); - } else { - CallSites.erase(CallSites.begin() + CSi); - } - --CSi; - - Changed = true; - LocalChange = true; - } - } while (LocalChange); - - return Changed; -} - -bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) { - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - ACT = &getAnalysis<AssumptionCacheTracker>(); - PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - GetTLI = [&](Function &F) -> const TargetLibraryInfo & { - return getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { - return ACT->getAssumptionCache(F); - }; - return inlineCallsImpl( - SCC, CG, GetAssumptionCache, PSI, GetTLI, InsertLifetime, - [&](CallBase &CB) { return getInlineCost(CB); }, LegacyAARGetter(*this), - ImportedFunctionsStats); -} - -/// Remove now-dead linkonce functions at the end of -/// processing to avoid breaking the SCC traversal. -bool LegacyInlinerBase::doFinalization(CallGraph &CG) { - if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) - ImportedFunctionsStats.dump(InlinerFunctionImportStats == - InlinerFunctionImportStatsOpts::Verbose); - return removeDeadFunctions(CG); -} - -/// Remove dead functions that are not included in DNR (Do Not Remove) list. -bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG, - bool AlwaysInlineOnly) { - SmallVector<CallGraphNode *, 16> FunctionsToRemove; - SmallVector<Function *, 16> DeadFunctionsInComdats; - - auto RemoveCGN = [&](CallGraphNode *CGN) { - // Remove any call graph edges from the function to its callees. - CGN->removeAllCalledFunctions(); - - // Remove any edges from the external node to the function's call graph - // node. These edges might have been made irrelegant due to - // optimization of the program. - CG.getExternalCallingNode()->removeAnyCallEdgeTo(CGN); - - // Removing the node for callee from the call graph and delete it. - FunctionsToRemove.push_back(CGN); - }; - - // Scan for all of the functions, looking for ones that should now be removed - // from the program. Insert the dead ones in the FunctionsToRemove set. - for (const auto &I : CG) { - CallGraphNode *CGN = I.second.get(); - Function *F = CGN->getFunction(); - if (!F || F->isDeclaration()) - continue; - - // Handle the case when this function is called and we only want to care - // about always-inline functions. This is a bit of a hack to share code - // between here and the InlineAlways pass. - if (AlwaysInlineOnly && !F->hasFnAttribute(Attribute::AlwaysInline)) - continue; - - // If the only remaining users of the function are dead constants, remove - // them. - F->removeDeadConstantUsers(); - - if (!F->isDefTriviallyDead()) - continue; - - // It is unsafe to drop a function with discardable linkage from a COMDAT - // without also dropping the other members of the COMDAT. - // The inliner doesn't visit non-function entities which are in COMDAT - // groups so it is unsafe to do so *unless* the linkage is local. - if (!F->hasLocalLinkage()) { - if (F->hasComdat()) { - DeadFunctionsInComdats.push_back(F); - continue; - } - } - - RemoveCGN(CGN); - } - if (!DeadFunctionsInComdats.empty()) { - // Filter out the functions whose comdats remain alive. - filterDeadComdatFunctions(DeadFunctionsInComdats); - // Remove the rest. - for (Function *F : DeadFunctionsInComdats) - RemoveCGN(CG[F]); - } - - if (FunctionsToRemove.empty()) - return false; - - // Now that we know which functions to delete, do so. We didn't want to do - // this inline, because that would invalidate our CallGraph::iterator - // objects. :( - // - // Note that it doesn't matter that we are iterating over a non-stable order - // here to do this, it doesn't matter which order the functions are deleted - // in. - array_pod_sort(FunctionsToRemove.begin(), FunctionsToRemove.end()); - FunctionsToRemove.erase( - std::unique(FunctionsToRemove.begin(), FunctionsToRemove.end()), - FunctionsToRemove.end()); - for (CallGraphNode *CGN : FunctionsToRemove) { - delete CG.removeFunctionFromModule(CGN); - ++NumDeleted; - } - return true; -} - InlineAdvisor & InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, FunctionAnalysisManager &FAM, Module &M) { @@ -729,8 +189,7 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, CGSCCInlineReplayFallback, {CGSCCInlineReplayFormat}}, /*EmitRemarks=*/true, - InlineContext{LTOPhase, - InlinePass::ReplayCGSCCInliner}); + InlineContext{LTOPhase, InlinePass::ReplayCGSCCInliner}); return *OwnedAdvisor; } @@ -871,9 +330,12 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (InlineHistoryID != -1 && inlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) { - LLVM_DEBUG(dbgs() << "Skipping inlining due to history: " - << F.getName() << " -> " << Callee.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Skipping inlining due to history: " << F.getName() + << " -> " << Callee.getName() << "\n"); setInlineRemark(*CB, "recursive"); + // Set noinline so that we don't forget this decision across CGSCC + // iterations. + CB->setIsNoInline(); continue; } @@ -911,7 +373,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // Setup the data structure used to plumb customization into the // `InlineFunction` routine. InlineFunctionInfo IFI( - /*cg=*/nullptr, GetAssumptionCache, PSI, + GetAssumptionCache, PSI, &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), &FAM.getResult<BlockFrequencyAnalysis>(Callee)); @@ -1193,13 +655,13 @@ void ModuleInlinerWrapperPass::printPipeline( // on Params and Mode). if (!MPM.isEmpty()) { MPM.printPipeline(OS, MapClassName2PassName); - OS << ","; + OS << ','; } OS << "cgscc("; if (MaxDevirtIterations != 0) OS << "devirt<" << MaxDevirtIterations << ">("; PM.printPipeline(OS, MapClassName2PassName); if (MaxDevirtIterations != 0) - OS << ")"; - OS << ")"; + OS << ')'; + OS << ')'; } diff --git a/llvm/lib/Transforms/IPO/Internalize.cpp b/llvm/lib/Transforms/IPO/Internalize.cpp index 85b1a8303d33..0b8fde6489f8 100644 --- a/llvm/lib/Transforms/IPO/Internalize.cpp +++ b/llvm/lib/Transforms/IPO/Internalize.cpp @@ -19,19 +19,18 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/Internalize.h" +#include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSet.h" -#include "llvm/ADT/Triple.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/GlobPattern.h" #include "llvm/Support/LineIterator.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO.h" using namespace llvm; @@ -183,9 +182,8 @@ void InternalizePass::checkComdat( Info.External = true; } -bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { +bool InternalizePass::internalizeModule(Module &M) { bool Changed = false; - CallGraphNode *ExternalNode = CG ? CG->getExternalCallingNode() : nullptr; SmallVector<GlobalValue *, 4> Used; collectUsedGlobalVariables(M, Used, false); @@ -242,10 +240,6 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { continue; Changed = true; - if (ExternalNode) - // Remove a callgraph edge from the external node to this function. - ExternalNode->removeOneAbstractEdgeTo((*CG)[&I]); - ++NumFunctions; LLVM_DEBUG(dbgs() << "Internalizing func " << I.getName() << "\n"); } @@ -277,55 +271,8 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { InternalizePass::InternalizePass() : MustPreserveGV(PreserveAPIList()) {} PreservedAnalyses InternalizePass::run(Module &M, ModuleAnalysisManager &AM) { - if (!internalizeModule(M, AM.getCachedResult<CallGraphAnalysis>(M))) + if (!internalizeModule(M)) return PreservedAnalyses::all(); - PreservedAnalyses PA; - PA.preserve<CallGraphAnalysis>(); - return PA; -} - -namespace { -class InternalizeLegacyPass : public ModulePass { - // Client supplied callback to control wheter a symbol must be preserved. - std::function<bool(const GlobalValue &)> MustPreserveGV; - -public: - static char ID; // Pass identification, replacement for typeid - - InternalizeLegacyPass() : ModulePass(ID), MustPreserveGV(PreserveAPIList()) {} - - InternalizeLegacyPass(std::function<bool(const GlobalValue &)> MustPreserveGV) - : ModulePass(ID), MustPreserveGV(std::move(MustPreserveGV)) { - initializeInternalizeLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - CallGraphWrapperPass *CGPass = - getAnalysisIfAvailable<CallGraphWrapperPass>(); - CallGraph *CG = CGPass ? &CGPass->getCallGraph() : nullptr; - return internalizeModule(M, MustPreserveGV, CG); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addPreserved<CallGraphWrapperPass>(); - } -}; -} - -char InternalizeLegacyPass::ID = 0; -INITIALIZE_PASS(InternalizeLegacyPass, "internalize", - "Internalize Global Symbols", false, false) - -ModulePass *llvm::createInternalizePass() { - return new InternalizeLegacyPass(); -} - -ModulePass *llvm::createInternalizePass( - std::function<bool(const GlobalValue &)> MustPreserveGV) { - return new InternalizeLegacyPass(std::move(MustPreserveGV)); + return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/llvm/lib/Transforms/IPO/LoopExtractor.cpp index ad1927c09803..9a5876f85ba7 100644 --- a/llvm/lib/Transforms/IPO/LoopExtractor.cpp +++ b/llvm/lib/Transforms/IPO/LoopExtractor.cpp @@ -283,8 +283,8 @@ void LoopExtractorPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (NumLoops == 1) OS << "single"; - OS << ">"; + OS << '>'; } diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index ddfcace6acf8..9b4b3efd7283 100644 --- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -24,7 +24,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TinyPtrVector.h" -#include "llvm/ADT/Triple.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" @@ -51,12 +51,11 @@ #include "llvm/IR/ModuleSummaryIndexYAML.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ReplaceConstant.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -69,6 +68,7 @@ #include "llvm/Support/TrailingObjects.h" #include "llvm/Support/YAMLTraits.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -172,7 +172,7 @@ BitSetInfo BitSetBuilder::build() { BSI.AlignLog2 = 0; if (Mask != 0) - BSI.AlignLog2 = countTrailingZeros(Mask); + BSI.AlignLog2 = llvm::countr_zero(Mask); // Build the compressed bitset while normalizing the offsets against the // computed alignment. @@ -242,7 +242,7 @@ bool lowertypetests::isJumpTableCanonical(Function *F) { return false; auto *CI = mdconst::extract_or_null<ConstantInt>( F->getParent()->getModuleFlag("CFI Canonical Jump Tables")); - if (!CI || CI->getZExtValue() != 0) + if (!CI || !CI->isZero()) return true; return F->hasFnAttribute("cfi-canonical-jump-table"); } @@ -406,6 +406,15 @@ class LowerTypeTestsModule { Triple::OSType OS; Triple::ObjectFormatType ObjectFormat; + // Determines which kind of Thumb jump table we generate. If arch is + // either 'arm' or 'thumb' we need to find this out, because + // selectJumpTableArmEncoding may decide to use Thumb in either case. + bool CanUseArmJumpTable = false, CanUseThumbBWJumpTable = false; + + // The jump table type we ended up deciding on. (Usually the same as + // Arch, except that 'arm' and 'thumb' are often interchangeable.) + Triple::ArchType JumpTableArch = Triple::UnknownArch; + IntegerType *Int1Ty = Type::getInt1Ty(M.getContext()); IntegerType *Int8Ty = Type::getInt8Ty(M.getContext()); PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext()); @@ -481,6 +490,8 @@ class LowerTypeTestsModule { void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals); + Triple::ArchType + selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS, @@ -518,7 +529,8 @@ class LowerTypeTestsModule { void replaceDirectCalls(Value *Old, Value *New); public: - LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary, + LowerTypeTestsModule(Module &M, ModuleAnalysisManager &AM, + ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary, bool DropTypeTests); @@ -526,7 +538,7 @@ public: // Lower the module using the action and summary passed as command line // arguments. For testing purposes only. - static bool runForTesting(Module &M); + static bool runForTesting(Module &M, ModuleAnalysisManager &AM); }; } // end anonymous namespace @@ -686,7 +698,7 @@ static bool isKnownTypeIdMember(Metadata *TypeId, const DataLayout &DL, } if (auto GEP = dyn_cast<GEPOperator>(V)) { - APInt APOffset(DL.getPointerSizeInBits(0), 0); + APInt APOffset(DL.getIndexSizeInBits(0), 0); bool Result = GEP->accumulateConstantOffset(DL, APOffset); if (!Result) return false; @@ -1182,31 +1194,36 @@ static const unsigned kX86JumpTableEntrySize = 8; static const unsigned kX86IBTJumpTableEntrySize = 16; static const unsigned kARMJumpTableEntrySize = 4; static const unsigned kARMBTIJumpTableEntrySize = 8; +static const unsigned kARMv6MJumpTableEntrySize = 16; static const unsigned kRISCVJumpTableEntrySize = 8; unsigned LowerTypeTestsModule::getJumpTableEntrySize() { - switch (Arch) { - case Triple::x86: - case Triple::x86_64: - if (const auto *MD = mdconst::extract_or_null<ConstantInt>( + switch (JumpTableArch) { + case Triple::x86: + case Triple::x86_64: + if (const auto *MD = mdconst::extract_or_null<ConstantInt>( M.getModuleFlag("cf-protection-branch"))) - if (MD->getZExtValue()) - return kX86IBTJumpTableEntrySize; - return kX86JumpTableEntrySize; - case Triple::arm: - case Triple::thumb: + if (MD->getZExtValue()) + return kX86IBTJumpTableEntrySize; + return kX86JumpTableEntrySize; + case Triple::arm: + return kARMJumpTableEntrySize; + case Triple::thumb: + if (CanUseThumbBWJumpTable) return kARMJumpTableEntrySize; - case Triple::aarch64: - if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( + else + return kARMv6MJumpTableEntrySize; + case Triple::aarch64: + if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( M.getModuleFlag("branch-target-enforcement"))) - if (BTE->getZExtValue()) - return kARMBTIJumpTableEntrySize; - return kARMJumpTableEntrySize; - case Triple::riscv32: - case Triple::riscv64: - return kRISCVJumpTableEntrySize; - default: - report_fatal_error("Unsupported architecture for jump tables"); + if (BTE->getZExtValue()) + return kARMBTIJumpTableEntrySize; + return kARMJumpTableEntrySize; + case Triple::riscv32: + case Triple::riscv64: + return kRISCVJumpTableEntrySize; + default: + report_fatal_error("Unsupported architecture for jump tables"); } } @@ -1223,7 +1240,7 @@ void LowerTypeTestsModule::createJumpTableEntry( bool Endbr = false; if (const auto *MD = mdconst::extract_or_null<ConstantInt>( Dest->getParent()->getModuleFlag("cf-protection-branch"))) - Endbr = MD->getZExtValue() != 0; + Endbr = !MD->isZero(); if (Endbr) AsmOS << (JumpTableArch == Triple::x86 ? "endbr32\n" : "endbr64\n"); AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n"; @@ -1240,7 +1257,32 @@ void LowerTypeTestsModule::createJumpTableEntry( AsmOS << "bti c\n"; AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::thumb) { - AsmOS << "b.w $" << ArgIndex << "\n"; + if (!CanUseThumbBWJumpTable) { + // In Armv6-M, this sequence will generate a branch without corrupting + // any registers. We use two stack words; in the second, we construct the + // address we'll pop into pc, and the first is used to save and restore + // r0 which we use as a temporary register. + // + // To support position-independent use cases, the offset of the target + // function is stored as a relative offset (which will expand into an + // R_ARM_REL32 relocation in ELF, and presumably the equivalent in other + // object file types), and added to pc after we load it. (The alternative + // B.W is automatically pc-relative.) + // + // There are five 16-bit Thumb instructions here, so the .balign 4 adds a + // sixth halfword of padding, and then the offset consumes a further 4 + // bytes, for a total of 16, which is very convenient since entries in + // this jump table need to have power-of-two size. + AsmOS << "push {r0,r1}\n" + << "ldr r0, 1f\n" + << "0: add r0, r0, pc\n" + << "str r0, [sp, #4]\n" + << "pop {r0,pc}\n" + << ".balign 4\n" + << "1: .word $" << ArgIndex << " - (0b + 4)\n"; + } else { + AsmOS << "b.w $" << ArgIndex << "\n"; + } } else if (JumpTableArch == Triple::riscv32 || JumpTableArch == Triple::riscv64) { AsmOS << "tail $" << ArgIndex << "@plt\n"; @@ -1325,11 +1367,27 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( F->getAddressSpace(), "", &M); replaceCfiUses(F, PlaceholderFn, IsJumpTableCanonical); - Constant *Target = ConstantExpr::getSelect( - ConstantExpr::getICmp(CmpInst::ICMP_NE, F, - Constant::getNullValue(F->getType())), - JT, Constant::getNullValue(F->getType())); - PlaceholderFn->replaceAllUsesWith(Target); + convertUsersOfConstantsToInstructions(PlaceholderFn); + // Don't use range based loop, because use list will be modified. + while (!PlaceholderFn->use_empty()) { + Use &U = *PlaceholderFn->use_begin(); + auto *InsertPt = dyn_cast<Instruction>(U.getUser()); + assert(InsertPt && "Non-instruction users should have been eliminated"); + auto *PN = dyn_cast<PHINode>(InsertPt); + if (PN) + InsertPt = PN->getIncomingBlock(U)->getTerminator(); + IRBuilder Builder(InsertPt); + Value *ICmp = Builder.CreateICmp(CmpInst::ICMP_NE, F, + Constant::getNullValue(F->getType())); + Value *Select = Builder.CreateSelect(ICmp, JT, + Constant::getNullValue(F->getType())); + // For phi nodes, we need to update the incoming value for all operands + // with the same predecessor. + if (PN) + PN->setIncomingValueForBlock(InsertPt->getParent(), Select); + else + U.set(Select); + } PlaceholderFn->eraseFromParent(); } @@ -1352,12 +1410,19 @@ static bool isThumbFunction(Function *F, Triple::ArchType ModuleArch) { // Each jump table must be either ARM or Thumb as a whole for the bit-test math // to work. Pick one that matches the majority of members to minimize interop // veneers inserted by the linker. -static Triple::ArchType -selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions, - Triple::ArchType ModuleArch) { - if (ModuleArch != Triple::arm && ModuleArch != Triple::thumb) - return ModuleArch; +Triple::ArchType LowerTypeTestsModule::selectJumpTableArmEncoding( + ArrayRef<GlobalTypeMember *> Functions) { + if (Arch != Triple::arm && Arch != Triple::thumb) + return Arch; + + if (!CanUseThumbBWJumpTable && CanUseArmJumpTable) { + // In architectures that provide Arm and Thumb-1 but not Thumb-2, + // we should always prefer the Arm jump table format, because the + // Thumb-1 one is larger and slower. + return Triple::arm; + } + // Otherwise, go with majority vote. unsigned ArmCount = 0, ThumbCount = 0; for (const auto GTM : Functions) { if (!GTM->isJumpTableCanonical()) { @@ -1368,7 +1433,7 @@ selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions, } Function *F = cast<Function>(GTM->getGlobal()); - ++(isThumbFunction(F, ModuleArch) ? ThumbCount : ArmCount); + ++(isThumbFunction(F, Arch) ? ThumbCount : ArmCount); } return ArmCount > ThumbCount ? Triple::arm : Triple::thumb; @@ -1381,8 +1446,6 @@ void LowerTypeTestsModule::createJumpTable( SmallVector<Value *, 16> AsmArgs; AsmArgs.reserve(Functions.size() * 2); - Triple::ArchType JumpTableArch = selectJumpTableArmEncoding(Functions, Arch); - for (GlobalTypeMember *GTM : Functions) createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs, cast<Function>(GTM->getGlobal())); @@ -1399,9 +1462,11 @@ void LowerTypeTestsModule::createJumpTable( F->addFnAttr("target-features", "-thumb-mode"); if (JumpTableArch == Triple::thumb) { F->addFnAttr("target-features", "+thumb-mode"); - // Thumb jump table assembly needs Thumb2. The following attribute is added - // by Clang for -march=armv7. - F->addFnAttr("target-cpu", "cortex-a8"); + if (CanUseThumbBWJumpTable) { + // Thumb jump table assembly needs Thumb2. The following attribute is + // added by Clang for -march=armv7. + F->addFnAttr("target-cpu", "cortex-a8"); + } } // When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI // for the function to avoid double BTI. This is a no-op without @@ -1521,6 +1586,10 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( // FIXME: find a better way to represent the jumptable in the IR. assert(!Functions.empty()); + // Decide on the jump table encoding, so that we know how big the + // entries will be. + JumpTableArch = selectJumpTableArmEncoding(Functions); + // Build a simple layout based on the regular layout of jump tables. DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout; unsigned EntrySize = getJumpTableEntrySize(); @@ -1706,18 +1775,31 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet( /// Lower all type tests in this module. LowerTypeTestsModule::LowerTypeTestsModule( - Module &M, ModuleSummaryIndex *ExportSummary, + Module &M, ModuleAnalysisManager &AM, ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary, bool DropTypeTests) : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary), DropTypeTests(DropTypeTests || ClDropTypeTests) { assert(!(ExportSummary && ImportSummary)); Triple TargetTriple(M.getTargetTriple()); Arch = TargetTriple.getArch(); + if (Arch == Triple::arm) + CanUseArmJumpTable = true; + if (Arch == Triple::arm || Arch == Triple::thumb) { + auto &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) { + auto &TTI = FAM.getResult<TargetIRAnalysis>(F); + if (TTI.hasArmWideBranch(false)) + CanUseArmJumpTable = true; + if (TTI.hasArmWideBranch(true)) + CanUseThumbBWJumpTable = true; + } + } OS = TargetTriple.getOS(); ObjectFormat = TargetTriple.getObjectFormat(); } -bool LowerTypeTestsModule::runForTesting(Module &M) { +bool LowerTypeTestsModule::runForTesting(Module &M, ModuleAnalysisManager &AM) { ModuleSummaryIndex Summary(/*HaveGVs=*/false); // Handle the command-line summary arguments. This code is for testing @@ -1735,7 +1817,8 @@ bool LowerTypeTestsModule::runForTesting(Module &M) { bool Changed = LowerTypeTestsModule( - M, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, + M, AM, + ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr, /*DropTypeTests*/ false) .lower(); @@ -2186,9 +2269,9 @@ bool LowerTypeTestsModule::lower() { unsigned MaxUniqueId = 0; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I); MI != GlobalClasses.member_end(); ++MI) { - if (auto *MD = MI->dyn_cast<Metadata *>()) + if (auto *MD = dyn_cast_if_present<Metadata *>(*MI)) MaxUniqueId = std::max(MaxUniqueId, TypeIdInfo[MD].UniqueId); - else if (auto *BF = MI->dyn_cast<ICallBranchFunnel *>()) + else if (auto *BF = dyn_cast_if_present<ICallBranchFunnel *>(*MI)) MaxUniqueId = std::max(MaxUniqueId, BF->UniqueId); } Sets.emplace_back(I, MaxUniqueId); @@ -2204,12 +2287,12 @@ bool LowerTypeTestsModule::lower() { for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(S.first); MI != GlobalClasses.member_end(); ++MI) { - if (MI->is<Metadata *>()) - TypeIds.push_back(MI->get<Metadata *>()); - else if (MI->is<GlobalTypeMember *>()) - Globals.push_back(MI->get<GlobalTypeMember *>()); + if (isa<Metadata *>(*MI)) + TypeIds.push_back(cast<Metadata *>(*MI)); + else if (isa<GlobalTypeMember *>(*MI)) + Globals.push_back(cast<GlobalTypeMember *>(*MI)); else - ICallBranchFunnels.push_back(MI->get<ICallBranchFunnel *>()); + ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(*MI)); } // Order type identifiers by unique ID for determinism. This ordering is @@ -2298,10 +2381,10 @@ PreservedAnalyses LowerTypeTestsPass::run(Module &M, ModuleAnalysisManager &AM) { bool Changed; if (UseCommandLine) - Changed = LowerTypeTestsModule::runForTesting(M); + Changed = LowerTypeTestsModule::runForTesting(M, AM); else Changed = - LowerTypeTestsModule(M, ExportSummary, ImportSummary, DropTypeTests) + LowerTypeTestsModule(M, AM, ExportSummary, ImportSummary, DropTypeTests) .lower(); if (!Changed) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp new file mode 100644 index 000000000000..f835fb26fcb8 --- /dev/null +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -0,0 +1,3277 @@ +//==-- MemProfContextDisambiguation.cpp - Disambiguate contexts -------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements support for context disambiguation of allocation +// calls for profile guided heap optimization. Specifically, it uses Memprof +// profiles which indicate context specific allocation behavior (currently +// distinguishing cold vs hot memory allocations). Cloning is performed to +// expose the cold allocation call contexts, and the allocation calls are +// subsequently annotated with an attribute for later transformation. +// +// The transformations can be performed either directly on IR (regular LTO), or +// on a ThinLTO index (and later applied to the IR during the ThinLTO backend). +// Both types of LTO operate on a the same base graph representation, which +// uses CRTP to support either IR or Index formats. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/MemProfContextDisambiguation.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/MemoryProfileInfo.h" +#include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndex.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include <sstream> +#include <vector> +using namespace llvm; +using namespace llvm::memprof; + +#define DEBUG_TYPE "memprof-context-disambiguation" + +STATISTIC(FunctionClonesAnalysis, + "Number of function clones created during whole program analysis"); +STATISTIC(FunctionClonesThinBackend, + "Number of function clones created during ThinLTO backend"); +STATISTIC(FunctionsClonedThinBackend, + "Number of functions that had clones created during ThinLTO backend"); +STATISTIC(AllocTypeNotCold, "Number of not cold static allocations (possibly " + "cloned) during whole program analysis"); +STATISTIC(AllocTypeCold, "Number of cold static allocations (possibly cloned) " + "during whole program analysis"); +STATISTIC(AllocTypeNotColdThinBackend, + "Number of not cold static allocations (possibly cloned) during " + "ThinLTO backend"); +STATISTIC(AllocTypeColdThinBackend, "Number of cold static allocations " + "(possibly cloned) during ThinLTO backend"); +STATISTIC(OrigAllocsThinBackend, + "Number of original (not cloned) allocations with memprof profiles " + "during ThinLTO backend"); +STATISTIC( + AllocVersionsThinBackend, + "Number of allocation versions (including clones) during ThinLTO backend"); +STATISTIC(MaxAllocVersionsThinBackend, + "Maximum number of allocation versions created for an original " + "allocation during ThinLTO backend"); +STATISTIC(UnclonableAllocsThinBackend, + "Number of unclonable ambigous allocations during ThinLTO backend"); + +static cl::opt<std::string> DotFilePathPrefix( + "memprof-dot-file-path-prefix", cl::init(""), cl::Hidden, + cl::value_desc("filename"), + cl::desc("Specify the path prefix of the MemProf dot files.")); + +static cl::opt<bool> ExportToDot("memprof-export-to-dot", cl::init(false), + cl::Hidden, + cl::desc("Export graph to dot files.")); + +static cl::opt<bool> + DumpCCG("memprof-dump-ccg", cl::init(false), cl::Hidden, + cl::desc("Dump CallingContextGraph to stdout after each stage.")); + +static cl::opt<bool> + VerifyCCG("memprof-verify-ccg", cl::init(false), cl::Hidden, + cl::desc("Perform verification checks on CallingContextGraph.")); + +static cl::opt<bool> + VerifyNodes("memprof-verify-nodes", cl::init(false), cl::Hidden, + cl::desc("Perform frequent verification checks on nodes.")); + +static cl::opt<std::string> MemProfImportSummary( + "memprof-import-summary", + cl::desc("Import summary to use for testing the ThinLTO backend via opt"), + cl::Hidden); + +// Indicate we are linking with an allocator that supports hot/cold operator +// new interfaces. +cl::opt<bool> SupportsHotColdNew( + "supports-hot-cold-new", cl::init(false), cl::Hidden, + cl::desc("Linking with hot/cold operator new interfaces")); + +namespace { +/// CRTP base for graphs built from either IR or ThinLTO summary index. +/// +/// The graph represents the call contexts in all memprof metadata on allocation +/// calls, with nodes for the allocations themselves, as well as for the calls +/// in each context. The graph is initially built from the allocation memprof +/// metadata (or summary) MIBs. It is then updated to match calls with callsite +/// metadata onto the nodes, updating it to reflect any inlining performed on +/// those calls. +/// +/// Each MIB (representing an allocation's call context with allocation +/// behavior) is assigned a unique context id during the graph build. The edges +/// and nodes in the graph are decorated with the context ids they carry. This +/// is used to correctly update the graph when cloning is performed so that we +/// can uniquify the context for a single (possibly cloned) allocation. +template <typename DerivedCCG, typename FuncTy, typename CallTy> +class CallsiteContextGraph { +public: + CallsiteContextGraph() = default; + CallsiteContextGraph(const CallsiteContextGraph &) = default; + CallsiteContextGraph(CallsiteContextGraph &&) = default; + + /// Main entry point to perform analysis and transformations on graph. + bool process(); + + /// Perform cloning on the graph necessary to uniquely identify the allocation + /// behavior of an allocation based on its context. + void identifyClones(); + + /// Assign callsite clones to functions, cloning functions as needed to + /// accommodate the combinations of their callsite clones reached by callers. + /// For regular LTO this clones functions and callsites in the IR, but for + /// ThinLTO the cloning decisions are noted in the summaries and later applied + /// in applyImport. + bool assignFunctions(); + + void dump() const; + void print(raw_ostream &OS) const; + + friend raw_ostream &operator<<(raw_ostream &OS, + const CallsiteContextGraph &CCG) { + CCG.print(OS); + return OS; + } + + friend struct GraphTraits< + const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *>; + friend struct DOTGraphTraits< + const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *>; + + void exportToDot(std::string Label) const; + + /// Represents a function clone via FuncTy pointer and clone number pair. + struct FuncInfo final + : public std::pair<FuncTy *, unsigned /*Clone number*/> { + using Base = std::pair<FuncTy *, unsigned>; + FuncInfo(const Base &B) : Base(B) {} + FuncInfo(FuncTy *F = nullptr, unsigned CloneNo = 0) : Base(F, CloneNo) {} + explicit operator bool() const { return this->first != nullptr; } + FuncTy *func() const { return this->first; } + unsigned cloneNo() const { return this->second; } + }; + + /// Represents a callsite clone via CallTy and clone number pair. + struct CallInfo final : public std::pair<CallTy, unsigned /*Clone number*/> { + using Base = std::pair<CallTy, unsigned>; + CallInfo(const Base &B) : Base(B) {} + CallInfo(CallTy Call = nullptr, unsigned CloneNo = 0) + : Base(Call, CloneNo) {} + explicit operator bool() const { return (bool)this->first; } + CallTy call() const { return this->first; } + unsigned cloneNo() const { return this->second; } + void setCloneNo(unsigned N) { this->second = N; } + void print(raw_ostream &OS) const { + if (!operator bool()) { + assert(!cloneNo()); + OS << "null Call"; + return; + } + call()->print(OS); + OS << "\t(clone " << cloneNo() << ")"; + } + void dump() const { + print(dbgs()); + dbgs() << "\n"; + } + friend raw_ostream &operator<<(raw_ostream &OS, const CallInfo &Call) { + Call.print(OS); + return OS; + } + }; + + struct ContextEdge; + + /// Node in the Callsite Context Graph + struct ContextNode { + // Keep this for now since in the IR case where we have an Instruction* it + // is not as immediately discoverable. Used for printing richer information + // when dumping graph. + bool IsAllocation; + + // Keeps track of when the Call was reset to null because there was + // recursion. + bool Recursive = false; + + // The corresponding allocation or interior call. + CallInfo Call; + + // For alloc nodes this is a unique id assigned when constructed, and for + // callsite stack nodes it is the original stack id when the node is + // constructed from the memprof MIB metadata on the alloc nodes. Note that + // this is only used when matching callsite metadata onto the stack nodes + // created when processing the allocation memprof MIBs, and for labeling + // nodes in the dot graph. Therefore we don't bother to assign a value for + // clones. + uint64_t OrigStackOrAllocId = 0; + + // This will be formed by ORing together the AllocationType enum values + // for contexts including this node. + uint8_t AllocTypes = 0; + + // Edges to all callees in the profiled call stacks. + // TODO: Should this be a map (from Callee node) for more efficient lookup? + std::vector<std::shared_ptr<ContextEdge>> CalleeEdges; + + // Edges to all callers in the profiled call stacks. + // TODO: Should this be a map (from Caller node) for more efficient lookup? + std::vector<std::shared_ptr<ContextEdge>> CallerEdges; + + // The set of IDs for contexts including this node. + DenseSet<uint32_t> ContextIds; + + // List of clones of this ContextNode, initially empty. + std::vector<ContextNode *> Clones; + + // If a clone, points to the original uncloned node. + ContextNode *CloneOf = nullptr; + + ContextNode(bool IsAllocation) : IsAllocation(IsAllocation), Call() {} + + ContextNode(bool IsAllocation, CallInfo C) + : IsAllocation(IsAllocation), Call(C) {} + + void addClone(ContextNode *Clone) { + if (CloneOf) { + CloneOf->Clones.push_back(Clone); + Clone->CloneOf = CloneOf; + } else { + Clones.push_back(Clone); + assert(!Clone->CloneOf); + Clone->CloneOf = this; + } + } + + ContextNode *getOrigNode() { + if (!CloneOf) + return this; + return CloneOf; + } + + void addOrUpdateCallerEdge(ContextNode *Caller, AllocationType AllocType, + unsigned int ContextId); + + ContextEdge *findEdgeFromCallee(const ContextNode *Callee); + ContextEdge *findEdgeFromCaller(const ContextNode *Caller); + void eraseCalleeEdge(const ContextEdge *Edge); + void eraseCallerEdge(const ContextEdge *Edge); + + void setCall(CallInfo C) { Call = C; } + + bool hasCall() const { return (bool)Call.call(); } + + void printCall(raw_ostream &OS) const { Call.print(OS); } + + // True if this node was effectively removed from the graph, in which case + // its context id set, caller edges, and callee edges should all be empty. + bool isRemoved() const { + assert(ContextIds.empty() == + (CalleeEdges.empty() && CallerEdges.empty())); + return ContextIds.empty(); + } + + void dump() const; + void print(raw_ostream &OS) const; + + friend raw_ostream &operator<<(raw_ostream &OS, const ContextNode &Node) { + Node.print(OS); + return OS; + } + }; + + /// Edge in the Callsite Context Graph from a ContextNode N to a caller or + /// callee. + struct ContextEdge { + ContextNode *Callee; + ContextNode *Caller; + + // This will be formed by ORing together the AllocationType enum values + // for contexts including this edge. + uint8_t AllocTypes = 0; + + // The set of IDs for contexts including this edge. + DenseSet<uint32_t> ContextIds; + + ContextEdge(ContextNode *Callee, ContextNode *Caller, uint8_t AllocType, + DenseSet<uint32_t> ContextIds) + : Callee(Callee), Caller(Caller), AllocTypes(AllocType), + ContextIds(ContextIds) {} + + DenseSet<uint32_t> &getContextIds() { return ContextIds; } + + void dump() const; + void print(raw_ostream &OS) const; + + friend raw_ostream &operator<<(raw_ostream &OS, const ContextEdge &Edge) { + Edge.print(OS); + return OS; + } + }; + + /// Helper to remove callee edges that have allocation type None (due to not + /// carrying any context ids) after transformations. + void removeNoneTypeCalleeEdges(ContextNode *Node); + +protected: + /// Get a list of nodes corresponding to the stack ids in the given callsite + /// context. + template <class NodeT, class IteratorT> + std::vector<uint64_t> + getStackIdsWithContextNodes(CallStack<NodeT, IteratorT> &CallsiteContext); + + /// Adds nodes for the given allocation and any stack ids on its memprof MIB + /// metadata (or summary). + ContextNode *addAllocNode(CallInfo Call, const FuncTy *F); + + /// Adds nodes for the given MIB stack ids. + template <class NodeT, class IteratorT> + void addStackNodesForMIB(ContextNode *AllocNode, + CallStack<NodeT, IteratorT> &StackContext, + CallStack<NodeT, IteratorT> &CallsiteContext, + AllocationType AllocType); + + /// Matches all callsite metadata (or summary) to the nodes created for + /// allocation memprof MIB metadata, synthesizing new nodes to reflect any + /// inlining performed on those callsite instructions. + void updateStackNodes(); + + /// Update graph to conservatively handle any callsite stack nodes that target + /// multiple different callee target functions. + void handleCallsitesWithMultipleTargets(); + + /// Save lists of calls with MemProf metadata in each function, for faster + /// iteration. + std::vector<std::pair<FuncTy *, std::vector<CallInfo>>> + FuncToCallsWithMetadata; + + /// Map from callsite node to the enclosing caller function. + std::map<const ContextNode *, const FuncTy *> NodeToCallingFunc; + +private: + using EdgeIter = typename std::vector<std::shared_ptr<ContextEdge>>::iterator; + + using CallContextInfo = std::tuple<CallTy, std::vector<uint64_t>, + const FuncTy *, DenseSet<uint32_t>>; + + /// Assigns the given Node to calls at or inlined into the location with + /// the Node's stack id, after post order traversing and processing its + /// caller nodes. Uses the call information recorded in the given + /// StackIdToMatchingCalls map, and creates new nodes for inlined sequences + /// as needed. Called by updateStackNodes which sets up the given + /// StackIdToMatchingCalls map. + void assignStackNodesPostOrder( + ContextNode *Node, DenseSet<const ContextNode *> &Visited, + DenseMap<uint64_t, std::vector<CallContextInfo>> &StackIdToMatchingCalls); + + /// Duplicates the given set of context ids, updating the provided + /// map from each original id with the newly generated context ids, + /// and returning the new duplicated id set. + DenseSet<uint32_t> duplicateContextIds( + const DenseSet<uint32_t> &StackSequenceContextIds, + DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds); + + /// Propagates all duplicated context ids across the graph. + void propagateDuplicateContextIds( + const DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds); + + /// Connect the NewNode to OrigNode's callees if TowardsCallee is true, + /// else to its callers. Also updates OrigNode's edges to remove any context + /// ids moved to the newly created edge. + void connectNewNode(ContextNode *NewNode, ContextNode *OrigNode, + bool TowardsCallee); + + /// Get the stack id corresponding to the given Id or Index (for IR this will + /// return itself, for a summary index this will return the id recorded in the + /// index for that stack id index value). + uint64_t getStackId(uint64_t IdOrIndex) const { + return static_cast<const DerivedCCG *>(this)->getStackId(IdOrIndex); + } + + /// Returns true if the given call targets the given function. + bool calleeMatchesFunc(CallTy Call, const FuncTy *Func) { + return static_cast<DerivedCCG *>(this)->calleeMatchesFunc(Call, Func); + } + + /// Get a list of nodes corresponding to the stack ids in the given + /// callsite's context. + std::vector<uint64_t> getStackIdsWithContextNodesForCall(CallTy Call) { + return static_cast<DerivedCCG *>(this)->getStackIdsWithContextNodesForCall( + Call); + } + + /// Get the last stack id in the context for callsite. + uint64_t getLastStackId(CallTy Call) { + return static_cast<DerivedCCG *>(this)->getLastStackId(Call); + } + + /// Update the allocation call to record type of allocated memory. + void updateAllocationCall(CallInfo &Call, AllocationType AllocType) { + AllocType == AllocationType::Cold ? AllocTypeCold++ : AllocTypeNotCold++; + static_cast<DerivedCCG *>(this)->updateAllocationCall(Call, AllocType); + } + + /// Update non-allocation call to invoke (possibly cloned) function + /// CalleeFunc. + void updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc) { + static_cast<DerivedCCG *>(this)->updateCall(CallerCall, CalleeFunc); + } + + /// Clone the given function for the given callsite, recording mapping of all + /// of the functions tracked calls to their new versions in the CallMap. + /// Assigns new clones to clone number CloneNo. + FuncInfo cloneFunctionForCallsite( + FuncInfo &Func, CallInfo &Call, std::map<CallInfo, CallInfo> &CallMap, + std::vector<CallInfo> &CallsWithMetadataInFunc, unsigned CloneNo) { + return static_cast<DerivedCCG *>(this)->cloneFunctionForCallsite( + Func, Call, CallMap, CallsWithMetadataInFunc, CloneNo); + } + + /// Gets a label to use in the dot graph for the given call clone in the given + /// function. + std::string getLabel(const FuncTy *Func, const CallTy Call, + unsigned CloneNo) const { + return static_cast<const DerivedCCG *>(this)->getLabel(Func, Call, CloneNo); + } + + /// Helpers to find the node corresponding to the given call or stackid. + ContextNode *getNodeForInst(const CallInfo &C); + ContextNode *getNodeForAlloc(const CallInfo &C); + ContextNode *getNodeForStackId(uint64_t StackId); + + /// Removes the node information recorded for the given call. + void unsetNodeForInst(const CallInfo &C); + + /// Computes the alloc type corresponding to the given context ids, by + /// unioning their recorded alloc types. + uint8_t computeAllocType(DenseSet<uint32_t> &ContextIds); + + /// Returns the alloction type of the intersection of the contexts of two + /// nodes (based on their provided context id sets), optimized for the case + /// when Node1Ids is smaller than Node2Ids. + uint8_t intersectAllocTypesImpl(const DenseSet<uint32_t> &Node1Ids, + const DenseSet<uint32_t> &Node2Ids); + + /// Returns the alloction type of the intersection of the contexts of two + /// nodes (based on their provided context id sets). + uint8_t intersectAllocTypes(const DenseSet<uint32_t> &Node1Ids, + const DenseSet<uint32_t> &Node2Ids); + + /// Create a clone of Edge's callee and move Edge to that new callee node, + /// performing the necessary context id and allocation type updates. + /// If callee's caller edge iterator is supplied, it is updated when removing + /// the edge from that list. + ContextNode * + moveEdgeToNewCalleeClone(const std::shared_ptr<ContextEdge> &Edge, + EdgeIter *CallerEdgeI = nullptr); + + /// Change the callee of Edge to existing callee clone NewCallee, performing + /// the necessary context id and allocation type updates. + /// If callee's caller edge iterator is supplied, it is updated when removing + /// the edge from that list. + void moveEdgeToExistingCalleeClone(const std::shared_ptr<ContextEdge> &Edge, + ContextNode *NewCallee, + EdgeIter *CallerEdgeI = nullptr, + bool NewClone = false); + + /// Recursively perform cloning on the graph for the given Node and its + /// callers, in order to uniquely identify the allocation behavior of an + /// allocation given its context. + void identifyClones(ContextNode *Node, + DenseSet<const ContextNode *> &Visited); + + /// Map from each context ID to the AllocationType assigned to that context. + std::map<uint32_t, AllocationType> ContextIdToAllocationType; + + /// Identifies the context node created for a stack id when adding the MIB + /// contexts to the graph. This is used to locate the context nodes when + /// trying to assign the corresponding callsites with those stack ids to these + /// nodes. + std::map<uint64_t, ContextNode *> StackEntryIdToContextNodeMap; + + /// Maps to track the calls to their corresponding nodes in the graph. + MapVector<CallInfo, ContextNode *> AllocationCallToContextNodeMap; + MapVector<CallInfo, ContextNode *> NonAllocationCallToContextNodeMap; + + /// Owner of all ContextNode unique_ptrs. + std::vector<std::unique_ptr<ContextNode>> NodeOwner; + + /// Perform sanity checks on graph when requested. + void check() const; + + /// Keeps track of the last unique context id assigned. + unsigned int LastContextId = 0; +}; + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +using ContextNode = + typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode; +template <typename DerivedCCG, typename FuncTy, typename CallTy> +using ContextEdge = + typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge; +template <typename DerivedCCG, typename FuncTy, typename CallTy> +using FuncInfo = + typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::FuncInfo; +template <typename DerivedCCG, typename FuncTy, typename CallTy> +using CallInfo = + typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::CallInfo; + +/// CRTP derived class for graphs built from IR (regular LTO). +class ModuleCallsiteContextGraph + : public CallsiteContextGraph<ModuleCallsiteContextGraph, Function, + Instruction *> { +public: + ModuleCallsiteContextGraph( + Module &M, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter); + +private: + friend CallsiteContextGraph<ModuleCallsiteContextGraph, Function, + Instruction *>; + + uint64_t getStackId(uint64_t IdOrIndex) const; + bool calleeMatchesFunc(Instruction *Call, const Function *Func); + uint64_t getLastStackId(Instruction *Call); + std::vector<uint64_t> getStackIdsWithContextNodesForCall(Instruction *Call); + void updateAllocationCall(CallInfo &Call, AllocationType AllocType); + void updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc); + CallsiteContextGraph<ModuleCallsiteContextGraph, Function, + Instruction *>::FuncInfo + cloneFunctionForCallsite(FuncInfo &Func, CallInfo &Call, + std::map<CallInfo, CallInfo> &CallMap, + std::vector<CallInfo> &CallsWithMetadataInFunc, + unsigned CloneNo); + std::string getLabel(const Function *Func, const Instruction *Call, + unsigned CloneNo) const; + + const Module &Mod; + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; +}; + +/// Represents a call in the summary index graph, which can either be an +/// allocation or an interior callsite node in an allocation's context. +/// Holds a pointer to the corresponding data structure in the index. +struct IndexCall : public PointerUnion<CallsiteInfo *, AllocInfo *> { + IndexCall() : PointerUnion() {} + IndexCall(std::nullptr_t) : IndexCall() {} + IndexCall(CallsiteInfo *StackNode) : PointerUnion(StackNode) {} + IndexCall(AllocInfo *AllocNode) : PointerUnion(AllocNode) {} + IndexCall(PointerUnion PT) : PointerUnion(PT) {} + + IndexCall *operator->() { return this; } + + PointerUnion<CallsiteInfo *, AllocInfo *> getBase() const { return *this; } + + void print(raw_ostream &OS) const { + if (auto *AI = llvm::dyn_cast_if_present<AllocInfo *>(getBase())) { + OS << *AI; + } else { + auto *CI = llvm::dyn_cast_if_present<CallsiteInfo *>(getBase()); + assert(CI); + OS << *CI; + } + } +}; + +/// CRTP derived class for graphs built from summary index (ThinLTO). +class IndexCallsiteContextGraph + : public CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary, + IndexCall> { +public: + IndexCallsiteContextGraph( + ModuleSummaryIndex &Index, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing); + +private: + friend CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary, + IndexCall>; + + uint64_t getStackId(uint64_t IdOrIndex) const; + bool calleeMatchesFunc(IndexCall &Call, const FunctionSummary *Func); + uint64_t getLastStackId(IndexCall &Call); + std::vector<uint64_t> getStackIdsWithContextNodesForCall(IndexCall &Call); + void updateAllocationCall(CallInfo &Call, AllocationType AllocType); + void updateCall(CallInfo &CallerCall, FuncInfo CalleeFunc); + CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary, + IndexCall>::FuncInfo + cloneFunctionForCallsite(FuncInfo &Func, CallInfo &Call, + std::map<CallInfo, CallInfo> &CallMap, + std::vector<CallInfo> &CallsWithMetadataInFunc, + unsigned CloneNo); + std::string getLabel(const FunctionSummary *Func, const IndexCall &Call, + unsigned CloneNo) const; + + // Saves mapping from function summaries containing memprof records back to + // its VI, for use in checking and debugging. + std::map<const FunctionSummary *, ValueInfo> FSToVIMap; + + const ModuleSummaryIndex &Index; +}; +} // namespace + +namespace llvm { +template <> +struct DenseMapInfo<typename CallsiteContextGraph< + ModuleCallsiteContextGraph, Function, Instruction *>::CallInfo> + : public DenseMapInfo<std::pair<Instruction *, unsigned>> {}; +template <> +struct DenseMapInfo<typename CallsiteContextGraph< + IndexCallsiteContextGraph, FunctionSummary, IndexCall>::CallInfo> + : public DenseMapInfo<std::pair<IndexCall, unsigned>> {}; +template <> +struct DenseMapInfo<IndexCall> + : public DenseMapInfo<PointerUnion<CallsiteInfo *, AllocInfo *>> {}; +} // end namespace llvm + +namespace { + +struct FieldSeparator { + bool Skip = true; + const char *Sep; + + FieldSeparator(const char *Sep = ", ") : Sep(Sep) {} +}; + +raw_ostream &operator<<(raw_ostream &OS, FieldSeparator &FS) { + if (FS.Skip) { + FS.Skip = false; + return OS; + } + return OS << FS.Sep; +} + +// Map the uint8_t alloc types (which may contain NotCold|Cold) to the alloc +// type we should actually use on the corresponding allocation. +// If we can't clone a node that has NotCold+Cold alloc type, we will fall +// back to using NotCold. So don't bother cloning to distinguish NotCold+Cold +// from NotCold. +AllocationType allocTypeToUse(uint8_t AllocTypes) { + assert(AllocTypes != (uint8_t)AllocationType::None); + if (AllocTypes == + ((uint8_t)AllocationType::NotCold | (uint8_t)AllocationType::Cold)) + return AllocationType::NotCold; + else + return (AllocationType)AllocTypes; +} + +// Helper to check if the alloc types for all edges recorded in the +// InAllocTypes vector match the alloc types for all edges in the Edges +// vector. +template <typename DerivedCCG, typename FuncTy, typename CallTy> +bool allocTypesMatch( + const std::vector<uint8_t> &InAllocTypes, + const std::vector<std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>>> + &Edges) { + return std::equal( + InAllocTypes.begin(), InAllocTypes.end(), Edges.begin(), + [](const uint8_t &l, + const std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>> &r) { + // Can share if one of the edges is None type - don't + // care about the type along that edge as it doesn't + // exist for those context ids. + if (l == (uint8_t)AllocationType::None || + r->AllocTypes == (uint8_t)AllocationType::None) + return true; + return allocTypeToUse(l) == allocTypeToUse(r->AllocTypes); + }); +} + +} // end anonymous namespace + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForInst( + const CallInfo &C) { + ContextNode *Node = getNodeForAlloc(C); + if (Node) + return Node; + + return NonAllocationCallToContextNodeMap.lookup(C); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForAlloc( + const CallInfo &C) { + return AllocationCallToContextNodeMap.lookup(C); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForStackId( + uint64_t StackId) { + auto StackEntryNode = StackEntryIdToContextNodeMap.find(StackId); + if (StackEntryNode != StackEntryIdToContextNodeMap.end()) + return StackEntryNode->second; + return nullptr; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::unsetNodeForInst( + const CallInfo &C) { + AllocationCallToContextNodeMap.erase(C) || + NonAllocationCallToContextNodeMap.erase(C); + assert(!AllocationCallToContextNodeMap.count(C) && + !NonAllocationCallToContextNodeMap.count(C)); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: + addOrUpdateCallerEdge(ContextNode *Caller, AllocationType AllocType, + unsigned int ContextId) { + for (auto &Edge : CallerEdges) { + if (Edge->Caller == Caller) { + Edge->AllocTypes |= (uint8_t)AllocType; + Edge->getContextIds().insert(ContextId); + return; + } + } + std::shared_ptr<ContextEdge> Edge = std::make_shared<ContextEdge>( + this, Caller, (uint8_t)AllocType, DenseSet<uint32_t>({ContextId})); + CallerEdges.push_back(Edge); + Caller->CalleeEdges.push_back(Edge); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph< + DerivedCCG, FuncTy, CallTy>::removeNoneTypeCalleeEdges(ContextNode *Node) { + for (auto EI = Node->CalleeEdges.begin(); EI != Node->CalleeEdges.end();) { + auto Edge = *EI; + if (Edge->AllocTypes == (uint8_t)AllocationType::None) { + assert(Edge->ContextIds.empty()); + Edge->Callee->eraseCallerEdge(Edge.get()); + EI = Node->CalleeEdges.erase(EI); + } else + ++EI; + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: + findEdgeFromCallee(const ContextNode *Callee) { + for (const auto &Edge : CalleeEdges) + if (Edge->Callee == Callee) + return Edge.get(); + return nullptr; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: + findEdgeFromCaller(const ContextNode *Caller) { + for (const auto &Edge : CallerEdges) + if (Edge->Caller == Caller) + return Edge.get(); + return nullptr; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: + eraseCalleeEdge(const ContextEdge *Edge) { + auto EI = + std::find_if(CalleeEdges.begin(), CalleeEdges.end(), + [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) { + return CalleeEdge.get() == Edge; + }); + assert(EI != CalleeEdges.end()); + CalleeEdges.erase(EI); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: + eraseCallerEdge(const ContextEdge *Edge) { + auto EI = + std::find_if(CallerEdges.begin(), CallerEdges.end(), + [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) { + return CallerEdge.get() == Edge; + }); + assert(EI != CallerEdges.end()); + CallerEdges.erase(EI); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +uint8_t CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::computeAllocType( + DenseSet<uint32_t> &ContextIds) { + uint8_t BothTypes = + (uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold; + uint8_t AllocType = (uint8_t)AllocationType::None; + for (auto Id : ContextIds) { + AllocType |= (uint8_t)ContextIdToAllocationType[Id]; + // Bail early if alloc type reached both, no further refinement. + if (AllocType == BothTypes) + return AllocType; + } + return AllocType; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +uint8_t +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::intersectAllocTypesImpl( + const DenseSet<uint32_t> &Node1Ids, const DenseSet<uint32_t> &Node2Ids) { + uint8_t BothTypes = + (uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold; + uint8_t AllocType = (uint8_t)AllocationType::None; + for (auto Id : Node1Ids) { + if (!Node2Ids.count(Id)) + continue; + AllocType |= (uint8_t)ContextIdToAllocationType[Id]; + // Bail early if alloc type reached both, no further refinement. + if (AllocType == BothTypes) + return AllocType; + } + return AllocType; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +uint8_t CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::intersectAllocTypes( + const DenseSet<uint32_t> &Node1Ids, const DenseSet<uint32_t> &Node2Ids) { + if (Node1Ids.size() < Node2Ids.size()) + return intersectAllocTypesImpl(Node1Ids, Node2Ids); + else + return intersectAllocTypesImpl(Node2Ids, Node1Ids); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addAllocNode( + CallInfo Call, const FuncTy *F) { + assert(!getNodeForAlloc(Call)); + NodeOwner.push_back( + std::make_unique<ContextNode>(/*IsAllocation=*/true, Call)); + ContextNode *AllocNode = NodeOwner.back().get(); + AllocationCallToContextNodeMap[Call] = AllocNode; + NodeToCallingFunc[AllocNode] = F; + // Use LastContextId as a uniq id for MIB allocation nodes. + AllocNode->OrigStackOrAllocId = LastContextId; + // Alloc type should be updated as we add in the MIBs. We should assert + // afterwards that it is not still None. + AllocNode->AllocTypes = (uint8_t)AllocationType::None; + + return AllocNode; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +template <class NodeT, class IteratorT> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addStackNodesForMIB( + ContextNode *AllocNode, CallStack<NodeT, IteratorT> &StackContext, + CallStack<NodeT, IteratorT> &CallsiteContext, AllocationType AllocType) { + // Treating the hot alloc type as NotCold before the disambiguation for "hot" + // is done. + if (AllocType == AllocationType::Hot) + AllocType = AllocationType::NotCold; + + ContextIdToAllocationType[++LastContextId] = AllocType; + + // Update alloc type and context ids for this MIB. + AllocNode->AllocTypes |= (uint8_t)AllocType; + AllocNode->ContextIds.insert(LastContextId); + + // Now add or update nodes for each stack id in alloc's context. + // Later when processing the stack ids on non-alloc callsites we will adjust + // for any inlining in the context. + ContextNode *PrevNode = AllocNode; + // Look for recursion (direct recursion should have been collapsed by + // module summary analysis, here we should just be detecting mutual + // recursion). Mark these nodes so we don't try to clone. + SmallSet<uint64_t, 8> StackIdSet; + // Skip any on the allocation call (inlining). + for (auto ContextIter = StackContext.beginAfterSharedPrefix(CallsiteContext); + ContextIter != StackContext.end(); ++ContextIter) { + auto StackId = getStackId(*ContextIter); + ContextNode *StackNode = getNodeForStackId(StackId); + if (!StackNode) { + NodeOwner.push_back( + std::make_unique<ContextNode>(/*IsAllocation=*/false)); + StackNode = NodeOwner.back().get(); + StackEntryIdToContextNodeMap[StackId] = StackNode; + StackNode->OrigStackOrAllocId = StackId; + } + auto Ins = StackIdSet.insert(StackId); + if (!Ins.second) + StackNode->Recursive = true; + StackNode->ContextIds.insert(LastContextId); + StackNode->AllocTypes |= (uint8_t)AllocType; + PrevNode->addOrUpdateCallerEdge(StackNode, AllocType, LastContextId); + PrevNode = StackNode; + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +DenseSet<uint32_t> +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::duplicateContextIds( + const DenseSet<uint32_t> &StackSequenceContextIds, + DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds) { + DenseSet<uint32_t> NewContextIds; + for (auto OldId : StackSequenceContextIds) { + NewContextIds.insert(++LastContextId); + OldToNewContextIds[OldId].insert(LastContextId); + assert(ContextIdToAllocationType.count(OldId)); + // The new context has the same allocation type as original. + ContextIdToAllocationType[LastContextId] = ContextIdToAllocationType[OldId]; + } + return NewContextIds; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: + propagateDuplicateContextIds( + const DenseMap<uint32_t, DenseSet<uint32_t>> &OldToNewContextIds) { + // Build a set of duplicated context ids corresponding to the input id set. + auto GetNewIds = [&OldToNewContextIds](const DenseSet<uint32_t> &ContextIds) { + DenseSet<uint32_t> NewIds; + for (auto Id : ContextIds) + if (auto NewId = OldToNewContextIds.find(Id); + NewId != OldToNewContextIds.end()) + NewIds.insert(NewId->second.begin(), NewId->second.end()); + return NewIds; + }; + + // Recursively update context ids sets along caller edges. + auto UpdateCallers = [&](ContextNode *Node, + DenseSet<const ContextEdge *> &Visited, + auto &&UpdateCallers) -> void { + for (const auto &Edge : Node->CallerEdges) { + auto Inserted = Visited.insert(Edge.get()); + if (!Inserted.second) + continue; + ContextNode *NextNode = Edge->Caller; + DenseSet<uint32_t> NewIdsToAdd = GetNewIds(Edge->getContextIds()); + // Only need to recursively iterate to NextNode via this caller edge if + // it resulted in any added ids to NextNode. + if (!NewIdsToAdd.empty()) { + Edge->getContextIds().insert(NewIdsToAdd.begin(), NewIdsToAdd.end()); + NextNode->ContextIds.insert(NewIdsToAdd.begin(), NewIdsToAdd.end()); + UpdateCallers(NextNode, Visited, UpdateCallers); + } + } + }; + + DenseSet<const ContextEdge *> Visited; + for (auto &Entry : AllocationCallToContextNodeMap) { + auto *Node = Entry.second; + // Update ids on the allocation nodes before calling the recursive + // update along caller edges, since this simplifies the logic during + // that traversal. + DenseSet<uint32_t> NewIdsToAdd = GetNewIds(Node->ContextIds); + Node->ContextIds.insert(NewIdsToAdd.begin(), NewIdsToAdd.end()); + UpdateCallers(Node, Visited, UpdateCallers); + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::connectNewNode( + ContextNode *NewNode, ContextNode *OrigNode, bool TowardsCallee) { + // Make a copy of the context ids, since this will be adjusted below as they + // are moved. + DenseSet<uint32_t> RemainingContextIds = NewNode->ContextIds; + auto &OrigEdges = + TowardsCallee ? OrigNode->CalleeEdges : OrigNode->CallerEdges; + // Increment iterator in loop so that we can remove edges as needed. + for (auto EI = OrigEdges.begin(); EI != OrigEdges.end();) { + auto Edge = *EI; + // Remove any matching context ids from Edge, return set that were found and + // removed, these are the new edge's context ids. Also update the remaining + // (not found ids). + DenseSet<uint32_t> NewEdgeContextIds, NotFoundContextIds; + set_subtract(Edge->getContextIds(), RemainingContextIds, NewEdgeContextIds, + NotFoundContextIds); + RemainingContextIds.swap(NotFoundContextIds); + // If no matching context ids for this edge, skip it. + if (NewEdgeContextIds.empty()) { + ++EI; + continue; + } + if (TowardsCallee) { + auto NewEdge = std::make_shared<ContextEdge>( + Edge->Callee, NewNode, computeAllocType(NewEdgeContextIds), + NewEdgeContextIds); + NewNode->CalleeEdges.push_back(NewEdge); + NewEdge->Callee->CallerEdges.push_back(NewEdge); + } else { + auto NewEdge = std::make_shared<ContextEdge>( + NewNode, Edge->Caller, computeAllocType(NewEdgeContextIds), + NewEdgeContextIds); + NewNode->CallerEdges.push_back(NewEdge); + NewEdge->Caller->CalleeEdges.push_back(NewEdge); + } + // Remove old edge if context ids empty. + if (Edge->getContextIds().empty()) { + if (TowardsCallee) { + Edge->Callee->eraseCallerEdge(Edge.get()); + EI = OrigNode->CalleeEdges.erase(EI); + } else { + Edge->Caller->eraseCalleeEdge(Edge.get()); + EI = OrigNode->CallerEdges.erase(EI); + } + continue; + } + ++EI; + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: + assignStackNodesPostOrder(ContextNode *Node, + DenseSet<const ContextNode *> &Visited, + DenseMap<uint64_t, std::vector<CallContextInfo>> + &StackIdToMatchingCalls) { + auto Inserted = Visited.insert(Node); + if (!Inserted.second) + return; + // Post order traversal. Iterate over a copy since we may add nodes and + // therefore new callers during the recursive call, invalidating any + // iterator over the original edge vector. We don't need to process these + // new nodes as they were already processed on creation. + auto CallerEdges = Node->CallerEdges; + for (auto &Edge : CallerEdges) { + // Skip any that have been removed during the recursion. + if (!Edge) + continue; + assignStackNodesPostOrder(Edge->Caller, Visited, StackIdToMatchingCalls); + } + + // If this node's stack id is in the map, update the graph to contain new + // nodes representing any inlining at interior callsites. Note we move the + // associated context ids over to the new nodes. + + // Ignore this node if it is for an allocation or we didn't record any + // stack id lists ending at it. + if (Node->IsAllocation || + !StackIdToMatchingCalls.count(Node->OrigStackOrAllocId)) + return; + + auto &Calls = StackIdToMatchingCalls[Node->OrigStackOrAllocId]; + // Handle the simple case first. A single call with a single stack id. + // In this case there is no need to create any new context nodes, simply + // assign the context node for stack id to this Call. + if (Calls.size() == 1) { + auto &[Call, Ids, Func, SavedContextIds] = Calls[0]; + if (Ids.size() == 1) { + assert(SavedContextIds.empty()); + // It should be this Node + assert(Node == getNodeForStackId(Ids[0])); + if (Node->Recursive) + return; + Node->setCall(Call); + NonAllocationCallToContextNodeMap[Call] = Node; + NodeToCallingFunc[Node] = Func; + return; + } + } + + // Find the node for the last stack id, which should be the same + // across all calls recorded for this id, and is this node's id. + uint64_t LastId = Node->OrigStackOrAllocId; + ContextNode *LastNode = getNodeForStackId(LastId); + // We should only have kept stack ids that had nodes. + assert(LastNode); + + for (unsigned I = 0; I < Calls.size(); I++) { + auto &[Call, Ids, Func, SavedContextIds] = Calls[I]; + // Skip any for which we didn't assign any ids, these don't get a node in + // the graph. + if (SavedContextIds.empty()) + continue; + + assert(LastId == Ids.back()); + + ContextNode *FirstNode = getNodeForStackId(Ids[0]); + assert(FirstNode); + + // Recompute the context ids for this stack id sequence (the + // intersection of the context ids of the corresponding nodes). + // Start with the ids we saved in the map for this call, which could be + // duplicated context ids. We have to recompute as we might have overlap + // overlap between the saved context ids for different last nodes, and + // removed them already during the post order traversal. + set_intersect(SavedContextIds, FirstNode->ContextIds); + ContextNode *PrevNode = nullptr; + for (auto Id : Ids) { + ContextNode *CurNode = getNodeForStackId(Id); + // We should only have kept stack ids that had nodes and weren't + // recursive. + assert(CurNode); + assert(!CurNode->Recursive); + if (!PrevNode) { + PrevNode = CurNode; + continue; + } + auto *Edge = CurNode->findEdgeFromCallee(PrevNode); + if (!Edge) { + SavedContextIds.clear(); + break; + } + PrevNode = CurNode; + set_intersect(SavedContextIds, Edge->getContextIds()); + + // If we now have no context ids for clone, skip this call. + if (SavedContextIds.empty()) + break; + } + if (SavedContextIds.empty()) + continue; + + // Create new context node. + NodeOwner.push_back( + std::make_unique<ContextNode>(/*IsAllocation=*/false, Call)); + ContextNode *NewNode = NodeOwner.back().get(); + NodeToCallingFunc[NewNode] = Func; + NonAllocationCallToContextNodeMap[Call] = NewNode; + NewNode->ContextIds = SavedContextIds; + NewNode->AllocTypes = computeAllocType(NewNode->ContextIds); + + // Connect to callees of innermost stack frame in inlined call chain. + // This updates context ids for FirstNode's callee's to reflect those + // moved to NewNode. + connectNewNode(NewNode, FirstNode, /*TowardsCallee=*/true); + + // Connect to callers of outermost stack frame in inlined call chain. + // This updates context ids for FirstNode's caller's to reflect those + // moved to NewNode. + connectNewNode(NewNode, LastNode, /*TowardsCallee=*/false); + + // Now we need to remove context ids from edges/nodes between First and + // Last Node. + PrevNode = nullptr; + for (auto Id : Ids) { + ContextNode *CurNode = getNodeForStackId(Id); + // We should only have kept stack ids that had nodes. + assert(CurNode); + + // Remove the context ids moved to NewNode from CurNode, and the + // edge from the prior node. + set_subtract(CurNode->ContextIds, NewNode->ContextIds); + if (PrevNode) { + auto *PrevEdge = CurNode->findEdgeFromCallee(PrevNode); + assert(PrevEdge); + set_subtract(PrevEdge->getContextIds(), NewNode->ContextIds); + if (PrevEdge->getContextIds().empty()) { + PrevNode->eraseCallerEdge(PrevEdge); + CurNode->eraseCalleeEdge(PrevEdge); + } + } + PrevNode = CurNode; + } + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::updateStackNodes() { + // Map of stack id to all calls with that as the last (outermost caller) + // callsite id that has a context node (some might not due to pruning + // performed during matching of the allocation profile contexts). + // The CallContextInfo contains the Call and a list of its stack ids with + // ContextNodes, the function containing Call, and the set of context ids + // the analysis will eventually identify for use in any new node created + // for that callsite. + DenseMap<uint64_t, std::vector<CallContextInfo>> StackIdToMatchingCalls; + for (auto &[Func, CallsWithMetadata] : FuncToCallsWithMetadata) { + for (auto &Call : CallsWithMetadata) { + // Ignore allocations, already handled. + if (AllocationCallToContextNodeMap.count(Call)) + continue; + auto StackIdsWithContextNodes = + getStackIdsWithContextNodesForCall(Call.call()); + // If there were no nodes created for MIBs on allocs (maybe this was in + // the unambiguous part of the MIB stack that was pruned), ignore. + if (StackIdsWithContextNodes.empty()) + continue; + // Otherwise, record this Call along with the list of ids for the last + // (outermost caller) stack id with a node. + StackIdToMatchingCalls[StackIdsWithContextNodes.back()].push_back( + {Call.call(), StackIdsWithContextNodes, Func, {}}); + } + } + + // First make a pass through all stack ids that correspond to a call, + // as identified in the above loop. Compute the context ids corresponding to + // each of these calls when they correspond to multiple stack ids due to + // due to inlining. Perform any duplication of context ids required when + // there is more than one call with the same stack ids. Their (possibly newly + // duplicated) context ids are saved in the StackIdToMatchingCalls map. + DenseMap<uint32_t, DenseSet<uint32_t>> OldToNewContextIds; + for (auto &It : StackIdToMatchingCalls) { + auto &Calls = It.getSecond(); + // Skip single calls with a single stack id. These don't need a new node. + if (Calls.size() == 1) { + auto &Ids = std::get<1>(Calls[0]); + if (Ids.size() == 1) + continue; + } + // In order to do the best and maximal matching of inlined calls to context + // node sequences we will sort the vectors of stack ids in descending order + // of length, and within each length, lexicographically by stack id. The + // latter is so that we can specially handle calls that have identical stack + // id sequences (either due to cloning or artificially because of the MIB + // context pruning). + std::stable_sort(Calls.begin(), Calls.end(), + [](const CallContextInfo &A, const CallContextInfo &B) { + auto &IdsA = std::get<1>(A); + auto &IdsB = std::get<1>(B); + return IdsA.size() > IdsB.size() || + (IdsA.size() == IdsB.size() && IdsA < IdsB); + }); + + // Find the node for the last stack id, which should be the same + // across all calls recorded for this id, and is the id for this + // entry in the StackIdToMatchingCalls map. + uint64_t LastId = It.getFirst(); + ContextNode *LastNode = getNodeForStackId(LastId); + // We should only have kept stack ids that had nodes. + assert(LastNode); + + if (LastNode->Recursive) + continue; + + // Initialize the context ids with the last node's. We will subsequently + // refine the context ids by computing the intersection along all edges. + DenseSet<uint32_t> LastNodeContextIds = LastNode->ContextIds; + assert(!LastNodeContextIds.empty()); + + for (unsigned I = 0; I < Calls.size(); I++) { + auto &[Call, Ids, Func, SavedContextIds] = Calls[I]; + assert(SavedContextIds.empty()); + assert(LastId == Ids.back()); + + // First compute the context ids for this stack id sequence (the + // intersection of the context ids of the corresponding nodes). + // Start with the remaining saved ids for the last node. + assert(!LastNodeContextIds.empty()); + DenseSet<uint32_t> StackSequenceContextIds = LastNodeContextIds; + + ContextNode *PrevNode = LastNode; + ContextNode *CurNode = LastNode; + bool Skip = false; + + // Iterate backwards through the stack Ids, starting after the last Id + // in the list, which was handled once outside for all Calls. + for (auto IdIter = Ids.rbegin() + 1; IdIter != Ids.rend(); IdIter++) { + auto Id = *IdIter; + CurNode = getNodeForStackId(Id); + // We should only have kept stack ids that had nodes. + assert(CurNode); + + if (CurNode->Recursive) { + Skip = true; + break; + } + + auto *Edge = CurNode->findEdgeFromCaller(PrevNode); + // If there is no edge then the nodes belong to different MIB contexts, + // and we should skip this inlined context sequence. For example, this + // particular inlined context may include stack ids A->B, and we may + // indeed have nodes for both A and B, but it is possible that they were + // never profiled in sequence in a single MIB for any allocation (i.e. + // we might have profiled an allocation that involves the callsite A, + // but through a different one of its callee callsites, and we might + // have profiled an allocation that involves callsite B, but reached + // from a different caller callsite). + if (!Edge) { + Skip = true; + break; + } + PrevNode = CurNode; + + // Update the context ids, which is the intersection of the ids along + // all edges in the sequence. + set_intersect(StackSequenceContextIds, Edge->getContextIds()); + + // If we now have no context ids for clone, skip this call. + if (StackSequenceContextIds.empty()) { + Skip = true; + break; + } + } + if (Skip) + continue; + + // If some of this call's stack ids did not have corresponding nodes (due + // to pruning), don't include any context ids for contexts that extend + // beyond these nodes. Otherwise we would be matching part of unrelated / + // not fully matching stack contexts. To do this, subtract any context ids + // found in caller nodes of the last node found above. + if (Ids.back() != getLastStackId(Call)) { + for (const auto &PE : CurNode->CallerEdges) { + set_subtract(StackSequenceContextIds, PE->getContextIds()); + if (StackSequenceContextIds.empty()) + break; + } + // If we now have no context ids for clone, skip this call. + if (StackSequenceContextIds.empty()) + continue; + } + + // Check if the next set of stack ids is the same (since the Calls vector + // of tuples is sorted by the stack ids we can just look at the next one). + bool DuplicateContextIds = false; + if (I + 1 < Calls.size()) { + auto NextIds = std::get<1>(Calls[I + 1]); + DuplicateContextIds = Ids == NextIds; + } + + // If we don't have duplicate context ids, then we can assign all the + // context ids computed for the original node sequence to this call. + // If there are duplicate calls with the same stack ids then we synthesize + // new context ids that are duplicates of the originals. These are + // assigned to SavedContextIds, which is a reference into the map entry + // for this call, allowing us to access these ids later on. + OldToNewContextIds.reserve(OldToNewContextIds.size() + + StackSequenceContextIds.size()); + SavedContextIds = + DuplicateContextIds + ? duplicateContextIds(StackSequenceContextIds, OldToNewContextIds) + : StackSequenceContextIds; + assert(!SavedContextIds.empty()); + + if (!DuplicateContextIds) { + // Update saved last node's context ids to remove those that are + // assigned to other calls, so that it is ready for the next call at + // this stack id. + set_subtract(LastNodeContextIds, StackSequenceContextIds); + if (LastNodeContextIds.empty()) + break; + } + } + } + + // Propagate the duplicate context ids over the graph. + propagateDuplicateContextIds(OldToNewContextIds); + + if (VerifyCCG) + check(); + + // Now perform a post-order traversal over the graph, starting with the + // allocation nodes, essentially processing nodes from callers to callees. + // For any that contains an id in the map, update the graph to contain new + // nodes representing any inlining at interior callsites. Note we move the + // associated context ids over to the new nodes. + DenseSet<const ContextNode *> Visited; + for (auto &Entry : AllocationCallToContextNodeMap) + assignStackNodesPostOrder(Entry.second, Visited, StackIdToMatchingCalls); +} + +uint64_t ModuleCallsiteContextGraph::getLastStackId(Instruction *Call) { + CallStack<MDNode, MDNode::op_iterator> CallsiteContext( + Call->getMetadata(LLVMContext::MD_callsite)); + return CallsiteContext.back(); +} + +uint64_t IndexCallsiteContextGraph::getLastStackId(IndexCall &Call) { + assert(isa<CallsiteInfo *>(Call.getBase())); + CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator> + CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase())); + // Need to convert index into stack id. + return Index.getStackIdAtIndex(CallsiteContext.back()); +} + +static const std::string MemProfCloneSuffix = ".memprof."; + +static std::string getMemProfFuncName(Twine Base, unsigned CloneNo) { + // We use CloneNo == 0 to refer to the original version, which doesn't get + // renamed with a suffix. + if (!CloneNo) + return Base.str(); + return (Base + MemProfCloneSuffix + Twine(CloneNo)).str(); +} + +std::string ModuleCallsiteContextGraph::getLabel(const Function *Func, + const Instruction *Call, + unsigned CloneNo) const { + return (Twine(Call->getFunction()->getName()) + " -> " + + cast<CallBase>(Call)->getCalledFunction()->getName()) + .str(); +} + +std::string IndexCallsiteContextGraph::getLabel(const FunctionSummary *Func, + const IndexCall &Call, + unsigned CloneNo) const { + auto VI = FSToVIMap.find(Func); + assert(VI != FSToVIMap.end()); + if (isa<AllocInfo *>(Call.getBase())) + return (VI->second.name() + " -> alloc").str(); + else { + auto *Callsite = dyn_cast_if_present<CallsiteInfo *>(Call.getBase()); + return (VI->second.name() + " -> " + + getMemProfFuncName(Callsite->Callee.name(), + Callsite->Clones[CloneNo])) + .str(); + } +} + +std::vector<uint64_t> +ModuleCallsiteContextGraph::getStackIdsWithContextNodesForCall( + Instruction *Call) { + CallStack<MDNode, MDNode::op_iterator> CallsiteContext( + Call->getMetadata(LLVMContext::MD_callsite)); + return getStackIdsWithContextNodes<MDNode, MDNode::op_iterator>( + CallsiteContext); +} + +std::vector<uint64_t> +IndexCallsiteContextGraph::getStackIdsWithContextNodesForCall(IndexCall &Call) { + assert(isa<CallsiteInfo *>(Call.getBase())); + CallStack<CallsiteInfo, SmallVector<unsigned>::const_iterator> + CallsiteContext(dyn_cast_if_present<CallsiteInfo *>(Call.getBase())); + return getStackIdsWithContextNodes<CallsiteInfo, + SmallVector<unsigned>::const_iterator>( + CallsiteContext); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +template <class NodeT, class IteratorT> +std::vector<uint64_t> +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getStackIdsWithContextNodes( + CallStack<NodeT, IteratorT> &CallsiteContext) { + std::vector<uint64_t> StackIds; + for (auto IdOrIndex : CallsiteContext) { + auto StackId = getStackId(IdOrIndex); + ContextNode *Node = getNodeForStackId(StackId); + if (!Node) + break; + StackIds.push_back(StackId); + } + return StackIds; +} + +ModuleCallsiteContextGraph::ModuleCallsiteContextGraph( + Module &M, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) + : Mod(M), OREGetter(OREGetter) { + for (auto &F : M) { + std::vector<CallInfo> CallsWithMetadata; + for (auto &BB : F) { + for (auto &I : BB) { + if (!isa<CallBase>(I)) + continue; + if (auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof)) { + CallsWithMetadata.push_back(&I); + auto *AllocNode = addAllocNode(&I, &F); + auto *CallsiteMD = I.getMetadata(LLVMContext::MD_callsite); + assert(CallsiteMD); + CallStack<MDNode, MDNode::op_iterator> CallsiteContext(CallsiteMD); + // Add all of the MIBs and their stack nodes. + for (auto &MDOp : MemProfMD->operands()) { + auto *MIBMD = cast<const MDNode>(MDOp); + MDNode *StackNode = getMIBStackNode(MIBMD); + assert(StackNode); + CallStack<MDNode, MDNode::op_iterator> StackContext(StackNode); + addStackNodesForMIB<MDNode, MDNode::op_iterator>( + AllocNode, StackContext, CallsiteContext, + getMIBAllocType(MIBMD)); + } + assert(AllocNode->AllocTypes != (uint8_t)AllocationType::None); + // Memprof and callsite metadata on memory allocations no longer + // needed. + I.setMetadata(LLVMContext::MD_memprof, nullptr); + I.setMetadata(LLVMContext::MD_callsite, nullptr); + } + // For callsite metadata, add to list for this function for later use. + else if (I.getMetadata(LLVMContext::MD_callsite)) + CallsWithMetadata.push_back(&I); + } + } + if (!CallsWithMetadata.empty()) + FuncToCallsWithMetadata.push_back({&F, CallsWithMetadata}); + } + + if (DumpCCG) { + dbgs() << "CCG before updating call stack chains:\n"; + dbgs() << *this; + } + + if (ExportToDot) + exportToDot("prestackupdate"); + + updateStackNodes(); + + handleCallsitesWithMultipleTargets(); + + // Strip off remaining callsite metadata, no longer needed. + for (auto &FuncEntry : FuncToCallsWithMetadata) + for (auto &Call : FuncEntry.second) + Call.call()->setMetadata(LLVMContext::MD_callsite, nullptr); +} + +IndexCallsiteContextGraph::IndexCallsiteContextGraph( + ModuleSummaryIndex &Index, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing) + : Index(Index) { + for (auto &I : Index) { + auto VI = Index.getValueInfo(I); + for (auto &S : VI.getSummaryList()) { + // We should only add the prevailing nodes. Otherwise we may try to clone + // in a weak copy that won't be linked (and may be different than the + // prevailing version). + // We only keep the memprof summary on the prevailing copy now when + // building the combined index, as a space optimization, however don't + // rely on this optimization. The linker doesn't resolve local linkage + // values so don't check whether those are prevailing. + if (!GlobalValue::isLocalLinkage(S->linkage()) && + !isPrevailing(VI.getGUID(), S.get())) + continue; + auto *FS = dyn_cast<FunctionSummary>(S.get()); + if (!FS) + continue; + std::vector<CallInfo> CallsWithMetadata; + if (!FS->allocs().empty()) { + for (auto &AN : FS->mutableAllocs()) { + // This can happen because of recursion elimination handling that + // currently exists in ModuleSummaryAnalysis. Skip these for now. + // We still added them to the summary because we need to be able to + // correlate properly in applyImport in the backends. + if (AN.MIBs.empty()) + continue; + CallsWithMetadata.push_back({&AN}); + auto *AllocNode = addAllocNode({&AN}, FS); + // Pass an empty CallStack to the CallsiteContext (second) + // parameter, since for ThinLTO we already collapsed out the inlined + // stack ids on the allocation call during ModuleSummaryAnalysis. + CallStack<MIBInfo, SmallVector<unsigned>::const_iterator> + EmptyContext; + // Now add all of the MIBs and their stack nodes. + for (auto &MIB : AN.MIBs) { + CallStack<MIBInfo, SmallVector<unsigned>::const_iterator> + StackContext(&MIB); + addStackNodesForMIB<MIBInfo, SmallVector<unsigned>::const_iterator>( + AllocNode, StackContext, EmptyContext, MIB.AllocType); + } + assert(AllocNode->AllocTypes != (uint8_t)AllocationType::None); + // Initialize version 0 on the summary alloc node to the current alloc + // type, unless it has both types in which case make it default, so + // that in the case where we aren't able to clone the original version + // always ends up with the default allocation behavior. + AN.Versions[0] = (uint8_t)allocTypeToUse(AllocNode->AllocTypes); + } + } + // For callsite metadata, add to list for this function for later use. + if (!FS->callsites().empty()) + for (auto &SN : FS->mutableCallsites()) + CallsWithMetadata.push_back({&SN}); + + if (!CallsWithMetadata.empty()) + FuncToCallsWithMetadata.push_back({FS, CallsWithMetadata}); + + if (!FS->allocs().empty() || !FS->callsites().empty()) + FSToVIMap[FS] = VI; + } + } + + if (DumpCCG) { + dbgs() << "CCG before updating call stack chains:\n"; + dbgs() << *this; + } + + if (ExportToDot) + exportToDot("prestackupdate"); + + updateStackNodes(); + + handleCallsitesWithMultipleTargets(); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, + CallTy>::handleCallsitesWithMultipleTargets() { + // Look for and workaround callsites that call multiple functions. + // This can happen for indirect calls, which needs better handling, and in + // more rare cases (e.g. macro expansion). + // TODO: To fix this for indirect calls we will want to perform speculative + // devirtualization using either the normal PGO info with ICP, or using the + // information in the profiled MemProf contexts. We can do this prior to + // this transformation for regular LTO, and for ThinLTO we can simulate that + // effect in the summary and perform the actual speculative devirtualization + // while cloning in the ThinLTO backend. + for (auto Entry = NonAllocationCallToContextNodeMap.begin(); + Entry != NonAllocationCallToContextNodeMap.end();) { + auto *Node = Entry->second; + assert(Node->Clones.empty()); + // Check all node callees and see if in the same function. + bool Removed = false; + auto Call = Node->Call.call(); + for (auto &Edge : Node->CalleeEdges) { + if (!Edge->Callee->hasCall()) + continue; + assert(NodeToCallingFunc.count(Edge->Callee)); + // Check if the called function matches that of the callee node. + if (calleeMatchesFunc(Call, NodeToCallingFunc[Edge->Callee])) + continue; + // Work around by setting Node to have a null call, so it gets + // skipped during cloning. Otherwise assignFunctions will assert + // because its data structures are not designed to handle this case. + Entry = NonAllocationCallToContextNodeMap.erase(Entry); + Node->setCall(CallInfo()); + Removed = true; + break; + } + if (!Removed) + Entry++; + } +} + +uint64_t ModuleCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const { + // In the Module (IR) case this is already the Id. + return IdOrIndex; +} + +uint64_t IndexCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const { + // In the Index case this is an index into the stack id list in the summary + // index, convert it to an Id. + return Index.getStackIdAtIndex(IdOrIndex); +} + +bool ModuleCallsiteContextGraph::calleeMatchesFunc(Instruction *Call, + const Function *Func) { + auto *CB = dyn_cast<CallBase>(Call); + if (!CB->getCalledOperand()) + return false; + auto *CalleeVal = CB->getCalledOperand()->stripPointerCasts(); + auto *CalleeFunc = dyn_cast<Function>(CalleeVal); + if (CalleeFunc == Func) + return true; + auto *Alias = dyn_cast<GlobalAlias>(CalleeVal); + return Alias && Alias->getAliasee() == Func; +} + +bool IndexCallsiteContextGraph::calleeMatchesFunc(IndexCall &Call, + const FunctionSummary *Func) { + ValueInfo Callee = + dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee; + // If there is no summary list then this is a call to an externally defined + // symbol. + AliasSummary *Alias = + Callee.getSummaryList().empty() + ? nullptr + : dyn_cast<AliasSummary>(Callee.getSummaryList()[0].get()); + assert(FSToVIMap.count(Func)); + return Callee == FSToVIMap[Func] || + // If callee is an alias, check the aliasee, since only function + // summary base objects will contain the stack node summaries and thus + // get a context node. + (Alias && Alias->getAliaseeVI() == FSToVIMap[Func]); +} + +static std::string getAllocTypeString(uint8_t AllocTypes) { + if (!AllocTypes) + return "None"; + std::string Str; + if (AllocTypes & (uint8_t)AllocationType::NotCold) + Str += "NotCold"; + if (AllocTypes & (uint8_t)AllocationType::Cold) + Str += "Cold"; + return Str; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::dump() + const { + print(dbgs()); + dbgs() << "\n"; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::print( + raw_ostream &OS) const { + OS << "Node " << this << "\n"; + OS << "\t"; + printCall(OS); + if (Recursive) + OS << " (recursive)"; + OS << "\n"; + OS << "\tAllocTypes: " << getAllocTypeString(AllocTypes) << "\n"; + OS << "\tContextIds:"; + std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end()); + std::sort(SortedIds.begin(), SortedIds.end()); + for (auto Id : SortedIds) + OS << " " << Id; + OS << "\n"; + OS << "\tCalleeEdges:\n"; + for (auto &Edge : CalleeEdges) + OS << "\t\t" << *Edge << "\n"; + OS << "\tCallerEdges:\n"; + for (auto &Edge : CallerEdges) + OS << "\t\t" << *Edge << "\n"; + if (!Clones.empty()) { + OS << "\tClones: "; + FieldSeparator FS; + for (auto *Clone : Clones) + OS << FS << Clone; + OS << "\n"; + } else if (CloneOf) { + OS << "\tClone of " << CloneOf << "\n"; + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge::dump() + const { + print(dbgs()); + dbgs() << "\n"; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextEdge::print( + raw_ostream &OS) const { + OS << "Edge from Callee " << Callee << " to Caller: " << Caller + << " AllocTypes: " << getAllocTypeString(AllocTypes); + OS << " ContextIds:"; + std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end()); + std::sort(SortedIds.begin(), SortedIds.end()); + for (auto Id : SortedIds) + OS << " " << Id; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::dump() const { + print(dbgs()); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::print( + raw_ostream &OS) const { + OS << "Callsite Context Graph:\n"; + using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *; + for (const auto Node : nodes<GraphType>(this)) { + if (Node->isRemoved()) + continue; + Node->print(OS); + OS << "\n"; + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +static void checkEdge( + const std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>> &Edge) { + // Confirm that alloc type is not None and that we have at least one context + // id. + assert(Edge->AllocTypes != (uint8_t)AllocationType::None); + assert(!Edge->ContextIds.empty()); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +static void checkNode(const ContextNode<DerivedCCG, FuncTy, CallTy> *Node, + bool CheckEdges = true) { + if (Node->isRemoved()) + return; + // Node's context ids should be the union of both its callee and caller edge + // context ids. + if (Node->CallerEdges.size()) { + auto EI = Node->CallerEdges.begin(); + auto &FirstEdge = *EI; + EI++; + DenseSet<uint32_t> CallerEdgeContextIds(FirstEdge->ContextIds); + for (; EI != Node->CallerEdges.end(); EI++) { + const auto &Edge = *EI; + if (CheckEdges) + checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); + set_union(CallerEdgeContextIds, Edge->ContextIds); + } + // Node can have more context ids than callers if some contexts terminate at + // node and some are longer. + assert(Node->ContextIds == CallerEdgeContextIds || + set_is_subset(CallerEdgeContextIds, Node->ContextIds)); + } + if (Node->CalleeEdges.size()) { + auto EI = Node->CalleeEdges.begin(); + auto &FirstEdge = *EI; + EI++; + DenseSet<uint32_t> CalleeEdgeContextIds(FirstEdge->ContextIds); + for (; EI != Node->CalleeEdges.end(); EI++) { + const auto &Edge = *EI; + if (CheckEdges) + checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); + set_union(CalleeEdgeContextIds, Edge->ContextIds); + } + assert(Node->ContextIds == CalleeEdgeContextIds); + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::check() const { + using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *; + for (const auto Node : nodes<GraphType>(this)) { + checkNode<DerivedCCG, FuncTy, CallTy>(Node, /*CheckEdges=*/false); + for (auto &Edge : Node->CallerEdges) + checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +struct GraphTraits<const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *> { + using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *; + using NodeRef = const ContextNode<DerivedCCG, FuncTy, CallTy> *; + + using NodePtrTy = std::unique_ptr<ContextNode<DerivedCCG, FuncTy, CallTy>>; + static NodeRef getNode(const NodePtrTy &P) { return P.get(); } + + using nodes_iterator = + mapped_iterator<typename std::vector<NodePtrTy>::const_iterator, + decltype(&getNode)>; + + static nodes_iterator nodes_begin(GraphType G) { + return nodes_iterator(G->NodeOwner.begin(), &getNode); + } + + static nodes_iterator nodes_end(GraphType G) { + return nodes_iterator(G->NodeOwner.end(), &getNode); + } + + static NodeRef getEntryNode(GraphType G) { + return G->NodeOwner.begin()->get(); + } + + using EdgePtrTy = std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>>; + static const ContextNode<DerivedCCG, FuncTy, CallTy> * + GetCallee(const EdgePtrTy &P) { + return P->Callee; + } + + using ChildIteratorType = + mapped_iterator<typename std::vector<std::shared_ptr<ContextEdge< + DerivedCCG, FuncTy, CallTy>>>::const_iterator, + decltype(&GetCallee)>; + + static ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N->CalleeEdges.begin(), &GetCallee); + } + + static ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType(N->CalleeEdges.end(), &GetCallee); + } +}; + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +struct DOTGraphTraits<const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *> + : public DefaultDOTGraphTraits { + DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} + + using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *; + using GTraits = GraphTraits<GraphType>; + using NodeRef = typename GTraits::NodeRef; + using ChildIteratorType = typename GTraits::ChildIteratorType; + + static std::string getNodeLabel(NodeRef Node, GraphType G) { + std::string LabelString = + (Twine("OrigId: ") + (Node->IsAllocation ? "Alloc" : "") + + Twine(Node->OrigStackOrAllocId)) + .str(); + LabelString += "\n"; + if (Node->hasCall()) { + auto Func = G->NodeToCallingFunc.find(Node); + assert(Func != G->NodeToCallingFunc.end()); + LabelString += + G->getLabel(Func->second, Node->Call.call(), Node->Call.cloneNo()); + } else { + LabelString += "null call"; + if (Node->Recursive) + LabelString += " (recursive)"; + else + LabelString += " (external)"; + } + return LabelString; + } + + static std::string getNodeAttributes(NodeRef Node, GraphType) { + std::string AttributeString = (Twine("tooltip=\"") + getNodeId(Node) + " " + + getContextIds(Node->ContextIds) + "\"") + .str(); + AttributeString += + (Twine(",fillcolor=\"") + getColor(Node->AllocTypes) + "\"").str(); + AttributeString += ",style=\"filled\""; + if (Node->CloneOf) { + AttributeString += ",color=\"blue\""; + AttributeString += ",style=\"filled,bold,dashed\""; + } else + AttributeString += ",style=\"filled\""; + return AttributeString; + } + + static std::string getEdgeAttributes(NodeRef, ChildIteratorType ChildIter, + GraphType) { + auto &Edge = *(ChildIter.getCurrent()); + return (Twine("tooltip=\"") + getContextIds(Edge->ContextIds) + "\"" + + Twine(",fillcolor=\"") + getColor(Edge->AllocTypes) + "\"") + .str(); + } + + // Since the NodeOwners list includes nodes that are no longer connected to + // the graph, skip them here. + static bool isNodeHidden(NodeRef Node, GraphType) { + return Node->isRemoved(); + } + +private: + static std::string getContextIds(const DenseSet<uint32_t> &ContextIds) { + std::string IdString = "ContextIds:"; + if (ContextIds.size() < 100) { + std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end()); + std::sort(SortedIds.begin(), SortedIds.end()); + for (auto Id : SortedIds) + IdString += (" " + Twine(Id)).str(); + } else { + IdString += (" (" + Twine(ContextIds.size()) + " ids)").str(); + } + return IdString; + } + + static std::string getColor(uint8_t AllocTypes) { + if (AllocTypes == (uint8_t)AllocationType::NotCold) + // Color "brown1" actually looks like a lighter red. + return "brown1"; + if (AllocTypes == (uint8_t)AllocationType::Cold) + return "cyan"; + if (AllocTypes == + ((uint8_t)AllocationType::NotCold | (uint8_t)AllocationType::Cold)) + // Lighter purple. + return "mediumorchid1"; + return "gray"; + } + + static std::string getNodeId(NodeRef Node) { + std::stringstream SStream; + SStream << std::hex << "N0x" << (unsigned long long)Node; + std::string Result = SStream.str(); + return Result; + } +}; + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::exportToDot( + std::string Label) const { + WriteGraph(this, "", false, Label, + DotFilePathPrefix + "ccg." + Label + ".dot"); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode * +CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::moveEdgeToNewCalleeClone( + const std::shared_ptr<ContextEdge> &Edge, EdgeIter *CallerEdgeI) { + ContextNode *Node = Edge->Callee; + NodeOwner.push_back( + std::make_unique<ContextNode>(Node->IsAllocation, Node->Call)); + ContextNode *Clone = NodeOwner.back().get(); + Node->addClone(Clone); + assert(NodeToCallingFunc.count(Node)); + NodeToCallingFunc[Clone] = NodeToCallingFunc[Node]; + moveEdgeToExistingCalleeClone(Edge, Clone, CallerEdgeI, /*NewClone=*/true); + return Clone; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: + moveEdgeToExistingCalleeClone(const std::shared_ptr<ContextEdge> &Edge, + ContextNode *NewCallee, EdgeIter *CallerEdgeI, + bool NewClone) { + // NewCallee and Edge's current callee must be clones of the same original + // node (Edge's current callee may be the original node too). + assert(NewCallee->getOrigNode() == Edge->Callee->getOrigNode()); + auto &EdgeContextIds = Edge->getContextIds(); + ContextNode *OldCallee = Edge->Callee; + if (CallerEdgeI) + *CallerEdgeI = OldCallee->CallerEdges.erase(*CallerEdgeI); + else + OldCallee->eraseCallerEdge(Edge.get()); + Edge->Callee = NewCallee; + NewCallee->CallerEdges.push_back(Edge); + // Don't need to update Edge's context ids since we are simply reconnecting + // it. + set_subtract(OldCallee->ContextIds, EdgeContextIds); + NewCallee->ContextIds.insert(EdgeContextIds.begin(), EdgeContextIds.end()); + NewCallee->AllocTypes |= Edge->AllocTypes; + OldCallee->AllocTypes = computeAllocType(OldCallee->ContextIds); + // OldCallee alloc type should be None iff its context id set is now empty. + assert((OldCallee->AllocTypes == (uint8_t)AllocationType::None) == + OldCallee->ContextIds.empty()); + // Now walk the old callee node's callee edges and move Edge's context ids + // over to the corresponding edge into the clone (which is created here if + // this is a newly created clone). + for (auto &OldCalleeEdge : OldCallee->CalleeEdges) { + // The context ids moving to the new callee are the subset of this edge's + // context ids and the context ids on the caller edge being moved. + DenseSet<uint32_t> EdgeContextIdsToMove = + set_intersection(OldCalleeEdge->getContextIds(), EdgeContextIds); + set_subtract(OldCalleeEdge->getContextIds(), EdgeContextIdsToMove); + OldCalleeEdge->AllocTypes = + computeAllocType(OldCalleeEdge->getContextIds()); + if (!NewClone) { + // Update context ids / alloc type on corresponding edge to NewCallee. + // There is a chance this may not exist if we are reusing an existing + // clone, specifically during function assignment, where we would have + // removed none type edges after creating the clone. If we can't find + // a corresponding edge there, fall through to the cloning below. + if (auto *NewCalleeEdge = + NewCallee->findEdgeFromCallee(OldCalleeEdge->Callee)) { + NewCalleeEdge->getContextIds().insert(EdgeContextIdsToMove.begin(), + EdgeContextIdsToMove.end()); + NewCalleeEdge->AllocTypes |= computeAllocType(EdgeContextIdsToMove); + continue; + } + } + auto NewEdge = std::make_shared<ContextEdge>( + OldCalleeEdge->Callee, NewCallee, + computeAllocType(EdgeContextIdsToMove), EdgeContextIdsToMove); + NewCallee->CalleeEdges.push_back(NewEdge); + NewEdge->Callee->CallerEdges.push_back(NewEdge); + } + if (VerifyCCG) { + checkNode<DerivedCCG, FuncTy, CallTy>(OldCallee, /*CheckEdges=*/false); + checkNode<DerivedCCG, FuncTy, CallTy>(NewCallee, /*CheckEdges=*/false); + for (const auto &OldCalleeEdge : OldCallee->CalleeEdges) + checkNode<DerivedCCG, FuncTy, CallTy>(OldCalleeEdge->Callee, + /*CheckEdges=*/false); + for (const auto &NewCalleeEdge : NewCallee->CalleeEdges) + checkNode<DerivedCCG, FuncTy, CallTy>(NewCalleeEdge->Callee, + /*CheckEdges=*/false); + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones() { + DenseSet<const ContextNode *> Visited; + for (auto &Entry : AllocationCallToContextNodeMap) + identifyClones(Entry.second, Visited); +} + +// helper function to check an AllocType is cold or notcold or both. +bool checkColdOrNotCold(uint8_t AllocType) { + return (AllocType == (uint8_t)AllocationType::Cold) || + (AllocType == (uint8_t)AllocationType::NotCold) || + (AllocType == + ((uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold)); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( + ContextNode *Node, DenseSet<const ContextNode *> &Visited) { + if (VerifyNodes) + checkNode<DerivedCCG, FuncTy, CallTy>(Node); + assert(!Node->CloneOf); + + // If Node as a null call, then either it wasn't found in the module (regular + // LTO) or summary index (ThinLTO), or there were other conditions blocking + // cloning (e.g. recursion, calls multiple targets, etc). + // Do this here so that we don't try to recursively clone callers below, which + // isn't useful at least for this node. + if (!Node->hasCall()) + return; + +#ifndef NDEBUG + auto Insert = +#endif + Visited.insert(Node); + // We should not have visited this node yet. + assert(Insert.second); + // The recursive call to identifyClones may delete the current edge from the + // CallerEdges vector. Make a copy and iterate on that, simpler than passing + // in an iterator and having recursive call erase from it. Other edges may + // also get removed during the recursion, which will have null Callee and + // Caller pointers (and are deleted later), so we skip those below. + { + auto CallerEdges = Node->CallerEdges; + for (auto &Edge : CallerEdges) { + // Skip any that have been removed by an earlier recursive call. + if (Edge->Callee == nullptr && Edge->Caller == nullptr) { + assert(!std::count(Node->CallerEdges.begin(), Node->CallerEdges.end(), + Edge)); + continue; + } + // Ignore any caller we previously visited via another edge. + if (!Visited.count(Edge->Caller) && !Edge->Caller->CloneOf) { + identifyClones(Edge->Caller, Visited); + } + } + } + + // Check if we reached an unambiguous call or have have only a single caller. + if (hasSingleAllocType(Node->AllocTypes) || Node->CallerEdges.size() <= 1) + return; + + // We need to clone. + + // Try to keep the original version as alloc type NotCold. This will make + // cases with indirect calls or any other situation with an unknown call to + // the original function get the default behavior. We do this by sorting the + // CallerEdges of the Node we will clone by alloc type. + // + // Give NotCold edge the lowest sort priority so those edges are at the end of + // the caller edges vector, and stay on the original version (since the below + // code clones greedily until it finds all remaining edges have the same type + // and leaves the remaining ones on the original Node). + // + // We shouldn't actually have any None type edges, so the sorting priority for + // that is arbitrary, and we assert in that case below. + const unsigned AllocTypeCloningPriority[] = {/*None*/ 3, /*NotCold*/ 4, + /*Cold*/ 1, + /*NotColdCold*/ 2}; + std::stable_sort(Node->CallerEdges.begin(), Node->CallerEdges.end(), + [&](const std::shared_ptr<ContextEdge> &A, + const std::shared_ptr<ContextEdge> &B) { + assert(checkColdOrNotCold(A->AllocTypes) && + checkColdOrNotCold(B->AllocTypes)); + + if (A->AllocTypes == B->AllocTypes) + // Use the first context id for each edge as a + // tie-breaker. + return *A->ContextIds.begin() < *B->ContextIds.begin(); + return AllocTypeCloningPriority[A->AllocTypes] < + AllocTypeCloningPriority[B->AllocTypes]; + }); + + assert(Node->AllocTypes != (uint8_t)AllocationType::None); + + // Iterate until we find no more opportunities for disambiguating the alloc + // types via cloning. In most cases this loop will terminate once the Node + // has a single allocation type, in which case no more cloning is needed. + // We need to be able to remove Edge from CallerEdges, so need to adjust + // iterator inside the loop. + for (auto EI = Node->CallerEdges.begin(); EI != Node->CallerEdges.end();) { + auto CallerEdge = *EI; + + // See if cloning the prior caller edge left this node with a single alloc + // type or a single caller. In that case no more cloning of Node is needed. + if (hasSingleAllocType(Node->AllocTypes) || Node->CallerEdges.size() <= 1) + break; + + // Compute the node callee edge alloc types corresponding to the context ids + // for this caller edge. + std::vector<uint8_t> CalleeEdgeAllocTypesForCallerEdge; + CalleeEdgeAllocTypesForCallerEdge.reserve(Node->CalleeEdges.size()); + for (auto &CalleeEdge : Node->CalleeEdges) + CalleeEdgeAllocTypesForCallerEdge.push_back(intersectAllocTypes( + CalleeEdge->getContextIds(), CallerEdge->getContextIds())); + + // Don't clone if doing so will not disambiguate any alloc types amongst + // caller edges (including the callee edges that would be cloned). + // Otherwise we will simply move all edges to the clone. + // + // First check if by cloning we will disambiguate the caller allocation + // type from node's allocation type. Query allocTypeToUse so that we don't + // bother cloning to distinguish NotCold+Cold from NotCold. Note that + // neither of these should be None type. + // + // Then check if by cloning node at least one of the callee edges will be + // disambiguated by splitting out different context ids. + assert(CallerEdge->AllocTypes != (uint8_t)AllocationType::None); + assert(Node->AllocTypes != (uint8_t)AllocationType::None); + if (allocTypeToUse(CallerEdge->AllocTypes) == + allocTypeToUse(Node->AllocTypes) && + allocTypesMatch<DerivedCCG, FuncTy, CallTy>( + CalleeEdgeAllocTypesForCallerEdge, Node->CalleeEdges)) { + ++EI; + continue; + } + + // First see if we can use an existing clone. Check each clone and its + // callee edges for matching alloc types. + ContextNode *Clone = nullptr; + for (auto *CurClone : Node->Clones) { + if (allocTypeToUse(CurClone->AllocTypes) != + allocTypeToUse(CallerEdge->AllocTypes)) + continue; + + if (!allocTypesMatch<DerivedCCG, FuncTy, CallTy>( + CalleeEdgeAllocTypesForCallerEdge, CurClone->CalleeEdges)) + continue; + Clone = CurClone; + break; + } + + // The edge iterator is adjusted when we move the CallerEdge to the clone. + if (Clone) + moveEdgeToExistingCalleeClone(CallerEdge, Clone, &EI); + else + Clone = moveEdgeToNewCalleeClone(CallerEdge, &EI); + + assert(EI == Node->CallerEdges.end() || + Node->AllocTypes != (uint8_t)AllocationType::None); + // Sanity check that no alloc types on clone or its edges are None. + assert(Clone->AllocTypes != (uint8_t)AllocationType::None); + assert(llvm::none_of( + Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) { + return E->AllocTypes == (uint8_t)AllocationType::None; + })); + } + + // Cloning may have resulted in some cloned callee edges with type None, + // because they aren't carrying any contexts. Remove those edges. + for (auto *Clone : Node->Clones) { + removeNoneTypeCalleeEdges(Clone); + if (VerifyNodes) + checkNode<DerivedCCG, FuncTy, CallTy>(Clone); + } + // We should still have some context ids on the original Node. + assert(!Node->ContextIds.empty()); + + // Remove any callee edges that ended up with alloc type None after creating + // clones and updating callee edges. + removeNoneTypeCalleeEdges(Node); + + // Sanity check that no alloc types on node or edges are None. + assert(Node->AllocTypes != (uint8_t)AllocationType::None); + assert(llvm::none_of(Node->CalleeEdges, + [&](const std::shared_ptr<ContextEdge> &E) { + return E->AllocTypes == (uint8_t)AllocationType::None; + })); + assert(llvm::none_of(Node->CallerEdges, + [&](const std::shared_ptr<ContextEdge> &E) { + return E->AllocTypes == (uint8_t)AllocationType::None; + })); + + if (VerifyNodes) + checkNode<DerivedCCG, FuncTy, CallTy>(Node); +} + +void ModuleCallsiteContextGraph::updateAllocationCall( + CallInfo &Call, AllocationType AllocType) { + std::string AllocTypeString = getAllocTypeAttributeString(AllocType); + auto A = llvm::Attribute::get(Call.call()->getFunction()->getContext(), + "memprof", AllocTypeString); + cast<CallBase>(Call.call())->addFnAttr(A); + OREGetter(Call.call()->getFunction()) + .emit(OptimizationRemark(DEBUG_TYPE, "MemprofAttribute", Call.call()) + << ore::NV("AllocationCall", Call.call()) << " in clone " + << ore::NV("Caller", Call.call()->getFunction()) + << " marked with memprof allocation attribute " + << ore::NV("Attribute", AllocTypeString)); +} + +void IndexCallsiteContextGraph::updateAllocationCall(CallInfo &Call, + AllocationType AllocType) { + auto *AI = Call.call().dyn_cast<AllocInfo *>(); + assert(AI); + assert(AI->Versions.size() > Call.cloneNo()); + AI->Versions[Call.cloneNo()] = (uint8_t)AllocType; +} + +void ModuleCallsiteContextGraph::updateCall(CallInfo &CallerCall, + FuncInfo CalleeFunc) { + if (CalleeFunc.cloneNo() > 0) + cast<CallBase>(CallerCall.call())->setCalledFunction(CalleeFunc.func()); + OREGetter(CallerCall.call()->getFunction()) + .emit(OptimizationRemark(DEBUG_TYPE, "MemprofCall", CallerCall.call()) + << ore::NV("Call", CallerCall.call()) << " in clone " + << ore::NV("Caller", CallerCall.call()->getFunction()) + << " assigned to call function clone " + << ore::NV("Callee", CalleeFunc.func())); +} + +void IndexCallsiteContextGraph::updateCall(CallInfo &CallerCall, + FuncInfo CalleeFunc) { + auto *CI = CallerCall.call().dyn_cast<CallsiteInfo *>(); + assert(CI && + "Caller cannot be an allocation which should not have profiled calls"); + assert(CI->Clones.size() > CallerCall.cloneNo()); + CI->Clones[CallerCall.cloneNo()] = CalleeFunc.cloneNo(); +} + +CallsiteContextGraph<ModuleCallsiteContextGraph, Function, + Instruction *>::FuncInfo +ModuleCallsiteContextGraph::cloneFunctionForCallsite( + FuncInfo &Func, CallInfo &Call, std::map<CallInfo, CallInfo> &CallMap, + std::vector<CallInfo> &CallsWithMetadataInFunc, unsigned CloneNo) { + // Use existing LLVM facilities for cloning and obtaining Call in clone + ValueToValueMapTy VMap; + auto *NewFunc = CloneFunction(Func.func(), VMap); + std::string Name = getMemProfFuncName(Func.func()->getName(), CloneNo); + assert(!Func.func()->getParent()->getFunction(Name)); + NewFunc->setName(Name); + for (auto &Inst : CallsWithMetadataInFunc) { + // This map always has the initial version in it. + assert(Inst.cloneNo() == 0); + CallMap[Inst] = {cast<Instruction>(VMap[Inst.call()]), CloneNo}; + } + OREGetter(Func.func()) + .emit(OptimizationRemark(DEBUG_TYPE, "MemprofClone", Func.func()) + << "created clone " << ore::NV("NewFunction", NewFunc)); + return {NewFunc, CloneNo}; +} + +CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary, + IndexCall>::FuncInfo +IndexCallsiteContextGraph::cloneFunctionForCallsite( + FuncInfo &Func, CallInfo &Call, std::map<CallInfo, CallInfo> &CallMap, + std::vector<CallInfo> &CallsWithMetadataInFunc, unsigned CloneNo) { + // Check how many clones we have of Call (and therefore function). + // The next clone number is the current size of versions array. + // Confirm this matches the CloneNo provided by the caller, which is based on + // the number of function clones we have. + assert(CloneNo == + (Call.call().is<AllocInfo *>() + ? Call.call().dyn_cast<AllocInfo *>()->Versions.size() + : Call.call().dyn_cast<CallsiteInfo *>()->Clones.size())); + // Walk all the instructions in this function. Create a new version for + // each (by adding an entry to the Versions/Clones summary array), and copy + // over the version being called for the function clone being cloned here. + // Additionally, add an entry to the CallMap for the new function clone, + // mapping the original call (clone 0, what is in CallsWithMetadataInFunc) + // to the new call clone. + for (auto &Inst : CallsWithMetadataInFunc) { + // This map always has the initial version in it. + assert(Inst.cloneNo() == 0); + if (auto *AI = Inst.call().dyn_cast<AllocInfo *>()) { + assert(AI->Versions.size() == CloneNo); + // We assign the allocation type later (in updateAllocationCall), just add + // an entry for it here. + AI->Versions.push_back(0); + } else { + auto *CI = Inst.call().dyn_cast<CallsiteInfo *>(); + assert(CI && CI->Clones.size() == CloneNo); + // We assign the clone number later (in updateCall), just add an entry for + // it here. + CI->Clones.push_back(0); + } + CallMap[Inst] = {Inst.call(), CloneNo}; + } + return {Func.func(), CloneNo}; +} + +// This method assigns cloned callsites to functions, cloning the functions as +// needed. The assignment is greedy and proceeds roughly as follows: +// +// For each function Func: +// For each call with graph Node having clones: +// Initialize ClonesWorklist to Node and its clones +// Initialize NodeCloneCount to 0 +// While ClonesWorklist is not empty: +// Clone = pop front ClonesWorklist +// NodeCloneCount++ +// If Func has been cloned less than NodeCloneCount times: +// If NodeCloneCount is 1: +// Assign Clone to original Func +// Continue +// Create a new function clone +// If other callers not assigned to call a function clone yet: +// Assign them to call new function clone +// Continue +// Assign any other caller calling the cloned version to new clone +// +// For each caller of Clone: +// If caller is assigned to call a specific function clone: +// If we cannot assign Clone to that function clone: +// Create new callsite Clone NewClone +// Add NewClone to ClonesWorklist +// Continue +// Assign Clone to existing caller's called function clone +// Else: +// If Clone not already assigned to a function clone: +// Assign to first function clone without assignment +// Assign caller to selected function clone +template <typename DerivedCCG, typename FuncTy, typename CallTy> +bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::assignFunctions() { + bool Changed = false; + + // Keep track of the assignment of nodes (callsites) to function clones they + // call. + DenseMap<ContextNode *, FuncInfo> CallsiteToCalleeFuncCloneMap; + + // Update caller node to call function version CalleeFunc, by recording the + // assignment in CallsiteToCalleeFuncCloneMap. + auto RecordCalleeFuncOfCallsite = [&](ContextNode *Caller, + const FuncInfo &CalleeFunc) { + assert(Caller->hasCall()); + CallsiteToCalleeFuncCloneMap[Caller] = CalleeFunc; + }; + + // Walk all functions for which we saw calls with memprof metadata, and handle + // cloning for each of its calls. + for (auto &[Func, CallsWithMetadata] : FuncToCallsWithMetadata) { + FuncInfo OrigFunc(Func); + // Map from each clone of OrigFunc to a map of remappings of each call of + // interest (from original uncloned call to the corresponding cloned call in + // that function clone). + std::map<FuncInfo, std::map<CallInfo, CallInfo>> FuncClonesToCallMap; + for (auto &Call : CallsWithMetadata) { + ContextNode *Node = getNodeForInst(Call); + // Skip call if we do not have a node for it (all uses of its stack ids + // were either on inlined chains or pruned from the MIBs), or if we did + // not create any clones for it. + if (!Node || Node->Clones.empty()) + continue; + assert(Node->hasCall() && + "Not having a call should have prevented cloning"); + + // Track the assignment of function clones to clones of the current + // callsite Node being handled. + std::map<FuncInfo, ContextNode *> FuncCloneToCurNodeCloneMap; + + // Assign callsite version CallsiteClone to function version FuncClone, + // and also assign (possibly cloned) Call to CallsiteClone. + auto AssignCallsiteCloneToFuncClone = [&](const FuncInfo &FuncClone, + CallInfo &Call, + ContextNode *CallsiteClone, + bool IsAlloc) { + // Record the clone of callsite node assigned to this function clone. + FuncCloneToCurNodeCloneMap[FuncClone] = CallsiteClone; + + assert(FuncClonesToCallMap.count(FuncClone)); + std::map<CallInfo, CallInfo> &CallMap = FuncClonesToCallMap[FuncClone]; + CallInfo CallClone(Call); + if (CallMap.count(Call)) + CallClone = CallMap[Call]; + CallsiteClone->setCall(CallClone); + }; + + // Keep track of the clones of callsite Node that need to be assigned to + // function clones. This list may be expanded in the loop body below if we + // find additional cloning is required. + std::deque<ContextNode *> ClonesWorklist; + // Ignore original Node if we moved all of its contexts to clones. + if (!Node->ContextIds.empty()) + ClonesWorklist.push_back(Node); + ClonesWorklist.insert(ClonesWorklist.end(), Node->Clones.begin(), + Node->Clones.end()); + + // Now walk through all of the clones of this callsite Node that we need, + // and determine the assignment to a corresponding clone of the current + // function (creating new function clones as needed). + unsigned NodeCloneCount = 0; + while (!ClonesWorklist.empty()) { + ContextNode *Clone = ClonesWorklist.front(); + ClonesWorklist.pop_front(); + NodeCloneCount++; + if (VerifyNodes) + checkNode<DerivedCCG, FuncTy, CallTy>(Clone); + + // Need to create a new function clone if we have more callsite clones + // than existing function clones, which would have been assigned to an + // earlier clone in the list (we assign callsite clones to function + // clones greedily). + if (FuncClonesToCallMap.size() < NodeCloneCount) { + // If this is the first callsite copy, assign to original function. + if (NodeCloneCount == 1) { + // Since FuncClonesToCallMap is empty in this case, no clones have + // been created for this function yet, and no callers should have + // been assigned a function clone for this callee node yet. + assert(llvm::none_of( + Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) { + return CallsiteToCalleeFuncCloneMap.count(E->Caller); + })); + // Initialize with empty call map, assign Clone to original function + // and its callers, and skip to the next clone. + FuncClonesToCallMap[OrigFunc] = {}; + AssignCallsiteCloneToFuncClone( + OrigFunc, Call, Clone, + AllocationCallToContextNodeMap.count(Call)); + for (auto &CE : Clone->CallerEdges) { + // Ignore any caller that does not have a recorded callsite Call. + if (!CE->Caller->hasCall()) + continue; + RecordCalleeFuncOfCallsite(CE->Caller, OrigFunc); + } + continue; + } + + // First locate which copy of OrigFunc to clone again. If a caller + // of this callsite clone was already assigned to call a particular + // function clone, we need to redirect all of those callers to the + // new function clone, and update their other callees within this + // function. + FuncInfo PreviousAssignedFuncClone; + auto EI = llvm::find_if( + Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) { + return CallsiteToCalleeFuncCloneMap.count(E->Caller); + }); + bool CallerAssignedToCloneOfFunc = false; + if (EI != Clone->CallerEdges.end()) { + const std::shared_ptr<ContextEdge> &Edge = *EI; + PreviousAssignedFuncClone = + CallsiteToCalleeFuncCloneMap[Edge->Caller]; + CallerAssignedToCloneOfFunc = true; + } + + // Clone function and save it along with the CallInfo map created + // during cloning in the FuncClonesToCallMap. + std::map<CallInfo, CallInfo> NewCallMap; + unsigned CloneNo = FuncClonesToCallMap.size(); + assert(CloneNo > 0 && "Clone 0 is the original function, which " + "should already exist in the map"); + FuncInfo NewFuncClone = cloneFunctionForCallsite( + OrigFunc, Call, NewCallMap, CallsWithMetadata, CloneNo); + FuncClonesToCallMap.emplace(NewFuncClone, std::move(NewCallMap)); + FunctionClonesAnalysis++; + Changed = true; + + // If no caller callsites were already assigned to a clone of this + // function, we can simply assign this clone to the new func clone + // and update all callers to it, then skip to the next clone. + if (!CallerAssignedToCloneOfFunc) { + AssignCallsiteCloneToFuncClone( + NewFuncClone, Call, Clone, + AllocationCallToContextNodeMap.count(Call)); + for (auto &CE : Clone->CallerEdges) { + // Ignore any caller that does not have a recorded callsite Call. + if (!CE->Caller->hasCall()) + continue; + RecordCalleeFuncOfCallsite(CE->Caller, NewFuncClone); + } + continue; + } + + // We may need to do additional node cloning in this case. + // Reset the CallsiteToCalleeFuncCloneMap entry for any callers + // that were previously assigned to call PreviousAssignedFuncClone, + // to record that they now call NewFuncClone. + for (auto CE : Clone->CallerEdges) { + // Ignore any caller that does not have a recorded callsite Call. + if (!CE->Caller->hasCall()) + continue; + + if (!CallsiteToCalleeFuncCloneMap.count(CE->Caller) || + // We subsequently fall through to later handling that + // will perform any additional cloning required for + // callers that were calling other function clones. + CallsiteToCalleeFuncCloneMap[CE->Caller] != + PreviousAssignedFuncClone) + continue; + + RecordCalleeFuncOfCallsite(CE->Caller, NewFuncClone); + + // If we are cloning a function that was already assigned to some + // callers, then essentially we are creating new callsite clones + // of the other callsites in that function that are reached by those + // callers. Clone the other callees of the current callsite's caller + // that were already assigned to PreviousAssignedFuncClone + // accordingly. This is important since we subsequently update the + // calls from the nodes in the graph and their assignments to callee + // functions recorded in CallsiteToCalleeFuncCloneMap. + for (auto CalleeEdge : CE->Caller->CalleeEdges) { + // Skip any that have been removed on an earlier iteration when + // cleaning up newly None type callee edges. + if (!CalleeEdge) + continue; + ContextNode *Callee = CalleeEdge->Callee; + // Skip the current callsite, we are looking for other + // callsites Caller calls, as well as any that does not have a + // recorded callsite Call. + if (Callee == Clone || !Callee->hasCall()) + continue; + ContextNode *NewClone = moveEdgeToNewCalleeClone(CalleeEdge); + removeNoneTypeCalleeEdges(NewClone); + // Moving the edge may have resulted in some none type + // callee edges on the original Callee. + removeNoneTypeCalleeEdges(Callee); + assert(NewClone->AllocTypes != (uint8_t)AllocationType::None); + // If the Callee node was already assigned to call a specific + // function version, make sure its new clone is assigned to call + // that same function clone. + if (CallsiteToCalleeFuncCloneMap.count(Callee)) + RecordCalleeFuncOfCallsite( + NewClone, CallsiteToCalleeFuncCloneMap[Callee]); + // Update NewClone with the new Call clone of this callsite's Call + // created for the new function clone created earlier. + // Recall that we have already ensured when building the graph + // that each caller can only call callsites within the same + // function, so we are guaranteed that Callee Call is in the + // current OrigFunc. + // CallMap is set up as indexed by original Call at clone 0. + CallInfo OrigCall(Callee->getOrigNode()->Call); + OrigCall.setCloneNo(0); + std::map<CallInfo, CallInfo> &CallMap = + FuncClonesToCallMap[NewFuncClone]; + assert(CallMap.count(OrigCall)); + CallInfo NewCall(CallMap[OrigCall]); + assert(NewCall); + NewClone->setCall(NewCall); + } + } + // Fall through to handling below to perform the recording of the + // function for this callsite clone. This enables handling of cases + // where the callers were assigned to different clones of a function. + } + + // See if we can use existing function clone. Walk through + // all caller edges to see if any have already been assigned to + // a clone of this callsite's function. If we can use it, do so. If not, + // because that function clone is already assigned to a different clone + // of this callsite, then we need to clone again. + // Basically, this checking is needed to handle the case where different + // caller functions/callsites may need versions of this function + // containing different mixes of callsite clones across the different + // callsites within the function. If that happens, we need to create + // additional function clones to handle the various combinations. + // + // Keep track of any new clones of this callsite created by the + // following loop, as well as any existing clone that we decided to + // assign this clone to. + std::map<FuncInfo, ContextNode *> FuncCloneToNewCallsiteCloneMap; + FuncInfo FuncCloneAssignedToCurCallsiteClone; + // We need to be able to remove Edge from CallerEdges, so need to adjust + // iterator in the loop. + for (auto EI = Clone->CallerEdges.begin(); + EI != Clone->CallerEdges.end();) { + auto Edge = *EI; + // Ignore any caller that does not have a recorded callsite Call. + if (!Edge->Caller->hasCall()) { + EI++; + continue; + } + // If this caller already assigned to call a version of OrigFunc, need + // to ensure we can assign this callsite clone to that function clone. + if (CallsiteToCalleeFuncCloneMap.count(Edge->Caller)) { + FuncInfo FuncCloneCalledByCaller = + CallsiteToCalleeFuncCloneMap[Edge->Caller]; + // First we need to confirm that this function clone is available + // for use by this callsite node clone. + // + // While FuncCloneToCurNodeCloneMap is built only for this Node and + // its callsite clones, one of those callsite clones X could have + // been assigned to the same function clone called by Edge's caller + // - if Edge's caller calls another callsite within Node's original + // function, and that callsite has another caller reaching clone X. + // We need to clone Node again in this case. + if ((FuncCloneToCurNodeCloneMap.count(FuncCloneCalledByCaller) && + FuncCloneToCurNodeCloneMap[FuncCloneCalledByCaller] != + Clone) || + // Detect when we have multiple callers of this callsite that + // have already been assigned to specific, and different, clones + // of OrigFunc (due to other unrelated callsites in Func they + // reach via call contexts). Is this Clone of callsite Node + // assigned to a different clone of OrigFunc? If so, clone Node + // again. + (FuncCloneAssignedToCurCallsiteClone && + FuncCloneAssignedToCurCallsiteClone != + FuncCloneCalledByCaller)) { + // We need to use a different newly created callsite clone, in + // order to assign it to another new function clone on a + // subsequent iteration over the Clones array (adjusted below). + // Note we specifically do not reset the + // CallsiteToCalleeFuncCloneMap entry for this caller, so that + // when this new clone is processed later we know which version of + // the function to copy (so that other callsite clones we have + // assigned to that function clone are properly cloned over). See + // comments in the function cloning handling earlier. + + // Check if we already have cloned this callsite again while + // walking through caller edges, for a caller calling the same + // function clone. If so, we can move this edge to that new clone + // rather than creating yet another new clone. + if (FuncCloneToNewCallsiteCloneMap.count( + FuncCloneCalledByCaller)) { + ContextNode *NewClone = + FuncCloneToNewCallsiteCloneMap[FuncCloneCalledByCaller]; + moveEdgeToExistingCalleeClone(Edge, NewClone, &EI); + // Cleanup any none type edges cloned over. + removeNoneTypeCalleeEdges(NewClone); + } else { + // Create a new callsite clone. + ContextNode *NewClone = moveEdgeToNewCalleeClone(Edge, &EI); + removeNoneTypeCalleeEdges(NewClone); + FuncCloneToNewCallsiteCloneMap[FuncCloneCalledByCaller] = + NewClone; + // Add to list of clones and process later. + ClonesWorklist.push_back(NewClone); + assert(EI == Clone->CallerEdges.end() || + Clone->AllocTypes != (uint8_t)AllocationType::None); + assert(NewClone->AllocTypes != (uint8_t)AllocationType::None); + } + // Moving the caller edge may have resulted in some none type + // callee edges. + removeNoneTypeCalleeEdges(Clone); + // We will handle the newly created callsite clone in a subsequent + // iteration over this Node's Clones. Continue here since we + // already adjusted iterator EI while moving the edge. + continue; + } + + // Otherwise, we can use the function clone already assigned to this + // caller. + if (!FuncCloneAssignedToCurCallsiteClone) { + FuncCloneAssignedToCurCallsiteClone = FuncCloneCalledByCaller; + // Assign Clone to FuncCloneCalledByCaller + AssignCallsiteCloneToFuncClone( + FuncCloneCalledByCaller, Call, Clone, + AllocationCallToContextNodeMap.count(Call)); + } else + // Don't need to do anything - callsite is already calling this + // function clone. + assert(FuncCloneAssignedToCurCallsiteClone == + FuncCloneCalledByCaller); + + } else { + // We have not already assigned this caller to a version of + // OrigFunc. Do the assignment now. + + // First check if we have already assigned this callsite clone to a + // clone of OrigFunc for another caller during this iteration over + // its caller edges. + if (!FuncCloneAssignedToCurCallsiteClone) { + // Find first function in FuncClonesToCallMap without an assigned + // clone of this callsite Node. We should always have one + // available at this point due to the earlier cloning when the + // FuncClonesToCallMap size was smaller than the clone number. + for (auto &CF : FuncClonesToCallMap) { + if (!FuncCloneToCurNodeCloneMap.count(CF.first)) { + FuncCloneAssignedToCurCallsiteClone = CF.first; + break; + } + } + assert(FuncCloneAssignedToCurCallsiteClone); + // Assign Clone to FuncCloneAssignedToCurCallsiteClone + AssignCallsiteCloneToFuncClone( + FuncCloneAssignedToCurCallsiteClone, Call, Clone, + AllocationCallToContextNodeMap.count(Call)); + } else + assert(FuncCloneToCurNodeCloneMap + [FuncCloneAssignedToCurCallsiteClone] == Clone); + // Update callers to record function version called. + RecordCalleeFuncOfCallsite(Edge->Caller, + FuncCloneAssignedToCurCallsiteClone); + } + + EI++; + } + } + if (VerifyCCG) { + checkNode<DerivedCCG, FuncTy, CallTy>(Node); + for (const auto &PE : Node->CalleeEdges) + checkNode<DerivedCCG, FuncTy, CallTy>(PE->Callee); + for (const auto &CE : Node->CallerEdges) + checkNode<DerivedCCG, FuncTy, CallTy>(CE->Caller); + for (auto *Clone : Node->Clones) { + checkNode<DerivedCCG, FuncTy, CallTy>(Clone); + for (const auto &PE : Clone->CalleeEdges) + checkNode<DerivedCCG, FuncTy, CallTy>(PE->Callee); + for (const auto &CE : Clone->CallerEdges) + checkNode<DerivedCCG, FuncTy, CallTy>(CE->Caller); + } + } + } + } + + auto UpdateCalls = [&](ContextNode *Node, + DenseSet<const ContextNode *> &Visited, + auto &&UpdateCalls) { + auto Inserted = Visited.insert(Node); + if (!Inserted.second) + return; + + for (auto *Clone : Node->Clones) + UpdateCalls(Clone, Visited, UpdateCalls); + + for (auto &Edge : Node->CallerEdges) + UpdateCalls(Edge->Caller, Visited, UpdateCalls); + + // Skip if either no call to update, or if we ended up with no context ids + // (we moved all edges onto other clones). + if (!Node->hasCall() || Node->ContextIds.empty()) + return; + + if (Node->IsAllocation) { + updateAllocationCall(Node->Call, allocTypeToUse(Node->AllocTypes)); + return; + } + + if (!CallsiteToCalleeFuncCloneMap.count(Node)) + return; + + auto CalleeFunc = CallsiteToCalleeFuncCloneMap[Node]; + updateCall(Node->Call, CalleeFunc); + }; + + // Performs DFS traversal starting from allocation nodes to update calls to + // reflect cloning decisions recorded earlier. For regular LTO this will + // update the actual calls in the IR to call the appropriate function clone + // (and add attributes to allocation calls), whereas for ThinLTO the decisions + // are recorded in the summary entries. + DenseSet<const ContextNode *> Visited; + for (auto &Entry : AllocationCallToContextNodeMap) + UpdateCalls(Entry.second, Visited, UpdateCalls); + + return Changed; +} + +static SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> createFunctionClones( + Function &F, unsigned NumClones, Module &M, OptimizationRemarkEmitter &ORE, + std::map<const Function *, SmallPtrSet<const GlobalAlias *, 1>> + &FuncToAliasMap) { + // The first "clone" is the original copy, we should only call this if we + // needed to create new clones. + assert(NumClones > 1); + SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> VMaps; + VMaps.reserve(NumClones - 1); + FunctionsClonedThinBackend++; + for (unsigned I = 1; I < NumClones; I++) { + VMaps.emplace_back(std::make_unique<ValueToValueMapTy>()); + auto *NewF = CloneFunction(&F, *VMaps.back()); + FunctionClonesThinBackend++; + // Strip memprof and callsite metadata from clone as they are no longer + // needed. + for (auto &BB : *NewF) { + for (auto &Inst : BB) { + Inst.setMetadata(LLVMContext::MD_memprof, nullptr); + Inst.setMetadata(LLVMContext::MD_callsite, nullptr); + } + } + std::string Name = getMemProfFuncName(F.getName(), I); + auto *PrevF = M.getFunction(Name); + if (PrevF) { + // We might have created this when adjusting callsite in another + // function. It should be a declaration. + assert(PrevF->isDeclaration()); + NewF->takeName(PrevF); + PrevF->replaceAllUsesWith(NewF); + PrevF->eraseFromParent(); + } else + NewF->setName(Name); + ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofClone", &F) + << "created clone " << ore::NV("NewFunction", NewF)); + + // Now handle aliases to this function, and clone those as well. + if (!FuncToAliasMap.count(&F)) + continue; + for (auto *A : FuncToAliasMap[&F]) { + std::string Name = getMemProfFuncName(A->getName(), I); + auto *PrevA = M.getNamedAlias(Name); + auto *NewA = GlobalAlias::create(A->getValueType(), + A->getType()->getPointerAddressSpace(), + A->getLinkage(), Name, NewF); + NewA->copyAttributesFrom(A); + if (PrevA) { + // We might have created this when adjusting callsite in another + // function. It should be a declaration. + assert(PrevA->isDeclaration()); + NewA->takeName(PrevA); + PrevA->replaceAllUsesWith(NewA); + PrevA->eraseFromParent(); + } + } + } + return VMaps; +} + +// Locate the summary for F. This is complicated by the fact that it might +// have been internalized or promoted. +static ValueInfo findValueInfoForFunc(const Function &F, const Module &M, + const ModuleSummaryIndex *ImportSummary) { + // FIXME: Ideally we would retain the original GUID in some fashion on the + // function (e.g. as metadata), but for now do our best to locate the + // summary without that information. + ValueInfo TheFnVI = ImportSummary->getValueInfo(F.getGUID()); + if (!TheFnVI) + // See if theFn was internalized, by checking index directly with + // original name (this avoids the name adjustment done by getGUID() for + // internal symbols). + TheFnVI = ImportSummary->getValueInfo(GlobalValue::getGUID(F.getName())); + if (TheFnVI) + return TheFnVI; + // Now query with the original name before any promotion was performed. + StringRef OrigName = + ModuleSummaryIndex::getOriginalNameBeforePromote(F.getName()); + std::string OrigId = GlobalValue::getGlobalIdentifier( + OrigName, GlobalValue::InternalLinkage, M.getSourceFileName()); + TheFnVI = ImportSummary->getValueInfo(GlobalValue::getGUID(OrigId)); + if (TheFnVI) + return TheFnVI; + // Could be a promoted local imported from another module. We need to pass + // down more info here to find the original module id. For now, try with + // the OrigName which might have been stored in the OidGuidMap in the + // index. This would not work if there were same-named locals in multiple + // modules, however. + auto OrigGUID = + ImportSummary->getGUIDFromOriginalID(GlobalValue::getGUID(OrigName)); + if (OrigGUID) + TheFnVI = ImportSummary->getValueInfo(OrigGUID); + return TheFnVI; +} + +bool MemProfContextDisambiguation::applyImport(Module &M) { + assert(ImportSummary); + bool Changed = false; + + auto IsMemProfClone = [](const Function &F) { + return F.getName().contains(MemProfCloneSuffix); + }; + + // We also need to clone any aliases that reference cloned functions, because + // the modified callsites may invoke via the alias. Keep track of the aliases + // for each function. + std::map<const Function *, SmallPtrSet<const GlobalAlias *, 1>> + FuncToAliasMap; + for (auto &A : M.aliases()) { + auto *Aliasee = A.getAliaseeObject(); + if (auto *F = dyn_cast<Function>(Aliasee)) + FuncToAliasMap[F].insert(&A); + } + + for (auto &F : M) { + if (F.isDeclaration() || IsMemProfClone(F)) + continue; + + OptimizationRemarkEmitter ORE(&F); + + SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> VMaps; + bool ClonesCreated = false; + unsigned NumClonesCreated = 0; + auto CloneFuncIfNeeded = [&](unsigned NumClones) { + // We should at least have version 0 which is the original copy. + assert(NumClones > 0); + // If only one copy needed use original. + if (NumClones == 1) + return; + // If we already performed cloning of this function, confirm that the + // requested number of clones matches (the thin link should ensure the + // number of clones for each constituent callsite is consistent within + // each function), before returning. + if (ClonesCreated) { + assert(NumClonesCreated == NumClones); + return; + } + VMaps = createFunctionClones(F, NumClones, M, ORE, FuncToAliasMap); + // The first "clone" is the original copy, which doesn't have a VMap. + assert(VMaps.size() == NumClones - 1); + Changed = true; + ClonesCreated = true; + NumClonesCreated = NumClones; + }; + + // Locate the summary for F. + ValueInfo TheFnVI = findValueInfoForFunc(F, M, ImportSummary); + // If not found, this could be an imported local (see comment in + // findValueInfoForFunc). Skip for now as it will be cloned in its original + // module (where it would have been promoted to global scope so should + // satisfy any reference in this module). + if (!TheFnVI) + continue; + + auto *GVSummary = + ImportSummary->findSummaryInModule(TheFnVI, M.getModuleIdentifier()); + if (!GVSummary) + // Must have been imported, use the first summary (might be multiple if + // this was a linkonce_odr). + GVSummary = TheFnVI.getSummaryList().front().get(); + + // If this was an imported alias skip it as we won't have the function + // summary, and it should be cloned in the original module. + if (isa<AliasSummary>(GVSummary)) + continue; + + auto *FS = cast<FunctionSummary>(GVSummary->getBaseObject()); + + if (FS->allocs().empty() && FS->callsites().empty()) + continue; + + auto SI = FS->callsites().begin(); + auto AI = FS->allocs().begin(); + + // Assume for now that the instructions are in the exact same order + // as when the summary was created, but confirm this is correct by + // matching the stack ids. + for (auto &BB : F) { + for (auto &I : BB) { + auto *CB = dyn_cast<CallBase>(&I); + // Same handling as when creating module summary. + if (!mayHaveMemprofSummary(CB)) + continue; + + CallStack<MDNode, MDNode::op_iterator> CallsiteContext( + I.getMetadata(LLVMContext::MD_callsite)); + auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof); + + // Include allocs that were already assigned a memprof function + // attribute in the statistics. + if (CB->getAttributes().hasFnAttr("memprof")) { + assert(!MemProfMD); + CB->getAttributes().getFnAttr("memprof").getValueAsString() == "cold" + ? AllocTypeColdThinBackend++ + : AllocTypeNotColdThinBackend++; + OrigAllocsThinBackend++; + AllocVersionsThinBackend++; + if (!MaxAllocVersionsThinBackend) + MaxAllocVersionsThinBackend = 1; + // Remove any remaining callsite metadata and we can skip the rest of + // the handling for this instruction, since no cloning needed. + I.setMetadata(LLVMContext::MD_callsite, nullptr); + continue; + } + + if (MemProfMD) { + // Consult the next alloc node. + assert(AI != FS->allocs().end()); + auto &AllocNode = *(AI++); + + // Sanity check that the MIB stack ids match between the summary and + // instruction metadata. + auto MIBIter = AllocNode.MIBs.begin(); + for (auto &MDOp : MemProfMD->operands()) { + assert(MIBIter != AllocNode.MIBs.end()); + LLVM_ATTRIBUTE_UNUSED auto StackIdIndexIter = + MIBIter->StackIdIndices.begin(); + auto *MIBMD = cast<const MDNode>(MDOp); + MDNode *StackMDNode = getMIBStackNode(MIBMD); + assert(StackMDNode); + SmallVector<unsigned> StackIdsFromMetadata; + CallStack<MDNode, MDNode::op_iterator> StackContext(StackMDNode); + for (auto ContextIter = + StackContext.beginAfterSharedPrefix(CallsiteContext); + ContextIter != StackContext.end(); ++ContextIter) { + // If this is a direct recursion, simply skip the duplicate + // entries, to be consistent with how the summary ids were + // generated during ModuleSummaryAnalysis. + if (!StackIdsFromMetadata.empty() && + StackIdsFromMetadata.back() == *ContextIter) + continue; + assert(StackIdIndexIter != MIBIter->StackIdIndices.end()); + assert(ImportSummary->getStackIdAtIndex(*StackIdIndexIter) == + *ContextIter); + StackIdIndexIter++; + } + MIBIter++; + } + + // Perform cloning if not yet done. + CloneFuncIfNeeded(/*NumClones=*/AllocNode.Versions.size()); + + OrigAllocsThinBackend++; + AllocVersionsThinBackend += AllocNode.Versions.size(); + if (MaxAllocVersionsThinBackend < AllocNode.Versions.size()) + MaxAllocVersionsThinBackend = AllocNode.Versions.size(); + + // If there is only one version that means we didn't end up + // considering this function for cloning, and in that case the alloc + // will still be none type or should have gotten the default NotCold. + // Skip that after calling clone helper since that does some sanity + // checks that confirm we haven't decided yet that we need cloning. + if (AllocNode.Versions.size() == 1) { + assert((AllocationType)AllocNode.Versions[0] == + AllocationType::NotCold || + (AllocationType)AllocNode.Versions[0] == + AllocationType::None); + UnclonableAllocsThinBackend++; + continue; + } + + // All versions should have a singular allocation type. + assert(llvm::none_of(AllocNode.Versions, [](uint8_t Type) { + return Type == ((uint8_t)AllocationType::NotCold | + (uint8_t)AllocationType::Cold); + })); + + // Update the allocation types per the summary info. + for (unsigned J = 0; J < AllocNode.Versions.size(); J++) { + // Ignore any that didn't get an assigned allocation type. + if (AllocNode.Versions[J] == (uint8_t)AllocationType::None) + continue; + AllocationType AllocTy = (AllocationType)AllocNode.Versions[J]; + AllocTy == AllocationType::Cold ? AllocTypeColdThinBackend++ + : AllocTypeNotColdThinBackend++; + std::string AllocTypeString = getAllocTypeAttributeString(AllocTy); + auto A = llvm::Attribute::get(F.getContext(), "memprof", + AllocTypeString); + CallBase *CBClone; + // Copy 0 is the original function. + if (!J) + CBClone = CB; + else + // Since VMaps are only created for new clones, we index with + // clone J-1 (J==0 is the original clone and does not have a VMaps + // entry). + CBClone = cast<CallBase>((*VMaps[J - 1])[CB]); + CBClone->addFnAttr(A); + ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofAttribute", CBClone) + << ore::NV("AllocationCall", CBClone) << " in clone " + << ore::NV("Caller", CBClone->getFunction()) + << " marked with memprof allocation attribute " + << ore::NV("Attribute", AllocTypeString)); + } + } else if (!CallsiteContext.empty()) { + // Consult the next callsite node. + assert(SI != FS->callsites().end()); + auto &StackNode = *(SI++); + +#ifndef NDEBUG + // Sanity check that the stack ids match between the summary and + // instruction metadata. + auto StackIdIndexIter = StackNode.StackIdIndices.begin(); + for (auto StackId : CallsiteContext) { + assert(StackIdIndexIter != StackNode.StackIdIndices.end()); + assert(ImportSummary->getStackIdAtIndex(*StackIdIndexIter) == + StackId); + StackIdIndexIter++; + } +#endif + + // Perform cloning if not yet done. + CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size()); + + // Should have skipped indirect calls via mayHaveMemprofSummary. + assert(CB->getCalledFunction()); + assert(!IsMemProfClone(*CB->getCalledFunction())); + + // Update the calls per the summary info. + // Save orig name since it gets updated in the first iteration + // below. + auto CalleeOrigName = CB->getCalledFunction()->getName(); + for (unsigned J = 0; J < StackNode.Clones.size(); J++) { + // Do nothing if this version calls the original version of its + // callee. + if (!StackNode.Clones[J]) + continue; + auto NewF = M.getOrInsertFunction( + getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]), + CB->getCalledFunction()->getFunctionType()); + CallBase *CBClone; + // Copy 0 is the original function. + if (!J) + CBClone = CB; + else + CBClone = cast<CallBase>((*VMaps[J - 1])[CB]); + CBClone->setCalledFunction(NewF); + ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofCall", CBClone) + << ore::NV("Call", CBClone) << " in clone " + << ore::NV("Caller", CBClone->getFunction()) + << " assigned to call function clone " + << ore::NV("Callee", NewF.getCallee())); + } + } + // Memprof and callsite metadata on memory allocations no longer needed. + I.setMetadata(LLVMContext::MD_memprof, nullptr); + I.setMetadata(LLVMContext::MD_callsite, nullptr); + } + } + } + + return Changed; +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::process() { + if (DumpCCG) { + dbgs() << "CCG before cloning:\n"; + dbgs() << *this; + } + if (ExportToDot) + exportToDot("postbuild"); + + if (VerifyCCG) { + check(); + } + + identifyClones(); + + if (VerifyCCG) { + check(); + } + + if (DumpCCG) { + dbgs() << "CCG after cloning:\n"; + dbgs() << *this; + } + if (ExportToDot) + exportToDot("cloned"); + + bool Changed = assignFunctions(); + + if (DumpCCG) { + dbgs() << "CCG after assigning function clones:\n"; + dbgs() << *this; + } + if (ExportToDot) + exportToDot("clonefuncassign"); + + return Changed; +} + +bool MemProfContextDisambiguation::processModule( + Module &M, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { + + // If we have an import summary, then the cloning decisions were made during + // the thin link on the index. Apply them and return. + if (ImportSummary) + return applyImport(M); + + // TODO: If/when other types of memprof cloning are enabled beyond just for + // hot and cold, we will need to change this to individually control the + // AllocationType passed to addStackNodesForMIB during CCG construction. + // Note that we specifically check this after applying imports above, so that + // the option isn't needed to be passed to distributed ThinLTO backend + // clang processes, which won't necessarily have visibility into the linker + // dependences. Instead the information is communicated from the LTO link to + // the backends via the combined summary index. + if (!SupportsHotColdNew) + return false; + + ModuleCallsiteContextGraph CCG(M, OREGetter); + return CCG.process(); +} + +MemProfContextDisambiguation::MemProfContextDisambiguation( + const ModuleSummaryIndex *Summary) + : ImportSummary(Summary) { + if (ImportSummary) { + // The MemProfImportSummary should only be used for testing ThinLTO + // distributed backend handling via opt, in which case we don't have a + // summary from the pass pipeline. + assert(MemProfImportSummary.empty()); + return; + } + if (MemProfImportSummary.empty()) + return; + + auto ReadSummaryFile = + errorOrToExpected(MemoryBuffer::getFile(MemProfImportSummary)); + if (!ReadSummaryFile) { + logAllUnhandledErrors(ReadSummaryFile.takeError(), errs(), + "Error loading file '" + MemProfImportSummary + + "': "); + return; + } + auto ImportSummaryForTestingOrErr = getModuleSummaryIndex(**ReadSummaryFile); + if (!ImportSummaryForTestingOrErr) { + logAllUnhandledErrors(ImportSummaryForTestingOrErr.takeError(), errs(), + "Error parsing file '" + MemProfImportSummary + + "': "); + return; + } + ImportSummaryForTesting = std::move(*ImportSummaryForTestingOrErr); + ImportSummary = ImportSummaryForTesting.get(); +} + +PreservedAnalyses MemProfContextDisambiguation::run(Module &M, + ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { + return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); + }; + if (!processModule(M, OREGetter)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +void MemProfContextDisambiguation::run( + ModuleSummaryIndex &Index, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + isPrevailing) { + // TODO: If/when other types of memprof cloning are enabled beyond just for + // hot and cold, we will need to change this to individually control the + // AllocationType passed to addStackNodesForMIB during CCG construction. + // The index was set from the option, so these should be in sync. + assert(Index.withSupportsHotColdNew() == SupportsHotColdNew); + if (!SupportsHotColdNew) + return; + + IndexCallsiteContextGraph CCG(Index, isPrevailing); + CCG.process(); +} diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp index 590f62ca58dd..feda5d6459cb 100644 --- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -112,8 +112,6 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -294,34 +292,8 @@ private: // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree. DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree; }; - -class MergeFunctionsLegacyPass : public ModulePass { -public: - static char ID; - - MergeFunctionsLegacyPass(): ModulePass(ID) { - initializeMergeFunctionsLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - MergeFunctions MF; - return MF.runOnModule(M); - } -}; - } // end anonymous namespace -char MergeFunctionsLegacyPass::ID = 0; -INITIALIZE_PASS(MergeFunctionsLegacyPass, "mergefunc", - "Merge Functions", false, false) - -ModulePass *llvm::createMergeFunctionsPass() { - return new MergeFunctionsLegacyPass(); -} - PreservedAnalyses MergeFunctionsPass::run(Module &M, ModuleAnalysisManager &AM) { MergeFunctions MF; diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp index ee382657f5e6..5e91ab80d750 100644 --- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp +++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp @@ -138,17 +138,12 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, // // TODO: Here is a huge amount duplicate code between the module inliner and // the SCC inliner, which need some refactoring. - auto Calls = getInlineOrder(FAM, Params); + auto Calls = getInlineOrder(FAM, Params, MAM, M); assert(Calls != nullptr && "Expected an initialized InlineOrder"); // Populate the initial list of calls in this module. for (Function &F : M) { auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - // We want to generally process call sites top-down in order for - // simplifications stemming from replacing the call with the returned value - // after inlining to be visible to subsequent inlining decisions. - // FIXME: Using instructions sequence is a really bad way to do this. - // Instead we should do an actual RPO walk of the function body. for (Instruction &I : instructions(F)) if (auto *CB = dyn_cast<CallBase>(&I)) if (Function *Callee = CB->getCalledFunction()) { @@ -213,7 +208,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, // Setup the data structure used to plumb customization into the // `InlineFunction` routine. InlineFunctionInfo IFI( - /*cg=*/nullptr, GetAssumptionCache, PSI, + GetAssumptionCache, PSI, &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), &FAM.getResult<BlockFrequencyAnalysis>(Callee)); diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index bee154dab10f..588f3901e3cb 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -22,8 +22,10 @@ #include "llvm/ADT/EnumeratedArray.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" @@ -36,6 +38,8 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instruction.h" @@ -44,7 +48,7 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/InitializePasses.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/IPO/Attributor.h" @@ -188,9 +192,9 @@ struct AAICVTracker; struct OMPInformationCache : public InformationCache { OMPInformationCache(Module &M, AnalysisGetter &AG, BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC, - KernelSet &Kernels) + bool OpenMPPostLink) : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M), - Kernels(Kernels) { + OpenMPPostLink(OpenMPPostLink) { OMPBuilder.initialize(); initializeRuntimeFunctions(M); @@ -417,7 +421,7 @@ struct OMPInformationCache : public InformationCache { // TODO: We directly convert uses into proper calls and unknown uses. for (Use &U : RFI.Declaration->uses()) { if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { - if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) { + if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) { RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); ++NumUses; } @@ -448,6 +452,24 @@ struct OMPInformationCache : public InformationCache { CI->setCallingConv(Fn->getCallingConv()); } + // Helper function to determine if it's legal to create a call to the runtime + // functions. + bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) { + // We can always emit calls if we haven't yet linked in the runtime. + if (!OpenMPPostLink) + return true; + + // Once the runtime has been already been linked in we cannot emit calls to + // any undefined functions. + for (RuntimeFunction Fn : Fns) { + RuntimeFunctionInfo &RFI = RFIs[Fn]; + + if (RFI.Declaration && RFI.Declaration->isDeclaration()) + return false; + } + return true; + } + /// Helper to initialize all runtime function information for those defined /// in OpenMPKinds.def. void initializeRuntimeFunctions(Module &M) { @@ -518,11 +540,11 @@ struct OMPInformationCache : public InformationCache { // TODO: We should attach the attributes defined in OMPKinds.def. } - /// Collection of known kernels (\see Kernel) in the module. - KernelSet &Kernels; - /// Collection of known OpenMP runtime functions.. DenseSet<const Function *> RTLFunctions; + + /// Indicates if we have already linked in the OpenMP device library. + bool OpenMPPostLink = false; }; template <typename Ty, bool InsertInvalidates = true> @@ -808,7 +830,7 @@ struct OpenMPOpt { return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); } - /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. + /// Run all OpenMP optimizations on the underlying SCC. bool run(bool IsModulePass) { if (SCC.empty()) return false; @@ -816,8 +838,7 @@ struct OpenMPOpt { bool Changed = false; LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() - << " functions in a slice with " - << OMPInfoCache.ModuleSlice.size() << " functions\n"); + << " functions\n"); if (IsModulePass) { Changed |= runAttributor(IsModulePass); @@ -882,7 +903,7 @@ struct OpenMPOpt { /// Print OpenMP GPU kernels for testing. void printKernels() const { for (Function *F : SCC) { - if (!OMPInfoCache.Kernels.count(F)) + if (!omp::isKernel(*F)) continue; auto Remark = [&](OptimizationRemarkAnalysis ORA) { @@ -1412,7 +1433,10 @@ private: Changed |= WasSplit; return WasSplit; }; - RFI.foreachUse(SCC, SplitMemTransfers); + if (OMPInfoCache.runtimeFnsAvailable( + {OMPRTL___tgt_target_data_begin_mapper_issue, + OMPRTL___tgt_target_data_begin_mapper_wait})) + RFI.foreachUse(SCC, SplitMemTransfers); return Changed; } @@ -1681,37 +1705,27 @@ private: }; if (!ReplVal) { - for (Use *U : *UV) + auto *DT = + OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F); + if (!DT) + return false; + Instruction *IP = nullptr; + for (Use *U : *UV) { if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { + if (IP) + IP = DT->findNearestCommonDominator(IP, CI); + else + IP = CI; if (!CanBeMoved(*CI)) continue; - - // If the function is a kernel, dedup will move - // the runtime call right after the kernel init callsite. Otherwise, - // it will move it to the beginning of the caller function. - if (isKernel(F)) { - auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; - auto *KernelInitUV = KernelInitRFI.getUseVector(F); - - if (KernelInitUV->empty()) - continue; - - assert(KernelInitUV->size() == 1 && - "Expected a single __kmpc_target_init in kernel\n"); - - CallInst *KernelInitCI = - getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI); - assert(KernelInitCI && - "Expected a call to __kmpc_target_init in kernel\n"); - - CI->moveAfter(KernelInitCI); - } else - CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); - ReplVal = CI; - break; + if (!ReplVal) + ReplVal = CI; } + } if (!ReplVal) return false; + assert(IP && "Expected insertion point!"); + cast<Instruction>(ReplVal)->moveBefore(IP); } // If we use a call as a replacement value we need to make sure the ident is @@ -1809,9 +1823,6 @@ private: /// ///{{ - /// Check if \p F is a kernel, hence entry point for target offloading. - bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } - /// Cache to remember the unique kernel for a function. DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap; @@ -1920,7 +1931,8 @@ public: }; Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { - if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F)) + if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() && + !OMPInfoCache.CGSCC->contains(&F)) return nullptr; // Use a scope to keep the lifetime of the CachedKernel short. @@ -2095,12 +2107,6 @@ struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { using Base = StateWrapper<BooleanState, AbstractAttribute>; AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} - void initialize(Attributor &A) override { - Function *F = getAnchorScope(); - if (!F || !A.isFunctionIPOAmendable(*F)) - indicatePessimisticFixpoint(); - } - /// Returns true if value is assumed to be tracked. bool isAssumedTracked() const { return getAssumed(); } @@ -2146,7 +2152,9 @@ struct AAICVTrackerFunction : public AAICVTracker { : AAICVTracker(IRP, A) {} // FIXME: come up with better string. - const std::string getAsStr() const override { return "ICVTrackerFunction"; } + const std::string getAsStr(Attributor *) const override { + return "ICVTrackerFunction"; + } // FIXME: come up with some stats. void trackStatistics() const override {} @@ -2242,11 +2250,12 @@ struct AAICVTrackerFunction : public AAICVTracker { if (CalledFunction->isDeclaration()) return nullptr; - const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>( *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); - if (ICVTrackingAA.isAssumedTracked()) { - std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV); + if (ICVTrackingAA->isAssumedTracked()) { + std::optional<Value *> URV = + ICVTrackingAA->getUniqueReplacementValue(ICV); if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I), OMPInfoCache))) return URV; @@ -2337,7 +2346,7 @@ struct AAICVTrackerFunctionReturned : AAICVTracker { : AAICVTracker(IRP, A) {} // FIXME: come up with better string. - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *) const override { return "ICVTrackerFunctionReturned"; } @@ -2362,10 +2371,10 @@ struct AAICVTrackerFunctionReturned : AAICVTracker { ChangeStatus updateImpl(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; - const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); - if (!ICVTrackingAA.isAssumedTracked()) + if (!ICVTrackingAA->isAssumedTracked()) return indicatePessimisticFixpoint(); for (InternalControlVar ICV : TrackableICVs) { @@ -2374,7 +2383,7 @@ struct AAICVTrackerFunctionReturned : AAICVTracker { auto CheckReturnInst = [&](Instruction &I) { std::optional<Value *> NewReplVal = - ICVTrackingAA.getReplacementValue(ICV, &I, A); + ICVTrackingAA->getReplacementValue(ICV, &I, A); // If we found a second ICV value there is no unique returned value. if (UniqueICVValue && UniqueICVValue != NewReplVal) @@ -2407,9 +2416,7 @@ struct AAICVTrackerCallSite : AAICVTracker { : AAICVTracker(IRP, A) {} void initialize(Attributor &A) override { - Function *F = getAnchorScope(); - if (!F || !A.isFunctionIPOAmendable(*F)) - indicatePessimisticFixpoint(); + assert(getAnchorScope() && "Expected anchor function"); // We only initialize this AA for getters, so we need to know which ICV it // gets. @@ -2438,7 +2445,9 @@ struct AAICVTrackerCallSite : AAICVTracker { } // FIXME: come up with better string. - const std::string getAsStr() const override { return "ICVTrackerCallSite"; } + const std::string getAsStr(Attributor *) const override { + return "ICVTrackerCallSite"; + } // FIXME: come up with some stats. void trackStatistics() const override {} @@ -2447,15 +2456,15 @@ struct AAICVTrackerCallSite : AAICVTracker { std::optional<Value *> ReplVal; ChangeStatus updateImpl(Attributor &A) override { - const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); // We don't have any information, so we assume it changes the ICV. - if (!ICVTrackingAA.isAssumedTracked()) + if (!ICVTrackingAA->isAssumedTracked()) return indicatePessimisticFixpoint(); std::optional<Value *> NewReplVal = - ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); + ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A); if (ReplVal == NewReplVal) return ChangeStatus::UNCHANGED; @@ -2477,7 +2486,7 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { : AAICVTracker(IRP, A) {} // FIXME: come up with better string. - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *) const override { return "ICVTrackerCallSiteReturned"; } @@ -2503,18 +2512,18 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { ChangeStatus updateImpl(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; - const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( + const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>( *this, IRPosition::returned(*getAssociatedFunction()), DepClassTy::REQUIRED); // We don't have any information, so we assume it changes the ICV. - if (!ICVTrackingAA.isAssumedTracked()) + if (!ICVTrackingAA->isAssumedTracked()) return indicatePessimisticFixpoint(); for (InternalControlVar ICV : TrackableICVs) { std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; std::optional<Value *> NewReplVal = - ICVTrackingAA.getUniqueReplacementValue(ICV); + ICVTrackingAA->getUniqueReplacementValue(ICV); if (ReplVal == NewReplVal) continue; @@ -2530,26 +2539,28 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) : AAExecutionDomain(IRP, A) {} - ~AAExecutionDomainFunction() { - delete RPOT; - } + ~AAExecutionDomainFunction() { delete RPOT; } void initialize(Attributor &A) override { - if (getAnchorScope()->isDeclaration()) { - indicatePessimisticFixpoint(); - return; - } - RPOT = new ReversePostOrderTraversal<Function *>(getAnchorScope()); + Function *F = getAnchorScope(); + assert(F && "Expected anchor function"); + RPOT = new ReversePostOrderTraversal<Function *>(F); } - const std::string getAsStr() const override { - unsigned TotalBlocks = 0, InitialThreadBlocks = 0; + const std::string getAsStr(Attributor *) const override { + unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0; for (auto &It : BEDMap) { + if (!It.getFirst()) + continue; TotalBlocks++; InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly; + AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly && + It.getSecond().IsReachingAlignedBarrierOnly; } return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" + - std::to_string(TotalBlocks) + " executed by initial thread only"; + std::to_string(AlignedBlocks) + " of " + + std::to_string(TotalBlocks) + + " executed by initial thread / aligned"; } /// See AbstractAttribute::trackStatistics(). @@ -2572,7 +2583,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { SmallPtrSet<CallBase *, 16> DeletedBarriers; auto HandleAlignedBarrier = [&](CallBase *CB) { - const ExecutionDomainTy &ED = CEDMap[CB]; + const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr]; if (!ED.IsReachedFromAlignedBarrierOnly || ED.EncounteredNonLocalSideEffect) return; @@ -2596,6 +2607,8 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { CallBase *LastCB = Worklist.pop_back_val(); if (!Visited.insert(LastCB)) continue; + if (LastCB->getFunction() != getAnchorScope()) + continue; if (!DeletedBarriers.count(LastCB)) { A.deleteAfterManifest(*LastCB); continue; @@ -2603,7 +2616,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { // The final aligned barrier (LastCB) reaching the kernel end was // removed already. This means we can go one step further and remove // the barriers encoutered last before (LastCB). - const ExecutionDomainTy &LastED = CEDMap[LastCB]; + const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}]; Worklist.append(LastED.AlignedBarriers.begin(), LastED.AlignedBarriers.end()); } @@ -2619,14 +2632,17 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { for (auto *CB : AlignedBarriers) HandleAlignedBarrier(CB); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); // Handle the "kernel end barrier" for kernels too. - if (OMPInfoCache.Kernels.count(getAnchorScope())) + if (omp::isKernel(*getAnchorScope())) HandleAlignedBarrier(nullptr); return Changed; } + bool isNoOpFence(const FenceInst &FI) const override { + return getState().isValidState() && !NonNoOpFences.count(&FI); + } + /// Merge barrier and assumption information from \p PredED into the successor /// \p ED. void @@ -2636,12 +2652,12 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { /// Merge all information from \p PredED into the successor \p ED. If /// \p InitialEdgeOnly is set, only the initial edge will enter the block /// represented by \p ED from this predecessor. - void mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED, + bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED, bool InitialEdgeOnly = false); /// Accumulate information for the entry block in \p EntryBBED. - void handleEntryBB(Attributor &A, ExecutionDomainTy &EntryBBED); + bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED); /// See AbstractAttribute::updateImpl. ChangeStatus updateImpl(Attributor &A) override; @@ -2651,14 +2667,18 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override { if (!isValidState()) return false; + assert(BB.getParent() == getAnchorScope() && "Block is out of scope!"); return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly; } bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const override { - if (!isValidState() || isa<CallBase>(I)) + assert(I.getFunction() == getAnchorScope() && + "Instruction is out of scope!"); + if (!isValidState()) return false; + bool ForwardIsOk = true; const Instruction *CurI; // Check forward until a call or the block end is reached. @@ -2667,15 +2687,18 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { auto *CB = dyn_cast<CallBase>(CurI); if (!CB) continue; - const auto &It = CEDMap.find(CB); + if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) + return true; + const auto &It = CEDMap.find({CB, PRE}); if (It == CEDMap.end()) continue; - if (!It->getSecond().IsReachedFromAlignedBarrierOnly) - return false; + if (!It->getSecond().IsReachingAlignedBarrierOnly) + ForwardIsOk = false; + break; } while ((CurI = CurI->getNextNonDebugInstruction())); - if (!CurI && !BEDMap.lookup(I.getParent()).IsReachedFromAlignedBarrierOnly) - return false; + if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly) + ForwardIsOk = false; // Check backward until a call or the block beginning is reached. CurI = &I; @@ -2683,33 +2706,30 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { auto *CB = dyn_cast<CallBase>(CurI); if (!CB) continue; - const auto &It = CEDMap.find(CB); + if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) + return true; + const auto &It = CEDMap.find({CB, POST}); if (It == CEDMap.end()) continue; - if (!AA::isNoSyncInst(A, *CB, *this)) { - if (It->getSecond().IsReachedFromAlignedBarrierOnly) - break; - return false; - } - - Function *Callee = CB->getCalledFunction(); - if (!Callee || Callee->isDeclaration()) - return false; - const auto &EDAA = A.getAAFor<AAExecutionDomain>( - *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL); - if (!EDAA.getState().isValidState()) - return false; - if (!EDAA.getFunctionExecutionDomain().IsReachedFromAlignedBarrierOnly) - return false; - break; + if (It->getSecond().IsReachedFromAlignedBarrierOnly) + break; + return false; } while ((CurI = CurI->getPrevNonDebugInstruction())); - if (!CurI && - !llvm::all_of( - predecessors(I.getParent()), [&](const BasicBlock *PredBB) { - return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly; - })) { + // Delayed decision on the forward pass to allow aligned barrier detection + // in the backwards traversal. + if (!ForwardIsOk) return false; + + if (!CurI) { + const BasicBlock *BB = I.getParent(); + if (BB == &BB->getParent()->getEntryBlock()) + return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly; + if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) { + return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly; + })) { + return false; + } } // On neither traversal we found a anything but aligned barriers. @@ -2721,15 +2741,16 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { "No request should be made against an invalid state!"); return BEDMap.lookup(&BB); } - ExecutionDomainTy getExecutionDomain(const CallBase &CB) const override { + std::pair<ExecutionDomainTy, ExecutionDomainTy> + getExecutionDomain(const CallBase &CB) const override { assert(isValidState() && "No request should be made against an invalid state!"); - return CEDMap.lookup(&CB); + return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})}; } ExecutionDomainTy getFunctionExecutionDomain() const override { assert(isValidState() && "No request should be made against an invalid state!"); - return BEDMap.lookup(nullptr); + return InterProceduralED; } ///} @@ -2778,12 +2799,28 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { return false; }; + /// Mapping containing information about the function for other AAs. + ExecutionDomainTy InterProceduralED; + + enum Direction { PRE = 0, POST = 1 }; /// Mapping containing information per block. DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap; - DenseMap<const CallBase *, ExecutionDomainTy> CEDMap; + DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy> + CEDMap; SmallSetVector<CallBase *, 16> AlignedBarriers; ReversePostOrderTraversal<Function *> *RPOT = nullptr; + + /// Set \p R to \V and report true if that changed \p R. + static bool setAndRecord(bool &R, bool V) { + bool Eq = (R == V); + R = V; + return !Eq; + } + + /// Collection of fences known to be non-no-opt. All fences not in this set + /// can be assumed no-opt. + SmallPtrSet<const FenceInst *, 8> NonNoOpFences; }; void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions( @@ -2795,62 +2832,82 @@ void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions( ED.addAlignedBarrier(A, *AB); } -void AAExecutionDomainFunction::mergeInPredecessor( +bool AAExecutionDomainFunction::mergeInPredecessor( Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED, bool InitialEdgeOnly) { - ED.IsExecutedByInitialThreadOnly = - InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly && - ED.IsExecutedByInitialThreadOnly); - - ED.IsReachedFromAlignedBarrierOnly = ED.IsReachedFromAlignedBarrierOnly && - PredED.IsReachedFromAlignedBarrierOnly; - ED.EncounteredNonLocalSideEffect = - ED.EncounteredNonLocalSideEffect | PredED.EncounteredNonLocalSideEffect; + + bool Changed = false; + Changed |= + setAndRecord(ED.IsExecutedByInitialThreadOnly, + InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly && + ED.IsExecutedByInitialThreadOnly)); + + Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly, + ED.IsReachedFromAlignedBarrierOnly && + PredED.IsReachedFromAlignedBarrierOnly); + Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect, + ED.EncounteredNonLocalSideEffect | + PredED.EncounteredNonLocalSideEffect); + // Do not track assumptions and barriers as part of Changed. if (ED.IsReachedFromAlignedBarrierOnly) mergeInPredecessorBarriersAndAssumptions(A, ED, PredED); else ED.clearAssumeInstAndAlignedBarriers(); + return Changed; } -void AAExecutionDomainFunction::handleEntryBB(Attributor &A, +bool AAExecutionDomainFunction::handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED) { - SmallVector<ExecutionDomainTy> PredExecDomains; + SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs; auto PredForCallSite = [&](AbstractCallSite ACS) { - const auto &EDAA = A.getAAFor<AAExecutionDomain>( + const auto *EDAA = A.getAAFor<AAExecutionDomain>( *this, IRPosition::function(*ACS.getInstruction()->getFunction()), DepClassTy::OPTIONAL); - if (!EDAA.getState().isValidState()) + if (!EDAA || !EDAA->getState().isValidState()) return false; - PredExecDomains.emplace_back( - EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction()))); + CallSiteEDs.emplace_back( + EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction()))); return true; }; + ExecutionDomainTy ExitED; bool AllCallSitesKnown; if (A.checkForAllCallSites(PredForCallSite, *this, /* RequiresAllCallSites */ true, AllCallSitesKnown)) { - for (const auto &PredED : PredExecDomains) - mergeInPredecessor(A, EntryBBED, PredED); + for (const auto &[CSInED, CSOutED] : CallSiteEDs) { + mergeInPredecessor(A, EntryBBED, CSInED); + ExitED.IsReachingAlignedBarrierOnly &= + CSOutED.IsReachingAlignedBarrierOnly; + } } else { // We could not find all predecessors, so this is either a kernel or a // function with external linkage (or with some other weird uses). - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - if (OMPInfoCache.Kernels.count(getAnchorScope())) { + if (omp::isKernel(*getAnchorScope())) { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = true; EntryBBED.EncounteredNonLocalSideEffect = false; + ExitED.IsReachingAlignedBarrierOnly = true; } else { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = false; EntryBBED.EncounteredNonLocalSideEffect = true; + ExitED.IsReachingAlignedBarrierOnly = false; } } + bool Changed = false; auto &FnED = BEDMap[nullptr]; - FnED.IsReachingAlignedBarrierOnly &= - EntryBBED.IsReachedFromAlignedBarrierOnly; + Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly, + FnED.IsReachedFromAlignedBarrierOnly & + EntryBBED.IsReachedFromAlignedBarrierOnly); + Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly, + FnED.IsReachingAlignedBarrierOnly & + ExitED.IsReachingAlignedBarrierOnly); + Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly, + EntryBBED.IsExecutedByInitialThreadOnly); + return Changed; } ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { @@ -2860,36 +2917,28 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // Helper to deal with an aligned barrier encountered during the forward // traversal. \p CB is the aligned barrier, \p ED is the execution domain when // it was encountered. - auto HandleAlignedBarrier = [&](CallBase *CB, ExecutionDomainTy &ED) { - if (CB) - Changed |= AlignedBarriers.insert(CB); + auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) { + Changed |= AlignedBarriers.insert(&CB); // First, update the barrier ED kept in the separate CEDMap. - auto &CallED = CEDMap[CB]; - mergeInPredecessor(A, CallED, ED); + auto &CallInED = CEDMap[{&CB, PRE}]; + Changed |= mergeInPredecessor(A, CallInED, ED); + CallInED.IsReachingAlignedBarrierOnly = true; // Next adjust the ED we use for the traversal. ED.EncounteredNonLocalSideEffect = false; ED.IsReachedFromAlignedBarrierOnly = true; // Aligned barrier collection has to come last. ED.clearAssumeInstAndAlignedBarriers(); - if (CB) - ED.addAlignedBarrier(A, *CB); + ED.addAlignedBarrier(A, CB); + auto &CallOutED = CEDMap[{&CB, POST}]; + Changed |= mergeInPredecessor(A, CallOutED, ED); }; - auto &LivenessAA = + auto *LivenessAA = A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); - // Set \p R to \V and report true if that changed \p R. - auto SetAndRecord = [&](bool &R, bool V) { - bool Eq = (R == V); - R = V; - return !Eq; - }; - - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - Function *F = getAnchorScope(); BasicBlock &EntryBB = F->getEntryBlock(); - bool IsKernel = OMPInfoCache.Kernels.count(F); + bool IsKernel = omp::isKernel(*F); SmallVector<Instruction *> SyncInstWorklist; for (auto &RIt : *RPOT) { @@ -2899,18 +2948,19 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // TODO: We use local reasoning since we don't have a divergence analysis // running as well. We could basically allow uniform branches here. bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel; + bool IsExplicitlyAligned = IsEntryBB && IsKernel; ExecutionDomainTy ED; // Propagate "incoming edges" into information about this block. if (IsEntryBB) { - handleEntryBB(A, ED); + Changed |= handleCallees(A, ED); } else { // For live non-entry blocks we only propagate // information via live edges. - if (LivenessAA.isAssumedDead(&BB)) + if (LivenessAA && LivenessAA->isAssumedDead(&BB)) continue; for (auto *PredBB : predecessors(&BB)) { - if (LivenessAA.isEdgeDead(PredBB, &BB)) + if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB)) continue; bool InitialEdgeOnly = isInitialThreadOnlyEdge( A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB); @@ -2922,7 +2972,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // information to calls. for (Instruction &I : BB) { bool UsedAssumedInformation; - if (A.isAssumedDead(I, *this, &LivenessAA, UsedAssumedInformation, + if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation, /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL, /* CheckForDeadStore */ true)) continue; @@ -2939,6 +2989,33 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { continue; } + if (auto *FI = dyn_cast<FenceInst>(&I)) { + if (!ED.EncounteredNonLocalSideEffect) { + // An aligned fence without non-local side-effects is a no-op. + if (ED.IsReachedFromAlignedBarrierOnly) + continue; + // A non-aligned fence without non-local side-effects is a no-op + // if the ordering only publishes non-local side-effects (or less). + switch (FI->getOrdering()) { + case AtomicOrdering::NotAtomic: + continue; + case AtomicOrdering::Unordered: + continue; + case AtomicOrdering::Monotonic: + continue; + case AtomicOrdering::Acquire: + break; + case AtomicOrdering::Release: + continue; + case AtomicOrdering::AcquireRelease: + break; + case AtomicOrdering::SequentiallyConsistent: + break; + }; + } + NonNoOpFences.insert(FI); + } + auto *CB = dyn_cast<CallBase>(&I); bool IsNoSync = AA::isNoSyncInst(A, I, *this); bool IsAlignedBarrier = @@ -2946,14 +3023,16 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock); AlignedBarrierLastInBlock &= IsNoSync; + IsExplicitlyAligned &= IsNoSync; // Next we check for calls. Aligned barriers are handled // explicitly, everything else is kept for the backward traversal and will // also affect our state. if (CB) { if (IsAlignedBarrier) { - HandleAlignedBarrier(CB, ED); + HandleAlignedBarrier(*CB, ED); AlignedBarrierLastInBlock = true; + IsExplicitlyAligned = true; continue; } @@ -2971,20 +3050,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // Record how we entered the call, then accumulate the effect of the // call in ED for potential use by the callee. - auto &CallED = CEDMap[CB]; - mergeInPredecessor(A, CallED, ED); + auto &CallInED = CEDMap[{CB, PRE}]; + Changed |= mergeInPredecessor(A, CallInED, ED); // If we have a sync-definition we can check if it starts/ends in an // aligned barrier. If we are unsure we assume any sync breaks // alignment. Function *Callee = CB->getCalledFunction(); if (!IsNoSync && Callee && !Callee->isDeclaration()) { - const auto &EDAA = A.getAAFor<AAExecutionDomain>( + const auto *EDAA = A.getAAFor<AAExecutionDomain>( *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL); - if (EDAA.getState().isValidState()) { - const auto &CalleeED = EDAA.getFunctionExecutionDomain(); + if (EDAA && EDAA->getState().isValidState()) { + const auto &CalleeED = EDAA->getFunctionExecutionDomain(); ED.IsReachedFromAlignedBarrierOnly = - CalleeED.IsReachedFromAlignedBarrierOnly; + CalleeED.IsReachedFromAlignedBarrierOnly; AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly; if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly) ED.EncounteredNonLocalSideEffect |= @@ -2992,19 +3071,27 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { else ED.EncounteredNonLocalSideEffect = CalleeED.EncounteredNonLocalSideEffect; - if (!CalleeED.IsReachingAlignedBarrierOnly) + if (!CalleeED.IsReachingAlignedBarrierOnly) { + Changed |= + setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false); SyncInstWorklist.push_back(&I); + } if (CalleeED.IsReachedFromAlignedBarrierOnly) mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED); + auto &CallOutED = CEDMap[{CB, POST}]; + Changed |= mergeInPredecessor(A, CallOutED, ED); continue; } } - ED.IsReachedFromAlignedBarrierOnly = - IsNoSync && ED.IsReachedFromAlignedBarrierOnly; + if (!IsNoSync) { + ED.IsReachedFromAlignedBarrierOnly = false; + Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false); + SyncInstWorklist.push_back(&I); + } AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly; ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory(); - if (!IsNoSync) - SyncInstWorklist.push_back(&I); + auto &CallOutED = CEDMap[{CB, POST}]; + Changed |= mergeInPredecessor(A, CallOutED, ED); } if (!I.mayHaveSideEffects() && !I.mayReadFromMemory()) @@ -3013,7 +3100,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // If we have a callee we try to use fine-grained information to // determine local side-effects. if (CB) { - const auto &MemAA = A.getAAFor<AAMemoryLocation>( + const auto *MemAA = A.getAAFor<AAMemoryLocation>( *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); auto AccessPred = [&](const Instruction *I, const Value *Ptr, @@ -3021,13 +3108,14 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { AAMemoryLocation::MemoryLocationsKind) { return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I); }; - if (MemAA.getState().isValidState() && - MemAA.checkForAllAccessesToMemoryKind( + if (MemAA && MemAA->getState().isValidState() && + MemAA->checkForAllAccessesToMemoryKind( AccessPred, AAMemoryLocation::ALL_LOCATIONS)) continue; } - if (!I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(I)) + auto &InfoCache = A.getInfoCache(); + if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I)) continue; if (auto *LI = dyn_cast<LoadInst>(&I)) @@ -3039,18 +3127,28 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { ED.EncounteredNonLocalSideEffect = true; } + bool IsEndAndNotReachingAlignedBarriersOnly = false; if (!isa<UnreachableInst>(BB.getTerminator()) && !BB.getTerminator()->getNumSuccessors()) { - auto &FnED = BEDMap[nullptr]; - mergeInPredecessor(A, FnED, ED); + Changed |= mergeInPredecessor(A, InterProceduralED, ED); - if (IsKernel) - HandleAlignedBarrier(nullptr, ED); + auto &FnED = BEDMap[nullptr]; + if (IsKernel && !IsExplicitlyAligned) + FnED.IsReachingAlignedBarrierOnly = false; + Changed |= mergeInPredecessor(A, FnED, ED); + + if (!FnED.IsReachingAlignedBarrierOnly) { + IsEndAndNotReachingAlignedBarriersOnly = true; + SyncInstWorklist.push_back(BB.getTerminator()); + auto &BBED = BEDMap[&BB]; + Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false); + } } ExecutionDomainTy &StoredED = BEDMap[&BB]; - ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly; + ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly & + !IsEndAndNotReachingAlignedBarriersOnly; // Check if we computed anything different as part of the forward // traversal. We do not take assumptions and aligned barriers into account @@ -3074,36 +3172,38 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { while (!SyncInstWorklist.empty()) { Instruction *SyncInst = SyncInstWorklist.pop_back_val(); Instruction *CurInst = SyncInst; - bool HitAlignedBarrier = false; + bool HitAlignedBarrierOrKnownEnd = false; while ((CurInst = CurInst->getPrevNode())) { auto *CB = dyn_cast<CallBase>(CurInst); if (!CB) continue; - auto &CallED = CEDMap[CB]; - if (SetAndRecord(CallED.IsReachingAlignedBarrierOnly, false)) - Changed = true; - HitAlignedBarrier = AlignedBarriers.count(CB); - if (HitAlignedBarrier) + auto &CallOutED = CEDMap[{CB, POST}]; + Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false); + auto &CallInED = CEDMap[{CB, PRE}]; + HitAlignedBarrierOrKnownEnd = + AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly; + if (HitAlignedBarrierOrKnownEnd) break; + Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false); } - if (HitAlignedBarrier) + if (HitAlignedBarrierOrKnownEnd) continue; BasicBlock *SyncBB = SyncInst->getParent(); for (auto *PredBB : predecessors(SyncBB)) { - if (LivenessAA.isEdgeDead(PredBB, SyncBB)) + if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB)) continue; if (!Visited.insert(PredBB)) continue; - SyncInstWorklist.push_back(PredBB->getTerminator()); auto &PredED = BEDMap[PredBB]; - if (SetAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) + if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) { Changed = true; + SyncInstWorklist.push_back(PredBB->getTerminator()); + } } if (SyncBB != &EntryBB) continue; - auto &FnED = BEDMap[nullptr]; - if (SetAndRecord(FnED.IsReachingAlignedBarrierOnly, false)) - Changed = true; + Changed |= + setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false); } return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; @@ -3146,7 +3246,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A) : AAHeapToShared(IRP, A) {} - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *) const override { return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) + " malloc calls eligible."; } @@ -3261,7 +3361,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue()); auto *SharedMem = new GlobalVariable( *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage, - UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr, + PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr, GlobalValue::NotThreadLocal, static_cast<unsigned>(AddressSpace::Shared)); auto *NewBuffer = @@ -3270,7 +3370,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { auto Remark = [&](OptimizationRemark OR) { return OR << "Replaced globalized variable with " << ore::NV("SharedMemory", AllocSize->getZExtValue()) - << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ") + << (AllocSize->isOne() ? " byte " : " bytes ") << "of shared memory."; }; A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark); @@ -3278,7 +3378,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { MaybeAlign Alignment = CB->getRetAlign(); assert(Alignment && "HeapToShared on allocation without alignment attribute"); - SharedMem->setAlignment(MaybeAlign(Alignment)); + SharedMem->setAlignment(*Alignment); A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer); A.deleteAfterManifest(*CB); @@ -3315,9 +3415,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared { MallocCalls.remove(CB); continue; } - const auto &ED = A.getAAFor<AAExecutionDomain>( + const auto *ED = A.getAAFor<AAExecutionDomain>( *this, IRPosition::function(*F), DepClassTy::REQUIRED); - if (!ED.isExecutedByInitialThreadOnly(*CB)) + if (!ED || !ED->isExecutedByInitialThreadOnly(*CB)) MallocCalls.remove(CB); } } @@ -3346,7 +3446,7 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { void trackStatistics() const override {} /// See AbstractAttribute::getAsStr() - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *) const override { if (!isValidState()) return "<invalid>"; return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD" @@ -3456,22 +3556,7 @@ struct AAKernelInfoFunction : AAKernelInfo { Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = [&](const IRPosition &IRP, const AbstractAttribute *AA, bool &UsedAssumedInformation) -> std::optional<Value *> { - // IRP represents the "use generic state machine" argument of an - // __kmpc_target_init call. We will answer this one with the internal - // state. As long as we are not in an invalid state, we will create a - // custom state machine so the value should be a `i1 false`. If we are - // in an invalid state, we won't change the value that is in the IR. - if (!ReachedKnownParallelRegions.isValidState()) - return nullptr; - // If we have disabled state machine rewrites, don't make a custom one. - if (DisableOpenMPOptStateMachineRewrite) return nullptr; - if (AA) - A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); - UsedAssumedInformation = !isAtFixpoint(); - auto *FalseVal = - ConstantInt::getBool(IRP.getAnchorValue().getContext(), false); - return FalseVal; }; Attributor::SimplifictionCallbackTy ModeSimplifyCB = @@ -3622,10 +3707,11 @@ struct AAKernelInfoFunction : AAKernelInfo { Function *Kernel = getAnchorScope(); Module &M = *Kernel->getParent(); Type *Int8Ty = Type::getInt8Ty(M.getContext()); - new GlobalVariable(M, Int8Ty, /* isConstant */ true, - GlobalValue::WeakAnyLinkage, - ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0), - Kernel->getName() + "_nested_parallelism"); + auto *GV = new GlobalVariable( + M, Int8Ty, /* isConstant */ true, GlobalValue::WeakAnyLinkage, + ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0), + Kernel->getName() + "_nested_parallelism"); + GV->setVisibility(GlobalValue::HiddenVisibility); // If we can we change the execution mode to SPMD-mode otherwise we build a // custom state machine. @@ -3914,6 +4000,12 @@ struct AAKernelInfoFunction : AAKernelInfo { bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + // We cannot change to SPMD mode if the runtime functions aren't availible. + if (!OMPInfoCache.runtimeFnsAvailable( + {OMPRTL___kmpc_get_hardware_thread_id_in_block, + OMPRTL___kmpc_barrier_simple_spmd})) + return false; + if (!SPMDCompatibilityTracker.isAssumed()) { for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) { if (!NonCompatibleI) @@ -3951,7 +4043,7 @@ struct AAKernelInfoFunction : AAKernelInfo { auto *CB = cast<CallBase>(Kernel->user_back()); Kernel = CB->getCaller(); } - assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!"); + assert(omp::isKernel(*Kernel) && "Expected kernel function!"); // Check if the kernel is already in SPMD mode, if so, return success. GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( @@ -4021,6 +4113,13 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!ReachedKnownParallelRegions.isValidState()) return ChangeStatus::UNCHANGED; + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + if (!OMPInfoCache.runtimeFnsAvailable( + {OMPRTL___kmpc_get_hardware_num_threads_in_block, + OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic, + OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel})) + return ChangeStatus::UNCHANGED; + const int InitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; @@ -4167,7 +4266,6 @@ struct AAKernelInfoFunction : AAKernelInfo { BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB); Module &M = *Kernel->getParent(); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); FunctionCallee BlockHwSizeFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_get_hardware_num_threads_in_block); @@ -4220,10 +4318,7 @@ struct AAKernelInfoFunction : AAKernelInfo { if (WorkFnAI->getType()->getPointerAddressSpace() != (unsigned int)AddressSpace::Generic) { WorkFnAI = new AddrSpaceCastInst( - WorkFnAI, - PointerType::getWithSamePointeeType( - cast<PointerType>(WorkFnAI->getType()), - (unsigned int)AddressSpace::Generic), + WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic), WorkFnAI->getName() + ".generic", StateMachineBeginBB); WorkFnAI->setDebugLoc(DLoc); } @@ -4345,19 +4440,20 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!I.mayWriteToMemory()) return true; if (auto *SI = dyn_cast<StoreInst>(&I)) { - const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>( + const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>( *this, IRPosition::value(*SI->getPointerOperand()), DepClassTy::OPTIONAL); - auto &HS = A.getAAFor<AAHeapToStack>( + auto *HS = A.getAAFor<AAHeapToStack>( *this, IRPosition::function(*I.getFunction()), DepClassTy::OPTIONAL); - if (UnderlyingObjsAA.forallUnderlyingObjects([&](Value &Obj) { + if (UnderlyingObjsAA && + UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) { if (AA::isAssumedThreadLocalObject(A, Obj, *this)) return true; // Check for AAHeapToStack moved objects which must not be // guarded. auto *CB = dyn_cast<CallBase>(&Obj); - return CB && HS.isAssumedHeapToStack(*CB); + return CB && HS && HS->isAssumedHeapToStack(*CB); })) return true; } @@ -4392,14 +4488,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // we cannot fix the internal spmd-zation state either. int SPMD = 0, Generic = 0; for (auto *Kernel : ReachingKernelEntries) { - auto &CBAA = A.getAAFor<AAKernelInfo>( + auto *CBAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL); - if (CBAA.SPMDCompatibilityTracker.isValidState() && - CBAA.SPMDCompatibilityTracker.isAssumed()) + if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() && + CBAA->SPMDCompatibilityTracker.isAssumed()) ++SPMD; else ++Generic; - if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint()) + if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint()) UsedAssumedInformationFromReachingKernels = true; } if (SPMD != 0 && Generic != 0) @@ -4413,14 +4509,16 @@ struct AAKernelInfoFunction : AAKernelInfo { bool AllSPMDStatesWereFixed = true; auto CheckCallInst = [&](Instruction &I) { auto &CB = cast<CallBase>(I); - auto &CBAA = A.getAAFor<AAKernelInfo>( + auto *CBAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); - getState() ^= CBAA.getState(); - AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint(); + if (!CBAA) + return false; + getState() ^= CBAA->getState(); + AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint(); AllParallelRegionStatesWereFixed &= - CBAA.ReachedKnownParallelRegions.isAtFixpoint(); + CBAA->ReachedKnownParallelRegions.isAtFixpoint(); AllParallelRegionStatesWereFixed &= - CBAA.ReachedUnknownParallelRegions.isAtFixpoint(); + CBAA->ReachedUnknownParallelRegions.isAtFixpoint(); return true; }; @@ -4460,10 +4558,10 @@ private: assert(Caller && "Caller is nullptr"); - auto &CAA = A.getOrCreateAAFor<AAKernelInfo>( + auto *CAA = A.getOrCreateAAFor<AAKernelInfo>( IRPosition::function(*Caller), this, DepClassTy::REQUIRED); - if (CAA.ReachingKernelEntries.isValidState()) { - ReachingKernelEntries ^= CAA.ReachingKernelEntries; + if (CAA && CAA->ReachingKernelEntries.isValidState()) { + ReachingKernelEntries ^= CAA->ReachingKernelEntries; return true; } @@ -4491,9 +4589,9 @@ private: assert(Caller && "Caller is nullptr"); - auto &CAA = + auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller)); - if (CAA.ParallelLevels.isValidState()) { + if (CAA && CAA->ParallelLevels.isValidState()) { // Any function that is called by `__kmpc_parallel_51` will not be // folded as the parallel level in the function is updated. In order to // get it right, all the analysis would depend on the implentation. That @@ -4504,7 +4602,7 @@ private: return true; } - ParallelLevels ^= CAA.ParallelLevels; + ParallelLevels ^= CAA->ParallelLevels; return true; } @@ -4538,11 +4636,11 @@ struct AAKernelInfoCallSite : AAKernelInfo { CallBase &CB = cast<CallBase>(getAssociatedValue()); Function *Callee = getAssociatedFunction(); - auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>( + auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>( *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); // Check for SPMD-mode assumptions. - if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) { + if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) { SPMDCompatibilityTracker.indicateOptimisticFixpoint(); indicateOptimisticFixpoint(); } @@ -4567,8 +4665,9 @@ struct AAKernelInfoCallSite : AAKernelInfo { // Unknown callees might contain parallel regions, except if they have // an appropriate assumption attached. - if (!(AssumptionAA.hasAssumption("omp_no_openmp") || - AssumptionAA.hasAssumption("omp_no_parallelism"))) + if (!AssumptionAA || + !(AssumptionAA->hasAssumption("omp_no_openmp") || + AssumptionAA->hasAssumption("omp_no_parallelism"))) ReachedUnknownParallelRegions.insert(&CB); // If SPMDCompatibilityTracker is not fixed, we need to give up on the @@ -4643,11 +4742,11 @@ struct AAKernelInfoCallSite : AAKernelInfo { CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { ReachedKnownParallelRegions.insert(ParallelRegion); /// Check nested parallelism - auto &FnAA = A.getAAFor<AAKernelInfo>( + auto *FnAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); - NestedParallelism |= !FnAA.getState().isValidState() || - !FnAA.ReachedKnownParallelRegions.empty() || - !FnAA.ReachedUnknownParallelRegions.empty(); + NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || + !FnAA->ReachedKnownParallelRegions.empty() || + !FnAA->ReachedUnknownParallelRegions.empty(); break; } // The condition above should usually get the parallel region function @@ -4691,10 +4790,12 @@ struct AAKernelInfoCallSite : AAKernelInfo { // If F is not a runtime function, propagate the AAKernelInfo of the callee. if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); - if (getState() == FnAA.getState()) + auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); + if (!FnAA) + return indicatePessimisticFixpoint(); + if (getState() == FnAA->getState()) return ChangeStatus::UNCHANGED; - getState() = FnAA.getState(); + getState() = FnAA->getState(); return ChangeStatus::CHANGED; } @@ -4707,9 +4808,9 @@ struct AAKernelInfoCallSite : AAKernelInfo { CallBase &CB = cast<CallBase>(getAssociatedValue()); - auto &HeapToStackAA = A.getAAFor<AAHeapToStack>( + auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); - auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>( + auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); RuntimeFunction RF = It->getSecond(); @@ -4718,13 +4819,15 @@ struct AAKernelInfoCallSite : AAKernelInfo { // If neither HeapToStack nor HeapToShared assume the call is removed, // assume SPMD incompatibility. case OMPRTL___kmpc_alloc_shared: - if (!HeapToStackAA.isAssumedHeapToStack(CB) && - !HeapToSharedAA.isAssumedHeapToShared(CB)) + if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && + (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) SPMDCompatibilityTracker.insert(&CB); break; case OMPRTL___kmpc_free_shared: - if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) && - !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB)) + if ((!HeapToStackAA || + !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && + (!HeapToSharedAA || + !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) SPMDCompatibilityTracker.insert(&CB); break; default: @@ -4770,7 +4873,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { : AAFoldRuntimeCall(IRP, A) {} /// See AbstractAttribute::getAsStr() - const std::string getAsStr() const override { + const std::string getAsStr(Attributor *) const override { if (!isValidState()) return "<invalid>"; @@ -4883,28 +4986,29 @@ private: unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; - auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( + auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); - if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) + if (!CallerKernelInfoAA || + !CallerKernelInfoAA->ReachingKernelEntries.isValidState()) return indicatePessimisticFixpoint(); - for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { - auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), + for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) { + auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), DepClassTy::REQUIRED); - if (!AA.isValidState()) { + if (!AA || !AA->isValidState()) { SimplifiedValue = nullptr; return indicatePessimisticFixpoint(); } - if (AA.SPMDCompatibilityTracker.isAssumed()) { - if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + if (AA->SPMDCompatibilityTracker.isAssumed()) { + if (AA->SPMDCompatibilityTracker.isAtFixpoint()) ++KnownSPMDCount; else ++AssumedSPMDCount; } else { - if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + if (AA->SPMDCompatibilityTracker.isAtFixpoint()) ++KnownNonSPMDCount; else ++AssumedNonSPMDCount; @@ -4943,16 +5047,17 @@ private: ChangeStatus foldParallelLevel(Attributor &A) { std::optional<Value *> SimplifiedValueBefore = SimplifiedValue; - auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( + auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); - if (!CallerKernelInfoAA.ParallelLevels.isValidState()) + if (!CallerKernelInfoAA || + !CallerKernelInfoAA->ParallelLevels.isValidState()) return indicatePessimisticFixpoint(); - if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) + if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState()) return indicatePessimisticFixpoint(); - if (CallerKernelInfoAA.ReachingKernelEntries.empty()) { + if (CallerKernelInfoAA->ReachingKernelEntries.empty()) { assert(!SimplifiedValue && "SimplifiedValue should keep none at this point"); return ChangeStatus::UNCHANGED; @@ -4960,19 +5065,19 @@ private: unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; - for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { - auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), + for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) { + auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), DepClassTy::REQUIRED); - if (!AA.SPMDCompatibilityTracker.isValidState()) + if (!AA || !AA->SPMDCompatibilityTracker.isValidState()) return indicatePessimisticFixpoint(); - if (AA.SPMDCompatibilityTracker.isAssumed()) { - if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + if (AA->SPMDCompatibilityTracker.isAssumed()) { + if (AA->SPMDCompatibilityTracker.isAtFixpoint()) ++KnownSPMDCount; else ++AssumedSPMDCount; } else { - if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + if (AA->SPMDCompatibilityTracker.isAtFixpoint()) ++KnownNonSPMDCount; else ++AssumedNonSPMDCount; @@ -5005,14 +5110,15 @@ private: int32_t CurrentAttrValue = -1; std::optional<Value *> SimplifiedValueBefore = SimplifiedValue; - auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( + auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); - if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) + if (!CallerKernelInfoAA || + !CallerKernelInfoAA->ReachingKernelEntries.isValidState()) return indicatePessimisticFixpoint(); // Iterate over the kernels that reach this function - for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { + for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) { int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1); if (NextAttrVal == -1 || @@ -5135,6 +5241,8 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) { A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F)); if (!DisableOpenMPOptDeglobalization) A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F)); + if (F.hasFnAttribute(Attribute::Convergent)) + A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F)); for (auto &I : instructions(F)) { if (auto *LI = dyn_cast<LoadInst>(&I)) { @@ -5147,6 +5255,10 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) { A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); continue; } + if (auto *FI = dyn_cast<FenceInst>(&I)) { + A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI)); + continue; + } if (auto *II = dyn_cast<IntrinsicInst>(&I)) { if (II->getIntrinsicID() == Intrinsic::assume) { A.getOrCreateAAFor<AAPotentialValues>( @@ -5304,6 +5416,8 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { }); }; + bool Changed = false; + // Create internal copies of each function if this is a kernel Module. This // allows iterprocedural passes to see every call edge. DenseMap<Function *, Function *> InternalizedMap; @@ -5319,7 +5433,8 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { } } - Attributor::internalizeFunctions(InternalizeFns, InternalizedMap); + Changed |= + Attributor::internalizeFunctions(InternalizeFns, InternalizedMap); } // Look at every function in the Module unless it was internalized. @@ -5332,7 +5447,7 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { } if (SCC.empty()) - return PreservedAnalyses::all(); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); AnalysisGetter AG(FAM); @@ -5343,7 +5458,9 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { BumpPtrAllocator Allocator; CallGraphUpdater CGUpdater; - OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels); + bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink || + LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink; + OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; @@ -5356,11 +5473,14 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { AC.OREGetter = OREGetter; AC.PassName = DEBUG_TYPE; AC.InitializationCallback = OpenMPOpt::registerAAsForFunction; + AC.IPOAmendableCB = [](const Function &F) { + return F.hasFnAttribute("kernel"); + }; Attributor A(Functions, InfoCache, AC); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); - bool Changed = OMPOpt.run(true); + Changed |= OMPOpt.run(true); // Optionally inline device functions for potentially better performance. if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M)) @@ -5417,9 +5537,11 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, CallGraphUpdater CGUpdater; CGUpdater.initialize(CG, C, AM, UR); + bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink || + LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink; SetVector<Function *> Functions(SCC.begin(), SCC.end()); OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, - /*CGSCC*/ &Functions, Kernels); + /*CGSCC*/ &Functions, PostLink); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; @@ -5447,6 +5569,8 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } +bool llvm::omp::isKernel(Function &Fn) { return Fn.hasFnAttribute("kernel"); } + KernelSet llvm::omp::getDeviceKernels(Module &M) { // TODO: Create a more cross-platform way of determining device kernels. NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations"); @@ -5467,6 +5591,7 @@ KernelSet llvm::omp::getDeviceKernels(Module &M) { if (!KernelFn) continue; + assert(isKernel(*KernelFn) && "Inconsistent kernel function annotation"); ++NumOpenMPTargetRegionKernels; Kernels.insert(KernelFn); diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp index 310e4d4164a5..b88ba2dec24b 100644 --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -41,8 +42,6 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/User.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/BlockFrequency.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/Casting.h" @@ -342,52 +341,6 @@ private: OptimizationRemarkEmitter &ORE) const; }; -struct PartialInlinerLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - - PartialInlinerLegacyPass() : ModulePass(ID) { - initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>(); - TargetTransformInfoWrapperPass *TTIWP = - &getAnalysis<TargetTransformInfoWrapperPass>(); - ProfileSummaryInfo &PSI = - getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - - auto GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & { - return ACT->getAssumptionCache(F); - }; - - auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * { - return ACT->lookupAssumptionCache(F); - }; - - auto GetTTI = [&TTIWP](Function &F) -> TargetTransformInfo & { - return TTIWP->getTTI(F); - }; - - auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - - return PartialInlinerImpl(GetAssumptionCache, LookupAssumptionCache, GetTTI, - GetTLI, PSI) - .run(M); - } -}; - } // end anonymous namespace std::unique_ptr<FunctionOutliningMultiRegionInfo> @@ -1027,7 +980,7 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( // Go through all Outline Candidate Regions and update all BasicBlock // information. - for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo : + for (const FunctionOutliningMultiRegionInfo::OutlineRegionInfo &RegionInfo : OI->ORI) { SmallVector<BasicBlock *, 8> Region; for (BasicBlock *BB : RegionInfo.Region) @@ -1226,14 +1179,14 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { ToExtract.push_back(ClonedOI->NonReturnBlock); OutlinedRegionCost += PartialInlinerImpl::computeBBInlineCost( ClonedOI->NonReturnBlock, ClonedFuncTTI); - for (BasicBlock &BB : *ClonedFunc) - if (!ToBeInlined(&BB) && &BB != ClonedOI->NonReturnBlock) { - ToExtract.push_back(&BB); + for (BasicBlock *BB : depth_first(&ClonedFunc->getEntryBlock())) + if (!ToBeInlined(BB) && BB != ClonedOI->NonReturnBlock) { + ToExtract.push_back(BB); // FIXME: the code extractor may hoist/sink more code // into the outlined function which may make the outlining // overhead (the difference of the outlined function cost // and OutliningRegionCost) look larger. - OutlinedRegionCost += computeBBInlineCost(&BB, ClonedFuncTTI); + OutlinedRegionCost += computeBBInlineCost(BB, ClonedFuncTTI); } // Extract the body of the if. @@ -1429,7 +1382,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { OR << ore::NV("Callee", Cloner.OrigFunc) << " partially inlined into " << ore::NV("Caller", CB->getCaller()); - InlineFunctionInfo IFI(nullptr, GetAssumptionCache, &PSI); + InlineFunctionInfo IFI(GetAssumptionCache, &PSI); // We can only forward varargs when we outlined a single region, else we // bail on vararg functions. if (!InlineFunction(*CB, IFI, /*MergeAttributes=*/false, nullptr, true, @@ -1497,21 +1450,6 @@ bool PartialInlinerImpl::run(Module &M) { return Changed; } -char PartialInlinerLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner", - "Partial Inliner", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner", - "Partial Inliner", false, false) - -ModulePass *llvm::createPartialInliningPass() { - return new PartialInlinerLegacyPass(); -} - PreservedAnalyses PartialInlinerPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp deleted file mode 100644 index 6b91c8494f39..000000000000 --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ /dev/null @@ -1,517 +0,0 @@ -//===- PassManagerBuilder.cpp - Build Standard Pass -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the PassManagerBuilder class, which is used to set up a -// "standard" optimization sequence suitable for languages like C and C++. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "llvm-c/Transforms/PassManagerBuilder.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/ScopedNoAliasAA.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/ManagedStatic.h" -#include "llvm/Target/CGPassBuilderOption.h" -#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" -#include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/IPO/Attributor.h" -#include "llvm/Transforms/IPO/ForceFunctionAttrs.h" -#include "llvm/Transforms/IPO/FunctionAttrs.h" -#include "llvm/Transforms/IPO/InferFunctionAttrs.h" -#include "llvm/Transforms/InstCombine/InstCombine.h" -#include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/GVN.h" -#include "llvm/Transforms/Scalar/LICM.h" -#include "llvm/Transforms/Scalar/LoopUnrollPass.h" -#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" -#include "llvm/Transforms/Utils.h" -#include "llvm/Transforms/Vectorize.h" - -using namespace llvm; - -PassManagerBuilder::PassManagerBuilder() { - OptLevel = 2; - SizeLevel = 0; - LibraryInfo = nullptr; - Inliner = nullptr; - DisableUnrollLoops = false; - SLPVectorize = false; - LoopVectorize = true; - LoopsInterleaved = true; - LicmMssaOptCap = SetLicmMssaOptCap; - LicmMssaNoAccForPromotionCap = SetLicmMssaNoAccForPromotionCap; - DisableGVNLoadPRE = false; - ForgetAllSCEVInLoopUnroll = ForgetSCEVInLoopUnroll; - VerifyInput = false; - VerifyOutput = false; - MergeFunctions = false; - DivergentTarget = false; - CallGraphProfile = true; -} - -PassManagerBuilder::~PassManagerBuilder() { - delete LibraryInfo; - delete Inliner; -} - -void PassManagerBuilder::addInitialAliasAnalysisPasses( - legacy::PassManagerBase &PM) const { - // Add TypeBasedAliasAnalysis before BasicAliasAnalysis so that - // BasicAliasAnalysis wins if they disagree. This is intended to help - // support "obvious" type-punning idioms. - PM.add(createTypeBasedAAWrapperPass()); - PM.add(createScopedNoAliasAAWrapperPass()); -} - -void PassManagerBuilder::populateFunctionPassManager( - legacy::FunctionPassManager &FPM) { - // Add LibraryInfo if we have some. - if (LibraryInfo) - FPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); - - if (OptLevel == 0) return; - - addInitialAliasAnalysisPasses(FPM); - - // Lower llvm.expect to metadata before attempting transforms. - // Compare/branch metadata may alter the behavior of passes like SimplifyCFG. - FPM.add(createLowerExpectIntrinsicPass()); - FPM.add(createCFGSimplificationPass()); - FPM.add(createSROAPass()); - FPM.add(createEarlyCSEPass()); -} - -void PassManagerBuilder::addFunctionSimplificationPasses( - legacy::PassManagerBase &MPM) { - // Start of function pass. - // Break up aggregate allocas, using SSAUpdater. - assert(OptLevel >= 1 && "Calling function optimizer with no optimization level!"); - MPM.add(createSROAPass()); - MPM.add(createEarlyCSEPass(true /* Enable mem-ssa. */)); // Catch trivial redundancies - - if (OptLevel > 1) { - // Speculative execution if the target has divergent branches; otherwise nop. - MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass()); - - MPM.add(createJumpThreadingPass()); // Thread jumps. - MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals - } - MPM.add( - createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp( - true))); // Merge & remove BBs - // Combine silly seq's - MPM.add(createInstructionCombiningPass()); - if (SizeLevel == 0) - MPM.add(createLibCallsShrinkWrapPass()); - - // TODO: Investigate the cost/benefit of tail call elimination on debugging. - if (OptLevel > 1) - MPM.add(createTailCallEliminationPass()); // Eliminate tail calls - MPM.add( - createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp( - true))); // Merge & remove BBs - MPM.add(createReassociatePass()); // Reassociate expressions - - // Begin the loop pass pipeline. - - // The simple loop unswitch pass relies on separate cleanup passes. Schedule - // them first so when we re-process a loop they run before other loop - // passes. - MPM.add(createLoopInstSimplifyPass()); - MPM.add(createLoopSimplifyCFGPass()); - - // Try to remove as much code from the loop header as possible, - // to reduce amount of IR that will have to be duplicated. However, - // do not perform speculative hoisting the first time as LICM - // will destroy metadata that may not need to be destroyed if run - // after loop rotation. - // TODO: Investigate promotion cap for O1. - MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/false)); - // Rotate Loop - disable header duplication at -Oz - MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1, false)); - // TODO: Investigate promotion cap for O1. - MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/true)); - MPM.add(createSimpleLoopUnswitchLegacyPass(OptLevel == 3)); - // FIXME: We break the loop pass pipeline here in order to do full - // simplifycfg. Eventually loop-simplifycfg should be enhanced to replace the - // need for this. - MPM.add(createCFGSimplificationPass( - SimplifyCFGOptions().convertSwitchRangeToICmp(true))); - MPM.add(createInstructionCombiningPass()); - // We resume loop passes creating a second loop pipeline here. - MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. - MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars - MPM.add(createLoopDeletionPass()); // Delete dead loops - - // Unroll small loops and perform peeling. - MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, - ForgetAllSCEVInLoopUnroll)); - // This ends the loop pass pipelines. - - // Break up allocas that may now be splittable after loop unrolling. - MPM.add(createSROAPass()); - - if (OptLevel > 1) { - MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds - MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies - } - MPM.add(createSCCPPass()); // Constant prop with SCCP - - // Delete dead bit computations (instcombine runs after to fold away the dead - // computations, and then ADCE will run later to exploit any new DCE - // opportunities that creates). - MPM.add(createBitTrackingDCEPass()); // Delete dead bit computations - - // Run instcombine after redundancy elimination to exploit opportunities - // opened up by them. - MPM.add(createInstructionCombiningPass()); - if (OptLevel > 1) { - MPM.add(createJumpThreadingPass()); // Thread jumps - MPM.add(createCorrelatedValuePropagationPass()); - } - MPM.add(createAggressiveDCEPass()); // Delete dead instructions - - MPM.add(createMemCpyOptPass()); // Remove memcpy / form memset - // TODO: Investigate if this is too expensive at O1. - if (OptLevel > 1) { - MPM.add(createDeadStoreEliminationPass()); // Delete dead stores - MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/true)); - } - - // Merge & remove BBs and sink & hoist common instructions. - MPM.add(createCFGSimplificationPass( - SimplifyCFGOptions().hoistCommonInsts(true).sinkCommonInsts(true))); - // Clean up after everything. - MPM.add(createInstructionCombiningPass()); -} - -/// FIXME: Should LTO cause any differences to this set of passes? -void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM, - bool IsFullLTO) { - PM.add(createLoopVectorizePass(!LoopsInterleaved, !LoopVectorize)); - - if (IsFullLTO) { - // The vectorizer may have significantly shortened a loop body; unroll - // again. Unroll small loops to hide loop backedge latency and saturate any - // parallel execution resources of an out-of-order processor. We also then - // need to clean up redundancies and loop invariant code. - // FIXME: It would be really good to use a loop-integrated instruction - // combiner for cleanup here so that the unrolling and LICM can be pipelined - // across the loop nests. - PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, - ForgetAllSCEVInLoopUnroll)); - PM.add(createWarnMissedTransformationsPass()); - } - - if (!IsFullLTO) { - // Eliminate loads by forwarding stores from the previous iteration to loads - // of the current iteration. - PM.add(createLoopLoadEliminationPass()); - } - // Cleanup after the loop optimization passes. - PM.add(createInstructionCombiningPass()); - - // Now that we've formed fast to execute loop structures, we do further - // optimizations. These are run afterward as they might block doing complex - // analyses and transforms such as what are needed for loop vectorization. - - // Cleanup after loop vectorization, etc. Simplification passes like CVP and - // GVN, loop transforms, and others have already run, so it's now better to - // convert to more optimized IR using more aggressive simplify CFG options. - // The extra sinking transform can create larger basic blocks, so do this - // before SLP vectorization. - PM.add(createCFGSimplificationPass(SimplifyCFGOptions() - .forwardSwitchCondToPhi(true) - .convertSwitchRangeToICmp(true) - .convertSwitchToLookupTable(true) - .needCanonicalLoops(false) - .hoistCommonInsts(true) - .sinkCommonInsts(true))); - - if (IsFullLTO) { - PM.add(createSCCPPass()); // Propagate exposed constants - PM.add(createInstructionCombiningPass()); // Clean up again - PM.add(createBitTrackingDCEPass()); - } - - // Optimize parallel scalar instruction chains into SIMD instructions. - if (SLPVectorize) { - PM.add(createSLPVectorizerPass()); - } - - // Enhance/cleanup vector code. - PM.add(createVectorCombinePass()); - - if (!IsFullLTO) { - PM.add(createInstructionCombiningPass()); - - // Unroll small loops - PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, - ForgetAllSCEVInLoopUnroll)); - - if (!DisableUnrollLoops) { - // LoopUnroll may generate some redundency to cleanup. - PM.add(createInstructionCombiningPass()); - - // Runtime unrolling will introduce runtime check in loop prologue. If the - // unrolled loop is a inner loop, then the prologue will be inside the - // outer loop. LICM pass can help to promote the runtime check out if the - // checked value is loop invariant. - PM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/true)); - } - - PM.add(createWarnMissedTransformationsPass()); - } - - // After vectorization and unrolling, assume intrinsics may tell us more - // about pointer alignments. - PM.add(createAlignmentFromAssumptionsPass()); - - if (IsFullLTO) - PM.add(createInstructionCombiningPass()); -} - -void PassManagerBuilder::populateModulePassManager( - legacy::PassManagerBase &MPM) { - MPM.add(createAnnotation2MetadataLegacyPass()); - - // Allow forcing function attributes as a debugging and tuning aid. - MPM.add(createForceFunctionAttrsLegacyPass()); - - // If all optimizations are disabled, just run the always-inline pass and, - // if enabled, the function merging pass. - if (OptLevel == 0) { - if (Inliner) { - MPM.add(Inliner); - Inliner = nullptr; - } - - // FIXME: The BarrierNoopPass is a HACK! The inliner pass above implicitly - // creates a CGSCC pass manager, but we don't want to add extensions into - // that pass manager. To prevent this we insert a no-op module pass to reset - // the pass manager to get the same behavior as EP_OptimizerLast in non-O0 - // builds. The function merging pass is - if (MergeFunctions) - MPM.add(createMergeFunctionsPass()); - return; - } - - // Add LibraryInfo if we have some. - if (LibraryInfo) - MPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); - - addInitialAliasAnalysisPasses(MPM); - - // Infer attributes about declarations if possible. - MPM.add(createInferFunctionAttrsLegacyPass()); - - if (OptLevel > 2) - MPM.add(createCallSiteSplittingPass()); - - MPM.add(createIPSCCPPass()); // IP SCCP - MPM.add(createCalledValuePropagationPass()); - - MPM.add(createGlobalOptimizerPass()); // Optimize out global vars - // Promote any localized global vars. - MPM.add(createPromoteMemoryToRegisterPass()); - - MPM.add(createDeadArgEliminationPass()); // Dead argument elimination - - MPM.add(createInstructionCombiningPass()); // Clean up after IPCP & DAE - MPM.add( - createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp( - true))); // Clean up after IPCP & DAE - - // We add a module alias analysis pass here. In part due to bugs in the - // analysis infrastructure this "works" in that the analysis stays alive - // for the entire SCC pass run below. - MPM.add(createGlobalsAAWrapperPass()); - - // Start of CallGraph SCC passes. - bool RunInliner = false; - if (Inliner) { - MPM.add(Inliner); - Inliner = nullptr; - RunInliner = true; - } - - MPM.add(createPostOrderFunctionAttrsLegacyPass()); - - addFunctionSimplificationPasses(MPM); - - // FIXME: This is a HACK! The inliner pass above implicitly creates a CGSCC - // pass manager that we are specifically trying to avoid. To prevent this - // we must insert a no-op module pass to reset the pass manager. - MPM.add(createBarrierNoopPass()); - - if (OptLevel > 1) - // Remove avail extern fns and globals definitions if we aren't - // compiling an object file for later LTO. For LTO we want to preserve - // these so they are eligible for inlining at link-time. Note if they - // are unreferenced they will be removed by GlobalDCE later, so - // this only impacts referenced available externally globals. - // Eventually they will be suppressed during codegen, but eliminating - // here enables more opportunity for GlobalDCE as it may make - // globals referenced by available external functions dead - // and saves running remaining passes on the eliminated functions. - MPM.add(createEliminateAvailableExternallyPass()); - - MPM.add(createReversePostOrderFunctionAttrsPass()); - - // The inliner performs some kind of dead code elimination as it goes, - // but there are cases that are not really caught by it. We might - // at some point consider teaching the inliner about them, but it - // is OK for now to run GlobalOpt + GlobalDCE in tandem as their - // benefits generally outweight the cost, making the whole pipeline - // faster. - if (RunInliner) { - MPM.add(createGlobalOptimizerPass()); - MPM.add(createGlobalDCEPass()); - } - - // We add a fresh GlobalsModRef run at this point. This is particularly - // useful as the above will have inlined, DCE'ed, and function-attr - // propagated everything. We should at this point have a reasonably minimal - // and richly annotated call graph. By computing aliasing and mod/ref - // information for all local globals here, the late loop passes and notably - // the vectorizer will be able to use them to help recognize vectorizable - // memory operations. - // - // Note that this relies on a bug in the pass manager which preserves - // a module analysis into a function pass pipeline (and throughout it) so - // long as the first function pass doesn't invalidate the module analysis. - // Thus both Float2Int and LoopRotate have to preserve AliasAnalysis for - // this to work. Fortunately, it is trivial to preserve AliasAnalysis - // (doing nothing preserves it as it is required to be conservatively - // correct in the face of IR changes). - MPM.add(createGlobalsAAWrapperPass()); - - MPM.add(createFloat2IntPass()); - MPM.add(createLowerConstantIntrinsicsPass()); - - // Re-rotate loops in all our loop nests. These may have fallout out of - // rotated form due to GVN or other transformations, and the vectorizer relies - // on the rotated form. Disable header duplication at -Oz. - MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1, false)); - - // Distribute loops to allow partial vectorization. I.e. isolate dependences - // into separate loop that would otherwise inhibit vectorization. This is - // currently only performed for loops marked with the metadata - // llvm.loop.distribute=true or when -enable-loop-distribute is specified. - MPM.add(createLoopDistributePass()); - - addVectorPasses(MPM, /* IsFullLTO */ false); - - // FIXME: We shouldn't bother with this anymore. - MPM.add(createStripDeadPrototypesPass()); // Get rid of dead prototypes - - // GlobalOpt already deletes dead functions and globals, at -O2 try a - // late pass of GlobalDCE. It is capable of deleting dead cycles. - if (OptLevel > 1) { - MPM.add(createGlobalDCEPass()); // Remove dead fns and globals. - MPM.add(createConstantMergePass()); // Merge dup global constants - } - - if (MergeFunctions) - MPM.add(createMergeFunctionsPass()); - - // LoopSink pass sinks instructions hoisted by LICM, which serves as a - // canonicalization pass that enables other optimizations. As a result, - // LoopSink pass needs to be a very late IR pass to avoid undoing LICM - // result too early. - MPM.add(createLoopSinkPass()); - // Get rid of LCSSA nodes. - MPM.add(createInstSimplifyLegacyPass()); - - // This hoists/decomposes div/rem ops. It should run after other sink/hoist - // passes to avoid re-sinking, but before SimplifyCFG because it can allow - // flattening of blocks. - MPM.add(createDivRemPairsPass()); - - // LoopSink (and other loop passes since the last simplifyCFG) might have - // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. - MPM.add(createCFGSimplificationPass( - SimplifyCFGOptions().convertSwitchRangeToICmp(true))); -} - -LLVMPassManagerBuilderRef LLVMPassManagerBuilderCreate() { - PassManagerBuilder *PMB = new PassManagerBuilder(); - return wrap(PMB); -} - -void LLVMPassManagerBuilderDispose(LLVMPassManagerBuilderRef PMB) { - PassManagerBuilder *Builder = unwrap(PMB); - delete Builder; -} - -void -LLVMPassManagerBuilderSetOptLevel(LLVMPassManagerBuilderRef PMB, - unsigned OptLevel) { - PassManagerBuilder *Builder = unwrap(PMB); - Builder->OptLevel = OptLevel; -} - -void -LLVMPassManagerBuilderSetSizeLevel(LLVMPassManagerBuilderRef PMB, - unsigned SizeLevel) { - PassManagerBuilder *Builder = unwrap(PMB); - Builder->SizeLevel = SizeLevel; -} - -void -LLVMPassManagerBuilderSetDisableUnitAtATime(LLVMPassManagerBuilderRef PMB, - LLVMBool Value) { - // NOTE: The DisableUnitAtATime switch has been removed. -} - -void -LLVMPassManagerBuilderSetDisableUnrollLoops(LLVMPassManagerBuilderRef PMB, - LLVMBool Value) { - PassManagerBuilder *Builder = unwrap(PMB); - Builder->DisableUnrollLoops = Value; -} - -void -LLVMPassManagerBuilderSetDisableSimplifyLibCalls(LLVMPassManagerBuilderRef PMB, - LLVMBool Value) { - // NOTE: The simplify-libcalls pass has been removed. -} - -void -LLVMPassManagerBuilderUseInlinerWithThreshold(LLVMPassManagerBuilderRef PMB, - unsigned Threshold) { - PassManagerBuilder *Builder = unwrap(PMB); - Builder->Inliner = createFunctionInliningPass(Threshold); -} - -void -LLVMPassManagerBuilderPopulateFunctionPassManager(LLVMPassManagerBuilderRef PMB, - LLVMPassManagerRef PM) { - PassManagerBuilder *Builder = unwrap(PMB); - legacy::FunctionPassManager *FPM = unwrap<legacy::FunctionPassManager>(PM); - Builder->populateFunctionPassManager(*FPM); -} - -void -LLVMPassManagerBuilderPopulateModulePassManager(LLVMPassManagerBuilderRef PMB, - LLVMPassManagerRef PM) { - PassManagerBuilder *Builder = unwrap(PMB); - legacy::PassManagerBase *MPM = unwrap(PM); - Builder->populateModulePassManager(*MPM); -} diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp index 5c1582ddfdae..e2e6364df906 100644 --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -13,14 +13,14 @@ #include "llvm/Transforms/IPO/SCCP.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/InitializePasses.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/CommandLine.h" @@ -42,8 +42,8 @@ STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); STATISTIC(NumInstReplaced, "Number of instructions replaced with (simpler) instruction"); -static cl::opt<unsigned> FuncSpecializationMaxIters( - "func-specialization-max-iters", cl::init(1), cl::Hidden, cl::desc( +static cl::opt<unsigned> FuncSpecMaxIters( + "funcspec-max-iters", cl::init(1), cl::Hidden, cl::desc( "The maximum number of iterations function specialization is run")); static void findReturnsToZap(Function &F, @@ -111,10 +111,12 @@ static bool runIPSCCP( std::function<const TargetLibraryInfo &(Function &)> GetTLI, std::function<TargetTransformInfo &(Function &)> GetTTI, std::function<AssumptionCache &(Function &)> GetAC, - function_ref<AnalysisResultsForFn(Function &)> getAnalysis, + std::function<DominatorTree &(Function &)> GetDT, + std::function<BlockFrequencyInfo &(Function &)> GetBFI, bool IsFuncSpecEnabled) { SCCPSolver Solver(DL, GetTLI, M.getContext()); - FunctionSpecializer Specializer(Solver, M, FAM, GetTLI, GetTTI, GetAC); + FunctionSpecializer Specializer(Solver, M, FAM, GetBFI, GetTLI, GetTTI, + GetAC); // Loop over all functions, marking arguments to those with their addresses // taken or that are external as overdefined. @@ -122,7 +124,9 @@ static bool runIPSCCP( if (F.isDeclaration()) continue; - Solver.addAnalysis(F, getAnalysis(F)); + DominatorTree &DT = GetDT(F); + AssumptionCache &AC = GetAC(F); + Solver.addPredicateInfo(F, DT, AC); // Determine if we can track the function's return values. If so, add the // function to the solver's set of return-tracked functions. @@ -158,7 +162,7 @@ static bool runIPSCCP( if (IsFuncSpecEnabled) { unsigned Iters = 0; - while (Iters++ < FuncSpecializationMaxIters && Specializer.run()); + while (Iters++ < FuncSpecMaxIters && Specializer.run()); } // Iterate over all of the instructions in the module, replacing them with @@ -187,8 +191,8 @@ static bool runIPSCCP( if (ME == MemoryEffects::unknown()) return AL; - ME |= MemoryEffects(MemoryEffects::Other, - ME.getModRef(MemoryEffects::ArgMem)); + ME |= MemoryEffects(IRMemLocation::Other, + ME.getModRef(IRMemLocation::ArgMem)); return AL.addFnAttribute( F.getContext(), Attribute::getWithMemoryEffects(F.getContext(), ME)); @@ -223,10 +227,9 @@ static bool runIPSCCP( BB, InsertedValues, NumInstRemoved, NumInstReplaced); } - DomTreeUpdater DTU = IsFuncSpecEnabled && Specializer.isClonedFunction(&F) - ? DomTreeUpdater(DomTreeUpdater::UpdateStrategy::Lazy) - : Solver.getDTU(F); - + DominatorTree *DT = FAM->getCachedResult<DominatorTreeAnalysis>(F); + PostDominatorTree *PDT = FAM->getCachedResult<PostDominatorTreeAnalysis>(F); + DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); // Change dead blocks to unreachable. We do it after replacing constants // in all executable blocks, because changeToUnreachable may remove PHI // nodes in executable blocks we found values for. The function's entry @@ -292,13 +295,6 @@ static bool runIPSCCP( if (!CB || CB->getCalledFunction() != F) continue; - // Limit to cases where the return value is guaranteed to be neither - // poison nor undef. Poison will be outside any range and currently - // values outside of the specified range cause immediate undefined - // behavior. - if (!isGuaranteedNotToBeUndefOrPoison(CB, nullptr, CB)) - continue; - // Do not touch existing metadata for now. // TODO: We should be able to take the intersection of the existing // metadata and the inferred range. @@ -338,9 +334,14 @@ static bool runIPSCCP( // Remove the returned attribute for zapped functions and the // corresponding call sites. + // Also remove any attributes that convert an undef return value into + // immediate undefined behavior + AttributeMask UBImplyingAttributes = + AttributeFuncs::getUBImplyingAttributes(); for (Function *F : FuncZappedReturn) { for (Argument &A : F->args()) F->removeParamAttr(A.getArgNo(), Attribute::Returned); + F->removeRetAttrs(UBImplyingAttributes); for (Use &U : F->uses()) { CallBase *CB = dyn_cast<CallBase>(U.getUser()); if (!CB) { @@ -354,6 +355,7 @@ static bool runIPSCCP( for (Use &Arg : CB->args()) CB->removeParamAttr(CB->getArgOperandNo(&Arg), Attribute::Returned); + CB->removeRetAttrs(UBImplyingAttributes); } } @@ -368,9 +370,9 @@ static bool runIPSCCP( while (!GV->use_empty()) { StoreInst *SI = cast<StoreInst>(GV->user_back()); SI->eraseFromParent(); - MadeChanges = true; } - M.getGlobalList().erase(GV); + MadeChanges = true; + M.eraseGlobalVariable(GV); ++NumGlobalConst; } @@ -389,15 +391,15 @@ PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { auto GetAC = [&FAM](Function &F) -> AssumptionCache & { return FAM.getResult<AssumptionAnalysis>(F); }; - auto getAnalysis = [&FAM, this](Function &F) -> AnalysisResultsForFn { - DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); - return { - std::make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)), - &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F), - isFuncSpecEnabled() ? &FAM.getResult<LoopAnalysis>(F) : nullptr }; + auto GetDT = [&FAM](Function &F) -> DominatorTree & { + return FAM.getResult<DominatorTreeAnalysis>(F); }; + auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { + return FAM.getResult<BlockFrequencyAnalysis>(F); + }; + - if (!runIPSCCP(M, DL, &FAM, GetTLI, GetTTI, GetAC, getAnalysis, + if (!runIPSCCP(M, DL, &FAM, GetTLI, GetTTI, GetAC, GetDT, GetBFI, isFuncSpecEnabled())) return PreservedAnalyses::all(); @@ -407,73 +409,3 @@ PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { PA.preserve<FunctionAnalysisManagerModuleProxy>(); return PA; } - -namespace { - -//===--------------------------------------------------------------------===// -// -/// IPSCCP Class - This class implements interprocedural Sparse Conditional -/// Constant Propagation. -/// -class IPSCCPLegacyPass : public ModulePass { -public: - static char ID; - - IPSCCPLegacyPass() : ModulePass(ID) { - initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - const DataLayout &DL = M.getDataLayout(); - auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - auto GetTTI = [this](Function &F) -> TargetTransformInfo & { - return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - }; - auto GetAC = [this](Function &F) -> AssumptionCache & { - return this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - }; - auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn { - DominatorTree &DT = - this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); - return { - std::make_unique<PredicateInfo>( - F, DT, - this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - F)), - nullptr, // We cannot preserve the LI, DT or PDT with the legacy pass - nullptr, // manager, so set them to nullptr. - nullptr}; - }; - - return runIPSCCP(M, DL, nullptr, GetTLI, GetTTI, GetAC, getAnalysis, false); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - -char IPSCCPLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", - "Interprocedural Sparse Conditional Constant Propagation", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", - "Interprocedural Sparse Conditional Constant Propagation", - false, false) - -// createIPSCCPPass - This is the public interface to this file. -ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } - diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp index 93b368fd72a6..a53baecd4776 100644 --- a/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -35,9 +35,9 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" -#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ReplayInlineAdvisor.h" @@ -58,8 +58,6 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/PseudoProbe.h" #include "llvm/IR/ValueSymbolTable.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/SampleProf.h" #include "llvm/ProfileData/SampleProfReader.h" @@ -67,6 +65,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorOr.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/ProfiledCallGraph.h" @@ -129,6 +128,11 @@ static cl::opt<std::string> SampleProfileRemappingFile( "sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"), cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden); +static cl::opt<bool> SalvageStaleProfile( + "salvage-stale-profile", cl::Hidden, cl::init(false), + cl::desc("Salvage stale profile by fuzzy matching and use the remapped " + "location for sample profile query.")); + static cl::opt<bool> ReportProfileStaleness( "report-profile-staleness", cl::Hidden, cl::init(false), cl::desc("Compute and report stale profile statistical metrics.")); @@ -138,6 +142,11 @@ static cl::opt<bool> PersistProfileStaleness( cl::desc("Compute stale profile statistical metrics and write it into the " "native object file(.llvm_stats section).")); +static cl::opt<bool> FlattenProfileForMatching( + "flatten-profile-for-matching", cl::Hidden, cl::init(true), + cl::desc( + "Use flattened profile for stale profile detection and matching.")); + static cl::opt<bool> ProfileSampleAccurate( "profile-sample-accurate", cl::Hidden, cl::init(false), cl::desc("If the sample profile is accurate, we will mark all un-sampled " @@ -173,9 +182,6 @@ static cl::opt<bool> cl::desc("Process functions in a top-down order " "defined by the profiled call graph when " "-sample-profile-top-down-load is on.")); -cl::opt<bool> - SortProfiledSCC("sort-profiled-scc-member", cl::init(true), cl::Hidden, - cl::desc("Sort profiled recursion by edge weights.")); static cl::opt<bool> ProfileSizeInline( "sample-profile-inline-size", cl::Hidden, cl::init(false), @@ -191,6 +197,11 @@ static cl::opt<bool> DisableSampleLoaderInlining( "pass, and merge (or scale) profiles (as configured by " "--sample-profile-merge-inlinee).")); +namespace llvm { +cl::opt<bool> + SortProfiledSCC("sort-profiled-scc-member", cl::init(true), cl::Hidden, + cl::desc("Sort profiled recursion by edge weights.")); + cl::opt<int> ProfileInlineGrowthLimit( "sample-profile-inline-growth-limit", cl::Hidden, cl::init(12), cl::desc("The size growth ratio limit for proirity-based sample profile " @@ -214,6 +225,7 @@ cl::opt<int> SampleHotCallSiteThreshold( cl::opt<int> SampleColdCallSiteThreshold( "sample-profile-cold-inline-threshold", cl::Hidden, cl::init(45), cl::desc("Threshold for inlining cold callsites")); +} // namespace llvm static cl::opt<unsigned> ProfileICPRelativeHotness( "sample-profile-icp-relative-hotness", cl::Hidden, cl::init(25), @@ -307,7 +319,9 @@ static cl::opt<bool> AnnotateSampleProfileInlinePhase( cl::desc("Annotate LTO phase (prelink / postlink), or main (no LTO) for " "sample-profile inline pass name.")); +namespace llvm { extern cl::opt<bool> EnableExtTspBlockPlacement; +} namespace { @@ -428,6 +442,11 @@ class SampleProfileMatcher { Module &M; SampleProfileReader &Reader; const PseudoProbeManager *ProbeManager; + SampleProfileMap FlattenedProfiles; + // For each function, the matcher generates a map, of which each entry is a + // mapping from the source location of current build to the source location in + // the profile. + StringMap<LocToLocMap> FuncMappings; // Profile mismatching statstics. uint64_t TotalProfiledCallsites = 0; @@ -442,9 +461,43 @@ class SampleProfileMatcher { public: SampleProfileMatcher(Module &M, SampleProfileReader &Reader, const PseudoProbeManager *ProbeManager) - : M(M), Reader(Reader), ProbeManager(ProbeManager) {} - void detectProfileMismatch(); - void detectProfileMismatch(const Function &F, const FunctionSamples &FS); + : M(M), Reader(Reader), ProbeManager(ProbeManager) { + if (FlattenProfileForMatching) { + ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, + FunctionSamples::ProfileIsCS); + } + } + void runOnModule(); + +private: + FunctionSamples *getFlattenedSamplesFor(const Function &F) { + StringRef CanonFName = FunctionSamples::getCanonicalFnName(F); + auto It = FlattenedProfiles.find(CanonFName); + if (It != FlattenedProfiles.end()) + return &It->second; + return nullptr; + } + void runOnFunction(const Function &F, const FunctionSamples &FS); + void countProfileMismatches( + const FunctionSamples &FS, + const std::unordered_set<LineLocation, LineLocationHash> + &MatchedCallsiteLocs, + uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites); + + LocToLocMap &getIRToProfileLocationMap(const Function &F) { + auto Ret = FuncMappings.try_emplace( + FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap()); + return Ret.first->second; + } + void distributeIRToProfileLocationMap(); + void distributeIRToProfileLocationMap(FunctionSamples &FS); + void populateProfileCallsites( + const FunctionSamples &FS, + StringMap<std::set<LineLocation>> &CalleeToCallsitesMap); + void runStaleProfileMatching( + const std::map<LineLocation, StringRef> &IRLocations, + StringMap<std::set<LineLocation>> &CalleeToCallsitesMap, + LocToLocMap &IRToProfileLocationMap); }; /// Sample profile pass. @@ -452,15 +505,16 @@ public: /// This pass reads profile data from the file specified by /// -sample-profile-file and annotates every affected function with the /// profile information found in that file. -class SampleProfileLoader final - : public SampleProfileLoaderBaseImpl<BasicBlock> { +class SampleProfileLoader final : public SampleProfileLoaderBaseImpl<Function> { public: SampleProfileLoader( StringRef Name, StringRef RemapName, ThinOrFullLTOPhase LTOPhase, + IntrusiveRefCntPtr<vfs::FileSystem> FS, std::function<AssumptionCache &(Function &)> GetAssumptionCache, std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo, std::function<const TargetLibraryInfo &(Function &)> GetTLI) - : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)), + : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName), + std::move(FS)), GetAC(std::move(GetAssumptionCache)), GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)), LTOPhase(LTOPhase), @@ -471,13 +525,12 @@ public: bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr); bool runOnModule(Module &M, ModuleAnalysisManager *AM, - ProfileSummaryInfo *_PSI, CallGraph *CG); + ProfileSummaryInfo *_PSI, LazyCallGraph &CG); protected: bool runOnFunction(Function &F, ModuleAnalysisManager *AM); bool emitAnnotations(Function &F); ErrorOr<uint64_t> getInstWeight(const Instruction &I) override; - ErrorOr<uint64_t> getProbeWeight(const Instruction &I); const FunctionSamples *findCalleeFunctionSamples(const CallBase &I) const; const FunctionSamples * findFunctionSamples(const Instruction &I) const override; @@ -512,8 +565,8 @@ protected: void promoteMergeNotInlinedContextSamples( MapVector<CallBase *, const FunctionSamples *> NonInlinedCallSites, const Function &F); - std::vector<Function *> buildFunctionOrder(Module &M, CallGraph *CG); - std::unique_ptr<ProfiledCallGraph> buildProfiledCallGraph(CallGraph &CG); + std::vector<Function *> buildFunctionOrder(Module &M, LazyCallGraph &CG); + std::unique_ptr<ProfiledCallGraph> buildProfiledCallGraph(Module &M); void generateMDProfMetadata(Function &F); /// Map from function name to Function *. Used to find the function from @@ -573,9 +626,6 @@ protected: // External inline advisor used to replay inline decision from remarks. std::unique_ptr<InlineAdvisor> ExternalInlineAdvisor; - // A pseudo probe helper to correlate the imported sample counts. - std::unique_ptr<PseudoProbeManager> ProbeManager; - // A helper to implement the sample profile matching algorithm. std::unique_ptr<SampleProfileMatcher> MatchingManager; @@ -586,6 +636,50 @@ private: }; } // end anonymous namespace +namespace llvm { +template <> +inline bool SampleProfileInference<Function>::isExit(const BasicBlock *BB) { + return succ_empty(BB); +} + +template <> +inline void SampleProfileInference<Function>::findUnlikelyJumps( + const std::vector<const BasicBlockT *> &BasicBlocks, + BlockEdgeMap &Successors, FlowFunction &Func) { + for (auto &Jump : Func.Jumps) { + const auto *BB = BasicBlocks[Jump.Source]; + const auto *Succ = BasicBlocks[Jump.Target]; + const Instruction *TI = BB->getTerminator(); + // Check if a block ends with InvokeInst and mark non-taken branch unlikely. + // In that case block Succ should be a landing pad + if (Successors[BB].size() == 2 && Successors[BB].back() == Succ) { + if (isa<InvokeInst>(TI)) { + Jump.IsUnlikely = true; + } + } + const Instruction *SuccTI = Succ->getTerminator(); + // Check if the target block contains UnreachableInst and mark it unlikely + if (SuccTI->getNumSuccessors() == 0) { + if (isa<UnreachableInst>(SuccTI)) { + Jump.IsUnlikely = true; + } + } + } +} + +template <> +void SampleProfileLoaderBaseImpl<Function>::computeDominanceAndLoopInfo( + Function &F) { + DT.reset(new DominatorTree); + DT->recalculate(F); + + PDT.reset(new PostDominatorTree(F)); + + LI.reset(new LoopInfo); + LI->analyze(*DT); +} +} // namespace llvm + ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { if (FunctionSamples::ProfileIsProbeBased) return getProbeWeight(Inst); @@ -614,68 +708,6 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { return getInstWeightImpl(Inst); } -// Here use error_code to represent: 1) The dangling probe. 2) Ignore the weight -// of non-probe instruction. So if all instructions of the BB give error_code, -// tell the inference algorithm to infer the BB weight. -ErrorOr<uint64_t> SampleProfileLoader::getProbeWeight(const Instruction &Inst) { - assert(FunctionSamples::ProfileIsProbeBased && - "Profile is not pseudo probe based"); - std::optional<PseudoProbe> Probe = extractProbe(Inst); - // Ignore the non-probe instruction. If none of the instruction in the BB is - // probe, we choose to infer the BB's weight. - if (!Probe) - return std::error_code(); - - const FunctionSamples *FS = findFunctionSamples(Inst); - // If none of the instruction has FunctionSample, we choose to return zero - // value sample to indicate the BB is cold. This could happen when the - // instruction is from inlinee and no profile data is found. - // FIXME: This should not be affected by the source drift issue as 1) if the - // newly added function is top-level inliner, it won't match the CFG checksum - // in the function profile or 2) if it's the inlinee, the inlinee should have - // a profile, otherwise it wouldn't be inlined. For non-probe based profile, - // we can improve it by adding a switch for profile-sample-block-accurate for - // block level counts in the future. - if (!FS) - return 0; - - // For non-CS profile, If a direct call/invoke instruction is inlined in - // profile (findCalleeFunctionSamples returns non-empty result), but not - // inlined here, it means that the inlined callsite has no sample, thus the - // call instruction should have 0 count. - // For CS profile, the callsite count of previously inlined callees is - // populated with the entry count of the callees. - if (!FunctionSamples::ProfileIsCS) - if (const auto *CB = dyn_cast<CallBase>(&Inst)) - if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB)) - return 0; - - const ErrorOr<uint64_t> &R = FS->findSamplesAt(Probe->Id, 0); - if (R) { - uint64_t Samples = R.get() * Probe->Factor; - bool FirstMark = CoverageTracker.markSamplesUsed(FS, Probe->Id, 0, Samples); - if (FirstMark) { - ORE->emit([&]() { - OptimizationRemarkAnalysis Remark(DEBUG_TYPE, "AppliedSamples", &Inst); - Remark << "Applied " << ore::NV("NumSamples", Samples); - Remark << " samples from profile (ProbeId="; - Remark << ore::NV("ProbeId", Probe->Id); - Remark << ", Factor="; - Remark << ore::NV("Factor", Probe->Factor); - Remark << ", OriginalSamples="; - Remark << ore::NV("OriginalSamples", R.get()); - Remark << ")"; - return Remark; - }); - } - LLVM_DEBUG(dbgs() << " " << Probe->Id << ":" << Inst - << " - weight: " << R.get() << " - factor: " - << format("%0.2f", Probe->Factor) << ")\n"); - return Samples; - } - return R; -} - /// Get the FunctionSamples for a call instruction. /// /// The FunctionSamples of a call/invoke instruction \p Inst is the inlined @@ -1041,8 +1073,8 @@ void SampleProfileLoader::findExternalInlineCandidate( DenseSet<GlobalValue::GUID> &InlinedGUIDs, const StringMap<Function *> &SymbolMap, uint64_t Threshold) { - // If ExternalInlineAdvisor wants to inline an external function - // make sure it's imported + // If ExternalInlineAdvisor(ReplayInlineAdvisor) wants to inline an external + // function make sure it's imported if (CB && getExternalInlineAdvisorShouldInline(*CB)) { // Samples may not exist for replayed function, if so // just add the direct GUID and move on @@ -1055,7 +1087,13 @@ void SampleProfileLoader::findExternalInlineCandidate( Threshold = 0; } - assert(Samples && "expect non-null caller profile"); + // In some rare cases, call instruction could be changed after being pushed + // into inline candidate queue, this is because earlier inlining may expose + // constant propagation which can change indirect call to direct call. When + // this happens, we may fail to find matching function samples for the + // candidate later, even if a match was found when the candidate was enqueued. + if (!Samples) + return; // For AutoFDO profile, retrieve candidate profiles by walking over // the nested inlinee profiles. @@ -1255,7 +1293,7 @@ bool SampleProfileLoader::tryInlineCandidate( if (!Cost) return false; - InlineFunctionInfo IFI(nullptr, GetAC); + InlineFunctionInfo IFI(GetAC); IFI.UpdateProfile = false; InlineResult IR = InlineFunction(CB, IFI, /*MergeAttributes=*/true); @@ -1784,9 +1822,10 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { if (!ProbeManager->profileIsValid(F, *Samples)) { LLVM_DEBUG( dbgs() << "Profile is invalid due to CFG mismatch for Function " - << F.getName()); + << F.getName() << "\n"); ++NumMismatchedProfile; - return false; + if (!SalvageStaleProfile) + return false; } ++NumMatchedProfile; } else { @@ -1813,7 +1852,7 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { } std::unique_ptr<ProfiledCallGraph> -SampleProfileLoader::buildProfiledCallGraph(CallGraph &CG) { +SampleProfileLoader::buildProfiledCallGraph(Module &M) { std::unique_ptr<ProfiledCallGraph> ProfiledCG; if (FunctionSamples::ProfileIsCS) ProfiledCG = std::make_unique<ProfiledCallGraph>(*ContextTracker); @@ -1823,18 +1862,17 @@ SampleProfileLoader::buildProfiledCallGraph(CallGraph &CG) { // Add all functions into the profiled call graph even if they are not in // the profile. This makes sure functions missing from the profile still // gets a chance to be processed. - for (auto &Node : CG) { - const auto *F = Node.first; - if (!F || F->isDeclaration() || !F->hasFnAttribute("use-sample-profile")) + for (Function &F : M) { + if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) continue; - ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(*F)); + ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(F)); } return ProfiledCG; } std::vector<Function *> -SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { +SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) { std::vector<Function *> FunctionOrderList; FunctionOrderList.reserve(M.size()); @@ -1842,7 +1880,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { errs() << "WARNING: -use-profiled-call-graph ignored, should be used " "together with -sample-profile-top-down-load.\n"; - if (!ProfileTopDownLoad || CG == nullptr) { + if (!ProfileTopDownLoad) { if (ProfileMergeInlinee) { // Disable ProfileMergeInlinee if profile is not loaded in top down order, // because the profile for a function may be used for the profile @@ -1858,8 +1896,6 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { return FunctionOrderList; } - assert(&CG->getModule() == &M); - if (UseProfiledCallGraph || (FunctionSamples::ProfileIsCS && !UseProfiledCallGraph.getNumOccurrences())) { // Use profiled call edges to augment the top-down order. There are cases @@ -1910,7 +1946,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { // static call edges are not so important when they don't correspond to a // context in the profile. - std::unique_ptr<ProfiledCallGraph> ProfiledCG = buildProfiledCallGraph(*CG); + std::unique_ptr<ProfiledCallGraph> ProfiledCG = buildProfiledCallGraph(M); scc_iterator<ProfiledCallGraph *> CGI = scc_begin(ProfiledCG.get()); while (!CGI.isAtEnd()) { auto Range = *CGI; @@ -1927,25 +1963,27 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { ++CGI; } } else { - scc_iterator<CallGraph *> CGI = scc_begin(CG); - while (!CGI.isAtEnd()) { - for (CallGraphNode *Node : *CGI) { - auto *F = Node->getFunction(); - if (F && !F->isDeclaration() && F->hasFnAttribute("use-sample-profile")) - FunctionOrderList.push_back(F); + CG.buildRefSCCs(); + for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) { + for (LazyCallGraph::SCC &C : RC) { + for (LazyCallGraph::Node &N : C) { + Function &F = N.getFunction(); + if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile")) + FunctionOrderList.push_back(&F); + } } - ++CGI; } } + std::reverse(FunctionOrderList.begin(), FunctionOrderList.end()); + LLVM_DEBUG({ dbgs() << "Function processing order:\n"; - for (auto F : reverse(FunctionOrderList)) { + for (auto F : FunctionOrderList) { dbgs() << F->getName() << "\n"; } }); - std::reverse(FunctionOrderList.begin(), FunctionOrderList.end()); return FunctionOrderList; } @@ -1954,7 +1992,7 @@ bool SampleProfileLoader::doInitialization(Module &M, auto &Ctx = M.getContext(); auto ReaderOrErr = SampleProfileReader::create( - Filename, Ctx, FSDiscriminatorPass::Base, RemappingFilename); + Filename, Ctx, *FS, FSDiscriminatorPass::Base, RemappingFilename); if (std::error_code EC = ReaderOrErr.getError()) { std::string Msg = "Could not open profile: " + EC.message(); Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); @@ -2016,6 +2054,16 @@ bool SampleProfileLoader::doInitialization(Module &M, UsePreInlinerDecision = true; } + // Enable stale profile matching by default for probe-based profile. + // Currently the matching relies on if the checksum mismatch is detected, + // which is currently only available for pseudo-probe mode. Removing the + // checksum check could cause regressions for some cases, so further tuning + // might be needed if we want to enable it for all cases. + if (Reader->profileIsProbeBased() && + !SalvageStaleProfile.getNumOccurrences()) { + SalvageStaleProfile = true; + } + if (!Reader->profileIsCS()) { // Non-CS profile should be fine without a function size budget for the // inliner since the contexts in the profile are either all from inlining @@ -2046,7 +2094,8 @@ bool SampleProfileLoader::doInitialization(Module &M, } } - if (ReportProfileStaleness || PersistProfileStaleness) { + if (ReportProfileStaleness || PersistProfileStaleness || + SalvageStaleProfile) { MatchingManager = std::make_unique<SampleProfileMatcher>(M, *Reader, ProbeManager.get()); } @@ -2054,8 +2103,167 @@ bool SampleProfileLoader::doInitialization(Module &M, return true; } -void SampleProfileMatcher::detectProfileMismatch(const Function &F, - const FunctionSamples &FS) { +void SampleProfileMatcher::countProfileMismatches( + const FunctionSamples &FS, + const std::unordered_set<LineLocation, LineLocationHash> + &MatchedCallsiteLocs, + uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) { + + auto isInvalidLineOffset = [](uint32_t LineOffset) { + return LineOffset & 0x8000; + }; + + // Check if there are any callsites in the profile that does not match to any + // IR callsites, those callsite samples will be discarded. + for (auto &I : FS.getBodySamples()) { + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + + uint64_t Count = I.second.getSamples(); + if (!I.second.getCallTargets().empty()) { + TotalCallsiteSamples += Count; + FuncProfiledCallsites++; + if (!MatchedCallsiteLocs.count(Loc)) { + MismatchedCallsiteSamples += Count; + FuncMismatchedCallsites++; + } + } + } + + for (auto &I : FS.getCallsiteSamples()) { + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + + uint64_t Count = 0; + for (auto &FM : I.second) { + Count += FM.second.getHeadSamplesEstimate(); + } + TotalCallsiteSamples += Count; + FuncProfiledCallsites++; + if (!MatchedCallsiteLocs.count(Loc)) { + MismatchedCallsiteSamples += Count; + FuncMismatchedCallsites++; + } + } +} + +// Populate the anchors(direct callee name) from profile. +void SampleProfileMatcher::populateProfileCallsites( + const FunctionSamples &FS, + StringMap<std::set<LineLocation>> &CalleeToCallsitesMap) { + for (const auto &I : FS.getBodySamples()) { + const auto &Loc = I.first; + const auto &CTM = I.second.getCallTargets(); + // Filter out possible indirect calls, use direct callee name as anchor. + if (CTM.size() == 1) { + StringRef CalleeName = CTM.begin()->first(); + const auto &Candidates = CalleeToCallsitesMap.try_emplace( + CalleeName, std::set<LineLocation>()); + Candidates.first->second.insert(Loc); + } + } + + for (const auto &I : FS.getCallsiteSamples()) { + const LineLocation &Loc = I.first; + const auto &CalleeMap = I.second; + // Filter out possible indirect calls, use direct callee name as anchor. + if (CalleeMap.size() == 1) { + StringRef CalleeName = CalleeMap.begin()->first; + const auto &Candidates = CalleeToCallsitesMap.try_emplace( + CalleeName, std::set<LineLocation>()); + Candidates.first->second.insert(Loc); + } + } +} + +// Call target name anchor based profile fuzzy matching. +// Input: +// For IR locations, the anchor is the callee name of direct callsite; For +// profile locations, it's the call target name for BodySamples or inlinee's +// profile name for CallsiteSamples. +// Matching heuristic: +// First match all the anchors in lexical order, then split the non-anchor +// locations between the two anchors evenly, first half are matched based on the +// start anchor, second half are matched based on the end anchor. +// For example, given: +// IR locations: [1, 2(foo), 3, 5, 6(bar), 7] +// Profile locations: [1, 2, 3(foo), 4, 7, 8(bar), 9] +// The matching gives: +// [1, 2(foo), 3, 5, 6(bar), 7] +// | | | | | | +// [1, 2, 3(foo), 4, 7, 8(bar), 9] +// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9]. +void SampleProfileMatcher::runStaleProfileMatching( + const std::map<LineLocation, StringRef> &IRLocations, + StringMap<std::set<LineLocation>> &CalleeToCallsitesMap, + LocToLocMap &IRToProfileLocationMap) { + assert(IRToProfileLocationMap.empty() && + "Run stale profile matching only once per function"); + + auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) { + // Skip the unchanged location mapping to save memory. + if (From != To) + IRToProfileLocationMap.insert({From, To}); + }; + + // Use function's beginning location as the initial anchor. + int32_t LocationDelta = 0; + SmallVector<LineLocation> LastMatchedNonAnchors; + + for (const auto &IR : IRLocations) { + const auto &Loc = IR.first; + StringRef CalleeName = IR.second; + bool IsMatchedAnchor = false; + // Match the anchor location in lexical order. + if (!CalleeName.empty()) { + auto ProfileAnchors = CalleeToCallsitesMap.find(CalleeName); + if (ProfileAnchors != CalleeToCallsitesMap.end() && + !ProfileAnchors->second.empty()) { + auto CI = ProfileAnchors->second.begin(); + const auto Candidate = *CI; + ProfileAnchors->second.erase(CI); + InsertMatching(Loc, Candidate); + LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName + << " is matched from " << Loc << " to " << Candidate + << "\n"); + LocationDelta = Candidate.LineOffset - Loc.LineOffset; + + // Match backwards for non-anchor locations. + // The locations in LastMatchedNonAnchors have been matched forwards + // based on the previous anchor, spilt it evenly and overwrite the + // second half based on the current anchor. + for (size_t I = (LastMatchedNonAnchors.size() + 1) / 2; + I < LastMatchedNonAnchors.size(); I++) { + const auto &L = LastMatchedNonAnchors[I]; + uint32_t CandidateLineOffset = L.LineOffset + LocationDelta; + LineLocation Candidate(CandidateLineOffset, L.Discriminator); + InsertMatching(L, Candidate); + LLVM_DEBUG(dbgs() << "Location is rematched backwards from " << L + << " to " << Candidate << "\n"); + } + + IsMatchedAnchor = true; + LastMatchedNonAnchors.clear(); + } + } + + // Match forwards for non-anchor locations. + if (!IsMatchedAnchor) { + uint32_t CandidateLineOffset = Loc.LineOffset + LocationDelta; + LineLocation Candidate(CandidateLineOffset, Loc.Discriminator); + InsertMatching(Loc, Candidate); + LLVM_DEBUG(dbgs() << "Location is matched from " << Loc << " to " + << Candidate << "\n"); + LastMatchedNonAnchors.emplace_back(Loc); + } + } +} + +void SampleProfileMatcher::runOnFunction(const Function &F, + const FunctionSamples &FS) { + bool IsFuncHashMismatch = false; if (FunctionSamples::ProfileIsProbeBased) { uint64_t Count = FS.getTotalSamples(); TotalFuncHashSamples += Count; @@ -2063,16 +2271,24 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F, if (!ProbeManager->profileIsValid(F, FS)) { MismatchedFuncHashSamples += Count; NumMismatchedFuncHash++; - return; + IsFuncHashMismatch = true; } } std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs; + // The value of the map is the name of direct callsite and use empty StringRef + // for non-direct-call site. + std::map<LineLocation, StringRef> IRLocations; - // Go through all the callsites on the IR and flag the callsite if the target - // name is the same as the one in the profile. + // Extract profile matching anchors and profile mismatch metrics in the IR. for (auto &BB : F) { for (auto &I : BB) { + // TODO: Support line-number based location(AutoFDO). + if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) { + if (std::optional<PseudoProbe> Probe = extractProbe(I)) + IRLocations.emplace(LineLocation(Probe->Id, 0), StringRef()); + } + if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) continue; @@ -2084,6 +2300,17 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F, if (Function *Callee = CB->getCalledFunction()) CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); + // Force to overwrite the callee name in case any non-call location was + // written before. + auto R = IRLocations.emplace(IRCallsite, CalleeName); + R.first->second = CalleeName; + assert((!FunctionSamples::ProfileIsProbeBased || R.second || + R.first->second == CalleeName) && + "Overwrite non-call or different callee name location for " + "pseudo probe callsite"); + + // Go through all the callsites on the IR and flag the callsite if the + // target name is the same as the one in the profile. const auto CTM = FS.findCallTargetMapAt(IRCallsite); const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite); @@ -2105,55 +2332,54 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F, } } - auto isInvalidLineOffset = [](uint32_t LineOffset) { - return LineOffset & 0x8000; - }; + // Detect profile mismatch for profile staleness metrics report. + if (ReportProfileStaleness || PersistProfileStaleness) { + uint64_t FuncMismatchedCallsites = 0; + uint64_t FuncProfiledCallsites = 0; + countProfileMismatches(FS, MatchedCallsiteLocs, FuncMismatchedCallsites, + FuncProfiledCallsites); + TotalProfiledCallsites += FuncProfiledCallsites; + NumMismatchedCallsites += FuncMismatchedCallsites; + LLVM_DEBUG({ + if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch && + FuncMismatchedCallsites) + dbgs() << "Function checksum is matched but there are " + << FuncMismatchedCallsites << "/" << FuncProfiledCallsites + << " mismatched callsites.\n"; + }); + } - // Check if there are any callsites in the profile that does not match to any - // IR callsites, those callsite samples will be discarded. - for (auto &I : FS.getBodySamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) - continue; + if (IsFuncHashMismatch && SalvageStaleProfile) { + LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() + << "\n"); - uint64_t Count = I.second.getSamples(); - if (!I.second.getCallTargets().empty()) { - TotalCallsiteSamples += Count; - TotalProfiledCallsites++; - if (!MatchedCallsiteLocs.count(Loc)) { - MismatchedCallsiteSamples += Count; - NumMismatchedCallsites++; - } - } - } + StringMap<std::set<LineLocation>> CalleeToCallsitesMap; + populateProfileCallsites(FS, CalleeToCallsitesMap); - for (auto &I : FS.getCallsiteSamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) - continue; + // The matching result will be saved to IRToProfileLocationMap, create a new + // map for each function. + auto &IRToProfileLocationMap = getIRToProfileLocationMap(F); - uint64_t Count = 0; - for (auto &FM : I.second) { - Count += FM.second.getHeadSamplesEstimate(); - } - TotalCallsiteSamples += Count; - TotalProfiledCallsites++; - if (!MatchedCallsiteLocs.count(Loc)) { - MismatchedCallsiteSamples += Count; - NumMismatchedCallsites++; - } + runStaleProfileMatching(IRLocations, CalleeToCallsitesMap, + IRToProfileLocationMap); } } -void SampleProfileMatcher::detectProfileMismatch() { +void SampleProfileMatcher::runOnModule() { for (auto &F : M) { if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) continue; - FunctionSamples *FS = Reader.getSamplesFor(F); + FunctionSamples *FS = nullptr; + if (FlattenProfileForMatching) + FS = getFlattenedSamplesFor(F); + else + FS = Reader.getSamplesFor(F); if (!FS) continue; - detectProfileMismatch(F, *FS); + runOnFunction(F, *FS); } + if (SalvageStaleProfile) + distributeIRToProfileLocationMap(); if (ReportProfileStaleness) { if (FunctionSamples::ProfileIsProbeBased) { @@ -2196,8 +2422,31 @@ void SampleProfileMatcher::detectProfileMismatch() { } } +void SampleProfileMatcher::distributeIRToProfileLocationMap( + FunctionSamples &FS) { + const auto ProfileMappings = FuncMappings.find(FS.getName()); + if (ProfileMappings != FuncMappings.end()) { + FS.setIRToProfileLocationMap(&(ProfileMappings->second)); + } + + for (auto &Inlinees : FS.getCallsiteSamples()) { + for (auto FS : Inlinees.second) { + distributeIRToProfileLocationMap(FS.second); + } + } +} + +// Use a central place to distribute the matching results. Outlined and inlined +// profile with the function name will be set to the same pointer. +void SampleProfileMatcher::distributeIRToProfileLocationMap() { + for (auto &I : Reader.getProfiles()) { + distributeIRToProfileLocationMap(I.second); + } +} + bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, - ProfileSummaryInfo *_PSI, CallGraph *CG) { + ProfileSummaryInfo *_PSI, + LazyCallGraph &CG) { GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap); PSI = _PSI; @@ -2240,8 +2489,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, assert(SymbolMap.count(StringRef()) == 0 && "No empty StringRef should be added in SymbolMap"); - if (ReportProfileStaleness || PersistProfileStaleness) - MatchingManager->detectProfileMismatch(); + if (ReportProfileStaleness || PersistProfileStaleness || + SalvageStaleProfile) { + MatchingManager->runOnModule(); + } bool retval = false; for (auto *F : buildFunctionOrder(M, CG)) { @@ -2327,6 +2578,11 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) return emitAnnotations(F); return false; } +SampleProfileLoaderPass::SampleProfileLoaderPass( + std::string File, std::string RemappingFile, ThinOrFullLTOPhase LTOPhase, + IntrusiveRefCntPtr<vfs::FileSystem> FS) + : ProfileFileName(File), ProfileRemappingFileName(RemappingFile), + LTOPhase(LTOPhase), FS(std::move(FS)) {} PreservedAnalyses SampleProfileLoaderPass::run(Module &M, ModuleAnalysisManager &AM) { @@ -2343,18 +2599,21 @@ PreservedAnalyses SampleProfileLoaderPass::run(Module &M, return FAM.getResult<TargetLibraryAnalysis>(F); }; + if (!FS) + FS = vfs::getRealFileSystem(); + SampleProfileLoader SampleLoader( ProfileFileName.empty() ? SampleProfileFile : ProfileFileName, ProfileRemappingFileName.empty() ? SampleProfileRemappingFile : ProfileRemappingFileName, - LTOPhase, GetAssumptionCache, GetTTI, GetTLI); + LTOPhase, FS, GetAssumptionCache, GetTTI, GetTLI); if (!SampleLoader.doInitialization(M, &FAM)) return PreservedAnalyses::all(); ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); - CallGraph &CG = AM.getResult<CallGraphAnalysis>(M); - if (!SampleLoader.runOnModule(M, &AM, PSI, &CG)) + LazyCallGraph &CG = AM.getResult<LazyCallGraphAnalysis>(M); + if (!SampleLoader.runOnModule(M, &AM, PSI, CG)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index c4844dbe7f3c..0a42de7224b4 100644 --- a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -13,6 +13,7 @@ #include "llvm/Transforms/IPO/SampleProfileProbe.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/EHUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -32,7 +33,7 @@ #include <vector> using namespace llvm; -#define DEBUG_TYPE "sample-profile-probe" +#define DEBUG_TYPE "pseudo-probe" STATISTIC(ArtificialDbgLine, "Number of probes that have an artificial debug line"); @@ -55,11 +56,7 @@ static uint64_t getCallStackHash(const DILocation *DIL) { while (InlinedAt) { Hash ^= MD5Hash(std::to_string(InlinedAt->getLine())); Hash ^= MD5Hash(std::to_string(InlinedAt->getColumn())); - const DISubprogram *SP = InlinedAt->getScope()->getSubprogram(); - // Use linkage name for C++ if possible. - auto Name = SP->getLinkageName(); - if (Name.empty()) - Name = SP->getName(); + auto Name = InlinedAt->getSubprogramLinkageName(); Hash ^= MD5Hash(Name); InlinedAt = InlinedAt->getInlinedAt(); } @@ -169,47 +166,6 @@ void PseudoProbeVerifier::verifyProbeFactors( } } -PseudoProbeManager::PseudoProbeManager(const Module &M) { - if (NamedMDNode *FuncInfo = M.getNamedMetadata(PseudoProbeDescMetadataName)) { - for (const auto *Operand : FuncInfo->operands()) { - const auto *MD = cast<MDNode>(Operand); - auto GUID = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(0))->getZExtValue(); - auto Hash = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(1))->getZExtValue(); - GUIDToProbeDescMap.try_emplace(GUID, PseudoProbeDescriptor(GUID, Hash)); - } - } -} - -const PseudoProbeDescriptor * -PseudoProbeManager::getDesc(const Function &F) const { - auto I = GUIDToProbeDescMap.find( - Function::getGUID(FunctionSamples::getCanonicalFnName(F))); - return I == GUIDToProbeDescMap.end() ? nullptr : &I->second; -} - -bool PseudoProbeManager::moduleIsProbed(const Module &M) const { - return M.getNamedMetadata(PseudoProbeDescMetadataName); -} - -bool PseudoProbeManager::profileIsValid(const Function &F, - const FunctionSamples &Samples) const { - const auto *Desc = getDesc(F); - if (!Desc) { - LLVM_DEBUG(dbgs() << "Probe descriptor missing for Function " << F.getName() - << "\n"); - return false; - } else { - if (Desc->getFunctionHash() != Samples.getFunctionHash()) { - LLVM_DEBUG(dbgs() << "Hash mismatch for Function " << F.getName() - << "\n"); - return false; - } - } - return true; -} - SampleProfileProber::SampleProfileProber(Function &Func, const std::string &CurModuleUniqueId) : F(&Func), CurModuleUniqueId(CurModuleUniqueId) { @@ -253,8 +209,14 @@ void SampleProfileProber::computeCFGHash() { } void SampleProfileProber::computeProbeIdForBlocks() { + DenseSet<BasicBlock *> KnownColdBlocks; + computeEHOnlyBlocks(*F, KnownColdBlocks); + // Insert pseudo probe to non-cold blocks only. This will reduce IR size as + // well as the binary size while retaining the profile quality. for (auto &BB : *F) { - BlockProbeIds[&BB] = ++LastProbeId; + ++LastProbeId; + if (!KnownColdBlocks.contains(&BB)) + BlockProbeIds[&BB] = LastProbeId; } } @@ -283,9 +245,16 @@ uint32_t SampleProfileProber::getCallsiteId(const Instruction *Call) const { void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { Module *M = F.getParent(); MDBuilder MDB(F.getContext()); - // Compute a GUID without considering the function's linkage type. This is - // fine since function name is the only key in the profile database. - uint64_t Guid = Function::getGUID(F.getName()); + // Since the GUID from probe desc and inline stack are computed seperately, we + // need to make sure their names are consistent, so here also use the name + // from debug info. + StringRef FName = F.getName(); + if (auto *SP = F.getSubprogram()) { + FName = SP->getLinkageName(); + if (FName.empty()) + FName = SP->getName(); + } + uint64_t Guid = Function::getGUID(FName); // Assign an artificial debug line to a probe that doesn't come with a real // line. A probe not having a debug line will get an incomplete inline @@ -339,6 +308,14 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { Builder.getInt64(PseudoProbeFullDistributionFactor)}; auto *Probe = Builder.CreateCall(ProbeFn, Args); AssignDebugLoc(Probe); + // Reset the dwarf discriminator if the debug location comes with any. The + // discriminator field may be used by FS-AFDO later in the pipeline. + if (auto DIL = Probe->getDebugLoc()) { + if (DIL->getDiscriminator()) { + DIL = DIL->cloneWithDiscriminator(0); + Probe->setDebugLoc(DIL); + } + } } // Probe both direct calls and indirect calls. Direct calls are probed so that @@ -351,12 +328,13 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { ? (uint32_t)PseudoProbeType::DirectCall : (uint32_t)PseudoProbeType::IndirectCall; AssignDebugLoc(Call); - // Levarge the 32-bit discriminator field of debug data to store the ID and - // type of a callsite probe. This gets rid of the dependency on plumbing a - // customized metadata through the codegen pipeline. - uint32_t V = PseudoProbeDwarfDiscriminator::packProbeData( - Index, Type, 0, PseudoProbeDwarfDiscriminator::FullDistributionFactor); if (auto DIL = Call->getDebugLoc()) { + // Levarge the 32-bit discriminator field of debug data to store the ID + // and type of a callsite probe. This gets rid of the dependency on + // plumbing a customized metadata through the codegen pipeline. + uint32_t V = PseudoProbeDwarfDiscriminator::packProbeData( + Index, Type, 0, + PseudoProbeDwarfDiscriminator::FullDistributionFactor); DIL = DIL->cloneWithDiscriminator(V); Call->setDebugLoc(DIL); } @@ -368,28 +346,10 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { // - FunctionHash. // - FunctionName auto Hash = getFunctionHash(); - auto *MD = MDB.createPseudoProbeDesc(Guid, Hash, &F); + auto *MD = MDB.createPseudoProbeDesc(Guid, Hash, FName); auto *NMD = M->getNamedMetadata(PseudoProbeDescMetadataName); assert(NMD && "llvm.pseudo_probe_desc should be pre-created"); NMD->addOperand(MD); - - // Preserve a comdat group to hold all probes materialized later. This - // allows that when the function is considered dead and removed, the - // materialized probes are disposed too. - // Imported functions are defined in another module. They do not need - // the following handling since same care will be taken for them in their - // original module. The pseudo probes inserted into an imported functions - // above will naturally not be emitted since the imported function is free - // from object emission. However they will be emitted together with the - // inliner functions that the imported function is inlined into. We are not - // creating a comdat group for an import function since it's useless anyway. - if (!F.isDeclarationForLinker()) { - if (TM) { - auto Triple = TM->getTargetTriple(); - if (Triple.supportsCOMDAT() && TM->getFunctionSections()) - getOrCreateFunctionComdat(F, Triple); - } - } } PreservedAnalyses SampleProfileProbePass::run(Module &M, diff --git a/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp b/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp index 0f2412dce1c9..53d5b18dcead 100644 --- a/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp +++ b/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp @@ -16,8 +16,6 @@ #include "llvm/Transforms/IPO/StripDeadPrototypes.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" using namespace llvm; @@ -56,30 +54,3 @@ PreservedAnalyses StripDeadPrototypesPass::run(Module &M, return PreservedAnalyses::none(); return PreservedAnalyses::all(); } - -namespace { - -class StripDeadPrototypesLegacyPass : public ModulePass { -public: - static char ID; // Pass identification, replacement for typeid - StripDeadPrototypesLegacyPass() : ModulePass(ID) { - initializeStripDeadPrototypesLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - return stripDeadPrototypes(M); - } -}; - -} // end anonymous namespace - -char StripDeadPrototypesLegacyPass::ID = 0; -INITIALIZE_PASS(StripDeadPrototypesLegacyPass, "strip-dead-prototypes", - "Strip Unused Function Prototypes", false, false) - -ModulePass *llvm::createStripDeadPrototypesPass() { - return new StripDeadPrototypesLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/StripSymbols.cpp b/llvm/lib/Transforms/IPO/StripSymbols.cpp index 34f8c4316cca..147513452789 100644 --- a/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -30,110 +30,12 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/TypeFinder.h" #include "llvm/IR/ValueSymbolTable.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; -namespace { - class StripSymbols : public ModulePass { - bool OnlyDebugInfo; - public: - static char ID; // Pass identification, replacement for typeid - explicit StripSymbols(bool ODI = false) - : ModulePass(ID), OnlyDebugInfo(ODI) { - initializeStripSymbolsPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - }; - - class StripNonDebugSymbols : public ModulePass { - public: - static char ID; // Pass identification, replacement for typeid - explicit StripNonDebugSymbols() - : ModulePass(ID) { - initializeStripNonDebugSymbolsPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - }; - - class StripDebugDeclare : public ModulePass { - public: - static char ID; // Pass identification, replacement for typeid - explicit StripDebugDeclare() - : ModulePass(ID) { - initializeStripDebugDeclarePass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - }; - - class StripDeadDebugInfo : public ModulePass { - public: - static char ID; // Pass identification, replacement for typeid - explicit StripDeadDebugInfo() - : ModulePass(ID) { - initializeStripDeadDebugInfoPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - }; -} - -char StripSymbols::ID = 0; -INITIALIZE_PASS(StripSymbols, "strip", - "Strip all symbols from a module", false, false) - -ModulePass *llvm::createStripSymbolsPass(bool OnlyDebugInfo) { - return new StripSymbols(OnlyDebugInfo); -} - -char StripNonDebugSymbols::ID = 0; -INITIALIZE_PASS(StripNonDebugSymbols, "strip-nondebug", - "Strip all symbols, except dbg symbols, from a module", - false, false) - -ModulePass *llvm::createStripNonDebugSymbolsPass() { - return new StripNonDebugSymbols(); -} - -char StripDebugDeclare::ID = 0; -INITIALIZE_PASS(StripDebugDeclare, "strip-debug-declare", - "Strip all llvm.dbg.declare intrinsics", false, false) - -ModulePass *llvm::createStripDebugDeclarePass() { - return new StripDebugDeclare(); -} - -char StripDeadDebugInfo::ID = 0; -INITIALIZE_PASS(StripDeadDebugInfo, "strip-dead-debug-info", - "Strip debug info for unused symbols", false, false) - -ModulePass *llvm::createStripDeadDebugInfoPass() { - return new StripDeadDebugInfo(); -} - /// OnlyUsedBy - Return true if V is only used by Usr. static bool OnlyUsedBy(Value *V, Value *Usr) { for (User *U : V->users()) @@ -234,24 +136,6 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { return true; } -bool StripSymbols::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - bool Changed = false; - Changed |= StripDebugInfo(M); - if (!OnlyDebugInfo) - Changed |= StripSymbolNames(M, false); - return Changed; -} - -bool StripNonDebugSymbols::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - return StripSymbolNames(M, true); -} - static bool stripDebugDeclareImpl(Module &M) { Function *Declare = M.getFunction("llvm.dbg.declare"); @@ -290,50 +174,6 @@ static bool stripDebugDeclareImpl(Module &M) { return true; } -bool StripDebugDeclare::runOnModule(Module &M) { - if (skipModule(M)) - return false; - return stripDebugDeclareImpl(M); -} - -/// Collects compilation units referenced by functions or lexical scopes. -/// Accepts any DIScope and uses recursive bottom-up approach to reach either -/// DISubprogram or DILexicalBlockBase. -static void -collectCUsWithScope(const DIScope *Scope, std::set<DICompileUnit *> &LiveCUs, - SmallPtrSet<const DIScope *, 8> &VisitedScopes) { - if (!Scope) - return; - - auto InS = VisitedScopes.insert(Scope); - if (!InS.second) - return; - - if (const auto *SP = dyn_cast<DISubprogram>(Scope)) { - if (SP->getUnit()) - LiveCUs.insert(SP->getUnit()); - return; - } - if (const auto *LB = dyn_cast<DILexicalBlockBase>(Scope)) { - const DISubprogram *SP = LB->getSubprogram(); - if (SP && SP->getUnit()) - LiveCUs.insert(SP->getUnit()); - return; - } - - collectCUsWithScope(Scope->getScope(), LiveCUs, VisitedScopes); -} - -static void -collectCUsForInlinedFuncs(const DILocation *Loc, - std::set<DICompileUnit *> &LiveCUs, - SmallPtrSet<const DIScope *, 8> &VisitedScopes) { - if (!Loc || !Loc->getInlinedAt()) - return; - collectCUsWithScope(Loc->getScope(), LiveCUs, VisitedScopes); - collectCUsForInlinedFuncs(Loc->getInlinedAt(), LiveCUs, VisitedScopes); -} - static bool stripDeadDebugInfoImpl(Module &M) { bool Changed = false; @@ -361,19 +201,15 @@ static bool stripDeadDebugInfoImpl(Module &M) { } std::set<DICompileUnit *> LiveCUs; - SmallPtrSet<const DIScope *, 8> VisitedScopes; - // Any CU is live if is referenced from a subprogram metadata that is attached - // to a function defined or inlined in the module. - for (const Function &Fn : M.functions()) { - collectCUsWithScope(Fn.getSubprogram(), LiveCUs, VisitedScopes); - for (const_inst_iterator I = inst_begin(&Fn), E = inst_end(&Fn); I != E; - ++I) { - if (!I->getDebugLoc()) - continue; - const DILocation *DILoc = I->getDebugLoc().get(); - collectCUsForInlinedFuncs(DILoc, LiveCUs, VisitedScopes); - } + DebugInfoFinder LiveCUFinder; + for (const Function &F : M.functions()) { + if (auto *SP = cast_or_null<DISubprogram>(F.getSubprogram())) + LiveCUFinder.processSubprogram(SP); + for (const Instruction &I : instructions(F)) + LiveCUFinder.processInstruction(M, I); } + auto FoundCUs = LiveCUFinder.compile_units(); + LiveCUs.insert(FoundCUs.begin(), FoundCUs.end()); bool HasDeadCUs = false; for (DICompileUnit *DIC : F.compile_units()) { @@ -424,39 +260,34 @@ static bool stripDeadDebugInfoImpl(Module &M) { return Changed; } -/// Remove any debug info for global variables/functions in the given module for -/// which said global variable/function no longer exists (i.e. is null). -/// -/// Debugging information is encoded in llvm IR using metadata. This is designed -/// such a way that debug info for symbols preserved even if symbols are -/// optimized away by the optimizer. This special pass removes debug info for -/// such symbols. -bool StripDeadDebugInfo::runOnModule(Module &M) { - if (skipModule(M)) - return false; - return stripDeadDebugInfoImpl(M); -} - PreservedAnalyses StripSymbolsPass::run(Module &M, ModuleAnalysisManager &AM) { StripDebugInfo(M); StripSymbolNames(M, false); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } PreservedAnalyses StripNonDebugSymbolsPass::run(Module &M, ModuleAnalysisManager &AM) { StripSymbolNames(M, true); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } PreservedAnalyses StripDebugDeclarePass::run(Module &M, ModuleAnalysisManager &AM) { stripDebugDeclareImpl(M); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } PreservedAnalyses StripDeadDebugInfoPass::run(Module &M, ModuleAnalysisManager &AM) { stripDeadDebugInfoImpl(M); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } diff --git a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 670097010085..fc1e70b1b3d3 100644 --- a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -18,9 +18,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" #include "llvm/Object/ModuleSymbolTable.h" -#include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" @@ -148,6 +146,14 @@ void promoteTypeIds(Module &M, StringRef ModuleId) { } } + if (Function *TypeCheckedLoadRelativeFunc = M.getFunction( + Intrinsic::getName(Intrinsic::type_checked_load_relative))) { + for (const Use &U : TypeCheckedLoadRelativeFunc->uses()) { + auto CI = cast<CallInst>(U.getUser()); + ExternalizeTypeId(CI, 2); + } + } + for (GlobalObject &GO : M.global_objects()) { SmallVector<MDNode *, 1> MDs; GO.getMetadata(LLVMContext::MD_type, MDs); @@ -196,6 +202,13 @@ void simplifyExternals(Module &M) { F.eraseFromParent(); } + for (GlobalIFunc &I : llvm::make_early_inc_range(M.ifuncs())) { + if (I.use_empty()) + I.eraseFromParent(); + else + assert(I.getResolverFunction() && "ifunc misses its resolver function"); + } + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { if (GV.isDeclaration() && GV.use_empty()) { GV.eraseFromParent(); @@ -246,6 +259,16 @@ static void cloneUsedGlobalVariables(const Module &SrcM, Module &DestM, appendToUsed(DestM, NewUsed); } +#ifndef NDEBUG +static bool enableUnifiedLTO(Module &M) { + bool UnifiedLTO = false; + if (auto *MD = + mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("UnifiedLTO"))) + UnifiedLTO = MD->getZExtValue(); + return UnifiedLTO; +} +#endif + // If it's possible to split M into regular and thin LTO parts, do so and write // a multi-module bitcode file with the two parts to OS. Otherwise, write only a // regular LTO bitcode file to OS. @@ -254,18 +277,20 @@ void splitAndWriteThinLTOBitcode( function_ref<AAResults &(Function &)> AARGetter, Module &M) { std::string ModuleId = getUniqueModuleId(&M); if (ModuleId.empty()) { + assert(!enableUnifiedLTO(M)); // We couldn't generate a module ID for this module, write it out as a // regular LTO module with an index for summary-based dead stripping. ProfileSummaryInfo PSI(M); M.addModuleFlag(Module::Error, "ThinLTO", uint32_t(0)); ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, &PSI); - WriteBitcodeToFile(M, OS, /*ShouldPreserveUseListOrder=*/false, &Index); + WriteBitcodeToFile(M, OS, /*ShouldPreserveUseListOrder=*/false, &Index, + /*UnifiedLTO=*/false); if (ThinLinkOS) // We don't have a ThinLTO part, but still write the module to the // ThinLinkOS if requested so that the expected output file is produced. WriteBitcodeToFile(M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false, - &Index); + &Index, /*UnifiedLTO=*/false); return; } @@ -503,15 +528,17 @@ bool hasTypeMetadata(Module &M) { return false; } -void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, +bool writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, function_ref<AAResults &(Function &)> AARGetter, Module &M, const ModuleSummaryIndex *Index) { std::unique_ptr<ModuleSummaryIndex> NewIndex = nullptr; // See if this module has any type metadata. If so, we try to split it // or at least promote type ids to enable WPD. if (hasTypeMetadata(M)) { - if (enableSplitLTOUnit(M)) - return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M); + if (enableSplitLTOUnit(M)) { + splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M); + return true; + } // Promote type ids as needed for index-based WPD. std::string ModuleId = getUniqueModuleId(&M); if (!ModuleId.empty()) { @@ -544,6 +571,7 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, // given OS. if (ThinLinkOS && Index) writeThinLinkBitcodeToFile(M, *ThinLinkOS, *Index, ModHash); + return false; } } // anonymous namespace @@ -552,10 +580,11 @@ PreservedAnalyses llvm::ThinLTOBitcodeWriterPass::run(Module &M, ModuleAnalysisManager &AM) { FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - writeThinLTOBitcode(OS, ThinLinkOS, - [&FAM](Function &F) -> AAResults & { - return FAM.getResult<AAManager>(F); - }, - M, &AM.getResult<ModuleSummaryIndexAnalysis>(M)); - return PreservedAnalyses::all(); + bool Changed = writeThinLTOBitcode( + OS, ThinLinkOS, + [&FAM](Function &F) -> AAResults & { + return FAM.getResult<AAManager>(F); + }, + M, &AM.getResult<ModuleSummaryIndexAnalysis>(M)); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 487a0a4a97f7..d33258642365 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -58,7 +58,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -84,9 +83,6 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ModuleSummaryIndexYAML.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Errc.h" @@ -94,6 +90,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/GlobPattern.h" #include "llvm/Support/MathExtras.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -259,7 +256,7 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, if (I < B.size()) BitsUsed |= B[I]; if (BitsUsed != 0xff) - return (MinByte + I) * 8 + countTrailingZeros(uint8_t(~BitsUsed)); + return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed)); } } else { // Find a free (Size/8) byte region in each member of Used. @@ -313,9 +310,10 @@ void wholeprogramdevirt::setAfterReturnValues( } } -VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) +VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM) : Fn(Fn), TM(TM), - IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} + IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), + WasDevirt(false) {} namespace { @@ -379,6 +377,7 @@ namespace { // conditions // 1) All summaries are live. // 2) All function summaries indicate it's unreachable +// 3) There is no non-function with the same GUID (which is rare) bool mustBeUnreachableFunction(ValueInfo TheFnVI) { if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { // Returns false if ValueInfo is absent, or the summary list is empty @@ -391,12 +390,13 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) { // In general either all summaries should be live or all should be dead. if (!Summary->isLive()) return false; - if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) { + if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) { if (!FS->fflags().MustBeUnreachable) return false; } - // Do nothing if a non-function has the same GUID (which is rare). - // This is correct since non-function summaries are not relevant. + // Be conservative if a non-function has the same GUID (which is rare). + else + return false; } // All function summaries are live and all of them agree that the function is // unreachble. @@ -567,6 +567,10 @@ struct DevirtModule { // optimize a call more than once. SmallPtrSet<CallBase *, 8> OptimizedCalls; + // Store calls that had their ptrauth bundle removed. They are to be deleted + // at the end of the optimization. + SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved; + // This map keeps track of the number of "unsafe" uses of a loaded function // pointer. The key is the associated llvm.type.test intrinsic call generated // by this pass. An unsafe use is one that calls the loaded function pointer @@ -761,7 +765,7 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return FAM.getResult<DominatorTreeAnalysis>(F); }; if (UseCommandLine) { - if (DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) + if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -892,8 +896,7 @@ static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. const auto &ModPaths = Summary->modulePaths(); if (ClSummaryAction != PassSummaryAction::Import && - ModPaths.find(ModuleSummaryIndex::getRegularLTOModuleName()) == - ModPaths.end()) + !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName())) return createStringError( errc::invalid_argument, "combined summary should contain Regular LTO module"); @@ -958,7 +961,7 @@ void DevirtModule::buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { DenseMap<GlobalVariable *, VTableBits *> GVToBits; - Bits.reserve(M.getGlobalList().size()); + Bits.reserve(M.global_size()); SmallVector<MDNode *, 2> Types; for (GlobalVariable &GV : M.globals()) { Types.clear(); @@ -1003,11 +1006,17 @@ bool DevirtModule::tryFindVirtualCallTargets( return false; Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), - TM.Offset + ByteOffset, M); + TM.Offset + ByteOffset, M, TM.Bits->GV); if (!Ptr) return false; - auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); + auto C = Ptr->stripPointerCasts(); + // Make sure this is a function or alias to a function. + auto Fn = dyn_cast<Function>(C); + auto A = dyn_cast<GlobalAlias>(C); + if (!Fn && A) + Fn = dyn_cast<Function>(A->getAliasee()); + if (!Fn) return false; @@ -1024,7 +1033,11 @@ bool DevirtModule::tryFindVirtualCallTargets( if (mustBeUnreachableFunction(Fn, ExportSummary)) continue; - TargetsForSlot.push_back({Fn, &TM}); + // Save the symbol used in the vtable to use as the devirtualization + // target. + auto GV = dyn_cast<GlobalValue>(C); + assert(GV); + TargetsForSlot.push_back({GV, &TM}); } // Give up if we couldn't find any targets. @@ -1156,6 +1169,14 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, // !callees metadata. CB.setMetadata(LLVMContext::MD_prof, nullptr); CB.setMetadata(LLVMContext::MD_callees, nullptr); + if (CB.getCalledOperand() && + CB.getOperandBundle(LLVMContext::OB_ptrauth)) { + auto *NewCS = + CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB); + CB.replaceAllUsesWith(NewCS); + // Schedule for deletion at the end of pass run. + CallsWithPtrAuthBundleRemoved.push_back(&CB); + } } // This use is no longer unsafe. @@ -1205,7 +1226,7 @@ bool DevirtModule::trySingleImplDevirt( WholeProgramDevirtResolution *Res) { // See if the program contains a single implementation of this virtual // function. - Function *TheFn = TargetsForSlot[0].Fn; + auto *TheFn = TargetsForSlot[0].Fn; for (auto &&Target : TargetsForSlot) if (TheFn != Target.Fn) return false; @@ -1379,9 +1400,20 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, IsExported = true; if (CSInfo.AllCallSitesDevirted) return; + + std::map<CallBase *, CallBase *> CallBases; for (auto &&VCallSite : CSInfo.CallSites) { CallBase &CB = VCallSite.CB; + if (CallBases.find(&CB) != CallBases.end()) { + // When finding devirtualizable calls, it's possible to find the same + // vtable passed to multiple llvm.type.test or llvm.type.checked.load + // calls, which can cause duplicate call sites to be recorded in + // [Const]CallSites. If we've already found one of these + // call instances, just ignore it. It will be replaced later. + continue; + } + // Jump tables are only profitable if the retpoline mitigation is enabled. Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); if (!FSAttr.isValid() || @@ -1428,8 +1460,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, AttributeList::get(M.getContext(), Attrs.getFnAttrs(), Attrs.getRetAttrs(), NewArgAttrs)); - CB.replaceAllUsesWith(NewCS); - CB.eraseFromParent(); + CallBases[&CB] = NewCS; // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) @@ -1439,6 +1470,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // retpoline mitigation, which would mean that they are lowered to // llvm.type.test and therefore require an llvm.type.test resolution for the // type identifier. + + std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { + CBs.first->replaceAllUsesWith(CBs.second); + CBs.first->eraseFromParent(); + }); }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -1451,23 +1487,30 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs( // Evaluate each function and store the result in each target's RetVal // field. for (VirtualCallTarget &Target : TargetsForSlot) { - if (Target.Fn->arg_size() != Args.size() + 1) + // TODO: Skip for now if the vtable symbol was an alias to a function, + // need to evaluate whether it would be correct to analyze the aliasee + // function for this optimization. + auto Fn = dyn_cast<Function>(Target.Fn); + if (!Fn) + return false; + + if (Fn->arg_size() != Args.size() + 1) return false; Evaluator Eval(M.getDataLayout(), nullptr); SmallVector<Constant *, 2> EvalArgs; EvalArgs.push_back( - Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); + Constant::getNullValue(Fn->getFunctionType()->getParamType(0))); for (unsigned I = 0; I != Args.size(); ++I) { - auto *ArgTy = dyn_cast<IntegerType>( - Target.Fn->getFunctionType()->getParamType(I + 1)); + auto *ArgTy = + dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1)); if (!ArgTy) return false; EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); } Constant *RetVal; - if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || + if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) || !isa<ConstantInt>(RetVal)) return false; Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); @@ -1675,8 +1718,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, OREGetter, IsBitSet); } else { - Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); - Value *Val = B.CreateLoad(RetType, ValAddr); + Value *Val = B.CreateLoad(RetType, Addr); NumVirtConstProp++; Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, OREGetter, Val); @@ -1688,8 +1730,14 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, bool DevirtModule::tryVirtualConstProp( MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res, VTableSlot Slot) { + // TODO: Skip for now if the vtable symbol was an alias to a function, + // need to evaluate whether it would be correct to analyze the aliasee + // function for this optimization. + auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn); + if (!Fn) + return false; // This only works if the function returns an integer. - auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); + auto RetType = dyn_cast<IntegerType>(Fn->getReturnType()); if (!RetType) return false; unsigned BitWidth = RetType->getBitWidth(); @@ -1707,11 +1755,18 @@ bool DevirtModule::tryVirtualConstProp( // inline all implementations of the virtual function into each call site, // rather than using function attributes to perform local optimization. for (VirtualCallTarget &Target : TargetsForSlot) { - if (Target.Fn->isDeclaration() || - !computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) + // TODO: Skip for now if the vtable symbol was an alias to a function, + // need to evaluate whether it would be correct to analyze the aliasee + // function for this optimization. + auto Fn = dyn_cast<Function>(Target.Fn); + if (!Fn) + return false; + + if (Fn->isDeclaration() || + !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn)) .doesNotAccessMemory() || - Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || - Target.Fn->getReturnType() != RetType) + Fn->arg_empty() || !Fn->arg_begin()->use_empty() || + Fn->getReturnType() != RetType) return false; } @@ -1947,9 +2002,23 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { // This helps avoid unnecessary spills. IRBuilder<> LoadB( (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); - Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); - Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + + Value *LoadedValue = nullptr; + if (TypeCheckedLoadFunc->getIntrinsicID() == + Intrinsic::type_checked_load_relative) { + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty)); + LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr); + LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); + GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); + LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); + LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); + } else { + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = + LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); + LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + } for (Instruction *LoadedPtr : LoadedPtrs) { LoadedPtr->replaceAllUsesWith(LoadedValue); @@ -2130,6 +2199,8 @@ bool DevirtModule::run() { M.getFunction(Intrinsic::getName(Intrinsic::type_test)); Function *TypeCheckedLoadFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *TypeCheckedLoadRelativeFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); // Normally if there are no users of the devirtualization intrinsics in the @@ -2138,7 +2209,9 @@ bool DevirtModule::run() { if (!ExportSummary && (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || AssumeFunc->use_empty()) && - (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) && + (!TypeCheckedLoadRelativeFunc || + TypeCheckedLoadRelativeFunc->use_empty())) return false; // Rebuild type metadata into a map for easy lookup. @@ -2152,6 +2225,9 @@ bool DevirtModule::run() { if (TypeCheckedLoadFunc) scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); + if (TypeCheckedLoadRelativeFunc) + scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc); + if (ImportSummary) { for (auto &S : CallSlots) importResolution(S.first, S.second); @@ -2219,7 +2295,7 @@ bool DevirtModule::run() { // For each (type, offset) pair: bool DidVirtualConstProp = false; - std::map<std::string, Function*> DevirtTargets; + std::map<std::string, GlobalValue *> DevirtTargets; for (auto &S : CallSlots) { // Search each of the members of the type identifier for the virtual // function implementation at offset S.first.ByteOffset, and add to @@ -2274,7 +2350,14 @@ bool DevirtModule::run() { if (RemarksEnabled) { // Generate remarks for each devirtualized function. for (const auto &DT : DevirtTargets) { - Function *F = DT.second; + GlobalValue *GV = DT.second; + auto F = dyn_cast<Function>(GV); + if (!F) { + auto A = dyn_cast<GlobalAlias>(GV); + assert(A && isa<Function>(A->getAliasee())); + F = dyn_cast<Function>(A->getAliasee()); + assert(F); + } using namespace ore; OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) @@ -2299,6 +2382,9 @@ bool DevirtModule::run() { for (GlobalVariable &GV : M.globals()) GV.eraseMetadata(LLVMContext::MD_vcall_visibility); + for (auto *CI : CallsWithPtrAuthBundleRemoved) + CI->eraseFromParent(); + return true; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index b68efc993723..91ca44e0f11e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -797,7 +797,7 @@ static Value *checkForNegativeOperand(BinaryOperator &I, // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2)) // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2)) if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) - if (C1->countTrailingZeros() == 0) + if (C1->countr_zero() == 0) if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { Value *NewOr = Builder.CreateOr(Z, ~(*C2)); return Builder.CreateSub(RHS, NewOr, "sub"); @@ -880,8 +880,15 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1); // ~X + C --> (C-1) - X - if (match(Op0, m_Not(m_Value(X)))) - return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); + if (match(Op0, m_Not(m_Value(X)))) { + // ~X + C has NSW and (C-1) won't oveflow => (C-1)-X can have NSW + auto *COne = ConstantInt::get(Op1C->getType(), 1); + bool WillNotSOV = willNotOverflowSignedSub(Op1C, COne, Add); + BinaryOperator *Res = + BinaryOperator::CreateSub(ConstantExpr::getSub(Op1C, COne), X); + Res->setHasNoSignedWrap(Add.hasNoSignedWrap() && WillNotSOV); + return Res; + } // (iN X s>> (N - 1)) + 1 --> zext (X > -1) const APInt *C; @@ -975,6 +982,16 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } + // Fold (add (zext (add X, -1)), 1) -> (zext X) if X is non-zero. + // TODO: There's a general form for any constant on the outer add. + if (C->isOne()) { + if (match(Op0, m_ZExt(m_Add(m_Value(X), m_AllOnes())))) { + const SimplifyQuery Q = SQ.getWithInstruction(&Add); + if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ZExtInst(X, Ty); + } + } + return nullptr; } @@ -1366,6 +1383,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *X = foldNoWrapAdd(I, Builder)) return X; + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) @@ -1421,6 +1441,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { Value *Sub = Builder.CreateSub(A, B); return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2)); } + + // Canonicalize a constant sub operand as an add operand for better folding: + // (C1 - A) + B --> (B - A) + C1 + if (match(&I, m_c_Add(m_OneUse(m_Sub(m_ImmConstant(C1), m_Value(A))), + m_Value(B)))) { + Value *Sub = Builder.CreateSub(B, A, "reass.sub"); + return BinaryOperator::CreateAdd(Sub, C1); + } } // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) @@ -1439,7 +1467,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) && - C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countLeadingZeros())) { + C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) { Constant *NewMask = ConstantInt::get(RHS->getType(), *C1 - 1); return BinaryOperator::CreateAnd(A, NewMask); } @@ -1451,6 +1479,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A)))))) return new ZExtInst(B, LHS->getType()); + // zext(A) + sext(A) --> 0 if A is i1 + if (match(&I, m_c_BinOp(m_ZExt(m_Value(A)), m_SExt(m_Deferred(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1515,7 +1548,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { const APInt *NegPow2C; if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))), m_Value(B)))) { - Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros()); + Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countr_zero()); Value *Shl = Builder.CreateShl(A, ShiftAmtC); return BinaryOperator::CreateSub(B, Shl); } @@ -1536,6 +1569,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *Ashr = foldAddToAshr(I)) return Ashr; + // min(A, B) + max(A, B) => A + B. + if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)), + m_c_SMin(m_Deferred(A), m_Deferred(B))), + m_c_Add(m_UMax(m_Value(A), m_Value(B)), + m_c_UMin(m_Deferred(A), m_Deferred(B)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I); + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. @@ -1575,6 +1615,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return Changed ? &I : nullptr; } @@ -1786,6 +1832,20 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); } + // minumum(X, Y) + maximum(X, Y) => X + Y. + if (match(&I, + m_c_FAdd(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), + m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), + m_Deferred(Y))))) { + BinaryOperator *Result = BinaryOperator::CreateFAddFMF(X, Y, &I); + // We cannot preserve ninf if nnan flag is not set. + // If X is NaN and Y is Inf then in original program we had NaN + NaN, + // while in optimized version NaN + Inf and this is a poison with ninf flag. + if (!Result->hasNoNaNs()) + Result->setHasNoInfs(false); + return Result; + } + return nullptr; } @@ -1956,8 +2016,17 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { Constant *C2; // C-(X+C2) --> (C-C2)-X - if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) - return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); + if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) { + // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW + // => (C-C2)-X can have NSW + bool WillNotSOV = willNotOverflowSignedSub(C, C2, I); + BinaryOperator *Res = + BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); + auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); + Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() && + WillNotSOV); + return Res; + } } auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * { @@ -2325,7 +2394,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { const APInt *AddC, *AndC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { - unsigned Cttz = AddC->countTrailingZeros(); + unsigned Cttz = AddC->countr_zero(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); if ((HighMask & *AndC).isZero()) return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC))); @@ -2388,6 +2457,21 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return replaceInstUsesWith(I, Mul); } + // max(X,Y) nsw/nuw - min(X,Y) --> abs(X nsw - Y) + if (match(Op0, m_OneUse(m_c_SMax(m_Value(X), m_Value(Y)))) && + match(Op1, m_OneUse(m_c_SMin(m_Specific(X), m_Specific(Y))))) { + if (I.hasNoUnsignedWrap() || I.hasNoSignedWrap()) { + Value *Sub = + Builder.CreateSub(X, Y, "sub", /*HasNUW=*/false, /*HasNSW=*/true); + Value *Call = + Builder.CreateBinaryIntrinsic(Intrinsic::abs, Sub, Builder.getTrue()); + return replaceInstUsesWith(I, Call); + } + } + + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return TryToNarrowDeduceFlags(); } @@ -2567,7 +2651,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // Note that if this fsub was really an fneg, the fadd with -0.0 will get // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. - if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { + if (I.hasNoSignedZeros() || cannotBeNegativeZero(Op0, SQ.DL, SQ.TLI)) { if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 97a001b2ed32..8a1fb6b7f17e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -625,7 +625,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, return RHS; } - if (Mask & BMask_Mixed) { + if (Mask & (BMask_Mixed | BMask_NotMixed)) { + // Mixed: // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -636,24 +637,50 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. + + // NotMixed: + // (icmp ne (A & B), C) & (icmp ne (A & D), E) + // -> (icmp ne (A & (B & D)), (C & E)) + // Check the intersection (B & D) for inequality. + // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B + // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the + // B and the D, don't contradict. + // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous + // operation should delete these icmps if it hadn't been met. + const APInt *OldConstC, *OldConstE; if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; - const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; - const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; + auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * { + CC = IsNot ? CmpInst::getInversePredicate(CC) : CC; + const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE; - // If there is a conflict, we should actually return a false for the - // whole construct. - if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return ConstantInt::get(LHS->getType(), !IsAnd); + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd); - Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewAnd = Builder.CreateAnd(A, NewOr1); - Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); - return Builder.CreateICmp(NewCC, NewAnd, NewOr2); - } + if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB)) + return nullptr; + APInt BD, CE; + if (IsNot) { + BD = *ConstB & *ConstD; + CE = ConstC & ConstE; + } else { + BD = *ConstB | *ConstD; + CE = ConstC | ConstE; + } + Value *NewAnd = Builder.CreateAnd(A, BD); + Value *CEVal = ConstantInt::get(A->getType(), CE); + return Builder.CreateICmp(CC, CEVal, NewAnd); + }; + + if (Mask & BMask_Mixed) + return FoldBMixed(NewCC, false); + if (Mask & BMask_NotMixed) // can be else also + return FoldBMixed(NewCC, true); + } return nullptr; } @@ -928,6 +955,108 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, return nullptr; } +/// Try to fold (icmp(A & B) == 0) & (icmp(A & D) != E) into (icmp A u< D) iff +/// B is a contiguous set of ones starting from the most significant bit +/// (negative power of 2), D and E are equal, and D is a contiguous set of ones +/// starting at the most significant zero bit in B. Parameter B supports masking +/// using undef/poison in either scalar or vector values. +static Value *foldNegativePower2AndShiftedMask( + Value *A, Value *B, Value *D, Value *E, ICmpInst::Predicate PredL, + ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) { + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); + if (PredL != ICmpInst::ICMP_EQ || PredR != ICmpInst::ICMP_NE) + return nullptr; + + if (!match(B, m_NegatedPower2()) || !match(D, m_ShiftedMask()) || + !match(E, m_ShiftedMask())) + return nullptr; + + // Test scalar arguments for conversion. B has been validated earlier to be a + // negative power of two and thus is guaranteed to have one or more contiguous + // ones starting from the MSB followed by zero or more contiguous zeros. D has + // been validated earlier to be a shifted set of one or more contiguous ones. + // In order to match, B leading ones and D leading zeros should be equal. The + // predicate that B be a negative power of 2 prevents the condition of there + // ever being zero leading ones. Thus 0 == 0 cannot occur. The predicate that + // D always be a shifted mask prevents the condition of D equaling 0. This + // prevents matching the condition where B contains the maximum number of + // leading one bits (-1) and D contains the maximum number of leading zero + // bits (0). + auto isReducible = [](const Value *B, const Value *D, const Value *E) { + const APInt *BCst, *DCst, *ECst; + return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) && + match(E, m_APInt(ECst)) && *DCst == *ECst && + (isa<UndefValue>(B) || + (BCst->countLeadingOnes() == DCst->countLeadingZeros())); + }; + + // Test vector type arguments for conversion. + if (const auto *BVTy = dyn_cast<VectorType>(B->getType())) { + const auto *BFVTy = dyn_cast<FixedVectorType>(BVTy); + const auto *BConst = dyn_cast<Constant>(B); + const auto *DConst = dyn_cast<Constant>(D); + const auto *EConst = dyn_cast<Constant>(E); + + if (!BFVTy || !BConst || !DConst || !EConst) + return nullptr; + + for (unsigned I = 0; I != BFVTy->getNumElements(); ++I) { + const auto *BElt = BConst->getAggregateElement(I); + const auto *DElt = DConst->getAggregateElement(I); + const auto *EElt = EConst->getAggregateElement(I); + + if (!BElt || !DElt || !EElt) + return nullptr; + if (!isReducible(BElt, DElt, EElt)) + return nullptr; + } + } else { + // Test scalar type arguments for conversion. + if (!isReducible(B, D, E)) + return nullptr; + } + return Builder.CreateICmp(ICmpInst::ICMP_ULT, A, D); +} + +/// Try to fold ((icmp X u< P) & (icmp(X & M) != M)) or ((icmp X s> -1) & +/// (icmp(X & M) != M)) into (icmp X u< M). Where P is a power of 2, M < P, and +/// M is a contiguous shifted mask starting at the right most significant zero +/// bit in P. SGT is supported as when P is the largest representable power of +/// 2, an earlier optimization converts the expression into (icmp X s> -1). +/// Parameter P supports masking using undef/poison in either scalar or vector +/// values. +static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool JoinedByAnd, + InstCombiner::BuilderTy &Builder) { + if (!JoinedByAnd) + return nullptr; + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; + ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(), + CmpPred1 = Cmp1->getPredicate(); + // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u< + // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X & + // SignMask) == 0). + std::optional<std::pair<unsigned, unsigned>> MaskPair = + getMaskedTypeForICmpPair(A, B, C, D, E, Cmp0, Cmp1, CmpPred0, CmpPred1); + if (!MaskPair) + return nullptr; + + const auto compareBMask = BMask_NotMixed | BMask_NotAllOnes; + unsigned CmpMask0 = MaskPair->first; + unsigned CmpMask1 = MaskPair->second; + if ((CmpMask0 & Mask_AllZeros) && (CmpMask1 == compareBMask)) { + if (Value *V = foldNegativePower2AndShiftedMask(A, B, D, E, CmpPred0, + CmpPred1, Builder)) + return V; + } else if ((CmpMask0 == compareBMask) && (CmpMask1 & Mask_AllZeros)) { + if (Value *V = foldNegativePower2AndShiftedMask(A, D, B, C, CmpPred1, + CmpPred0, Builder)) + return V; + } + return nullptr; +} + /// Commuted variants are assumed to be handled by calling this function again /// with the parameters swapped. static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, @@ -1313,9 +1442,44 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, return Right; } + // Turn at least two fcmps with constants into llvm.is.fpclass. + // + // If we can represent a combined value test with one class call, we can + // potentially eliminate 4-6 instructions. If we can represent a test with a + // single fcmp with fneg and fabs, that's likely a better canonical form. + if (LHS->hasOneUse() && RHS->hasOneUse()) { + auto [ClassValRHS, ClassMaskRHS] = + fcmpToClassTest(PredR, *RHS->getFunction(), RHS0, RHS1); + if (ClassValRHS) { + auto [ClassValLHS, ClassMaskLHS] = + fcmpToClassTest(PredL, *LHS->getFunction(), LHS0, LHS1); + if (ClassValLHS == ClassValRHS) { + unsigned CombinedMask = IsAnd ? (ClassMaskLHS & ClassMaskRHS) + : (ClassMaskLHS | ClassMaskRHS); + return Builder.CreateIntrinsic( + Intrinsic::is_fpclass, {ClassValLHS->getType()}, + {ClassValLHS, Builder.getInt32(CombinedMask)}); + } + } + } + return nullptr; } +/// Match an fcmp against a special value that performs a test possible by +/// llvm.is.fpclass. +static bool matchIsFPClassLikeFCmp(Value *Op, Value *&ClassVal, + uint64_t &ClassMask) { + auto *FCmp = dyn_cast<FCmpInst>(Op); + if (!FCmp || !FCmp->hasOneUse()) + return false; + + std::tie(ClassVal, ClassMask) = + fcmpToClassTest(FCmp->getPredicate(), *FCmp->getParent()->getParent(), + FCmp->getOperand(0), FCmp->getOperand(1)); + return ClassVal != nullptr; +} + /// or (is_fpclass x, mask0), (is_fpclass x, mask1) /// -> is_fpclass x, (mask0 | mask1) /// and (is_fpclass x, mask0), (is_fpclass x, mask1) @@ -1324,13 +1488,25 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, /// -> is_fpclass x, (mask0 ^ mask1) Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, Value *Op0, Value *Op1) { - Value *ClassVal; + Value *ClassVal0 = nullptr; + Value *ClassVal1 = nullptr; uint64_t ClassMask0, ClassMask1; - if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( - m_Value(ClassVal), m_ConstantInt(ClassMask0)))) && + // Restrict to folding one fcmp into one is.fpclass for now, don't introduce a + // new class. + // + // TODO: Support forming is.fpclass out of 2 separate fcmps when codegen is + // better. + + bool IsLHSClass = + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( + m_Value(ClassVal0), m_ConstantInt(ClassMask0)))); + bool IsRHSClass = match(Op1, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( - m_Specific(ClassVal), m_ConstantInt(ClassMask1))))) { + m_Value(ClassVal1), m_ConstantInt(ClassMask1)))); + if ((((IsLHSClass || matchIsFPClassLikeFCmp(Op0, ClassVal0, ClassMask0)) && + (IsRHSClass || matchIsFPClassLikeFCmp(Op1, ClassVal1, ClassMask1)))) && + ClassVal0 == ClassVal1) { unsigned NewClassMask; switch (BO.getOpcode()) { case Instruction::And: @@ -1346,11 +1522,24 @@ Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, llvm_unreachable("not a binary logic operator"); } - // TODO: Also check for special fcmps - auto *II = cast<IntrinsicInst>(Op0); - II->setArgOperand( - 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); - return replaceInstUsesWith(BO, II); + if (IsLHSClass) { + auto *II = cast<IntrinsicInst>(Op0); + II->setArgOperand( + 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); + return replaceInstUsesWith(BO, II); + } + + if (IsRHSClass) { + auto *II = cast<IntrinsicInst>(Op1); + II->setArgOperand( + 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); + return replaceInstUsesWith(BO, II); + } + + CallInst *NewClass = + Builder.CreateIntrinsic(Intrinsic::is_fpclass, {ClassVal0->getType()}, + {ClassVal0, Builder.getInt32(NewClassMask)}); + return replaceInstUsesWith(BO, NewClass); } return nullptr; @@ -1523,6 +1712,39 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding"); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // fold bitwise(A >> BW - 1, zext(icmp)) (BW is the scalar bits of the + // type of A) + // -> bitwise(zext(A < 0), zext(icmp)) + // -> zext(bitwise(A < 0, icmp)) + auto FoldBitwiseICmpZeroWithICmp = [&](Value *Op0, + Value *Op1) -> Instruction * { + ICmpInst::Predicate Pred; + Value *A; + bool IsMatched = + match(Op0, + m_OneUse(m_LShr( + m_Value(A), + m_SpecificInt(Op0->getType()->getScalarSizeInBits() - 1)))) && + match(Op1, m_OneUse(m_ZExt(m_ICmp(Pred, m_Value(), m_Value())))); + + if (!IsMatched) + return nullptr; + + auto *ICmpL = + Builder.CreateICmpSLT(A, Constant::getNullValue(A->getType())); + auto *ICmpR = cast<ZExtInst>(Op1)->getOperand(0); + auto *BitwiseOp = Builder.CreateBinOp(LogicOpc, ICmpL, ICmpR); + + return new ZExtInst(BitwiseOp, Op0->getType()); + }; + + if (auto *Ret = FoldBitwiseICmpZeroWithICmp(Op0, Op1)) + return Ret; + + if (auto *Ret = FoldBitwiseICmpZeroWithICmp(Op1, Op0)) + return Ret; + CastInst *Cast0 = dyn_cast<CastInst>(Op0); if (!Cast0) return nullptr; @@ -1906,16 +2128,16 @@ static Instruction *canonicalizeLogicFirst(BinaryOperator &I, return nullptr; unsigned Width = Ty->getScalarSizeInBits(); - unsigned LastOneMath = Width - C2->countTrailingZeros(); + unsigned LastOneMath = Width - C2->countr_zero(); switch (OpC) { case Instruction::And: - if (C->countLeadingOnes() < LastOneMath) + if (C->countl_one() < LastOneMath) return nullptr; break; case Instruction::Xor: case Instruction::Or: - if (C->countLeadingZeros() < LastOneMath) + if (C->countl_zero() < LastOneMath) return nullptr; break; default: @@ -1923,7 +2145,51 @@ static Instruction *canonicalizeLogicFirst(BinaryOperator &I, } Value *NewBinOp = Builder.CreateBinOp(OpC, X, ConstantInt::get(Ty, *C)); - return BinaryOperator::CreateAdd(NewBinOp, ConstantInt::get(Ty, *C2)); + return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, NewBinOp, + ConstantInt::get(Ty, *C2), Op0); +} + +// binop(shift(ShiftedC1, ShAmt), shift(ShiftedC2, add(ShAmt, AddC))) -> +// shift(binop(ShiftedC1, shift(ShiftedC2, AddC)), ShAmt) +// where both shifts are the same and AddC is a valid shift amount. +Instruction *InstCombinerImpl::foldBinOpOfDisplacedShifts(BinaryOperator &I) { + assert((I.isBitwiseLogicOp() || I.getOpcode() == Instruction::Add) && + "Unexpected opcode"); + + Value *ShAmt; + Constant *ShiftedC1, *ShiftedC2, *AddC; + Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (!match(&I, + m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)), + m_Shift(m_ImmConstant(ShiftedC2), + m_Add(m_Deferred(ShAmt), m_ImmConstant(AddC)))))) + return nullptr; + + // Make sure the add constant is a valid shift amount. + if (!match(AddC, + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(BitWidth, BitWidth)))) + return nullptr; + + // Avoid constant expressions. + auto *Op0Inst = dyn_cast<Instruction>(I.getOperand(0)); + auto *Op1Inst = dyn_cast<Instruction>(I.getOperand(1)); + if (!Op0Inst || !Op1Inst) + return nullptr; + + // Both shifts must be the same. + Instruction::BinaryOps ShiftOp = + static_cast<Instruction::BinaryOps>(Op0Inst->getOpcode()); + if (ShiftOp != Op1Inst->getOpcode()) + return nullptr; + + // For adds, only left shifts are supported. + if (I.getOpcode() == Instruction::Add && ShiftOp != Instruction::Shl) + return nullptr; + + Value *NewC = Builder.CreateBinOp( + I.getOpcode(), ShiftedC1, Builder.CreateBinOp(ShiftOp, ShiftedC2, AddC)); + return BinaryOperator::Create(ShiftOp, NewC, ShAmt); } // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches @@ -1964,6 +2230,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *X, *Y; @@ -2033,7 +2302,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) { // If we add zeros to every bit below a mask, the add has no effect: // (X + AddC) & LowMaskC --> X & LowMaskC - unsigned Ctlz = C->countLeadingZeros(); + unsigned Ctlz = C->countl_zero(); APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); if ((*AddC & LowMask).isZero()) return BinaryOperator::CreateAnd(X, Op1); @@ -2150,7 +2419,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { const APInt *C3 = C; Value *X; if (C3->isPowerOf2()) { - Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros()); + Constant *Log2C3 = ConstantInt::get(Ty, C3->countr_zero()); if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)), m_ImmConstant(C2)))) && match(C1, m_Power2())) { @@ -2407,6 +2676,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) return Folded; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + return nullptr; } @@ -2718,34 +2990,47 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, return nullptr; } -// (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1) -// (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1) -static Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, - bool IsAnd, bool IsLogical, - IRBuilderBase &Builder) { +// (icmp eq X, C) | (icmp ult Other, (X - C)) -> (icmp ule Other, (X - (C + 1))) +// (icmp ne X, C) & (icmp uge Other, (X - C)) -> (icmp ugt Other, (X - (C + 1))) +static Value *foldAndOrOfICmpEqConstantAndICmp(ICmpInst *LHS, ICmpInst *RHS, + bool IsAnd, bool IsLogical, + IRBuilderBase &Builder) { + Value *LHS0 = LHS->getOperand(0); + Value *RHS0 = RHS->getOperand(0); + Value *RHS1 = RHS->getOperand(1); + ICmpInst::Predicate LPred = IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); ICmpInst::Predicate RPred = IsAnd ? RHS->getInversePredicate() : RHS->getPredicate(); - Value *LHS0 = LHS->getOperand(0); - if (LPred != ICmpInst::ICMP_EQ || !match(LHS->getOperand(1), m_Zero()) || + + const APInt *CInt; + if (LPred != ICmpInst::ICMP_EQ || + !match(LHS->getOperand(1), m_APIntAllowUndef(CInt)) || !LHS0->getType()->isIntOrIntVectorTy() || !(LHS->hasOneUse() || RHS->hasOneUse())) return nullptr; + auto MatchRHSOp = [LHS0, CInt](const Value *RHSOp) { + return match(RHSOp, + m_Add(m_Specific(LHS0), m_SpecificIntAllowUndef(-*CInt))) || + (CInt->isZero() && RHSOp == LHS0); + }; + Value *Other; - if (RPred == ICmpInst::ICMP_ULT && RHS->getOperand(1) == LHS0) - Other = RHS->getOperand(0); - else if (RPred == ICmpInst::ICMP_UGT && RHS->getOperand(0) == LHS0) - Other = RHS->getOperand(1); + if (RPred == ICmpInst::ICMP_ULT && MatchRHSOp(RHS1)) + Other = RHS0; + else if (RPred == ICmpInst::ICMP_UGT && MatchRHSOp(RHS0)) + Other = RHS1; else return nullptr; if (IsLogical) Other = Builder.CreateFreeze(Other); + return Builder.CreateICmp( IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE, - Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())), + Builder.CreateSub(LHS0, ConstantInt::get(LHS0->getType(), *CInt + 1)), Other); } @@ -2792,12 +3077,12 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return V; if (Value *V = - foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, IsLogical, Builder)) + foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder)) return V; // We can treat logical like bitwise here, because both operands are used on // the LHS, and as such poison from both will propagate. - if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, - /*IsLogical*/ false, Builder)) + if (Value *V = foldAndOrOfICmpEqConstantAndICmp(RHS, LHS, IsAnd, + /*IsLogical*/ false, Builder)) return V; if (Value *V = @@ -2836,6 +3121,9 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder)) return V; + if (Value *V = foldPowerOf2AndShiftedMask(LHS, RHS, IsAnd, Builder)) + return V; + // TODO: Verify whether this is safe for logical and/or. if (!IsLogical) { if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder)) @@ -2849,7 +3137,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + // TODO: Remove this and below when foldLogOpOfMaskedICmps can handle undefs. if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && PredL == PredR && match(LHS1, m_ZeroInt()) && match(RHS1, m_ZeroInt()) && LHS0->getType() == RHS0->getType()) { @@ -2858,6 +3146,16 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Constant::getNullValue(NewOr->getType())); } + // (icmp ne A, -1) | (icmp ne B, -1) --> (icmp ne (A&B), -1) + // (icmp eq A, -1) & (icmp eq B, -1) --> (icmp eq (A&B), -1) + if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + PredL == PredR && match(LHS1, m_AllOnes()) && match(RHS1, m_AllOnes()) && + LHS0->getType() == RHS0->getType()) { + Value *NewAnd = Builder.CreateAnd(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewAnd, + Constant::getAllOnesValue(LHS0->getType())); + } + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; @@ -2998,6 +3296,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Concat = matchOrConcat(I, Builder)) return replaceInstUsesWith(I, Concat); + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && @@ -3416,6 +3717,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) return Folded; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + return nullptr; } @@ -3715,6 +4019,24 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor, return nullptr; } +static bool canFreelyInvert(InstCombiner &IC, Value *Op, + Instruction *IgnoredUser) { + auto *I = dyn_cast<Instruction>(Op); + return I && IC.isFreeToInvert(I, /*WillInvertAllUses=*/true) && + InstCombiner::canFreelyInvertAllUsersOf(I, IgnoredUser); +} + +static Value *freelyInvert(InstCombinerImpl &IC, Value *Op, + Instruction *IgnoredUser) { + auto *I = cast<Instruction>(Op); + IC.Builder.SetInsertPoint(&*I->getInsertionPointAfterDef()); + Value *NotOp = IC.Builder.CreateNot(Op, Op->getName() + ".not"); + Op->replaceUsesWithIf(NotOp, + [NotOp](Use &U) { return U.getUser() != NotOp; }); + IC.freelyInvertAllUsersOf(NotOp, IgnoredUser); + return NotOp; +} + // Transform // z = ~(x &/| y) // into: @@ -3739,28 +4061,11 @@ bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) { return false; // And can the operands be adapted? - for (Value *Op : {Op0, Op1}) - if (!(InstCombiner::isFreeToInvert(Op, /*WillInvertAllUses=*/true) && - (match(Op, m_ImmConstant()) || - (isa<Instruction>(Op) && - InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op), - /*IgnoredUser=*/&I))))) - return false; + if (!canFreelyInvert(*this, Op0, &I) || !canFreelyInvert(*this, Op1, &I)) + return false; - for (Value **Op : {&Op0, &Op1}) { - Value *NotOp; - if (auto *C = dyn_cast<Constant>(*Op)) { - NotOp = ConstantExpr::getNot(C); - } else { - Builder.SetInsertPoint( - &*cast<Instruction>(*Op)->getInsertionPointAfterDef()); - NotOp = Builder.CreateNot(*Op, (*Op)->getName() + ".not"); - (*Op)->replaceUsesWithIf( - NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; }); - freelyInvertAllUsersOf(NotOp, /*IgnoredUser=*/&I); - } - *Op = NotOp; - } + Op0 = freelyInvert(*this, Op0, &I); + Op1 = freelyInvert(*this, Op1, &I); Builder.SetInsertPoint(I.getInsertionPointAfterDef()); Value *NewLogicOp; @@ -3794,20 +4099,11 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { Value *NotOp0 = nullptr; Value *NotOp1 = nullptr; Value **OpToInvert = nullptr; - if (match(Op0, m_Not(m_Value(NotOp0))) && - InstCombiner::isFreeToInvert(Op1, /*WillInvertAllUses=*/true) && - (match(Op1, m_ImmConstant()) || - (isa<Instruction>(Op1) && - InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op1), - /*IgnoredUser=*/&I)))) { + if (match(Op0, m_Not(m_Value(NotOp0))) && canFreelyInvert(*this, Op1, &I)) { Op0 = NotOp0; OpToInvert = &Op1; } else if (match(Op1, m_Not(m_Value(NotOp1))) && - InstCombiner::isFreeToInvert(Op0, /*WillInvertAllUses=*/true) && - (match(Op0, m_ImmConstant()) || - (isa<Instruction>(Op0) && - InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op0), - /*IgnoredUser=*/&I)))) { + canFreelyInvert(*this, Op0, &I)) { Op1 = NotOp1; OpToInvert = &Op0; } else @@ -3817,19 +4113,7 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) return false; - if (auto *C = dyn_cast<Constant>(*OpToInvert)) { - *OpToInvert = ConstantExpr::getNot(C); - } else { - Builder.SetInsertPoint( - &*cast<Instruction>(*OpToInvert)->getInsertionPointAfterDef()); - Value *NotOpToInvert = - Builder.CreateNot(*OpToInvert, (*OpToInvert)->getName() + ".not"); - (*OpToInvert)->replaceUsesWithIf(NotOpToInvert, [NotOpToInvert](Use &U) { - return U.getUser() != NotOpToInvert; - }); - freelyInvertAllUsersOf(NotOpToInvert, /*IgnoredUser=*/&I); - *OpToInvert = NotOpToInvert; - } + *OpToInvert = freelyInvert(*this, *OpToInvert, &I); Builder.SetInsertPoint(&*I.getInsertionPointAfterDef()); Value *NewBinOp; @@ -3896,8 +4180,8 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) return BinaryOperator::CreateAShr(X, Y); - // Bit-hack form of a signbit test: - // iN ~X >>s (N-1) --> sext i1 (X > -1) to iN + // Bit-hack form of a signbit test for iN type: + // ~(X >>s (N - 1)) --> sext i1 (X > -1) to iN unsigned FullShift = Ty->getScalarSizeInBits() - 1; if (match(NotVal, m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))))) { Value *IsNotNeg = Builder.CreateIsNotNeg(X, "isnotneg"); @@ -4071,6 +4355,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *R = foldNot(I)) return R; + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) // This it a special case in haveNoCommonBitsSet, but the computeKnownBits // calls in there are unnecessary as SimplifyDemandedInstructionBits should @@ -4280,6 +4567,23 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { } } + // (A & B) ^ (A | C) --> A ? ~B : C -- There are 4 commuted variants. + if (I.getType()->isIntOrIntVectorTy(1) && + match(Op0, m_OneUse(m_LogicalAnd(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_LogicalOr(m_Value(C), m_Value(D))))) { + bool NeedFreeze = isa<SelectInst>(Op0) && isa<SelectInst>(Op1) && B == D; + if (B == C || B == D) + std::swap(A, B); + if (A == C) + std::swap(C, D); + if (A == D) { + if (NeedFreeze) + A = Builder.CreateFreeze(A); + Value *NotB = Builder.CreateNot(B); + return SelectInst::Create(A, NotB, C); + } + } + if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (Value *V = foldXorOfICmps(LHS, RHS, I)) @@ -4313,5 +4617,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I)) return Folded; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index e73667f9c02e..cba282cea72b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -116,24 +116,10 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { return &RMWI; } - AtomicOrdering Ordering = RMWI.getOrdering(); - assert(Ordering != AtomicOrdering::NotAtomic && - Ordering != AtomicOrdering::Unordered && + assert(RMWI.getOrdering() != AtomicOrdering::NotAtomic && + RMWI.getOrdering() != AtomicOrdering::Unordered && "AtomicRMWs don't make sense with Unordered or NotAtomic"); - // Any atomicrmw xchg with no uses can be converted to a atomic store if the - // ordering is compatible. - if (RMWI.getOperation() == AtomicRMWInst::Xchg && - RMWI.use_empty()) { - if (Ordering != AtomicOrdering::Release && - Ordering != AtomicOrdering::Monotonic) - return nullptr; - new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), - /*isVolatile*/ false, RMWI.getAlign(), Ordering, - RMWI.getSyncScopeID(), &RMWI); - return eraseInstFromFunction(RMWI); - } - if (!isIdempotentRMW(RMWI)) return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index fbf1327143a8..d3ec6a7aa667 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -439,9 +440,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType()); ElementCount VF = WideLoadTy->getElementCount(); - Constant *EC = - ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue()); - Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC; + Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF); Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1)); Value *Extract = Builder.CreateExtractElement(II.getArgOperand(0), LastLane); @@ -533,16 +532,15 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType())); } - // If the operand is a select with constant arm(s), try to hoist ctlz/cttz. - if (auto *Sel = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = IC.FoldOpIntoSelect(II, Sel)) - return R; - if (IsTZ) { // cttz(-x) -> cttz(x) if (match(Op0, m_Neg(m_Value(X)))) return IC.replaceOperand(II, 0, X); + // cttz(-x & x) -> cttz(x) + if (match(Op0, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) + return IC.replaceOperand(II, 0, X); + // cttz(sext(x)) -> cttz(zext(x)) if (match(Op0, m_OneUse(m_SExt(m_Value(X))))) { auto *Zext = IC.Builder.CreateZExt(X, II.getType()); @@ -599,8 +597,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { } // Add range metadata since known bits can't completely reflect what we know. - // TODO: Handle splat vectors. - auto *IT = dyn_cast<IntegerType>(Op0->getType()); + auto *IT = cast<IntegerType>(Op0->getType()->getScalarType()); if (IT && IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { Metadata *LowAndHigh[] = { ConstantAsMetadata::get(ConstantInt::get(IT, DefiniteZeros)), @@ -657,11 +654,6 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { return CastInst::Create(Instruction::ZExt, NarrowPop, Ty); } - // If the operand is a select with constant arm(s), try to hoist ctpop. - if (auto *Sel = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = IC.FoldOpIntoSelect(II, Sel)) - return R; - KnownBits Known(BitWidth); IC.computeKnownBits(Op0, Known, 0, &II); @@ -683,12 +675,8 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { Constant::getNullValue(Ty)), Ty); - // FIXME: Try to simplify vectors of integers. - auto *IT = dyn_cast<IntegerType>(Ty); - if (!IT) - return nullptr; - // Add range metadata since known bits can't completely reflect what we know. + auto *IT = cast<IntegerType>(Ty->getScalarType()); unsigned MinCount = Known.countMinPopulation(); unsigned MaxCount = Known.countMaxPopulation(); if (IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { @@ -830,10 +818,204 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { return nullptr; } +static bool inputDenormalIsIEEE(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + return F.getDenormalMode(Ty->getFltSemantics()).Input == DenormalMode::IEEE; +} + +static bool inputDenormalIsDAZ(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + return F.getDenormalMode(Ty->getFltSemantics()).inputsAreZero(); +} + +/// \returns the compare predicate type if the test performed by +/// llvm.is.fpclass(x, \p Mask) is equivalent to fcmp o__ x, 0.0 with the +/// floating-point environment assumed for \p F for type \p Ty +static FCmpInst::Predicate fpclassTestIsFCmp0(FPClassTest Mask, + const Function &F, Type *Ty) { + switch (static_cast<unsigned>(Mask)) { + case fcZero: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OEQ; + break; + case fcZero | fcSubnormal: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OEQ; + break; + case fcPositive | fcNegZero: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OGE; + break; + case fcPositive | fcNegZero | fcNegSubnormal: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OGE; + break; + case fcPosSubnormal | fcPosNormal | fcPosInf: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OGT; + break; + case fcNegative | fcPosZero: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OLE; + break; + case fcNegative | fcPosZero | fcPosSubnormal: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OLE; + break; + case fcNegSubnormal | fcNegNormal | fcNegInf: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OLT; + break; + case fcPosNormal | fcPosInf: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OGT; + break; + case fcNegNormal | fcNegInf: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OLT; + break; + case ~fcZero & ~fcNan: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_ONE; + break; + case ~(fcZero | fcSubnormal) & ~fcNan: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_ONE; + break; + default: + break; + } + + return FCmpInst::BAD_FCMP_PREDICATE; +} + +Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { + Value *Src0 = II.getArgOperand(0); + Value *Src1 = II.getArgOperand(1); + const ConstantInt *CMask = cast<ConstantInt>(Src1); + FPClassTest Mask = static_cast<FPClassTest>(CMask->getZExtValue()); + const bool IsUnordered = (Mask & fcNan) == fcNan; + const bool IsOrdered = (Mask & fcNan) == fcNone; + const FPClassTest OrderedMask = Mask & ~fcNan; + const FPClassTest OrderedInvertedMask = ~OrderedMask & ~fcNan; + + const bool IsStrict = II.isStrictFP(); + + Value *FNegSrc; + if (match(Src0, m_FNeg(m_Value(FNegSrc)))) { + // is.fpclass (fneg x), mask -> is.fpclass x, (fneg mask) + + II.setArgOperand(1, ConstantInt::get(Src1->getType(), fneg(Mask))); + return replaceOperand(II, 0, FNegSrc); + } + + Value *FAbsSrc; + if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) { + II.setArgOperand(1, ConstantInt::get(Src1->getType(), fabs(Mask))); + return replaceOperand(II, 0, FAbsSrc); + } + + // TODO: is.fpclass(x, fcInf) -> fabs(x) == inf + + if ((OrderedMask == fcPosInf || OrderedMask == fcNegInf) && + (IsOrdered || IsUnordered) && !IsStrict) { + // is.fpclass(x, fcPosInf) -> fcmp oeq x, +inf + // is.fpclass(x, fcNegInf) -> fcmp oeq x, -inf + // is.fpclass(x, fcPosInf|fcNan) -> fcmp ueq x, +inf + // is.fpclass(x, fcNegInf|fcNan) -> fcmp ueq x, -inf + Constant *Inf = + ConstantFP::getInfinity(Src0->getType(), OrderedMask == fcNegInf); + Value *EqInf = IsUnordered ? Builder.CreateFCmpUEQ(Src0, Inf) + : Builder.CreateFCmpOEQ(Src0, Inf); + + EqInf->takeName(&II); + return replaceInstUsesWith(II, EqInf); + } + + if ((OrderedInvertedMask == fcPosInf || OrderedInvertedMask == fcNegInf) && + (IsOrdered || IsUnordered) && !IsStrict) { + // is.fpclass(x, ~fcPosInf) -> fcmp one x, +inf + // is.fpclass(x, ~fcNegInf) -> fcmp one x, -inf + // is.fpclass(x, ~fcPosInf|fcNan) -> fcmp une x, +inf + // is.fpclass(x, ~fcNegInf|fcNan) -> fcmp une x, -inf + Constant *Inf = ConstantFP::getInfinity(Src0->getType(), + OrderedInvertedMask == fcNegInf); + Value *NeInf = IsUnordered ? Builder.CreateFCmpUNE(Src0, Inf) + : Builder.CreateFCmpONE(Src0, Inf); + NeInf->takeName(&II); + return replaceInstUsesWith(II, NeInf); + } + + if (Mask == fcNan && !IsStrict) { + // Equivalent of isnan. Replace with standard fcmp if we don't care about FP + // exceptions. + Value *IsNan = + Builder.CreateFCmpUNO(Src0, ConstantFP::getZero(Src0->getType())); + IsNan->takeName(&II); + return replaceInstUsesWith(II, IsNan); + } + + if (Mask == (~fcNan & fcAllFlags) && !IsStrict) { + // Equivalent of !isnan. Replace with standard fcmp. + Value *FCmp = + Builder.CreateFCmpORD(Src0, ConstantFP::getZero(Src0->getType())); + FCmp->takeName(&II); + return replaceInstUsesWith(II, FCmp); + } + + FCmpInst::Predicate PredType = FCmpInst::BAD_FCMP_PREDICATE; + + // Try to replace with an fcmp with 0 + // + // is.fpclass(x, fcZero) -> fcmp oeq x, 0.0 + // is.fpclass(x, fcZero | fcNan) -> fcmp ueq x, 0.0 + // is.fpclass(x, ~fcZero & ~fcNan) -> fcmp one x, 0.0 + // is.fpclass(x, ~fcZero) -> fcmp une x, 0.0 + // + // is.fpclass(x, fcPosSubnormal | fcPosNormal | fcPosInf) -> fcmp ogt x, 0.0 + // is.fpclass(x, fcPositive | fcNegZero) -> fcmp oge x, 0.0 + // + // is.fpclass(x, fcNegSubnormal | fcNegNormal | fcNegInf) -> fcmp olt x, 0.0 + // is.fpclass(x, fcNegative | fcPosZero) -> fcmp ole x, 0.0 + // + if (!IsStrict && (IsOrdered || IsUnordered) && + (PredType = fpclassTestIsFCmp0(OrderedMask, *II.getFunction(), + Src0->getType())) != + FCmpInst::BAD_FCMP_PREDICATE) { + Constant *Zero = ConstantFP::getZero(Src0->getType()); + // Equivalent of == 0. + Value *FCmp = Builder.CreateFCmp( + IsUnordered ? FCmpInst::getUnorderedPredicate(PredType) : PredType, + Src0, Zero); + + FCmp->takeName(&II); + return replaceInstUsesWith(II, FCmp); + } + + KnownFPClass Known = computeKnownFPClass( + Src0, DL, Mask, 0, &getTargetLibraryInfo(), &AC, &II, &DT); + + // Clear test bits we know must be false from the source value. + // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other + // fp_class (ninf x), ninf|pinf|other -> fp_class (ninf x), other + if ((Mask & Known.KnownFPClasses) != Mask) { + II.setArgOperand( + 1, ConstantInt::get(Src1->getType(), Mask & Known.KnownFPClasses)); + return &II; + } + + // If none of the tests which can return false are possible, fold to true. + // fp_class (nnan x), ~(qnan|snan) -> true + // fp_class (ninf x), ~(ninf|pinf) -> true + if (Mask == Known.KnownFPClasses) + return replaceInstUsesWith(II, ConstantInt::get(II.getType(), true)); + + return nullptr; +} + static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, - const DataLayout &DL, - AssumptionCache *AC, - DominatorTree *DT) { + const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); if (Known.isNonNegative()) return false; @@ -848,6 +1030,19 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } +/// Return true if two values \p Op0 and \p Op1 are known to have the same sign. +static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI, + const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { + std::optional<bool> Known1 = getKnownSign(Op1, CxtI, DL, AC, DT); + if (!Known1) + return false; + std::optional<bool> Known0 = getKnownSign(Op0, CxtI, DL, AC, DT); + if (!Known0) + return false; + return *Known0 == *Known1; +} + /// Try to canonicalize min/max(X + C0, C1) as min/max(X, C1 - C0) + C0. This /// can trigger other combines. static Instruction *moveAddAfterMinMax(IntrinsicInst *II, @@ -991,7 +1186,8 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, /// If this min/max has a constant operand and an operand that is a matching /// min/max with a constant operand, constant-fold the 2 constant operands. -static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) { +static Value *reassociateMinMaxWithConstants(IntrinsicInst *II, + IRBuilderBase &Builder) { Intrinsic::ID MinMaxID = II->getIntrinsicID(); auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); if (!LHS || LHS->getIntrinsicID() != MinMaxID) @@ -1004,12 +1200,10 @@ static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) { // max (max X, C0), C1 --> max X, (max C0, C1) --> max X, NewC ICmpInst::Predicate Pred = MinMaxIntrinsic::getPredicate(MinMaxID); - Constant *CondC = ConstantExpr::getICmp(Pred, C0, C1); - Constant *NewC = ConstantExpr::getSelect(CondC, C0, C1); - - Module *Mod = II->getModule(); - Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); - return CallInst::Create(MinMax, {LHS->getArgOperand(0), NewC}); + Value *CondC = Builder.CreateICmp(Pred, C0, C1); + Value *NewC = Builder.CreateSelect(CondC, C0, C1); + return Builder.CreateIntrinsic(MinMaxID, II->getType(), + {LHS->getArgOperand(0), NewC}); } /// If this min/max has a matching min/max operand with a constant, try to push @@ -1149,15 +1343,60 @@ foldShuffledIntrinsicOperands(IntrinsicInst *II, return new ShuffleVectorInst(NewIntrinsic, Mask); } +/// Fold the following cases and accepts bswap and bitreverse intrinsics: +/// bswap(logic_op(bswap(x), y)) --> logic_op(x, bswap(y)) +/// bswap(logic_op(bswap(x), bswap(y))) --> logic_op(x, y) (ignores multiuse) +template <Intrinsic::ID IntrID> +static Instruction *foldBitOrderCrossLogicOp(Value *V, + InstCombiner::BuilderTy &Builder) { + static_assert(IntrID == Intrinsic::bswap || IntrID == Intrinsic::bitreverse, + "This helper only supports BSWAP and BITREVERSE intrinsics"); + + Value *X, *Y; + // Find bitwise logic op. Check that it is a BinaryOperator explicitly so we + // don't match ConstantExpr that aren't meaningful for this transform. + if (match(V, m_OneUse(m_BitwiseLogic(m_Value(X), m_Value(Y)))) && + isa<BinaryOperator>(V)) { + Value *OldReorderX, *OldReorderY; + BinaryOperator::BinaryOps Op = cast<BinaryOperator>(V)->getOpcode(); + + // If both X and Y are bswap/bitreverse, the transform reduces the number + // of instructions even if there's multiuse. + // If only one operand is bswap/bitreverse, we need to ensure the operand + // have only one use. + if (match(X, m_Intrinsic<IntrID>(m_Value(OldReorderX))) && + match(Y, m_Intrinsic<IntrID>(m_Value(OldReorderY)))) { + return BinaryOperator::Create(Op, OldReorderX, OldReorderY); + } + + if (match(X, m_OneUse(m_Intrinsic<IntrID>(m_Value(OldReorderX))))) { + Value *NewReorder = Builder.CreateUnaryIntrinsic(IntrID, Y); + return BinaryOperator::Create(Op, OldReorderX, NewReorder); + } + + if (match(Y, m_OneUse(m_Intrinsic<IntrID>(m_Value(OldReorderY))))) { + Value *NewReorder = Builder.CreateUnaryIntrinsic(IntrID, X); + return BinaryOperator::Create(Op, NewReorder, OldReorderY); + } + } + return nullptr; +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Don't try to simplify calls without uses. It will not do anything useful, // but will result in the following folds being skipped. - if (!CI.use_empty()) - if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI))) + if (!CI.use_empty()) { + SmallVector<Value *, 4> Args; + Args.reserve(CI.arg_size()); + for (Value *Op : CI.args()) + Args.push_back(Op); + if (Value *V = simplifyCall(&CI, CI.getCalledOperand(), Args, + SQ.getWithInstruction(&CI))) return replaceInstUsesWith(CI, V); + } if (Value *FreedOp = getFreedOperand(&CI, &TLI)) return visitFree(CI, FreedOp); @@ -1176,7 +1415,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // not a multiple of element size then behavior is undefined. if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II)) if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(AMI->getLength())) - if (NumBytes->getSExtValue() < 0 || + if (NumBytes->isNegative() || (NumBytes->getZExtValue() % AMI->getElementSizeInBytes() != 0)) { CreateNonTerminatorUnreachable(AMI); assert(AMI->getType()->isVoidTy() && @@ -1267,10 +1506,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Intrinsic::ID IID = II->getIntrinsicID(); switch (IID) { - case Intrinsic::objectsize: - if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false)) + case Intrinsic::objectsize: { + SmallVector<Instruction *> InsertedInstructions; + if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false, + &InsertedInstructions)) { + for (Instruction *Inserted : InsertedInstructions) + Worklist.add(Inserted); return replaceInstUsesWith(CI, V); + } return nullptr; + } case Intrinsic::abs: { Value *IIOperand = II->getArgOperand(0); bool IntMinIsPoison = cast<Constant>(II->getArgOperand(1))->isOneValue(); @@ -1377,6 +1622,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } + // (umax X, (xor X, Pow2)) + // -> (or X, Pow2) + // (umin X, (xor X, Pow2)) + // -> (and X, ~Pow2) + // (smax X, (xor X, Pos_Pow2)) + // -> (or X, Pos_Pow2) + // (smin X, (xor X, Pos_Pow2)) + // -> (and X, ~Pos_Pow2) + // (smax X, (xor X, Neg_Pow2)) + // -> (and X, ~Neg_Pow2) + // (smin X, (xor X, Neg_Pow2)) + // -> (or X, Neg_Pow2) + if ((match(I0, m_c_Xor(m_Specific(I1), m_Value(X))) || + match(I1, m_c_Xor(m_Specific(I0), m_Value(X)))) && + isKnownToBeAPowerOfTwo(X, /* OrZero */ true)) { + bool UseOr = IID == Intrinsic::smax || IID == Intrinsic::umax; + bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin; + + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { + auto KnownSign = getKnownSign(X, II, DL, &AC, &DT); + if (KnownSign == std::nullopt) { + UseOr = false; + UseAndN = false; + } else if (*KnownSign /* true is Signed. */) { + UseOr ^= true; + UseAndN ^= true; + Type *Ty = I0->getType(); + // Negative power of 2 must be IntMin. It's possible to be able to + // prove negative / power of 2 without actually having known bits, so + // just get the value by hand. + X = Constant::getIntegerValue( + Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits())); + } + } + if (UseOr) + return BinaryOperator::CreateOr(I0, X); + else if (UseAndN) + return BinaryOperator::CreateAnd(I0, Builder.CreateNot(X)); + } + // If we can eliminate ~A and Y is free to invert: // max ~A, Y --> ~(min A, ~Y) // @@ -1436,13 +1721,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *SAdd = matchSAddSubSat(*II)) return SAdd; - if (match(I1, m_ImmConstant())) - if (auto *Sel = dyn_cast<SelectInst>(I0)) - if (Instruction *R = FoldOpIntoSelect(*II, Sel)) - return R; - - if (Instruction *NewMinMax = reassociateMinMaxWithConstants(II)) - return NewMinMax; + if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder)) + return replaceInstUsesWith(*II, NewMinMax); if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder)) return R; @@ -1453,15 +1733,21 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } case Intrinsic::bitreverse: { + Value *IIOperand = II->getArgOperand(0); // bitrev (zext i1 X to ?) --> X ? SignBitC : 0 Value *X; - if (match(II->getArgOperand(0), m_ZExt(m_Value(X))) && + if (match(IIOperand, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { Type *Ty = II->getType(); APInt SignBit = APInt::getSignMask(Ty->getScalarSizeInBits()); return SelectInst::Create(X, ConstantInt::get(Ty, SignBit), ConstantInt::getNullValue(Ty)); } + + if (Instruction *crossLogicOpFold = + foldBitOrderCrossLogicOp<Intrinsic::bitreverse>(IIOperand, Builder)) + return crossLogicOpFold; + break; } case Intrinsic::bswap: { @@ -1511,6 +1797,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *V = Builder.CreateLShr(X, CV); return new TruncInst(V, IIOperand->getType()); } + + if (Instruction *crossLogicOpFold = + foldBitOrderCrossLogicOp<Intrinsic::bswap>(IIOperand, Builder)) { + return crossLogicOpFold; + } + break; } case Intrinsic::masked_load: @@ -1616,6 +1908,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Function *Bswap = Intrinsic::getDeclaration(Mod, Intrinsic::bswap, Ty); return CallInst::Create(Bswap, { Op0 }); } + if (Instruction *BitOp = + matchBSwapOrBitReverse(*II, /*MatchBSwaps*/ true, + /*MatchBitReversals*/ true)) + return BitOp; } // Left or right might be masked. @@ -1983,7 +2279,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::copysign: { Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1); - if (SignBitMustBeZero(Sign, &TLI)) { + if (SignBitMustBeZero(Sign, DL, &TLI)) { // If we know that the sign argument is positive, reduce to FABS: // copysign Mag, +Sign --> fabs Mag Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); @@ -2079,6 +2375,42 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::ldexp: { + // ldexp(ldexp(x, a), b) -> ldexp(x, a + b) + // + // The danger is if the first ldexp would overflow to infinity or underflow + // to zero, but the combined exponent avoids it. We ignore this with + // reassoc. + // + // It's also safe to fold if we know both exponents are >= 0 or <= 0 since + // it would just double down on the overflow/underflow which would occur + // anyway. + // + // TODO: Could do better if we had range tracking for the input value + // exponent. Also could broaden sign check to cover == 0 case. + Value *Src = II->getArgOperand(0); + Value *Exp = II->getArgOperand(1); + Value *InnerSrc; + Value *InnerExp; + if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ldexp>( + m_Value(InnerSrc), m_Value(InnerExp)))) && + Exp->getType() == InnerExp->getType()) { + FastMathFlags FMF = II->getFastMathFlags(); + FastMathFlags InnerFlags = cast<FPMathOperator>(Src)->getFastMathFlags(); + + if ((FMF.allowReassoc() && InnerFlags.allowReassoc()) || + signBitMustBeTheSame(Exp, InnerExp, II, DL, &AC, &DT)) { + // TODO: Add nsw/nuw probably safe if integer type exceeds exponent + // width. + Value *NewExp = Builder.CreateAdd(InnerExp, Exp); + II->setArgOperand(1, NewExp); + II->setFastMathFlags(InnerFlags); // Or the inner flags. + return replaceOperand(*II, 0, InnerSrc); + } + } + + break; + } case Intrinsic::ptrauth_auth: case Intrinsic::ptrauth_resign: { // (sign|resign) + (auth|resign) can be folded by omitting the middle @@ -2380,12 +2712,34 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { isValidAssumeForContext(II, LHS, &DT)) { MDNode *MD = MDNode::get(II->getContext(), std::nullopt); LHS->setMetadata(LLVMContext::MD_nonnull, MD); + LHS->setMetadata(LLVMContext::MD_noundef, MD); return RemoveConditionFromAssume(II); // TODO: apply nonnull return attributes to calls and invokes // TODO: apply range metadata for range check patterns? } + // Separate storage assumptions apply to the underlying allocations, not any + // particular pointer within them. When evaluating the hints for AA purposes + // we getUnderlyingObject them; by precomputing the answers here we can + // avoid having to do so repeatedly there. + for (unsigned Idx = 0; Idx < II->getNumOperandBundles(); Idx++) { + OperandBundleUse OBU = II->getOperandBundleAt(Idx); + if (OBU.getTagName() == "separate_storage") { + assert(OBU.Inputs.size() == 2); + auto MaybeSimplifyHint = [&](const Use &U) { + Value *Hint = U.get(); + // Not having a limit is safe because InstCombine removes unreachable + // code. + Value *UnderlyingObject = getUnderlyingObject(Hint, /*MaxLookup*/ 0); + if (Hint != UnderlyingObject) + replaceUse(const_cast<Use &>(U), UnderlyingObject); + }; + MaybeSimplifyHint(OBU.Inputs[0]); + MaybeSimplifyHint(OBU.Inputs[1]); + } + } + // Convert nonnull assume like: // %A = icmp ne i32* %PTR, null // call void @llvm.assume(i1 %A) @@ -2479,6 +2833,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Known.isAllOnes() && isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) return eraseInstFromFunction(*II); + // assume(false) is unreachable. + if (match(IIOperand, m_CombineOr(m_Zero(), m_Undef()))) { + CreateNonTerminatorUnreachable(II); + return eraseInstFromFunction(*II); + } + // Update the cache of affected values for this assumption (we might be // here because we just simplified the condition). AC.updateAffectedValues(cast<AssumeInst>(II)); @@ -2545,7 +2905,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { for (i = 0; i != SubVecNumElts; ++i) WidenMask.push_back(i); for (; i != VecNumElts; ++i) - WidenMask.push_back(UndefMaskElem); + WidenMask.push_back(PoisonMaskElem); Value *WidenShuffle = Builder.CreateShuffleVector(SubVec, WidenMask); @@ -2840,7 +3200,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { int Sz = Mask.size(); SmallBitVector UsedIndices(Sz); for (int Idx : Mask) { - if (Idx == UndefMaskElem || UsedIndices.test(Idx)) + if (Idx == PoisonMaskElem || UsedIndices.test(Idx)) break; UsedIndices.set(Idx); } @@ -2852,6 +3212,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::is_fpclass: { + if (Instruction *I = foldIntrinsicIsFPClass(*II)) + return I; + break; + } default: { // Handle target specific intrinsics std::optional<Instruction *> V = targetInstCombineIntrinsic(*II); @@ -2861,6 +3226,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } + // Try to fold intrinsic into select operands. This is legal if: + // * The intrinsic is speculatable. + // * The select condition is not a vector, or the intrinsic does not + // perform cross-lane operations. + switch (IID) { + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::ctpop: + case Intrinsic::umin: + case Intrinsic::umax: + case Intrinsic::smin: + case Intrinsic::smax: + case Intrinsic::usub_sat: + case Intrinsic::uadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::sadd_sat: + for (Value *Op : II->args()) + if (auto *Sel = dyn_cast<SelectInst>(Op)) + if (Instruction *R = FoldOpIntoSelect(*II, Sel)) + return R; + [[fallthrough]]; + default: + break; + } + if (Instruction *Shuf = foldShuffledIntrinsicOperands(II, Builder)) return Shuf; @@ -2907,49 +3297,6 @@ Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) { return visitCallBase(CBI); } -/// If this cast does not affect the value passed through the varargs area, we -/// can eliminate the use of the cast. -static bool isSafeToEliminateVarargsCast(const CallBase &Call, - const DataLayout &DL, - const CastInst *const CI, - const int ix) { - if (!CI->isLosslessCast()) - return false; - - // If this is a GC intrinsic, avoid munging types. We need types for - // statepoint reconstruction in SelectionDAG. - // TODO: This is probably something which should be expanded to all - // intrinsics since the entire point of intrinsics is that - // they are understandable by the optimizer. - if (isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) || - isa<GCResultInst>(Call)) - return false; - - // Opaque pointers are compatible with any byval types. - PointerType *SrcTy = cast<PointerType>(CI->getOperand(0)->getType()); - if (SrcTy->isOpaque()) - return true; - - // The size of ByVal or InAlloca arguments is derived from the type, so we - // can't change to a type with a different size. If the size were - // passed explicitly we could avoid this check. - if (!Call.isPassPointeeByValueArgument(ix)) - return true; - - // The transform currently only handles type replacement for byval, not other - // type-carrying attributes. - if (!Call.isByValArgument(ix)) - return false; - - Type *SrcElemTy = SrcTy->getNonOpaquePointerElementType(); - Type *DstElemTy = Call.getParamByValType(ix); - if (!SrcElemTy->isSized() || !DstElemTy->isSized()) - return false; - if (DL.getTypeAllocSize(SrcElemTy) != DL.getTypeAllocSize(DstElemTy)) - return false; - return true; -} - Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { if (!CI->getCalledFunction()) return nullptr; @@ -2965,7 +3312,7 @@ Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { auto InstCombineErase = [this](Instruction *I) { eraseInstFromFunction(*I); }; - LibCallSimplifier Simplifier(DL, &TLI, ORE, BFI, PSI, InstCombineRAUW, + LibCallSimplifier Simplifier(DL, &TLI, &AC, ORE, BFI, PSI, InstCombineRAUW, InstCombineErase); if (Value *With = Simplifier.optimizeCall(CI, Builder)) { ++NumSimplified; @@ -3198,32 +3545,6 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { if (IntrinsicInst *II = findInitTrampoline(Callee)) return transformCallThroughTrampoline(Call, *II); - // TODO: Drop this transform once opaque pointer transition is done. - FunctionType *FTy = Call.getFunctionType(); - if (FTy->isVarArg()) { - int ix = FTy->getNumParams(); - // See if we can optimize any arguments passed through the varargs area of - // the call. - for (auto I = Call.arg_begin() + FTy->getNumParams(), E = Call.arg_end(); - I != E; ++I, ++ix) { - CastInst *CI = dyn_cast<CastInst>(*I); - if (CI && isSafeToEliminateVarargsCast(Call, DL, CI, ix)) { - replaceUse(*I, CI->getOperand(0)); - - // Update the byval type to match the pointer type. - // Not necessary for opaque pointers. - PointerType *NewTy = cast<PointerType>(CI->getOperand(0)->getType()); - if (!NewTy->isOpaque() && Call.isByValArgument(ix)) { - Call.removeParamAttr(ix, Attribute::ByVal); - Call.addParamAttr(ix, Attribute::getWithByValType( - Call.getContext(), - NewTy->getNonOpaquePointerElementType())); - } - Changed = true; - } - } - } - if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) { InlineAsm *IA = cast<InlineAsm>(Callee); if (!IA->canThrow()) { @@ -3381,13 +3702,17 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { } /// If the callee is a constexpr cast of a function, attempt to move the cast to -/// the arguments of the call/callbr/invoke. +/// the arguments of the call/invoke. +/// CallBrInst is not supported. bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { auto *Callee = dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts()); if (!Callee) return false; + assert(!isa<CallBrInst>(Call) && + "CallBr's don't have a single point after a def to insert at"); + // If this is a call to a thunk function, don't remove the cast. Thunks are // used to transparently forward all incoming parameters and outgoing return // values, so it's important to leave the cast in place. @@ -3433,7 +3758,7 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { return false; // Attribute not compatible with transformed value. } - // If the callbase is an invoke/callbr instruction, and the return value is + // If the callbase is an invoke instruction, and the return value is // used by a PHI node in a successor, we cannot change the return type of // the call because there is no place to put the cast instruction (without // breaking the critical edge). Bail out in this case. @@ -3441,8 +3766,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { BasicBlock *PhisNotSupportedBlock = nullptr; if (auto *II = dyn_cast<InvokeInst>(Caller)) PhisNotSupportedBlock = II->getNormalDest(); - if (auto *CB = dyn_cast<CallBrInst>(Caller)) - PhisNotSupportedBlock = CB->getDefaultDest(); if (PhisNotSupportedBlock) for (User *U : Caller->users()) if (PHINode *PN = dyn_cast<PHINode>(U)) @@ -3490,24 +3813,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (CallerPAL.hasParamAttr(i, Attribute::ByVal) != Callee->getAttributes().hasParamAttr(i, Attribute::ByVal)) return false; // Cannot transform to or from byval. - - // If the parameter is passed as a byval argument, then we have to have a - // sized type and the sized type has to have the same size as the old type. - if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { - PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); - if (!ParamPTy) - return false; - - if (!ParamPTy->isOpaque()) { - Type *ParamElTy = ParamPTy->getNonOpaquePointerElementType(); - if (!ParamElTy->isSized()) - return false; - - Type *CurElTy = Call.getParamByValType(i); - if (DL.getTypeAllocSize(CurElTy) != DL.getTypeAllocSize(ParamElTy)) - return false; - } - } } if (Callee->isDeclaration()) { @@ -3568,16 +3873,8 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // type. Note that we made sure all incompatible ones are safe to drop. AttributeMask IncompatibleAttrs = AttributeFuncs::typeIncompatible( ParamTy, AttributeFuncs::ASK_SAFE_TO_DROP); - if (CallerPAL.hasParamAttr(i, Attribute::ByVal) && - !ParamTy->isOpaquePointerTy()) { - AttrBuilder AB(Ctx, CallerPAL.getParamAttrs(i).removeAttributes( - Ctx, IncompatibleAttrs)); - AB.addByValAttr(ParamTy->getNonOpaquePointerElementType()); - ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); - } else { - ArgAttrs.push_back( - CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs)); - } + ArgAttrs.push_back( + CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs)); } // If the function takes more arguments than the call was taking, add them @@ -3626,9 +3923,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { NewCall = Builder.CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(), Args, OpBundles); - } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { - NewCall = Builder.CreateCallBr(Callee, CBI->getDefaultDest(), - CBI->getIndirectDests(), Args, OpBundles); } else { NewCall = Builder.CreateCall(Callee, Args, OpBundles); cast<CallInst>(NewCall)->setTailCallKind( diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 3f851a2b2182..5c84f666616d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -25,166 +25,6 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// Analyze 'Val', seeing if it is a simple linear expression. -/// If so, decompose it, returning some value X, such that Val is -/// X*Scale+Offset. -/// -static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale, - uint64_t &Offset) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) { - Offset = CI->getZExtValue(); - Scale = 0; - return ConstantInt::get(Val->getType(), 0); - } - - if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) { - // Cannot look past anything that might overflow. - // We specifically require nuw because we store the Scale in an unsigned - // and perform an unsigned divide on it. - OverflowingBinaryOperator *OBI = dyn_cast<OverflowingBinaryOperator>(Val); - if (OBI && !OBI->hasNoUnsignedWrap()) { - Scale = 1; - Offset = 0; - return Val; - } - - if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (I->getOpcode() == Instruction::Shl) { - // This is a value scaled by '1 << the shift amt'. - Scale = UINT64_C(1) << RHS->getZExtValue(); - Offset = 0; - return I->getOperand(0); - } - - if (I->getOpcode() == Instruction::Mul) { - // This value is scaled by 'RHS'. - Scale = RHS->getZExtValue(); - Offset = 0; - return I->getOperand(0); - } - - if (I->getOpcode() == Instruction::Add) { - // We have X+C. Check to see if we really have (X*C2)+C1, - // where C1 is divisible by C2. - unsigned SubScale; - Value *SubVal = - decomposeSimpleLinearExpr(I->getOperand(0), SubScale, Offset); - Offset += RHS->getZExtValue(); - Scale = SubScale; - return SubVal; - } - } - } - - // Otherwise, we can't look past this. - Scale = 1; - Offset = 0; - return Val; -} - -/// If we find a cast of an allocation instruction, try to eliminate the cast by -/// moving the type information into the alloc. -Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, - AllocaInst &AI) { - PointerType *PTy = cast<PointerType>(CI.getType()); - // Opaque pointers don't have an element type we could replace with. - if (PTy->isOpaque()) - return nullptr; - - IRBuilderBase::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(&AI); - - // Get the type really allocated and the type casted to. - Type *AllocElTy = AI.getAllocatedType(); - Type *CastElTy = PTy->getNonOpaquePointerElementType(); - if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; - - // This optimisation does not work for cases where the cast type - // is scalable and the allocated type is not. This because we need to - // know how many times the casted type fits into the allocated type. - // For the opposite case where the allocated type is scalable and the - // cast type is not this leads to poor code quality due to the - // introduction of 'vscale' into the calculations. It seems better to - // bail out for this case too until we've done a proper cost-benefit - // analysis. - bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy); - bool CastIsScalable = isa<ScalableVectorType>(CastElTy); - if (AllocIsScalable != CastIsScalable) return nullptr; - - Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy); - Align CastElTyAlign = DL.getABITypeAlign(CastElTy); - if (CastElTyAlign < AllocElTyAlign) return nullptr; - - // If the allocation has multiple uses, only promote it if we are strictly - // increasing the alignment of the resultant allocation. If we keep it the - // same, we open the door to infinite loops of various kinds. - if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr; - - // The alloc and cast types should be either both fixed or both scalable. - uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinValue(); - uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinValue(); - if (CastElTySize == 0 || AllocElTySize == 0) return nullptr; - - // If the allocation has multiple uses, only promote it if we're not - // shrinking the amount of memory being allocated. - uint64_t AllocElTyStoreSize = - DL.getTypeStoreSize(AllocElTy).getKnownMinValue(); - uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinValue(); - if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr; - - // See if we can satisfy the modulus by pulling a scale out of the array - // size argument. - unsigned ArraySizeScale; - uint64_t ArrayOffset; - Value *NumElements = // See if the array size is a decomposable linear expr. - decomposeSimpleLinearExpr(AI.getOperand(0), ArraySizeScale, ArrayOffset); - - // If we can now satisfy the modulus, by using a non-1 scale, we really can - // do the xform. - if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 || - (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr; - - // We don't currently support arrays of scalable types. - assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0)); - - unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize; - Value *Amt = nullptr; - if (Scale == 1) { - Amt = NumElements; - } else { - Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale); - // Insert before the alloca, not before the cast. - Amt = Builder.CreateMul(Amt, NumElements); - } - - if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { - Value *Off = ConstantInt::get(AI.getArraySize()->getType(), - Offset, true); - Amt = Builder.CreateAdd(Amt, Off); - } - - AllocaInst *New = Builder.CreateAlloca(CastElTy, AI.getAddressSpace(), Amt); - New->setAlignment(AI.getAlign()); - New->takeName(&AI); - New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); - New->setMetadata(LLVMContext::MD_DIAssignID, - AI.getMetadata(LLVMContext::MD_DIAssignID)); - - replaceAllDbgUsesWith(AI, *New, *New, DT); - - // If the allocation has multiple real uses, insert a cast and change all - // things that used it to use the new cast. This will also hack on CI, but it - // will die soon. - if (!AI.hasOneUse()) { - // New is the allocation instruction, pointer typed. AI is the original - // allocation instruction, also pointer typed. Thus, cast to use is BitCast. - Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast"); - replaceInstUsesWith(AI, NewCast); - eraseInstFromFunction(AI); - } - return replaceInstUsesWith(CI, New); -} - /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns /// true for, actually insert the code to evaluate the expression. Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, @@ -252,6 +92,20 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, Res = CastInst::Create( static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty); break; + case Instruction::Call: + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Unsupported call!"); + case Intrinsic::vscale: { + Function *Fn = + Intrinsic::getDeclaration(I->getModule(), Intrinsic::vscale, {Ty}); + Res = CallInst::Create(Fn->getFunctionType(), Fn); + break; + } + } + } + break; default: // TODO: Can handle more cases here. llvm_unreachable("Unreachable!"); @@ -294,6 +148,10 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); Type *Ty = CI.getType(); + if (auto *SrcC = dyn_cast<Constant>(Src)) + if (Constant *Res = ConstantFoldCastOperand(CI.getOpcode(), SrcC, Ty, DL)) + return replaceInstUsesWith(CI, Res); + // Try to eliminate a cast of a cast. if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { @@ -501,16 +359,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, // If the integer type can hold the max FP value, it is safe to cast // directly to that type. Otherwise, we may create poison via overflow // that did not exist in the original code. - // - // The max FP value is pow(2, MaxExponent) * (1 + MaxFraction), so we need - // at least one more bit than the MaxExponent to hold the max FP value. Type *InputTy = I->getOperand(0)->getType()->getScalarType(); const fltSemantics &Semantics = InputTy->getFltSemantics(); - uint32_t MinBitWidth = APFloatBase::semanticsMaxExponent(Semantics); - // Extra sign bit needed. - if (I->getOpcode() == Instruction::FPToSI) - ++MinBitWidth; - return Ty->getScalarSizeInBits() > MinBitWidth; + uint32_t MinBitWidth = + APFloatBase::semanticsIntSizeInBits(Semantics, + I->getOpcode() == Instruction::FPToSI); + return Ty->getScalarSizeInBits() >= MinBitWidth; } default: // TODO: Can handle more cases here. @@ -881,13 +735,12 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } - if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)), + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_ImmConstant(C)), m_Deferred(X))))) { // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); Constant *MaskC = ConstantExpr::getShl(One, C); - MaskC = ConstantExpr::getOr(MaskC, One); - Value *And = Builder.CreateAnd(X, MaskC); + Value *And = Builder.CreateAnd(X, Builder.CreateOr(MaskC, One)); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } } @@ -904,11 +757,18 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // removed by the trunc. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { + auto GetNewShAmt = [&](unsigned Width) { + Constant *MaxAmt = ConstantInt::get(SrcTy, Width - 1, false); + Constant *Cmp = + ConstantFoldCompareInstOperands(ICmpInst::ICMP_ULT, C, MaxAmt, DL); + Constant *ShAmt = ConstantFoldSelectInstruction(Cmp, C, MaxAmt); + return ConstantFoldCastOperand(Instruction::Trunc, ShAmt, A->getType(), + DL); + }; + // trunc (lshr (sext A), C) --> ashr A, C if (A->getType() == DestTy) { - Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); - Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); - ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + Constant *ShAmt = GetNewShAmt(DestWidth); ShAmt = Constant::mergeUndefsWith(ShAmt, C); return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) : BinaryOperator::CreateAShr(A, ShAmt); @@ -916,9 +776,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // The types are mismatched, so create a cast after shifting: // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) if (Src->hasOneUse()) { - Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false); - Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); - ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + Constant *ShAmt = GetNewShAmt(AWidth); Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact); return CastInst::CreateIntegerCast(Shift, DestTy, true); } @@ -998,7 +856,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { } } - if (match(Src, m_VScale(DL))) { + if (match(Src, m_VScale())) { if (Trunc.getFunction() && Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = @@ -1217,6 +1075,13 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, return false; return true; } + case Instruction::Call: + // llvm.vscale() can always be executed in larger type, because the + // value is automatically zero-extended. + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::vscale) + return true; + return false; default: // TODO: Can handle more cases here. return false; @@ -1226,7 +1091,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // If this zero extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this zext. - if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back())) + if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back()) && + !isa<Constant>(Zext.getOperand(0))) return nullptr; // If one of the common conversion will work, do it. @@ -1340,7 +1206,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { return BinaryOperator::CreateAnd(X, ZextC); } - if (match(Src, m_VScale(DL))) { + if (match(Src, m_VScale())) { if (Zext.getFunction() && Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = @@ -1402,7 +1268,7 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { // sext ((x & 2^n) == 0) -> (x >> n) - 1 // sext ((x & 2^n) != 2^n) -> (x >> n) - 1 - unsigned ShiftAmt = KnownZeroMask.countTrailingZeros(); + unsigned ShiftAmt = KnownZeroMask.countr_zero(); // Perform a right shift to place the desired bit in the LSB. if (ShiftAmt) In = Builder.CreateLShr(In, @@ -1416,7 +1282,7 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, } else { // sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1 // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1 - unsigned ShiftAmt = KnownZeroMask.countLeadingZeros(); + unsigned ShiftAmt = KnownZeroMask.countl_zero(); // Perform a left shift to place the desired bit in the MSB. if (ShiftAmt) In = Builder.CreateShl(In, @@ -1611,7 +1477,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { } } - if (match(Src, m_VScale(DL))) { + if (match(Src, m_VScale())) { if (Sext.getFunction() && Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = @@ -2687,57 +2553,6 @@ Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, return RetVal; } -static Instruction *convertBitCastToGEP(BitCastInst &CI, IRBuilderBase &Builder, - const DataLayout &DL) { - Value *Src = CI.getOperand(0); - PointerType *SrcPTy = cast<PointerType>(Src->getType()); - PointerType *DstPTy = cast<PointerType>(CI.getType()); - - // Bitcasts involving opaque pointers cannot be converted into a GEP. - if (SrcPTy->isOpaque() || DstPTy->isOpaque()) - return nullptr; - - Type *DstElTy = DstPTy->getNonOpaquePointerElementType(); - Type *SrcElTy = SrcPTy->getNonOpaquePointerElementType(); - - // When the type pointed to is not sized the cast cannot be - // turned into a gep. - if (!SrcElTy->isSized()) - return nullptr; - - // If the source and destination are pointers, and this cast is equivalent - // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. - // This can enhance SROA and other transforms that want type-safe pointers. - unsigned NumZeros = 0; - while (SrcElTy && SrcElTy != DstElTy) { - SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0); - ++NumZeros; - } - - // If we found a path from the src to dest, create the getelementptr now. - if (SrcElTy == DstElTy) { - SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - SrcPTy->getNonOpaquePointerElementType(), Src, Idxs); - - // If the source pointer is dereferenceable, then assume it points to an - // allocated object and apply "inbounds" to the GEP. - bool CanBeNull, CanBeFreed; - if (Src->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed)) { - // In a non-default address space (not 0), a null pointer can not be - // assumed inbounds, so ignore that case (dereferenceable_or_null). - // The reason is that 'null' is not treated differently in these address - // spaces, and we consequently ignore the 'gep inbounds' special case - // for 'null' which allows 'inbounds' on 'null' if the indices are - // zeros. - if (SrcPTy->getAddressSpace() == 0 || !CanBeNull) - GEP->setIsInBounds(); - } - return GEP; - } - return nullptr; -} - Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { // If the operands are integer typed then apply the integer transforms, // otherwise just apply the common ones. @@ -2750,19 +2565,6 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (DestTy == Src->getType()) return replaceInstUsesWith(CI, Src); - if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) { - // If we are casting a alloca to a pointer to a type of the same - // size, rewrite the allocation instruction to allocate the "right" type. - // There is no need to modify malloc calls because it is their bitcast that - // needs to be cleaned up. - if (AllocaInst *AI = dyn_cast<AllocaInst>(Src)) - if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) - return V; - - if (Instruction *I = convertBitCastToGEP(CI, Builder, DL)) - return I; - } - if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) { // Beware: messing with this target-specific oddity may cause trouble. if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { @@ -2905,23 +2707,5 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { } Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { - // If the destination pointer element type is not the same as the source's - // first do a bitcast to the destination type, and then the addrspacecast. - // This allows the cast to be exposed to other transforms. - Value *Src = CI.getOperand(0); - PointerType *SrcTy = cast<PointerType>(Src->getType()->getScalarType()); - PointerType *DestTy = cast<PointerType>(CI.getType()->getScalarType()); - - if (!SrcTy->hasSameElementTypeAs(DestTy)) { - Type *MidTy = - PointerType::getWithSamePointeeType(DestTy, SrcTy->getAddressSpace()); - // Handle vectors of pointers. - if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) - MidTy = VectorType::get(MidTy, VT->getElementCount()); - - Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); - return new AddrSpaceCastInst(NewBitCast, CI.getType()); - } - return commonPointerCastTransforms(CI); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 1480a0ff9e2f..656f04370e17 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -198,7 +199,11 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( } // If the element is masked, handle it. - if (AndCst) Elt = ConstantExpr::getAnd(Elt, AndCst); + if (AndCst) { + Elt = ConstantFoldBinaryOpOperands(Instruction::And, Elt, AndCst, DL); + if (!Elt) + return nullptr; + } // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, @@ -276,14 +281,14 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // order the state machines in complexity of the generated code. Value *Idx = GEP->getOperand(2); - // If the index is larger than the pointer size of the target, truncate the - // index down like the GEP would do implicitly. We don't have to do this for - // an inbounds GEP because the index can't be out of range. + // If the index is larger than the pointer offset size of the target, truncate + // the index down like the GEP would do implicitly. We don't have to do this + // for an inbounds GEP because the index can't be out of range. if (!GEP->isInBounds()) { - Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); - unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); - if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > PtrSize) - Idx = Builder.CreateTrunc(Idx, IntPtrTy); + Type *PtrIdxTy = DL.getIndexType(GEP->getType()); + unsigned OffsetSize = PtrIdxTy->getIntegerBitWidth(); + if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > OffsetSize) + Idx = Builder.CreateTrunc(Idx, PtrIdxTy); } // If inbounds keyword is not present, Idx * ElementSize can overflow. @@ -295,10 +300,10 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx. unsigned ElementSize = DL.getTypeAllocSize(Init->getType()->getArrayElementType()); - auto MaskIdx = [&](Value* Idx){ - if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) { + auto MaskIdx = [&](Value *Idx) { + if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) { Value *Mask = ConstantInt::get(Idx->getType(), -1); - Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize)); + Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize)); Idx = Builder.CreateAnd(Idx, Mask); } return Idx; @@ -533,7 +538,8 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V, /// pointer. static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, const DataLayout &DL, - SetVector<Value *> &Explored) { + SetVector<Value *> &Explored, + InstCombiner &IC) { // Perform all the substitutions. This is a bit tricky because we can // have cycles in our use-def chains. // 1. Create the PHI nodes without any incoming values. @@ -562,7 +568,7 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, // Create all the other instructions. for (Value *Val : Explored) { - if (NewInsts.find(Val) != NewInsts.end()) + if (NewInsts.contains(Val)) continue; if (auto *CI = dyn_cast<CastInst>(Val)) { @@ -610,7 +616,7 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { Value *NewIncoming = PHI->getIncomingValue(I); - if (NewInsts.find(NewIncoming) != NewInsts.end()) + if (NewInsts.contains(NewIncoming)) NewIncoming = NewInsts[NewIncoming]; NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I)); @@ -635,7 +641,10 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, Val->getName() + ".ptr"); NewVal = Builder.CreateBitOrPointerCast( NewVal, Val->getType(), Val->getName() + ".conv"); - Val->replaceAllUsesWith(NewVal); + IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal); + // Add old instruction to worklist for DCE. We don't directly remove it + // here because the original compare is one of the users. + IC.addToWorklist(cast<Instruction>(Val)); } return NewInsts[Start]; @@ -688,7 +697,8 @@ getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) { /// between GEPLHS and RHS. static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, - const DataLayout &DL) { + const DataLayout &DL, + InstCombiner &IC) { // FIXME: Support vector of pointers. if (GEPLHS->getType()->isVectorTy()) return nullptr; @@ -712,7 +722,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, // can't have overflow on either side. We can therefore re-write // this as: // OFFSET1 cmp OFFSET2 - Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes); + Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes, IC); // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written // GEP having PtrBase as the pointer base, and has returned in NewRHS the @@ -740,7 +750,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (PtrBase == RHS && GEPLHS->isInBounds()) { + if (PtrBase == RHS && (GEPLHS->isInBounds() || ICmpInst::isEquality(Cond))) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). Value *Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, @@ -831,7 +841,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Otherwise, the base pointers are different and the indices are // different. Try convert this to an indexed compare by looking through // PHIs/casts. - return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } // If one of the GEPs has all zero indices, recurse. @@ -883,7 +893,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! - if (GEPsInBounds && (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && + if ((GEPsInBounds || CmpInst::isEquality(Cond)) && + (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); @@ -894,13 +905,10 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Try convert this to an indexed compare by looking through PHIs/casts as a // last resort. - return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } -Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, - const AllocaInst *Alloca) { - assert(ICI.isEquality() && "Cannot fold non-equality comparison."); - +bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) { // It would be tempting to fold away comparisons between allocas and any // pointer not based on that alloca (e.g. an argument). However, even // though such pointers cannot alias, they can still compare equal. @@ -909,67 +917,72 @@ Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, // doesn't escape we can argue that it's impossible to guess its value, and we // can therefore act as if any such guesses are wrong. // - // The code below checks that the alloca doesn't escape, and that it's only - // used in a comparison once (the current instruction). The - // single-comparison-use condition ensures that we're trivially folding all - // comparisons against the alloca consistently, and avoids the risk of - // erroneously folding a comparison of the pointer with itself. - - unsigned MaxIter = 32; // Break cycles and bound to constant-time. + // However, we need to ensure that this folding is consistent: We can't fold + // one comparison to false, and then leave a different comparison against the + // same value alone (as it might evaluate to true at runtime, leading to a + // contradiction). As such, this code ensures that all comparisons are folded + // at the same time, and there are no other escapes. + + struct CmpCaptureTracker : public CaptureTracker { + AllocaInst *Alloca; + bool Captured = false; + /// The value of the map is a bit mask of which icmp operands the alloca is + /// used in. + SmallMapVector<ICmpInst *, unsigned, 4> ICmps; + + CmpCaptureTracker(AllocaInst *Alloca) : Alloca(Alloca) {} + + void tooManyUses() override { Captured = true; } + + bool captured(const Use *U) override { + auto *ICmp = dyn_cast<ICmpInst>(U->getUser()); + // We need to check that U is based *only* on the alloca, and doesn't + // have other contributions from a select/phi operand. + // TODO: We could check whether getUnderlyingObjects() reduces to one + // object, which would allow looking through phi nodes. + if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) { + // Collect equality icmps of the alloca, and don't treat them as + // captures. + auto Res = ICmps.insert({ICmp, 0}); + Res.first->second |= 1u << U->getOperandNo(); + return false; + } - SmallVector<const Use *, 32> Worklist; - for (const Use &U : Alloca->uses()) { - if (Worklist.size() >= MaxIter) - return nullptr; - Worklist.push_back(&U); - } + Captured = true; + return true; + } + }; - unsigned NumCmps = 0; - while (!Worklist.empty()) { - assert(Worklist.size() <= MaxIter); - const Use *U = Worklist.pop_back_val(); - const Value *V = U->getUser(); - --MaxIter; + CmpCaptureTracker Tracker(Alloca); + PointerMayBeCaptured(Alloca, &Tracker); + if (Tracker.Captured) + return false; - if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) || - isa<SelectInst>(V)) { - // Track the uses. - } else if (isa<LoadInst>(V)) { - // Loading from the pointer doesn't escape it. - continue; - } else if (const auto *SI = dyn_cast<StoreInst>(V)) { - // Storing *to* the pointer is fine, but storing the pointer escapes it. - if (SI->getValueOperand() == U->get()) - return nullptr; - continue; - } else if (isa<ICmpInst>(V)) { - if (NumCmps++) - return nullptr; // Found more than one cmp. - continue; - } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) { - switch (Intrin->getIntrinsicID()) { - // These intrinsics don't escape or compare the pointer. Memset is safe - // because we don't allow ptrtoint. Memcpy and memmove are safe because - // we don't allow stores, so src cannot point to V. - case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: - case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset: - continue; - default: - return nullptr; - } - } else { - return nullptr; + bool Changed = false; + for (auto [ICmp, Operands] : Tracker.ICmps) { + switch (Operands) { + case 1: + case 2: { + // The alloca is only used in one icmp operand. Assume that the + // equality is false. + auto *Res = ConstantInt::get( + ICmp->getType(), ICmp->getPredicate() == ICmpInst::ICMP_NE); + replaceInstUsesWith(*ICmp, Res); + eraseInstFromFunction(*ICmp); + Changed = true; + break; } - for (const Use &U : V->uses()) { - if (Worklist.size() >= MaxIter) - return nullptr; - Worklist.push_back(&U); + case 3: + // Both icmp operands are based on the alloca, so this is comparing + // pointer offsets, without leaking any information about the address + // of the alloca. Ignore such comparisons. + break; + default: + llvm_unreachable("Cannot happen"); } } - auto *Res = ConstantInt::get(ICI.getType(), - !CmpInst::isTrueWhenEqual(ICI.getPredicate())); - return replaceInstUsesWith(ICI, Res); + return Changed; } /// Fold "icmp pred (X+C), X". @@ -1058,9 +1071,9 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, int Shift; if (IsAShr && AP1.isNegative()) - Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); + Shift = AP1.countl_one() - AP2.countl_one(); else - Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + Shift = AP1.countl_zero() - AP2.countl_zero(); if (Shift > 0) { if (IsAShr && AP1 == AP2.ashr(Shift)) { @@ -1097,7 +1110,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, if (AP2.isZero()) return nullptr; - unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + unsigned AP2TrailingZeros = AP2.countr_zero(); if (!AP1 && AP2TrailingZeros != 0) return getICmp( @@ -1108,7 +1121,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); // Get the distance between the lowest bits that are set. - int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + int Shift = AP1.countr_zero() - AP2TrailingZeros; if (Shift > 0 && AP2.shl(Shift) == AP1) return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); @@ -1143,7 +1156,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. if (!CI2->getValue().isPowerOf2()) return nullptr; - unsigned NewWidth = CI2->getValue().countTrailingZeros(); + unsigned NewWidth = CI2->getValue().countr_zero(); if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; @@ -1295,6 +1308,48 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { return new ICmpInst(Pred, X, Cmp.getOperand(1)); } + // (icmp eq/ne (mul X Y)) -> (icmp eq/ne X/Y) if we know about whether X/Y are + // odd/non-zero/there is no overflow. + if (match(Cmp.getOperand(0), m_Mul(m_Value(X), m_Value(Y))) && + ICmpInst::isEquality(Pred)) { + + KnownBits XKnown = computeKnownBits(X, 0, &Cmp); + // if X % 2 != 0 + // (icmp eq/ne Y) + if (XKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, Y, Cmp.getOperand(1)); + + KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); + // if Y % 2 != 0 + // (icmp eq/ne X) + if (YKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + + auto *BO0 = cast<OverflowingBinaryOperator>(Cmp.getOperand(0)); + if (BO0->hasNoUnsignedWrap() || BO0->hasNoSignedWrap()) { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + // `isKnownNonZero` does more analysis than just `!KnownBits.One.isZero()` + // but to avoid unnecessary work, first just if this is an obvious case. + + // if X non-zero and NoOverflow(X * Y) + // (icmp eq/ne Y) + if (!XKnown.One.isZero() || isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(Pred, Y, Cmp.getOperand(1)); + + // if Y non-zero and NoOverflow(X * Y) + // (icmp eq/ne X) + if (!YKnown.One.isZero() || isKnownNonZero(Y, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + } + // Note, we are skipping cases: + // if Y % 2 != 0 AND X % 2 != 0 + // (false/true) + // if X non-zero and Y non-zero and NoOverflow(X * Y) + // (false/true) + // Those can be simplified later as we would have already replaced the (icmp + // eq/ne (mul X, Y)) with (icmp eq/ne X/Y) and if X/Y is known non-zero that + // will fold to a constant elsewhere. + } return nullptr; } @@ -1331,17 +1386,18 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { if (auto *Phi = dyn_cast<PHINode>(Op0)) if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) { - Type *Ty = Cmp.getType(); - Builder.SetInsertPoint(Phi); - PHINode *NewPhi = - Builder.CreatePHI(Ty, Phi->getNumOperands()); - for (BasicBlock *Predecessor : predecessors(Phi->getParent())) { - auto *Input = - cast<Constant>(Phi->getIncomingValueForBlock(Predecessor)); - auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C); - NewPhi->addIncoming(BoolInput, Predecessor); + SmallVector<Constant *> Ops; + for (Value *V : Phi->incoming_values()) { + Constant *Res = + ConstantFoldCompareInstOperands(Pred, cast<Constant>(V), C, DL); + if (!Res) + return nullptr; + Ops.push_back(Res); } - NewPhi->takeName(&Cmp); + Builder.SetInsertPoint(Phi); + PHINode *NewPhi = Builder.CreatePHI(Cmp.getType(), Phi->getNumOperands()); + for (auto [V, Pred] : zip(Ops, Phi->blocks())) + NewPhi->addIncoming(V, Pred); return replaceInstUsesWith(Cmp, NewPhi); } @@ -1369,11 +1425,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { if (TrueBB == FalseBB) return nullptr; - // Try to simplify this compare to T/F based on the dominating condition. - std::optional<bool> Imp = - isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); - if (Imp) - return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp)); + // We already checked simple implication in InstSimplify, only handle complex + // cases here. CmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1); @@ -1475,7 +1528,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, KnownBits Known = computeKnownBits(X, 0, &Cmp); // If all the high bits are known, we can do this xform. - if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { + if ((Known.Zero | Known.One).countl_one() >= SrcBits - DstBits) { // Pull in the high bits from known-ones set. APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); @@ -1781,17 +1834,12 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, ++UsesRemoved; // Compute A & ((1 << B) | 1) - Value *NewOr = nullptr; - if (auto *C = dyn_cast<Constant>(B)) { - if (UsesRemoved >= 1) - NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); - } - if (NewOr) { + unsigned RequireUsesRemoved = match(B, m_ImmConstant()) ? 1 : 3; + if (UsesRemoved >= RequireUsesRemoved) { + Value *NewOr = + Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); return replaceOperand(Cmp, 0, NewAnd); } @@ -1819,6 +1867,15 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType())); } + // (X & X) < 0 --> X == MinSignedC + // (X & X) > -1 --> X != MinSignedC + if (match(And, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) { + Constant *MinSignedC = ConstantInt::get( + X->getType(), + APInt::getSignedMinValue(X->getType()->getScalarSizeInBits())); + auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; + return new ICmpInst(NewPred, X, MinSignedC); + } } // TODO: These all require that Y is constant too, so refactor with the above. @@ -1846,6 +1903,30 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); } + // If we are testing the intersection of 2 select-of-nonzero-constants with no + // common bits set, it's the same as checking if exactly one select condition + // is set: + // ((A ? TC : FC) & (B ? TC : FC)) == 0 --> xor A, B + // ((A ? TC : FC) & (B ? TC : FC)) != 0 --> not(xor A, B) + // TODO: Generalize for non-constant values. + // TODO: Handle signed/unsigned predicates. + // TODO: Handle other bitwise logic connectors. + // TODO: Extend to handle a non-zero compare constant. + if (C.isZero() && (Pred == CmpInst::ICMP_EQ || And->hasOneUse())) { + assert(Cmp.isEquality() && "Not expecting non-equality predicates"); + Value *A, *B; + const APInt *TC, *FC; + if (match(X, m_Select(m_Value(A), m_APInt(TC), m_APInt(FC))) && + match(Y, + m_Select(m_Value(B), m_SpecificInt(*TC), m_SpecificInt(*FC))) && + !TC->isZero() && !FC->isZero() && !TC->intersects(*FC)) { + Value *R = Builder.CreateXor(A, B); + if (Pred == CmpInst::ICMP_NE) + R = Builder.CreateNot(R); + return replaceInstUsesWith(Cmp, R); + } + } + // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X) // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X) // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X) @@ -1863,6 +1944,59 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return nullptr; } +/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0. +static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, + InstCombiner::BuilderTy &Builder) { + // Are we using xors to bitwise check for a pair or pairs of (in)equalities? + // Convert to a shorter form that has more potential to be folded even + // further. + // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) + // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 --> + // (X1 == X2) && (X3 == X4) && (X5 == X6) + // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 --> + // (X1 != X2) || (X3 != X4) || (X5 != X6) + // TODO: Implement for sub + SmallVector<std::pair<Value *, Value *>, 2> CmpValues; + SmallVector<Value *, 16> WorkList(1, Or); + + while (!WorkList.empty()) { + auto MatchOrOperatorArgument = [&](Value *OrOperatorArgument) { + Value *Lhs, *Rhs; + + if (match(OrOperatorArgument, + m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) { + CmpValues.emplace_back(Lhs, Rhs); + } else { + WorkList.push_back(OrOperatorArgument); + } + }; + + Value *CurrentValue = WorkList.pop_back_val(); + Value *OrOperatorLhs, *OrOperatorRhs; + + if (!match(CurrentValue, + m_Or(m_Value(OrOperatorLhs), m_Value(OrOperatorRhs)))) { + return nullptr; + } + + MatchOrOperatorArgument(OrOperatorRhs); + MatchOrOperatorArgument(OrOperatorLhs); + } + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + Value *LhsCmp = Builder.CreateICmp(Pred, CmpValues.rbegin()->first, + CmpValues.rbegin()->second); + + for (auto It = CmpValues.rbegin() + 1; It != CmpValues.rend(); ++It) { + Value *RhsCmp = Builder.CreateICmp(Pred, It->first, It->second); + LhsCmp = Builder.CreateBinOp(BOpc, LhsCmp, RhsCmp); + } + + return LhsCmp; +} + /// Fold icmp (or X, Y), C. Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, @@ -1909,6 +2043,30 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, NewC); } + const APInt *OrC; + // icmp(X | OrC, C) --> icmp(X, 0) + if (C.isNonNegative() && match(Or, m_Or(m_Value(X), m_APInt(OrC)))) { + switch (Pred) { + // X | OrC s< C --> X s< 0 iff OrC s>= C s>= 0 + case ICmpInst::ICMP_SLT: + // X | OrC s>= C --> X s>= 0 iff OrC s>= C s>= 0 + case ICmpInst::ICMP_SGE: + if (OrC->sge(C)) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + break; + // X | OrC s<= C --> X s< 0 iff OrC s> C s>= 0 + case ICmpInst::ICMP_SLE: + // X | OrC s> C --> X s>= 0 iff OrC s> C s>= 0 + case ICmpInst::ICMP_SGT: + if (OrC->sgt(C)) + return new ICmpInst(ICmpInst::getFlippedStrictnessPredicate(Pred), X, + ConstantInt::getNullValue(X->getType())); + break; + default: + break; + } + } + if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) return nullptr; @@ -1924,18 +2082,8 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return BinaryOperator::Create(BOpc, CmpP, CmpQ); } - // Are we using xors to bitwise check for a pair of (in)equalities? Convert to - // a shorter form that has more potential to be folded even further. - Value *X1, *X2, *X3, *X4; - if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && - match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { - // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) - // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) - Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); - Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4); - auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; - return BinaryOperator::Create(BOpc, Cmp12, Cmp34); - } + if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder)) + return replaceInstUsesWith(Cmp, V); return nullptr; } @@ -1969,21 +2117,29 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); } - if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap())) + if (MulC->isZero()) return nullptr; - // If the multiply does not wrap, try to divide the compare constant by the - // multiplication factor. + // If the multiply does not wrap or the constant is odd, try to divide the + // compare constant by the multiplication factor. if (Cmp.isEquality()) { - // (mul nsw X, MulC) == C --> X == C /s MulC + // (mul nsw X, MulC) eq/ne C --> X eq/ne C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); return new ICmpInst(Pred, X, NewC); } - // (mul nuw X, MulC) == C --> X == C /u MulC - if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); - return new ICmpInst(Pred, X, NewC); + + // C % MulC == 0 is weaker than we could use if MulC is odd because it + // correct to transform if MulC * N == C including overflow. I.e with i8 + // (icmp eq (mul X, 5), 101) -> (icmp eq X, 225) but since 101 % 5 != 0, we + // miss that case. + if (C.urem(*MulC).isZero()) { + // (mul nuw X, MulC) eq/ne C --> X eq/ne C /u MulC + // (mul X, OddC) eq/ne N * C --> X eq/ne N + if ((*MulC & 1).isOne() || Mul->hasNoUnsignedWrap()) { + Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); + return new ICmpInst(Pred, X, NewC); + } } } @@ -1992,27 +2148,32 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, // (X * MulC) > C --> X > (C / MulC) // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE? Constant *NewC = nullptr; - if (Mul->hasNoSignedWrap()) { + if (Mul->hasNoSignedWrap() && ICmpInst::isSigned(Pred)) { // MININT / -1 --> overflow. if (C.isMinSignedValue() && MulC->isAllOnes()) return nullptr; if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { NewC = ConstantInt::get( MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); - if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) + } else { + assert((Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) && + "Unexpected predicate"); NewC = ConstantInt::get( MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); - } else { - assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw"); - if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) + } + } else if (Mul->hasNoUnsignedWrap() && ICmpInst::isUnsigned(Pred)) { + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) { NewC = ConstantInt::get( MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); - if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) + } else { + assert((Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && + "Unexpected predicate"); NewC = ConstantInt::get( MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); + } } return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; @@ -2070,6 +2231,32 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + // (icmp pred (shl nuw&nsw X, Y), Csle0) + // -> (icmp pred X, Csle0) + // + // The idea is the nuw/nsw essentially freeze the sign bit for the shift op + // so X's must be what is used. + if (C.sle(0) && Shl->hasNoUnsignedWrap() && Shl->hasNoSignedWrap()) + return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); + + // (icmp eq/ne (shl nuw|nsw X, Y), 0) + // -> (icmp eq/ne X, 0) + if (ICmpInst::isEquality(Pred) && C.isZero() && + (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap())) + return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); + + // (icmp slt (shl nsw X, Y), 0/1) + // -> (icmp slt X, 0/1) + // (icmp sgt (shl nsw X, Y), 0/-1) + // -> (icmp sgt X, 0/-1) + // + // NB: sge/sle with a constant will canonicalize to sgt/slt. + if (Shl->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) + if (C.isZero() || (Pred == ICmpInst::ICMP_SGT ? C.isAllOnes() : C.isOne())) + return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); + const APInt *ShiftAmt; if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) return foldICmpShlOne(Cmp, Shl, C); @@ -2080,7 +2267,6 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, if (ShiftAmt->uge(TypeBits)) return nullptr; - ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Shl->getOperand(0); Type *ShType = Shl->getType(); @@ -2107,11 +2293,6 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead; isSignTest may change 'Pred', so only - // do that if we're sure to not continue on in this function. - if (isSignTest(Pred, C)) - return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); } // NUW guarantees that we are only shifting out zero bits from the high bits, @@ -2189,7 +2370,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, // free on the target. It has the additional benefit of comparing to a // smaller constant that may be more target-friendly. unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); - if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && + if (Shl->hasOneUse() && Amt != 0 && C.countr_zero() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); if (auto *ShVTy = dyn_cast<VectorType>(ShType)) @@ -2237,9 +2418,8 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, assert(ShiftValC->uge(C) && "Expected simplify of compare"); assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify"); - unsigned CmpLZ = - IsUGT ? C.countLeadingZeros() : (C - 1).countLeadingZeros(); - unsigned ShiftLZ = ShiftValC->countLeadingZeros(); + unsigned CmpLZ = IsUGT ? C.countl_zero() : (C - 1).countl_zero(); + unsigned ShiftLZ = ShiftValC->countl_zero(); Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ); auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; return new ICmpInst(NewPred, Shr->getOperand(1), NewC); @@ -3184,18 +3364,30 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( } break; } - case Instruction::And: { - const APInt *BOC; - if (match(BOp1, m_APInt(BOC))) { - // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (C == *BOC && C.isPowerOf2()) - return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, - BO, Constant::getNullValue(RHS->getType())); - } - break; - } case Instruction::UDiv: - if (C.isZero()) { + case Instruction::SDiv: + if (BO->isExact()) { + // div exact X, Y eq/ne 0 -> X eq/ne 0 + // div exact X, Y eq/ne 1 -> X eq/ne Y + // div exact X, Y eq/ne C -> + // if Y * C never-overflow && OneUse: + // -> Y * C eq/ne X + if (C.isZero()) + return new ICmpInst(Pred, BOp0, Constant::getNullValue(BO->getType())); + else if (C.isOne()) + return new ICmpInst(Pred, BOp0, BOp1); + else if (BO->hasOneUse()) { + OverflowResult OR = computeOverflow( + Instruction::Mul, BO->getOpcode() == Instruction::SDiv, BOp1, + Cmp.getOperand(1), BO); + if (OR == OverflowResult::NeverOverflows) { + Value *YC = + Builder.CreateMul(BOp1, ConstantInt::get(BO->getType(), C)); + return new ICmpInst(Pred, YC, BOp0); + } + } + } + if (BO->getOpcode() == Instruction::UDiv && C.isZero()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -3207,6 +3399,44 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( return nullptr; } +static Instruction *foldCtpopPow2Test(ICmpInst &I, IntrinsicInst *CtpopLhs, + const APInt &CRhs, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &Q) { + assert(CtpopLhs->getIntrinsicID() == Intrinsic::ctpop && + "Non-ctpop intrin in ctpop fold"); + if (!CtpopLhs->hasOneUse()) + return nullptr; + + // Power of 2 test: + // isPow2OrZero : ctpop(X) u< 2 + // isPow2 : ctpop(X) == 1 + // NotPow2OrZero: ctpop(X) u> 1 + // NotPow2 : ctpop(X) != 1 + // If we know any bit of X can be folded to: + // IsPow2 : X & (~Bit) == 0 + // NotPow2 : X & (~Bit) != 0 + const ICmpInst::Predicate Pred = I.getPredicate(); + if (((I.isEquality() || Pred == ICmpInst::ICMP_UGT) && CRhs == 1) || + (Pred == ICmpInst::ICMP_ULT && CRhs == 2)) { + Value *Op = CtpopLhs->getArgOperand(0); + KnownBits OpKnown = computeKnownBits(Op, Q.DL, + /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT); + // No need to check for count > 1, that should be already constant folded. + if (OpKnown.countMinPopulation() == 1) { + Value *And = Builder.CreateAnd( + Op, Constant::getIntegerValue(Op->getType(), ~(OpKnown.One))); + return new ICmpInst( + (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_ULT) + ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(Op->getType())); + } + } + + return nullptr; +} + /// Fold an equality icmp with LLVM intrinsic and constant operand. Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { @@ -3227,6 +3457,11 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C.byteSwap())); + case Intrinsic::bitreverse: + // bitreverse(A) == C -> A == bitreverse(C) + return new ICmpInst(Pred, II->getArgOperand(0), + ConstantInt::get(Ty, C.reverseBits())); + case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != @@ -3277,15 +3512,22 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( } break; + case Intrinsic::umax: case Intrinsic::uadd_sat: { // uadd.sat(a, b) == 0 -> (a | b) == 0 - if (C.isZero()) { + // umax(a, b) == 0 -> (a | b) == 0 + if (C.isZero() && II->hasOneUse()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); } break; } + case Intrinsic::ssub_sat: + // ssub.sat(a, b) == 0 -> a == b + if (C.isZero()) + return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); + break; case Intrinsic::usub_sat: { // usub.sat(a, b) == 0 -> a <= b if (C.isZero()) { @@ -3303,7 +3545,9 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( } /// Fold an icmp with LLVM intrinsics -static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { +static Instruction * +foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { assert(Cmp.isEquality()); ICmpInst::Predicate Pred = Cmp.getPredicate(); @@ -3321,16 +3565,32 @@ static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { // original values. return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); case Intrinsic::fshl: - case Intrinsic::fshr: + case Intrinsic::fshr: { // If both operands are rotated by same amount, just compare the // original values. if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) break; if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) break; - if (IIOp0->getOperand(2) != IIOp1->getOperand(2)) - break; - return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + if (IIOp0->getOperand(2) == IIOp1->getOperand(2)) + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + + // rotate(X, AmtX) == rotate(Y, AmtY) + // -> rotate(X, AmtX - AmtY) == Y + // Do this if either both rotates have one use or if only one has one use + // and AmtX/AmtY are constants. + unsigned OneUses = IIOp0->hasOneUse() + IIOp1->hasOneUse(); + if (OneUses == 2 || + (OneUses == 1 && match(IIOp0->getOperand(2), m_ImmConstant()) && + match(IIOp1->getOperand(2), m_ImmConstant()))) { + Value *SubAmt = + Builder.CreateSub(IIOp0->getOperand(2), IIOp1->getOperand(2)); + Value *CombinedRotate = Builder.CreateIntrinsic( + Op0->getType(), IIOp0->getIntrinsicID(), + {IIOp0->getOperand(0), IIOp0->getOperand(0), SubAmt}); + return new ICmpInst(Pred, IIOp1->getOperand(0), CombinedRotate); + } + } break; default: break; } @@ -3421,16 +3681,119 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); } +static Instruction * +foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, + SaturatingInst *II, const APInt &C, + InstCombiner::BuilderTy &Builder) { + // This transform may end up producing more than one instruction for the + // intrinsic, so limit it to one user of the intrinsic. + if (!II->hasOneUse()) + return nullptr; + + // Let Y = [add/sub]_sat(X, C) pred C2 + // SatVal = The saturating value for the operation + // WillWrap = Whether or not the operation will underflow / overflow + // => Y = (WillWrap ? SatVal : (X binop C)) pred C2 + // => Y = WillWrap ? (SatVal pred C2) : ((X binop C) pred C2) + // + // When (SatVal pred C2) is true, then + // Y = WillWrap ? true : ((X binop C) pred C2) + // => Y = WillWrap || ((X binop C) pred C2) + // else + // Y = WillWrap ? false : ((X binop C) pred C2) + // => Y = !WillWrap ? ((X binop C) pred C2) : false + // => Y = !WillWrap && ((X binop C) pred C2) + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + + const APInt *COp1; + // This transform only works when the intrinsic has an integral constant or + // splat vector as the second operand. + if (!match(Op1, m_APInt(COp1))) + return nullptr; + + APInt SatVal; + switch (II->getIntrinsicID()) { + default: + llvm_unreachable( + "This function only works with usub_sat and uadd_sat for now!"); + case Intrinsic::uadd_sat: + SatVal = APInt::getAllOnes(C.getBitWidth()); + break; + case Intrinsic::usub_sat: + SatVal = APInt::getZero(C.getBitWidth()); + break; + } + + // Check (SatVal pred C2) + bool SatValCheck = ICmpInst::compare(SatVal, C, Pred); + + // !WillWrap. + ConstantRange C1 = ConstantRange::makeExactNoWrapRegion( + II->getBinaryOp(), *COp1, II->getNoWrapKind()); + + // WillWrap. + if (SatValCheck) + C1 = C1.inverse(); + + ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); + if (II->getBinaryOp() == Instruction::Add) + C2 = C2.sub(*COp1); + else + C2 = C2.add(*COp1); + + Instruction::BinaryOps CombiningOp = + SatValCheck ? Instruction::BinaryOps::Or : Instruction::BinaryOps::And; + + std::optional<ConstantRange> Combination; + if (CombiningOp == Instruction::BinaryOps::Or) + Combination = C1.exactUnionWith(C2); + else /* CombiningOp == Instruction::BinaryOps::And */ + Combination = C1.exactIntersectWith(C2); + + if (!Combination) + return nullptr; + + CmpInst::Predicate EquivPred; + APInt EquivInt; + APInt EquivOffset; + + Combination->getEquivalentICmp(EquivPred, EquivInt, EquivOffset); + + return new ICmpInst( + EquivPred, + Builder.CreateAdd(Op0, ConstantInt::get(Op1->getType(), EquivOffset)), + ConstantInt::get(Op1->getType(), EquivInt)); +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + + // Handle folds that apply for any kind of icmp. + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: + if (auto *Folded = foldICmpUSubSatOrUAddSatWithConstant( + Pred, cast<SaturatingInst>(II), C, Builder)) + return Folded; + break; + case Intrinsic::ctpop: { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + if (Instruction *R = foldCtpopPow2Test(Cmp, II, C, Builder, Q)) + return R; + } break; + } + if (Cmp.isEquality()) return foldICmpEqIntrinsicWithConstant(Cmp, II, C); Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); - ICmpInst::Predicate Pred = Cmp.getPredicate(); switch (II->getIntrinsicID()) { case Intrinsic::ctpop: { // (ctpop X > BitWidth - 1) --> X == -1 @@ -3484,6 +3847,21 @@ Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, } break; } + case Intrinsic::ssub_sat: + // ssub.sat(a, b) spred 0 -> a spred b + if (ICmpInst::isSigned(Pred)) { + if (C.isZero()) + return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); + // X s<= 0 is cannonicalized to X s< 1 + if (Pred == ICmpInst::ICMP_SLT && C.isOne()) + return new ICmpInst(ICmpInst::ICMP_SLE, II->getArgOperand(0), + II->getArgOperand(1)); + // X s>= 0 is cannonicalized to X s> -1 + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) + return new ICmpInst(ICmpInst::ICMP_SGE, II->getArgOperand(0), + II->getArgOperand(1)); + } + break; default: break; } @@ -4014,20 +4392,60 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { return Res; } -static Instruction *foldICmpXNegX(ICmpInst &I) { +static Instruction *foldICmpXNegX(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { CmpInst::Predicate Pred; Value *X; - if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) - return nullptr; + if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) { + + if (ICmpInst::isSigned(Pred)) + Pred = ICmpInst::getSwappedPredicate(Pred); + else if (ICmpInst::isUnsigned(Pred)) + Pred = ICmpInst::getSignedPredicate(Pred); + // else for equality-comparisons just keep the predicate. + + return ICmpInst::Create(Instruction::ICmp, Pred, X, + Constant::getNullValue(X->getType()), I.getName()); + } + + // A value is not equal to its negation unless that value is 0 or + // MinSignedValue, ie: a != -a --> (a & MaxSignedVal) != 0 + if (match(&I, m_c_ICmp(Pred, m_OneUse(m_Neg(m_Value(X))), m_Deferred(X))) && + ICmpInst::isEquality(Pred)) { + Type *Ty = X->getType(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + Constant *MaxSignedVal = + ConstantInt::get(Ty, APInt::getSignedMaxValue(BitWidth)); + Value *And = Builder.CreateAnd(X, MaxSignedVal); + Constant *Zero = Constant::getNullValue(Ty); + return CmpInst::Create(Instruction::ICmp, Pred, And, Zero); + } + + return nullptr; +} - if (ICmpInst::isSigned(Pred)) +static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + // Normalize xor operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_Xor(m_Specific(Op0), m_Value()))) { + std::swap(Op0, Op1); Pred = ICmpInst::getSwappedPredicate(Pred); - else if (ICmpInst::isUnsigned(Pred)) - Pred = ICmpInst::getSignedPredicate(Pred); - // else for equality-comparisons just keep the predicate. + } + if (!match(Op0, m_c_Xor(m_Specific(Op1), m_Value(A)))) + return nullptr; - return ICmpInst::Create(Instruction::ICmp, Pred, X, - Constant::getNullValue(X->getType()), I.getName()); + // icmp (X ^ Y_NonZero) u>= X --> icmp (X ^ Y_NonZero) u> X + // icmp (X ^ Y_NonZero) u<= X --> icmp (X ^ Y_NonZero) u< X + // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X + // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X + CmpInst::Predicate PredOut = CmpInst::getStrictPredicate(Pred); + if (PredOut != Pred && + isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(PredOut, Op0, Op1); + + return nullptr; } /// Try to fold icmp (binop), X or icmp X, (binop). @@ -4045,7 +4463,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (!BO0 && !BO1) return nullptr; - if (Instruction *NewICmp = foldICmpXNegX(I)) + if (Instruction *NewICmp = foldICmpXNegX(I, Builder)) return NewICmp; const CmpInst::Predicate Pred = I.getPredicate(); @@ -4326,17 +4744,41 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, ConstantExpr::getNeg(RHSC)); } + if (Instruction * R = foldICmpXorXX(I, Q, *this)) + return R; + { - // Try to remove shared constant multiplier from equality comparison: - // X * C == Y * C (with no overflowing/aliasing) --> X == Y - Value *X, *Y; - const APInt *C; - if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && - match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) - if (!C->countTrailingZeros() || - (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) - return new ICmpInst(Pred, X, Y); + // Try to remove shared multiplier from comparison: + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z + Value *X, *Y, *Z; + if (Pred == ICmpInst::getUnsignedPredicate(Pred) && + ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) || + (match(Op0, m_Mul(m_Value(Z), m_Value(X))) && + match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) { + bool NonZero; + if (ICmpInst::isEquality(Pred)) { + KnownBits ZKnown = computeKnownBits(Z, 0, &I); + // if Z % 2 != 0 + // X * Z eq/ne Y * Z -> X eq/ne Y + if (ZKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, X, Y); + NonZero = !ZKnown.One.isZero() || + isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + // if Z != 0 and nsw(X * Z) and nsw(Y * Z) + // X * Z eq/ne Y * Z -> X eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoSignedWrap() && + BO1->hasNoSignedWrap()) + return new ICmpInst(Pred, X, Y); + } else + NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + + // If Z != 0 and nuw(X * Z) and nuw(Y * Z) + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoUnsignedWrap() && + BO1->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, Y); + } } BinaryOperator *SRem = nullptr; @@ -4405,7 +4847,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, !C->isOne()) { // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) // Mask = -1 >> count-trailing-zeros(C). - if (unsigned TZs = C->countTrailingZeros()) { + if (unsigned TZs = C->countr_zero()) { Constant *Mask = ConstantInt::get( BO0->getType(), APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs)); @@ -4569,6 +5011,59 @@ static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { return nullptr; } +// Canonicalize checking for a power-of-2-or-zero value: +static Instruction *foldICmpPow2Test(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const CmpInst::Predicate Pred = I.getPredicate(); + Value *A = nullptr; + bool CheckIs; + if (I.isEquality()) { + // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) + // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) + if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), + m_Deferred(A)))) || + !match(Op1, m_ZeroInt())) + A = nullptr; + + // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) + // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) + if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) + A = Op1; + else if (match(Op1, + m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) + A = Op0; + + CheckIs = Pred == ICmpInst::ICMP_EQ; + } else if (ICmpInst::isUnsigned(Pred)) { + // (A ^ (A-1)) u>= A --> ctpop(A) < 2 (two commuted variants) + // ((A-1) ^ A) u< A --> ctpop(A) > 1 (two commuted variants) + + if ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && + match(Op0, m_OneUse(m_c_Xor(m_Add(m_Specific(Op1), m_AllOnes()), + m_Specific(Op1))))) { + A = Op1; + CheckIs = Pred == ICmpInst::ICMP_UGE; + } else if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) && + match(Op1, m_OneUse(m_c_Xor(m_Add(m_Specific(Op0), m_AllOnes()), + m_Specific(Op0))))) { + A = Op0; + CheckIs = Pred == ICmpInst::ICMP_ULE; + } + } + + if (A) { + Type *Ty = A->getType(); + CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); + return CheckIs ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, + ConstantInt::get(Ty, 2)) + : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, + ConstantInt::get(Ty, 1)); + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { if (!I.isEquality()) return nullptr; @@ -4604,6 +5099,21 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } + // canoncalize: + // (icmp eq/ne (and X, C), X) + // -> (icmp eq/ne (and X, ~C), 0) + { + Constant *CMask; + A = nullptr; + if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_ImmConstant(CMask))))) + A = Op1; + else if (match(Op1, m_OneUse(m_And(m_Specific(Op0), m_ImmConstant(CMask))))) + A = Op0; + if (A) + return new ICmpInst(Pred, Builder.CreateAnd(A, Builder.CreateNot(CMask)), + Constant::getNullValue(A->getType())); + } + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { // A == (A^B) -> B == 0 Value *OtherVal = A == Op0 ? B : A; @@ -4659,22 +5169,36 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { // (B & (Pow2C-1)) != zext A --> A != trunc B const APInt *MaskC; if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) && - MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits()) + MaskC->countr_one() == A->getType()->getScalarSizeInBits()) return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); + } - // Test if 2 values have different or same signbits: - // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 - // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 + // Test if 2 values have different or same signbits: + // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 + // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 + // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0 + // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1 + Instruction *ExtI; + if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + Instruction *ShiftI; Value *X, *Y; ICmpInst::Predicate Pred2; - if (match(Op0, m_LShr(m_Value(X), m_SpecificIntAllowUndef(OpWidth - 1))) && + if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), + m_Shr(m_Value(X), + m_SpecificIntAllowUndef(OpWidth - 1)))) && match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) && Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) { - Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); - Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) : - Builder.CreateIsNotNeg(Xor); - return replaceInstUsesWith(I, R); + unsigned ExtOpc = ExtI->getOpcode(); + unsigned ShiftOpc = ShiftI->getOpcode(); + if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || + (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { + Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); + Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) + : Builder.CreateIsNotNeg(Xor); + return replaceInstUsesWith(I, R); + } } } @@ -4737,33 +5261,9 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I)) + if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I, Builder)) return ICmp; - // Canonicalize checking for a power-of-2-or-zero value: - // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) - // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) - if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), - m_Deferred(A)))) || - !match(Op1, m_ZeroInt())) - A = nullptr; - - // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) - // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) - if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) - A = Op1; - else if (match(Op1, - m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) - A = Op0; - - if (A) { - Type *Ty = A->getType(); - CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); - return Pred == ICmpInst::ICMP_EQ - ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) - : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); - } - // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps @@ -4794,11 +5294,23 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1, ConstantInt::getNullValue(Op1->getType())); + // Canonicalize: + // icmp eq/ne X, OneUse(rotate-right(X)) + // -> icmp eq/ne X, rotate-left(X) + // We generally try to convert rotate-right -> rotate-left, this just + // canonicalizes another case. + CmpInst::Predicate PredUnused = Pred; + if (match(&I, m_c_ICmp(PredUnused, m_Value(A), + m_OneUse(m_Intrinsic<Intrinsic::fshr>( + m_Deferred(A), m_Deferred(A), m_Value(B)))))) + return new ICmpInst( + Pred, A, + Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B})); + return nullptr; } -static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { ICmpInst::Predicate Pred = ICmp.getPredicate(); Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); @@ -4836,6 +5348,25 @@ static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); } + if (auto *II = dyn_cast<IntrinsicInst>(X)) { + if (II->getIntrinsicID() == Intrinsic::cttz || + II->getIntrinsicID() == Intrinsic::ctlz) { + unsigned MaxRet = SrcBits; + // If the "is_zero_poison" argument is set, then we know at least + // one bit is set in the input, so the result is always at least one + // less than the full bitwidth of that input. + if (match(II->getArgOperand(1), m_One())) + MaxRet--; + + // Make sure the destination is wide enough to hold the largest output of + // the intrinsic. + if (llvm::Log2_32(MaxRet) + 1 <= Op0->getType()->getScalarSizeInBits()) + if (Instruction *I = + foldICmpIntrinsicWithConstant(ICmp, II, C->zext(SrcBits))) + return I; + } + } + return nullptr; } @@ -4855,10 +5386,19 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0)); bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1)); - // If we have mismatched casts, treat the zext of a non-negative source as - // a sext to simulate matching casts. Otherwise, we are done. - // TODO: Can we handle some predicates (equality) without non-negative? if (IsZext0 != IsZext1) { + // If X and Y and both i1 + // (icmp eq/ne (zext X) (sext Y)) + // eq -> (icmp eq (or X, Y), 0) + // ne -> (icmp ne (or X, Y), 0) + if (ICmp.isEquality() && X->getType()->isIntOrIntVectorTy(1) && + Y->getType()->isIntOrIntVectorTy(1)) + return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y), + Constant::getNullValue(X->getType())); + + // If we have mismatched casts, treat the zext of a non-negative source as + // a sext to simulate matching casts. Otherwise, we are done. + // TODO: Can we handle some predicates (equality) without non-negative? if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) IsSignedExt = true; @@ -4993,7 +5533,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); } - if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) + if (Instruction *R = foldICmpWithTrunc(ICmp)) return R; return foldICmpWithZextOrSext(ICmp); @@ -5153,7 +5693,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, return nullptr; if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { const APInt &CVal = CI->getValue(); - if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth) return nullptr; } else { // In this case we could have the operand of the binary operation @@ -5334,44 +5874,18 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { // bits doesn't impact the outcome of the comparison, because any value // greater than the RHS must differ in a bit higher than these due to carry. case ICmpInst::ICMP_UGT: - return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes()); + return APInt::getBitsSetFrom(BitWidth, RHS->countr_one()); // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. case ICmpInst::ICMP_ULT: - return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); + return APInt::getBitsSetFrom(BitWidth, RHS->countr_zero()); default: return APInt::getAllOnes(BitWidth); } } -/// Check if the order of \p Op0 and \p Op1 as operands in an ICmpInst -/// should be swapped. -/// The decision is based on how many times these two operands are reused -/// as subtract operands and their positions in those instructions. -/// The rationale is that several architectures use the same instruction for -/// both subtract and cmp. Thus, it is better if the order of those operands -/// match. -/// \return true if Op0 and Op1 should be swapped. -static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) { - // Filter out pointer values as those cannot appear directly in subtract. - // FIXME: we may want to go through inttoptrs or bitcasts. - if (Op0->getType()->isPointerTy()) - return false; - // If a subtract already has the same operands as a compare, swapping would be - // bad. If a subtract has the same operands as a compare but in reverse order, - // then swapping is good. - int GoodToSwap = 0; - for (const User *U : Op0->users()) { - if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0)))) - GoodToSwap++; - else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1)))) - GoodToSwap--; - } - return GoodToSwap > 0; -} - /// Check that one use is in the same block as the definition and all /// other uses are in blocks dominated by a given block. /// @@ -5638,14 +6152,14 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { const APInt *C1; if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) { Type *XTy = X->getType(); - unsigned Log2C1 = C1->countTrailingZeros(); + unsigned Log2C1 = C1->countr_zero(); APInt C2 = Op0KnownZeroInverted; APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1; if (C2Pow2.isPowerOf2()) { // iff (C1 is pow2) & ((C2 & ~(C1-1)) + C1) is pow2): // ((C1 << X) & C2) == 0 -> X >= (Log2(C2+C1) - Log2(C1)) // ((C1 << X) & C2) != 0 -> X < (Log2(C2+C1) - Log2(C1)) - unsigned Log2C2 = C2Pow2.countTrailingZeros(); + unsigned Log2C2 = C2Pow2.countr_zero(); auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1); auto NewPred = Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; @@ -5653,6 +6167,12 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { } } } + + // Op0 eq C_Pow2 -> Op0 ne 0 if Op0 is known to be C_Pow2 or zero. + if (Op1Known.isConstant() && Op1Known.getConstant().isPowerOf2() && + (Op0Known & Op1Known) == Op0Known) + return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0, + ConstantInt::getNullValue(Op1->getType())); break; } case ICmpInst::ICMP_ULT: { @@ -5733,8 +6253,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { /// If one operand of an icmp is effectively a bool (value range of {0,1}), /// then try to reduce patterns based on that limit. -static Instruction *foldICmpUsingBoolRange(ICmpInst &I, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { Value *X, *Y; ICmpInst::Predicate Pred; @@ -5750,6 +6269,60 @@ static Instruction *foldICmpUsingBoolRange(ICmpInst &I, Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + const APInt *C; + if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) && + match(I.getOperand(1), m_APInt(C)) && + X->getType()->isIntOrIntVectorTy(1) && + Y->getType()->isIntOrIntVectorTy(1)) { + unsigned BitWidth = C->getBitWidth(); + Pred = I.getPredicate(); + APInt Zero = APInt::getZero(BitWidth); + APInt MinusOne = APInt::getAllOnes(BitWidth); + APInt One(BitWidth, 1); + if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) || + (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT)) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) || + (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT)) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + if (I.getOperand(0)->hasOneUse()) { + APInt NewC = *C; + // canonicalize predicate to eq/ne + if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) || + (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) { + // x s< 0 in [-1, 1] --> x == -1 + // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1 + NewC = MinusOne; + Pred = ICmpInst::ICMP_EQ; + } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) || + (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) { + // x s> -1 in [-1, 1] --> x != -1 + // x u< -1 in [-1, 1] --> x != -1 + Pred = ICmpInst::ICMP_NE; + } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) { + // x s> 0 in [-1, 1] --> x == 1 + NewC = One; + Pred = ICmpInst::ICMP_EQ; + } else if (*C == One && Pred == ICmpInst::ICMP_SLT) { + // x s< 1 in [-1, 1] --> x != 1 + Pred = ICmpInst::ICMP_NE; + } + + if (NewC == MinusOne) { + if (Pred == ICmpInst::ICMP_EQ) + return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); + if (Pred == ICmpInst::ICMP_NE) + return BinaryOperator::CreateOr(X, Builder.CreateNot(Y)); + } else if (NewC == One) { + if (Pred == ICmpInst::ICMP_EQ) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y)); + if (Pred == ICmpInst::ICMP_NE) + return BinaryOperator::CreateOr(Builder.CreateNot(X), Y); + } + } + } + return nullptr; } @@ -6162,8 +6735,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { /// Orders the operands of the compare so that they are listed from most /// complex to least complex. This puts constants before unary operators, /// before binary operators. - if (Op0Cplxity < Op1Cplxity || - (Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) { + if (Op0Cplxity < Op1Cplxity) { I.swapOperands(); std::swap(Op0, Op1); Changed = true; @@ -6205,7 +6777,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; - if (Instruction *Res = foldICmpUsingBoolRange(I, Builder)) + if (Instruction *Res = foldICmpUsingBoolRange(I)) return Res; if (Instruction *Res = foldICmpUsingKnownBits(I)) @@ -6288,15 +6860,46 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I)) return NI; + // In case of a comparison with two select instructions having the same + // condition, check whether one of the resulting branches can be simplified. + // If so, just compare the other branch and select the appropriate result. + // For example: + // %tmp1 = select i1 %cmp, i32 %y, i32 %x + // %tmp2 = select i1 %cmp, i32 %z, i32 %x + // %cmp2 = icmp slt i32 %tmp2, %tmp1 + // The icmp will result false for the false value of selects and the result + // will depend upon the comparison of true values of selects if %cmp is + // true. Thus, transform this into: + // %cmp = icmp slt i32 %y, %z + // %sel = select i1 %cond, i1 %cmp, i1 false + // This handles similar cases to transform. + { + Value *Cond, *A, *B, *C, *D; + if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) && + match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // Check whether comparison of TrueValues can be simplified + if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) { + Value *NewICMP = Builder.CreateICmp(Pred, B, D); + return SelectInst::Create(Cond, Res, NewICMP); + } + // Check whether comparison of FalseValues can be simplified + if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) { + Value *NewICMP = Builder.CreateICmp(Pred, A, C); + return SelectInst::Create(Cond, NewICMP, Res); + } + } + } + // Try to optimize equality comparisons against alloca-based pointers. if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0))) - if (Instruction *New = foldAllocaCmp(I, Alloca)) - return New; + if (foldAllocaCmp(Alloca)) + return nullptr; if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1))) - if (Instruction *New = foldAllocaCmp(I, Alloca)) - return New; + if (foldAllocaCmp(Alloca)) + return nullptr; } if (Instruction *Res = foldICmpBitCast(I)) @@ -6363,6 +6966,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpEquality(I)) return Res; + if (Instruction *Res = foldICmpPow2Test(I, Builder)) + return Res; + if (Instruction *Res = foldICmpOfUAddOv(I)) return Res; @@ -6717,7 +7323,7 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { Mode.Input == DenormalMode::PositiveZero) { auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { - Constant *Zero = ConstantFP::getNullValue(X->getType()); + Constant *Zero = ConstantFP::getZero(X->getType()); return new FCmpInst(P, X, Zero, "", I); }; @@ -6813,7 +7419,7 @@ static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { // Replace the negated operand with 0.0: // fcmp Pred Op0, -Op0 --> fcmp Pred Op0, 0.0 - Constant *Zero = ConstantFP::getNullValue(Op0->getType()); + Constant *Zero = ConstantFP::getZero(Op0->getType()); return new FCmpInst(Pred, Op0, Zero, "", &I); } @@ -6863,11 +7469,13 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) - return replaceOperand(I, 0, ConstantFP::getNullValue(OpType)); + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, DL, &TLI, 0, + &AC, &I, &DT)) + return replaceOperand(I, 0, ConstantFP::getZero(OpType)); - if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) - return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + if (!match(Op1, m_PosZeroFP()) && + isKnownNeverNaN(Op1, DL, &TLI, 0, &AC, &I, &DT)) + return replaceOperand(I, 1, ConstantFP::getZero(OpType)); } // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y @@ -6896,7 +7504,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) - return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + return replaceOperand(I, 1, ConstantFP::getZero(OpType)); // Ignore signbit of bitcasted int when comparing equality to FP 0.0: // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 @@ -6985,11 +7593,11 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_ONE: // X is ordered and not equal to an impossible constant --> ordered return new FCmpInst(FCmpInst::FCMP_ORD, X, - ConstantFP::getNullValue(X->getType())); + ConstantFP::getZero(X->getType())); case FCmpInst::FCMP_UEQ: // X is unordered or equal to an impossible constant --> unordered return new FCmpInst(FCmpInst::FCMP_UNO, X, - ConstantFP::getNullValue(X->getType())); + ConstantFP::getZero(X->getType())); case FCmpInst::FCMP_UNE: // X is unordered or not equal to an impossible constant --> true return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index f4e88b122383..701579e1de48 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -150,7 +150,6 @@ public: Instruction *visitPHINode(PHINode &PN); Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); Instruction *visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src); - Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP); Instruction *visitAllocaInst(AllocaInst &AI); Instruction *visitAllocSite(Instruction &FI); Instruction *visitFree(CallInst &FI, Value *FreedOp); @@ -330,8 +329,7 @@ private: Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); Instruction *matchSAddSubSat(IntrinsicInst &MinMax1); Instruction *foldNot(BinaryOperator &I); - - void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); + Instruction *foldBinOpOfDisplacedShifts(BinaryOperator &I); /// Determine if a pair of casts can be replaced by a single cast. /// @@ -378,6 +376,7 @@ private: Instruction *foldLShrOverflowBit(BinaryOperator &I); Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); + Instruction *foldIntrinsicIsFPClass(IntrinsicInst &II); Instruction *foldFPSignBitOps(BinaryOperator &I); Instruction *foldFDivConstantDivisor(BinaryOperator &I); @@ -393,12 +392,12 @@ public: /// without having to rewrite the CFG from within InstCombine. void CreateNonTerminatorUnreachable(Instruction *InsertAt) { auto &Ctx = InsertAt->getContext(); - new StoreInst(ConstantInt::getTrue(Ctx), - PoisonValue::get(Type::getInt1PtrTy(Ctx)), - InsertAt); + auto *SI = new StoreInst(ConstantInt::getTrue(Ctx), + PoisonValue::get(Type::getInt1PtrTy(Ctx)), + /*isVolatile*/ false, Align(1)); + InsertNewInstBefore(SI, *InsertAt); } - /// Combiner aware instruction erasure. /// /// When dealing with an instruction that has side effects or produces a void @@ -411,12 +410,11 @@ public: // Make sure that we reprocess all operands now that we reduced their // use counts. - for (Use &Operand : I.operands()) - if (auto *Inst = dyn_cast<Instruction>(Operand)) - Worklist.add(Inst); - + SmallVector<Value *> Ops(I.operands()); Worklist.remove(&I); I.eraseFromParent(); + for (Value *Op : Ops) + Worklist.handleUseCountDecrement(Op); MadeIRChange = true; return nullptr; // Don't do anything with FI } @@ -450,6 +448,18 @@ public: Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, Value *RHS); + // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C)) + // -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C) + // (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt)) + // -> (BinOp (logic_shift (BinOp X, Y)), Mask) + Instruction *foldBinOpShiftWithShift(BinaryOperator &I); + + /// Tries to simplify binops of select and cast of the select condition. + /// + /// (Binop (cast C), (select C, T, F)) + /// -> (select C, C0, C1) + Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorizationFolds(BinaryOperator &I); @@ -549,7 +559,7 @@ public: ICmpInst::Predicate Cond, Instruction &I); Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI, Value *RHS, const ICmpInst &I); - Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca); + bool foldAllocaCmp(AllocaInst *Alloca); Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, @@ -564,6 +574,7 @@ public: Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpUsingBoolRange(ICmpInst &I); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, @@ -623,6 +634,7 @@ public: Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, const APInt &C); Instruction *foldICmpBitCast(ICmpInst &Cmp); + Instruction *foldICmpWithTrunc(ICmpInst &Cmp); // Helpers of visitSelectInst(). Instruction *foldSelectOfBools(SelectInst &SI); @@ -634,10 +646,11 @@ public: SelectPatternFlavor SPF2, Value *C); Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI); + bool replaceInInstruction(Value *V, Value *Old, Value *New, + unsigned Depth = 0); Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside); - Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); bool mergeStoreIntoSuccessor(StoreInst &SI); /// Given an initial instruction, check to see if it is the root of a @@ -651,10 +664,12 @@ public: Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); - /// Returns a value X such that Val = X * Scale, or null if none. - /// - /// If the multiplication is known not to overflow then NoSignedWrap is set. - Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); + bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock); + + bool removeInstructionsBeforeUnreachable(Instruction &I); + bool handleUnreachableFrom(Instruction *I); + bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); + void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); }; class Negator final { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 41bc65620ff6..6aa20ee26b9a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -32,7 +32,7 @@ STATISTIC(NumDeadStore, "Number of dead stores eliminated"); STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global"); static cl::opt<unsigned> MaxCopiedFromConstantUsers( - "instcombine-max-copied-from-constant-users", cl::init(128), + "instcombine-max-copied-from-constant-users", cl::init(300), cl::desc("Maximum users to visit in copy from constant transform"), cl::Hidden); @@ -219,7 +219,7 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, // Now that I is pointing to the first non-allocation-inst in the block, // insert our getelementptr instruction... // - Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType()); + Type *IdxTy = IC.getDataLayout().getIndexType(AI.getType()); Value *NullIdx = Constant::getNullValue(IdxTy); Value *Idx[2] = {NullIdx, NullIdx}; Instruction *GEP = GetElementPtrInst::CreateInBounds( @@ -235,11 +235,12 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, if (isa<UndefValue>(AI.getArraySize())) return IC.replaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); - // Ensure that the alloca array size argument has type intptr_t, so that - // any casting is exposed early. - Type *IntPtrTy = IC.getDataLayout().getIntPtrType(AI.getType()); - if (AI.getArraySize()->getType() != IntPtrTy) { - Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), IntPtrTy, false); + // Ensure that the alloca array size argument has type equal to the offset + // size of the alloca() pointer, which, in the tyical case, is intptr_t, + // so that any casting is exposed early. + Type *PtrIdxTy = IC.getDataLayout().getIndexType(AI.getType()); + if (AI.getArraySize()->getType() != PtrIdxTy) { + Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), PtrIdxTy, false); return IC.replaceOperand(AI, 0, V); } @@ -259,8 +260,8 @@ namespace { // instruction. class PointerReplacer { public: - PointerReplacer(InstCombinerImpl &IC, Instruction &Root) - : IC(IC), Root(Root) {} + PointerReplacer(InstCombinerImpl &IC, Instruction &Root, unsigned SrcAS) + : IC(IC), Root(Root), FromAS(SrcAS) {} bool collectUsers(); void replacePointer(Value *V); @@ -273,11 +274,21 @@ private: return I == &Root || Worklist.contains(I); } + bool isEqualOrValidAddrSpaceCast(const Instruction *I, + unsigned FromAS) const { + const auto *ASC = dyn_cast<AddrSpaceCastInst>(I); + if (!ASC) + return false; + unsigned ToAS = ASC->getDestAddressSpace(); + return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS); + } + SmallPtrSet<Instruction *, 32> ValuesToRevisit; SmallSetVector<Instruction *, 4> Worklist; MapVector<Value *, Value *> WorkMap; InstCombinerImpl &IC; Instruction &Root; + unsigned FromAS; }; } // end anonymous namespace @@ -341,6 +352,8 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) { if (MI->isVolatile()) return false; Worklist.insert(Inst); + } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) { + Worklist.insert(Inst); } else if (Inst->isLifetimeStartOrEnd()) { continue; } else { @@ -391,9 +404,8 @@ void PointerReplacer::replace(Instruction *I) { } else if (auto *BC = dyn_cast<BitCastInst>(I)) { auto *V = getReplacement(BC->getOperand(0)); assert(V && "Operand not replaced"); - auto *NewT = PointerType::getWithSamePointeeType( - cast<PointerType>(BC->getType()), - V->getType()->getPointerAddressSpace()); + auto *NewT = PointerType::get(BC->getType()->getContext(), + V->getType()->getPointerAddressSpace()); auto *NewI = new BitCastInst(V, NewT); IC.InsertNewInstWith(NewI, *BC); NewI->takeName(BC); @@ -426,6 +438,22 @@ void PointerReplacer::replace(Instruction *I) { IC.eraseInstFromFunction(*MemCpy); WorkMap[MemCpy] = NewI; + } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) { + auto *V = getReplacement(ASC->getPointerOperand()); + assert(V && "Operand not replaced"); + assert(isEqualOrValidAddrSpaceCast( + ASC, V->getType()->getPointerAddressSpace()) && + "Invalid address space cast!"); + auto *NewV = V; + if (V->getType()->getPointerAddressSpace() != + ASC->getType()->getPointerAddressSpace()) { + auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); + NewI->takeName(ASC); + IC.InsertNewInstWith(NewI, *ASC); + NewV = NewI; + } + IC.replaceInstUsesWith(*ASC, NewV); + IC.eraseInstFromFunction(*ASC); } else { llvm_unreachable("should never reach here"); } @@ -435,7 +463,7 @@ void PointerReplacer::replacePointer(Value *V) { #ifndef NDEBUG auto *PT = cast<PointerType>(Root.getType()); auto *NT = cast<PointerType>(V->getType()); - assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage"); + assert(PT != NT && "Invalid usage"); #endif WorkMap[&Root] = V; @@ -518,7 +546,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { return NewI; } - PointerReplacer PtrReplacer(*this, AI); + PointerReplacer PtrReplacer(*this, AI, SrcAddrSpace); if (PtrReplacer.collectUsers()) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); @@ -739,6 +767,11 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { // the knowledge that padding exists for the rest of the pipeline. const DataLayout &DL = IC.getDataLayout(); auto *SL = DL.getStructLayout(ST); + + // Don't unpack for structure with scalable vector. + if (SL->getSizeInBits().isScalable()) + return nullptr; + if (SL->hasPadding()) return nullptr; @@ -979,17 +1012,15 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, // If we're indexing into an object with a variable index for the memory // access, but the object has only one element, we can assume that the index // will always be zero. If we replace the GEP, return it. -template <typename T> static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr, - T &MemI) { + Instruction &MemI) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) { unsigned Idx; if (canReplaceGEPIdxWithZero(IC, GEPI, &MemI, Idx)) { Instruction *NewGEPI = GEPI->clone(); NewGEPI->setOperand(Idx, ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0)); - NewGEPI->insertBefore(GEPI); - MemI.setOperand(MemI.getPointerOperandIndex(), NewGEPI); + IC.InsertNewInstBefore(NewGEPI, *GEPI); return NewGEPI; } } @@ -1024,6 +1055,8 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { Value *Op = LI.getOperand(0); + if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI))) + return replaceInstUsesWith(LI, Res); // Try to canonicalize the loaded type. if (Instruction *Res = combineLoadToOperationType(*this, LI)) @@ -1036,10 +1069,8 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { LI.setAlignment(KnownAlign); // Replace GEP indices if possible. - if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) { - Worklist.push(NewGEPI); - return &LI; - } + if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) + return replaceOperand(LI, 0, NewGEPI); if (Instruction *Res = unpackLoadToAggregate(*this, LI)) return Res; @@ -1065,13 +1096,7 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { // load null/undef -> unreachable // TODO: Consider a target hook for valid address spaces for this xforms. if (canSimplifyNullLoadOrGEP(LI, Op)) { - // Insert a new store to null instruction before the load to indicate - // that this code is not reachable. We do this instead of inserting - // an unreachable instruction directly because we cannot modify the - // CFG. - StoreInst *SI = new StoreInst(PoisonValue::get(LI.getType()), - Constant::getNullValue(Op->getType()), &LI); - SI->setDebugLoc(LI.getDebugLoc()); + CreateNonTerminatorUnreachable(&LI); return replaceInstUsesWith(LI, PoisonValue::get(LI.getType())); } @@ -1261,6 +1286,11 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { // the knowledge that padding exists for the rest of the pipeline. const DataLayout &DL = IC.getDataLayout(); auto *SL = DL.getStructLayout(ST); + + // Don't unpack for structure with scalable vector. + if (SL->getSizeInBits().isScalable()) + return false; + if (SL->hasPadding()) return false; @@ -1443,10 +1473,8 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { return eraseInstFromFunction(SI); // Replace GEP indices if possible. - if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { - Worklist.push(NewGEPI); - return &SI; - } + if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) + return replaceOperand(SI, 1, NewGEPI); // Don't hack volatile/ordered stores. // FIXME: Some bits are legal for ordered atomic stores; needs refactoring. @@ -1530,6 +1558,16 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { return nullptr; // Do not modify these! } + // This is a non-terminator unreachable marker. Don't remove it. + if (isa<UndefValue>(Ptr)) { + // Remove all instructions after the marker and guaranteed-to-transfer + // instructions before the marker. + if (handleUnreachableFrom(SI.getNextNode()) || + removeInstructionsBeforeUnreachable(SI)) + return &SI; + return nullptr; + } + // store undef, Ptr -> noop // FIXME: This is technically incorrect because it might overwrite a poison // value. Change to PoisonValue once #52930 is resolved. @@ -1571,6 +1609,17 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { if (!OtherBr || BBI == OtherBB->begin()) return false; + auto OtherStoreIsMergeable = [&](StoreInst *OtherStore) -> bool { + if (!OtherStore || + OtherStore->getPointerOperand() != SI.getPointerOperand()) + return false; + + auto *SIVTy = SI.getValueOperand()->getType(); + auto *OSVTy = OtherStore->getValueOperand()->getType(); + return CastInst::isBitOrNoopPointerCastable(OSVTy, SIVTy, DL) && + SI.hasSameSpecialState(OtherStore); + }; + // If the other block ends in an unconditional branch, check for the 'if then // else' case. There is an instruction before the branch. StoreInst *OtherStore = nullptr; @@ -1586,8 +1635,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { // If this isn't a store, isn't a store to the same location, or is not the // right kind of store, bail out. OtherStore = dyn_cast<StoreInst>(BBI); - if (!OtherStore || OtherStore->getOperand(1) != SI.getOperand(1) || - !SI.isSameOperationAs(OtherStore)) + if (!OtherStoreIsMergeable(OtherStore)) return false; } else { // Otherwise, the other block ended with a conditional branch. If one of the @@ -1601,12 +1649,10 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { // lives in OtherBB. for (;; --BBI) { // Check to see if we find the matching store. - if ((OtherStore = dyn_cast<StoreInst>(BBI))) { - if (OtherStore->getOperand(1) != SI.getOperand(1) || - !SI.isSameOperationAs(OtherStore)) - return false; + OtherStore = dyn_cast<StoreInst>(BBI); + if (OtherStoreIsMergeable(OtherStore)) break; - } + // If we find something that may be using or overwriting the stored // value, or if we run out of instructions, we can't do the transform. if (BBI->mayReadFromMemory() || BBI->mayThrow() || @@ -1624,14 +1670,17 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { } // Insert a PHI node now if we need it. - Value *MergedVal = OtherStore->getOperand(0); + Value *MergedVal = OtherStore->getValueOperand(); // The debug locations of the original instructions might differ. Merge them. DebugLoc MergedLoc = DILocation::getMergedLocation(SI.getDebugLoc(), OtherStore->getDebugLoc()); - if (MergedVal != SI.getOperand(0)) { - PHINode *PN = PHINode::Create(MergedVal->getType(), 2, "storemerge"); - PN->addIncoming(SI.getOperand(0), SI.getParent()); - PN->addIncoming(OtherStore->getOperand(0), OtherBB); + if (MergedVal != SI.getValueOperand()) { + PHINode *PN = + PHINode::Create(SI.getValueOperand()->getType(), 2, "storemerge"); + PN->addIncoming(SI.getValueOperand(), SI.getParent()); + Builder.SetInsertPoint(OtherStore); + PN->addIncoming(Builder.CreateBitOrPointerCast(MergedVal, PN->getType()), + OtherBB); MergedVal = InsertNewInstBefore(PN, DestBB->front()); PN->setDebugLoc(MergedLoc); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 97f129e200de..50458e2773e6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -185,6 +185,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, return nullptr; } +static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, + bool AssumeNonZero, bool DoFold); + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = @@ -270,7 +273,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (match(Op0, m_ZExtOrSExt(m_Value(X))) && match(Op1, m_APIntAllowUndef(NegPow2C))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); - unsigned ShiftAmt = NegPow2C->countTrailingZeros(); + unsigned ShiftAmt = NegPow2C->countr_zero(); if (ShiftAmt >= BitWidth - SrcWidth) { Value *N = Builder.CreateNeg(X, X->getName() + ".neg"); Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z"); @@ -471,6 +474,40 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + + // min(X, Y) * max(X, Y) => X * Y. + if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)), + m_c_SMin(m_Deferred(X), m_Deferred(Y))), + m_c_Mul(m_UMax(m_Value(X), m_Value(Y)), + m_c_UMin(m_Deferred(X), m_Deferred(Y)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I); + + // (mul Op0 Op1): + // if Log2(Op0) folds away -> + // (shl Op1, Log2(Op0)) + // if Log2(Op1) folds away -> + // (shl Op0, Log2(Op1)) + if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res); + // We can only propegate nuw flag. + Shl->setHasNoUnsignedWrap(HasNUW); + return Shl; + } + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res); + // We can only propegate nuw flag. + Shl->setHasNoUnsignedWrap(HasNUW); + return Shl; + } + bool Changed = false; if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; @@ -765,6 +802,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { I.hasNoSignedZeros() && match(Start, m_Zero())) return replaceInstUsesWith(I, Start); + // minimun(X, Y) * maximum(X, Y) => X * Y. + if (match(&I, + m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), + m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), + m_Deferred(Y))))) { + BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I); + // We cannot preserve ninf if nnan flag is not set. + // If X is NaN and Y is Inf then in original program we had NaN * NaN, + // while in optimized version NaN * Inf and this is a poison with ninf flag. + if (!Result->hasNoNaNs()) + Result->setHasNoInfs(false); + return Result; + } + return nullptr; } @@ -976,9 +1027,9 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { ConstantInt::get(Ty, Product)); } + APInt Quotient(C2->getBitWidth(), /*val=*/0ULL, IsSigned); if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. if (isMultiple(*C2, *C1, Quotient, IsSigned)) { @@ -1003,7 +1054,6 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { C1->ult(C1->getBitWidth() - 1)) || (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) && C1->ult(C1->getBitWidth()))) { - APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); APInt C1Shifted = APInt::getOneBitSet( C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue())); @@ -1026,6 +1076,23 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } + // Distribute div over add to eliminate a matching div/mul pair: + // ((X * C2) + C1) / C2 --> X + C1/C2 + // We need a multiple of the divisor for a signed add constant, but + // unsigned is fine with any constant pair. + if (IsSigned && + match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1))) && + isMultiple(*C1, *C2, Quotient, IsSigned)) { + return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient)); + } + if (!IsSigned && + match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1)))) { + return BinaryOperator::CreateNUWAdd(X, + ConstantInt::get(Ty, C1->udiv(*C2))); + } + if (!C2->isZero()) // avoid X udiv 0 if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) return FoldedDiv; @@ -1121,7 +1188,7 @@ static const unsigned MaxDepth = 6; // actual instructions, otherwise return a non-null dummy value. Return nullptr // on failure. static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold) { + bool AssumeNonZero, bool DoFold) { auto IfFold = [DoFold](function_ref<Value *()> Fn) { if (!DoFold) return reinterpret_cast<Value *>(-1); @@ -1147,14 +1214,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(X << Y) -> log2(X) + Y // FIXME: Require one use unless X is 1? - if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) - return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) { + auto *BO = cast<OverflowingBinaryOperator>(Op); + // nuw will be set if the `shl` is trivially non-zero. + if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: missed optimization: if one of the hands of select is/contains @@ -1162,8 +1233,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast<SelectInst>(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1171,13 +1244,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op); - if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) { + // Use AssumeNonZero as false here. Otherwise we can hit case where + // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) return IfFold([&]() { - return Builder.CreateBinaryIntrinsic( - MinMax->getIntrinsicID(), LogX, LogY); + return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, + LogY); }); + } return nullptr; } @@ -1297,8 +1375,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, + /*AssumeNonZero*/ true, /*DoFold*/ true); return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); } @@ -1359,7 +1439,8 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // (sext X) sdiv C --> sext (X sdiv C) Value *Op0Src; if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && - Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + Op0Src->getType()->getScalarSizeInBits() >= + Op1C->getSignificantBits()) { // In the general case, we need to make sure that the dividend is not the // minimum signed value because dividing that by -1 is UB. But here, we @@ -1402,7 +1483,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { KnownBits KnownDividend = computeKnownBits(Op0, 0, &I); if (!I.isExact() && (match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) && - KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) { + KnownDividend.countMinTrailingZeros() >= Op1C->countr_zero()) { I.setIsExact(); return &I; } @@ -1681,6 +1762,111 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { return nullptr; } +// Variety of transform for: +// (urem/srem (mul X, Y), (mul X, Z)) +// (urem/srem (shl X, Y), (shl X, Z)) +// (urem/srem (shl Y, X), (shl Z, X)) +// NB: The shift cases are really just extensions of the mul case. We treat +// shift as Val * (1 << Amt). +static Instruction *simplifyIRemMulShl(BinaryOperator &I, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr; + APInt Y, Z; + bool ShiftByX = false; + + // If V is not nullptr, it will be matched using m_Specific. + auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp))))) + C = *Tmp; + else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) + C = APInt(Tmp->getBitWidth(), 1) << *Tmp; + if (Tmp != nullptr) + return true; + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) || + (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) { + C = *Tmp; + return true; + } + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) { + // pass + } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) { + ShiftByX = true; + } else { + return nullptr; + } + + bool IsSRem = I.getOpcode() == Instruction::SRem; + + OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0); + // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >= + // Z or Z >= Y. + bool BO0HasNSW = BO0->hasNoSignedWrap(); + bool BO0HasNUW = BO0->hasNoUnsignedWrap(); + bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW; + + APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z); + // (rem (mul nuw/nsw X, Y), (mul X, Z)) + // if (rem Y, Z) == 0 + // -> 0 + if (RemYZ.isZero() && BO0NoWrap) + return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + + // Helper function to emit either (RemSimplificationC << X) or + // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as + // (shl V, X) or (mul V, X) respectively. + auto CreateMulOrShift = + [&](const APInt &RemSimplificationC) -> BinaryOperator * { + Value *RemSimplification = + ConstantInt::get(I.getType(), RemSimplificationC); + return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X) + : BinaryOperator::CreateMul(X, RemSimplification); + }; + + OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1); + bool BO1HasNSW = BO1->hasNoSignedWrap(); + bool BO1HasNUW = BO1->hasNoUnsignedWrap(); + bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW; + // (rem (mul X, Y), (mul nuw/nsw X, Z)) + // if (rem Y, Z) == Y + // -> (mul nuw/nsw X, Y) + if (RemYZ == Y && BO1NoWrap) { + BinaryOperator *BO = CreateMulOrShift(Y); + // Copy any overflow flags from Op0. + BO->setHasNoSignedWrap(IsSRem || BO0HasNSW); + BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW); + return BO; + } + + // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) + // if Y >= Z + // -> (mul {nuw} nsw X, (rem Y, Z)) + if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { + BinaryOperator *BO = CreateMulOrShift(RemYZ); + BO->setHasNoSignedWrap(); + BO->setHasNoUnsignedWrap(BO0HasNUW); + return BO; + } + + return nullptr; +} + /// This function implements the transforms common to both integer remainder /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. @@ -1733,6 +1919,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { } } + if (Instruction *R = simplifyIRemMulShl(I, *this)) + return R; + return nullptr; } @@ -1782,8 +1971,21 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 Value *X; if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); - return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); + Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = + Builder.CreateICmpEQ(FrozenOp0, ConstantInt::getAllOnesValue(Ty)); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); + } + + // For "(X + 1) % Op1" and if (X u< Op1) => (X + 1) == Op1 ? 0 : X + 1 . + if (match(Op0, m_Add(m_Value(X), m_One()))) { + Value *Val = + simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ.getWithInstruction(&I)); + if (Val && match(Val, m_One())) { + Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, Op1); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); + } } return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 7f59729f0085..2f6aa85062a5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -316,7 +316,7 @@ Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) { for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) { if (auto *NewOp = simplifyIntToPtrRoundTripCast(PN.getIncomingValue(OpNum))) { - PN.setIncomingValue(OpNum, NewOp); + replaceOperand(PN, OpNum, NewOp); OperandWithRoundTripCast = true; } } @@ -745,6 +745,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_access_group, + LLVMContext::MD_noundef, }; for (unsigned ID : KnownIDs) @@ -1388,11 +1389,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // If all PHI operands are the same operation, pull them through the PHI, // reducing code size. - if (isa<Instruction>(PN.getIncomingValue(0)) && - isa<Instruction>(PN.getIncomingValue(1)) && - cast<Instruction>(PN.getIncomingValue(0))->getOpcode() == - cast<Instruction>(PN.getIncomingValue(1))->getOpcode() && - PN.getIncomingValue(0)->hasOneUser()) + auto *Inst0 = dyn_cast<Instruction>(PN.getIncomingValue(0)); + auto *Inst1 = dyn_cast<Instruction>(PN.getIncomingValue(1)); + if (Inst0 && Inst1 && Inst0->getOpcode() == Inst1->getOpcode() && + Inst0->hasOneUser()) if (Instruction *Result = foldPHIArgOpIntoPHI(PN)) return Result; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index e7d8208f94fd..661c50062223 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -98,7 +98,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, // +0.0 compares equal to -0.0, and so it does not behave as required for this // transform. Bail out if we can not exclude that possibility. if (isa<FPMathOperator>(BO)) - if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) + if (!BO->hasNoSignedZeros() && + !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI)) return nullptr; // BO = binop Y, X @@ -386,6 +387,32 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); } } + + // select c, (ldexp v, e0), (ldexp v, e1) -> ldexp v, (select c, e0, e1) + // select c, (ldexp v0, e), (ldexp v1, e) -> ldexp (select c, v0, v1), e + // + // select c, (ldexp v0, e0), (ldexp v1, e1) -> + // ldexp (select c, v0, v1), (select c, e0, e1) + if (TII->getIntrinsicID() == Intrinsic::ldexp) { + Value *LdexpVal0 = TII->getArgOperand(0); + Value *LdexpExp0 = TII->getArgOperand(1); + Value *LdexpVal1 = FII->getArgOperand(0); + Value *LdexpExp1 = FII->getArgOperand(1); + if (LdexpExp0->getType() == LdexpExp1->getType()) { + FPMathOperator *SelectFPOp = cast<FPMathOperator>(&SI); + FastMathFlags FMF = cast<FPMathOperator>(TII)->getFastMathFlags(); + FMF &= cast<FPMathOperator>(FII)->getFastMathFlags(); + FMF |= SelectFPOp->getFastMathFlags(); + + Value *SelectVal = Builder.CreateSelect(Cond, LdexpVal0, LdexpVal1); + Value *SelectExp = Builder.CreateSelect(Cond, LdexpExp0, LdexpExp1); + + CallInst *NewLdexp = Builder.CreateIntrinsic( + TII->getType(), Intrinsic::ldexp, {SelectVal, SelectExp}); + NewLdexp->setFastMathFlags(FMF); + return replaceInstUsesWith(SI, NewLdexp); + } + } } // icmp with a common operand also can have the common operand @@ -429,6 +456,21 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, !OtherOpF->getType()->isVectorTy())) return nullptr; + // If we are sinking div/rem after a select, we may need to freeze the + // condition because div/rem may induce immediate UB with a poison operand. + // For example, the following transform is not safe if Cond can ever be poison + // because we can replace poison with zero and then we have div-by-zero that + // didn't exist in the original code: + // Cond ? x/y : x/z --> x / (Cond ? y : z) + auto *BO = dyn_cast<BinaryOperator>(TI); + if (BO && BO->isIntDivRem() && !isGuaranteedNotToBePoison(Cond)) { + // A udiv/urem with a common divisor is safe because UB can only occur with + // div-by-zero, and that would be present in the original code. + if (BO->getOpcode() == Instruction::SDiv || + BO->getOpcode() == Instruction::SRem || MatchIsOpZero) + Cond = Builder.CreateFreeze(Cond); + } + // If we reach here, they do have operations in common. Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, SI.getName() + ".v", &SI); @@ -461,7 +503,7 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) { /// optimization. Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { - // See the comment above GetSelectFoldableOperands for a description of the + // See the comment above getSelectFoldableOperands for a description of the // transformation we are doing here. auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, Value *FalseVal, @@ -496,7 +538,7 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, - Swapped ? OOp : C); + Swapped ? OOp : C, "", &SI); if (isa<FPMathOperator>(&SI)) cast<Instruction>(NewSel)->setFastMathFlags(FMF); NewSel->takeName(TVI); @@ -569,6 +611,44 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, } /// We want to turn: +/// (select (icmp eq (and X, C1), 0), 0, (shl [nsw/nuw] X, C2)); +/// iff C1 is a mask and the number of its leading zeros is equal to C2 +/// into: +/// shl X, C2 +static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal, + Value *FVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred; + Value *AndVal; + if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero()))) + return nullptr; + + if (Pred == ICmpInst::ICMP_NE) { + Pred = ICmpInst::ICMP_EQ; + std::swap(TVal, FVal); + } + + Value *X; + const APInt *C2, *C1; + if (Pred != ICmpInst::ICMP_EQ || + !match(AndVal, m_And(m_Value(X), m_APInt(C1))) || + !match(TVal, m_Zero()) || !match(FVal, m_Shl(m_Specific(X), m_APInt(C2)))) + return nullptr; + + if (!C1->isMask() || + C1->countLeadingZeros() != static_cast<unsigned>(C2->getZExtValue())) + return nullptr; + + auto *FI = dyn_cast<Instruction>(FVal); + if (!FI) + return nullptr; + + FI->setHasNoSignedWrap(false); + FI->setHasNoUnsignedWrap(false); + return FVal; +} + +/// We want to turn: /// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 /// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 /// into: @@ -935,10 +1015,53 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Try to match patterns with select and subtract as absolute difference. +static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + auto *TI = dyn_cast<Instruction>(TVal); + auto *FI = dyn_cast<Instruction>(FVal); + if (!TI || !FI) + return nullptr; + + // Normalize predicate to gt/lt rather than ge/le. + ICmpInst::Predicate Pred = Cmp->getStrictPredicate(); + Value *A = Cmp->getOperand(0); + Value *B = Cmp->getOperand(1); + + // Normalize "A - B" as the true value of the select. + if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) { + std::swap(FI, TI); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + // With any pair of no-wrap subtracts: + // (A > B) ? (A - B) : (B - A) --> abs(A - B) + if (Pred == CmpInst::ICMP_SGT && + match(TI, m_Sub(m_Specific(A), m_Specific(B))) && + match(FI, m_Sub(m_Specific(B), m_Specific(A))) && + (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) && + (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) { + // The remaining subtract is not "nuw" any more. + // If there's one use of the subtract (no other use than the use we are + // about to replace), then we know that the sub is "nsw" in this context + // even if it was only "nuw" before. If there's another use, then we can't + // add "nsw" to the existing instruction because it may not be safe in the + // other user's context. + TI->setHasNoUnsignedWrap(false); + if (!TI->hasNoSignedWrap()) + TI->setHasNoSignedWrap(TI->hasOneUse()); + return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue()); + } + + return nullptr; +} + /// Fold the following code sequence: /// \code /// int a = ctlz(x & -x); // x ? 31 - a : a; +// // or +// x ? 31 - a : 32; /// \code /// /// into: @@ -953,15 +1076,19 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, if (ICI->getPredicate() == ICmpInst::ICMP_NE) std::swap(TrueVal, FalseVal); + Value *Ctlz; if (!match(FalseVal, - m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) + m_Xor(m_Value(Ctlz), m_SpecificInt(BitWidth - 1)))) return nullptr; - if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) + if (!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>())) + return nullptr; + + if (TrueVal != Ctlz && !match(TrueVal, m_SpecificInt(BitWidth))) return nullptr; Value *X = ICI->getOperand(0); - auto *II = cast<IntrinsicInst>(TrueVal); + auto *II = cast<IntrinsicInst>(Ctlz); if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) return nullptr; @@ -1038,99 +1165,6 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return nullptr; } -/// Return true if we find and adjust an icmp+select pattern where the compare -/// is with a constant that can be incremented or decremented to match the -/// minimum or maximum idiom. -static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { - ICmpInst::Predicate Pred = Cmp.getPredicate(); - Value *CmpLHS = Cmp.getOperand(0); - Value *CmpRHS = Cmp.getOperand(1); - Value *TrueVal = Sel.getTrueValue(); - Value *FalseVal = Sel.getFalseValue(); - - // We may move or edit the compare, so make sure the select is the only user. - const APInt *CmpC; - if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) - return false; - - // These transforms only work for selects of integers or vector selects of - // integer vectors. - Type *SelTy = Sel.getType(); - auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); - if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) - return false; - - Constant *AdjustedRHS; - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) - AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); - else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) - AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); - else - return false; - - // X > C ? X : C+1 --> X < C+1 ? C+1 : X - // X < C ? X : C-1 --> X > C-1 ? C-1 : X - if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { - ; // Nothing to do here. Values match without any sign/zero extension. - } - // Types do not match. Instead of calculating this with mixed types, promote - // all to the larger type. This enables scalar evolution to analyze this - // expression. - else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { - Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); - - // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X - // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X - // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X - // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X - if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = SextRHS; - } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && - SextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = SextRHS; - } else if (Cmp.isUnsigned()) { - Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); - // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X - // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X - // zext + signed compare cannot be changed: - // 0xff <s 0x00, but 0x00ff >s 0x0000 - if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = ZextRHS; - } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && - ZextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = ZextRHS; - } else { - return false; - } - } else { - return false; - } - } else { - return false; - } - - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - Cmp.setPredicate(Pred); - Cmp.setOperand(0, CmpLHS); - Cmp.setOperand(1, CmpRHS); - Sel.setOperand(1, TrueVal); - Sel.setOperand(2, FalseVal); - Sel.swapProfMetadata(); - - // Move the compare instruction right before the select instruction. Otherwise - // the sext/zext value may be defined after the compare instruction uses it. - Cmp.moveBefore(&Sel); - - return true; -} - static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, InstCombinerImpl &IC) { Value *LHS, *RHS; @@ -1182,8 +1216,8 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, return nullptr; } -static bool replaceInInstruction(Value *V, Value *Old, Value *New, - InstCombiner &IC, unsigned Depth = 0) { +bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New, + unsigned Depth) { // Conservatively limit replacement to two instructions upwards. if (Depth == 2) return false; @@ -1195,10 +1229,11 @@ static bool replaceInInstruction(Value *V, Value *Old, Value *New, bool Changed = false; for (Use &U : I->operands()) { if (U == Old) { - IC.replaceUse(U, New); + replaceUse(U, New); + Worklist.add(I); Changed = true; } else { - Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1); + Changed |= replaceInInstruction(U, Old, New, Depth + 1); } } return Changed; @@ -1254,7 +1289,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // FIXME: Support vectors. if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && !Cmp.getType()->isVectorTy()) - if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this)) + if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS)) return &Sel; } if (TrueVal != CmpRHS && @@ -1593,13 +1628,32 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, return nullptr; } -static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) { +static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, + InstCombiner::BuilderTy &Builder) { const APInt *CmpC; Value *V; CmpInst::Predicate Pred; if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC)))) return nullptr; + // Match clamp away from min/max value as a max/min operation. + Value *TVal = SI.getTrueValue(); + Value *FVal = SI.getFalseValue(); + if (Pred == ICmpInst::ICMP_EQ && V == FVal) { + // (V == UMIN) ? UMIN+1 : V --> umax(V, UMIN+1) + if (CmpC->isMinValue() && match(TVal, m_SpecificInt(*CmpC + 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::umax, V, TVal); + // (V == UMAX) ? UMAX-1 : V --> umin(V, UMAX-1) + if (CmpC->isMaxValue() && match(TVal, m_SpecificInt(*CmpC - 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::umin, V, TVal); + // (V == SMIN) ? SMIN+1 : V --> smax(V, SMIN+1) + if (CmpC->isMinSignedValue() && match(TVal, m_SpecificInt(*CmpC + 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::smax, V, TVal); + // (V == SMAX) ? SMAX-1 : V --> smin(V, SMAX-1) + if (CmpC->isMaxSignedValue() && match(TVal, m_SpecificInt(*CmpC - 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal); + } + BinaryOperator *BO; const APInt *C; CmpInst::Predicate CPred; @@ -1632,7 +1686,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this)) return NewSPF; - if (Value *V = foldSelectInstWithICmpConst(SI, ICI)) + if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) @@ -1642,18 +1696,17 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) return NewSel; - bool Changed = adjustMinMax(SI, *ICI); - if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); // NOTE: if we wanted to, this is where to detect integer MIN/MAX + bool Changed = false; Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { + if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS) && !isa<Constant>(CmpLHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y SI.setOperand(1, CmpRHS); @@ -1683,7 +1736,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring // decomposeBitTestICmp() might help. - { + if (TrueVal->getType()->isIntOrIntVectorTy()) { unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); APInt MinSignedValue = APInt::getSignedMinValue(BitWidth); @@ -1735,6 +1788,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) return V; + if (Value *V = foldSelectICmpAndZeroShl(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) return V; @@ -1756,6 +1812,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } @@ -2418,7 +2477,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { // in the case of a shuffle with no undefined mask elements. ArrayRef<int> Mask; if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && - !is_contained(Mask, UndefMaskElem) && + !is_contained(Mask, PoisonMaskElem) && cast<ShuffleVectorInst>(TVal)->isSelect()) { if (X == FVal) { // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X) @@ -2432,7 +2491,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { } } if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && - !is_contained(Mask, UndefMaskElem) && + !is_contained(Mask, PoisonMaskElem) && cast<ShuffleVectorInst>(FVal)->isSelect()) { if (X == TVal) { // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y) @@ -2965,6 +3024,14 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); + // select a, (select ~a, true, b), false -> select a, b, false + if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) && + match(FalseVal, m_Zero())) + return replaceOperand(SI, 1, B); + // select a, true, (select ~a, b, false) -> select a, true, b + if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) && + match(TrueVal, m_One())) + return replaceOperand(SI, 2, B); // ~(A & B) & (A | B) --> A ^ B if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), @@ -3077,6 +3144,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return nullptr; } +// Return true if we can safely remove the select instruction for std::bit_ceil +// pattern. +static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, + const APInt *Cond1, Value *CtlzOp, + unsigned BitWidth) { + // The challenge in recognizing std::bit_ceil(X) is that the operand is used + // for the CTLZ proper and select condition, each possibly with some + // operation like add and sub. + // + // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the + // select instruction would select 1, which allows us to get rid of the select + // instruction. + // + // To see if we can do so, we do some symbolic execution with ConstantRange. + // Specifically, we compute the range of values that Cond0 could take when + // Cond == false. Then we successively transform the range until we obtain + // the range of values that CtlzOp could take. + // + // Conceptually, we follow the def-use chain backward from Cond0 while + // transforming the range for Cond0 until we meet the common ancestor of Cond0 + // and CtlzOp. Then we follow the def-use chain forward until we obtain the + // range for CtlzOp. That said, we only follow at most one ancestor from + // Cond0. Likewise, we only follow at most one ancestor from CtrlOp. + + ConstantRange CR = ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), *Cond1); + + // Match the operation that's used to compute CtlzOp from CommonAncestor. If + // CtlzOp == CommonAncestor, return true as no operation is needed. If a + // match is found, execute the operation on CR, update CR, and return true. + // Otherwise, return false. + auto MatchForward = [&](Value *CommonAncestor) { + const APInt *C = nullptr; + if (CtlzOp == CommonAncestor) + return true; + if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) { + CR = CR.add(*C); + return true; + } + if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { + CR = ConstantRange(*C).sub(CR); + return true; + } + if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) { + CR = CR.binaryNot(); + return true; + } + return false; + }; + + const APInt *C = nullptr; + Value *CommonAncestor; + if (MatchForward(Cond0)) { + // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated. + } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) { + CR = CR.sub(*C); + if (!MatchForward(CommonAncestor)) + return false; + // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated. + } else { + return false; + } + + // Return true if all the values in the range are either 0 or negative (if + // treated as signed). We do so by evaluating: + // + // CR - 1 u>= (1 << BitWidth) - 1. + APInt IntMax = APInt::getSignMask(BitWidth) - 1; + CR = CR.sub(APInt(BitWidth, 1)); + return CR.icmp(ICmpInst::ICMP_UGE, IntMax); +} + +// Transform the std::bit_ceil(X) pattern like: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %sub = sub i32 32, %ctlz +// %shl = shl i32 1, %sub +// %ugt = icmp ugt i32 %x, 1 +// %sel = select i1 %ugt, i32 %shl, i32 1 +// +// into: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %neg = sub i32 0, %ctlz +// %masked = and i32 %ctlz, 31 +// %shl = shl i32 1, %sub +// +// Note that the select is optimized away while the shift count is masked with +// 31. We handle some variations of the input operand like std::bit_ceil(X + +// 1). +static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + Value *FalseVal = SI.getFalseValue(); + Value *TrueVal = SI.getTrueValue(); + ICmpInst::Predicate Pred; + const APInt *Cond1; + Value *Cond0, *Ctlz, *CtlzOp; + if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1)))) + return nullptr; + + if (match(TrueVal, m_One())) { + std::swap(FalseVal, TrueVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + if (!match(FalseVal, m_One()) || + !match(TrueVal, + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(Ctlz)))))) || + !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) || + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + return nullptr; + + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a + // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth + // is an integer constant. Masking with BitWidth-1 comes free on some + // hardware as part of the shift instruction. + Value *Neg = Builder.CreateNeg(Ctlz); + Value *Masked = + Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); + return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), + Masked); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3253,6 +3448,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { std::swap(NewT, NewF); Value *NewSI = Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); + if (Gep->isInBounds()) + return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI}); return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); }; if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal)) @@ -3364,25 +3561,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } - auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { - // The select might be preventing a division by 0. - switch (BO->getOpcode()) { - default: - return true; - case Instruction::SRem: - case Instruction::URem: - case Instruction::SDiv: - case Instruction::UDiv: - return false; - } - }; - // Try to simplify a binop sandwiched between 2 selects with the same - // condition. + // condition. This is not valid for div/rem because the select might be + // preventing a division-by-zero. + // TODO: A div/rem restriction is conservative; use something like + // isSafeToSpeculativelyExecute(). // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) BinaryOperator *TrueBO; - if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && - canMergeSelectThroughBinop(TrueBO)) { + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue()); @@ -3401,8 +3587,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) BinaryOperator *FalseBO; - if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && - canMergeSelectThroughBinop(FalseBO)) { + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue()); @@ -3516,5 +3701,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitCeil(SI, Builder)) + return I; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index ec505381cc86..89dad455f015 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -322,15 +322,20 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } -/// If we have a shift-by-constant of a bitwise logic op that itself has a -/// shift-by-constant operand with identical opcode, we may be able to convert -/// that into 2 independent shifts followed by the logic op. This eliminates a -/// a use of an intermediate value (reduces dependency chain). -static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, +/// If we have a shift-by-constant of a bin op (bitwise logic op or add/sub w/ +/// shl) that itself has a shift-by-constant operand with identical opcode, we +/// may be able to convert that into 2 independent shifts followed by the logic +/// op. This eliminates a use of an intermediate value (reduces dependency +/// chain). +static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { assert(I.isShift() && "Expected a shift as input"); - auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); - if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) + auto *BinInst = dyn_cast<BinaryOperator>(I.getOperand(0)); + if (!BinInst || + (!BinInst->isBitwiseLogicOp() && + BinInst->getOpcode() != Instruction::Add && + BinInst->getOpcode() != Instruction::Sub) || + !BinInst->hasOneUse()) return nullptr; Constant *C0, *C1; @@ -338,6 +343,12 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, return nullptr; Instruction::BinaryOps ShiftOpcode = I.getOpcode(); + // Transform for add/sub only works with shl. + if ((BinInst->getOpcode() == Instruction::Add || + BinInst->getOpcode() == Instruction::Sub) && + ShiftOpcode != Instruction::Shl) + return nullptr; + Type *Ty = I.getType(); // Find a matching one-use shift by constant. The fold is not valid if the sum @@ -352,19 +363,25 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; - // Logic ops are commutative, so check each operand for a match. - if (matchFirstShift(LogicInst->getOperand(0))) - Y = LogicInst->getOperand(1); - else if (matchFirstShift(LogicInst->getOperand(1))) - Y = LogicInst->getOperand(0); - else + // Logic ops and Add are commutative, so check each operand for a match. Sub + // is not so we cannot reoder if we match operand(1) and need to keep the + // operands in their original positions. + bool FirstShiftIsOp1 = false; + if (matchFirstShift(BinInst->getOperand(0))) + Y = BinInst->getOperand(1); + else if (matchFirstShift(BinInst->getOperand(1))) { + Y = BinInst->getOperand(0); + FirstShiftIsOp1 = BinInst->getOpcode() == Instruction::Sub; + } else return nullptr; - // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) + // shift (binop (shift X, C0), Y), C1 -> binop (shift X, C0+C1), (shift Y, C1) Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1); - return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); + Value *Op1 = FirstShiftIsOp1 ? NewShift2 : NewShift1; + Value *Op2 = FirstShiftIsOp1 ? NewShift1 : NewShift2; + return BinaryOperator::Create(BinInst->getOpcode(), Op1, Op2); } Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { @@ -463,9 +480,12 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { return replaceOperand(I, 1, Rem); } - if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) + if (Instruction *Logic = foldShiftOfShiftedBinOp(I, Builder)) return Logic; + if (match(Op1, m_Or(m_Value(), m_SpecificInt(BitWidth - 1)))) + return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1)); + return nullptr; } @@ -570,8 +590,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, const APInt *MulConst; // We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`) return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) && - MulConst->isNegatedPowerOf2() && - MulConst->countTrailingZeros() == NumBits; + MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits; } } } @@ -900,8 +919,10 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { // Replace the uses of the original add with a zext of the // NarrowAdd's result. Note that all users at this stage are known to // be ShAmt-sized truncs, or the lshr itself. - if (!Add->hasOneUse()) + if (!Add->hasOneUse()) { replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty)); + eraseInstFromFunction(*AddInst); + } // Replace the LShr with a zext of the overflow check. return new ZExtInst(Overflow, Ty); @@ -1133,6 +1154,14 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { return BinaryOperator::CreateLShr( ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + // Canonicalize "extract lowest set bit" using cttz to and-with-negate: + // 1 << (cttz X) --> -X & X + if (match(Op1, + m_OneUse(m_Intrinsic<Intrinsic::cttz>(m_Value(X), m_Value())))) { + Value *NegX = Builder.CreateNeg(X, "neg"); + return BinaryOperator::CreateAnd(NegX, X); + } + // The only way to shift out the 1 is with an over-shift, so that would // be poison with or without "nuw". Undef is excluded because (undef << X) // is not undef (it is zero). diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 77d675422966..00eece9534b0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -168,7 +168,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care // about the high bits of the operands. auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); // Right fill the mask of bits for the operands to demand the most // significant bit and all those below it. DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); @@ -195,7 +195,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = LHSKnown & RHSKnown; + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, DL, &AC, CxtI, &DT); // If the client is only demanding bits that we know, return the known // constant. @@ -224,7 +225,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = LHSKnown | RHSKnown; + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, DL, &AC, CxtI, &DT); // If the client is only demanding bits that we know, return the known // constant. @@ -262,7 +264,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = LHSKnown ^ RHSKnown; + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, DL, &AC, CxtI, &DT); // If the client is only demanding bits that we know, return the known // constant. @@ -381,7 +384,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; // Only known if known in both the LHS and RHS. - Known = KnownBits::commonBits(LHSKnown, RHSKnown); + Known = LHSKnown.intersectWith(RHSKnown); break; } case Instruction::Trunc: { @@ -393,7 +396,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // The shift amount must be valid (not poison) in the narrow type, and // it must not be greater than the high bits demanded of the result. if (C->ult(VTy->getScalarSizeInBits()) && - C->ule(DemandedMask.countLeadingZeros())) { + C->ule(DemandedMask.countl_zero())) { // trunc (lshr X, C) --> lshr (trunc X), C IRBuilderBase::InsertPointGuard Guard(Builder); Builder.SetInsertPoint(I); @@ -508,7 +511,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Right fill the mask of bits for the operands to demand the most // significant bit and all those below it. - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) @@ -517,7 +520,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If low order bits are not demanded and known to be zero in one operand, // then we don't need to demand them from the other operand, since they // can't cause overflow into any bits that are demanded in the result. - unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); APInt DemandedFromLHS = DemandedFromOps; DemandedFromLHS.clearLowBits(NTZ); if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || @@ -539,7 +542,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, case Instruction::Sub: { // Right fill the mask of bits for the operands to demand the most // significant bit and all those below it. - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) @@ -548,7 +551,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If low order bits are not demanded and are known to be zero in RHS, // then we don't need to demand them from LHS, since they can't cause a // borrow from any bits that are demanded in the result. - unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); APInt DemandedFromLHS = DemandedFromOps; DemandedFromLHS.clearLowBits(NTZ); if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || @@ -578,10 +581,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. // If we demand exactly one bit N and we have "X * (C' << N)" where C' is // odd (has LSB set), then the left-shifted low bit of X is the answer. - unsigned CTZ = DemandedMask.countTrailingZeros(); + unsigned CTZ = DemandedMask.countr_zero(); const APInt *C; - if (match(I->getOperand(1), m_APInt(C)) && - C->countTrailingZeros() == CTZ) { + if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) { Constant *ShiftC = ConstantInt::get(VTy, CTZ); Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); return InsertNewInstWith(Shl, *I); @@ -619,7 +621,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); Value *X; Constant *C; - if (DemandedMask.countTrailingZeros() >= ShiftAmt && + if (DemandedMask.countr_zero() >= ShiftAmt && match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); @@ -642,29 +644,15 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - bool SignBitZero = Known.Zero.isSignBitSet(); - bool SignBitOne = Known.One.isSignBitSet(); - Known.Zero <<= ShiftAmt; - Known.One <<= ShiftAmt; - // low bits known zero. - if (ShiftAmt) - Known.Zero.setLowBits(ShiftAmt); - - // If this shift has "nsw" keyword, then the result is either a poison - // value or has the same sign bit as the first operand. - if (IOp->hasNoSignedWrap()) { - if (SignBitZero) - Known.Zero.setSignBit(); - else if (SignBitOne) - Known.One.setSignBit(); - if (Known.hasConflict()) - return UndefValue::get(VTy); - } + Known = KnownBits::shl(Known, + KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), + /* NUW */ IOp->hasNoUnsignedWrap(), + /* NSW */ IOp->hasNoSignedWrap()); } else { // This is a variable shift, so we can't shift the demand mask by a known // amount. But if we are not demanding high bits, then we are not // demanding those bits from the pre-shifted operand either. - if (unsigned CTLZ = DemandedMask.countLeadingZeros()) { + if (unsigned CTLZ = DemandedMask.countl_zero()) { APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ)); if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) { // We can't guarantee that nsw/nuw hold after simplifying the operand. @@ -683,11 +671,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If we are just demanding the shifted sign bit and below, then this can // be treated as an ASHR in disguise. - if (DemandedMask.countLeadingZeros() >= ShiftAmt) { + if (DemandedMask.countl_zero() >= ShiftAmt) { // If we only want bits that already match the signbit then we don't // need to shift. - unsigned NumHiDemandedBits = - BitWidth - DemandedMask.countTrailingZeros(); + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); if (SignBits >= NumHiDemandedBits) @@ -734,7 +721,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If we only want bits that already match the signbit then we don't need // to shift. - unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros(); + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); if (SignBits >= NumHiDemandedBits) return I->getOperand(0); @@ -757,7 +744,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); // If any of the high bits are demanded, we should set the sign bit as // demanded. - if (DemandedMask.countLeadingZeros() <= ShiftAmt) + if (DemandedMask.countl_zero() <= ShiftAmt) DemandedMaskIn.setSignBit(); // If the shift is exact, then it does demand the low bits (and knows that @@ -797,7 +784,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { // TODO: Take the demanded mask of the result into account. - unsigned RHSTrailingZeros = SA->countTrailingZeros(); + unsigned RHSTrailingZeros = SA->countr_zero(); APInt DemandedMaskIn = APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) { @@ -807,9 +794,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; } - // Increase high zero bits from the input. - Known.Zero.setHighBits(std::min( - BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); + Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA), + cast<BinaryOperator>(I)->isExact()); } else { computeKnownBits(I, Known, Depth, CxtI); } @@ -851,25 +837,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } } - // The sign bit is the LHS's sign bit, except when the result of the - // remainder is zero. - if (DemandedMask.isSignBitSet()) { - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - // If it's known zero, our sign bit is also zero. - if (LHSKnown.isNonNegative()) - Known.makeNonNegative(); - } + computeKnownBits(I, Known, Depth, CxtI); break; } case Instruction::URem: { - KnownBits Known2(BitWidth); APInt AllOnes = APInt::getAllOnes(BitWidth); - if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || - SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) + if (SimplifyDemandedBits(I, 0, AllOnes, LHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 1, AllOnes, RHSKnown, Depth + 1)) return I; - unsigned Leaders = Known2.countMinLeadingZeros(); - Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; + Known = KnownBits::urem(LHSKnown, RHSKnown); break; } case Instruction::Call: { @@ -897,8 +874,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, case Intrinsic::bswap: { // If the only bits demanded come from one byte of the bswap result, // just shift the input byte into position to eliminate the bswap. - unsigned NLZ = DemandedMask.countLeadingZeros(); - unsigned NTZ = DemandedMask.countTrailingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); + unsigned NTZ = DemandedMask.countr_zero(); // Round NTZ down to the next byte. If we have 11 trailing zeros, then // we need all the bits down to bit 8. Likewise, round NLZ. If we @@ -935,9 +912,28 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt)); APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); - if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) || - SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) - return I; + if (I->getOperand(0) != I->getOperand(1)) { + if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, + Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) + return I; + } else { // fshl is a rotate + // Avoid converting rotate into funnel shift. + // Only simplify if one operand is constant. + LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I); + if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) && + !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) { + replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One)); + return I; + } + + RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I); + if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) && + !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) { + replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One)); + return I; + } + } Known.Zero = LHSKnown.Zero.shl(ShiftAmt) | RHSKnown.Zero.lshr(BitWidth - ShiftAmt); @@ -951,7 +947,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // The lowest non-zero bit of DemandMask is higher than the highest // non-zero bit of C. const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); + unsigned CTZ = DemandedMask.countr_zero(); if (match(II->getArgOperand(1), m_APInt(C)) && CTZ >= C->getActiveBits()) return II->getArgOperand(0); @@ -963,9 +959,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // non-one bit of C. // This comes from using DeMorgans on the above umax example. const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); + unsigned CTZ = DemandedMask.countr_zero(); if (match(II->getArgOperand(1), m_APInt(C)) && - CTZ >= C->getBitWidth() - C->countLeadingOnes()) + CTZ >= C->getBitWidth() - C->countl_one()) return II->getArgOperand(0); break; } @@ -1014,6 +1010,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown & RHSKnown; + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1033,6 +1030,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown | RHSKnown; + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1054,6 +1052,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown ^ RHSKnown; + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1071,7 +1070,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Add: { - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); // If an operand adds zeros to every bit below the highest demanded bit, @@ -1084,10 +1083,13 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown); + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::Sub: { - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); // If an operand subtracts zeros from every bit below the highest demanded @@ -1096,6 +1098,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown); + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::AShr: { @@ -1541,7 +1547,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Found constant vector with single element - convert to insertelement. if (Op && Value) { Instruction *New = InsertElementInst::Create( - Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx), + Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx), Shuffle->getName()); InsertNewInstWith(New, *Shuffle); return New; @@ -1552,7 +1558,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, SmallVector<int, 16> Elts; for (unsigned i = 0; i < VWidth; ++i) { if (UndefElts[i]) - Elts.push_back(UndefMaskElem); + Elts.push_back(PoisonMaskElem); else Elts.push_back(Shuffle->getMaskValue(i)); } @@ -1653,7 +1659,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // corresponding input elements are undef. for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio); - if (SubUndef.countPopulation() == Ratio) + if (SubUndef.popcount() == Ratio) UndefElts.setBit(OutIdx); } } else { @@ -1712,6 +1718,54 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // UB/poison potential, but that should be refined. BinaryOperator *BO; if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) { + Value *X = BO->getOperand(0); + Value *Y = BO->getOperand(1); + + // Look for an equivalent binop except that one operand has been shuffled. + // If the demand for this binop only includes elements that are the same as + // the other binop, then we may be able to replace this binop with a use of + // the earlier one. + // + // Example: + // %other_bo = bo (shuf X, {0}), Y + // %this_extracted_bo = extelt (bo X, Y), 0 + // --> + // %other_bo = bo (shuf X, {0}), Y + // %this_extracted_bo = extelt %other_bo, 0 + // + // TODO: Handle demand of an arbitrary single element or more than one + // element instead of just element 0. + // TODO: Unlike general demanded elements transforms, this should be safe + // for any (div/rem/shift) opcode too. + if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() && + BO->hasOneUse() ) { + + auto findShufBO = [&](bool MatchShufAsOp0) -> User * { + // Try to use shuffle-of-operand in place of an operand: + // bo X, Y --> bo (shuf X), Y + // bo X, Y --> bo X, (shuf Y) + BinaryOperator::BinaryOps Opcode = BO->getOpcode(); + Value *ShufOp = MatchShufAsOp0 ? X : Y; + Value *OtherOp = MatchShufAsOp0 ? Y : X; + for (User *U : OtherOp->users()) { + auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_ZeroMask()); + if (BO->isCommutative() + ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp))) + : MatchShufAsOp0 + ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp))) + : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf))) + if (DT.dominates(U, I)) + return U; + } + return nullptr; + }; + + if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true)) + return ShufBO; + if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false)) + return ShufBO; + } + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); simplifyAndSetOp(I, 1, DemandedElts, UndefElts2); @@ -1723,7 +1777,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // If we've proven all of the lanes undef, return an undef value. // TODO: Intersect w/demanded lanes if (UndefElts.isAllOnes()) - return UndefValue::get(I->getType());; + return UndefValue::get(I->getType()); return MadeChange ? I : nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 61e62adbe327..4a5ffef2b08e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -171,8 +171,11 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, } } - for (auto *E : Extracts) + for (auto *E : Extracts) { replaceInstUsesWith(*E, scalarPHI); + // Add old extract to worklist for DCE. + addToWorklist(E); + } return &EI; } @@ -384,7 +387,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) { /// return it with the canonical type if it isn't already canonical. We /// arbitrarily pick 64 bit as our canonical type. The actual bitwidth doesn't /// matter, we just want a consistent type to simplify CSE. -ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) { +static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) { const unsigned IndexBW = IndexC->getType()->getBitWidth(); if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64) return nullptr; @@ -543,16 +546,16 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { ->getNumElements(); if (SrcIdx < 0) - return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); + return replaceInstUsesWith(EI, PoisonValue::get(EI.getType())); if (SrcIdx < (int)LHSWidth) Src = SVI->getOperand(0); else { SrcIdx -= LHSWidth; Src = SVI->getOperand(1); } - Type *Int32Ty = Type::getInt32Ty(EI.getContext()); + Type *Int64Ty = Type::getInt64Ty(EI.getContext()); return ExtractElementInst::Create( - Src, ConstantInt::get(Int32Ty, SrcIdx, false)); + Src, ConstantInt::get(Int64Ty, SrcIdx, false)); } } else if (auto *CI = dyn_cast<CastInst>(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). @@ -594,6 +597,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { SrcVec, DemandedElts, UndefElts, 0 /* Depth */, true /* AllowMultipleUsers */)) { if (V != SrcVec) { + Worklist.addValue(SrcVec); SrcVec->replaceAllUsesWith(V); return &EI; } @@ -640,11 +644,11 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, return false; unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); - if (isa<UndefValue>(ScalarOp)) { // inserting undef into vector. + if (isa<PoisonValue>(ScalarOp)) { // inserting poison into vector. // We can handle this if the vector we are inserting into is // transitively ok. if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { - // If so, update the mask to reflect the inserted undef. + // If so, update the mask to reflect the inserted poison. Mask[InsertedIdx] = -1; return true; } @@ -680,7 +684,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, /// If we have insertion into a vector that is wider than the vector that we /// are extracting from, try to widen the source vector to allow a single /// shufflevector to replace one or more insert/extract pairs. -static void replaceExtractElements(InsertElementInst *InsElt, +static bool replaceExtractElements(InsertElementInst *InsElt, ExtractElementInst *ExtElt, InstCombinerImpl &IC) { auto *InsVecType = cast<FixedVectorType>(InsElt->getType()); @@ -691,7 +695,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, // The inserted-to vector must be wider than the extracted-from vector. if (InsVecType->getElementType() != ExtVecType->getElementType() || NumExtElts >= NumInsElts) - return; + return false; // Create a shuffle mask to widen the extended-from vector using poison // values. The mask selects all of the values of the original vector followed @@ -719,7 +723,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, // that will delete our widening shuffle. This would trigger another attempt // here to create that shuffle, and we spin forever. if (InsertionBlock != InsElt->getParent()) - return; + return false; // TODO: This restriction matches the check in visitInsertElementInst() and // prevents an infinite loop caused by not turning the extract/insert pair @@ -727,7 +731,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, // folds for shufflevectors because we're afraid to generate shuffle masks // that the backend can't handle. if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) - return; + return false; auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask); @@ -747,9 +751,14 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (!OldExt || OldExt->getParent() != WideVec->getParent()) continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); - NewExt->insertAfter(OldExt); + IC.InsertNewInstWith(NewExt, *OldExt); IC.replaceInstUsesWith(*OldExt, NewExt); + // Add the old extracts to the worklist for DCE. We can't remove the + // extracts directly, because they may still be used by the calling code. + IC.addToWorklist(OldExt); } + + return true; } /// We are building a shuffle to create V, which is a sequence of insertelement, @@ -764,7 +773,7 @@ using ShuffleOps = std::pair<Value *, Value *>; static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, Value *PermittedRHS, - InstCombinerImpl &IC) { + InstCombinerImpl &IC, bool &Rerun) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); @@ -795,13 +804,14 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, // otherwise we'd end up with a shuffle of three inputs. if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) { Value *RHS = EI->getOperand(0); - ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC); + ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC, Rerun); assert(LR.second == nullptr || LR.second == RHS); if (LR.first->getType() != RHS->getType()) { // Although we are giving up for now, see if we can create extracts // that match the inserts for another round of combining. - replaceExtractElements(IEI, EI, IC); + if (replaceExtractElements(IEI, EI, IC)) + Rerun = true; // We tried our best, but we can't find anything compatible with RHS // further up the chain. Return a trivial shuffle. @@ -1129,6 +1139,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( /// It should be transformed to: /// %0 = insertvalue { i8, i32 } undef, i8 %y, 0 Instruction *InstCombinerImpl::visitInsertValueInst(InsertValueInst &I) { + if (Value *V = simplifyInsertValueInst( + I.getAggregateOperand(), I.getInsertedValueOperand(), I.getIndices(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + bool IsRedundant = false; ArrayRef<unsigned int> FirstIndices = I.getIndices(); @@ -1235,22 +1250,22 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { if (FirstIE == &InsElt) return nullptr; - // If we are not inserting into an undef vector, make sure we've seen an + // If we are not inserting into a poison vector, make sure we've seen an // insert into every element. // TODO: If the base vector is not undef, it might be better to create a splat // and then a select-shuffle (blend) with the base vector. - if (!match(FirstIE->getOperand(0), m_Undef())) + if (!match(FirstIE->getOperand(0), m_Poison())) if (!ElementPresent.all()) return nullptr; // Create the insert + shuffle. - Type *Int32Ty = Type::getInt32Ty(InsElt.getContext()); + Type *Int64Ty = Type::getInt64Ty(InsElt.getContext()); PoisonValue *PoisonVec = PoisonValue::get(VecTy); - Constant *Zero = ConstantInt::get(Int32Ty, 0); + Constant *Zero = ConstantInt::get(Int64Ty, 0); if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) FirstIE = InsertElementInst::Create(PoisonVec, SplatVal, Zero, "", &InsElt); - // Splat from element 0, but replace absent elements with undef in the mask. + // Splat from element 0, but replace absent elements with poison in the mask. SmallVector<int, 16> Mask(NumElements, 0); for (unsigned i = 0; i != NumElements; ++i) if (!ElementPresent[i]) @@ -1339,7 +1354,7 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { // (demanded elements analysis may unset it later). return nullptr; } else { - assert(OldMask[i] == UndefMaskElem && + assert(OldMask[i] == PoisonMaskElem && "Unexpected shuffle mask element for identity shuffle"); NewMask[i] = IdxC; } @@ -1465,10 +1480,10 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { } ++ValI; } - // Remaining values are filled with 'undef' values. + // Remaining values are filled with 'poison' values. for (unsigned I = 0; I < NumElts; ++I) { if (!Values[I]) { - Values[I] = UndefValue::get(InsElt.getType()->getElementType()); + Values[I] = PoisonValue::get(InsElt.getType()->getElementType()); Mask[I] = I; } } @@ -1676,16 +1691,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { // Try to form a shuffle from a chain of extract-insert ops. if (isShuffleRootCandidate(IE)) { - SmallVector<int, 16> Mask; - ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); - - // The proposed shuffle may be trivial, in which case we shouldn't - // perform the combine. - if (LR.first != &IE && LR.second != &IE) { - // We now have a shuffle of LHS, RHS, Mask. - if (LR.second == nullptr) - LR.second = UndefValue::get(LR.first->getType()); - return new ShuffleVectorInst(LR.first, LR.second, Mask); + bool Rerun = true; + while (Rerun) { + Rerun = false; + + SmallVector<int, 16> Mask; + ShuffleOps LR = + collectShuffleElements(&IE, Mask, nullptr, *this, Rerun); + + // The proposed shuffle may be trivial, in which case we shouldn't + // perform the combine. + if (LR.first != &IE && LR.second != &IE) { + // We now have a shuffle of LHS, RHS, Mask. + if (LR.second == nullptr) + LR.second = PoisonValue::get(LR.first->getType()); + return new ShuffleVectorInst(LR.first, LR.second, Mask); + } } } } @@ -1815,9 +1836,9 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, /// Rebuild a new instruction just like 'I' but with the new operands given. /// In the event of type mismatch, the type of the operands is correct. -static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { - // We don't want to use the IRBuilder here because we want the replacement - // instructions to appear next to 'I', not the builder's insertion point. +static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps, + IRBuilderBase &Builder) { + Builder.SetInsertPoint(I); switch (I->getOpcode()) { case Instruction::Add: case Instruction::FAdd: @@ -1839,28 +1860,29 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { case Instruction::Xor: { BinaryOperator *BO = cast<BinaryOperator>(I); assert(NewOps.size() == 2 && "binary operator with #ops != 2"); - BinaryOperator *New = - BinaryOperator::Create(cast<BinaryOperator>(I)->getOpcode(), - NewOps[0], NewOps[1], "", BO); - if (isa<OverflowingBinaryOperator>(BO)) { - New->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap()); - New->setHasNoSignedWrap(BO->hasNoSignedWrap()); - } - if (isa<PossiblyExactOperator>(BO)) { - New->setIsExact(BO->isExact()); + Value *New = Builder.CreateBinOp(cast<BinaryOperator>(I)->getOpcode(), + NewOps[0], NewOps[1]); + if (auto *NewI = dyn_cast<Instruction>(New)) { + if (isa<OverflowingBinaryOperator>(BO)) { + NewI->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap()); + NewI->setHasNoSignedWrap(BO->hasNoSignedWrap()); + } + if (isa<PossiblyExactOperator>(BO)) { + NewI->setIsExact(BO->isExact()); + } + if (isa<FPMathOperator>(BO)) + NewI->copyFastMathFlags(I); } - if (isa<FPMathOperator>(BO)) - New->copyFastMathFlags(I); return New; } case Instruction::ICmp: assert(NewOps.size() == 2 && "icmp with #ops != 2"); - return new ICmpInst(I, cast<ICmpInst>(I)->getPredicate(), - NewOps[0], NewOps[1]); + return Builder.CreateICmp(cast<ICmpInst>(I)->getPredicate(), NewOps[0], + NewOps[1]); case Instruction::FCmp: assert(NewOps.size() == 2 && "fcmp with #ops != 2"); - return new FCmpInst(I, cast<FCmpInst>(I)->getPredicate(), - NewOps[0], NewOps[1]); + return Builder.CreateFCmp(cast<FCmpInst>(I)->getPredicate(), NewOps[0], + NewOps[1]); case Instruction::Trunc: case Instruction::ZExt: case Instruction::SExt: @@ -1876,27 +1898,26 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { I->getType()->getScalarType(), cast<VectorType>(NewOps[0]->getType())->getElementCount()); assert(NewOps.size() == 1 && "cast with #ops != 1"); - return CastInst::Create(cast<CastInst>(I)->getOpcode(), NewOps[0], DestTy, - "", I); + return Builder.CreateCast(cast<CastInst>(I)->getOpcode(), NewOps[0], + DestTy); } case Instruction::GetElementPtr: { Value *Ptr = NewOps[0]; ArrayRef<Value*> Idx = NewOps.slice(1); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - cast<GetElementPtrInst>(I)->getSourceElementType(), Ptr, Idx, "", I); - GEP->setIsInBounds(cast<GetElementPtrInst>(I)->isInBounds()); - return GEP; + return Builder.CreateGEP(cast<GEPOperator>(I)->getSourceElementType(), + Ptr, Idx, "", + cast<GEPOperator>(I)->isInBounds()); } } llvm_unreachable("failed to rebuild vector instructions"); } -static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { +static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask, + IRBuilderBase &Builder) { // Mask.size() does not need to be equal to the number of vector elements. assert(V->getType()->isVectorTy() && "can't reorder non-vector elements"); Type *EltTy = V->getType()->getScalarType(); - Type *I32Ty = IntegerType::getInt32Ty(V->getContext()); if (match(V, m_Undef())) return UndefValue::get(FixedVectorType::get(EltTy, Mask.size())); @@ -1950,15 +1971,14 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // as well. E.g. GetElementPtr may have scalar operands even if the // return value is a vector, so we need to examine the operand type. if (I->getOperand(i)->getType()->isVectorTy()) - V = evaluateInDifferentElementOrder(I->getOperand(i), Mask); + V = evaluateInDifferentElementOrder(I->getOperand(i), Mask, Builder); else V = I->getOperand(i); NewOps.push_back(V); NeedsRebuild |= (V != I->getOperand(i)); } - if (NeedsRebuild) { - return buildNew(I, NewOps); - } + if (NeedsRebuild) + return buildNew(I, NewOps, Builder); return I; } case Instruction::InsertElement: { @@ -1979,11 +1999,12 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // If element is not in Mask, no need to handle the operand 1 (element to // be inserted). Just evaluate values in operand 0 according to Mask. if (!Found) - return evaluateInDifferentElementOrder(I->getOperand(0), Mask); + return evaluateInDifferentElementOrder(I->getOperand(0), Mask, Builder); - Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask); - return InsertElementInst::Create(V, I->getOperand(1), - ConstantInt::get(I32Ty, Index), "", I); + Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask, + Builder); + Builder.SetInsertPoint(I); + return Builder.CreateInsertElement(V, I->getOperand(1), Index); } } llvm_unreachable("failed to reorder elements of vector instruction!"); @@ -2140,7 +2161,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { ConstantExpr::getShuffleVector(IdC, C, Mask); bool MightCreatePoisonOrUB = - is_contained(Mask, UndefMaskElem) && + is_contained(Mask, PoisonMaskElem) && (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); if (MightCreatePoisonOrUB) NewC = InstCombiner::getSafeVectorConstantForBinop(BOpcode, NewC, true); @@ -2154,7 +2175,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { // An undef shuffle mask element may propagate as an undef constant element in // the new binop. That would produce poison where the original code might not. // If we already made a safe constant, then there's no danger. - if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) + if (is_contained(Mask, PoisonMaskElem) && !MightCreatePoisonOrUB) NewBO->dropPoisonGeneratingFlags(); return NewBO; } @@ -2178,8 +2199,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, // Insert into element 0 of an undef vector. UndefValue *UndefVec = UndefValue::get(Shuf.getType()); - Constant *Zero = Builder.getInt32(0); - Value *NewIns = Builder.CreateInsertElement(UndefVec, X, Zero); + Value *NewIns = Builder.CreateInsertElement(UndefVec, X, (uint64_t)0); // Splat from element 0. Any mask element that is undefined remains undefined. // For example: @@ -2189,7 +2209,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts, 0); for (unsigned i = 0; i != NumMaskElts; ++i) - if (Mask[i] == UndefMaskElem) + if (Mask[i] == PoisonMaskElem) NewMask[i] = Mask[i]; return new ShuffleVectorInst(NewIns, NewMask); @@ -2274,7 +2294,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { // mask element, the result is undefined, but it is not poison or undefined // behavior. That is not necessarily true for div/rem/shift. bool MightCreatePoisonOrUB = - is_contained(Mask, UndefMaskElem) && + is_contained(Mask, PoisonMaskElem) && (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); if (MightCreatePoisonOrUB) NewC = InstCombiner::getSafeVectorConstantForBinop(BOpc, NewC, @@ -2325,7 +2345,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { NewI->andIRFlags(B1); if (DropNSW) NewI->setHasNoSignedWrap(false); - if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) + if (is_contained(Mask, PoisonMaskElem) && !MightCreatePoisonOrUB) NewI->dropPoisonGeneratingFlags(); } return replaceInstUsesWith(Shuf, NewBO); @@ -2361,7 +2381,7 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, SrcType->getScalarSizeInBits() / DestType->getScalarSizeInBits(); ArrayRef<int> Mask = Shuf.getShuffleMask(); for (unsigned i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] == UndefMaskElem) + if (Mask[i] == PoisonMaskElem) continue; uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio; assert(LSBIndex <= INT32_MAX && "Overflowed 32-bits"); @@ -2407,37 +2427,51 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, return SelectInst::Create(NarrowCond, NarrowX, NarrowY); } -/// Canonicalize FP negate after shuffle. -static Instruction *foldFNegShuffle(ShuffleVectorInst &Shuf, - InstCombiner::BuilderTy &Builder) { - Instruction *FNeg0; +/// Canonicalize FP negate/abs after shuffle. +static Instruction *foldShuffleOfUnaryOps(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + auto *S0 = dyn_cast<Instruction>(Shuf.getOperand(0)); Value *X; - if (!match(Shuf.getOperand(0), m_CombineAnd(m_Instruction(FNeg0), - m_FNeg(m_Value(X))))) + if (!S0 || !match(S0, m_CombineOr(m_FNeg(m_Value(X)), m_FAbs(m_Value(X))))) return nullptr; - // shuffle (fneg X), Mask --> fneg (shuffle X, Mask) - if (FNeg0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { + bool IsFNeg = S0->getOpcode() == Instruction::FNeg; + + // Match 1-input (unary) shuffle. + // shuffle (fneg/fabs X), Mask --> fneg/fabs (shuffle X, Mask) + if (S0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { Value *NewShuf = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); - return UnaryOperator::CreateFNegFMF(NewShuf, FNeg0); + if (IsFNeg) + return UnaryOperator::CreateFNegFMF(NewShuf, S0); + + Function *FAbs = Intrinsic::getDeclaration(Shuf.getModule(), + Intrinsic::fabs, Shuf.getType()); + CallInst *NewF = CallInst::Create(FAbs, {NewShuf}); + NewF->setFastMathFlags(S0->getFastMathFlags()); + return NewF; } - Instruction *FNeg1; + // Match 2-input (binary) shuffle. + auto *S1 = dyn_cast<Instruction>(Shuf.getOperand(1)); Value *Y; - if (!match(Shuf.getOperand(1), m_CombineAnd(m_Instruction(FNeg1), - m_FNeg(m_Value(Y))))) + if (!S1 || !match(S1, m_CombineOr(m_FNeg(m_Value(Y)), m_FAbs(m_Value(Y)))) || + S0->getOpcode() != S1->getOpcode() || + (!S0->hasOneUse() && !S1->hasOneUse())) return nullptr; - // shuffle (fneg X), (fneg Y), Mask --> fneg (shuffle X, Y, Mask) - if (FNeg0->hasOneUse() || FNeg1->hasOneUse()) { - Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); - Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewShuf); - NewFNeg->copyIRFlags(FNeg0); - NewFNeg->andIRFlags(FNeg1); - return NewFNeg; + // shuf (fneg/fabs X), (fneg/fabs Y), Mask --> fneg/fabs (shuf X, Y, Mask) + Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); + Instruction *NewF; + if (IsFNeg) { + NewF = UnaryOperator::CreateFNeg(NewShuf); + } else { + Function *FAbs = Intrinsic::getDeclaration(Shuf.getModule(), + Intrinsic::fabs, Shuf.getType()); + NewF = CallInst::Create(FAbs, {NewShuf}); } - - return nullptr; + NewF->copyIRFlags(S0); + NewF->andIRFlags(S1); + return NewF; } /// Canonicalize casts after shuffle. @@ -2533,7 +2567,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { for (unsigned i = 0; i != NumElts; ++i) { int ExtractMaskElt = Shuf.getMaskValue(i); int MaskElt = Mask[i]; - NewMask[i] = ExtractMaskElt == UndefMaskElem ? ExtractMaskElt : MaskElt; + NewMask[i] = ExtractMaskElt == PoisonMaskElem ? ExtractMaskElt : MaskElt; } return new ShuffleVectorInst(X, Y, NewMask); } @@ -2699,7 +2733,8 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { // splatting the first element of the result of the BinOp Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) { if (!match(SVI.getOperand(1), m_Undef()) || - !match(SVI.getShuffleMask(), m_ZeroMask())) + !match(SVI.getShuffleMask(), m_ZeroMask()) || + !SVI.getOperand(0)->hasOneUse()) return nullptr; Value *Op0 = SVI.getOperand(0); @@ -2759,7 +2794,6 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { } ArrayRef<int> Mask = SVI.getShuffleMask(); - Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); // Peek through a bitcasted shuffle operand by scaling the mask. If the // simulated shuffle can simplify, then this shuffle is unnecessary: @@ -2815,7 +2849,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = narrowVectorSelect(SVI, Builder)) return I; - if (Instruction *I = foldFNegShuffle(SVI, Builder)) + if (Instruction *I = foldShuffleOfUnaryOps(SVI, Builder)) return I; if (Instruction *I = foldCastShuffle(SVI, Builder)) @@ -2840,7 +2874,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return I; if (match(RHS, m_Undef()) && canEvaluateShuffled(LHS, Mask)) { - Value *V = evaluateInDifferentElementOrder(LHS, Mask); + Value *V = evaluateInDifferentElementOrder(LHS, Mask, Builder); return replaceInstUsesWith(SVI, V); } @@ -2916,15 +2950,15 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { unsigned SrcElemsPerTgtElem = TgtElemBitWidth / SrcElemBitWidth; assert(SrcElemsPerTgtElem); BegIdx /= SrcElemsPerTgtElem; - bool BCAlreadyExists = NewBCs.find(CastSrcTy) != NewBCs.end(); + bool BCAlreadyExists = NewBCs.contains(CastSrcTy); auto *NewBC = BCAlreadyExists ? NewBCs[CastSrcTy] : Builder.CreateBitCast(V, CastSrcTy, SVI.getName() + ".bc"); if (!BCAlreadyExists) NewBCs[CastSrcTy] = NewBC; - auto *Ext = Builder.CreateExtractElement( - NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract"); + auto *Ext = Builder.CreateExtractElement(NewBC, BegIdx, + SVI.getName() + ".extract"); // The shufflevector isn't being replaced: the bitcast that used it // is. InstCombine will visit the newly-created instructions. replaceInstUsesWith(*BC, Ext); @@ -3042,7 +3076,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { for (unsigned i = 0; i < VWidth; ++i) { int eltMask; if (Mask[i] < 0) { - // This element is an undef value. + // This element is a poison value. eltMask = -1; } else if (Mask[i] < (int)LHSWidth) { // This element is from left hand side vector operand. @@ -3051,27 +3085,27 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // new mask value for the element. if (newLHS != LHS) { eltMask = LHSMask[Mask[i]]; - // If the value selected is an undef value, explicitly specify it + // If the value selected is an poison value, explicitly specify it // with a -1 mask value. - if (eltMask >= (int)LHSOp0Width && isa<UndefValue>(LHSOp1)) + if (eltMask >= (int)LHSOp0Width && isa<PoisonValue>(LHSOp1)) eltMask = -1; } else eltMask = Mask[i]; } else { // This element is from right hand side vector operand // - // If the value selected is an undef value, explicitly specify it + // If the value selected is a poison value, explicitly specify it // with a -1 mask value. (case 1) - if (match(RHS, m_Undef())) + if (match(RHS, m_Poison())) eltMask = -1; // If RHS is going to be replaced (case 3 or 4), calculate the // new mask value for the element. else if (newRHS != RHS) { eltMask = RHSMask[Mask[i]-LHSWidth]; - // If the value selected is an undef value, explicitly specify it + // If the value selected is an poison value, explicitly specify it // with a -1 mask value. if (eltMask >= (int)RHSOp0Width) { - assert(match(RHSShuffle->getOperand(1), m_Undef()) && + assert(match(RHSShuffle->getOperand(1), m_Poison()) && "should have been check above"); eltMask = -1; } @@ -3102,7 +3136,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // or is a splat, do the replacement. if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) { if (!newRHS) - newRHS = UndefValue::get(newLHS->getType()); + newRHS = PoisonValue::get(newLHS->getType()); return new ShuffleVectorInst(newLHS, newRHS, newMask); } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index fb6f4f96ea48..afd6e034f46d 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -33,8 +33,6 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/InstCombine.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -47,7 +45,6 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyBlockFrequencyInfo.h" @@ -70,6 +67,7 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" @@ -78,7 +76,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" @@ -117,6 +114,11 @@ using namespace llvm::PatternMatch; STATISTIC(NumWorklistIterations, "Number of instruction combining iterations performed"); +STATISTIC(NumOneIteration, "Number of functions with one iteration"); +STATISTIC(NumTwoIterations, "Number of functions with two iterations"); +STATISTIC(NumThreeIterations, "Number of functions with three iterations"); +STATISTIC(NumFourOrMoreIterations, + "Number of functions with four or more iterations"); STATISTIC(NumCombined , "Number of insts combined"); STATISTIC(NumConstProp, "Number of constant folds"); @@ -129,7 +131,6 @@ DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); // FIXME: these limits eventually should be as low as 2. -static constexpr unsigned InstCombineDefaultMaxIterations = 1000; #ifndef NDEBUG static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; #else @@ -144,11 +145,6 @@ static cl::opt<unsigned> MaxSinkNumUsers( "instcombine-max-sink-users", cl::init(32), cl::desc("Maximum number of undroppable users for instruction sinking")); -static cl::opt<unsigned> LimitMaxIterations( - "instcombine-max-iterations", - cl::desc("Limit the maximum number of instruction combining iterations"), - cl::init(InstCombineDefaultMaxIterations)); - static cl::opt<unsigned> InfiniteLoopDetectionThreshold( "instcombine-infinite-loop-threshold", cl::desc("Number of instruction combining iterations considered an " @@ -203,6 +199,10 @@ std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( return std::nullopt; } +bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return TTI.isValidAddrSpaceCast(FromAS, ToAS); +} + Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::emitGEPOffset(&Builder, DL, GEP); } @@ -360,13 +360,17 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) Type *DestTy = C1->getType(); Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); - Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); + Constant *FoldedC = + ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout()); + if (!FoldedC) + return false; + IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0)); IC.replaceOperand(*BinOp1, 1, FoldedC); return true; } -// Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast. +// Simplifies IntToPtr/PtrToInt RoundTrip Cast. // inttoptr ( ptrtoint (x) ) --> x Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { auto *IntToPtr = dyn_cast<IntToPtrInst>(Val); @@ -378,10 +382,8 @@ Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { CastTy->getPointerAddressSpace() == PtrToInt->getSrcTy()->getPointerAddressSpace() && DL.getTypeSizeInBits(PtrToInt->getSrcTy()) == - DL.getTypeSizeInBits(PtrToInt->getDestTy())) { - return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, - "", PtrToInt); - } + DL.getTypeSizeInBits(PtrToInt->getDestTy())) + return PtrToInt->getOperand(0); } return nullptr; } @@ -732,6 +734,207 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, return RetVal; } +// (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C)) +// IFF +// 1) the logic_shifts match +// 2) either both binops are binops and one is `and` or +// BinOp1 is `and` +// (logic_shift (inv_logic_shift C1, C), C) == C1 or +// +// -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C) +// +// (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt)) +// IFF +// 1) the logic_shifts match +// 2) BinOp1 == BinOp2 (if BinOp == `add`, then also requires `shl`). +// +// -> (BinOp (logic_shift (BinOp X, Y)), Mask) +Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { + auto IsValidBinOpc = [](unsigned Opc) { + switch (Opc) { + default: + return false; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + // Skip Sub as we only match constant masks which will canonicalize to use + // add. + return true; + } + }; + + // Check if we can distribute binop arbitrarily. `add` + `lshr` has extra + // constraints. + auto IsCompletelyDistributable = [](unsigned BinOpc1, unsigned BinOpc2, + unsigned ShOpc) { + return (BinOpc1 != Instruction::Add && BinOpc2 != Instruction::Add) || + ShOpc == Instruction::Shl; + }; + + auto GetInvShift = [](unsigned ShOpc) { + return ShOpc == Instruction::LShr ? Instruction::Shl : Instruction::LShr; + }; + + auto CanDistributeBinops = [&](unsigned BinOpc1, unsigned BinOpc2, + unsigned ShOpc, Constant *CMask, + Constant *CShift) { + // If the BinOp1 is `and` we don't need to check the mask. + if (BinOpc1 == Instruction::And) + return true; + + // For all other possible transfers we need complete distributable + // binop/shift (anything but `add` + `lshr`). + if (!IsCompletelyDistributable(BinOpc1, BinOpc2, ShOpc)) + return false; + + // If BinOp2 is `and`, any mask works (this only really helps for non-splat + // vecs, otherwise the mask will be simplified and the following check will + // handle it). + if (BinOpc2 == Instruction::And) + return true; + + // Otherwise, need mask that meets the below requirement. + // (logic_shift (inv_logic_shift Mask, ShAmt), ShAmt) == Mask + return ConstantExpr::get( + ShOpc, ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift), + CShift) == CMask; + }; + + auto MatchBinOp = [&](unsigned ShOpnum) -> Instruction * { + Constant *CMask, *CShift; + Value *X, *Y, *ShiftedX, *Mask, *Shift; + if (!match(I.getOperand(ShOpnum), + m_OneUse(m_LogicalShift(m_Value(Y), m_Value(Shift))))) + return nullptr; + if (!match(I.getOperand(1 - ShOpnum), + m_BinOp(m_Value(ShiftedX), m_Value(Mask)))) + return nullptr; + + if (!match(ShiftedX, + m_OneUse(m_LogicalShift(m_Value(X), m_Specific(Shift))))) + return nullptr; + + // Make sure we are matching instruction shifts and not ConstantExpr + auto *IY = dyn_cast<Instruction>(I.getOperand(ShOpnum)); + auto *IX = dyn_cast<Instruction>(ShiftedX); + if (!IY || !IX) + return nullptr; + + // LHS and RHS need same shift opcode + unsigned ShOpc = IY->getOpcode(); + if (ShOpc != IX->getOpcode()) + return nullptr; + + // Make sure binop is real instruction and not ConstantExpr + auto *BO2 = dyn_cast<Instruction>(I.getOperand(1 - ShOpnum)); + if (!BO2) + return nullptr; + + unsigned BinOpc = BO2->getOpcode(); + // Make sure we have valid binops. + if (!IsValidBinOpc(I.getOpcode()) || !IsValidBinOpc(BinOpc)) + return nullptr; + + // If BinOp1 == BinOp2 and it's bitwise or shl with add, then just + // distribute to drop the shift irrelevant of constants. + if (BinOpc == I.getOpcode() && + IsCompletelyDistributable(I.getOpcode(), BinOpc, ShOpc)) { + Value *NewBinOp2 = Builder.CreateBinOp(I.getOpcode(), X, Y); + Value *NewBinOp1 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(ShOpc), NewBinOp2, Shift); + return BinaryOperator::Create(I.getOpcode(), NewBinOp1, Mask); + } + + // Otherwise we can only distribute by constant shifting the mask, so + // ensure we have constants. + if (!match(Shift, m_ImmConstant(CShift))) + return nullptr; + if (!match(Mask, m_ImmConstant(CMask))) + return nullptr; + + // Check if we can distribute the binops. + if (!CanDistributeBinops(I.getOpcode(), BinOpc, ShOpc, CMask, CShift)) + return nullptr; + + Constant *NewCMask = ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift); + Value *NewBinOp2 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(BinOpc), X, NewCMask); + Value *NewBinOp1 = Builder.CreateBinOp(I.getOpcode(), Y, NewBinOp2); + return BinaryOperator::Create(static_cast<Instruction::BinaryOps>(ShOpc), + NewBinOp1, CShift); + }; + + if (Instruction *R = MatchBinOp(0)) + return R; + return MatchBinOp(1); +} + +// (Binop (zext C), (select C, T, F)) +// -> (select C, (binop 1, T), (binop 0, F)) +// +// (Binop (sext C), (select C, T, F)) +// -> (select C, (binop -1, T), (binop 0, F)) +// +// Attempt to simplify binary operations into a select with folded args, when +// one operand of the binop is a select instruction and the other operand is a +// zext/sext extension, whose value is the select condition. +Instruction * +InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { + // TODO: this simplification may be extended to any speculatable instruction, + // not just binops, and would possibly be handled better in FoldOpIntoSelect. + Instruction::BinaryOps Opc = I.getOpcode(); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *A, *CondVal, *TrueVal, *FalseVal; + Value *CastOp; + + auto MatchSelectAndCast = [&](Value *CastOp, Value *SelectOp) { + return match(CastOp, m_ZExtOrSExt(m_Value(A))) && + A->getType()->getScalarSizeInBits() == 1 && + match(SelectOp, m_Select(m_Value(CondVal), m_Value(TrueVal), + m_Value(FalseVal))); + }; + + // Make sure one side of the binop is a select instruction, and the other is a + // zero/sign extension operating on a i1. + if (MatchSelectAndCast(LHS, RHS)) + CastOp = LHS; + else if (MatchSelectAndCast(RHS, LHS)) + CastOp = RHS; + else + return nullptr; + + auto NewFoldedConst = [&](bool IsTrueArm, Value *V) { + bool IsCastOpRHS = (CastOp == RHS); + bool IsZExt = isa<ZExtInst>(CastOp); + Constant *C; + + if (IsTrueArm) { + C = Constant::getNullValue(V->getType()); + } else if (IsZExt) { + unsigned BitWidth = V->getType()->getScalarSizeInBits(); + C = Constant::getIntegerValue(V->getType(), APInt(BitWidth, 1)); + } else { + C = Constant::getAllOnesValue(V->getType()); + } + + return IsCastOpRHS ? Builder.CreateBinOp(Opc, V, C) + : Builder.CreateBinOp(Opc, C, V); + }; + + // If the value used in the zext/sext is the select condition, or the negated + // of the select condition, the binop can be simplified. + if (CondVal == A) + return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal), + NewFoldedConst(true, FalseVal)); + + if (match(A, m_Not(m_Specific(CondVal)))) + return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal), + NewFoldedConst(false, FalseVal)); + + return nullptr; +} + Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); @@ -948,6 +1151,7 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, /// Freely adapt every user of V as-if V was changed to !V. /// WARNING: only if canFreelyInvertAllUsersOf() said this can be done. void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { + assert(!isa<Constant>(I) && "Shouldn't invert users of constant"); for (User *U : make_early_inc_range(I->users())) { if (U == IgnoredUser) continue; // Don't consider this user. @@ -1033,63 +1237,39 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) { return SelectInst::Create(X, TVal, FVal); } -static Constant *constantFoldOperationIntoSelectOperand( - Instruction &I, SelectInst *SI, Value *SO) { - auto *ConstSO = dyn_cast<Constant>(SO); - if (!ConstSO) - return nullptr; - +static Constant *constantFoldOperationIntoSelectOperand(Instruction &I, + SelectInst *SI, + bool IsTrueArm) { SmallVector<Constant *> ConstOps; for (Value *Op : I.operands()) { - if (Op == SI) - ConstOps.push_back(ConstSO); - else if (auto *C = dyn_cast<Constant>(Op)) - ConstOps.push_back(C); - else - llvm_unreachable("Operands should be select or constant"); - } - return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); -} + CmpInst::Predicate Pred; + Constant *C = nullptr; + if (Op == SI) { + C = dyn_cast<Constant>(IsTrueArm ? SI->getTrueValue() + : SI->getFalseValue()); + } else if (match(SI->getCondition(), + m_ICmp(Pred, m_Specific(Op), m_Constant(C))) && + Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + isGuaranteedNotToBeUndefOrPoison(C)) { + // Pass + } else { + C = dyn_cast<Constant>(Op); + } + if (C == nullptr) + return nullptr; -static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, - InstCombiner::BuilderTy &Builder) { - if (auto *Cast = dyn_cast<CastInst>(&I)) - return Builder.CreateCast(Cast->getOpcode(), SO, I.getType()); - - if (auto *II = dyn_cast<IntrinsicInst>(&I)) { - assert(canConstantFoldCallTo(II, cast<Function>(II->getCalledOperand())) && - "Expected constant-foldable intrinsic"); - Intrinsic::ID IID = II->getIntrinsicID(); - if (II->arg_size() == 1) - return Builder.CreateUnaryIntrinsic(IID, SO); - - // This works for real binary ops like min/max (where we always expect the - // constant operand to be canonicalized as op1) and unary ops with a bonus - // constant argument like ctlz/cttz. - // TODO: Handle non-commutative binary intrinsics as below for binops. - assert(II->arg_size() == 2 && "Expected binary intrinsic"); - assert(isa<Constant>(II->getArgOperand(1)) && "Expected constant operand"); - return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); + ConstOps.push_back(C); } - if (auto *EI = dyn_cast<ExtractElementInst>(&I)) - return Builder.CreateExtractElement(SO, EI->getIndexOperand()); - - assert(I.isBinaryOp() && "Unexpected opcode for select folding"); - - // Figure out if the constant is the left or the right argument. - bool ConstIsRHS = isa<Constant>(I.getOperand(1)); - Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); - - Value *Op0 = SO, *Op1 = ConstOperand; - if (!ConstIsRHS) - std::swap(Op0, Op1); + return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); +} - Value *NewBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), Op0, - Op1, SO->getName() + ".op"); - if (auto *NewBOI = dyn_cast<Instruction>(NewBO)) - NewBOI->copyIRFlags(&I); - return NewBO; +static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, + Value *NewOp, InstCombiner &IC) { + Instruction *Clone = I.clone(); + Clone->replaceUsesOfWith(SI, NewOp); + IC.InsertNewInstBefore(Clone, *SI); + return Clone; } Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, @@ -1122,56 +1302,17 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return nullptr; } - // Test if a CmpInst instruction is used exclusively by a select as - // part of a minimum or maximum operation. If so, refrain from doing - // any other folding. This helps out other analyses which understand - // non-obfuscated minimum and maximum idioms, such as ScalarEvolution - // and CodeGen. And in this case, at least one of the comparison - // operands has at least one user besides the compare (the select), - // which would often largely negate the benefit of folding anyway. - if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { - if (CI->hasOneUse()) { - Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); - - // FIXME: This is a hack to avoid infinite looping with min/max patterns. - // We have to ensure that vector constants that only differ with - // undef elements are treated as equivalent. - auto areLooselyEqual = [](Value *A, Value *B) { - if (A == B) - return true; - - // Test for vector constants. - Constant *ConstA, *ConstB; - if (!match(A, m_Constant(ConstA)) || !match(B, m_Constant(ConstB))) - return false; - - // TODO: Deal with FP constants? - if (!A->getType()->isIntOrIntVectorTy() || A->getType() != B->getType()) - return false; - - // Compare for equality including undefs as equal. - auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB); - const APInt *C; - return match(Cmp, m_APIntAllowUndef(C)) && C->isOne(); - }; - - if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) || - (areLooselyEqual(FV, Op0) && areLooselyEqual(TV, Op1))) - return nullptr; - } - } - // Make sure that one of the select arms constant folds successfully. - Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV); - Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV); + Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); + Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); if (!NewTV && !NewFV) return nullptr; // Create an instruction for the arm that did not fold. if (!NewTV) - NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); + NewTV = foldOperationIntoSelectOperand(Op, SI, TV, *this); if (!NewFV) - NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); + NewFV = foldOperationIntoSelectOperand(Op, SI, FV, *this); return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } @@ -1263,6 +1404,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues()); InsertNewInstBefore(NewPN, *PN); NewPN->takeName(PN); + NewPN->setDebugLoc(PN->getDebugLoc()); // If we are going to have to insert a new computation, do so right before the // predecessor's terminator. @@ -1291,6 +1433,10 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { replaceInstUsesWith(*User, NewPN); eraseInstFromFunction(*User); } + + replaceAllDbgUsesWith(const_cast<PHINode &>(*PN), + const_cast<PHINode &>(*NewPN), + const_cast<PHINode &>(*PN), DT); return replaceInstUsesWith(I, NewPN); } @@ -1301,7 +1447,7 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { auto *Phi0 = dyn_cast<PHINode>(BO.getOperand(0)); auto *Phi1 = dyn_cast<PHINode>(BO.getOperand(1)); if (!Phi0 || !Phi1 || !Phi0->hasOneUse() || !Phi1->hasOneUse() || - Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2) + Phi0->getNumOperands() != Phi1->getNumOperands()) return nullptr; // TODO: Remove the restriction for binop being in the same block as the phis. @@ -1309,6 +1455,51 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { BO.getParent() != Phi1->getParent()) return nullptr; + // Fold if there is at least one specific constant value in phi0 or phi1's + // incoming values that comes from the same block and this specific constant + // value can be used to do optimization for specific binary operator. + // For example: + // %phi0 = phi i32 [0, %bb0], [%i, %bb1] + // %phi1 = phi i32 [%j, %bb0], [0, %bb1] + // %add = add i32 %phi0, %phi1 + // ==> + // %add = phi i32 [%j, %bb0], [%i, %bb1] + Constant *C = ConstantExpr::getBinOpIdentity(BO.getOpcode(), BO.getType(), + /*AllowRHSConstant*/ false); + if (C) { + SmallVector<Value *, 4> NewIncomingValues; + auto CanFoldIncomingValuePair = [&](std::tuple<Use &, Use &> T) { + auto &Phi0Use = std::get<0>(T); + auto &Phi1Use = std::get<1>(T); + if (Phi0->getIncomingBlock(Phi0Use) != Phi1->getIncomingBlock(Phi1Use)) + return false; + Value *Phi0UseV = Phi0Use.get(); + Value *Phi1UseV = Phi1Use.get(); + if (Phi0UseV == C) + NewIncomingValues.push_back(Phi1UseV); + else if (Phi1UseV == C) + NewIncomingValues.push_back(Phi0UseV); + else + return false; + return true; + }; + + if (all_of(zip(Phi0->operands(), Phi1->operands()), + CanFoldIncomingValuePair)) { + PHINode *NewPhi = + PHINode::Create(Phi0->getType(), Phi0->getNumOperands()); + assert(NewIncomingValues.size() == Phi0->getNumOperands() && + "The number of collected incoming values should equal the number " + "of the original PHINode operands!"); + for (unsigned I = 0; I < Phi0->getNumOperands(); I++) + NewPhi->addIncoming(NewIncomingValues[I], Phi0->getIncomingBlock(I)); + return NewPhi; + } + } + + if (Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2) + return nullptr; + // Match a pair of incoming constants for one of the predecessor blocks. BasicBlock *ConstBB, *OtherBB; Constant *C0, *C1; @@ -1374,28 +1565,6 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { return nullptr; } -/// Given a pointer type and a constant offset, determine whether or not there -/// is a sequence of GEP indices into the pointed type that will land us at the -/// specified offset. If so, fill them into NewIndices and return the resultant -/// element type, otherwise return null. -static Type *findElementAtOffset(PointerType *PtrTy, int64_t IntOffset, - SmallVectorImpl<Value *> &NewIndices, - const DataLayout &DL) { - // Only used by visitGEPOfBitcast(), which is skipped for opaque pointers. - Type *Ty = PtrTy->getNonOpaquePointerElementType(); - if (!Ty->isSized()) - return nullptr; - - APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), IntOffset); - SmallVector<APInt> Indices = DL.getGEPIndicesForOffset(Ty, Offset); - if (!Offset.isZero()) - return nullptr; - - for (const APInt &Index : Indices) - NewIndices.push_back(ConstantInt::get(PtrTy->getContext(), Index)); - return Ty; -} - static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { // If this GEP has only 0 indices, it is the same pointer as // Src. If Src is not a trivial GEP too, don't combine @@ -1406,248 +1575,6 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { return true; } -/// Return a value X such that Val = X * Scale, or null if none. -/// If the multiplication is known not to overflow, then NoSignedWrap is set. -Value *InstCombinerImpl::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { - assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!"); - assert(cast<IntegerType>(Val->getType())->getBitWidth() == - Scale.getBitWidth() && "Scale not compatible with value!"); - - // If Val is zero or Scale is one then Val = Val * Scale. - if (match(Val, m_Zero()) || Scale == 1) { - NoSignedWrap = true; - return Val; - } - - // If Scale is zero then it does not divide Val. - if (Scale.isMinValue()) - return nullptr; - - // Look through chains of multiplications, searching for a constant that is - // divisible by Scale. For example, descaling X*(Y*(Z*4)) by a factor of 4 - // will find the constant factor 4 and produce X*(Y*Z). Descaling X*(Y*8) by - // a factor of 4 will produce X*(Y*2). The principle of operation is to bore - // down from Val: - // - // Val = M1 * X || Analysis starts here and works down - // M1 = M2 * Y || Doesn't descend into terms with more - // M2 = Z * 4 \/ than one use - // - // Then to modify a term at the bottom: - // - // Val = M1 * X - // M1 = Z * Y || Replaced M2 with Z - // - // Then to work back up correcting nsw flags. - - // Op - the term we are currently analyzing. Starts at Val then drills down. - // Replaced with its descaled value before exiting from the drill down loop. - Value *Op = Val; - - // Parent - initially null, but after drilling down notes where Op came from. - // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the - // 0'th operand of Val. - std::pair<Instruction *, unsigned> Parent; - - // Set if the transform requires a descaling at deeper levels that doesn't - // overflow. - bool RequireNoSignedWrap = false; - - // Log base 2 of the scale. Negative if not a power of 2. - int32_t logScale = Scale.exactLogBase2(); - - for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { - // If Op is a constant divisible by Scale then descale to the quotient. - APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth. - APInt::sdivrem(CI->getValue(), Scale, Quotient, Remainder); - if (!Remainder.isMinValue()) - // Not divisible by Scale. - return nullptr; - // Replace with the quotient in the parent. - Op = ConstantInt::get(CI->getType(), Quotient); - NoSignedWrap = true; - break; - } - - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) { - if (BO->getOpcode() == Instruction::Mul) { - // Multiplication. - NoSignedWrap = BO->hasNoSignedWrap(); - if (RequireNoSignedWrap && !NoSignedWrap) - return nullptr; - - // There are three cases for multiplication: multiplication by exactly - // the scale, multiplication by a constant different to the scale, and - // multiplication by something else. - Value *LHS = BO->getOperand(0); - Value *RHS = BO->getOperand(1); - - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Multiplication by a constant. - if (CI->getValue() == Scale) { - // Multiplication by exactly the scale, replace the multiplication - // by its left-hand side in the parent. - Op = LHS; - break; - } - - // Otherwise drill down into the constant. - if (!Op->hasOneUse()) - return nullptr; - - Parent = std::make_pair(BO, 1); - continue; - } - - // Multiplication by something else. Drill down into the left-hand side - // since that's where the reassociate pass puts the good stuff. - if (!Op->hasOneUse()) - return nullptr; - - Parent = std::make_pair(BO, 0); - continue; - } - - if (logScale > 0 && BO->getOpcode() == Instruction::Shl && - isa<ConstantInt>(BO->getOperand(1))) { - // Multiplication by a power of 2. - NoSignedWrap = BO->hasNoSignedWrap(); - if (RequireNoSignedWrap && !NoSignedWrap) - return nullptr; - - Value *LHS = BO->getOperand(0); - int32_t Amt = cast<ConstantInt>(BO->getOperand(1))-> - getLimitedValue(Scale.getBitWidth()); - // Op = LHS << Amt. - - if (Amt == logScale) { - // Multiplication by exactly the scale, replace the multiplication - // by its left-hand side in the parent. - Op = LHS; - break; - } - if (Amt < logScale || !Op->hasOneUse()) - return nullptr; - - // Multiplication by more than the scale. Reduce the multiplying amount - // by the scale in the parent. - Parent = std::make_pair(BO, 1); - Op = ConstantInt::get(BO->getType(), Amt - logScale); - break; - } - } - - if (!Op->hasOneUse()) - return nullptr; - - if (CastInst *Cast = dyn_cast<CastInst>(Op)) { - if (Cast->getOpcode() == Instruction::SExt) { - // Op is sign-extended from a smaller type, descale in the smaller type. - unsigned SmallSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); - APInt SmallScale = Scale.trunc(SmallSize); - // Suppose Op = sext X, and we descale X as Y * SmallScale. We want to - // descale Op as (sext Y) * Scale. In order to have - // sext (Y * SmallScale) = (sext Y) * Scale - // some conditions need to hold however: SmallScale must sign-extend to - // Scale and the multiplication Y * SmallScale should not overflow. - if (SmallScale.sext(Scale.getBitWidth()) != Scale) - // SmallScale does not sign-extend to Scale. - return nullptr; - assert(SmallScale.exactLogBase2() == logScale); - // Require that Y * SmallScale must not overflow. - RequireNoSignedWrap = true; - - // Drill down through the cast. - Parent = std::make_pair(Cast, 0); - Scale = SmallScale; - continue; - } - - if (Cast->getOpcode() == Instruction::Trunc) { - // Op is truncated from a larger type, descale in the larger type. - // Suppose Op = trunc X, and we descale X as Y * sext Scale. Then - // trunc (Y * sext Scale) = (trunc Y) * Scale - // always holds. However (trunc Y) * Scale may overflow even if - // trunc (Y * sext Scale) does not, so nsw flags need to be cleared - // from this point up in the expression (see later). - if (RequireNoSignedWrap) - return nullptr; - - // Drill down through the cast. - unsigned LargeSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); - Parent = std::make_pair(Cast, 0); - Scale = Scale.sext(LargeSize); - if (logScale + 1 == (int32_t)Cast->getType()->getPrimitiveSizeInBits()) - logScale = -1; - assert(Scale.exactLogBase2() == logScale); - continue; - } - } - - // Unsupported expression, bail out. - return nullptr; - } - - // If Op is zero then Val = Op * Scale. - if (match(Op, m_Zero())) { - NoSignedWrap = true; - return Op; - } - - // We know that we can successfully descale, so from here on we can safely - // modify the IR. Op holds the descaled version of the deepest term in the - // expression. NoSignedWrap is 'true' if multiplying Op by Scale is known - // not to overflow. - - if (!Parent.first) - // The expression only had one term. - return Op; - - // Rewrite the parent using the descaled version of its operand. - assert(Parent.first->hasOneUse() && "Drilled down when more than one use!"); - assert(Op != Parent.first->getOperand(Parent.second) && - "Descaling was a no-op?"); - replaceOperand(*Parent.first, Parent.second, Op); - Worklist.push(Parent.first); - - // Now work back up the expression correcting nsw flags. The logic is based - // on the following observation: if X * Y is known not to overflow as a signed - // multiplication, and Y is replaced by a value Z with smaller absolute value, - // then X * Z will not overflow as a signed multiplication either. As we work - // our way up, having NoSignedWrap 'true' means that the descaled value at the - // current level has strictly smaller absolute value than the original. - Instruction *Ancestor = Parent.first; - do { - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Ancestor)) { - // If the multiplication wasn't nsw then we can't say anything about the - // value of the descaled multiplication, and we have to clear nsw flags - // from this point on up. - bool OpNoSignedWrap = BO->hasNoSignedWrap(); - NoSignedWrap &= OpNoSignedWrap; - if (NoSignedWrap != OpNoSignedWrap) { - BO->setHasNoSignedWrap(NoSignedWrap); - Worklist.push(Ancestor); - } - } else if (Ancestor->getOpcode() == Instruction::Trunc) { - // The fact that the descaled input to the trunc has smaller absolute - // value than the original input doesn't tell us anything useful about - // the absolute values of the truncations. - NoSignedWrap = false; - } - assert((Ancestor->getOpcode() != Instruction::SExt || NoSignedWrap) && - "Failed to keep proper track of nsw flags while drilling down?"); - - if (Ancestor == Val) - // Got to the top, all done! - return Val; - - // Move up one level in the expression. - assert(Ancestor->hasOneUse() && "Drilled down when more than one use!"); - Ancestor = Ancestor->user_back(); - } while (true); -} - Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (!isa<VectorType>(Inst.getType())) return nullptr; @@ -1748,9 +1675,9 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { // TODO: Allow arbitrary shuffles by shuffling after binop? // That might be legal, but we have to deal with poison. if (LShuf->isSelect() && - !is_contained(LShuf->getShuffleMask(), UndefMaskElem) && + !is_contained(LShuf->getShuffleMask(), PoisonMaskElem) && RShuf->isSelect() && - !is_contained(RShuf->getShuffleMask(), UndefMaskElem)) { + !is_contained(RShuf->getShuffleMask(), PoisonMaskElem)) { // Example: // LHS = shuffle V1, V2, <0, 5, 6, 3> // RHS = shuffle V2, V1, <0, 5, 6, 3> @@ -1991,50 +1918,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; - if (Src->getResultElementType() == GEP.getSourceElementType() && - Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && - Src->hasOneUse()) { - Value *GO1 = GEP.getOperand(1); - Value *SO1 = Src->getOperand(1); - - if (LI) { - // Try to reassociate loop invariant GEP chains to enable LICM. - if (Loop *L = LI->getLoopFor(GEP.getParent())) { - // Reassociate the two GEPs if SO1 is variant in the loop and GO1 is - // invariant: this breaks the dependence between GEPs and allows LICM - // to hoist the invariant part out of the loop. - if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) { - // The swapped GEPs are inbounds if both original GEPs are inbounds - // and the sign of the offsets is the same. For simplicity, only - // handle both offsets being non-negative. - bool IsInBounds = Src->isInBounds() && GEP.isInBounds() && - isKnownNonNegative(SO1, DL, 0, &AC, &GEP, &DT) && - isKnownNonNegative(GO1, DL, 0, &AC, &GEP, &DT); - // Put NewSrc at same location as %src. - Builder.SetInsertPoint(cast<Instruction>(Src)); - Value *NewSrc = Builder.CreateGEP(GEP.getSourceElementType(), - Src->getPointerOperand(), GO1, - Src->getName(), IsInBounds); - GetElementPtrInst *NewGEP = GetElementPtrInst::Create( - GEP.getSourceElementType(), NewSrc, {SO1}); - NewGEP->setIsInBounds(IsInBounds); - return NewGEP; - } - } - } - } - - // Note that if our source is a gep chain itself then we wait for that - // chain to be resolved before we perform this transformation. This - // avoids us creating a TON of code in some cases. - if (auto *SrcGEP = dyn_cast<GEPOperator>(Src->getOperand(0))) - if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP)) - return nullptr; // Wait until our source is folded to completion. - // For constant GEPs, use a more general offset-based folding approach. - // Only do this for opaque pointers, as the result element type may change. Type *PtrTy = Src->getType()->getScalarType(); - if (PtrTy->isOpaquePointerTy() && GEP.hasAllConstantIndices() && + if (GEP.hasAllConstantIndices() && (Src->hasOneUse() || Src->hasAllConstantIndices())) { // Split Src into a variable part and a constant suffix. gep_type_iterator GTI = gep_type_begin(*Src); @@ -2077,13 +1963,11 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, // If both GEP are constant-indexed, and cannot be merged in either way, // convert them to a GEP of i8. if (Src->hasAllConstantIndices()) - return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) - ? GetElementPtrInst::CreateInBounds( - Builder.getInt8Ty(), Src->getOperand(0), - Builder.getInt(OffsetOld), GEP.getName()) - : GetElementPtrInst::Create( - Builder.getInt8Ty(), Src->getOperand(0), - Builder.getInt(OffsetOld), GEP.getName()); + return replaceInstUsesWith( + GEP, Builder.CreateGEP( + Builder.getInt8Ty(), Src->getOperand(0), + Builder.getInt(OffsetOld), "", + isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)))); return nullptr; } @@ -2100,13 +1984,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, IsInBounds &= Idx.isNonNegative() == ConstIndices[0].isNonNegative(); } - return IsInBounds - ? GetElementPtrInst::CreateInBounds(Src->getSourceElementType(), - Src->getOperand(0), Indices, - GEP.getName()) - : GetElementPtrInst::Create(Src->getSourceElementType(), - Src->getOperand(0), Indices, - GEP.getName()); + return replaceInstUsesWith( + GEP, Builder.CreateGEP(Src->getSourceElementType(), Src->getOperand(0), + Indices, "", IsInBounds)); } if (Src->getResultElementType() != GEP.getSourceElementType()) @@ -2160,118 +2040,10 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, } if (!Indices.empty()) - return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) - ? GetElementPtrInst::CreateInBounds( - Src->getSourceElementType(), Src->getOperand(0), Indices, - GEP.getName()) - : GetElementPtrInst::Create(Src->getSourceElementType(), - Src->getOperand(0), Indices, - GEP.getName()); - - return nullptr; -} - -// Note that we may have also stripped an address space cast in between. -Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, - GetElementPtrInst &GEP) { - // With opaque pointers, there is no pointer element type we can use to - // adjust the GEP type. - PointerType *SrcType = cast<PointerType>(BCI->getSrcTy()); - if (SrcType->isOpaque()) - return nullptr; - - Type *GEPEltType = GEP.getSourceElementType(); - Type *SrcEltType = SrcType->getNonOpaquePointerElementType(); - Value *SrcOp = BCI->getOperand(0); - - // GEP directly using the source operand if this GEP is accessing an element - // of a bitcasted pointer to vector or array of the same dimensions: - // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z - // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z - auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy, - const DataLayout &DL) { - auto *VecVTy = cast<FixedVectorType>(VecTy); - return ArrTy->getArrayElementType() == VecVTy->getElementType() && - ArrTy->getArrayNumElements() == VecVTy->getNumElements() && - DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy); - }; - if (GEP.getNumOperands() == 3 && - ((GEPEltType->isArrayTy() && isa<FixedVectorType>(SrcEltType) && - areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) || - (isa<FixedVectorType>(GEPEltType) && SrcEltType->isArrayTy() && - areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) { - - // Create a new GEP here, as using `setOperand()` followed by - // `setSourceElementType()` won't actually update the type of the - // existing GEP Value. Causing issues if this Value is accessed when - // constructing an AddrSpaceCastInst - SmallVector<Value *, 8> Indices(GEP.indices()); - Value *NGEP = - Builder.CreateGEP(SrcEltType, SrcOp, Indices, "", GEP.isInBounds()); - NGEP->takeName(&GEP); - - // Preserve GEP address space to satisfy users - if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(NGEP, GEP.getType()); - - return replaceInstUsesWith(GEP, NGEP); - } - - // See if we can simplify: - // X = bitcast A* to B* - // Y = gep X, <...constant indices...> - // into a gep of the original struct. This is important for SROA and alias - // analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. - unsigned OffsetBits = DL.getIndexTypeSizeInBits(GEP.getType()); - APInt Offset(OffsetBits, 0); - - // If the bitcast argument is an allocation, The bitcast is for convertion - // to actual type of allocation. Removing such bitcasts, results in having - // GEPs with i8* base and pure byte offsets. That means GEP is not aware of - // struct or array hierarchy. - // By avoiding such GEPs, phi translation and MemoryDependencyAnalysis have - // a better chance to succeed. - if (!isa<BitCastInst>(SrcOp) && GEP.accumulateConstantOffset(DL, Offset) && - !isAllocationFn(SrcOp, &TLI)) { - // If this GEP instruction doesn't move the pointer, just replace the GEP - // with a bitcast of the real input to the dest type. - if (!Offset) { - // If the bitcast is of an allocation, and the allocation will be - // converted to match the type of the cast, don't touch this. - if (isa<AllocaInst>(SrcOp)) { - // See if the bitcast simplifies, if so, don't nuke this GEP yet. - if (Instruction *I = visitBitCast(*BCI)) { - if (I != BCI) { - I->takeName(BCI); - I->insertInto(BCI->getParent(), BCI->getIterator()); - replaceInstUsesWith(*BCI, I); - } - return &GEP; - } - } - - if (SrcType->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(SrcOp, GEP.getType()); - return new BitCastInst(SrcOp, GEP.getType()); - } - - // Otherwise, if the offset is non-zero, we need to find out if there is a - // field at Offset in 'A's type. If so, we can pull the cast through the - // GEP. - SmallVector<Value *, 8> NewIndices; - if (findElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices, DL)) { - Value *NGEP = Builder.CreateGEP(SrcEltType, SrcOp, NewIndices, "", - GEP.isInBounds()); - - if (NGEP->getType() == GEP.getType()) - return replaceInstUsesWith(GEP, NGEP); - NGEP->takeName(&GEP); - - if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(NGEP, GEP.getType()); - return new BitCastInst(NGEP, GEP.getType()); - } - } + return replaceInstUsesWith( + GEP, Builder.CreateGEP( + Src->getSourceElementType(), Src->getOperand(0), Indices, "", + isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)))); return nullptr; } @@ -2497,192 +2269,6 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEPType->isVectorTy()) return nullptr; - // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). - Value *StrippedPtr = PtrOp->stripPointerCasts(); - PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType()); - - // TODO: The basic approach of these folds is not compatible with opaque - // pointers, because we can't use bitcasts as a hint for a desirable GEP - // type. Instead, we should perform canonicalization directly on the GEP - // type. For now, skip these. - if (StrippedPtr != PtrOp && !StrippedPtrTy->isOpaque()) { - bool HasZeroPointerIndex = false; - Type *StrippedPtrEltTy = StrippedPtrTy->getNonOpaquePointerElementType(); - - if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) - HasZeroPointerIndex = C->isZero(); - - // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... - // into : GEP [10 x i8]* X, i32 0, ... - // - // Likewise, transform: GEP (bitcast i8* X to [0 x i8]*), i32 0, ... - // into : GEP i8* X, ... - // - // This occurs when the program declares an array extern like "int X[];" - if (HasZeroPointerIndex) { - if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) { - // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? - if (CATy->getElementType() == StrippedPtrEltTy) { - // -> GEP i8* X, ... - SmallVector<Value *, 8> Idx(drop_begin(GEP.indices())); - GetElementPtrInst *Res = GetElementPtrInst::Create( - StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName()); - Res->setIsInBounds(GEP.isInBounds()); - if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) - return Res; - // Insert Res, and create an addrspacecast. - // e.g., - // GEP (addrspacecast i8 addrspace(1)* X to [0 x i8]*), i32 0, ... - // -> - // %0 = GEP i8 addrspace(1)* X, ... - // addrspacecast i8 addrspace(1)* %0 to i8* - return new AddrSpaceCastInst(Builder.Insert(Res), GEPType); - } - - if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrEltTy)) { - // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? - if (CATy->getElementType() == XATy->getElementType()) { - // -> GEP [10 x i8]* X, i32 0, ... - // At this point, we know that the cast source type is a pointer - // to an array of the same type as the destination pointer - // array. Because the array type is never stepped over (there - // is a leading zero) we can fold the cast into this GEP. - if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) { - GEP.setSourceElementType(XATy); - return replaceOperand(GEP, 0, StrippedPtr); - } - // Cannot replace the base pointer directly because StrippedPtr's - // address space is different. Instead, create a new GEP followed by - // an addrspacecast. - // e.g., - // GEP (addrspacecast [10 x i8] addrspace(1)* X to [0 x i8]*), - // i32 0, ... - // -> - // %0 = GEP [10 x i8] addrspace(1)* X, ... - // addrspacecast i8 addrspace(1)* %0 to i8* - SmallVector<Value *, 8> Idx(GEP.indices()); - Value *NewGEP = - Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName(), GEP.isInBounds()); - return new AddrSpaceCastInst(NewGEP, GEPType); - } - } - } - } else if (GEP.getNumOperands() == 2 && !IsGEPSrcEleScalable) { - // Skip if GEP source element type is scalable. The type alloc size is - // unknown at compile-time. - // Transform things like: %t = getelementptr i32* - // bitcast ([2 x i32]* %str to i32*), i32 %V into: %t1 = getelementptr [2 - // x i32]* %str, i32 0, i32 %V; bitcast - if (StrippedPtrEltTy->isArrayTy() && - DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) == - DL.getTypeAllocSize(GEPEltType)) { - Type *IdxType = DL.getIndexType(GEPType); - Value *Idx[2] = {Constant::getNullValue(IdxType), GEP.getOperand(1)}; - Value *NewGEP = Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName(), GEP.isInBounds()); - - // V and GEP are both pointer types --> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); - } - - // Transform things like: - // %V = mul i64 %N, 4 - // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V - // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast - if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { - // Check that changing the type amounts to dividing the index by a scale - // factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); - uint64_t SrcSize = - DL.getTypeAllocSize(StrippedPtrEltTy).getFixedValue(); - if (ResSize && SrcSize % ResSize == 0) { - Value *Idx = GEP.getOperand(1); - unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); - uint64_t Scale = SrcSize / ResSize; - - // Earlier transforms ensure that the index has the right type - // according to Data Layout, which considerably simplifies the - // logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIndexType(GEPType) && - "Index type does not match the Data Layout preferences"); - - bool NSW; - if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { - // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. - // If the multiplication NewIdx * Scale may overflow then the new - // GEP may not be "inbounds". - Value *NewGEP = - Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx, - GEP.getName(), GEP.isInBounds() && NSW); - - // The NewGEP must be pointer typed, so must the old one -> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEPType); - } - } - } - - // Similarly, transform things like: - // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp - // (where tmp = 8*tmp2) into: - // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast - if (GEPEltType->isSized() && StrippedPtrEltTy->isSized() && - StrippedPtrEltTy->isArrayTy()) { - // Check that changing to the array element type amounts to dividing the - // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); - uint64_t ArrayEltSize = - DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) - .getFixedValue(); - if (ResSize && ArrayEltSize % ResSize == 0) { - Value *Idx = GEP.getOperand(1); - unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); - uint64_t Scale = ArrayEltSize / ResSize; - - // Earlier transforms ensure that the index has the right type - // according to the Data Layout, which considerably simplifies - // the logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIndexType(GEPType) && - "Index type does not match the Data Layout preferences"); - - bool NSW; - if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { - // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. - // If the multiplication NewIdx * Scale may overflow then the new - // GEP may not be "inbounds". - Type *IndTy = DL.getIndexType(GEPType); - Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; - - Value *NewGEP = - Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off, - GEP.getName(), GEP.isInBounds() && NSW); - // The NewGEP must be pointer typed, so must the old one -> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEPType); - } - } - } - } - } - - // addrspacecast between types is canonicalized as a bitcast, then an - // addrspacecast. To take advantage of the below bitcast + struct GEP, look - // through the addrspacecast. - Value *ASCStrippedPtrOp = PtrOp; - if (auto *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { - // X = bitcast A addrspace(1)* to B addrspace(1)* - // Y = addrspacecast A addrspace(1)* to B addrspace(2)* - // Z = gep Y, <...constant indices...> - // Into an addrspacecasted GEP of the struct. - if (auto *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) - ASCStrippedPtrOp = BC; - } - - if (auto *BCI = dyn_cast<BitCastInst>(ASCStrippedPtrOp)) - if (Instruction *I = visitGEPOfBitcast(BCI, GEP)) - return I; - if (!GEP.isInBounds()) { unsigned IdxWidth = DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); @@ -2690,12 +2276,13 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Value *UnderlyingPtrOp = PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, BasePtrOffset); - if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { + bool CanBeNull, CanBeFreed; + uint64_t DerefBytes = UnderlyingPtrOp->getPointerDereferenceableBytes( + DL, CanBeNull, CanBeFreed); + if (!CanBeNull && !CanBeFreed && DerefBytes != 0) { if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && BasePtrOffset.isNonNegative()) { - APInt AllocSize( - IdxWidth, - DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinValue()); + APInt AllocSize(IdxWidth, DerefBytes); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( GEP.getSourceElementType(), PtrOp, Indices, GEP.getName()); @@ -2881,8 +2468,11 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { if (II->getIntrinsicID() == Intrinsic::objectsize) { - Value *Result = - lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/true); + SmallVector<Instruction *> InsertedInstructions; + Value *Result = lowerObjectSizeCall( + II, DL, &TLI, AA, /*MustSucceed=*/true, &InsertedInstructions); + for (Instruction *Inserted : InsertedInstructions) + Worklist.add(Inserted); replaceInstUsesWith(*I, Result); eraseInstFromFunction(*I); Users[i] = nullptr; // Skip examining in the next loop. @@ -3089,50 +2679,27 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) { return nullptr; } -static bool isMustTailCall(Value *V) { - if (auto *CI = dyn_cast<CallInst>(V)) - return CI->isMustTailCall(); - return false; -} - Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) { - if (RI.getNumOperands() == 0) // ret void - return nullptr; - - Value *ResultOp = RI.getOperand(0); - Type *VTy = ResultOp->getType(); - if (!VTy->isIntegerTy() || isa<Constant>(ResultOp)) - return nullptr; - - // Don't replace result of musttail calls. - if (isMustTailCall(ResultOp)) - return nullptr; - - // There might be assume intrinsics dominating this return that completely - // determine the value. If so, constant fold it. - KnownBits Known = computeKnownBits(ResultOp, 0, &RI); - if (Known.isConstant()) - return replaceOperand(RI, 0, - Constant::getIntegerValue(VTy, Known.getConstant())); - + // Nothing for now. return nullptr; } // WARNING: keep in sync with SimplifyCFGOpt::simplifyUnreachable()! -Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { +bool InstCombinerImpl::removeInstructionsBeforeUnreachable(Instruction &I) { // Try to remove the previous instruction if it must lead to unreachable. // This includes instructions like stores and "llvm.assume" that may not get // removed by simple dead code elimination. + bool Changed = false; while (Instruction *Prev = I.getPrevNonDebugInstruction()) { // While we theoretically can erase EH, that would result in a block that // used to start with an EH no longer starting with EH, which is invalid. // To make it valid, we'd need to fixup predecessors to no longer refer to // this block, but that changes CFG, which is not allowed in InstCombine. if (Prev->isEHPad()) - return nullptr; // Can not drop any more instructions. We're done here. + break; // Can not drop any more instructions. We're done here. if (!isGuaranteedToTransferExecutionToSuccessor(Prev)) - return nullptr; // Can not drop any more instructions. We're done here. + break; // Can not drop any more instructions. We're done here. // Otherwise, this instruction can be freely erased, // even if it is not side-effect free. @@ -3140,9 +2707,13 @@ Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { // another unreachable block), so convert those to poison. replaceInstUsesWith(*Prev, PoisonValue::get(Prev->getType())); eraseInstFromFunction(*Prev); + Changed = true; } - assert(I.getParent()->sizeWithoutDebug() == 1 && "The block is now empty."); - // FIXME: recurse into unconditional predecessors? + return Changed; +} + +Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { + removeInstructionsBeforeUnreachable(I); return nullptr; } @@ -3175,6 +2746,57 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { return nullptr; } +// Under the assumption that I is unreachable, remove it and following +// instructions. +bool InstCombinerImpl::handleUnreachableFrom(Instruction *I) { + bool Changed = false; + BasicBlock *BB = I->getParent(); + for (Instruction &Inst : make_early_inc_range( + make_range(std::next(BB->getTerminator()->getReverseIterator()), + std::next(I->getReverseIterator())))) { + if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) { + replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType())); + Changed = true; + } + if (Inst.isEHPad() || Inst.getType()->isTokenTy()) + continue; + eraseInstFromFunction(Inst); + Changed = true; + } + + // Replace phi node operands in successor blocks with poison. + for (BasicBlock *Succ : successors(BB)) + for (PHINode &PN : Succ->phis()) + for (Use &U : PN.incoming_values()) + if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) { + replaceUse(U, PoisonValue::get(PN.getType())); + addToWorklist(&PN); + Changed = true; + } + + // TODO: Successor blocks may also be dead. + return Changed; +} + +bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB, + BasicBlock *LiveSucc) { + bool Changed = false; + for (BasicBlock *Succ : successors(BB)) { + // The live successor isn't dead. + if (Succ == LiveSucc) + continue; + + if (!all_of(predecessors(Succ), [&](BasicBlock *Pred) { + return DT.dominates(BasicBlockEdge(BB, Succ), + BasicBlockEdge(Pred, Succ)); + })) + continue; + + Changed |= handleUnreachableFrom(&Succ->front()); + } + return Changed; +} + Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { if (BI.isUnconditional()) return visitUnconditionalBranchInst(BI); @@ -3218,6 +2840,14 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return &BI; } + if (isa<UndefValue>(Cond) && + handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr)) + return &BI; + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + if (handlePotentiallyDeadSuccessors(BI.getParent(), + BI.getSuccessor(!CI->getZExtValue()))) + return &BI; + return nullptr; } @@ -3236,6 +2866,14 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, Op0); } + if (isa<UndefValue>(Cond) && + handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr)) + return &SI; + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + if (handlePotentiallyDeadSuccessors( + SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor())) + return &SI; + KnownBits Known = computeKnownBits(Cond, 0, &SI); unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); @@ -3243,10 +2881,10 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { // Compute the number of leading bits we can ignore. // TODO: A better way to determine this would use ComputeNumSignBits(). for (const auto &C : SI.cases()) { - LeadingKnownZeros = std::min( - LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); - LeadingKnownOnes = std::min( - LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes()); + LeadingKnownZeros = + std::min(LeadingKnownZeros, C.getCaseValue()->getValue().countl_zero()); + LeadingKnownOnes = + std::min(LeadingKnownOnes, C.getCaseValue()->getValue().countl_one()); } unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); @@ -3412,6 +3050,11 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { return R; if (LoadInst *L = dyn_cast<LoadInst>(Agg)) { + // Bail out if the aggregate contains scalable vector type + if (auto *STy = dyn_cast<StructType>(Agg->getType()); + STy && STy->containsScalableVectorType()) + return nullptr; + // If the (non-volatile) load only has one use, we can rewrite this to a // load from a GEP. This reduces the size of the load. If a load is used // only by extractvalue instructions then this either must have been @@ -3965,6 +3608,17 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { return Changed; } +// Check if any direct or bitcast user of this value is a shuffle instruction. +static bool isUsedWithinShuffleVector(Value *V) { + for (auto *U : V->users()) { + if (isa<ShuffleVectorInst>(U)) + return true; + else if (match(U, m_BitCast(m_Specific(V))) && isUsedWithinShuffleVector(U)) + return true; + } + return false; +} + Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { Value *Op0 = I.getOperand(0); @@ -4014,8 +3668,14 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { return BestValue; }; - if (match(Op0, m_Undef())) + if (match(Op0, m_Undef())) { + // Don't fold freeze(undef/poison) if it's used as a vector operand in + // a shuffle. This may improve codegen for shuffles that allow + // unspecified inputs. + if (isUsedWithinShuffleVector(&I)) + return nullptr; return replaceInstUsesWith(I, getUndefReplacement(I.getType())); + } Constant *C; if (match(Op0, m_Constant(C)) && C->containsUndefOrPoisonElement()) { @@ -4078,8 +3738,8 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) { /// beginning of DestBlock, which can only happen if it's safe to move the /// instruction past all of the instructions between it and the end of its /// block. -static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, - TargetLibraryInfo &TLI) { +bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, + BasicBlock *DestBlock) { BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -4126,10 +3786,13 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, return false; } - I->dropDroppableUses([DestBlock](const Use *U) { - if (auto *I = dyn_cast<Instruction>(U->getUser())) - return I->getParent() != DestBlock; - return true; + I->dropDroppableUses([&](const Use *U) { + auto *I = dyn_cast<Instruction>(U->getUser()); + if (I && I->getParent() != DestBlock) { + Worklist.add(I); + return true; + } + return false; }); /// FIXME: We could remove droppable uses that are not dominated by /// the new position. @@ -4227,23 +3890,6 @@ bool InstCombinerImpl::run() { if (!DebugCounter::shouldExecute(VisitCounter)) continue; - // Instruction isn't dead, see if we can constant propagate it. - if (!I->use_empty() && - (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { - if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { - LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I - << '\n'); - - // Add operands to the worklist. - replaceInstUsesWith(*I, C); - ++NumConstProp; - if (isInstructionTriviallyDead(I, &TLI)) - eraseInstFromFunction(*I); - MadeIRChange = true; - continue; - } - } - // See if we can trivially sink this instruction to its user if we can // prove that the successor is not executed more frequently than our block. // Return the UserBlock if successful. @@ -4319,7 +3965,7 @@ bool InstCombinerImpl::run() { if (OptBB) { auto *UserParent = *OptBB; // Okay, the CFG is simple enough, try to sink this instruction. - if (TryToSinkInstruction(I, UserParent, TLI)) { + if (tryToSinkInstruction(I, UserParent)) { LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); MadeIRChange = true; // We'll add uses of the sunk instruction below, but since @@ -4520,15 +4166,21 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Recursively visit successors. If this is a branch or switch on a // constant, only visit the reachable successor. Instruction *TI = BB->getTerminator(); - if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { - if (BI->isConditional() && isa<ConstantInt>(BI->getCondition())) { - bool CondVal = cast<ConstantInt>(BI->getCondition())->getZExtValue(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI); BI && BI->isConditional()) { + if (isa<UndefValue>(BI->getCondition())) + // Branch on undef is UB. + continue; + if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) { + bool CondVal = Cond->getZExtValue(); BasicBlock *ReachableBB = BI->getSuccessor(!CondVal); Worklist.push_back(ReachableBB); continue; } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { + if (isa<UndefValue>(SI->getCondition())) + // Switch on undef is UB. + continue; + if (auto *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } @@ -4584,7 +4236,6 @@ static bool combineInstructionsOverFunction( DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { auto &DL = F.getParent()->getDataLayout(); - MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue()); /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. @@ -4601,13 +4252,6 @@ static bool combineInstructionsOverFunction( bool MadeIRChange = false; if (ShouldLowerDbgDeclare) MadeIRChange = LowerDbgDeclare(F); - // LowerDbgDeclare calls RemoveRedundantDbgInstrs, but LowerDbgDeclare will - // almost never return true when running an assignment tracking build. Take - // this opportunity to do some clean up for assignment tracking builds too. - if (!MadeIRChange && isAssignmentTrackingEnabled(*F.getParent())) { - for (auto &BB : F) - RemoveRedundantDbgInstrs(&BB); - } // Iterate while there is work to do. unsigned Iteration = 0; @@ -4643,13 +4287,29 @@ static bool combineInstructionsOverFunction( MadeIRChange = true; } + if (Iteration == 1) + ++NumOneIteration; + else if (Iteration == 2) + ++NumTwoIterations; + else if (Iteration == 3) + ++NumThreeIterations; + else + ++NumFourOrMoreIterations; + return MadeIRChange; } -InstCombinePass::InstCombinePass() : MaxIterations(LimitMaxIterations) {} +InstCombinePass::InstCombinePass(InstCombineOptions Opts) : Options(Opts) {} -InstCombinePass::InstCombinePass(unsigned MaxIterations) - : MaxIterations(MaxIterations) {} +void InstCombinePass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<InstCombinePass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << '<'; + OS << "max-iterations=" << Options.MaxIterations << ";"; + OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info"; + OS << '>'; +} PreservedAnalyses InstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { @@ -4659,7 +4319,11 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); + // TODO: Only use LoopInfo when the option is set. This requires that the + // callers in the pass pipeline explicitly set the option. auto *LI = AM.getCachedResult<LoopAnalysis>(F); + if (!LI && Options.UseLoopInfo) + LI = &AM.getResult<LoopAnalysis>(F); auto *AA = &AM.getResult<AAManager>(F); auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); @@ -4669,7 +4333,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, MaxIterations, LI)) + BFI, PSI, Options.MaxIterations, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -4718,18 +4382,13 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { nullptr; return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, MaxIterations, LI); + BFI, PSI, + InstCombineDefaultMaxIterations, LI); } char InstructionCombiningPass::ID = 0; -InstructionCombiningPass::InstructionCombiningPass() - : FunctionPass(ID), MaxIterations(InstCombineDefaultMaxIterations) { - initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); -} - -InstructionCombiningPass::InstructionCombiningPass(unsigned MaxIterations) - : FunctionPass(ID), MaxIterations(MaxIterations) { +InstructionCombiningPass::InstructionCombiningPass() : FunctionPass(ID) { initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); } @@ -4752,18 +4411,6 @@ void llvm::initializeInstCombine(PassRegistry &Registry) { initializeInstructionCombiningPassPass(Registry); } -void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { - initializeInstructionCombiningPassPass(*unwrap(R)); -} - FunctionPass *llvm::createInstructionCombiningPass() { return new InstructionCombiningPass(); } - -FunctionPass *llvm::createInstructionCombiningPass(unsigned MaxIterations) { - return new InstructionCombiningPass(MaxIterations); -} - -void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createInstructionCombiningPass()); -} diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 599eeeabc143..bde5fba20f3b 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -70,6 +69,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h" @@ -492,7 +492,7 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, bool IsMIPS64 = TargetTriple.isMIPS64(); bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb(); bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64; - bool IsLoongArch64 = TargetTriple.getArch() == Triple::loongarch64; + bool IsLoongArch64 = TargetTriple.isLoongArch64(); bool IsRISCV64 = TargetTriple.getArch() == Triple::riscv64; bool IsWindows = TargetTriple.isOSWindows(); bool IsFuchsia = TargetTriple.isOSFuchsia(); @@ -656,6 +656,7 @@ struct AddressSanitizer { : UseAfterReturn), SSGI(SSGI) { C = &(M.getContext()); + DL = &M.getDataLayout(); LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); Int8PtrTy = Type::getInt8PtrTy(*C); @@ -667,17 +668,8 @@ struct AddressSanitizer { assert(this->UseAfterReturn != AsanDetectStackUseAfterReturnMode::Invalid); } - uint64_t getAllocaSizeInBytes(const AllocaInst &AI) const { - uint64_t ArraySize = 1; - if (AI.isArrayAllocation()) { - const ConstantInt *CI = dyn_cast<ConstantInt>(AI.getArraySize()); - assert(CI && "non-constant array size"); - ArraySize = CI->getZExtValue(); - } - Type *Ty = AI.getAllocatedType(); - uint64_t SizeInBytes = - AI.getModule()->getDataLayout().getTypeAllocSize(Ty); - return SizeInBytes * ArraySize; + TypeSize getAllocaSizeInBytes(const AllocaInst &AI) const { + return *AI.getAllocationSize(AI.getModule()->getDataLayout()); } /// Check if we want (and can) handle this alloca. @@ -692,19 +684,27 @@ struct AddressSanitizer { const DataLayout &DL); void instrumentPointerComparisonOrSubtraction(Instruction *I); void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, - Value *Addr, uint32_t TypeSize, bool IsWrite, + Value *Addr, MaybeAlign Alignment, + uint32_t TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp); Instruction *instrumentAMDGPUAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, - uint32_t TypeSize, bool IsWrite, + uint32_t TypeStoreSize, bool IsWrite, Value *SizeArgument); void instrumentUnusualSizeOrAlignment(Instruction *I, Instruction *InsertBefore, Value *Addr, - uint32_t TypeSize, bool IsWrite, + TypeSize TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp); + void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, const DataLayout &DL, + Type *IntptrTy, Value *Mask, Value *EVL, + Value *Stride, Instruction *I, Value *Addr, + MaybeAlign Alignment, unsigned Granularity, + Type *OpType, bool IsWrite, + Value *SizeArgument, bool UseCalls, + uint32_t Exp); Value *createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong, - Value *ShadowValue, uint32_t TypeSize); + Value *ShadowValue, uint32_t TypeStoreSize); Instruction *generateCrashCode(Instruction *InsertBefore, Value *Addr, bool IsWrite, size_t AccessSizeIndex, Value *SizeArgument, uint32_t Exp); @@ -724,7 +724,7 @@ private: bool LooksLikeCodeInBug11395(Instruction *I); bool GlobalIsLinkerInitialized(GlobalVariable *G); bool isSafeAccess(ObjectSizeOffsetVisitor &ObjSizeVis, Value *Addr, - uint64_t TypeSize) const; + TypeSize TypeStoreSize) const; /// Helper to cleanup per-function state. struct FunctionStateRAII { @@ -743,6 +743,7 @@ private: }; LLVMContext *C; + const DataLayout *DL; Triple TargetTriple; int LongSize; bool CompileKernel; @@ -1040,7 +1041,9 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { /// Collect Alloca instructions we want (and can) handle. void visitAllocaInst(AllocaInst &AI) { - if (!ASan.isInterestingAlloca(AI)) { + // FIXME: Handle scalable vectors instead of ignoring them. + if (!ASan.isInterestingAlloca(AI) || + isa<ScalableVectorType>(AI.getAllocatedType())) { if (AI.isStaticAlloca()) { // Skip over allocas that are present *before* the first instrumented // alloca, we don't want to move those around. @@ -1133,10 +1136,10 @@ void AddressSanitizerPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<AddressSanitizerPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (Options.CompileKernel) OS << "kernel"; - OS << ">"; + OS << '>'; } AddressSanitizerPass::AddressSanitizerPass( @@ -1176,8 +1179,8 @@ PreservedAnalyses AddressSanitizerPass::run(Module &M, return PA; } -static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { - size_t Res = countTrailingZeros(TypeSize / 8); +static size_t TypeStoreSizeToSizeIndex(uint32_t TypeSize) { + size_t Res = llvm::countr_zero(TypeSize / 8); assert(Res < kNumberOfAccessSizes); return Res; } @@ -1227,7 +1230,7 @@ Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) { // Instrument memset/memmove/memcpy void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { - IRBuilder<> IRB(MI); + InstrumentationIRBuilder IRB(MI); if (isa<MemTransferInst>(MI)) { IRB.CreateCall( isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy, @@ -1254,7 +1257,7 @@ bool AddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { bool IsInteresting = (AI.getAllocatedType()->isSized() && // alloca() may be called with 0 size, ignore it. - ((!AI.isStaticAlloca()) || getAllocaSizeInBytes(AI) > 0) && + ((!AI.isStaticAlloca()) || !getAllocaSizeInBytes(AI).isZero()) && // We are only interested in allocas not promotable to registers. // Promotable allocas are common under -O0. (!ClSkipPromotableAllocas || !isAllocaPromotable(&AI)) && @@ -1326,9 +1329,12 @@ void AddressSanitizer::getInterestingMemoryOperands( XCHG->getCompareOperand()->getType(), std::nullopt); } else if (auto CI = dyn_cast<CallInst>(I)) { - if (CI->getIntrinsicID() == Intrinsic::masked_load || - CI->getIntrinsicID() == Intrinsic::masked_store) { - bool IsWrite = CI->getIntrinsicID() == Intrinsic::masked_store; + switch (CI->getIntrinsicID()) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: + case Intrinsic::masked_gather: + case Intrinsic::masked_scatter: { + bool IsWrite = CI->getType()->isVoidTy(); // Masked store has an initial operand for the value. unsigned OpOffset = IsWrite ? 1 : 0; if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads) @@ -1344,7 +1350,76 @@ void AddressSanitizer::getInterestingMemoryOperands( Alignment = Op->getMaybeAlignValue(); Value *Mask = CI->getOperand(2 + OpOffset); Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, Mask); - } else { + break; + } + case Intrinsic::masked_expandload: + case Intrinsic::masked_compressstore: { + bool IsWrite = CI->getIntrinsicID() == Intrinsic::masked_compressstore; + unsigned OpOffset = IsWrite ? 1 : 0; + if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads) + return; + auto BasePtr = CI->getOperand(OpOffset); + if (ignoreAccess(I, BasePtr)) + return; + MaybeAlign Alignment = BasePtr->getPointerAlignment(*DL); + Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType(); + + IRBuilder IB(I); + Value *Mask = CI->getOperand(1 + OpOffset); + // Use the popcount of Mask as the effective vector length. + Type *ExtTy = VectorType::get(IntptrTy, cast<VectorType>(Ty)); + Value *ExtMask = IB.CreateZExt(Mask, ExtTy); + Value *EVL = IB.CreateAddReduce(ExtMask); + Value *TrueMask = ConstantInt::get(Mask->getType(), 1); + Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, TrueMask, + EVL); + break; + } + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::experimental_vp_strided_load: + case Intrinsic::experimental_vp_strided_store: { + auto *VPI = cast<VPIntrinsic>(CI); + unsigned IID = CI->getIntrinsicID(); + bool IsWrite = CI->getType()->isVoidTy(); + if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads) + return; + unsigned PtrOpNo = *VPI->getMemoryPointerParamPos(IID); + Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType(); + MaybeAlign Alignment = VPI->getOperand(PtrOpNo)->getPointerAlignment(*DL); + Value *Stride = nullptr; + if (IID == Intrinsic::experimental_vp_strided_store || + IID == Intrinsic::experimental_vp_strided_load) { + Stride = VPI->getOperand(PtrOpNo + 1); + // Use the pointer alignment as the element alignment if the stride is a + // mutiple of the pointer alignment. Otherwise, the element alignment + // should be Align(1). + unsigned PointerAlign = Alignment.valueOrOne().value(); + if (!isa<ConstantInt>(Stride) || + cast<ConstantInt>(Stride)->getZExtValue() % PointerAlign != 0) + Alignment = Align(1); + } + Interesting.emplace_back(I, PtrOpNo, IsWrite, Ty, Alignment, + VPI->getMaskParam(), VPI->getVectorLengthParam(), + Stride); + break; + } + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: { + auto *VPI = cast<VPIntrinsic>(CI); + unsigned IID = CI->getIntrinsicID(); + bool IsWrite = IID == Intrinsic::vp_scatter; + if (IsWrite ? !ClInstrumentWrites : !ClInstrumentReads) + return; + unsigned PtrOpNo = *VPI->getMemoryPointerParamPos(IID); + Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType(); + MaybeAlign Alignment = VPI->getPointerAlignment(); + Interesting.emplace_back(I, PtrOpNo, IsWrite, Ty, Alignment, + VPI->getMaskParam(), + VPI->getVectorLengthParam()); + break; + } + default: for (unsigned ArgNo = 0; ArgNo < CI->arg_size(); ArgNo++) { if (!ClInstrumentByval || !CI->isByValArgument(ArgNo) || ignoreAccess(I, CI->getArgOperand(ArgNo))) @@ -1416,57 +1491,94 @@ void AddressSanitizer::instrumentPointerComparisonOrSubtraction( static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I, Instruction *InsertBefore, Value *Addr, MaybeAlign Alignment, unsigned Granularity, - uint32_t TypeSize, bool IsWrite, + TypeSize TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { // Instrument a 1-, 2-, 4-, 8-, or 16- byte access with one check // if the data is properly aligned. - if ((TypeSize == 8 || TypeSize == 16 || TypeSize == 32 || TypeSize == 64 || - TypeSize == 128) && - (!Alignment || *Alignment >= Granularity || *Alignment >= TypeSize / 8)) - return Pass->instrumentAddress(I, InsertBefore, Addr, TypeSize, IsWrite, - nullptr, UseCalls, Exp); - Pass->instrumentUnusualSizeOrAlignment(I, InsertBefore, Addr, TypeSize, + if (!TypeStoreSize.isScalable()) { + const auto FixedSize = TypeStoreSize.getFixedValue(); + switch (FixedSize) { + case 8: + case 16: + case 32: + case 64: + case 128: + if (!Alignment || *Alignment >= Granularity || + *Alignment >= FixedSize / 8) + return Pass->instrumentAddress(I, InsertBefore, Addr, Alignment, + FixedSize, IsWrite, nullptr, UseCalls, + Exp); + } + } + Pass->instrumentUnusualSizeOrAlignment(I, InsertBefore, Addr, TypeStoreSize, IsWrite, nullptr, UseCalls, Exp); } -static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, - const DataLayout &DL, Type *IntptrTy, - Value *Mask, Instruction *I, - Value *Addr, MaybeAlign Alignment, - unsigned Granularity, Type *OpType, - bool IsWrite, Value *SizeArgument, - bool UseCalls, uint32_t Exp) { - auto *VTy = cast<FixedVectorType>(OpType); - uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); - unsigned Num = VTy->getNumElements(); +void AddressSanitizer::instrumentMaskedLoadOrStore( + AddressSanitizer *Pass, const DataLayout &DL, Type *IntptrTy, Value *Mask, + Value *EVL, Value *Stride, Instruction *I, Value *Addr, + MaybeAlign Alignment, unsigned Granularity, Type *OpType, bool IsWrite, + Value *SizeArgument, bool UseCalls, uint32_t Exp) { + auto *VTy = cast<VectorType>(OpType); + TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); auto Zero = ConstantInt::get(IntptrTy, 0); - for (unsigned Idx = 0; Idx < Num; ++Idx) { - Value *InstrumentedAddress = nullptr; - Instruction *InsertBefore = I; - if (auto *Vector = dyn_cast<ConstantVector>(Mask)) { - // dyn_cast as we might get UndefValue - if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) { - if (Masked->isZero()) - // Mask is constant false, so no instrumentation needed. - continue; - // If we have a true or undef value, fall through to doInstrumentAddress - // with InsertBefore == I - } + + IRBuilder IB(I); + Instruction *LoopInsertBefore = I; + if (EVL) { + // The end argument of SplitBlockAndInsertForLane is assumed bigger + // than zero, so we should check whether EVL is zero here. + Type *EVLType = EVL->getType(); + Value *IsEVLZero = IB.CreateICmpNE(EVL, ConstantInt::get(EVLType, 0)); + LoopInsertBefore = SplitBlockAndInsertIfThen(IsEVLZero, I, false); + IB.SetInsertPoint(LoopInsertBefore); + // Cast EVL to IntptrTy. + EVL = IB.CreateZExtOrTrunc(EVL, IntptrTy); + // To avoid undefined behavior for extracting with out of range index, use + // the minimum of evl and element count as trip count. + Value *EC = IB.CreateElementCount(IntptrTy, VTy->getElementCount()); + EVL = IB.CreateBinaryIntrinsic(Intrinsic::umin, EVL, EC); + } else { + EVL = IB.CreateElementCount(IntptrTy, VTy->getElementCount()); + } + + // Cast Stride to IntptrTy. + if (Stride) + Stride = IB.CreateZExtOrTrunc(Stride, IntptrTy); + + SplitBlockAndInsertForEachLane(EVL, LoopInsertBefore, + [&](IRBuilderBase &IRB, Value *Index) { + Value *MaskElem = IRB.CreateExtractElement(Mask, Index); + if (auto *MaskElemC = dyn_cast<ConstantInt>(MaskElem)) { + if (MaskElemC->isZero()) + // No check + return; + // Unconditional check } else { - IRBuilder<> IRB(I); - Value *MaskElem = IRB.CreateExtractElement(Mask, Idx); - Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false); - InsertBefore = ThenTerm; + // Conditional check + Instruction *ThenTerm = SplitBlockAndInsertIfThen( + MaskElem, &*IRB.GetInsertPoint(), false); + IRB.SetInsertPoint(ThenTerm); } - IRBuilder<> IRB(InsertBefore); - InstrumentedAddress = - IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)}); - doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment, - Granularity, ElemTypeSize, IsWrite, SizeArgument, - UseCalls, Exp); - } + Value *InstrumentedAddress; + if (isa<VectorType>(Addr->getType())) { + assert( + cast<VectorType>(Addr->getType())->getElementType()->isPointerTy() && + "Expected vector of pointer."); + InstrumentedAddress = IRB.CreateExtractElement(Addr, Index); + } else if (Stride) { + Index = IRB.CreateMul(Index, Stride); + Addr = IRB.CreateBitCast(Addr, Type::getInt8PtrTy(*C)); + InstrumentedAddress = IRB.CreateGEP(Type::getInt8Ty(*C), Addr, {Index}); + } else { + InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index}); + } + doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), + InstrumentedAddress, Alignment, Granularity, + ElemTypeSize, IsWrite, SizeArgument, UseCalls, Exp); + }); } void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, @@ -1492,7 +1604,7 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, // dynamically initialized global is always valid. GlobalVariable *G = dyn_cast<GlobalVariable>(getUnderlyingObject(Addr)); if (G && (!ClInitializers || GlobalIsLinkerInitialized(G)) && - isSafeAccess(ObjSizeVis, Addr, O.TypeSize)) { + isSafeAccess(ObjSizeVis, Addr, O.TypeStoreSize)) { NumOptimizedAccessesToGlobalVar++; return; } @@ -1501,7 +1613,7 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, if (ClOpt && ClOptStack) { // A direct inbounds access to a stack variable is always valid. if (isa<AllocaInst>(getUnderlyingObject(Addr)) && - isSafeAccess(ObjSizeVis, Addr, O.TypeSize)) { + isSafeAccess(ObjSizeVis, Addr, O.TypeStoreSize)) { NumOptimizedAccessesToStackVar++; return; } @@ -1514,12 +1626,13 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, unsigned Granularity = 1 << Mapping.Scale; if (O.MaybeMask) { - instrumentMaskedLoadOrStore(this, DL, IntptrTy, O.MaybeMask, O.getInsn(), - Addr, O.Alignment, Granularity, O.OpType, - O.IsWrite, nullptr, UseCalls, Exp); + instrumentMaskedLoadOrStore(this, DL, IntptrTy, O.MaybeMask, O.MaybeEVL, + O.MaybeStride, O.getInsn(), Addr, O.Alignment, + Granularity, O.OpType, O.IsWrite, nullptr, + UseCalls, Exp); } else { doInstrumentAddress(this, O.getInsn(), O.getInsn(), Addr, O.Alignment, - Granularity, O.TypeSize, O.IsWrite, nullptr, UseCalls, + Granularity, O.TypeStoreSize, O.IsWrite, nullptr, UseCalls, Exp); } } @@ -1529,7 +1642,7 @@ Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore, size_t AccessSizeIndex, Value *SizeArgument, uint32_t Exp) { - IRBuilder<> IRB(InsertBefore); + InstrumentationIRBuilder IRB(InsertBefore); Value *ExpVal = Exp == 0 ? nullptr : ConstantInt::get(IRB.getInt32Ty(), Exp); CallInst *Call = nullptr; if (SizeArgument) { @@ -1554,15 +1667,15 @@ Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore, Value *AddressSanitizer::createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong, Value *ShadowValue, - uint32_t TypeSize) { + uint32_t TypeStoreSize) { size_t Granularity = static_cast<size_t>(1) << Mapping.Scale; // Addr & (Granularity - 1) Value *LastAccessedByte = IRB.CreateAnd(AddrLong, ConstantInt::get(IntptrTy, Granularity - 1)); // (Addr & (Granularity - 1)) + size - 1 - if (TypeSize / 8 > 1) + if (TypeStoreSize / 8 > 1) LastAccessedByte = IRB.CreateAdd( - LastAccessedByte, ConstantInt::get(IntptrTy, TypeSize / 8 - 1)); + LastAccessedByte, ConstantInt::get(IntptrTy, TypeStoreSize / 8 - 1)); // (uint8_t) ((Addr & (Granularity-1)) + size - 1) LastAccessedByte = IRB.CreateIntCast(LastAccessedByte, ShadowValue->getType(), false); @@ -1572,7 +1685,7 @@ Value *AddressSanitizer::createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong, Instruction *AddressSanitizer::instrumentAMDGPUAddress( Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, - uint32_t TypeSize, bool IsWrite, Value *SizeArgument) { + uint32_t TypeStoreSize, bool IsWrite, Value *SizeArgument) { // Do not instrument unsupported addrspaces. if (isUnsupportedAMDGPUAddrspace(Addr)) return nullptr; @@ -1595,18 +1708,19 @@ Instruction *AddressSanitizer::instrumentAMDGPUAddress( void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, - uint32_t TypeSize, bool IsWrite, + MaybeAlign Alignment, + uint32_t TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { if (TargetTriple.isAMDGPU()) { InsertBefore = instrumentAMDGPUAddress(OrigIns, InsertBefore, Addr, - TypeSize, IsWrite, SizeArgument); + TypeStoreSize, IsWrite, SizeArgument); if (!InsertBefore) return; } - IRBuilder<> IRB(InsertBefore); - size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize); + InstrumentationIRBuilder IRB(InsertBefore); + size_t AccessSizeIndex = TypeStoreSizeToSizeIndex(TypeStoreSize); const ASanAccessInfo AccessInfo(IsWrite, CompileKernel, AccessSizeIndex); if (UseCalls && ClOptimizeCallbacks) { @@ -1631,17 +1745,19 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, } Type *ShadowTy = - IntegerType::get(*C, std::max(8U, TypeSize >> Mapping.Scale)); + IntegerType::get(*C, std::max(8U, TypeStoreSize >> Mapping.Scale)); Type *ShadowPtrTy = PointerType::get(ShadowTy, 0); Value *ShadowPtr = memToShadow(AddrLong, IRB); - Value *ShadowValue = - IRB.CreateLoad(ShadowTy, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); + const uint64_t ShadowAlign = + std::max<uint64_t>(Alignment.valueOrOne().value() >> Mapping.Scale, 1); + Value *ShadowValue = IRB.CreateAlignedLoad( + ShadowTy, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy), Align(ShadowAlign)); Value *Cmp = IRB.CreateIsNotNull(ShadowValue); size_t Granularity = 1ULL << Mapping.Scale; Instruction *CrashTerm = nullptr; - if (ClAlwaysSlowPath || (TypeSize < 8 * Granularity)) { + if (ClAlwaysSlowPath || (TypeStoreSize < 8 * Granularity)) { // We use branch weights for the slow path check, to indicate that the slow // path is rarely taken. This seems to be the case for SPEC benchmarks. Instruction *CheckTerm = SplitBlockAndInsertIfThen( @@ -1649,7 +1765,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, assert(cast<BranchInst>(CheckTerm)->isUnconditional()); BasicBlock *NextBB = CheckTerm->getSuccessor(0); IRB.SetInsertPoint(CheckTerm); - Value *Cmp2 = createSlowPathCmp(IRB, AddrLong, ShadowValue, TypeSize); + Value *Cmp2 = createSlowPathCmp(IRB, AddrLong, ShadowValue, TypeStoreSize); if (Recover) { CrashTerm = SplitBlockAndInsertIfThen(Cmp2, CheckTerm, false); } else { @@ -1665,7 +1781,8 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Instruction *Crash = generateCrashCode(CrashTerm, AddrLong, IsWrite, AccessSizeIndex, SizeArgument, Exp); - Crash->setDebugLoc(OrigIns->getDebugLoc()); + if (OrigIns->getDebugLoc()) + Crash->setDebugLoc(OrigIns->getDebugLoc()); } // Instrument unusual size or unusual alignment. @@ -1673,10 +1790,12 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, // and the last bytes. We call __asan_report_*_n(addr, real_size) to be able // to report the actual access size. void AddressSanitizer::instrumentUnusualSizeOrAlignment( - Instruction *I, Instruction *InsertBefore, Value *Addr, uint32_t TypeSize, + Instruction *I, Instruction *InsertBefore, Value *Addr, TypeSize TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { - IRBuilder<> IRB(InsertBefore); - Value *Size = ConstantInt::get(IntptrTy, TypeSize / 8); + InstrumentationIRBuilder IRB(InsertBefore); + Value *NumBits = IRB.CreateTypeSize(IntptrTy, TypeStoreSize); + Value *Size = IRB.CreateLShr(NumBits, ConstantInt::get(IntptrTy, 3)); + Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (UseCalls) { if (Exp == 0) @@ -1686,11 +1805,13 @@ void AddressSanitizer::instrumentUnusualSizeOrAlignment( IRB.CreateCall(AsanMemoryAccessCallbackSized[IsWrite][1], {AddrLong, Size, ConstantInt::get(IRB.getInt32Ty(), Exp)}); } else { + Value *SizeMinusOne = IRB.CreateSub(Size, ConstantInt::get(IntptrTy, 1)); Value *LastByte = IRB.CreateIntToPtr( - IRB.CreateAdd(AddrLong, ConstantInt::get(IntptrTy, TypeSize / 8 - 1)), + IRB.CreateAdd(AddrLong, SizeMinusOne), Addr->getType()); - instrumentAddress(I, InsertBefore, Addr, 8, IsWrite, Size, false, Exp); - instrumentAddress(I, InsertBefore, LastByte, 8, IsWrite, Size, false, Exp); + instrumentAddress(I, InsertBefore, Addr, {}, 8, IsWrite, Size, false, Exp); + instrumentAddress(I, InsertBefore, LastByte, {}, 8, IsWrite, Size, false, + Exp); } } @@ -2306,7 +2427,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, G->getThreadLocalMode(), G->getAddressSpace()); NewGlobal->copyAttributesFrom(G); NewGlobal->setComdat(G->getComdat()); - NewGlobal->setAlignment(MaybeAlign(getMinRedzoneSizeForGlobal())); + NewGlobal->setAlignment(Align(getMinRedzoneSizeForGlobal())); // Don't fold globals with redzones. ODR violation detector and redzone // poisoning implicitly creates a dependence on the global's address, so it // is no longer valid for it to be marked unnamed_addr. @@ -3485,7 +3606,11 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { // base object. For example, it is a field access or an array access with // constant inbounds index. bool AddressSanitizer::isSafeAccess(ObjectSizeOffsetVisitor &ObjSizeVis, - Value *Addr, uint64_t TypeSize) const { + Value *Addr, TypeSize TypeStoreSize) const { + if (TypeStoreSize.isScalable()) + // TODO: We can use vscale_range to convert a scalable value to an + // upper bound on the access size. + return false; SizeOffsetType SizeOffset = ObjSizeVis.compute(Addr); if (!ObjSizeVis.bothKnown(SizeOffset)) return false; uint64_t Size = SizeOffset.first.getZExtValue(); @@ -3495,5 +3620,5 @@ bool AddressSanitizer::isSafeAccess(ObjectSizeOffsetVisitor &ObjSizeVis, // . Size >= Offset (unsigned) // . Size - Offset >= NeededSize (unsigned) return Offset >= 0 && Size >= uint64_t(Offset) && - Size - uint64_t(Offset) >= TypeSize / 8; + Size - uint64_t(Offset) >= TypeStoreSize / 8; } diff --git a/llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp b/llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp new file mode 100644 index 000000000000..0e49984c6ee3 --- /dev/null +++ b/llvm/lib/Transforms/Instrumentation/BlockCoverageInference.cpp @@ -0,0 +1,368 @@ +//===-- BlockCoverageInference.cpp - Minimal Execution Coverage -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Our algorithm works by first identifying a subset of nodes that must always +// be instrumented. We call these nodes ambiguous because knowing the coverage +// of all remaining nodes is not enough to infer their coverage status. +// +// In general a node v is ambiguous if there exists two entry-to-terminal paths +// P_1 and P_2 such that: +// 1. v not in P_1 but P_1 visits a predecessor of v, and +// 2. v not in P_2 but P_2 visits a successor of v. +// +// If a node v is not ambiguous, then if condition 1 fails, we can infer v’s +// coverage from the coverage of its predecessors, or if condition 2 fails, we +// can infer v’s coverage from the coverage of its successors. +// +// Sadly, there are example CFGs where it is not possible to infer all nodes +// from the ambiguous nodes alone. Our algorithm selects a minimum number of +// extra nodes to add to the ambiguous nodes to form a valid instrumentation S. +// +// Details on this algorithm can be found in https://arxiv.org/abs/2208.13907 +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/BlockCoverageInference.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/CRC.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "pgo-block-coverage" + +STATISTIC(NumFunctions, "Number of total functions that BCI has processed"); +STATISTIC(NumIneligibleFunctions, + "Number of functions for which BCI cannot run on"); +STATISTIC(NumBlocks, "Number of total basic blocks that BCI has processed"); +STATISTIC(NumInstrumentedBlocks, + "Number of basic blocks instrumented for coverage"); + +BlockCoverageInference::BlockCoverageInference(const Function &F, + bool ForceInstrumentEntry) + : F(F), ForceInstrumentEntry(ForceInstrumentEntry) { + findDependencies(); + assert(!ForceInstrumentEntry || shouldInstrumentBlock(F.getEntryBlock())); + + ++NumFunctions; + for (auto &BB : F) { + ++NumBlocks; + if (shouldInstrumentBlock(BB)) + ++NumInstrumentedBlocks; + } +} + +BlockCoverageInference::BlockSet +BlockCoverageInference::getDependencies(const BasicBlock &BB) const { + assert(BB.getParent() == &F); + BlockSet Dependencies; + auto It = PredecessorDependencies.find(&BB); + if (It != PredecessorDependencies.end()) + Dependencies.set_union(It->second); + It = SuccessorDependencies.find(&BB); + if (It != SuccessorDependencies.end()) + Dependencies.set_union(It->second); + return Dependencies; +} + +uint64_t BlockCoverageInference::getInstrumentedBlocksHash() const { + JamCRC JC; + uint64_t Index = 0; + for (auto &BB : F) { + if (shouldInstrumentBlock(BB)) { + uint8_t Data[8]; + support::endian::write64le(Data, Index); + JC.update(Data); + } + Index++; + } + return JC.getCRC(); +} + +bool BlockCoverageInference::shouldInstrumentBlock(const BasicBlock &BB) const { + assert(BB.getParent() == &F); + auto It = PredecessorDependencies.find(&BB); + if (It != PredecessorDependencies.end() && It->second.size()) + return false; + It = SuccessorDependencies.find(&BB); + if (It != SuccessorDependencies.end() && It->second.size()) + return false; + return true; +} + +void BlockCoverageInference::findDependencies() { + assert(PredecessorDependencies.empty() && SuccessorDependencies.empty()); + // Empirical analysis shows that this algorithm finishes within 5 seconds for + // functions with fewer than 1.5K blocks. + if (F.hasFnAttribute(Attribute::NoReturn) || F.size() > 1500) { + ++NumIneligibleFunctions; + return; + } + + SmallVector<const BasicBlock *, 4> TerminalBlocks; + for (auto &BB : F) + if (succ_empty(&BB)) + TerminalBlocks.push_back(&BB); + + // Traverse the CFG backwards from the terminal blocks to make sure every + // block can reach some terminal block. Otherwise this algorithm will not work + // and we must fall back to instrumenting every block. + df_iterator_default_set<const BasicBlock *> Visited; + for (auto *BB : TerminalBlocks) + for (auto *N : inverse_depth_first_ext(BB, Visited)) + (void)N; + if (F.size() != Visited.size()) { + ++NumIneligibleFunctions; + return; + } + + // The current implementation for computing `PredecessorDependencies` and + // `SuccessorDependencies` runs in quadratic time with respect to the number + // of basic blocks. While we do have a more complicated linear time algorithm + // in https://arxiv.org/abs/2208.13907 we do not know if it will give a + // significant speedup in practice given that most functions tend to be + // relatively small in size for intended use cases. + auto &EntryBlock = F.getEntryBlock(); + for (auto &BB : F) { + // The set of blocks that are reachable while avoiding BB. + BlockSet ReachableFromEntry, ReachableFromTerminal; + getReachableAvoiding(EntryBlock, BB, /*IsForward=*/true, + ReachableFromEntry); + for (auto *TerminalBlock : TerminalBlocks) + getReachableAvoiding(*TerminalBlock, BB, /*IsForward=*/false, + ReachableFromTerminal); + + auto Preds = predecessors(&BB); + bool HasSuperReachablePred = llvm::any_of(Preds, [&](auto *Pred) { + return ReachableFromEntry.count(Pred) && + ReachableFromTerminal.count(Pred); + }); + if (!HasSuperReachablePred) + for (auto *Pred : Preds) + if (ReachableFromEntry.count(Pred)) + PredecessorDependencies[&BB].insert(Pred); + + auto Succs = successors(&BB); + bool HasSuperReachableSucc = llvm::any_of(Succs, [&](auto *Succ) { + return ReachableFromEntry.count(Succ) && + ReachableFromTerminal.count(Succ); + }); + if (!HasSuperReachableSucc) + for (auto *Succ : Succs) + if (ReachableFromTerminal.count(Succ)) + SuccessorDependencies[&BB].insert(Succ); + } + + if (ForceInstrumentEntry) { + // Force the entry block to be instrumented by clearing the blocks it can + // infer coverage from. + PredecessorDependencies[&EntryBlock].clear(); + SuccessorDependencies[&EntryBlock].clear(); + } + + // Construct a graph where blocks are connected if there is a mutual + // dependency between them. This graph has a special property that it contains + // only paths. + DenseMap<const BasicBlock *, BlockSet> AdjacencyList; + for (auto &BB : F) { + for (auto *Succ : successors(&BB)) { + if (SuccessorDependencies[&BB].count(Succ) && + PredecessorDependencies[Succ].count(&BB)) { + AdjacencyList[&BB].insert(Succ); + AdjacencyList[Succ].insert(&BB); + } + } + } + + // Given a path with at least one node, return the next node on the path. + auto getNextOnPath = [&](BlockSet &Path) -> const BasicBlock * { + assert(Path.size()); + auto &Neighbors = AdjacencyList[Path.back()]; + if (Path.size() == 1) { + // This is the first node on the path, return its neighbor. + assert(Neighbors.size() == 1); + return Neighbors.front(); + } else if (Neighbors.size() == 2) { + // This is the middle of the path, find the neighbor that is not on the + // path already. + assert(Path.size() >= 2); + return Path.count(Neighbors[0]) ? Neighbors[1] : Neighbors[0]; + } + // This is the end of the path. + assert(Neighbors.size() == 1); + return nullptr; + }; + + // Remove all cycles in the inferencing graph. + for (auto &BB : F) { + if (AdjacencyList[&BB].size() == 1) { + // We found the head of some path. + BlockSet Path; + Path.insert(&BB); + while (const BasicBlock *Next = getNextOnPath(Path)) + Path.insert(Next); + LLVM_DEBUG(dbgs() << "Found path: " << getBlockNames(Path) << "\n"); + + // Remove these nodes from the graph so we don't discover this path again. + for (auto *BB : Path) + AdjacencyList[BB].clear(); + + // Finally, remove the cycles. + if (PredecessorDependencies[Path.front()].size()) { + for (auto *BB : Path) + if (BB != Path.back()) + SuccessorDependencies[BB].clear(); + } else { + for (auto *BB : Path) + if (BB != Path.front()) + PredecessorDependencies[BB].clear(); + } + } + } + LLVM_DEBUG(dump(dbgs())); +} + +void BlockCoverageInference::getReachableAvoiding(const BasicBlock &Start, + const BasicBlock &Avoid, + bool IsForward, + BlockSet &Reachable) const { + df_iterator_default_set<const BasicBlock *> Visited; + Visited.insert(&Avoid); + if (IsForward) { + auto Range = depth_first_ext(&Start, Visited); + Reachable.insert(Range.begin(), Range.end()); + } else { + auto Range = inverse_depth_first_ext(&Start, Visited); + Reachable.insert(Range.begin(), Range.end()); + } +} + +namespace llvm { +class DotFuncBCIInfo { +private: + const BlockCoverageInference *BCI; + const DenseMap<const BasicBlock *, bool> *Coverage; + +public: + DotFuncBCIInfo(const BlockCoverageInference *BCI, + const DenseMap<const BasicBlock *, bool> *Coverage) + : BCI(BCI), Coverage(Coverage) {} + + const Function &getFunction() { return BCI->F; } + + bool isInstrumented(const BasicBlock *BB) const { + return BCI->shouldInstrumentBlock(*BB); + } + + bool isCovered(const BasicBlock *BB) const { + return Coverage && Coverage->lookup(BB); + } + + bool isDependent(const BasicBlock *Src, const BasicBlock *Dest) const { + return BCI->getDependencies(*Src).count(Dest); + } +}; + +template <> +struct GraphTraits<DotFuncBCIInfo *> : public GraphTraits<const BasicBlock *> { + static NodeRef getEntryNode(DotFuncBCIInfo *Info) { + return &(Info->getFunction().getEntryBlock()); + } + + // nodes_iterator/begin/end - Allow iteration over all nodes in the graph + using nodes_iterator = pointer_iterator<Function::const_iterator>; + + static nodes_iterator nodes_begin(DotFuncBCIInfo *Info) { + return nodes_iterator(Info->getFunction().begin()); + } + + static nodes_iterator nodes_end(DotFuncBCIInfo *Info) { + return nodes_iterator(Info->getFunction().end()); + } + + static size_t size(DotFuncBCIInfo *Info) { + return Info->getFunction().size(); + } +}; + +template <> +struct DOTGraphTraits<DotFuncBCIInfo *> : public DefaultDOTGraphTraits { + + DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} + + static std::string getGraphName(DotFuncBCIInfo *Info) { + return "BCI CFG for " + Info->getFunction().getName().str(); + } + + std::string getNodeLabel(const BasicBlock *Node, DotFuncBCIInfo *Info) { + return Node->getName().str(); + } + + std::string getEdgeAttributes(const BasicBlock *Src, const_succ_iterator I, + DotFuncBCIInfo *Info) { + const BasicBlock *Dest = *I; + if (Info->isDependent(Src, Dest)) + return "color=red"; + if (Info->isDependent(Dest, Src)) + return "color=blue"; + return ""; + } + + std::string getNodeAttributes(const BasicBlock *Node, DotFuncBCIInfo *Info) { + std::string Result; + if (Info->isInstrumented(Node)) + Result += "style=filled,fillcolor=gray"; + if (Info->isCovered(Node)) + Result += std::string(Result.empty() ? "" : ",") + "color=red"; + return Result; + } +}; + +} // namespace llvm + +void BlockCoverageInference::viewBlockCoverageGraph( + const DenseMap<const BasicBlock *, bool> *Coverage) const { + DotFuncBCIInfo Info(this, Coverage); + WriteGraph(&Info, "BCI", false, + "Block Coverage Inference for " + F.getName()); +} + +void BlockCoverageInference::dump(raw_ostream &OS) const { + OS << "Minimal block coverage for function \'" << F.getName() + << "\' (Instrumented=*)\n"; + for (auto &BB : F) { + OS << (shouldInstrumentBlock(BB) ? "* " : " ") << BB.getName() << "\n"; + auto It = PredecessorDependencies.find(&BB); + if (It != PredecessorDependencies.end() && It->second.size()) + OS << " PredDeps = " << getBlockNames(It->second) << "\n"; + It = SuccessorDependencies.find(&BB); + if (It != SuccessorDependencies.end() && It->second.size()) + OS << " SuccDeps = " << getBlockNames(It->second) << "\n"; + } + OS << " Instrumented Blocks Hash = 0x" + << Twine::utohexstr(getInstrumentedBlocksHash()) << "\n"; +} + +std::string +BlockCoverageInference::getBlockNames(ArrayRef<const BasicBlock *> BBs) { + std::string Result; + raw_string_ostream OS(Result); + OS << "["; + if (!BBs.empty()) { + OS << BBs.front()->getName(); + BBs = BBs.drop_front(); + } + for (auto *BB : BBs) + OS << ", " << BB->getName(); + OS << "]"; + return OS.str(); +} diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index 8b1d39ad412f..709095184af5 100644 --- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -23,8 +23,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -56,7 +54,7 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal, const DataLayout &DL, TargetLibraryInfo &TLI, ObjectSizeOffsetEvaluator &ObjSizeEval, BuilderTy &IRB, ScalarEvolution &SE) { - uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType()); + TypeSize NeededSize = DL.getTypeStoreSize(InstVal->getType()); LLVM_DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) << " bytes\n"); @@ -71,8 +69,8 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal, Value *Offset = SizeOffset.second; ConstantInt *SizeCI = dyn_cast<ConstantInt>(Size); - Type *IntTy = DL.getIntPtrType(Ptr->getType()); - Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize); + Type *IndexTy = DL.getIndexType(Ptr->getType()); + Value *NeededSizeVal = IRB.CreateTypeSize(IndexTy, NeededSize); auto SizeRange = SE.getUnsignedRange(SE.getSCEV(Size)); auto OffsetRange = SE.getUnsignedRange(SE.getSCEV(Offset)); @@ -97,7 +95,7 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal, Value *Or = IRB.CreateOr(Cmp2, Cmp3); if ((!SizeCI || SizeCI->getValue().slt(0)) && !SizeRange.getSignedMin().isNonNegative()) { - Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0)); + Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IndexTy, 0)); Or = IRB.CreateOr(Cmp1, Or); } diff --git a/llvm/lib/Transforms/Instrumentation/CFGMST.h b/llvm/lib/Transforms/Instrumentation/CFGMST.h deleted file mode 100644 index 2abe8d12de3c..000000000000 --- a/llvm/lib/Transforms/Instrumentation/CFGMST.h +++ /dev/null @@ -1,303 +0,0 @@ -//===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a Union-find algorithm to compute Minimum Spanning Tree -// for a given CFG. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H -#define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Support/BranchProbability.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <utility> -#include <vector> - -#define DEBUG_TYPE "cfgmst" - -using namespace llvm; - -namespace llvm { - -/// An union-find based Minimum Spanning Tree for CFG -/// -/// Implements a Union-find algorithm to compute Minimum Spanning Tree -/// for a given CFG. -template <class Edge, class BBInfo> class CFGMST { -public: - Function &F; - - // Store all the edges in CFG. It may contain some stale edges - // when Removed is set. - std::vector<std::unique_ptr<Edge>> AllEdges; - - // This map records the auxiliary information for each BB. - DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos; - - // Whehter the function has an exit block with no successors. - // (For function with an infinite loop, this block may be absent) - bool ExitBlockFound = false; - - // Find the root group of the G and compress the path from G to the root. - BBInfo *findAndCompressGroup(BBInfo *G) { - if (G->Group != G) - G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group)); - return static_cast<BBInfo *>(G->Group); - } - - // Union BB1 and BB2 into the same group and return true. - // Returns false if BB1 and BB2 are already in the same group. - bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) { - BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1)); - BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2)); - - if (BB1G == BB2G) - return false; - - // Make the smaller rank tree a direct child or the root of high rank tree. - if (BB1G->Rank < BB2G->Rank) - BB1G->Group = BB2G; - else { - BB2G->Group = BB1G; - // If the ranks are the same, increment root of one tree by one. - if (BB1G->Rank == BB2G->Rank) - BB1G->Rank++; - } - return true; - } - - // Give BB, return the auxiliary information. - BBInfo &getBBInfo(const BasicBlock *BB) const { - auto It = BBInfos.find(BB); - assert(It->second.get() != nullptr); - return *It->second.get(); - } - - // Give BB, return the auxiliary information if it's available. - BBInfo *findBBInfo(const BasicBlock *BB) const { - auto It = BBInfos.find(BB); - if (It == BBInfos.end()) - return nullptr; - return It->second.get(); - } - - // Traverse the CFG using a stack. Find all the edges and assign the weight. - // Edges with large weight will be put into MST first so they are less likely - // to be instrumented. - void buildEdges() { - LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n"); - - const BasicBlock *Entry = &(F.getEntryBlock()); - uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2); - // If we want to instrument the entry count, lower the weight to 0. - if (InstrumentFuncEntry) - EntryWeight = 0; - Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr, - *ExitOutgoing = nullptr, *ExitIncoming = nullptr; - uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0; - - // Add a fake edge to the entry. - EntryIncoming = &addEdge(nullptr, Entry, EntryWeight); - LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName() - << " w = " << EntryWeight << "\n"); - - // Special handling for single BB functions. - if (succ_empty(Entry)) { - addEdge(Entry, nullptr, EntryWeight); - return; - } - - static const uint32_t CriticalEdgeMultiplier = 1000; - - for (BasicBlock &BB : F) { - Instruction *TI = BB.getTerminator(); - uint64_t BBWeight = - (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2); - uint64_t Weight = 2; - if (int successors = TI->getNumSuccessors()) { - for (int i = 0; i != successors; ++i) { - BasicBlock *TargetBB = TI->getSuccessor(i); - bool Critical = isCriticalEdge(TI, i); - uint64_t scaleFactor = BBWeight; - if (Critical) { - if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier) - scaleFactor *= CriticalEdgeMultiplier; - else - scaleFactor = UINT64_MAX; - } - if (BPI != nullptr) - Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor); - if (Weight == 0) - Weight++; - auto *E = &addEdge(&BB, TargetBB, Weight); - E->IsCritical = Critical; - LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to " - << TargetBB->getName() << " w=" << Weight << "\n"); - - // Keep track of entry/exit edges: - if (&BB == Entry) { - if (Weight > MaxEntryOutWeight) { - MaxEntryOutWeight = Weight; - EntryOutgoing = E; - } - } - - auto *TargetTI = TargetBB->getTerminator(); - if (TargetTI && !TargetTI->getNumSuccessors()) { - if (Weight > MaxExitInWeight) { - MaxExitInWeight = Weight; - ExitIncoming = E; - } - } - } - } else { - ExitBlockFound = true; - Edge *ExitO = &addEdge(&BB, nullptr, BBWeight); - if (BBWeight > MaxExitOutWeight) { - MaxExitOutWeight = BBWeight; - ExitOutgoing = ExitO; - } - LLVM_DEBUG(dbgs() << " Edge: from " << BB.getName() << " to fake exit" - << " w = " << BBWeight << "\n"); - } - } - - // Entry/exit edge adjustment heurisitic: - // prefer instrumenting entry edge over exit edge - // if possible. Those exit edges may never have a chance to be - // executed (for instance the program is an event handling loop) - // before the profile is asynchronously dumped. - // - // If EntryIncoming and ExitOutgoing has similar weight, make sure - // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing - // and ExitIncoming has similar weight, make sure ExitIncoming becomes - // the min-edge. - uint64_t EntryInWeight = EntryWeight; - - if (EntryInWeight >= MaxExitOutWeight && - EntryInWeight * 2 < MaxExitOutWeight * 3) { - EntryIncoming->Weight = MaxExitOutWeight; - ExitOutgoing->Weight = EntryInWeight + 1; - } - - if (MaxEntryOutWeight >= MaxExitInWeight && - MaxEntryOutWeight * 2 < MaxExitInWeight * 3) { - EntryOutgoing->Weight = MaxExitInWeight; - ExitIncoming->Weight = MaxEntryOutWeight + 1; - } - } - - // Sort CFG edges based on its weight. - void sortEdgesByWeight() { - llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1, - const std::unique_ptr<Edge> &Edge2) { - return Edge1->Weight > Edge2->Weight; - }); - } - - // Traverse all the edges and compute the Minimum Weight Spanning Tree - // using union-find algorithm. - void computeMinimumSpanningTree() { - // First, put all the critical edge with landing-pad as the Dest to MST. - // This works around the insufficient support of critical edges split - // when destination BB is a landing pad. - for (auto &Ei : AllEdges) { - if (Ei->Removed) - continue; - if (Ei->IsCritical) { - if (Ei->DestBB && Ei->DestBB->isLandingPad()) { - if (unionGroups(Ei->SrcBB, Ei->DestBB)) - Ei->InMST = true; - } - } - } - - for (auto &Ei : AllEdges) { - if (Ei->Removed) - continue; - // If we detect infinite loops, force - // instrumenting the entry edge: - if (!ExitBlockFound && Ei->SrcBB == nullptr) - continue; - if (unionGroups(Ei->SrcBB, Ei->DestBB)) - Ei->InMST = true; - } - } - - // Dump the Debug information about the instrumentation. - void dumpEdges(raw_ostream &OS, const Twine &Message) const { - if (!Message.str().empty()) - OS << Message << "\n"; - OS << " Number of Basic Blocks: " << BBInfos.size() << "\n"; - for (auto &BI : BBInfos) { - const BasicBlock *BB = BI.first; - OS << " BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << " " - << BI.second->infoString() << "\n"; - } - - OS << " Number of Edges: " << AllEdges.size() - << " (*: Instrument, C: CriticalEdge, -: Removed)\n"; - uint32_t Count = 0; - for (auto &EI : AllEdges) - OS << " Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->" - << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n"; - } - - // Add an edge to AllEdges with weight W. - Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) { - uint32_t Index = BBInfos.size(); - auto Iter = BBInfos.end(); - bool Inserted; - std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr)); - if (Inserted) { - // Newly inserted, update the real info. - Iter->second = std::move(std::make_unique<BBInfo>(Index)); - Index++; - } - std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr)); - if (Inserted) - // Newly inserted, update the real info. - Iter->second = std::move(std::make_unique<BBInfo>(Index)); - AllEdges.emplace_back(new Edge(Src, Dest, W)); - return *AllEdges.back(); - } - - BranchProbabilityInfo *BPI; - BlockFrequencyInfo *BFI; - - // If function entry will be always instrumented. - bool InstrumentFuncEntry; - -public: - CFGMST(Function &Func, bool InstrumentFuncEntry_, - BranchProbabilityInfo *BPI_ = nullptr, - BlockFrequencyInfo *BFI_ = nullptr) - : F(Func), BPI(BPI_), BFI(BFI_), - InstrumentFuncEntry(InstrumentFuncEntry_) { - buildEdges(); - sortEdgesByWeight(); - computeMinimumSpanningTree(); - if (AllEdges.size() > 1 && InstrumentFuncEntry) - std::iter_swap(std::move(AllEdges.begin()), - std::move(AllEdges.begin() + AllEdges.size() - 1)); - } -}; - -} // end namespace llvm - -#undef DEBUG_TYPE // "cfgmst" - -#endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H diff --git a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp index 1c630e9ee424..d53e12ad1ff5 100644 --- a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp @@ -15,7 +15,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Transforms/Instrumentation.h" #include <optional> @@ -46,8 +45,7 @@ addModuleFlags(Module &M, } static bool runCGProfilePass( - Module &M, function_ref<BlockFrequencyInfo &(Function &)> GetBFI, - function_ref<TargetTransformInfo &(Function &)> GetTTI, bool LazyBFI) { + Module &M, FunctionAnalysisManager &FAM) { MapVector<std::pair<Function *, Function *>, uint64_t> Counts; InstrProfSymtab Symtab; auto UpdateCounts = [&](TargetTransformInfo &TTI, Function *F, @@ -64,15 +62,13 @@ static bool runCGProfilePass( (void)(bool) Symtab.create(M); for (auto &F : M) { // Avoid extra cost of running passes for BFI when the function doesn't have - // entry count. Since LazyBlockFrequencyInfoPass only exists in LPM, check - // if using LazyBlockFrequencyInfoPass. - // TODO: Remove LazyBFI when LazyBlockFrequencyInfoPass is available in NPM. - if (F.isDeclaration() || (LazyBFI && !F.getEntryCount())) + // entry count. + if (F.isDeclaration() || !F.getEntryCount()) continue; - auto &BFI = GetBFI(F); + auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); if (BFI.getEntryFreq() == 0) continue; - TargetTransformInfo &TTI = GetTTI(F); + TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); for (auto &BB : F) { std::optional<uint64_t> BBCount = BFI.getBlockProfileCount(&BB); if (!BBCount) @@ -105,14 +101,7 @@ static bool runCGProfilePass( PreservedAnalyses CGProfilePass::run(Module &M, ModuleAnalysisManager &MAM) { FunctionAnalysisManager &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { - return FAM.getResult<BlockFrequencyAnalysis>(F); - }; - auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & { - return FAM.getResult<TargetIRAnalysis>(F); - }; - - runCGProfilePass(M, GetBFI, GetTTI, false); + runCGProfilePass(M, FAM); return PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index a072ba278fce..3e3be536defc 100644 --- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -30,7 +30,6 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ProfDataUtils.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/MemoryBuffer.h" @@ -1888,8 +1887,7 @@ void CHR::fixupBranch(Region *R, CHRScope *Scope, assert((IsTrueBiased || Scope->FalseBiasedRegions.count(R)) && "Must be truthy or falsy"); auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); - assert(BranchBiasMap.find(R) != BranchBiasMap.end() && - "Must be in the bias map"); + assert(BranchBiasMap.contains(R) && "Must be in the bias map"); BranchProbability Bias = BranchBiasMap[R]; assert(Bias >= getCHRBiasThreshold() && "Must be highly biased"); // Take the min. @@ -1931,8 +1929,7 @@ void CHR::fixupSelect(SelectInst *SI, CHRScope *Scope, bool IsTrueBiased = Scope->TrueBiasedSelects.count(SI); assert((IsTrueBiased || Scope->FalseBiasedSelects.count(SI)) && "Must be biased"); - assert(SelectBiasMap.find(SI) != SelectBiasMap.end() && - "Must be in the bias map"); + assert(SelectBiasMap.contains(SI) && "Must be in the bias map"); BranchProbability Bias = SelectBiasMap[SI]; assert(Bias >= getCHRBiasThreshold() && "Must be highly biased"); // Take the min. @@ -1962,11 +1959,8 @@ void CHR::addToMergedCondition(bool IsTrueBiased, Value *Cond, Cond = IRB.CreateXor(ConstantInt::getTrue(F.getContext()), Cond); } - // Select conditions can be poison, while branching on poison is immediate - // undefined behavior. As such, we need to freeze potentially poisonous - // conditions derived from selects. - if (isa<SelectInst>(BranchOrSelect) && - !isGuaranteedNotToBeUndefOrPoison(Cond)) + // Freeze potentially poisonous conditions. + if (!isGuaranteedNotToBeUndefOrPoison(Cond)) Cond = IRB.CreateFreeze(Cond); // Use logical and to avoid propagating poison from later conditions. @@ -2080,10 +2074,14 @@ ControlHeightReductionPass::ControlHeightReductionPass() { PreservedAnalyses ControlHeightReductionPass::run( Function &F, FunctionAnalysisManager &FAM) { + auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + auto PPSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + // If there is no profile summary, we should not do CHR. + if (!PPSI || !PPSI->hasProfileSummary()) + return PreservedAnalyses::all(); + auto &PSI = *PPSI; auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); - auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F); - auto &PSI = *MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto &RI = FAM.getResult<RegionInfoAnalysis>(F); auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); bool Changed = CHR(F, BFI, DT, PSI, RI, ORE).run(); diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index e9614b48fde7..8caee5bed8ed 100644 --- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -67,12 +67,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -96,14 +97,13 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SpecialCaseList.h" #include "llvm/Support/VirtualFileSystem.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -305,6 +305,14 @@ const MemoryMapParams Linux_X86_64_MemoryMapParams = { }; // NOLINTEND(readability-identifier-naming) +// loongarch64 Linux +const MemoryMapParams Linux_LoongArch64_MemoryMapParams = { + 0, // AndMask (not used) + 0x500000000000, // XorMask + 0, // ShadowBase (not used) + 0x100000000000, // OriginBase +}; + namespace { class DFSanABIList { @@ -1128,6 +1136,9 @@ bool DataFlowSanitizer::initializeModule(Module &M) { case Triple::x86_64: MapParams = &Linux_X86_64_MemoryMapParams; break; + case Triple::loongarch64: + MapParams = &Linux_LoongArch64_MemoryMapParams; + break; default: report_fatal_error("unsupported architecture"); } @@ -1256,7 +1267,7 @@ void DataFlowSanitizer::addGlobalNameSuffix(GlobalValue *GV) { size_t Pos = Asm.find(SearchStr); if (Pos != std::string::npos) { Asm.replace(Pos, SearchStr.size(), ".symver " + GVName + Suffix + ","); - Pos = Asm.find("@"); + Pos = Asm.find('@'); if (Pos == std::string::npos) report_fatal_error(Twine("unsupported .symver: ", Asm)); @@ -2156,9 +2167,8 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowFast( ShadowSize == 4 ? Type::getInt32Ty(*DFS.Ctx) : Type::getInt64Ty(*DFS.Ctx); IRBuilder<> IRB(Pos); - Value *WideAddr = IRB.CreateBitCast(ShadowAddr, WideShadowTy->getPointerTo()); Value *CombinedWideShadow = - IRB.CreateAlignedLoad(WideShadowTy, WideAddr, ShadowAlign); + IRB.CreateAlignedLoad(WideShadowTy, ShadowAddr, ShadowAlign); unsigned WideShadowBitWidth = WideShadowTy->getIntegerBitWidth(); const uint64_t BytesPerWideShadow = WideShadowBitWidth / DFS.ShadowWidthBits; @@ -2195,10 +2205,10 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowFast( // shadow). for (uint64_t ByteOfs = BytesPerWideShadow; ByteOfs < Size; ByteOfs += BytesPerWideShadow) { - WideAddr = IRB.CreateGEP(WideShadowTy, WideAddr, - ConstantInt::get(DFS.IntptrTy, 1)); + ShadowAddr = IRB.CreateGEP(WideShadowTy, ShadowAddr, + ConstantInt::get(DFS.IntptrTy, 1)); Value *NextWideShadow = - IRB.CreateAlignedLoad(WideShadowTy, WideAddr, ShadowAlign); + IRB.CreateAlignedLoad(WideShadowTy, ShadowAddr, ShadowAlign); CombinedWideShadow = IRB.CreateOr(CombinedWideShadow, NextWideShadow); if (ShouldTrackOrigins) { Value *NextOrigin = DFS.loadNextOrigin(Pos, OriginAlign, &OriginAddr); @@ -2526,8 +2536,9 @@ void DFSanFunction::storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size, ConstantInt::get(DFS.IntptrTy, Size), Origin}); } else { Value *Cmp = convertToBool(CollapsedShadow, IRB, "_dfscmp"); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); Instruction *CheckTerm = SplitBlockAndInsertIfThen( - Cmp, &*IRB.GetInsertPoint(), false, DFS.OriginStoreWeights, &DT); + Cmp, &*IRB.GetInsertPoint(), false, DFS.OriginStoreWeights, &DTU); IRBuilder<> IRBNew(CheckTerm); paintOrigin(IRBNew, updateOrigin(Origin, IRBNew), StoreOriginAddr, Size, OriginAlignment); diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 9f3ca8b02fd9..21f0b1a92293 100644 --- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -13,7 +13,6 @@ // //===----------------------------------------------------------------------===// -#include "CFGMST.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" @@ -21,10 +20,10 @@ #include "llvm/ADT/StringMap.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugLoc.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -38,6 +37,7 @@ #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Instrumentation/CFGMST.h" #include "llvm/Transforms/Instrumentation/GCOVProfiler.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> @@ -919,15 +919,21 @@ bool GCOVProfiler::emitProfileNotes( IRBuilder<> Builder(E.Place, E.Place->getFirstInsertionPt()); Value *V = Builder.CreateConstInBoundsGEP2_64( Counters->getValueType(), Counters, 0, I); + // Disable sanitizers to decrease size bloat. We don't expect + // sanitizers to catch interesting issues. + Instruction *Inst; if (Options.Atomic) { - Builder.CreateAtomicRMW(AtomicRMWInst::Add, V, Builder.getInt64(1), - MaybeAlign(), AtomicOrdering::Monotonic); + Inst = Builder.CreateAtomicRMW(AtomicRMWInst::Add, V, + Builder.getInt64(1), MaybeAlign(), + AtomicOrdering::Monotonic); } else { - Value *Count = + LoadInst *OldCount = Builder.CreateLoad(Builder.getInt64Ty(), V, "gcov_ctr"); - Count = Builder.CreateAdd(Count, Builder.getInt64(1)); - Builder.CreateStore(Count, V); + OldCount->setNoSanitizeMetadata(); + Value *NewCount = Builder.CreateAdd(OldCount, Builder.getInt64(1)); + Inst = Builder.CreateStore(NewCount, V); } + Inst->setNoSanitizeMetadata(); } } } diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 34c61f83ad30..28db47a19092 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -1,4 +1,4 @@ -//===- HWAddressSanitizer.cpp - detector of uninitialized reads -------===// +//===- HWAddressSanitizer.cpp - memory access error detector --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -17,7 +17,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" @@ -50,6 +49,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" @@ -136,14 +136,6 @@ static cl::opt<bool> cl::desc("detect use after scope within function"), cl::Hidden, cl::init(false)); -static cl::opt<bool> ClUARRetagToZero( - "hwasan-uar-retag-to-zero", - cl::desc("Clear alloca tags before returning from the function to allow " - "non-instrumented and instrumented function calls mix. When set " - "to false, allocas are retagged before returning from the " - "function to detect use after return."), - cl::Hidden, cl::init(true)); - static cl::opt<bool> ClGenerateTagsWithCalls( "hwasan-generate-tags-with-calls", cl::desc("generate new tags with runtime library calls"), cl::Hidden, @@ -247,7 +239,9 @@ bool shouldInstrumentStack(const Triple &TargetTriple) { } bool shouldInstrumentWithCalls(const Triple &TargetTriple) { - return ClInstrumentWithCalls || TargetTriple.getArch() == Triple::x86_64; + return ClInstrumentWithCalls.getNumOccurrences() + ? ClInstrumentWithCalls + : TargetTriple.getArch() == Triple::x86_64; } bool mightUseStackSafetyAnalysis(bool DisableOptimization) { @@ -282,7 +276,7 @@ public: void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; } - bool sanitizeFunction(Function &F, FunctionAnalysisManager &FAM); + void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM); void initializeModule(); void createHwasanCtorComdat(); @@ -313,16 +307,15 @@ public: void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); - bool instrumentStack(memtag::StackInfo &Info, Value *StackTag, + bool instrumentStack(memtag::StackInfo &Info, Value *StackTag, Value *UARTag, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI); Value *readRegister(IRBuilder<> &IRB, StringRef Name); bool instrumentLandingPads(SmallVectorImpl<Instruction *> &RetVec); Value *getNextTagWithCall(IRBuilder<> &IRB); Value *getStackBaseTag(IRBuilder<> &IRB); - Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, AllocaInst *AI, - unsigned AllocaNo); - Value *getUARTag(IRBuilder<> &IRB, Value *StackTag); + Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, unsigned AllocaNo); + Value *getUARTag(IRBuilder<> &IRB); Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty); Value *applyTagMask(IRBuilder<> &IRB, Value *OldTag); @@ -344,8 +337,6 @@ private: Module &M; const StackSafetyGlobalInfo *SSI; Triple TargetTriple; - FunctionCallee HWAsanMemmove, HWAsanMemcpy, HWAsanMemset; - FunctionCallee HWAsanHandleVfork; /// This struct defines the shadow mapping using the rule: /// shadow = (mem >> Scale) + Offset. @@ -387,6 +378,7 @@ private: bool InstrumentStack; bool DetectUseAfterScope; bool UsePageAliases; + bool UseMatchAllCallback; std::optional<uint8_t> MatchAllTag; @@ -398,6 +390,9 @@ private: FunctionCallee HwasanMemoryAccessCallback[2][kNumberOfAccessSizes]; FunctionCallee HwasanMemoryAccessCallbackSized[2]; + FunctionCallee HwasanMemmove, HwasanMemcpy, HwasanMemset; + FunctionCallee HwasanHandleVfork; + FunctionCallee HwasanTagMemoryFunc; FunctionCallee HwasanGenerateTagFunc; FunctionCallee HwasanRecordFrameRecordFunc; @@ -420,12 +415,9 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M, SSI = &MAM.getResult<StackSafetyGlobalAnalysis>(M); HWAddressSanitizer HWASan(M, Options.CompileKernel, Options.Recover, SSI); - bool Modified = false; auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); for (Function &F : M) - Modified |= HWASan.sanitizeFunction(F, FAM); - if (!Modified) - return PreservedAnalyses::all(); + HWASan.sanitizeFunction(F, FAM); PreservedAnalyses PA = PreservedAnalyses::none(); // GlobalsAA is considered stateless and does not get invalidated unless @@ -438,12 +430,12 @@ void HWAddressSanitizerPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<HWAddressSanitizerPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (Options.CompileKernel) OS << "kernel;"; if (Options.Recover) OS << "recover"; - OS << ">"; + OS << '>'; } void HWAddressSanitizer::createHwasanCtorComdat() { @@ -594,6 +586,7 @@ void HWAddressSanitizer::initializeModule() { } else if (CompileKernel) { MatchAllTag = 0xFF; } + UseMatchAllCallback = !CompileKernel && MatchAllTag.has_value(); // If we don't have personality function support, fall back to landing pads. InstrumentLandingPads = ClInstrumentLandingPads.getNumOccurrences() @@ -631,51 +624,73 @@ void HWAddressSanitizer::initializeModule() { void HWAddressSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); + const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : ""; + FunctionType *HwasanMemoryAccessCallbackSizedFnTy, + *HwasanMemoryAccessCallbackFnTy, *HwasanMemTransferFnTy, + *HwasanMemsetFnTy; + if (UseMatchAllCallback) { + HwasanMemoryAccessCallbackSizedFnTy = + FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false); + HwasanMemoryAccessCallbackFnTy = + FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false); + HwasanMemTransferFnTy = FunctionType::get( + Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy, Int8Ty}, false); + HwasanMemsetFnTy = FunctionType::get( + Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}, false); + } else { + HwasanMemoryAccessCallbackSizedFnTy = + FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false); + HwasanMemoryAccessCallbackFnTy = + FunctionType::get(VoidTy, {IntptrTy}, false); + HwasanMemTransferFnTy = + FunctionType::get(Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy}, false); + HwasanMemsetFnTy = + FunctionType::get(Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy}, false); + } + for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { const std::string TypeStr = AccessIsWrite ? "store" : "load"; const std::string EndingStr = Recover ? "_noabort" : ""; HwasanMemoryAccessCallbackSized[AccessIsWrite] = M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + TypeStr + "N" + EndingStr, - FunctionType::get(IRB.getVoidTy(), {IntptrTy, IntptrTy}, false)); + ClMemoryAccessCallbackPrefix + TypeStr + "N" + MatchAllStr + EndingStr, + HwasanMemoryAccessCallbackSizedFnTy); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { HwasanMemoryAccessCallback[AccessIsWrite][AccessSizeIndex] = - M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + TypeStr + - itostr(1ULL << AccessSizeIndex) + EndingStr, - FunctionType::get(IRB.getVoidTy(), {IntptrTy}, false)); + M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + TypeStr + + itostr(1ULL << AccessSizeIndex) + + MatchAllStr + EndingStr, + HwasanMemoryAccessCallbackFnTy); } } - HwasanTagMemoryFunc = M.getOrInsertFunction( - "__hwasan_tag_memory", IRB.getVoidTy(), Int8PtrTy, Int8Ty, IntptrTy); + const std::string MemIntrinCallbackPrefix = + (CompileKernel && !ClKasanMemIntrinCallbackPrefix) + ? std::string("") + : ClMemoryAccessCallbackPrefix; + + HwasanMemmove = M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memmove" + MatchAllStr, HwasanMemTransferFnTy); + HwasanMemcpy = M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memcpy" + MatchAllStr, HwasanMemTransferFnTy); + HwasanMemset = M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memset" + MatchAllStr, HwasanMemsetFnTy); + + HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy, + Int8PtrTy, Int8Ty, IntptrTy); HwasanGenerateTagFunc = M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty); - HwasanRecordFrameRecordFunc = M.getOrInsertFunction( - "__hwasan_add_frame_record", IRB.getVoidTy(), Int64Ty); + HwasanRecordFrameRecordFunc = + M.getOrInsertFunction("__hwasan_add_frame_record", VoidTy, Int64Ty); - ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow", - ArrayType::get(IRB.getInt8Ty(), 0)); + ShadowGlobal = + M.getOrInsertGlobal("__hwasan_shadow", ArrayType::get(Int8Ty, 0)); - const std::string MemIntrinCallbackPrefix = - (CompileKernel && !ClKasanMemIntrinCallbackPrefix) - ? std::string("") - : ClMemoryAccessCallbackPrefix; - HWAsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy); - HWAsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy); - HWAsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy); - - HWAsanHandleVfork = - M.getOrInsertFunction("__hwasan_handle_vfork", IRB.getVoidTy(), IntptrTy); + HwasanHandleVfork = + M.getOrInsertFunction("__hwasan_handle_vfork", VoidTy, IntptrTy); } Value *HWAddressSanitizer::getOpaqueNoopCast(IRBuilder<> &IRB, Value *Val) { @@ -788,7 +803,7 @@ static unsigned getPointerOperandIndex(Instruction *I) { } static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { - size_t Res = countTrailingZeros(TypeSize / 8); + size_t Res = llvm::countr_zero(TypeSize / 8); assert(Res < kNumberOfAccessSizes); return Res; } @@ -847,8 +862,8 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, IRBuilder<> IRB(InsertBefore); Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); - Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), - IRB.getInt8Ty()); + Value *PtrTag = + IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), Int8Ty); Value *AddrLong = untagPointer(IRB, PtrLong); Value *Shadow = memToShadow(AddrLong, IRB); Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow); @@ -897,7 +912,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, case Triple::x86_64: // The signal handler will find the data address in rdi. Asm = InlineAsm::get( - FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {PtrLong->getType()}, false), "int3\nnopl " + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)) + "(%rax)", @@ -908,7 +923,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, case Triple::aarch64_be: // The signal handler will find the data address in x0. Asm = InlineAsm::get( - FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {PtrLong->getType()}, false), "brk #" + itostr(0x900 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), "{x0}", /*hasSideEffects=*/true); @@ -916,7 +931,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, case Triple::riscv64: // The signal handler will find the data address in x10. Asm = InlineAsm::get( - FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {PtrLong->getType()}, false), "ebreak\naddiw x0, x11, " + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), "{x10}", @@ -943,17 +958,35 @@ bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) { void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { IRBuilder<> IRB(MI); if (isa<MemTransferInst>(MI)) { - IRB.CreateCall( - isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + if (UseMatchAllCallback) { + IRB.CreateCall( + isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false), + ConstantInt::get(Int8Ty, *MatchAllTag)}); + } else { + IRB.CreateCall( + isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + } } else if (isa<MemSetInst>(MI)) { - IRB.CreateCall( - HWAsanMemset, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + if (UseMatchAllCallback) { + IRB.CreateCall( + HwasanMemset, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false), + ConstantInt::get(Int8Ty, *MatchAllTag)}); + } else { + IRB.CreateCall( + HwasanMemset, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + } } MI->eraseFromParent(); } @@ -967,23 +1000,40 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) { return false; // FIXME IRBuilder<> IRB(O.getInsn()); - if (isPowerOf2_64(O.TypeSize) && - (O.TypeSize / 8 <= (1ULL << (kNumberOfAccessSizes - 1))) && + if (!O.TypeStoreSize.isScalable() && isPowerOf2_64(O.TypeStoreSize) && + (O.TypeStoreSize / 8 <= (1ULL << (kNumberOfAccessSizes - 1))) && (!O.Alignment || *O.Alignment >= Mapping.getObjectAlignment() || - *O.Alignment >= O.TypeSize / 8)) { - size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeSize); + *O.Alignment >= O.TypeStoreSize / 8)) { + size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeStoreSize); if (InstrumentWithCalls) { - IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], - IRB.CreatePointerCast(Addr, IntptrTy)); + if (UseMatchAllCallback) { + IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], + {IRB.CreatePointerCast(Addr, IntptrTy), + ConstantInt::get(Int8Ty, *MatchAllTag)}); + } else { + IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], + IRB.CreatePointerCast(Addr, IntptrTy)); + } } else if (OutlinedChecks) { instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn()); } else { instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn()); } } else { - IRB.CreateCall(HwasanMemoryAccessCallbackSized[O.IsWrite], - {IRB.CreatePointerCast(Addr, IntptrTy), - ConstantInt::get(IntptrTy, O.TypeSize / 8)}); + if (UseMatchAllCallback) { + IRB.CreateCall( + HwasanMemoryAccessCallbackSized[O.IsWrite], + {IRB.CreatePointerCast(Addr, IntptrTy), + IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize), + ConstantInt::get(IntptrTy, 8)), + ConstantInt::get(Int8Ty, *MatchAllTag)}); + } else { + IRB.CreateCall( + HwasanMemoryAccessCallbackSized[O.IsWrite], + {IRB.CreatePointerCast(Addr, IntptrTy), + IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize), + ConstantInt::get(IntptrTy, 8))}); + } } untagPointerOperand(O.getInsn(), Addr); @@ -996,14 +1046,15 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, if (!UseShortGranules) Size = AlignedSize; - Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty()); + Tag = IRB.CreateTrunc(Tag, Int8Ty); if (InstrumentWithCalls) { IRB.CreateCall(HwasanTagMemoryFunc, - {IRB.CreatePointerCast(AI, Int8PtrTy), JustTag, + {IRB.CreatePointerCast(AI, Int8PtrTy), Tag, ConstantInt::get(IntptrTy, AlignedSize)}); } else { size_t ShadowSize = Size >> Mapping.Scale; - Value *ShadowPtr = memToShadow(IRB.CreatePointerCast(AI, IntptrTy), IRB); + Value *AddrLong = untagPointer(IRB, IRB.CreatePointerCast(AI, IntptrTy)); + Value *ShadowPtr = memToShadow(AddrLong, IRB); // If this memset is not inlined, it will be intercepted in the hwasan // runtime library. That's OK, because the interceptor skips the checks if // the address is in the shadow region. @@ -1011,14 +1062,14 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, // llvm.memset right here into either a sequence of stores, or a call to // hwasan_tag_memory. if (ShadowSize) - IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, Align(1)); + IRB.CreateMemSet(ShadowPtr, Tag, ShadowSize, Align(1)); if (Size != AlignedSize) { const uint8_t SizeRemainder = Size % Mapping.getObjectAlignment().value(); IRB.CreateStore(ConstantInt::get(Int8Ty, SizeRemainder), IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize)); - IRB.CreateStore(JustTag, IRB.CreateConstGEP1_32( - Int8Ty, IRB.CreateBitCast(AI, Int8PtrTy), - AlignedSize - 1)); + IRB.CreateStore(Tag, IRB.CreateConstGEP1_32( + Int8Ty, IRB.CreatePointerCast(AI, Int8PtrTy), + AlignedSize - 1)); } } } @@ -1037,21 +1088,18 @@ unsigned HWAddressSanitizer::retagMask(unsigned AllocaNo) { // mask allocated (temporally) nearby. The program that generated this list // can be found at: // https://github.com/google/sanitizers/blob/master/hwaddress-sanitizer/sort_masks.py - static unsigned FastMasks[] = {0, 128, 64, 192, 32, 96, 224, 112, 240, - 48, 16, 120, 248, 56, 24, 8, 124, 252, - 60, 28, 12, 4, 126, 254, 62, 30, 14, - 6, 2, 127, 63, 31, 15, 7, 3, 1}; + static const unsigned FastMasks[] = { + 0, 128, 64, 192, 32, 96, 224, 112, 240, 48, 16, 120, + 248, 56, 24, 8, 124, 252, 60, 28, 12, 4, 126, 254, + 62, 30, 14, 6, 2, 127, 63, 31, 15, 7, 3, 1}; return FastMasks[AllocaNo % std::size(FastMasks)]; } Value *HWAddressSanitizer::applyTagMask(IRBuilder<> &IRB, Value *OldTag) { - if (TargetTriple.getArch() == Triple::x86_64) { - Constant *TagMask = ConstantInt::get(IntptrTy, TagMaskByte); - Value *NewTag = IRB.CreateAnd(OldTag, TagMask); - return NewTag; - } - // aarch64 uses 8-bit tags, so no mask is needed. - return OldTag; + if (TagMaskByte == 0xFF) + return OldTag; // No need to clear the tag byte. + return IRB.CreateAnd(OldTag, + ConstantInt::get(OldTag->getType(), TagMaskByte)); } Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) { @@ -1060,7 +1108,7 @@ Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) { Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { if (ClGenerateTagsWithCalls) - return getNextTagWithCall(IRB); + return nullptr; if (StackBaseTag) return StackBaseTag; // Extract some entropy from the stack pointer for the tags. @@ -1075,19 +1123,20 @@ Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { } Value *HWAddressSanitizer::getAllocaTag(IRBuilder<> &IRB, Value *StackTag, - AllocaInst *AI, unsigned AllocaNo) { + unsigned AllocaNo) { if (ClGenerateTagsWithCalls) return getNextTagWithCall(IRB); - return IRB.CreateXor(StackTag, - ConstantInt::get(IntptrTy, retagMask(AllocaNo))); + return IRB.CreateXor( + StackTag, ConstantInt::get(StackTag->getType(), retagMask(AllocaNo))); } -Value *HWAddressSanitizer::getUARTag(IRBuilder<> &IRB, Value *StackTag) { - if (ClUARRetagToZero) - return ConstantInt::get(IntptrTy, 0); - if (ClGenerateTagsWithCalls) - return getNextTagWithCall(IRB); - return IRB.CreateXor(StackTag, ConstantInt::get(IntptrTy, TagMaskByte)); +Value *HWAddressSanitizer::getUARTag(IRBuilder<> &IRB) { + Value *StackPointerLong = getSP(IRB); + Value *UARTag = + applyTagMask(IRB, IRB.CreateLShr(StackPointerLong, PointerTagShift)); + + UARTag->setName("hwasan.uar.tag"); + return UARTag; } // Add a tag to an address. @@ -1117,12 +1166,12 @@ Value *HWAddressSanitizer::untagPointer(IRBuilder<> &IRB, Value *PtrLong) { // Kernel addresses have 0xFF in the most significant byte. UntaggedPtrLong = IRB.CreateOr(PtrLong, ConstantInt::get(PtrLong->getType(), - 0xFFULL << PointerTagShift)); + TagMaskByte << PointerTagShift)); } else { // Userspace addresses have 0x00. - UntaggedPtrLong = - IRB.CreateAnd(PtrLong, ConstantInt::get(PtrLong->getType(), - ~(0xFFULL << PointerTagShift))); + UntaggedPtrLong = IRB.CreateAnd( + PtrLong, ConstantInt::get(PtrLong->getType(), + ~(TagMaskByte << PointerTagShift))); } return UntaggedPtrLong; } @@ -1135,8 +1184,7 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { Function *ThreadPointerFunc = Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); Value *SlotPtr = IRB.CreatePointerCast( - IRB.CreateConstGEP1_32(IRB.getInt8Ty(), - IRB.CreateCall(ThreadPointerFunc), 0x30), + IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc), 0x30), Ty->getPointerTo(0)); return SlotPtr; } @@ -1162,8 +1210,7 @@ Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) { M, Intrinsic::frameaddress, IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); CachedSP = IRB.CreatePtrToInt( - IRB.CreateCall(GetStackPointerFn, - {Constant::getNullValue(IRB.getInt32Ty())}), + IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(Int32Ty)}), IntptrTy); } return CachedSP; @@ -1280,7 +1327,7 @@ bool HWAddressSanitizer::instrumentLandingPads( for (auto *LP : LandingPadVec) { IRBuilder<> IRB(LP->getNextNode()); IRB.CreateCall( - HWAsanHandleVfork, + HwasanHandleVfork, {readRegister(IRB, (TargetTriple.getArch() == Triple::x86_64) ? "rsp" : "sp")}); } @@ -1293,7 +1340,7 @@ static bool isLifetimeIntrinsic(Value *V) { } bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, - Value *StackTag, + Value *StackTag, Value *UARTag, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI) { @@ -1311,9 +1358,10 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, IRBuilder<> IRB(AI->getNextNode()); // Replace uses of the alloca with tagged address. - Value *Tag = getAllocaTag(IRB, StackTag, AI, N); + Value *Tag = getAllocaTag(IRB, StackTag, N); Value *AILong = IRB.CreatePointerCast(AI, IntptrTy); - Value *Replacement = tagPointer(IRB, AI->getType(), AILong, Tag); + Value *AINoTagLong = untagPointer(IRB, AILong); + Value *Replacement = tagPointer(IRB, AI->getType(), AINoTagLong, Tag); std::string Name = AI->hasName() ? AI->getName().str() : "alloca." + itostr(N); Replacement->setName(Name + ".hwasan"); @@ -1340,7 +1388,7 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, llvm::for_each(Info.LifetimeStart, HandleLifetime); llvm::for_each(Info.LifetimeEnd, HandleLifetime); - AI->replaceUsesWithIf(Replacement, [AICast, AILong](Use &U) { + AI->replaceUsesWithIf(Replacement, [AICast, AILong](const Use &U) { auto *User = U.getUser(); return User != AILong && User != AICast && !isLifetimeIntrinsic(User); }); @@ -1359,9 +1407,8 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, auto TagEnd = [&](Instruction *Node) { IRB.SetInsertPoint(Node); - Value *UARTag = getUARTag(IRB, StackTag); // When untagging, use the `AlignedSize` because we need to set the tags - // for the entire alloca to zero. If we used `Size` here, we would + // for the entire alloca to original. If we used `Size` here, we would // keep the last granule tagged, and store zero in the last byte of the // last granule, due to how short granules are implemented. tagAlloca(IRB, AI, UARTag, AlignedSize); @@ -1402,13 +1449,13 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, return true; } -bool HWAddressSanitizer::sanitizeFunction(Function &F, +void HWAddressSanitizer::sanitizeFunction(Function &F, FunctionAnalysisManager &FAM) { if (&F == HwasanCtorFunction) - return false; + return; if (!F.hasFnAttribute(Attribute::SanitizeHWAddress)) - return false; + return; LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n"); @@ -1436,22 +1483,19 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F, initializeCallbacks(*F.getParent()); - bool Changed = false; - if (!LandingPadVec.empty()) - Changed |= instrumentLandingPads(LandingPadVec); + instrumentLandingPads(LandingPadVec); if (SInfo.AllocasToInstrument.empty() && F.hasPersonalityFn() && F.getPersonalityFn()->getName() == kHwasanPersonalityThunkName) { // __hwasan_personality_thunk is a no-op for functions without an // instrumented stack, so we can drop it. F.setPersonalityFn(nullptr); - Changed = true; } if (SInfo.AllocasToInstrument.empty() && OperandsToInstrument.empty() && IntrinToInstrument.empty()) - return Changed; + return; assert(!ShadowBase); @@ -1466,9 +1510,9 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F, const DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); const PostDominatorTree &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); const LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); - Value *StackTag = - ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB); - instrumentStack(SInfo, StackTag, DT, PDT, LI); + Value *StackTag = getStackBaseTag(EntryIRB); + Value *UARTag = getUARTag(EntryIRB); + instrumentStack(SInfo, StackTag, UARTag, DT, PDT, LI); } // If we split the entry block, move any allocas that were originally in the @@ -1495,8 +1539,6 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F, ShadowBase = nullptr; StackBaseTag = nullptr; CachedSP = nullptr; - - return true; } void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) { @@ -1605,11 +1647,14 @@ void HWAddressSanitizer::instrumentGlobals() { Hasher.final(Hash); uint8_t Tag = Hash[0]; + assert(TagMaskByte >= 16); + for (GlobalVariable *GV : Globals) { - Tag &= TagMaskByte; - // Skip tag 0 in order to avoid collisions with untagged memory. - if (Tag == 0) - Tag = 1; + // Don't allow globals to be tagged with something that looks like a + // short-granule tag, otherwise we lose inter-granule overflow detection, as + // the fast path shadow-vs-address check succeeds. + if (Tag < 16 || Tag > TagMaskByte) + Tag = 16; instrumentGlobal(GV, Tag++); } } diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index b66e761d53b0..5c9799235017 100644 --- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -104,25 +104,24 @@ static cl::opt<bool> namespace { -// The class for main data structure to promote indirect calls to conditional -// direct calls. -class ICallPromotionFunc { +// Promote indirect calls to conditional direct calls, keeping track of +// thresholds. +class IndirectCallPromoter { private: Function &F; - Module *M; // Symtab that maps indirect call profile values to function names and // defines. - InstrProfSymtab *Symtab; + InstrProfSymtab *const Symtab; - bool SamplePGO; + const bool SamplePGO; OptimizationRemarkEmitter &ORE; // A struct that records the direct target and it's call count. struct PromotionCandidate { - Function *TargetFunction; - uint64_t Count; + Function *const TargetFunction; + const uint64_t Count; PromotionCandidate(Function *F, uint64_t C) : TargetFunction(F), Count(C) {} }; @@ -143,11 +142,11 @@ private: uint64_t &TotalCount); public: - ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab, - bool SamplePGO, OptimizationRemarkEmitter &ORE) - : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO), ORE(ORE) {} - ICallPromotionFunc(const ICallPromotionFunc &) = delete; - ICallPromotionFunc &operator=(const ICallPromotionFunc &) = delete; + IndirectCallPromoter(Function &Func, InstrProfSymtab *Symtab, bool SamplePGO, + OptimizationRemarkEmitter &ORE) + : F(Func), Symtab(Symtab), SamplePGO(SamplePGO), ORE(ORE) {} + IndirectCallPromoter(const IndirectCallPromoter &) = delete; + IndirectCallPromoter &operator=(const IndirectCallPromoter &) = delete; bool processFunction(ProfileSummaryInfo *PSI); }; @@ -156,8 +155,8 @@ public: // Indirect-call promotion heuristic. The direct targets are sorted based on // the count. Stop at the first target that is not promoted. -std::vector<ICallPromotionFunc::PromotionCandidate> -ICallPromotionFunc::getPromotionCandidatesForCallSite( +std::vector<IndirectCallPromoter::PromotionCandidate> +IndirectCallPromoter::getPromotionCandidatesForCallSite( const CallBase &CB, const ArrayRef<InstrProfValueData> &ValueDataRef, uint64_t TotalCount, uint32_t NumCandidates) { std::vector<PromotionCandidate> Ret; @@ -276,7 +275,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee, } // Promote indirect-call to conditional direct-call for one callsite. -uint32_t ICallPromotionFunc::tryToPromote( +uint32_t IndirectCallPromoter::tryToPromote( CallBase &CB, const std::vector<PromotionCandidate> &Candidates, uint64_t &TotalCount) { uint32_t NumPromoted = 0; @@ -295,7 +294,7 @@ uint32_t ICallPromotionFunc::tryToPromote( // Traverse all the indirect-call callsite and get the value profile // annotation to perform indirect-call promotion. -bool ICallPromotionFunc::processFunction(ProfileSummaryInfo *PSI) { +bool IndirectCallPromoter::processFunction(ProfileSummaryInfo *PSI) { bool Changed = false; ICallPromotionAnalysis ICallAnalysis; for (auto *CB : findIndirectCalls(F)) { @@ -319,16 +318,15 @@ bool ICallPromotionFunc::processFunction(ProfileSummaryInfo *PSI) { if (TotalCount == 0 || NumPromoted == NumVals) continue; // Otherwise we need update with the un-promoted records back. - annotateValueSite(*M, *CB, ICallProfDataRef.slice(NumPromoted), TotalCount, - IPVK_IndirectCallTarget, NumCandidates); + annotateValueSite(*F.getParent(), *CB, ICallProfDataRef.slice(NumPromoted), + TotalCount, IPVK_IndirectCallTarget, NumCandidates); } return Changed; } // A wrapper function that does the actual work. -static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, - bool InLTO, bool SamplePGO, - ModuleAnalysisManager *AM = nullptr) { +static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, bool InLTO, + bool SamplePGO, ModuleAnalysisManager &MAM) { if (DisableICP) return false; InstrProfSymtab Symtab; @@ -342,19 +340,12 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, if (F.isDeclaration() || F.hasOptNone()) continue; - std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; - OptimizationRemarkEmitter *ORE; - if (AM) { - auto &FAM = - AM->getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - } else { - OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F); - ORE = OwnedORE.get(); - } + auto &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO, *ORE); - bool FuncChanged = ICallPromotion.processFunction(PSI); + IndirectCallPromoter CallPromoter(F, &Symtab, SamplePGO, ORE); + bool FuncChanged = CallPromoter.processFunction(PSI); if (ICPDUMPAFTER && FuncChanged) { LLVM_DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); @@ -369,11 +360,11 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, } PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, - ModuleAnalysisManager &AM) { - ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + ModuleAnalysisManager &MAM) { + ProfileSummaryInfo *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M); if (!promoteIndirectCalls(M, PSI, InLTO | ICPLTOMode, - SamplePGO | ICPSamplePGOMode, &AM)) + SamplePGO | ICPSamplePGOMode, MAM)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp index d7561c193aa3..6882dd83f429 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp @@ -15,9 +15,6 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index c0409206216e..a7b1953ce81c 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" @@ -47,6 +46,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -421,6 +421,9 @@ bool InstrProfiling::lowerIntrinsics(Function *F) { } else if (auto *IPI = dyn_cast<InstrProfIncrementInst>(&Instr)) { lowerIncrement(IPI); MadeChange = true; + } else if (auto *IPC = dyn_cast<InstrProfTimestampInst>(&Instr)) { + lowerTimestamp(IPC); + MadeChange = true; } else if (auto *IPC = dyn_cast<InstrProfCoverInst>(&Instr)) { lowerCover(IPC); MadeChange = true; @@ -510,6 +513,7 @@ static bool containsProfilingIntrinsics(Module &M) { return containsIntrinsic(llvm::Intrinsic::instrprof_cover) || containsIntrinsic(llvm::Intrinsic::instrprof_increment) || containsIntrinsic(llvm::Intrinsic::instrprof_increment_step) || + containsIntrinsic(llvm::Intrinsic::instrprof_timestamp) || containsIntrinsic(llvm::Intrinsic::instrprof_value_profile); } @@ -540,18 +544,19 @@ bool InstrProfiling::run( // the instrumented function. This is counting the number of instrumented // target value sites to enter it as field in the profile data variable. for (Function &F : M) { - InstrProfIncrementInst *FirstProfIncInst = nullptr; + InstrProfInstBase *FirstProfInst = nullptr; for (BasicBlock &BB : F) for (auto I = BB.begin(), E = BB.end(); I != E; I++) if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(I)) computeNumValueSiteCounts(Ind); - else if (FirstProfIncInst == nullptr) - FirstProfIncInst = dyn_cast<InstrProfIncrementInst>(I); + else if (FirstProfInst == nullptr && + (isa<InstrProfIncrementInst>(I) || isa<InstrProfCoverInst>(I))) + FirstProfInst = dyn_cast<InstrProfInstBase>(I); // Value profiling intrinsic lowering requires per-function profile data // variable to be created first. - if (FirstProfIncInst != nullptr) - static_cast<void>(getOrCreateRegionCounters(FirstProfIncInst)); + if (FirstProfInst != nullptr) + static_cast<void>(getOrCreateRegionCounters(FirstProfInst)); } for (Function &F : M) @@ -669,6 +674,9 @@ Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) { auto *Counters = getOrCreateRegionCounters(I); IRBuilder<> Builder(I); + if (isa<InstrProfTimestampInst>(I)) + Counters->setAlignment(Align(8)); + auto *Addr = Builder.CreateConstInBoundsGEP2_32( Counters->getValueType(), Counters, 0, I->getIndex()->getZExtValue()); @@ -710,6 +718,21 @@ void InstrProfiling::lowerCover(InstrProfCoverInst *CoverInstruction) { CoverInstruction->eraseFromParent(); } +void InstrProfiling::lowerTimestamp( + InstrProfTimestampInst *TimestampInstruction) { + assert(TimestampInstruction->getIndex()->isZeroValue() && + "timestamp probes are always the first probe for a function"); + auto &Ctx = M->getContext(); + auto *TimestampAddr = getCounterAddress(TimestampInstruction); + IRBuilder<> Builder(TimestampInstruction); + auto *CalleeTy = + FunctionType::get(Type::getVoidTy(Ctx), TimestampAddr->getType(), false); + auto Callee = M->getOrInsertFunction( + INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_SET_TIMESTAMP), CalleeTy); + Builder.CreateCall(Callee, {TimestampAddr}); + TimestampInstruction->eraseFromParent(); +} + void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { auto *Addr = getCounterAddress(Inc); @@ -823,6 +846,72 @@ static inline bool shouldRecordFunctionAddr(Function *F) { return F->hasAddressTaken() || F->hasLinkOnceLinkage(); } +static inline bool shouldUsePublicSymbol(Function *Fn) { + // It isn't legal to make an alias of this function at all + if (Fn->isDeclarationForLinker()) + return true; + + // Symbols with local linkage can just use the symbol directly without + // introducing relocations + if (Fn->hasLocalLinkage()) + return true; + + // PGO + ThinLTO + CFI cause duplicate symbols to be introduced due to some + // unfavorable interaction between the new alias and the alias renaming done + // in LowerTypeTests under ThinLTO. For comdat functions that would normally + // be deduplicated, but the renaming scheme ends up preventing renaming, since + // it creates unique names for each alias, resulting in duplicated symbols. In + // the future, we should update the CFI related passes to migrate these + // aliases to the same module as the jump-table they refer to will be defined. + if (Fn->hasMetadata(LLVMContext::MD_type)) + return true; + + // For comdat functions, an alias would need the same linkage as the original + // function and hidden visibility. There is no point in adding an alias with + // identical linkage an visibility to avoid introducing symbolic relocations. + if (Fn->hasComdat() && + (Fn->getVisibility() == GlobalValue::VisibilityTypes::HiddenVisibility)) + return true; + + // its OK to use an alias + return false; +} + +static inline Constant *getFuncAddrForProfData(Function *Fn) { + auto *Int8PtrTy = Type::getInt8PtrTy(Fn->getContext()); + // Store a nullptr in __llvm_profd, if we shouldn't use a real address + if (!shouldRecordFunctionAddr(Fn)) + return ConstantPointerNull::get(Int8PtrTy); + + // If we can't use an alias, we must use the public symbol, even though this + // may require a symbolic relocation. + if (shouldUsePublicSymbol(Fn)) + return ConstantExpr::getBitCast(Fn, Int8PtrTy); + + // When possible use a private alias to avoid symbolic relocations. + auto *GA = GlobalAlias::create(GlobalValue::LinkageTypes::PrivateLinkage, + Fn->getName() + ".local", Fn); + + // When the instrumented function is a COMDAT function, we cannot use a + // private alias. If we did, we would create reference to a local label in + // this function's section. If this version of the function isn't selected by + // the linker, then the metadata would introduce a reference to a discarded + // section. So, for COMDAT functions, we need to adjust the linkage of the + // alias. Using hidden visibility avoids a dynamic relocation and an entry in + // the dynamic symbol table. + // + // Note that this handles COMDAT functions with visibility other than Hidden, + // since that case is covered in shouldUsePublicSymbol() + if (Fn->hasComdat()) { + GA->setLinkage(Fn->getLinkage()); + GA->setVisibility(GlobalValue::VisibilityTypes::HiddenVisibility); + } + + // appendToCompilerUsed(*Fn->getParent(), {GA}); + + return ConstantExpr::getBitCast(GA, Int8PtrTy); +} + static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { // Don't do this for Darwin. compiler-rt uses linker magic. if (TT.isOSDarwin()) @@ -1014,9 +1103,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { }; auto *DataTy = StructType::get(Ctx, ArrayRef(DataTypes)); - Constant *FunctionAddr = shouldRecordFunctionAddr(Fn) - ? ConstantExpr::getBitCast(Fn, Int8PtrTy) - : ConstantPointerNull::get(Int8PtrTy); + Constant *FunctionAddr = getFuncAddrForProfData(Fn); Constant *Int16ArrayVals[IPVK_Last + 1]; for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) @@ -1116,6 +1203,7 @@ void InstrProfiling::emitVNodes() { Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName()); VNodesVar->setSection( getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat())); + VNodesVar->setAlignment(M->getDataLayout().getABITypeAlign(VNodesTy)); // VNodesVar is used by runtime but not referenced via relocation by other // sections. Conservatively make it linker retained. UsedVars.push_back(VNodesVar); diff --git a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp index ab72650ae801..806afc8fcdf7 100644 --- a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -12,12 +12,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Instrumentation.h" -#include "llvm-c/Initialization.h" -#include "llvm/ADT/Triple.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/PassRegistry.h" +#include "llvm/TargetParser/Triple.h" using namespace llvm; diff --git a/llvm/lib/Transforms/Instrumentation/KCFI.cpp b/llvm/lib/Transforms/Instrumentation/KCFI.cpp index 7978c766f0f0..b1a26880c701 100644 --- a/llvm/lib/Transforms/Instrumentation/KCFI.cpp +++ b/llvm/lib/Transforms/Instrumentation/KCFI.cpp @@ -24,10 +24,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Target/TargetMachine.h" -#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -76,6 +73,7 @@ PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) { IntegerType *Int32Ty = Type::getInt32Ty(Ctx); MDNode *VeryUnlikelyWeights = MDBuilder(Ctx).createBranchWeights(1, (1U << 20) - 1); + Triple T(M.getTargetTriple()); for (CallInst *CI : KCFICalls) { // Get the expected hash value. @@ -96,14 +94,24 @@ PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) { // Emit a check and trap if the target hash doesn't match. IRBuilder<> Builder(Call); - Value *HashPtr = Builder.CreateConstInBoundsGEP1_32( - Int32Ty, Call->getCalledOperand(), -1); + Value *FuncPtr = Call->getCalledOperand(); + // ARM uses the least significant bit of the function pointer to select + // between ARM and Thumb modes for the callee. Instructions are always + // at least 16-bit aligned, so clear the LSB before we compute the hash + // location. + if (T.isARM() || T.isThumb()) { + FuncPtr = Builder.CreateIntToPtr( + Builder.CreateAnd(Builder.CreatePtrToInt(FuncPtr, Int32Ty), + ConstantInt::get(Int32Ty, -2)), + FuncPtr->getType()); + } + Value *HashPtr = Builder.CreateConstInBoundsGEP1_32(Int32Ty, FuncPtr, -1); Value *Test = Builder.CreateICmpNE(Builder.CreateLoad(Int32Ty, HashPtr), ConstantInt::get(Int32Ty, ExpectedHash)); Instruction *ThenTerm = SplitBlockAndInsertIfThen(Test, Call, false, VeryUnlikelyWeights); Builder.SetInsertPoint(ThenTerm); - Builder.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::trap)); + Builder.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::debugtrap)); ++NumKCFIChecks; } diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp index 2a1601fab45f..789ed005d03d 100644 --- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp @@ -18,10 +18,12 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryProfileInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" @@ -30,18 +32,30 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" +#include "llvm/ProfileData/InstrProfReader.h" +#include "llvm/Support/BLAKE3.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/HashBuilder.h" +#include "llvm/Support/VirtualFileSystem.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include <map> +#include <set> using namespace llvm; +using namespace llvm::memprof; #define DEBUG_TYPE "memprof" +namespace llvm { +extern cl::opt<bool> PGOWarnMissing; +extern cl::opt<bool> NoPGOWarnMismatch; +extern cl::opt<bool> NoPGOWarnMismatchComdatWeak; +} // namespace llvm + constexpr int LLVM_MEM_PROFILER_VERSION = 1; // Size of memory mapped to a single shadow location. @@ -130,6 +144,7 @@ STATISTIC(NumInstrumentedReads, "Number of instrumented reads"); STATISTIC(NumInstrumentedWrites, "Number of instrumented writes"); STATISTIC(NumSkippedStackReads, "Number of non-instrumented stack reads"); STATISTIC(NumSkippedStackWrites, "Number of non-instrumented stack writes"); +STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile."); namespace { @@ -603,3 +618,297 @@ bool MemProfiler::instrumentFunction(Function &F) { return FunctionModified; } + +static void addCallsiteMetadata(Instruction &I, + std::vector<uint64_t> &InlinedCallStack, + LLVMContext &Ctx) { + I.setMetadata(LLVMContext::MD_callsite, + buildCallstackMetadata(InlinedCallStack, Ctx)); +} + +static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset, + uint32_t Column) { + llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little> + HashBuilder; + HashBuilder.add(Function, LineOffset, Column); + llvm::BLAKE3Result<8> Hash = HashBuilder.final(); + uint64_t Id; + std::memcpy(&Id, Hash.data(), sizeof(Hash)); + return Id; +} + +static uint64_t computeStackId(const memprof::Frame &Frame) { + return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column); +} + +static void addCallStack(CallStackTrie &AllocTrie, + const AllocationInfo *AllocInfo) { + SmallVector<uint64_t> StackIds; + for (const auto &StackFrame : AllocInfo->CallStack) + StackIds.push_back(computeStackId(StackFrame)); + auto AllocType = getAllocType(AllocInfo->Info.getTotalLifetimeAccessDensity(), + AllocInfo->Info.getAllocCount(), + AllocInfo->Info.getTotalLifetime()); + AllocTrie.addCallStack(AllocType, StackIds); +} + +// Helper to compare the InlinedCallStack computed from an instruction's debug +// info to a list of Frames from profile data (either the allocation data or a +// callsite). For callsites, the StartIndex to use in the Frame array may be +// non-zero. +static bool +stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack, + ArrayRef<uint64_t> InlinedCallStack, + unsigned StartIndex = 0) { + auto StackFrame = ProfileCallStack.begin() + StartIndex; + auto InlCallStackIter = InlinedCallStack.begin(); + for (; StackFrame != ProfileCallStack.end() && + InlCallStackIter != InlinedCallStack.end(); + ++StackFrame, ++InlCallStackIter) { + uint64_t StackId = computeStackId(*StackFrame); + if (StackId != *InlCallStackIter) + return false; + } + // Return true if we found and matched all stack ids from the call + // instruction. + return InlCallStackIter == InlinedCallStack.end(); +} + +static void readMemprof(Module &M, Function &F, + IndexedInstrProfReader *MemProfReader, + const TargetLibraryInfo &TLI) { + auto &Ctx = M.getContext(); + + auto FuncName = getPGOFuncName(F); + auto FuncGUID = Function::getGUID(FuncName); + Expected<memprof::MemProfRecord> MemProfResult = + MemProfReader->getMemProfRecord(FuncGUID); + if (Error E = MemProfResult.takeError()) { + handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { + auto Err = IPE.get(); + bool SkipWarning = false; + LLVM_DEBUG(dbgs() << "Error in reading profile for Func " << FuncName + << ": "); + if (Err == instrprof_error::unknown_function) { + NumOfMemProfMissing++; + SkipWarning = !PGOWarnMissing; + LLVM_DEBUG(dbgs() << "unknown function"); + } else if (Err == instrprof_error::hash_mismatch) { + SkipWarning = + NoPGOWarnMismatch || + (NoPGOWarnMismatchComdatWeak && + (F.hasComdat() || + F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); + LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")"); + } + + if (SkipWarning) + return; + + std::string Msg = (IPE.message() + Twine(" ") + F.getName().str() + + Twine(" Hash = ") + std::to_string(FuncGUID)) + .str(); + + Ctx.diagnose( + DiagnosticInfoPGOProfile(M.getName().data(), Msg, DS_Warning)); + }); + return; + } + + // Build maps of the location hash to all profile data with that leaf location + // (allocation info and the callsites). + std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo; + // For the callsites we need to record the index of the associated frame in + // the frame array (see comments below where the map entries are added). + std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>> + LocHashToCallSites; + const auto MemProfRec = std::move(MemProfResult.get()); + for (auto &AI : MemProfRec.AllocSites) { + // Associate the allocation info with the leaf frame. The later matching + // code will match any inlined call sequences in the IR with a longer prefix + // of call stack frames. + uint64_t StackId = computeStackId(AI.CallStack[0]); + LocHashToAllocInfo[StackId].insert(&AI); + } + for (auto &CS : MemProfRec.CallSites) { + // Need to record all frames from leaf up to and including this function, + // as any of these may or may not have been inlined at this point. + unsigned Idx = 0; + for (auto &StackFrame : CS) { + uint64_t StackId = computeStackId(StackFrame); + LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++)); + // Once we find this function, we can stop recording. + if (StackFrame.Function == FuncGUID) + break; + } + assert(Idx <= CS.size() && CS[Idx - 1].Function == FuncGUID); + } + + auto GetOffset = [](const DILocation *DIL) { + return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) & + 0xffff; + }; + + // Now walk the instructions, looking up the associated profile data using + // dbug locations. + for (auto &BB : F) { + for (auto &I : BB) { + if (I.isDebugOrPseudoInst()) + continue; + // We are only interested in calls (allocation or interior call stack + // context calls). + auto *CI = dyn_cast<CallBase>(&I); + if (!CI) + continue; + auto *CalledFunction = CI->getCalledFunction(); + if (CalledFunction && CalledFunction->isIntrinsic()) + continue; + // List of call stack ids computed from the location hashes on debug + // locations (leaf to inlined at root). + std::vector<uint64_t> InlinedCallStack; + // Was the leaf location found in one of the profile maps? + bool LeafFound = false; + // If leaf was found in a map, iterators pointing to its location in both + // of the maps. It might exist in neither, one, or both (the latter case + // can happen because we don't currently have discriminators to + // distinguish the case when a single line/col maps to both an allocation + // and another callsite). + std::map<uint64_t, std::set<const AllocationInfo *>>::iterator + AllocInfoIter; + std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, + unsigned>>>::iterator CallSitesIter; + for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr; + DIL = DIL->getInlinedAt()) { + // Use C++ linkage name if possible. Need to compile with + // -fdebug-info-for-profiling to get linkage name. + StringRef Name = DIL->getScope()->getSubprogram()->getLinkageName(); + if (Name.empty()) + Name = DIL->getScope()->getSubprogram()->getName(); + auto CalleeGUID = Function::getGUID(Name); + auto StackId = + computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn()); + // LeafFound will only be false on the first iteration, since we either + // set it true or break out of the loop below. + if (!LeafFound) { + AllocInfoIter = LocHashToAllocInfo.find(StackId); + CallSitesIter = LocHashToCallSites.find(StackId); + // Check if the leaf is in one of the maps. If not, no need to look + // further at this call. + if (AllocInfoIter == LocHashToAllocInfo.end() && + CallSitesIter == LocHashToCallSites.end()) + break; + LeafFound = true; + } + InlinedCallStack.push_back(StackId); + } + // If leaf not in either of the maps, skip inst. + if (!LeafFound) + continue; + + // First add !memprof metadata from allocation info, if we found the + // instruction's leaf location in that map, and if the rest of the + // instruction's locations match the prefix Frame locations on an + // allocation context with the same leaf. + if (AllocInfoIter != LocHashToAllocInfo.end()) { + // Only consider allocations via new, to reduce unnecessary metadata, + // since those are the only allocations that will be targeted initially. + if (!isNewLikeFn(CI, &TLI)) + continue; + // We may match this instruction's location list to multiple MIB + // contexts. Add them to a Trie specialized for trimming the contexts to + // the minimal needed to disambiguate contexts with unique behavior. + CallStackTrie AllocTrie; + for (auto *AllocInfo : AllocInfoIter->second) { + // Check the full inlined call stack against this one. + // If we found and thus matched all frames on the call, include + // this MIB. + if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack, + InlinedCallStack)) + addCallStack(AllocTrie, AllocInfo); + } + // We might not have matched any to the full inlined call stack. + // But if we did, create and attach metadata, or a function attribute if + // all contexts have identical profiled behavior. + if (!AllocTrie.empty()) { + // MemprofMDAttached will be false if a function attribute was + // attached. + bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI); + assert(MemprofMDAttached == I.hasMetadata(LLVMContext::MD_memprof)); + if (MemprofMDAttached) { + // Add callsite metadata for the instruction's location list so that + // it simpler later on to identify which part of the MIB contexts + // are from this particular instruction (including during inlining, + // when the callsite metdata will be updated appropriately). + // FIXME: can this be changed to strip out the matching stack + // context ids from the MIB contexts and not add any callsite + // metadata here to save space? + addCallsiteMetadata(I, InlinedCallStack, Ctx); + } + } + continue; + } + + // Otherwise, add callsite metadata. If we reach here then we found the + // instruction's leaf location in the callsites map and not the allocation + // map. + assert(CallSitesIter != LocHashToCallSites.end()); + for (auto CallStackIdx : CallSitesIter->second) { + // If we found and thus matched all frames on the call, create and + // attach call stack metadata. + if (stackFrameIncludesInlinedCallStack( + *CallStackIdx.first, InlinedCallStack, CallStackIdx.second)) { + addCallsiteMetadata(I, InlinedCallStack, Ctx); + // Only need to find one with a matching call stack and add a single + // callsite metadata. + break; + } + } + } + } +} + +MemProfUsePass::MemProfUsePass(std::string MemoryProfileFile, + IntrusiveRefCntPtr<vfs::FileSystem> FS) + : MemoryProfileFileName(MemoryProfileFile), FS(FS) { + if (!FS) + this->FS = vfs::getRealFileSystem(); +} + +PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) { + LLVM_DEBUG(dbgs() << "Read in memory profile:"); + auto &Ctx = M.getContext(); + auto ReaderOrErr = IndexedInstrProfReader::create(MemoryProfileFileName, *FS); + if (Error E = ReaderOrErr.takeError()) { + handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) { + Ctx.diagnose( + DiagnosticInfoPGOProfile(MemoryProfileFileName.data(), EI.message())); + }); + return PreservedAnalyses::all(); + } + + std::unique_ptr<IndexedInstrProfReader> MemProfReader = + std::move(ReaderOrErr.get()); + if (!MemProfReader) { + Ctx.diagnose(DiagnosticInfoPGOProfile( + MemoryProfileFileName.data(), StringRef("Cannot get MemProfReader"))); + return PreservedAnalyses::all(); + } + + if (!MemProfReader->hasMemoryProfile()) { + Ctx.diagnose(DiagnosticInfoPGOProfile(MemoryProfileFileName.data(), + "Not a memory profile")); + return PreservedAnalyses::all(); + } + + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + for (auto &F : M) { + if (F.isDeclaration()) + continue; + + const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + readMemprof(M, F, MemProfReader.get(), TLI); + } + + return PreservedAnalyses::none(); +} diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index fe8b8ce0dc86..83d90049abc3 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -122,6 +122,10 @@ /// Arbitrary sized accesses are handled with: /// __msan_metadata_ptr_for_load_n(ptr, size) /// __msan_metadata_ptr_for_store_n(ptr, size); +/// Note that the sanitizer code has to deal with how shadow/origin pairs +/// returned by the these functions are represented in different ABIs. In +/// the X86_64 ABI they are returned in RDX:RAX, and in the SystemZ ABI they +/// are written to memory pointed to by a hidden parameter. /// - TLS variables are stored in a single per-task struct. A call to a /// function __msan_get_context_state() returning a pointer to that struct /// is inserted into every instrumented function before the entry block; @@ -135,7 +139,7 @@ /// Also, KMSAN currently ignores uninitialized memory passed into inline asm /// calls, making sure we're on the safe side wrt. possible false positives. /// -/// KernelMemorySanitizer only supports X86_64 at the moment. +/// KernelMemorySanitizer only supports X86_64 and SystemZ at the moment. /// // // FIXME: This sanitizer does not yet handle scalable vectors @@ -152,11 +156,11 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallingConv.h" @@ -190,6 +194,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -434,6 +439,14 @@ static const MemoryMapParams Linux_AArch64_MemoryMapParams = { 0x0200000000000, // OriginBase }; +// loongarch64 Linux +static const MemoryMapParams Linux_LoongArch64_MemoryMapParams = { + 0, // AndMask (not used) + 0x500000000000, // XorMask + 0, // ShadowBase (not used) + 0x100000000000, // OriginBase +}; + // aarch64 FreeBSD static const MemoryMapParams FreeBSD_AArch64_MemoryMapParams = { 0x1800000000000, // AndMask @@ -491,6 +504,11 @@ static const PlatformMemoryMapParams Linux_ARM_MemoryMapParams = { &Linux_AArch64_MemoryMapParams, }; +static const PlatformMemoryMapParams Linux_LoongArch_MemoryMapParams = { + nullptr, + &Linux_LoongArch64_MemoryMapParams, +}; + static const PlatformMemoryMapParams FreeBSD_ARM_MemoryMapParams = { nullptr, &FreeBSD_AArch64_MemoryMapParams, @@ -543,6 +561,10 @@ private: void createKernelApi(Module &M, const TargetLibraryInfo &TLI); void createUserspaceApi(Module &M, const TargetLibraryInfo &TLI); + template <typename... ArgsTy> + FunctionCallee getOrInsertMsanMetadataFunction(Module &M, StringRef Name, + ArgsTy... Args); + /// True if we're compiling the Linux kernel. bool CompileKernel; /// Track origins (allocation points) of uninitialized values. @@ -550,6 +572,7 @@ private: bool Recover; bool EagerChecks; + Triple TargetTriple; LLVMContext *C; Type *IntptrTy; Type *OriginTy; @@ -620,13 +643,18 @@ private: /// Functions for poisoning/unpoisoning local variables FunctionCallee MsanPoisonAllocaFn, MsanUnpoisonAllocaFn; - /// Each of the MsanMetadataPtrXxx functions returns a pair of shadow/origin - /// pointers. + /// Pair of shadow/origin pointers. + Type *MsanMetadata; + + /// Each of the MsanMetadataPtrXxx functions returns a MsanMetadata. FunctionCallee MsanMetadataPtrForLoadN, MsanMetadataPtrForStoreN; FunctionCallee MsanMetadataPtrForLoad_1_8[4]; FunctionCallee MsanMetadataPtrForStore_1_8[4]; FunctionCallee MsanInstrumentAsmStoreFn; + /// Storage for return values of the MsanMetadataPtrXxx functions. + Value *MsanMetadataAlloca; + /// Helper to choose between different MsanMetadataPtrXxx(). FunctionCallee getKmsanShadowOriginAccessFn(bool isStore, int size); @@ -706,7 +734,7 @@ void MemorySanitizerPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<MemorySanitizerPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (Options.Recover) OS << "recover;"; if (Options.Kernel) @@ -714,7 +742,7 @@ void MemorySanitizerPass::printPipeline( if (Options.EagerChecks) OS << "eager-checks;"; OS << "track-origins=" << Options.TrackOrigins; - OS << ">"; + OS << '>'; } /// Create a non-const global initialized with the given string. @@ -729,6 +757,21 @@ static GlobalVariable *createPrivateConstGlobalForString(Module &M, GlobalValue::PrivateLinkage, StrConst, ""); } +template <typename... ArgsTy> +FunctionCallee +MemorySanitizer::getOrInsertMsanMetadataFunction(Module &M, StringRef Name, + ArgsTy... Args) { + if (TargetTriple.getArch() == Triple::systemz) { + // SystemZ ABI: shadow/origin pair is returned via a hidden parameter. + return M.getOrInsertFunction(Name, Type::getVoidTy(*C), + PointerType::get(MsanMetadata, 0), + std::forward<ArgsTy>(Args)...); + } + + return M.getOrInsertFunction(Name, MsanMetadata, + std::forward<ArgsTy>(Args)...); +} + /// Create KMSAN API callbacks. void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) { IRBuilder<> IRB(*C); @@ -758,25 +801,25 @@ void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) { MsanGetContextStateFn = M.getOrInsertFunction( "__msan_get_context_state", PointerType::get(MsanContextStateTy, 0)); - Type *RetTy = StructType::get(PointerType::get(IRB.getInt8Ty(), 0), - PointerType::get(IRB.getInt32Ty(), 0)); + MsanMetadata = StructType::get(PointerType::get(IRB.getInt8Ty(), 0), + PointerType::get(IRB.getInt32Ty(), 0)); for (int ind = 0, size = 1; ind < 4; ind++, size <<= 1) { std::string name_load = "__msan_metadata_ptr_for_load_" + std::to_string(size); std::string name_store = "__msan_metadata_ptr_for_store_" + std::to_string(size); - MsanMetadataPtrForLoad_1_8[ind] = M.getOrInsertFunction( - name_load, RetTy, PointerType::get(IRB.getInt8Ty(), 0)); - MsanMetadataPtrForStore_1_8[ind] = M.getOrInsertFunction( - name_store, RetTy, PointerType::get(IRB.getInt8Ty(), 0)); + MsanMetadataPtrForLoad_1_8[ind] = getOrInsertMsanMetadataFunction( + M, name_load, PointerType::get(IRB.getInt8Ty(), 0)); + MsanMetadataPtrForStore_1_8[ind] = getOrInsertMsanMetadataFunction( + M, name_store, PointerType::get(IRB.getInt8Ty(), 0)); } - MsanMetadataPtrForLoadN = M.getOrInsertFunction( - "__msan_metadata_ptr_for_load_n", RetTy, - PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty()); - MsanMetadataPtrForStoreN = M.getOrInsertFunction( - "__msan_metadata_ptr_for_store_n", RetTy, + MsanMetadataPtrForLoadN = getOrInsertMsanMetadataFunction( + M, "__msan_metadata_ptr_for_load_n", PointerType::get(IRB.getInt8Ty(), 0), + IRB.getInt64Ty()); + MsanMetadataPtrForStoreN = getOrInsertMsanMetadataFunction( + M, "__msan_metadata_ptr_for_store_n", PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty()); // Functions for poisoning and unpoisoning memory. @@ -927,6 +970,8 @@ FunctionCallee MemorySanitizer::getKmsanShadowOriginAccessFn(bool isStore, void MemorySanitizer::initializeModule(Module &M) { auto &DL = M.getDataLayout(); + TargetTriple = Triple(M.getTargetTriple()); + bool ShadowPassed = ClShadowBase.getNumOccurrences() > 0; bool OriginPassed = ClOriginBase.getNumOccurrences() > 0; // Check the overrides first @@ -937,7 +982,6 @@ void MemorySanitizer::initializeModule(Module &M) { CustomMapParams.OriginBase = ClOriginBase; MapParams = &CustomMapParams; } else { - Triple TargetTriple(M.getTargetTriple()); switch (TargetTriple.getOS()) { case Triple::FreeBSD: switch (TargetTriple.getArch()) { @@ -986,6 +1030,9 @@ void MemorySanitizer::initializeModule(Module &M) { case Triple::aarch64_be: MapParams = Linux_ARM_MemoryMapParams.bits64; break; + case Triple::loongarch64: + MapParams = Linux_LoongArch_MemoryMapParams.bits64; + break; default: report_fatal_error("unsupported architecture"); } @@ -1056,10 +1103,14 @@ struct MemorySanitizerVisitor; static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, MemorySanitizerVisitor &Visitor); -static unsigned TypeSizeToSizeIndex(unsigned TypeSize) { - if (TypeSize <= 8) +static unsigned TypeSizeToSizeIndex(TypeSize TS) { + if (TS.isScalable()) + // Scalable types unconditionally take slowpaths. + return kNumberOfAccessSizes; + unsigned TypeSizeFixed = TS.getFixedValue(); + if (TypeSizeFixed <= 8) return 0; - return Log2_32_Ceil((TypeSize + 7) / 8); + return Log2_32_Ceil((TypeSizeFixed + 7) / 8); } namespace { @@ -1178,13 +1229,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Fill memory range with the given origin value. void paintOrigin(IRBuilder<> &IRB, Value *Origin, Value *OriginPtr, - unsigned Size, Align Alignment) { + TypeSize TS, Align Alignment) { const DataLayout &DL = F.getParent()->getDataLayout(); const Align IntptrAlignment = DL.getABITypeAlign(MS.IntptrTy); unsigned IntptrSize = DL.getTypeStoreSize(MS.IntptrTy); assert(IntptrAlignment >= kMinOriginAlignment); assert(IntptrSize >= kOriginSize); + // Note: The loop based formation works for fixed length vectors too, + // however we prefer to unroll and specialize alignment below. + if (TS.isScalable()) { + Value *Size = IRB.CreateTypeSize(IRB.getInt32Ty(), TS); + Value *RoundUp = IRB.CreateAdd(Size, IRB.getInt32(kOriginSize - 1)); + Value *End = IRB.CreateUDiv(RoundUp, IRB.getInt32(kOriginSize)); + auto [InsertPt, Index] = + SplitBlockAndInsertSimpleForLoop(End, &*IRB.GetInsertPoint()); + IRB.SetInsertPoint(InsertPt); + + Value *GEP = IRB.CreateGEP(MS.OriginTy, OriginPtr, Index); + IRB.CreateAlignedStore(Origin, GEP, kMinOriginAlignment); + return; + } + + unsigned Size = TS.getFixedValue(); + unsigned Ofs = 0; Align CurrentAlignment = Alignment; if (Alignment >= IntptrAlignment && IntptrSize > kOriginSize) { @@ -1212,7 +1280,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *OriginPtr, Align Alignment) { const DataLayout &DL = F.getParent()->getDataLayout(); const Align OriginAlignment = std::max(kMinOriginAlignment, Alignment); - unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType()); + TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType()); Value *ConvertedShadow = convertShadowToScalar(Shadow, IRB); if (auto *ConstantShadow = dyn_cast<Constant>(ConvertedShadow)) { if (!ClCheckConstantShadow || ConstantShadow->isZeroValue()) { @@ -1229,7 +1297,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Fallback to runtime check, which still can be optimized out later. } - unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); + TypeSize TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); if (instrumentWithCalls(ConvertedShadow) && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { @@ -1325,7 +1393,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void materializeOneCheck(IRBuilder<> &IRB, Value *ConvertedShadow, Value *Origin) { const DataLayout &DL = F.getParent()->getDataLayout(); - unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); + TypeSize TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); if (instrumentWithCalls(ConvertedShadow) && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { @@ -1443,6 +1511,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { MS.RetvalOriginTLS = IRB.CreateGEP(MS.MsanContextStateTy, ContextState, {Zero, IRB.getInt32(6)}, "retval_origin"); + if (MS.TargetTriple.getArch() == Triple::systemz) + MS.MsanMetadataAlloca = IRB.CreateAlloca(MS.MsanMetadata, 0u); } /// Add MemorySanitizer instrumentation to a function. @@ -1505,8 +1575,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { const DataLayout &DL = F.getParent()->getDataLayout(); if (VectorType *VT = dyn_cast<VectorType>(OrigTy)) { uint32_t EltSize = DL.getTypeSizeInBits(VT->getElementType()); - return FixedVectorType::get(IntegerType::get(*MS.C, EltSize), - cast<FixedVectorType>(VT)->getNumElements()); + return VectorType::get(IntegerType::get(*MS.C, EltSize), + VT->getElementCount()); } if (ArrayType *AT = dyn_cast<ArrayType>(OrigTy)) { return ArrayType::get(getShadowTy(AT->getElementType()), @@ -1524,14 +1594,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IntegerType::get(*MS.C, TypeSize); } - /// Flatten a vector type. - Type *getShadowTyNoVec(Type *ty) { - if (VectorType *vt = dyn_cast<VectorType>(ty)) - return IntegerType::get(*MS.C, - vt->getPrimitiveSizeInBits().getFixedValue()); - return ty; - } - /// Extract combined shadow of struct elements as a bool Value *collapseStructShadow(StructType *Struct, Value *Shadow, IRBuilder<> &IRB) { @@ -1541,8 +1603,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (unsigned Idx = 0; Idx < Struct->getNumElements(); Idx++) { // Combine by ORing together each element's bool shadow Value *ShadowItem = IRB.CreateExtractValue(Shadow, Idx); - Value *ShadowInner = convertShadowToScalar(ShadowItem, IRB); - Value *ShadowBool = convertToBool(ShadowInner, IRB); + Value *ShadowBool = convertToBool(ShadowItem, IRB); if (Aggregator != FalseVal) Aggregator = IRB.CreateOr(Aggregator, ShadowBool); @@ -1578,11 +1639,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return collapseStructShadow(Struct, V, IRB); if (ArrayType *Array = dyn_cast<ArrayType>(V->getType())) return collapseArrayShadow(Array, V, IRB); - Type *Ty = V->getType(); - Type *NoVecTy = getShadowTyNoVec(Ty); - if (Ty == NoVecTy) - return V; - return IRB.CreateBitCast(V, NoVecTy); + if (isa<VectorType>(V->getType())) { + if (isa<ScalableVectorType>(V->getType())) + return convertShadowToScalar(IRB.CreateOrReduce(V), IRB); + unsigned BitWidth = + V->getType()->getPrimitiveSizeInBits().getFixedValue(); + return IRB.CreateBitCast(V, IntegerType::get(*MS.C, BitWidth)); + } + return V; } // Convert a scalar value to an i1 by comparing with 0 @@ -1597,28 +1661,28 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } Type *ptrToIntPtrType(Type *PtrTy) const { - if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(PtrTy)) { - return FixedVectorType::get(ptrToIntPtrType(VectTy->getElementType()), - VectTy->getNumElements()); + if (VectorType *VectTy = dyn_cast<VectorType>(PtrTy)) { + return VectorType::get(ptrToIntPtrType(VectTy->getElementType()), + VectTy->getElementCount()); } assert(PtrTy->isIntOrPtrTy()); return MS.IntptrTy; } Type *getPtrToShadowPtrType(Type *IntPtrTy, Type *ShadowTy) const { - if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(IntPtrTy)) { - return FixedVectorType::get( + if (VectorType *VectTy = dyn_cast<VectorType>(IntPtrTy)) { + return VectorType::get( getPtrToShadowPtrType(VectTy->getElementType(), ShadowTy), - VectTy->getNumElements()); + VectTy->getElementCount()); } assert(IntPtrTy == MS.IntptrTy); return ShadowTy->getPointerTo(); } Constant *constToIntPtr(Type *IntPtrTy, uint64_t C) const { - if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(IntPtrTy)) { - return ConstantDataVector::getSplat( - VectTy->getNumElements(), constToIntPtr(VectTy->getElementType(), C)); + if (VectorType *VectTy = dyn_cast<VectorType>(IntPtrTy)) { + return ConstantVector::getSplat( + VectTy->getElementCount(), constToIntPtr(VectTy->getElementType(), C)); } assert(IntPtrTy == MS.IntptrTy); return ConstantInt::get(MS.IntptrTy, C); @@ -1681,24 +1745,37 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return std::make_pair(ShadowPtr, OriginPtr); } + template <typename... ArgsTy> + Value *createMetadataCall(IRBuilder<> &IRB, FunctionCallee Callee, + ArgsTy... Args) { + if (MS.TargetTriple.getArch() == Triple::systemz) { + IRB.CreateCall(Callee, + {MS.MsanMetadataAlloca, std::forward<ArgsTy>(Args)...}); + return IRB.CreateLoad(MS.MsanMetadata, MS.MsanMetadataAlloca); + } + + return IRB.CreateCall(Callee, {std::forward<ArgsTy>(Args)...}); + } + std::pair<Value *, Value *> getShadowOriginPtrKernelNoVec(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, bool isStore) { Value *ShadowOriginPtrs; const DataLayout &DL = F.getParent()->getDataLayout(); - int Size = DL.getTypeStoreSize(ShadowTy); + TypeSize Size = DL.getTypeStoreSize(ShadowTy); FunctionCallee Getter = MS.getKmsanShadowOriginAccessFn(isStore, Size); Value *AddrCast = IRB.CreatePointerCast(Addr, PointerType::get(IRB.getInt8Ty(), 0)); if (Getter) { - ShadowOriginPtrs = IRB.CreateCall(Getter, AddrCast); + ShadowOriginPtrs = createMetadataCall(IRB, Getter, AddrCast); } else { Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size); - ShadowOriginPtrs = IRB.CreateCall(isStore ? MS.MsanMetadataPtrForStoreN - : MS.MsanMetadataPtrForLoadN, - {AddrCast, SizeVal}); + ShadowOriginPtrs = createMetadataCall( + IRB, + isStore ? MS.MsanMetadataPtrForStoreN : MS.MsanMetadataPtrForLoadN, + AddrCast, SizeVal); } Value *ShadowPtr = IRB.CreateExtractValue(ShadowOriginPtrs, 0); ShadowPtr = IRB.CreatePointerCast(ShadowPtr, PointerType::get(ShadowTy, 0)); @@ -1714,14 +1791,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> &IRB, Type *ShadowTy, bool isStore) { - FixedVectorType *VectTy = dyn_cast<FixedVectorType>(Addr->getType()); + VectorType *VectTy = dyn_cast<VectorType>(Addr->getType()); if (!VectTy) { assert(Addr->getType()->isPointerTy()); return getShadowOriginPtrKernelNoVec(Addr, IRB, ShadowTy, isStore); } // TODO: Support callbacs with vectors of addresses. - unsigned NumElements = VectTy->getNumElements(); + unsigned NumElements = cast<FixedVectorType>(VectTy)->getNumElements(); Value *ShadowPtrs = ConstantInt::getNullValue( FixedVectorType::get(ShadowTy->getPointerTo(), NumElements)); Value *OriginPtrs = nullptr; @@ -2367,9 +2444,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Constant *ConstOrigin = dyn_cast<Constant>(OpOrigin); // No point in adding something that might result in 0 origin value. if (!ConstOrigin || !ConstOrigin->isNullValue()) { - Value *FlatShadow = MSV->convertShadowToScalar(OpShadow, IRB); - Value *Cond = - IRB.CreateICmpNE(FlatShadow, MSV->getCleanShadow(FlatShadow)); + Value *Cond = MSV->convertToBool(OpShadow, IRB); Origin = IRB.CreateSelect(Cond, OpOrigin, Origin); } } @@ -2434,8 +2509,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (dstTy->isIntegerTy() && srcTy->isIntegerTy()) return IRB.CreateIntCast(V, dstTy, Signed); if (dstTy->isVectorTy() && srcTy->isVectorTy() && - cast<FixedVectorType>(dstTy)->getNumElements() == - cast<FixedVectorType>(srcTy)->getNumElements()) + cast<VectorType>(dstTy)->getElementCount() == + cast<VectorType>(srcTy)->getElementCount()) return IRB.CreateIntCast(V, dstTy, Signed); Value *V1 = IRB.CreateBitCast(V, Type::getIntNTy(*MS.C, srcSizeInBits)); Value *V2 = @@ -2487,7 +2562,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ConstantInt *Elt = dyn_cast<ConstantInt>(ConstArg->getAggregateElement(Idx))) { const APInt &V = Elt->getValue(); - APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); + APInt V2 = APInt(V.getBitWidth(), 1) << V.countr_zero(); Elements.push_back(ConstantInt::get(EltTy, V2)); } else { Elements.push_back(ConstantInt::get(EltTy, 1)); @@ -2497,7 +2572,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } else { if (ConstantInt *Elt = dyn_cast<ConstantInt>(ConstArg)) { const APInt &V = Elt->getValue(); - APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); + APInt V2 = APInt(V.getBitWidth(), 1) << V.countr_zero(); ShadowMul = ConstantInt::get(Ty, V2); } else { ShadowMul = ConstantInt::get(Ty, 1); @@ -3356,7 +3431,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } Type *ShadowTy = getShadowTy(&I); - Type *ElementShadowTy = cast<FixedVectorType>(ShadowTy)->getElementType(); + Type *ElementShadowTy = cast<VectorType>(ShadowTy)->getElementType(); auto [ShadowPtr, OriginPtr] = getShadowOriginPtr(Ptr, IRB, ElementShadowTy, {}, /*isStore*/ false); @@ -3382,7 +3457,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Shadow = getShadow(Values); Type *ElementShadowTy = - getShadowTy(cast<FixedVectorType>(Values->getType())->getElementType()); + getShadowTy(cast<VectorType>(Values->getType())->getElementType()); auto [ShadowPtr, OriginPtrs] = getShadowOriginPtr(Ptr, IRB, ElementShadowTy, {}, /*isStore*/ true); @@ -3415,7 +3490,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } Type *ShadowTy = getShadowTy(&I); - Type *ElementShadowTy = cast<FixedVectorType>(ShadowTy)->getElementType(); + Type *ElementShadowTy = cast<VectorType>(ShadowTy)->getElementType(); auto [ShadowPtrs, OriginPtrs] = getShadowOriginPtr( Ptrs, IRB, ElementShadowTy, Alignment, /*isStore*/ false); @@ -3448,7 +3523,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Shadow = getShadow(Values); Type *ElementShadowTy = - getShadowTy(cast<FixedVectorType>(Values->getType())->getElementType()); + getShadowTy(cast<VectorType>(Values->getType())->getElementType()); auto [ShadowPtrs, OriginPtrs] = getShadowOriginPtr( Ptrs, IRB, ElementShadowTy, Alignment, /*isStore*/ true); @@ -3520,8 +3595,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *MaskedPassThruShadow = IRB.CreateAnd( getShadow(PassThru), IRB.CreateSExt(IRB.CreateNeg(Mask), ShadowTy)); - Value *ConvertedShadow = convertShadowToScalar(MaskedPassThruShadow, IRB); - Value *NotNull = convertToBool(ConvertedShadow, IRB, "_mscmp"); + Value *NotNull = convertToBool(MaskedPassThruShadow, IRB, "_mscmp"); Value *PtrOrigin = IRB.CreateLoad(MS.OriginTy, OriginPtr); Value *Origin = IRB.CreateSelect(NotNull, getOrigin(PassThru), PtrOrigin); @@ -3645,11 +3719,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getOrigin(&I, 0)); } + void handleIsFpClass(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Shadow = getShadow(&I, 0); + setShadow(&I, IRB.CreateICmpNE(Shadow, getCleanShadow(Shadow))); + setOrigin(&I, getOrigin(&I, 0)); + } + void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { case Intrinsic::abs: handleAbsIntrinsic(I); break; + case Intrinsic::is_fpclass: + handleIsFpClass(I); + break; case Intrinsic::lifetime_start: handleLifetimeStart(I); break; @@ -4391,11 +4475,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Origins are always i32, so any vector conditions must be flattened. // FIXME: consider tracking vector origins for app vectors? if (B->getType()->isVectorTy()) { - Type *FlatTy = getShadowTyNoVec(B->getType()); - B = IRB.CreateICmpNE(IRB.CreateBitCast(B, FlatTy), - ConstantInt::getNullValue(FlatTy)); - Sb = IRB.CreateICmpNE(IRB.CreateBitCast(Sb, FlatTy), - ConstantInt::getNullValue(FlatTy)); + B = convertToBool(B, IRB); + Sb = convertToBool(Sb, IRB); } // a = select b, c, d // Oa = Sb ? Ob : (b ? Oc : Od) @@ -4490,9 +4571,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (!ElemTy->isSized()) return; - int Size = DL.getTypeStoreSize(ElemTy); Value *Ptr = IRB.CreatePointerCast(Operand, IRB.getInt8PtrTy()); - Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size); + Value *SizeVal = + IRB.CreateTypeSize(MS.IntptrTy, DL.getTypeStoreSize(ElemTy)); IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Ptr, SizeVal}); } @@ -4600,8 +4681,8 @@ struct VarArgAMD64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy = nullptr; - Value *VAArgTLSOriginCopy = nullptr; + AllocaInst *VAArgTLSCopy = nullptr; + AllocaInst *VAArgTLSOriginCopy = nullptr; Value *VAArgOverflowSize = nullptr; SmallVector<CallInst *, 16> VAStartInstrumentationList; @@ -4721,7 +4802,7 @@ struct VarArgAMD64Helper : public VarArgHelper { IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment); if (MS.TrackOrigins) { Value *Origin = MSV.getOrigin(A); - unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType()); + TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType()); MSV.paintOrigin(IRB, Origin, OriginBase, StoreSize, std::max(kShadowTLSAlignment, kMinOriginAlignment)); } @@ -4797,11 +4878,20 @@ struct VarArgAMD64Helper : public VarArgHelper { Value *CopySize = IRB.CreateAdd( ConstantInt::get(MS.IntptrTy, AMD64FpEndOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); + VAArgTLSCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()), + CopySize, kShadowTLSAlignment, false); + + Value *SrcSize = IRB.CreateBinaryIntrinsic( + Intrinsic::umin, CopySize, + ConstantInt::get(MS.IntptrTy, kParamTLSSize)); + IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS, + kShadowTLSAlignment, SrcSize); if (MS.TrackOrigins) { VAArgTLSOriginCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSOriginCopy, Align(8), MS.VAArgOriginTLS, - Align(8), CopySize); + VAArgTLSOriginCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemCpy(VAArgTLSOriginCopy, kShadowTLSAlignment, + MS.VAArgOriginTLS, kShadowTLSAlignment, SrcSize); } } @@ -4859,7 +4949,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy = nullptr; + AllocaInst *VAArgTLSCopy = nullptr; Value *VAArgSize = nullptr; SmallVector<CallInst *, 16> VAStartInstrumentationList; @@ -4944,7 +5034,15 @@ struct VarArgMIPS64Helper : public VarArgHelper { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); + VAArgTLSCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()), + CopySize, kShadowTLSAlignment, false); + + Value *SrcSize = IRB.CreateBinaryIntrinsic( + Intrinsic::umin, CopySize, + ConstantInt::get(MS.IntptrTy, kParamTLSSize)); + IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS, + kShadowTLSAlignment, SrcSize); } // Instrument va_start. @@ -4986,7 +5084,7 @@ struct VarArgAArch64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy = nullptr; + AllocaInst *VAArgTLSCopy = nullptr; Value *VAArgOverflowSize = nullptr; SmallVector<CallInst *, 16> VAStartInstrumentationList; @@ -5130,7 +5228,15 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *CopySize = IRB.CreateAdd( ConstantInt::get(MS.IntptrTy, AArch64VAEndOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); + VAArgTLSCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()), + CopySize, kShadowTLSAlignment, false); + + Value *SrcSize = IRB.CreateBinaryIntrinsic( + Intrinsic::umin, CopySize, + ConstantInt::get(MS.IntptrTy, kParamTLSSize)); + IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS, + kShadowTLSAlignment, SrcSize); } Value *GrArgSize = ConstantInt::get(MS.IntptrTy, kAArch64GrArgSize); @@ -5230,7 +5336,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy = nullptr; + AllocaInst *VAArgTLSCopy = nullptr; Value *VAArgSize = nullptr; SmallVector<CallInst *, 16> VAStartInstrumentationList; @@ -5373,8 +5479,17 @@ struct VarArgPowerPC64Helper : public VarArgHelper { if (!VAStartInstrumentationList.empty()) { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. + VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); + VAArgTLSCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()), + CopySize, kShadowTLSAlignment, false); + + Value *SrcSize = IRB.CreateBinaryIntrinsic( + Intrinsic::umin, CopySize, + ConstantInt::get(MS.IntptrTy, kParamTLSSize)); + IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS, + kShadowTLSAlignment, SrcSize); } // Instrument va_start. @@ -5416,8 +5531,9 @@ struct VarArgSystemZHelper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy = nullptr; - Value *VAArgTLSOriginCopy = nullptr; + bool IsSoftFloatABI; + AllocaInst *VAArgTLSCopy = nullptr; + AllocaInst *VAArgTLSOriginCopy = nullptr; Value *VAArgOverflowSize = nullptr; SmallVector<CallInst *, 16> VAStartInstrumentationList; @@ -5434,9 +5550,10 @@ struct VarArgSystemZHelper : public VarArgHelper { VarArgSystemZHelper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV) {} + : F(F), MS(MS), MSV(MSV), + IsSoftFloatABI(F.getFnAttribute("use-soft-float").getValueAsBool()) {} - ArgKind classifyArgument(Type *T, bool IsSoftFloatABI) { + ArgKind classifyArgument(Type *T) { // T is a SystemZABIInfo::classifyArgumentType() output, and there are // only a few possibilities of what it can be. In particular, enums, single // element structs and large types have already been taken care of. @@ -5474,9 +5591,6 @@ struct VarArgSystemZHelper : public VarArgHelper { } void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override { - bool IsSoftFloatABI = CB.getCalledFunction() - ->getFnAttribute("use-soft-float") - .getValueAsBool(); unsigned GpOffset = SystemZGpOffset; unsigned FpOffset = SystemZFpOffset; unsigned VrIndex = 0; @@ -5487,7 +5601,7 @@ struct VarArgSystemZHelper : public VarArgHelper { // SystemZABIInfo does not produce ByVal parameters. assert(!CB.paramHasAttr(ArgNo, Attribute::ByVal)); Type *T = A->getType(); - ArgKind AK = classifyArgument(T, IsSoftFloatABI); + ArgKind AK = classifyArgument(T); if (AK == ArgKind::Indirect) { T = PointerType::get(T, 0); AK = ArgKind::GeneralPurpose; @@ -5587,7 +5701,7 @@ struct VarArgSystemZHelper : public VarArgHelper { IRB.CreateStore(Shadow, ShadowBase); if (MS.TrackOrigins) { Value *Origin = MSV.getOrigin(A); - unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType()); + TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType()); MSV.paintOrigin(IRB, Origin, OriginBase, StoreSize, kMinOriginAlignment); } @@ -5642,11 +5756,15 @@ struct VarArgSystemZHelper : public VarArgHelper { MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); // TODO(iii): copy only fragments filled by visitCallBase() + // TODO(iii): support packed-stack && !use-soft-float + // For use-soft-float functions, it is enough to copy just the GPRs. + unsigned RegSaveAreaSize = + IsSoftFloatABI ? SystemZGpEndOffset : SystemZRegSaveAreaSize; IRB.CreateMemCpy(RegSaveAreaShadowPtr, Alignment, VAArgTLSCopy, Alignment, - SystemZRegSaveAreaSize); + RegSaveAreaSize); if (MS.TrackOrigins) IRB.CreateMemCpy(RegSaveAreaOriginPtr, Alignment, VAArgTLSOriginCopy, - Alignment, SystemZRegSaveAreaSize); + Alignment, RegSaveAreaSize); } void copyOverflowArea(IRBuilder<> &IRB, Value *VAListTag) { @@ -5688,11 +5806,20 @@ struct VarArgSystemZHelper : public VarArgHelper { IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, SystemZOverflowOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); + VAArgTLSCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemSet(VAArgTLSCopy, Constant::getNullValue(IRB.getInt8Ty()), + CopySize, kShadowTLSAlignment, false); + + Value *SrcSize = IRB.CreateBinaryIntrinsic( + Intrinsic::umin, CopySize, + ConstantInt::get(MS.IntptrTy, kParamTLSSize)); + IRB.CreateMemCpy(VAArgTLSCopy, kShadowTLSAlignment, MS.VAArgTLS, + kShadowTLSAlignment, SrcSize); if (MS.TrackOrigins) { VAArgTLSOriginCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSOriginCopy, Align(8), MS.VAArgOriginTLS, - Align(8), CopySize); + VAArgTLSOriginCopy->setAlignment(kShadowTLSAlignment); + IRB.CreateMemCpy(VAArgTLSOriginCopy, kShadowTLSAlignment, + MS.VAArgOriginTLS, kShadowTLSAlignment, SrcSize); } } diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 4d4eb6f8ce80..3c8f25d73c62 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -48,7 +48,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" -#include "CFGMST.h" #include "ValueProfileCollector.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -56,17 +55,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/MemoryProfileInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -78,6 +73,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" @@ -99,7 +95,6 @@ #include "llvm/IR/Value.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/InstrProfReader.h" -#include "llvm/Support/BLAKE3.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CRC.h" #include "llvm/Support/Casting.h" @@ -109,27 +104,27 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" -#include "llvm/Support/HashBuilder.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Instrumentation/BlockCoverageInference.h" +#include "llvm/Transforms/Instrumentation/CFGMST.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/MisExpect.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> #include <cstdint> -#include <map> #include <memory> #include <numeric> #include <optional> -#include <set> #include <string> #include <unordered_map> #include <utility> #include <vector> using namespace llvm; -using namespace llvm::memprof; using ProfileCount = Function::ProfileCount; using VPCandidateInfo = ValueProfileCollector::CandidateInfo; @@ -144,7 +139,6 @@ STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts."); STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile."); STATISTIC(NumOfPGOMissing, "Number of functions without profile."); -STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile."); STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations."); STATISTIC(NumOfCSPGOInstrument, "Number of edges instrumented in CSPGO."); STATISTIC(NumOfCSPGOSelectInsts, @@ -159,6 +153,7 @@ STATISTIC(NumOfCSPGOFunc, STATISTIC(NumOfCSPGOMismatch, "Number of functions having mismatch profile in CSPGO."); STATISTIC(NumOfCSPGOMissing, "Number of functions without profile in CSPGO."); +STATISTIC(NumCoveredBlocks, "Number of basic blocks that were executed"); // Command line option to specify the file to read profile from. This is // mainly used for testing. @@ -200,31 +195,31 @@ static cl::opt<bool> DoComdatRenaming( cl::desc("Append function hash to the name of COMDAT function to avoid " "function hash mismatch due to the preinliner")); +namespace llvm { // Command line option to enable/disable the warning about missing profile // information. -static cl::opt<bool> - PGOWarnMissing("pgo-warn-missing-function", cl::init(false), cl::Hidden, - cl::desc("Use this option to turn on/off " - "warnings about missing profile data for " - "functions.")); +cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", cl::init(false), + cl::Hidden, + cl::desc("Use this option to turn on/off " + "warnings about missing profile data for " + "functions.")); -namespace llvm { // Command line option to enable/disable the warning about a hash mismatch in // the profile data. cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden, cl::desc("Use this option to turn off/on " "warnings about profile cfg mismatch.")); -} // namespace llvm // Command line option to enable/disable the warning about a hash mismatch in // the profile data for Comdat functions, which often turns out to be false // positive due to the pre-instrumentation inline. -static cl::opt<bool> NoPGOWarnMismatchComdatWeak( +cl::opt<bool> NoPGOWarnMismatchComdatWeak( "no-pgo-warn-mismatch-comdat-weak", cl::init(true), cl::Hidden, cl::desc("The option is used to turn on/off " "warnings about hash mismatch for comdat " "or weak functions.")); +} // namespace llvm // Command line option to enable/disable select instruction instrumentation. static cl::opt<bool> @@ -268,6 +263,19 @@ static cl::opt<bool> PGOFunctionEntryCoverage( cl::desc( "Use this option to enable function entry coverage instrumentation.")); +static cl::opt<bool> PGOBlockCoverage( + "pgo-block-coverage", + cl::desc("Use this option to enable basic block coverage instrumentation")); + +static cl::opt<bool> + PGOViewBlockCoverageGraph("pgo-view-block-coverage-graph", + cl::desc("Create a dot file of CFGs with block " + "coverage inference information")); + +static cl::opt<bool> PGOTemporalInstrumentation( + "pgo-temporal-instrumentation", + cl::desc("Use this option to enable temporal instrumentation")); + static cl::opt<bool> PGOFixEntryCount("pgo-fix-entry-count", cl::init(true), cl::Hidden, cl::desc("Fix function entry count in profile use.")); @@ -305,10 +313,6 @@ static cl::opt<unsigned> PGOFunctionSizeThreshold( "pgo-function-size-threshold", cl::Hidden, cl::desc("Do not instrument functions smaller than this threshold.")); -static cl::opt<bool> MatchMemProf( - "pgo-match-memprof", cl::init(true), cl::Hidden, - cl::desc("Perform matching and annotation of memprof profiles.")); - static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold( "pgo-critical-edge-threshold", cl::init(20000), cl::Hidden, cl::desc("Do not instrument functions with the number of critical edges " @@ -344,7 +348,7 @@ static std::string getBranchCondString(Instruction *TI) { std::string result; raw_string_ostream OS(result); - OS << CmpInst::getPredicateName(CI->getPredicate()) << "_"; + OS << CI->getPredicate() << "_"; CI->getOperand(0)->getType()->print(OS, true); Value *RHS = CI->getOperand(1); @@ -383,6 +387,10 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) { if (PGOFunctionEntryCoverage) ProfileVersion |= VARIANT_MASK_BYTE_COVERAGE | VARIANT_MASK_FUNCTION_ENTRY_ONLY; + if (PGOBlockCoverage) + ProfileVersion |= VARIANT_MASK_BYTE_COVERAGE; + if (PGOTemporalInstrumentation) + ProfileVersion |= VARIANT_MASK_TEMPORAL_PROF; auto IRLevelVersionVariable = new GlobalVariable( M, IntTy64, true, GlobalValue::WeakAnyLinkage, Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), VarName); @@ -415,35 +423,37 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { GlobalVariable *FuncNameVar = nullptr; uint64_t FuncHash = 0; PGOUseFunc *UseFunc = nullptr; + bool HasSingleByteCoverage; - SelectInstVisitor(Function &Func) : F(Func) {} + SelectInstVisitor(Function &Func, bool HasSingleByteCoverage) + : F(Func), HasSingleByteCoverage(HasSingleByteCoverage) {} - void countSelects(Function &Func) { + void countSelects() { NSIs = 0; Mode = VM_counting; - visit(Func); + visit(F); } // Visit the IR stream and instrument all select instructions. \p // Ind is a pointer to the counter index variable; \p TotalNC // is the total number of counters; \p FNV is the pointer to the // PGO function name var; \p FHash is the function hash. - void instrumentSelects(Function &Func, unsigned *Ind, unsigned TotalNC, - GlobalVariable *FNV, uint64_t FHash) { + void instrumentSelects(unsigned *Ind, unsigned TotalNC, GlobalVariable *FNV, + uint64_t FHash) { Mode = VM_instrument; CurCtrIdx = Ind; TotalNumCtrs = TotalNC; FuncHash = FHash; FuncNameVar = FNV; - visit(Func); + visit(F); } // Visit the IR stream and annotate all select instructions. - void annotateSelects(Function &Func, PGOUseFunc *UF, unsigned *Ind) { + void annotateSelects(PGOUseFunc *UF, unsigned *Ind) { Mode = VM_annotate; UseFunc = UF; CurCtrIdx = Ind; - visit(Func); + visit(F); } void instrumentOneSelectInst(SelectInst &SI); @@ -457,52 +467,41 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { unsigned getNumOfSelectInsts() const { return NSIs; } }; -} // end anonymous namespace - -namespace { - -/// An MST based instrumentation for PGO -/// -/// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO -/// in the function level. +/// This class implements the CFG edges for the Minimum Spanning Tree (MST) +/// based instrumentation. +/// Note that the CFG can be a multi-graph. So there might be multiple edges +/// with the same SrcBB and DestBB. struct PGOEdge { - // This class implements the CFG edges. Note the CFG can be a multi-graph. - // So there might be multiple edges with same SrcBB and DestBB. - const BasicBlock *SrcBB; - const BasicBlock *DestBB; + BasicBlock *SrcBB; + BasicBlock *DestBB; uint64_t Weight; bool InMST = false; bool Removed = false; bool IsCritical = false; - PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W = 1) + PGOEdge(BasicBlock *Src, BasicBlock *Dest, uint64_t W = 1) : SrcBB(Src), DestBB(Dest), Weight(W) {} - // Return the information string of an edge. + /// Return the information string of an edge. std::string infoString() const { return (Twine(Removed ? "-" : " ") + (InMST ? " " : "*") + - (IsCritical ? "c" : " ") + " W=" + Twine(Weight)).str(); + (IsCritical ? "c" : " ") + " W=" + Twine(Weight)) + .str(); } }; -// This class stores the auxiliary information for each BB. -struct BBInfo { - BBInfo *Group; +/// This class stores the auxiliary information for each BB in the MST. +struct PGOBBInfo { + PGOBBInfo *Group; uint32_t Index; uint32_t Rank = 0; - BBInfo(unsigned IX) : Group(this), Index(IX) {} + PGOBBInfo(unsigned IX) : Group(this), Index(IX) {} - // Return the information string of this object. + /// Return the information string of this object. std::string infoString() const { return (Twine("Index=") + Twine(Index)).str(); } - - // Empty function -- only applicable to UseBBInfo. - void addOutEdge(PGOEdge *E LLVM_ATTRIBUTE_UNUSED) {} - - // Empty function -- only applicable to UseBBInfo. - void addInEdge(PGOEdge *E LLVM_ATTRIBUTE_UNUSED) {} }; // This class implements the CFG edges. Note the CFG can be a multi-graph. @@ -534,6 +533,16 @@ public: // The Minimum Spanning Tree of function CFG. CFGMST<Edge, BBInfo> MST; + const std::optional<BlockCoverageInference> BCI; + + static std::optional<BlockCoverageInference> + constructBCI(Function &Func, bool HasSingleByteCoverage, + bool InstrumentFuncEntry) { + if (HasSingleByteCoverage) + return BlockCoverageInference(Func, InstrumentFuncEntry); + return {}; + } + // Collect all the BBs that will be instrumented, and store them in // InstrumentBBs. void getInstrumentBBs(std::vector<BasicBlock *> &InstrumentBBs); @@ -549,9 +558,9 @@ public: BBInfo *findBBInfo(const BasicBlock *BB) const { return MST.findBBInfo(BB); } // Dump edges and BB information. - void dumpInfo(std::string Str = "") const { - MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " + - Twine(FunctionHash) + "\t" + Str); + void dumpInfo(StringRef Str = "") const { + MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + + " Hash: " + Twine(FunctionHash) + "\t" + Str); } FuncPGOInstrumentation( @@ -559,12 +568,16 @@ public: std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr, bool IsCS = false, - bool InstrumentFuncEntry = true) + bool InstrumentFuncEntry = true, bool HasSingleByteCoverage = false) : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func, TLI), - TLI(TLI), ValueSites(IPVK_Last + 1), SIVisitor(Func), - MST(F, InstrumentFuncEntry, BPI, BFI) { + TLI(TLI), ValueSites(IPVK_Last + 1), + SIVisitor(Func, HasSingleByteCoverage), + MST(F, InstrumentFuncEntry, BPI, BFI), + BCI(constructBCI(Func, HasSingleByteCoverage, InstrumentFuncEntry)) { + if (BCI && PGOViewBlockCoverageGraph) + BCI->viewBlockCoverageGraph(); // This should be done before CFG hash computation. - SIVisitor.countSelects(Func); + SIVisitor.countSelects(); ValueSites[IPVK_MemOPSize] = VPC.get(IPVK_MemOPSize); if (!IsCS) { NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); @@ -637,7 +650,11 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { updateJCH((uint64_t)SIVisitor.getNumOfSelectInsts()); updateJCH((uint64_t)ValueSites[IPVK_IndirectCallTarget].size()); updateJCH((uint64_t)ValueSites[IPVK_MemOPSize].size()); - updateJCH((uint64_t)MST.AllEdges.size()); + if (BCI) { + updateJCH(BCI->getInstrumentedBlocksHash()); + } else { + updateJCH((uint64_t)MST.AllEdges.size()); + } // Hash format for context sensitive profile. Reserve 4 bits for other // information. @@ -725,11 +742,18 @@ void FuncPGOInstrumentation<Edge, BBInfo>::renameComdatFunction() { } } -// Collect all the BBs that will be instruments and return them in -// InstrumentBBs and setup InEdges/OutEdge for UseBBInfo. +/// Collect all the BBs that will be instruments and add them to +/// `InstrumentBBs`. template <class Edge, class BBInfo> void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs( std::vector<BasicBlock *> &InstrumentBBs) { + if (BCI) { + for (auto &BB : F) + if (BCI->shouldInstrumentBlock(BB)) + InstrumentBBs.push_back(&BB); + return; + } + // Use a worklist as we will update the vector during the iteration. std::vector<Edge *> EdgeList; EdgeList.reserve(MST.AllEdges.size()); @@ -741,18 +765,6 @@ void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs( if (InstrBB) InstrumentBBs.push_back(InstrBB); } - - // Set up InEdges/OutEdges for all BBs. - for (auto &E : MST.AllEdges) { - if (E->Removed) - continue; - const BasicBlock *SrcBB = E->SrcBB; - const BasicBlock *DestBB = E->DestBB; - BBInfo &SrcInfo = getBBInfo(SrcBB); - BBInfo &DestInfo = getBBInfo(DestBB); - SrcInfo.addOutEdge(E.get()); - DestInfo.addInEdge(E.get()); - } } // Given a CFG E to be instrumented, find which BB to place the instrumented @@ -762,8 +774,8 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { if (E->InMST || E->Removed) return nullptr; - BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB); - BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB); + BasicBlock *SrcBB = E->SrcBB; + BasicBlock *DestBB = E->DestBB; // For a fake edge, instrument the real BB. if (SrcBB == nullptr) return DestBB; @@ -852,12 +864,15 @@ static void instrumentOneFunc( BlockFrequencyInfo *BFI, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, bool IsCS) { - // Split indirectbr critical edges here before computing the MST rather than - // later in getInstrBB() to avoid invalidating it. - SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI); + if (!PGOBlockCoverage) { + // Split indirectbr critical edges here before computing the MST rather than + // later in getInstrBB() to avoid invalidating it. + SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI); + } - FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo( - F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry); + FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo( + F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry, + PGOBlockCoverage); Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); auto Name = ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy); @@ -880,6 +895,18 @@ static void instrumentOneFunc( InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts(); uint32_t I = 0; + if (PGOTemporalInstrumentation) { + NumCounters += PGOBlockCoverage ? 8 : 1; + auto &EntryBB = F.getEntryBlock(); + IRBuilder<> Builder(&EntryBB, EntryBB.getFirstInsertionPt()); + // llvm.instrprof.timestamp(i8* <name>, i64 <hash>, i32 <num-counters>, + // i32 <index>) + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_timestamp), + {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I)}); + I += PGOBlockCoverage ? 8 : 1; + } + for (auto *InstrBB : InstrumentBBs) { IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt()); assert(Builder.GetInsertPoint() != InstrBB->end() && @@ -887,12 +914,14 @@ static void instrumentOneFunc( // llvm.instrprof.increment(i8* <name>, i64 <hash>, i32 <num-counters>, // i32 <index>) Builder.CreateCall( - Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment), + Intrinsic::getDeclaration(M, PGOBlockCoverage + ? Intrinsic::instrprof_cover + : Intrinsic::instrprof_increment), {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I++)}); } // Now instrument select instructions: - FuncInfo.SIVisitor.instrumentSelects(F, &I, NumCounters, FuncInfo.FuncNameVar, + FuncInfo.SIVisitor.instrumentSelects(&I, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash); assert(I == NumCounters); @@ -947,12 +976,11 @@ namespace { // This class represents a CFG edge in profile use compilation. struct PGOUseEdge : public PGOEdge { + using PGOEdge::PGOEdge; + bool CountValid = false; uint64_t CountValue = 0; - PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W = 1) - : PGOEdge(Src, Dest, W) {} - // Set edge count value void setEdgeCount(uint64_t Value) { CountValue = Value; @@ -971,7 +999,7 @@ struct PGOUseEdge : public PGOEdge { using DirectEdges = SmallVector<PGOUseEdge *, 2>; // This class stores the auxiliary information for each BB. -struct UseBBInfo : public BBInfo { +struct PGOUseBBInfo : public PGOBBInfo { uint64_t CountValue = 0; bool CountValid; int32_t UnknownCountInEdge = 0; @@ -979,10 +1007,7 @@ struct UseBBInfo : public BBInfo { DirectEdges InEdges; DirectEdges OutEdges; - UseBBInfo(unsigned IX) : BBInfo(IX), CountValid(false) {} - - UseBBInfo(unsigned IX, uint64_t C) - : BBInfo(IX), CountValue(C), CountValid(true) {} + PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {} // Set the profile count value for this BB. void setBBInfoCount(uint64_t Value) { @@ -993,8 +1018,9 @@ struct UseBBInfo : public BBInfo { // Return the information string of this object. std::string infoString() const { if (!CountValid) - return BBInfo::infoString(); - return (Twine(BBInfo::infoString()) + " Count=" + Twine(CountValue)).str(); + return PGOBBInfo::infoString(); + return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue)) + .str(); } // Add an OutEdge and update the edge count. @@ -1030,22 +1056,25 @@ public: PGOUseFunc(Function &Func, Module *Modu, TargetLibraryInfo &TLI, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFIin, - ProfileSummaryInfo *PSI, bool IsCS, bool InstrumentFuncEntry) + ProfileSummaryInfo *PSI, bool IsCS, bool InstrumentFuncEntry, + bool HasSingleByteCoverage) : F(Func), M(Modu), BFI(BFIin), PSI(PSI), FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, IsCS, - InstrumentFuncEntry), + InstrumentFuncEntry, HasSingleByteCoverage), FreqAttr(FFA_Normal), IsCS(IsCS) {} + void handleInstrProfError(Error Err, uint64_t MismatchedFuncSum); + // Read counts for the instrumented BB from profile. bool readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, InstrProfRecord::CountPseudoKind &PseudoKind); - // Read memprof data for the instrumented function from profile. - bool readMemprof(IndexedInstrProfReader *PGOReader); - // Populate the counts for all BBs. void populateCounters(); + // Set block coverage based on profile coverage values. + void populateCoverage(IndexedInstrProfReader *PGOReader); + // Set the branch weights based on the count values. void setBranchWeights(); @@ -1071,22 +1100,21 @@ public: InstrProfRecord &getProfileRecord() { return ProfileRecord; } // Return the auxiliary BB information. - UseBBInfo &getBBInfo(const BasicBlock *BB) const { + PGOUseBBInfo &getBBInfo(const BasicBlock *BB) const { return FuncInfo.getBBInfo(BB); } // Return the auxiliary BB information if available. - UseBBInfo *findBBInfo(const BasicBlock *BB) const { + PGOUseBBInfo *findBBInfo(const BasicBlock *BB) const { return FuncInfo.findBBInfo(BB); } Function &getFunc() const { return F; } - void dumpInfo(std::string Str = "") const { - FuncInfo.dumpInfo(Str); - } + void dumpInfo(StringRef Str = "") const { FuncInfo.dumpInfo(Str); } uint64_t getProgramMaxCount() const { return ProgramMaxCount; } + private: Function &F; Module *M; @@ -1094,7 +1122,7 @@ private: ProfileSummaryInfo *PSI; // This member stores the shared information with class PGOGenFunc. - FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo; + FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> FuncInfo; // The maximum count value in the profile. This is only used in PGO use // compilation. @@ -1122,9 +1150,6 @@ private: // one unknown edge. void setEdgeCount(DirectEdges &Edges, uint64_t Value); - // Return FuncName string; - std::string getFuncName() const { return FuncInfo.FuncName; } - // Set the hot/cold inline hints based on the count values. // FIXME: This function should be removed once the functionality in // the inliner is implemented. @@ -1138,6 +1163,24 @@ private: } // end anonymous namespace +/// Set up InEdges/OutEdges for all BBs in the MST. +static void +setupBBInfoEdges(FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> &FuncInfo) { + // This is not required when there is block coverage inference. + if (FuncInfo.BCI) + return; + for (auto &E : FuncInfo.MST.AllEdges) { + if (E->Removed) + continue; + const BasicBlock *SrcBB = E->SrcBB; + const BasicBlock *DestBB = E->DestBB; + PGOUseBBInfo &SrcInfo = FuncInfo.getBBInfo(SrcBB); + PGOUseBBInfo &DestInfo = FuncInfo.getBBInfo(DestBB); + SrcInfo.addOutEdge(E.get()); + DestInfo.addInEdge(E.get()); + } +} + // Visit all the edges and assign the count value for the instrumented // edges and the BB. Return false on error. bool PGOUseFunc::setInstrumentedCounts( @@ -1145,6 +1188,9 @@ bool PGOUseFunc::setInstrumentedCounts( std::vector<BasicBlock *> InstrumentBBs; FuncInfo.getInstrumentBBs(InstrumentBBs); + + setupBBInfoEdges(FuncInfo); + unsigned NumCounters = InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts(); // The number of counters here should match the number of counters @@ -1158,7 +1204,7 @@ bool PGOUseFunc::setInstrumentedCounts( uint32_t I = 0; for (BasicBlock *InstrBB : InstrumentBBs) { uint64_t CountValue = CountFromProfile[I++]; - UseBBInfo &Info = getBBInfo(InstrBB); + PGOUseBBInfo &Info = getBBInfo(InstrBB); // If we reach here, we know that we have some nonzero count // values in this function. The entry count should not be 0. // Fix it if necessary. @@ -1183,7 +1229,7 @@ bool PGOUseFunc::setInstrumentedCounts( if (E->Removed || E->InMST) continue; const BasicBlock *SrcBB = E->SrcBB; - UseBBInfo &SrcInfo = getBBInfo(SrcBB); + PGOUseBBInfo &SrcInfo = getBBInfo(SrcBB); // If only one out-edge, the edge profile count should be the same as BB // profile count. @@ -1191,7 +1237,7 @@ bool PGOUseFunc::setInstrumentedCounts( setEdgeCount(E.get(), SrcInfo.CountValue); else { const BasicBlock *DestBB = E->DestBB; - UseBBInfo &DestInfo = getBBInfo(DestBB); + PGOUseBBInfo &DestInfo = getBBInfo(DestBB); // If only one in-edge, the edge profile count should be the same as BB // profile count. if (DestInfo.CountValid && DestInfo.InEdges.size() == 1) @@ -1222,8 +1268,7 @@ void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) { } // Emit function metadata indicating PGO profile mismatch. -static void annotateFunctionWithHashMismatch(Function &F, - LLVMContext &ctx) { +static void annotateFunctionWithHashMismatch(Function &F, LLVMContext &ctx) { const char MetadataName[] = "instr_prof_hash_mismatch"; SmallVector<Metadata *, 2> Names; // If this metadata already exists, ignore. @@ -1231,7 +1276,7 @@ static void annotateFunctionWithHashMismatch(Function &F, if (Existing) { MDTuple *Tuple = cast<MDTuple>(Existing); for (const auto &N : Tuple->operands()) { - if (cast<MDString>(N.get())->getString() == MetadataName) + if (N.equalsStr(MetadataName)) return; Names.push_back(N.get()); } @@ -1243,255 +1288,44 @@ static void annotateFunctionWithHashMismatch(Function &F, F.setMetadata(LLVMContext::MD_annotation, MD); } -static void addCallsiteMetadata(Instruction &I, - std::vector<uint64_t> &InlinedCallStack, - LLVMContext &Ctx) { - I.setMetadata(LLVMContext::MD_callsite, - buildCallstackMetadata(InlinedCallStack, Ctx)); -} - -static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset, - uint32_t Column) { - llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little> - HashBuilder; - HashBuilder.add(Function, LineOffset, Column); - llvm::BLAKE3Result<8> Hash = HashBuilder.final(); - uint64_t Id; - std::memcpy(&Id, Hash.data(), sizeof(Hash)); - return Id; -} - -static uint64_t computeStackId(const memprof::Frame &Frame) { - return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column); -} - -static void addCallStack(CallStackTrie &AllocTrie, - const AllocationInfo *AllocInfo) { - SmallVector<uint64_t> StackIds; - for (auto StackFrame : AllocInfo->CallStack) - StackIds.push_back(computeStackId(StackFrame)); - auto AllocType = getAllocType(AllocInfo->Info.getMaxAccessCount(), - AllocInfo->Info.getMinSize(), - AllocInfo->Info.getMinLifetime()); - AllocTrie.addCallStack(AllocType, StackIds); -} - -// Helper to compare the InlinedCallStack computed from an instruction's debug -// info to a list of Frames from profile data (either the allocation data or a -// callsite). For callsites, the StartIndex to use in the Frame array may be -// non-zero. -static bool -stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack, - ArrayRef<uint64_t> InlinedCallStack, - unsigned StartIndex = 0) { - auto StackFrame = ProfileCallStack.begin() + StartIndex; - auto InlCallStackIter = InlinedCallStack.begin(); - for (; StackFrame != ProfileCallStack.end() && - InlCallStackIter != InlinedCallStack.end(); - ++StackFrame, ++InlCallStackIter) { - uint64_t StackId = computeStackId(*StackFrame); - if (StackId != *InlCallStackIter) - return false; - } - // Return true if we found and matched all stack ids from the call - // instruction. - return InlCallStackIter == InlinedCallStack.end(); -} - -bool PGOUseFunc::readMemprof(IndexedInstrProfReader *PGOReader) { - if (!MatchMemProf) - return true; - - auto &Ctx = M->getContext(); - - auto FuncGUID = Function::getGUID(FuncInfo.FuncName); - Expected<memprof::MemProfRecord> MemProfResult = - PGOReader->getMemProfRecord(FuncGUID); - if (Error E = MemProfResult.takeError()) { - handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { - auto Err = IPE.get(); - bool SkipWarning = false; - LLVM_DEBUG(dbgs() << "Error in reading profile for Func " - << FuncInfo.FuncName << ": "); - if (Err == instrprof_error::unknown_function) { - NumOfMemProfMissing++; - SkipWarning = !PGOWarnMissing; - LLVM_DEBUG(dbgs() << "unknown function"); - } else if (Err == instrprof_error::hash_mismatch) { - SkipWarning = - NoPGOWarnMismatch || - (NoPGOWarnMismatchComdatWeak && - (F.hasComdat() || - F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); - LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")"); - } - - if (SkipWarning) - return; - - std::string Msg = - (IPE.message() + Twine(" ") + F.getName().str() + Twine(" Hash = ") + - std::to_string(FuncInfo.FunctionHash)) - .str(); - - Ctx.diagnose( - DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); - }); - return false; - } - - // Build maps of the location hash to all profile data with that leaf location - // (allocation info and the callsites). - std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo; - // For the callsites we need to record the index of the associated frame in - // the frame array (see comments below where the map entries are added). - std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>> - LocHashToCallSites; - const auto MemProfRec = std::move(MemProfResult.get()); - for (auto &AI : MemProfRec.AllocSites) { - // Associate the allocation info with the leaf frame. The later matching - // code will match any inlined call sequences in the IR with a longer prefix - // of call stack frames. - uint64_t StackId = computeStackId(AI.CallStack[0]); - LocHashToAllocInfo[StackId].insert(&AI); - } - for (auto &CS : MemProfRec.CallSites) { - // Need to record all frames from leaf up to and including this function, - // as any of these may or may not have been inlined at this point. - unsigned Idx = 0; - for (auto &StackFrame : CS) { - uint64_t StackId = computeStackId(StackFrame); - LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++)); - // Once we find this function, we can stop recording. - if (StackFrame.Function == FuncGUID) - break; +void PGOUseFunc::handleInstrProfError(Error Err, uint64_t MismatchedFuncSum) { + handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) { + auto &Ctx = M->getContext(); + auto Err = IPE.get(); + bool SkipWarning = false; + LLVM_DEBUG(dbgs() << "Error in reading profile for Func " + << FuncInfo.FuncName << ": "); + if (Err == instrprof_error::unknown_function) { + IsCS ? NumOfCSPGOMissing++ : NumOfPGOMissing++; + SkipWarning = !PGOWarnMissing; + LLVM_DEBUG(dbgs() << "unknown function"); + } else if (Err == instrprof_error::hash_mismatch || + Err == instrprof_error::malformed) { + IsCS ? NumOfCSPGOMismatch++ : NumOfPGOMismatch++; + SkipWarning = + NoPGOWarnMismatch || + (NoPGOWarnMismatchComdatWeak && + (F.hasComdat() || F.getLinkage() == GlobalValue::WeakAnyLinkage || + F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); + LLVM_DEBUG(dbgs() << "hash mismatch (hash= " << FuncInfo.FunctionHash + << " skip=" << SkipWarning << ")"); + // Emit function metadata indicating PGO profile mismatch. + annotateFunctionWithHashMismatch(F, M->getContext()); } - assert(Idx <= CS.size() && CS[Idx - 1].Function == FuncGUID); - } - - auto GetOffset = [](const DILocation *DIL) { - return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) & - 0xffff; - }; - - // Now walk the instructions, looking up the associated profile data using - // dbug locations. - for (auto &BB : F) { - for (auto &I : BB) { - if (I.isDebugOrPseudoInst()) - continue; - // We are only interested in calls (allocation or interior call stack - // context calls). - auto *CI = dyn_cast<CallBase>(&I); - if (!CI) - continue; - auto *CalledFunction = CI->getCalledFunction(); - if (CalledFunction && CalledFunction->isIntrinsic()) - continue; - // List of call stack ids computed from the location hashes on debug - // locations (leaf to inlined at root). - std::vector<uint64_t> InlinedCallStack; - // Was the leaf location found in one of the profile maps? - bool LeafFound = false; - // If leaf was found in a map, iterators pointing to its location in both - // of the maps. It might exist in neither, one, or both (the latter case - // can happen because we don't currently have discriminators to - // distinguish the case when a single line/col maps to both an allocation - // and another callsite). - std::map<uint64_t, std::set<const AllocationInfo *>>::iterator - AllocInfoIter; - std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, - unsigned>>>::iterator CallSitesIter; - for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr; - DIL = DIL->getInlinedAt()) { - // Use C++ linkage name if possible. Need to compile with - // -fdebug-info-for-profiling to get linkage name. - StringRef Name = DIL->getScope()->getSubprogram()->getLinkageName(); - if (Name.empty()) - Name = DIL->getScope()->getSubprogram()->getName(); - auto CalleeGUID = Function::getGUID(Name); - auto StackId = - computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn()); - // LeafFound will only be false on the first iteration, since we either - // set it true or break out of the loop below. - if (!LeafFound) { - AllocInfoIter = LocHashToAllocInfo.find(StackId); - CallSitesIter = LocHashToCallSites.find(StackId); - // Check if the leaf is in one of the maps. If not, no need to look - // further at this call. - if (AllocInfoIter == LocHashToAllocInfo.end() && - CallSitesIter == LocHashToCallSites.end()) - break; - LeafFound = true; - } - InlinedCallStack.push_back(StackId); - } - // If leaf not in either of the maps, skip inst. - if (!LeafFound) - continue; - // First add !memprof metadata from allocation info, if we found the - // instruction's leaf location in that map, and if the rest of the - // instruction's locations match the prefix Frame locations on an - // allocation context with the same leaf. - if (AllocInfoIter != LocHashToAllocInfo.end()) { - // Only consider allocations via new, to reduce unnecessary metadata, - // since those are the only allocations that will be targeted initially. - if (!isNewLikeFn(CI, &FuncInfo.TLI)) - continue; - // We may match this instruction's location list to multiple MIB - // contexts. Add them to a Trie specialized for trimming the contexts to - // the minimal needed to disambiguate contexts with unique behavior. - CallStackTrie AllocTrie; - for (auto *AllocInfo : AllocInfoIter->second) { - // Check the full inlined call stack against this one. - // If we found and thus matched all frames on the call, include - // this MIB. - if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack, - InlinedCallStack)) - addCallStack(AllocTrie, AllocInfo); - } - // We might not have matched any to the full inlined call stack. - // But if we did, create and attach metadata, or a function attribute if - // all contexts have identical profiled behavior. - if (!AllocTrie.empty()) { - // MemprofMDAttached will be false if a function attribute was - // attached. - bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI); - assert(MemprofMDAttached == I.hasMetadata(LLVMContext::MD_memprof)); - if (MemprofMDAttached) { - // Add callsite metadata for the instruction's location list so that - // it simpler later on to identify which part of the MIB contexts - // are from this particular instruction (including during inlining, - // when the callsite metdata will be updated appropriately). - // FIXME: can this be changed to strip out the matching stack - // context ids from the MIB contexts and not add any callsite - // metadata here to save space? - addCallsiteMetadata(I, InlinedCallStack, Ctx); - } - } - continue; - } + LLVM_DEBUG(dbgs() << " IsCS=" << IsCS << "\n"); + if (SkipWarning) + return; - // Otherwise, add callsite metadata. If we reach here then we found the - // instruction's leaf location in the callsites map and not the allocation - // map. - assert(CallSitesIter != LocHashToCallSites.end()); - for (auto CallStackIdx : CallSitesIter->second) { - // If we found and thus matched all frames on the call, create and - // attach call stack metadata. - if (stackFrameIncludesInlinedCallStack( - *CallStackIdx.first, InlinedCallStack, CallStackIdx.second)) { - addCallsiteMetadata(I, InlinedCallStack, Ctx); - // Only need to find one with a matching call stack and add a single - // callsite metadata. - break; - } - } - } - } + std::string Msg = + IPE.message() + std::string(" ") + F.getName().str() + + std::string(" Hash = ") + std::to_string(FuncInfo.FunctionHash) + + std::string(" up to ") + std::to_string(MismatchedFuncSum) + + std::string(" count discarded"); - return true; + Ctx.diagnose( + DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); + }); } // Read the profile from ProfileFileName and assign the value to the @@ -1504,42 +1338,7 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord( FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum); if (Error E = Result.takeError()) { - handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { - auto Err = IPE.get(); - bool SkipWarning = false; - LLVM_DEBUG(dbgs() << "Error in reading profile for Func " - << FuncInfo.FuncName << ": "); - if (Err == instrprof_error::unknown_function) { - IsCS ? NumOfCSPGOMissing++ : NumOfPGOMissing++; - SkipWarning = !PGOWarnMissing; - LLVM_DEBUG(dbgs() << "unknown function"); - } else if (Err == instrprof_error::hash_mismatch || - Err == instrprof_error::malformed) { - IsCS ? NumOfCSPGOMismatch++ : NumOfPGOMismatch++; - SkipWarning = - NoPGOWarnMismatch || - (NoPGOWarnMismatchComdatWeak && - (F.hasComdat() || F.getLinkage() == GlobalValue::WeakAnyLinkage || - F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); - LLVM_DEBUG(dbgs() << "hash mismatch (hash= " << FuncInfo.FunctionHash - << " skip=" << SkipWarning << ")"); - // Emit function metadata indicating PGO profile mismatch. - annotateFunctionWithHashMismatch(F, M->getContext()); - } - - LLVM_DEBUG(dbgs() << " IsCS=" << IsCS << "\n"); - if (SkipWarning) - return; - - std::string Msg = - IPE.message() + std::string(" ") + F.getName().str() + - std::string(" Hash = ") + std::to_string(FuncInfo.FunctionHash) + - std::string(" up to ") + std::to_string(MismatchedFuncSum) + - std::string(" count discarded"); - - Ctx.diagnose( - DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); - }); + handleInstrProfError(std::move(E), MismatchedFuncSum); return false; } ProfileRecord = std::move(Result.get()); @@ -1569,8 +1368,9 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, dbgs() << "Inconsistent number of counts, skipping this function"); Ctx.diagnose(DiagnosticInfoPGOProfile( M->getName().data(), - Twine("Inconsistent number of counts in ") + F.getName().str() - + Twine(": the profile may be stale or there is a function name collision."), + Twine("Inconsistent number of counts in ") + F.getName().str() + + Twine(": the profile may be stale or there is a function name " + "collision."), DS_Warning)); return false; } @@ -1578,6 +1378,113 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, return true; } +void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) { + uint64_t MismatchedFuncSum = 0; + Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord( + FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum); + if (auto Err = Result.takeError()) { + handleInstrProfError(std::move(Err), MismatchedFuncSum); + return; + } + + std::vector<uint64_t> &CountsFromProfile = Result.get().Counts; + DenseMap<const BasicBlock *, bool> Coverage; + unsigned Index = 0; + for (auto &BB : F) + if (FuncInfo.BCI->shouldInstrumentBlock(BB)) + Coverage[&BB] = (CountsFromProfile[Index++] != 0); + assert(Index == CountsFromProfile.size()); + + // For each B in InverseDependencies[A], if A is covered then B is covered. + DenseMap<const BasicBlock *, DenseSet<const BasicBlock *>> + InverseDependencies; + for (auto &BB : F) { + for (auto *Dep : FuncInfo.BCI->getDependencies(BB)) { + // If Dep is covered then BB is covered. + InverseDependencies[Dep].insert(&BB); + } + } + + // Infer coverage of the non-instrumented blocks using a flood-fill algorithm. + std::stack<const BasicBlock *> CoveredBlocksToProcess; + for (auto &[BB, IsCovered] : Coverage) + if (IsCovered) + CoveredBlocksToProcess.push(BB); + + while (!CoveredBlocksToProcess.empty()) { + auto *CoveredBlock = CoveredBlocksToProcess.top(); + assert(Coverage[CoveredBlock]); + CoveredBlocksToProcess.pop(); + for (auto *BB : InverseDependencies[CoveredBlock]) { + // If CoveredBlock is covered then BB is covered. + if (Coverage[BB]) + continue; + Coverage[BB] = true; + CoveredBlocksToProcess.push(BB); + } + } + + // Annotate block coverage. + MDBuilder MDB(F.getContext()); + // We set the entry count to 10000 if the entry block is covered so that BFI + // can propagate a fraction of this count to the other covered blocks. + F.setEntryCount(Coverage[&F.getEntryBlock()] ? 10000 : 0); + for (auto &BB : F) { + // For a block A and its successor B, we set the edge weight as follows: + // If A is covered and B is covered, set weight=1. + // If A is covered and B is uncovered, set weight=0. + // If A is uncovered, set weight=1. + // This setup will allow BFI to give nonzero profile counts to only covered + // blocks. + SmallVector<unsigned, 4> Weights; + for (auto *Succ : successors(&BB)) + Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0); + if (Weights.size() >= 2) + BB.getTerminator()->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(Weights)); + } + + unsigned NumCorruptCoverage = 0; + DominatorTree DT(F); + LoopInfo LI(DT); + BranchProbabilityInfo BPI(F, LI); + BlockFrequencyInfo BFI(F, BPI, LI); + auto IsBlockDead = [&](const BasicBlock &BB) -> std::optional<bool> { + if (auto C = BFI.getBlockProfileCount(&BB)) + return C == 0; + return {}; + }; + LLVM_DEBUG(dbgs() << "Block Coverage: (Instrumented=*, Covered=X)\n"); + for (auto &BB : F) { + LLVM_DEBUG(dbgs() << (FuncInfo.BCI->shouldInstrumentBlock(BB) ? "* " : " ") + << (Coverage[&BB] ? "X " : " ") << " " << BB.getName() + << "\n"); + // In some cases it is possible to find a covered block that has no covered + // successors, e.g., when a block calls a function that may call exit(). In + // those cases, BFI could find its successor to be covered while BCI could + // find its successor to be dead. + if (Coverage[&BB] == IsBlockDead(BB).value_or(false)) { + LLVM_DEBUG( + dbgs() << "Found inconsistent block covearge for " << BB.getName() + << ": BCI=" << (Coverage[&BB] ? "Covered" : "Dead") << " BFI=" + << (IsBlockDead(BB).value() ? "Dead" : "Covered") << "\n"); + ++NumCorruptCoverage; + } + if (Coverage[&BB]) + ++NumCoveredBlocks; + } + if (PGOVerifyBFI && NumCorruptCoverage) { + auto &Ctx = M->getContext(); + Ctx.diagnose(DiagnosticInfoPGOProfile( + M->getName().data(), + Twine("Found inconsistent block coverage for function ") + F.getName() + + " in " + Twine(NumCorruptCoverage) + " blocks.", + DS_Warning)); + } + if (PGOViewBlockCoverageGraph) + FuncInfo.BCI->viewBlockCoverageGraph(&Coverage); +} + // Populate the counters from instrumented BBs to all BBs. // In the end of this operation, all BBs should have a valid count value. void PGOUseFunc::populateCounters() { @@ -1590,7 +1497,7 @@ void PGOUseFunc::populateCounters() { // For efficient traversal, it's better to start from the end as most // of the instrumented edges are at the end. for (auto &BB : reverse(F)) { - UseBBInfo *Count = findBBInfo(&BB); + PGOUseBBInfo *Count = findBBInfo(&BB); if (Count == nullptr) continue; if (!Count->CountValid) { @@ -1629,7 +1536,7 @@ void PGOUseFunc::populateCounters() { } LLVM_DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); - (void) NumPasses; + (void)NumPasses; #ifndef NDEBUG // Assert every BB has a valid counter. for (auto &BB : F) { @@ -1655,7 +1562,7 @@ void PGOUseFunc::populateCounters() { markFunctionAttributes(FuncEntryCount, FuncMaxCount); // Now annotate select instructions - FuncInfo.SIVisitor.annotateSelects(F, this, &CountPosition); + FuncInfo.SIVisitor.annotateSelects(this, &CountPosition); assert(CountPosition == ProfileCountSize); LLVM_DEBUG(FuncInfo.dumpInfo("after reading profile.")); @@ -1679,7 +1586,7 @@ void PGOUseFunc::setBranchWeights() { continue; // We have a non-zero Branch BB. - const UseBBInfo &BBCountInfo = getBBInfo(&BB); + const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB); unsigned Size = BBCountInfo.OutEdges.size(); SmallVector<uint64_t, 2> EdgeCounts(Size, 0); uint64_t MaxCount = 0; @@ -1704,11 +1611,11 @@ void PGOUseFunc::setBranchWeights() { // when there is no exit block and the code exits via a noreturn function. auto &Ctx = M->getContext(); Ctx.diagnose(DiagnosticInfoPGOProfile( - M->getName().data(), - Twine("Profile in ") + F.getName().str() + - Twine(" partially ignored") + - Twine(", possibly due to the lack of a return path."), - DS_Warning)); + M->getName().data(), + Twine("Profile in ") + F.getName().str() + + Twine(" partially ignored") + + Twine(", possibly due to the lack of a return path."), + DS_Warning)); } } } @@ -1730,15 +1637,13 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() { // duplication. if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) { Instruction *TI = BB.getTerminator(); - const UseBBInfo &BBCountInfo = getBBInfo(&BB); + const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB); setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue); } } } void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { - if (PGOFunctionEntryCoverage) - return; Module *M = F.getParent(); IRBuilder<> Builder(&SI); Type *Int64Ty = Builder.getInt64Ty(); @@ -1771,7 +1676,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) { } void SelectInstVisitor::visitSelectInst(SelectInst &SI) { - if (!PGOInstrSelect) + if (!PGOInstrSelect || PGOFunctionEntryCoverage || HasSingleByteCoverage) return; // FIXME: do not handle this yet. if (SI.getCondition()->getType()->isVectorTy()) @@ -1815,8 +1720,8 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { Ctx.diagnose(DiagnosticInfoPGOProfile( M->getName().data(), Twine("Inconsistent number of value sites for ") + - Twine(ValueProfKindDescr[Kind]) + - Twine(" profiling in \"") + F.getName().str() + + Twine(ValueProfKindDescr[Kind]) + Twine(" profiling in \"") + + F.getName().str() + Twine("\", possibly due to the use of a stale profile."), DS_Warning)); return; @@ -1907,17 +1812,20 @@ static bool InstrumentAllFunctions( } PreservedAnalyses -PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &AM) { +PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) { createProfileFileNameVar(M, CSInstrName); // The variable in a comdat may be discarded by LTO. Ensure the declaration // will be retained. appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true)); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<FunctionAnalysisManagerModuleProxy>(); + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; } PreservedAnalyses PGOInstrumentationGen::run(Module &M, - ModuleAnalysisManager &AM) { - auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + ModuleAnalysisManager &MAM) { + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupTLI = [&FAM](Function &F) -> TargetLibraryInfo & { return FAM.getResult<TargetLibraryAnalysis>(F); }; @@ -1991,7 +1899,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI, BlockFrequencyInfo NBFI(F, NBPI, LI); // bool PrintFunc = false; bool HotBBOnly = PGOVerifyHotBFI; - std::string Msg; + StringRef Msg; OptimizationRemarkEmitter ORE(&F); unsigned BBNum = 0, BBMisMatchNum = 0, NonZeroBBNum = 0; @@ -2059,6 +1967,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI, static bool annotateAllFunctions( Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName, + vfs::FileSystem &FS, function_ref<TargetLibraryInfo &(Function &)> LookupTLI, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, @@ -2066,8 +1975,8 @@ static bool annotateAllFunctions( LLVM_DEBUG(dbgs() << "Read in profile counters: "); auto &Ctx = M.getContext(); // Read the counter array from file. - auto ReaderOrErr = - IndexedInstrProfReader::create(ProfileFileName, ProfileRemappingFileName); + auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName, FS, + ProfileRemappingFileName); if (Error E = ReaderOrErr.takeError()) { handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) { Ctx.diagnose( @@ -2087,17 +1996,11 @@ static bool annotateAllFunctions( return false; // TODO: might need to change the warning once the clang option is finalized. - if (!PGOReader->isIRLevelProfile() && !PGOReader->hasMemoryProfile()) { + if (!PGOReader->isIRLevelProfile()) { Ctx.diagnose(DiagnosticInfoPGOProfile( ProfileFileName.data(), "Not an IR level instrumentation profile")); return false; } - if (PGOReader->hasSingleByteCoverage()) { - Ctx.diagnose(DiagnosticInfoPGOProfile( - ProfileFileName.data(), - "Cannot use coverage profiles for optimization")); - return false; - } if (PGOReader->functionEntryOnly()) { Ctx.diagnose(DiagnosticInfoPGOProfile( ProfileFileName.data(), @@ -2123,25 +2026,25 @@ static bool annotateAllFunctions( bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled(); if (PGOInstrumentEntry.getNumOccurrences() > 0) InstrumentFuncEntry = PGOInstrumentEntry; + bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage(); for (auto &F : M) { if (skipPGO(F)) continue; auto &TLI = LookupTLI(F); auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); - // Split indirectbr critical edges here before computing the MST rather than - // later in getInstrBB() to avoid invalidating it. - SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI); + if (!HasSingleByteCoverage) { + // Split indirectbr critical edges here before computing the MST rather + // than later in getInstrBB() to avoid invalidating it. + SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, + BFI); + } PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, PSI, IsCS, - InstrumentFuncEntry); - // Read and match memprof first since we do this via debug info and can - // match even if there is an IR mismatch detected for regular PGO below. - if (PGOReader->hasMemoryProfile()) - Func.readMemprof(PGOReader.get()); - - if (!PGOReader->isIRLevelProfile()) + InstrumentFuncEntry, HasSingleByteCoverage); + if (HasSingleByteCoverage) { + Func.populateCoverage(PGOReader.get()); continue; - + } // When PseudoKind is set to a vaule other than InstrProfRecord::NotPseudo, // it means the profile for the function is unrepresentative and this // function is actually hot / warm. We will reset the function hot / cold @@ -2249,21 +2152,24 @@ static bool annotateAllFunctions( return true; } -PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename, - std::string RemappingFilename, - bool IsCS) +PGOInstrumentationUse::PGOInstrumentationUse( + std::string Filename, std::string RemappingFilename, bool IsCS, + IntrusiveRefCntPtr<vfs::FileSystem> VFS) : ProfileFileName(std::move(Filename)), - ProfileRemappingFileName(std::move(RemappingFilename)), IsCS(IsCS) { + ProfileRemappingFileName(std::move(RemappingFilename)), IsCS(IsCS), + FS(std::move(VFS)) { if (!PGOTestProfileFile.empty()) ProfileFileName = PGOTestProfileFile; if (!PGOTestProfileRemappingFile.empty()) ProfileRemappingFileName = PGOTestProfileRemappingFile; + if (!FS) + FS = vfs::getRealFileSystem(); } PreservedAnalyses PGOInstrumentationUse::run(Module &M, - ModuleAnalysisManager &AM) { + ModuleAnalysisManager &MAM) { - auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupTLI = [&FAM](Function &F) -> TargetLibraryInfo & { return FAM.getResult<TargetLibraryAnalysis>(F); }; @@ -2274,9 +2180,9 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, return &FAM.getResult<BlockFrequencyAnalysis>(F); }; - auto *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + auto *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M); - if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, + if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, *FS, LookupTLI, LookupBPI, LookupBFI, PSI, IsCS)) return PreservedAnalyses::all(); @@ -2285,7 +2191,7 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, static std::string getSimpleNodeName(const BasicBlock *Node) { if (!Node->getName().empty()) - return std::string(Node->getName()); + return Node->getName().str(); std::string SimpleNodeName; raw_string_ostream OS(SimpleNodeName); @@ -2294,8 +2200,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) { } void llvm::setProfMetadata(Module *M, Instruction *TI, - ArrayRef<uint64_t> EdgeCounts, - uint64_t MaxCount) { + ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { MDBuilder MDB(M->getContext()); assert(MaxCount > 0 && "Bad max count"); uint64_t Scale = calculateCountScale(MaxCount); @@ -2384,7 +2289,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { raw_string_ostream OS(Result); OS << getSimpleNodeName(Node) << ":\\l"; - UseBBInfo *BI = Graph->findBBInfo(Node); + PGOUseBBInfo *BI = Graph->findBBInfo(Node); OS << "Count : "; if (BI && BI->CountValid) OS << BI->CountValue << "\\l"; diff --git a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index 35db8483fc91..2906fe190984 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -317,7 +317,7 @@ bool MemOPSizeOpt::perform(MemOp MO) { } if (!SeenSizeId.insert(V).second) { - errs() << "Invalid Profile Data in Function " << Func.getName() + errs() << "warning: Invalid Profile Data in Function " << Func.getName() << ": Two identical values in MemOp value counts.\n"; return false; } diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp index 142b9c38e5fc..d83a3a991c89 100644 --- a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp +++ b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp @@ -15,8 +15,9 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -31,15 +32,19 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Support/SpecialCaseList.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/VirtualFileSystem.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <array> #include <cstdint> +#include <memory> using namespace llvm; @@ -49,7 +54,7 @@ namespace { //===--- Constants --------------------------------------------------------===// -constexpr uint32_t kVersionBase = 1; // occupies lower 16 bits +constexpr uint32_t kVersionBase = 2; // occupies lower 16 bits constexpr uint32_t kVersionPtrSizeRel = (1u << 16); // offsets are pointer-sized constexpr int kCtorDtorPriority = 2; @@ -59,7 +64,6 @@ class MetadataInfo { public: const StringRef FunctionPrefix; const StringRef SectionSuffix; - const uint32_t FeatureMask; static const MetadataInfo Covered; static const MetadataInfo Atomics; @@ -67,16 +71,13 @@ public: private: // Forbid construction elsewhere. explicit constexpr MetadataInfo(StringRef FunctionPrefix, - StringRef SectionSuffix, uint32_t Feature) - : FunctionPrefix(FunctionPrefix), SectionSuffix(SectionSuffix), - FeatureMask(Feature) {} + StringRef SectionSuffix) + : FunctionPrefix(FunctionPrefix), SectionSuffix(SectionSuffix) {} }; -const MetadataInfo MetadataInfo::Covered{"__sanitizer_metadata_covered", - kSanitizerBinaryMetadataCoveredSection, - kSanitizerBinaryMetadataNone}; -const MetadataInfo MetadataInfo::Atomics{"__sanitizer_metadata_atomics", - kSanitizerBinaryMetadataAtomicsSection, - kSanitizerBinaryMetadataAtomics}; +const MetadataInfo MetadataInfo::Covered{ + "__sanitizer_metadata_covered", kSanitizerBinaryMetadataCoveredSection}; +const MetadataInfo MetadataInfo::Atomics{ + "__sanitizer_metadata_atomics", kSanitizerBinaryMetadataAtomicsSection}; // The only instances of MetadataInfo are the constants above, so a set of // them may simply store pointers to them. To deterministically generate code, @@ -89,6 +90,11 @@ cl::opt<bool> ClWeakCallbacks( "sanitizer-metadata-weak-callbacks", cl::desc("Declare callbacks extern weak, and only call if non-null."), cl::Hidden, cl::init(true)); +cl::opt<bool> + ClNoSanitize("sanitizer-metadata-nosanitize-attr", + cl::desc("Mark some metadata features uncovered in functions " + "with associated no_sanitize attributes."), + cl::Hidden, cl::init(true)); cl::opt<bool> ClEmitCovered("sanitizer-metadata-covered", cl::desc("Emit PCs for covered functions."), @@ -120,24 +126,20 @@ transformOptionsFromCl(SanitizerBinaryMetadataOptions &&Opts) { class SanitizerBinaryMetadata { public: - SanitizerBinaryMetadata(Module &M, SanitizerBinaryMetadataOptions Opts) + SanitizerBinaryMetadata(Module &M, SanitizerBinaryMetadataOptions Opts, + std::unique_ptr<SpecialCaseList> Ignorelist) : Mod(M), Options(transformOptionsFromCl(std::move(Opts))), - TargetTriple(M.getTargetTriple()), IRB(M.getContext()) { + Ignorelist(std::move(Ignorelist)), TargetTriple(M.getTargetTriple()), + IRB(M.getContext()) { // FIXME: Make it work with other formats. assert(TargetTriple.isOSBinFormatELF() && "ELF only"); + assert(!(TargetTriple.isNVPTX() || TargetTriple.isAMDGPU()) && + "Device targets are not supported"); } bool run(); private: - // Return enabled feature mask of per-instruction metadata. - uint32_t getEnabledPerInstructionFeature() const { - uint32_t FeatureMask = 0; - if (Options.Atomics) - FeatureMask |= MetadataInfo::Atomics.FeatureMask; - return FeatureMask; - } - uint32_t getVersion() const { uint32_t Version = kVersionBase; const auto CM = Mod.getCodeModel(); @@ -156,7 +158,7 @@ private: // to determine if a memory operation is atomic or not in modules compiled // with SanitizerBinaryMetadata. bool runOn(Instruction &I, MetadataInfoSet &MIS, MDBuilder &MDB, - uint32_t &FeatureMask); + uint64_t &FeatureMask); // Get start/end section marker pointer. GlobalVariable *getSectionMarker(const Twine &MarkerName, Type *Ty); @@ -170,10 +172,16 @@ private: // Returns the section end marker name. Twine getSectionEnd(StringRef SectionSuffix); + // Returns true if the access to the address should be considered "atomic". + bool pretendAtomicAccess(const Value *Addr); + Module &Mod; const SanitizerBinaryMetadataOptions Options; + std::unique_ptr<SpecialCaseList> Ignorelist; const Triple TargetTriple; IRBuilder<> IRB; + BumpPtrAllocator Alloc; + UniqueStringSaver StringPool{Alloc}; }; bool SanitizerBinaryMetadata::run() { @@ -218,17 +226,23 @@ bool SanitizerBinaryMetadata::run() { (MI->FunctionPrefix + "_del").str(), InitTypes, InitArgs, /*VersionCheckName=*/StringRef(), /*Weak=*/ClWeakCallbacks) .first; - Constant *CtorData = nullptr; - Constant *DtorData = nullptr; + Constant *CtorComdatKey = nullptr; + Constant *DtorComdatKey = nullptr; if (TargetTriple.supportsCOMDAT()) { - // Use COMDAT to deduplicate constructor/destructor function. + // Use COMDAT to deduplicate constructor/destructor function. The COMDAT + // key needs to be a non-local linkage. Ctor->setComdat(Mod.getOrInsertComdat(Ctor->getName())); Dtor->setComdat(Mod.getOrInsertComdat(Dtor->getName())); - CtorData = Ctor; - DtorData = Dtor; + Ctor->setLinkage(GlobalValue::ExternalLinkage); + Dtor->setLinkage(GlobalValue::ExternalLinkage); + // DSOs should _not_ call another constructor/destructor! + Ctor->setVisibility(GlobalValue::HiddenVisibility); + Dtor->setVisibility(GlobalValue::HiddenVisibility); + CtorComdatKey = Ctor; + DtorComdatKey = Dtor; } - appendToGlobalCtors(Mod, Ctor, kCtorDtorPriority, CtorData); - appendToGlobalDtors(Mod, Dtor, kCtorDtorPriority, DtorData); + appendToGlobalCtors(Mod, Ctor, kCtorDtorPriority, CtorComdatKey); + appendToGlobalDtors(Mod, Dtor, kCtorDtorPriority, DtorComdatKey); } return true; @@ -239,6 +253,8 @@ void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) { return; if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) return; + if (Ignorelist && Ignorelist->inSection("metadata", "fun", F.getName())) + return; // Don't touch available_externally functions, their actual body is elsewhere. if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return; @@ -247,18 +263,18 @@ void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) { // The metadata features enabled for this function, stored along covered // metadata (if enabled). - uint32_t FeatureMask = getEnabledPerInstructionFeature(); + uint64_t FeatureMask = 0; // Don't emit unnecessary covered metadata for all functions to save space. bool RequiresCovered = false; - // We can only understand if we need to set UAR feature after looking - // at the instructions. So we need to check instructions even if FeatureMask - // is empty. - if (FeatureMask || Options.UAR) { + + if (Options.Atomics || Options.UAR) { for (BasicBlock &BB : F) for (Instruction &I : BB) RequiresCovered |= runOn(I, MIS, MDB, FeatureMask); } + if (ClNoSanitize && F.hasFnAttribute("no_sanitize_thread")) + FeatureMask &= ~kSanitizerBinaryMetadataAtomics; if (F.isVarArg()) FeatureMask &= ~kSanitizerBinaryMetadataUAR; if (FeatureMask & kSanitizerBinaryMetadataUAR) { @@ -274,9 +290,8 @@ void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) { const auto *MI = &MetadataInfo::Covered; MIS.insert(MI); const StringRef Section = getSectionName(MI->SectionSuffix); - // The feature mask will be placed after the size (32 bit) of the function, - // so in total one covered entry will use `sizeof(void*) + 4 + 4`. - Constant *CFM = IRB.getInt32(FeatureMask); + // The feature mask will be placed after the function size. + Constant *CFM = IRB.getInt64(FeatureMask); F.setMetadata(LLVMContext::MD_pcsections, MDB.createPCSections({{Section, {CFM}}})); } @@ -338,23 +353,80 @@ bool useAfterReturnUnsafe(Instruction &I) { return false; } +bool SanitizerBinaryMetadata::pretendAtomicAccess(const Value *Addr) { + if (!Addr) + return false; + + Addr = Addr->stripInBoundsOffsets(); + auto *GV = dyn_cast<GlobalVariable>(Addr); + if (!GV) + return false; + + // Some compiler-generated accesses are known racy, to avoid false positives + // in data-race analysis pretend they're atomic. + if (GV->hasSection()) { + const auto OF = Triple(Mod.getTargetTriple()).getObjectFormat(); + const auto ProfSec = + getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false); + if (GV->getSection().endswith(ProfSec)) + return true; + } + if (GV->getName().startswith("__llvm_gcov") || + GV->getName().startswith("__llvm_gcda")) + return true; + + return false; +} + +// Returns true if the memory at `Addr` may be shared with other threads. +bool maybeSharedMutable(const Value *Addr) { + // By default assume memory may be shared. + if (!Addr) + return true; + + if (isa<AllocaInst>(getUnderlyingObject(Addr)) && + !PointerMayBeCaptured(Addr, true, true)) + return false; // Object is on stack but does not escape. + + Addr = Addr->stripInBoundsOffsets(); + if (auto *GV = dyn_cast<GlobalVariable>(Addr)) { + if (GV->isConstant()) + return false; // Shared, but not mutable. + } + + return true; +} + bool SanitizerBinaryMetadata::runOn(Instruction &I, MetadataInfoSet &MIS, - MDBuilder &MDB, uint32_t &FeatureMask) { + MDBuilder &MDB, uint64_t &FeatureMask) { SmallVector<const MetadataInfo *, 1> InstMetadata; bool RequiresCovered = false; + // Only call if at least 1 type of metadata is requested. + assert(Options.UAR || Options.Atomics); + if (Options.UAR && !(FeatureMask & kSanitizerBinaryMetadataUAR)) { if (useAfterReturnUnsafe(I)) FeatureMask |= kSanitizerBinaryMetadataUAR; } - if (Options.Atomics && I.mayReadOrWriteMemory()) { - auto SSID = getAtomicSyncScopeID(&I); - if (SSID.has_value() && *SSID != SyncScope::SingleThread) { - NumMetadataAtomics++; - InstMetadata.push_back(&MetadataInfo::Atomics); + if (Options.Atomics) { + const Value *Addr = nullptr; + if (auto *SI = dyn_cast<StoreInst>(&I)) + Addr = SI->getPointerOperand(); + else if (auto *LI = dyn_cast<LoadInst>(&I)) + Addr = LI->getPointerOperand(); + + if (I.mayReadOrWriteMemory() && maybeSharedMutable(Addr)) { + auto SSID = getAtomicSyncScopeID(&I); + if ((SSID.has_value() && *SSID != SyncScope::SingleThread) || + pretendAtomicAccess(Addr)) { + NumMetadataAtomics++; + InstMetadata.push_back(&MetadataInfo::Atomics); + } + FeatureMask |= kSanitizerBinaryMetadataAtomics; + RequiresCovered = true; } - RequiresCovered = true; } // Attach MD_pcsections to instruction. @@ -381,8 +453,9 @@ SanitizerBinaryMetadata::getSectionMarker(const Twine &MarkerName, Type *Ty) { } StringRef SanitizerBinaryMetadata::getSectionName(StringRef SectionSuffix) { - // FIXME: Other TargetTriple (req. string pool) - return SectionSuffix; + // FIXME: Other TargetTriples. + // Request ULEB128 encoding for all integer constants. + return StringPool.save(SectionSuffix + "!C"); } Twine SanitizerBinaryMetadata::getSectionStart(StringRef SectionSuffix) { @@ -396,12 +469,20 @@ Twine SanitizerBinaryMetadata::getSectionEnd(StringRef SectionSuffix) { } // namespace SanitizerBinaryMetadataPass::SanitizerBinaryMetadataPass( - SanitizerBinaryMetadataOptions Opts) - : Options(std::move(Opts)) {} + SanitizerBinaryMetadataOptions Opts, ArrayRef<std::string> IgnorelistFiles) + : Options(std::move(Opts)), IgnorelistFiles(std::move(IgnorelistFiles)) {} PreservedAnalyses SanitizerBinaryMetadataPass::run(Module &M, AnalysisManager<Module> &AM) { - SanitizerBinaryMetadata Pass(M, Options); + std::unique_ptr<SpecialCaseList> Ignorelist; + if (!IgnorelistFiles.empty()) { + Ignorelist = SpecialCaseList::createOrDie(IgnorelistFiles, + *vfs::getRealFileSystem()); + if (Ignorelist->inSection("metadata", "src", M.getSourceFileName())) + return PreservedAnalyses::all(); + } + + SanitizerBinaryMetadata Pass(M, Options, std::move(Ignorelist)); if (Pass.run()) return PreservedAnalyses::none(); return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 23a88c3cfba2..f22918141f6e 100644 --- a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -13,13 +13,12 @@ #include "llvm/Transforms/Instrumentation/SanitizerCoverage.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Triple.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" @@ -28,11 +27,10 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/SpecialCaseList.h" #include "llvm/Support/VirtualFileSystem.h" -#include "llvm/Transforms/Instrumentation.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -250,10 +248,6 @@ private: std::pair<Value *, Value *> CreateSecStartEnd(Module &M, const char *Section, Type *Ty); - void SetNoSanitizeMetadata(Instruction *I) { - I->setMetadata(LLVMContext::MD_nosanitize, MDNode::get(*C, std::nullopt)); - } - std::string getSectionName(const std::string &Section) const; std::string getSectionStart(const std::string &Section) const; std::string getSectionEnd(const std::string &Section) const; @@ -809,7 +803,7 @@ void ModuleSanitizerCoverage::InjectCoverageForIndirectCalls( assert(Options.TracePC || Options.TracePCGuard || Options.Inline8bitCounters || Options.InlineBoolFlag); for (auto *I : IndirCalls) { - IRBuilder<> IRB(I); + InstrumentationIRBuilder IRB(I); CallBase &CB = cast<CallBase>(*I); Value *Callee = CB.getCalledOperand(); if (isa<InlineAsm>(Callee)) @@ -826,7 +820,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch( Function &, ArrayRef<Instruction *> SwitchTraceTargets) { for (auto *I : SwitchTraceTargets) { if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { - IRBuilder<> IRB(I); + InstrumentationIRBuilder IRB(I); SmallVector<Constant *, 16> Initializers; Value *Cond = SI->getCondition(); if (Cond->getType()->getScalarSizeInBits() > @@ -864,7 +858,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch( void ModuleSanitizerCoverage::InjectTraceForDiv( Function &, ArrayRef<BinaryOperator *> DivTraceTargets) { for (auto *BO : DivTraceTargets) { - IRBuilder<> IRB(BO); + InstrumentationIRBuilder IRB(BO); Value *A1 = BO->getOperand(1); if (isa<ConstantInt>(A1)) continue; if (!A1->getType()->isIntegerTy()) @@ -882,7 +876,7 @@ void ModuleSanitizerCoverage::InjectTraceForDiv( void ModuleSanitizerCoverage::InjectTraceForGep( Function &, ArrayRef<GetElementPtrInst *> GepTraceTargets) { for (auto *GEP : GepTraceTargets) { - IRBuilder<> IRB(GEP); + InstrumentationIRBuilder IRB(GEP); for (Use &Idx : GEP->indices()) if (!isa<ConstantInt>(Idx) && Idx->getType()->isIntegerTy()) IRB.CreateCall(SanCovTraceGepFunction, @@ -904,7 +898,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( Type *PointerType[5] = {Int8PtrTy, Int16PtrTy, Int32PtrTy, Int64PtrTy, Int128PtrTy}; for (auto *LI : Loads) { - IRBuilder<> IRB(LI); + InstrumentationIRBuilder IRB(LI); auto Ptr = LI->getPointerOperand(); int Idx = CallbackIdx(LI->getType()); if (Idx < 0) @@ -913,7 +907,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( IRB.CreatePointerCast(Ptr, PointerType[Idx])); } for (auto *SI : Stores) { - IRBuilder<> IRB(SI); + InstrumentationIRBuilder IRB(SI); auto Ptr = SI->getPointerOperand(); int Idx = CallbackIdx(SI->getValueOperand()->getType()); if (Idx < 0) @@ -927,7 +921,7 @@ void ModuleSanitizerCoverage::InjectTraceForCmp( Function &, ArrayRef<Instruction *> CmpTraceTargets) { for (auto *I : CmpTraceTargets) { if (ICmpInst *ICMP = dyn_cast<ICmpInst>(I)) { - IRBuilder<> IRB(ICMP); + InstrumentationIRBuilder IRB(ICMP); Value *A0 = ICMP->getOperand(0); Value *A1 = ICMP->getOperand(1); if (!A0->getType()->isIntegerTy()) @@ -994,8 +988,8 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, auto Load = IRB.CreateLoad(Int8Ty, CounterPtr); auto Inc = IRB.CreateAdd(Load, ConstantInt::get(Int8Ty, 1)); auto Store = IRB.CreateStore(Inc, CounterPtr); - SetNoSanitizeMetadata(Load); - SetNoSanitizeMetadata(Store); + Load->setNoSanitizeMetadata(); + Store->setNoSanitizeMetadata(); } if (Options.InlineBoolFlag) { auto FlagPtr = IRB.CreateGEP( @@ -1006,8 +1000,8 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, SplitBlockAndInsertIfThen(IRB.CreateIsNull(Load), &*IP, false); IRBuilder<> ThenIRB(ThenTerm); auto Store = ThenIRB.CreateStore(ConstantInt::getTrue(Int1Ty), FlagPtr); - SetNoSanitizeMetadata(Load); - SetNoSanitizeMetadata(Store); + Load->setNoSanitizeMetadata(); + Store->setNoSanitizeMetadata(); } if (Options.StackDepth && IsEntryBB && !IsLeafFunc) { // Check stack depth. If it's the deepest so far, record it. @@ -1023,8 +1017,8 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, auto ThenTerm = SplitBlockAndInsertIfThen(IsStackLower, &*IP, false); IRBuilder<> ThenIRB(ThenTerm); auto Store = ThenIRB.CreateStore(FrameAddrInt, SanCovLowestStack); - SetNoSanitizeMetadata(LowestStack); - SetNoSanitizeMetadata(Store); + LowestStack->setNoSanitizeMetadata(); + Store->setNoSanitizeMetadata(); } } diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index a127e81ce643..ce35eefb63fa 100644 --- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -689,7 +689,7 @@ static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) { // replaced back with intrinsics. If that becomes wrong at some point, // we will need to call e.g. __tsan_memset to avoid the intrinsics. bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) { - IRBuilder<> IRB(I); + InstrumentationIRBuilder IRB(I); if (MemSetInst *M = dyn_cast<MemSetInst>(I)) { IRB.CreateCall( MemsetFn, @@ -813,8 +813,6 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { int ThreadSanitizer::getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr, const DataLayout &DL) { assert(OrigTy->isSized()); - assert( - cast<PointerType>(Addr->getType())->isOpaqueOrPointeeTypeMatches(OrigTy)); uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); if (TypeSize != 8 && TypeSize != 16 && TypeSize != 32 && TypeSize != 64 && TypeSize != 128) { @@ -822,7 +820,7 @@ int ThreadSanitizer::getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr, // Ignore all unusual sizes. return -1; } - size_t Idx = countTrailingZeros(TypeSize / 8); + size_t Idx = llvm::countr_zero(TypeSize / 8); assert(Idx < kNumberOfAccessSizes); return Idx; } diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/llvm/lib/Transforms/ObjCARC/ObjCARC.h index d4570ff908f1..9e68bd574851 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARC.h +++ b/llvm/lib/Transforms/ObjCARC/ObjCARC.h @@ -22,9 +22,9 @@ #ifndef LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H #define LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCUtil.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/Transforms/Utils/Local.h" namespace llvm { diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index ab90ef090ae0..c397ab63f388 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -31,9 +31,9 @@ #include "ProvenanceAnalysis.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/ObjCARCUtil.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Operator.h" diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index a374958f9707..adf86526ebf1 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -36,7 +36,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/ObjCARCAliasAnalysis.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCInstKind.h" @@ -46,6 +45,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" @@ -933,8 +933,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, if (IsNullOrUndef(CI->getArgOperand(0))) { Changed = true; new StoreInst(ConstantInt::getTrue(CI->getContext()), - UndefValue::get(Type::getInt1PtrTy(CI->getContext())), CI); - Value *NewValue = UndefValue::get(CI->getType()); + PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI); + Value *NewValue = PoisonValue::get(CI->getType()); LLVM_DEBUG( dbgs() << "A null pointer-to-weak-pointer is undefined behavior." "\nOld = " @@ -952,9 +952,9 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, IsNullOrUndef(CI->getArgOperand(1))) { Changed = true; new StoreInst(ConstantInt::getTrue(CI->getContext()), - UndefValue::get(Type::getInt1PtrTy(CI->getContext())), CI); + PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI); - Value *NewValue = UndefValue::get(CI->getType()); + Value *NewValue = PoisonValue::get(CI->getType()); LLVM_DEBUG( dbgs() << "A null pointer-to-weak-pointer is undefined behavior." "\nOld = " diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index 2fa25a79ae9d..23855231c5b9 100644 --- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -42,40 +42,21 @@ bool ProvenanceAnalysis::relatedSelect(const SelectInst *A, const Value *B) { // If the values are Selects with the same condition, we can do a more precise // check: just check for relations between the values on corresponding arms. - if (const SelectInst *SB = dyn_cast<SelectInst>(B)) { + if (const SelectInst *SB = dyn_cast<SelectInst>(B)) if (A->getCondition() == SB->getCondition()) return related(A->getTrueValue(), SB->getTrueValue()) || related(A->getFalseValue(), SB->getFalseValue()); - // Check both arms of B individually. Return false if neither arm is related - // to A. - if (!(related(SB->getTrueValue(), A) || related(SB->getFalseValue(), A))) - return false; - } - // Check both arms of the Select node individually. return related(A->getTrueValue(), B) || related(A->getFalseValue(), B); } bool ProvenanceAnalysis::relatedPHI(const PHINode *A, const Value *B) { - - auto comparePHISources = [this](const PHINode *PNA, const Value *B) -> bool { - // Check each unique source of the PHI node against B. - SmallPtrSet<const Value *, 4> UniqueSrc; - for (Value *PV1 : PNA->incoming_values()) { - if (UniqueSrc.insert(PV1).second && related(PV1, B)) - return true; - } - - // All of the arms checked out. - return false; - }; - - if (const PHINode *PNB = dyn_cast<PHINode>(B)) { - // If the values are PHIs in the same block, we can do a more precise as - // well as efficient check: just check for relations between the values on - // corresponding edges. + // If the values are PHIs in the same block, we can do a more precise as well + // as efficient check: just check for relations between the values on + // corresponding edges. + if (const PHINode *PNB = dyn_cast<PHINode>(B)) if (PNB->getParent() == A->getParent()) { for (unsigned i = 0, e = A->getNumIncomingValues(); i != e; ++i) if (related(A->getIncomingValue(i), @@ -84,11 +65,15 @@ bool ProvenanceAnalysis::relatedPHI(const PHINode *A, return false; } - if (!comparePHISources(PNB, A)) - return false; + // Check each unique source of the PHI node against B. + SmallPtrSet<const Value *, 4> UniqueSrc; + for (Value *PV1 : A->incoming_values()) { + if (UniqueSrc.insert(PV1).second && related(PV1, B)) + return true; } - return comparePHISources(A, B); + // All of the arms checked out. + return false; } /// Test if the value of P, or any value covered by its provenance, is ever @@ -140,19 +125,22 @@ bool ProvenanceAnalysis::relatedCheck(const Value *A, const Value *B) { bool BIsIdentified = IsObjCIdentifiedObject(B); // An ObjC-Identified object can't alias a load if it is never locally stored. - - // Check for an obvious escape. - if ((AIsIdentified && isa<LoadInst>(B) && !IsStoredObjCPointer(A)) || - (BIsIdentified && isa<LoadInst>(A) && !IsStoredObjCPointer(B))) - return false; - - if ((AIsIdentified && isa<LoadInst>(B)) || - (BIsIdentified && isa<LoadInst>(A))) - return true; - - // Both pointers are identified and escapes aren't an evident problem. - if (AIsIdentified && BIsIdentified && !isa<LoadInst>(A) && !isa<LoadInst>(B)) - return false; + if (AIsIdentified) { + // Check for an obvious escape. + if (isa<LoadInst>(B)) + return IsStoredObjCPointer(A); + if (BIsIdentified) { + // Check for an obvious escape. + if (isa<LoadInst>(A)) + return IsStoredObjCPointer(B); + // Both pointers are identified and escapes aren't an evident problem. + return false; + } + } else if (BIsIdentified) { + // Check for an obvious escape. + if (isa<LoadInst>(A)) + return IsStoredObjCPointer(B); + } // Special handling for PHI and Select. if (const PHINode *PN = dyn_cast<PHINode>(A)) @@ -179,15 +167,12 @@ bool ProvenanceAnalysis::related(const Value *A, const Value *B) { // Begin by inserting a conservative value into the map. If the insertion // fails, we have the answer already. If it succeeds, leave it there until we // compute the real answer to guard against recursive queries. - if (A > B) std::swap(A, B); std::pair<CachedResultsTy::iterator, bool> Pair = CachedResults.insert(std::make_pair(ValuePairTy(A, B), true)); if (!Pair.second) return Pair.first->second; bool Result = relatedCheck(A, B); - assert(relatedCheck(B, A) == Result && - "relatedCheck result depending on order of parameters!"); CachedResults[ValuePairTy(A, B)] = Result; return Result; } diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp index 253293582945..24354211341f 100644 --- a/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -26,6 +26,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -42,14 +43,11 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstddef> @@ -113,6 +111,12 @@ struct BlockInfoType { bool terminatorIsLive() const { return TerminatorLiveInfo->Live; } }; +struct ADCEChanged { + bool ChangedAnything = false; + bool ChangedNonDebugInstr = false; + bool ChangedControlFlow = false; +}; + class AggressiveDeadCodeElimination { Function &F; @@ -179,7 +183,7 @@ class AggressiveDeadCodeElimination { /// Remove instructions not marked live, return if any instruction was /// removed. - bool removeDeadInstructions(); + ADCEChanged removeDeadInstructions(); /// Identify connected sections of the control flow graph which have /// dead terminators and rewrite the control flow graph to remove them. @@ -197,12 +201,12 @@ public: PostDominatorTree &PDT) : F(F), DT(DT), PDT(PDT) {} - bool performDeadCodeElimination(); + ADCEChanged performDeadCodeElimination(); }; } // end anonymous namespace -bool AggressiveDeadCodeElimination::performDeadCodeElimination() { +ADCEChanged AggressiveDeadCodeElimination::performDeadCodeElimination() { initialize(); markLiveInstructions(); return removeDeadInstructions(); @@ -504,9 +508,10 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { // Routines to update the CFG and SSA information before removing dead code. // //===----------------------------------------------------------------------===// -bool AggressiveDeadCodeElimination::removeDeadInstructions() { +ADCEChanged AggressiveDeadCodeElimination::removeDeadInstructions() { + ADCEChanged Changed; // Updates control and dataflow around dead blocks - bool RegionsUpdated = updateDeadRegions(); + Changed.ChangedControlFlow = updateDeadRegions(); LLVM_DEBUG({ for (Instruction &I : instructions(F)) { @@ -554,6 +559,8 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { continue; // Fallthrough and drop the intrinsic. + } else { + Changed.ChangedNonDebugInstr = true; } // Prepare to delete. @@ -569,7 +576,9 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { I->eraseFromParent(); } - return !Worklist.empty() || RegionsUpdated; + Changed.ChangedAnything = Changed.ChangedControlFlow || !Worklist.empty(); + + return Changed; } // A dead region is the set of dead blocks with a common live post-dominator. @@ -699,62 +708,25 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { // to update analysis if it is already available. auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); - if (!AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination()) + ADCEChanged Changed = + AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination(); + if (!Changed.ChangedAnything) return PreservedAnalyses::all(); PreservedAnalyses PA; - // TODO: We could track if we have actually done CFG changes. - if (!RemoveControlFlowFlag) + if (!Changed.ChangedControlFlow) { PA.preserveSet<CFGAnalyses>(); - else { - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<PostDominatorTreeAnalysis>(); - } - return PA; -} - -namespace { - -struct ADCELegacyPass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - - ADCELegacyPass() : FunctionPass(ID) { - initializeADCELegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - // ADCE does not need DominatorTree, but require DominatorTree here - // to update analysis if it is already available. - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - return AggressiveDeadCodeElimination(F, DT, PDT) - .performDeadCodeElimination(); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<PostDominatorTreeWrapperPass>(); - if (!RemoveControlFlowFlag) - AU.setPreservesCFG(); - else { - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<PostDominatorTreeWrapperPass>(); + if (!Changed.ChangedNonDebugInstr) { + // Only removing debug instructions does not affect MemorySSA. + // + // Therefore we preserve MemorySSA when only removing debug instructions + // since otherwise later passes may behave differently which then makes + // the presence of debug info affect code generation. + PA.preserve<MemorySSAAnalysis>(); } - AU.addPreserved<GlobalsAAWrapperPass>(); } -}; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); -} // end anonymous namespace - -char ADCELegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(ADCELegacyPass, "adce", - "Aggressive Dead Code Elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", - false, false) - -FunctionPass *llvm::createAggressiveDCEPass() { return new ADCELegacyPass(); } + return PA; +} diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index f419f7bd769f..b259c76fc3a5 100644 --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -28,13 +28,10 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" -#define AA_NAME "alignment-from-assumptions" -#define DEBUG_TYPE AA_NAME +#define DEBUG_TYPE "alignment-from-assumptions" using namespace llvm; STATISTIC(NumLoadAlignChanged, @@ -44,46 +41,6 @@ STATISTIC(NumStoreAlignChanged, STATISTIC(NumMemIntAlignChanged, "Number of memory intrinsics changed by alignment assumptions"); -namespace { -struct AlignmentFromAssumptions : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - AlignmentFromAssumptions() : FunctionPass(ID) { - initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - - AU.setPreservesCFG(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - } - - AlignmentFromAssumptionsPass Impl; -}; -} - -char AlignmentFromAssumptions::ID = 0; -static const char aip_name[] = "Alignment from assumptions"; -INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME, - aip_name, false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME, - aip_name, false, false) - -FunctionPass *llvm::createAlignmentFromAssumptionsPass() { - return new AlignmentFromAssumptions(); -} - // Given an expression for the (constant) alignment, AlignSCEV, and an // expression for the displacement between a pointer and the aligned address, // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced @@ -317,17 +274,6 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, return true; } -bool AlignmentFromAssumptions::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - - return Impl.runImpl(F, AC, SE, DT); -} - bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, ScalarEvolution *SE_, DominatorTree *DT_) { diff --git a/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp b/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp index 79f7e253d45b..b182f46cc515 100644 --- a/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp +++ b/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp @@ -16,7 +16,6 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/MemoryOpRemark.h" using namespace llvm; @@ -58,7 +57,12 @@ static void runImpl(Function &F, const TargetLibraryInfo &TLI) { for (const MDOperand &Op : I.getMetadata(LLVMContext::MD_annotation)->operands()) { - auto Iter = Mapping.insert({cast<MDString>(Op.get())->getString(), 0}); + StringRef AnnotationStr = + isa<MDString>(Op.get()) + ? cast<MDString>(Op.get())->getString() + : cast<MDString>(cast<MDTuple>(Op.get())->getOperand(0).get()) + ->getString(); + auto Iter = Mapping.insert({AnnotationStr, 0}); Iter.first->second++; } } diff --git a/llvm/lib/Transforms/Scalar/BDCE.cpp b/llvm/lib/Transforms/Scalar/BDCE.cpp index 187927b3dede..1fa2c75b0f42 100644 --- a/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -23,11 +23,8 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -116,7 +113,7 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { const uint32_t SrcBitSize = SE->getSrcTy()->getScalarSizeInBits(); auto *const DstTy = SE->getDestTy(); const uint32_t DestBitSize = DstTy->getScalarSizeInBits(); - if (Demanded.countLeadingZeros() >= (DestBitSize - SrcBitSize)) { + if (Demanded.countl_zero() >= (DestBitSize - SrcBitSize)) { clearAssumptionsOfUsers(SE, DB); IRBuilder<> Builder(SE); I.replaceAllUsesWith( @@ -173,34 +170,3 @@ PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserveSet<CFGAnalyses>(); return PA; } - -namespace { -struct BDCELegacyPass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - BDCELegacyPass() : FunctionPass(ID) { - initializeBDCELegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - auto &DB = getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); - return bitTrackingDCE(F, DB); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DemandedBitsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; -} - -char BDCELegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(BDCELegacyPass, "bdce", - "Bit-Tracking Dead Code Elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) -INITIALIZE_PASS_END(BDCELegacyPass, "bdce", - "Bit-Tracking Dead Code Elimination", false, false) - -FunctionPass *llvm::createBitTrackingDCEPass() { return new BDCELegacyPass(); } diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index 6665a927826d..aeb7c5d461f0 100644 --- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -535,45 +535,6 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI, return Changed; } -namespace { -struct CallSiteSplittingLegacyPass : public FunctionPass { - static char ID; - CallSiteSplittingLegacyPass() : FunctionPass(ID) { - initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - FunctionPass::getAnalysisUsage(AU); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return doCallSiteSplitting(F, TLI, TTI, DT); - } -}; -} // namespace - -char CallSiteSplittingLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", - "Call-site splitting", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", - "Call-site splitting", false, false) -FunctionPass *llvm::createCallSiteSplittingPass() { - return new CallSiteSplittingLegacyPass(); -} - PreservedAnalyses CallSiteSplittingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 8858545bbc5d..611e64bd0976 100644 --- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -155,16 +155,19 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { Fn.getEntryBlock(), &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI()); - if (MadeChange) { - LLVM_DEBUG(dbgs() << "********** Function after Constant Hoisting: " - << Fn.getName() << '\n'); - LLVM_DEBUG(dbgs() << Fn); - } LLVM_DEBUG(dbgs() << "********** End Constant Hoisting **********\n"); return MadeChange; } +void ConstantHoistingPass::collectMatInsertPts( + const RebasedConstantListType &RebasedConstants, + SmallVectorImpl<Instruction *> &MatInsertPts) const { + for (const RebasedConstantInfo &RCI : RebasedConstants) + for (const ConstantUser &U : RCI.Uses) + MatInsertPts.emplace_back(findMatInsertPt(U.Inst, U.OpndIdx)); +} + /// Find the constant materialization insertion point. Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, unsigned Idx) const { @@ -312,14 +315,15 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, /// Find an insertion point that dominates all uses. SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( - const ConstantInfo &ConstInfo) const { + const ConstantInfo &ConstInfo, + const ArrayRef<Instruction *> MatInsertPts) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. SetVector<BasicBlock *> BBs; SetVector<Instruction *> InsertPts; - for (auto const &RCI : ConstInfo.RebasedConstants) - for (auto const &U : RCI.Uses) - BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent()); + + for (Instruction *MatInsertPt : MatInsertPts) + BBs.insert(MatInsertPt->getParent()); if (BBs.count(Entry)) { InsertPts.insert(&Entry->front()); @@ -328,12 +332,8 @@ SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( if (BFI) { findBestInsertionSet(*DT, *BFI, Entry, BBs); - for (auto *BB : BBs) { - BasicBlock::iterator InsertPt = BB->begin(); - for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) - ; - InsertPts.insert(&*InsertPt); - } + for (BasicBlock *BB : BBs) + InsertPts.insert(&*BB->getFirstInsertionPt()); return InsertPts; } @@ -410,8 +410,8 @@ void ConstantHoistingPass::collectConstantCandidates( // Get offset from the base GV. PointerType *GVPtrTy = cast<PointerType>(BaseGV->getType()); - IntegerType *PtrIntTy = DL->getIntPtrType(*Ctx, GVPtrTy->getAddressSpace()); - APInt Offset(DL->getTypeSizeInBits(PtrIntTy), /*val*/0, /*isSigned*/true); + IntegerType *OffsetTy = DL->getIndexType(*Ctx, GVPtrTy->getAddressSpace()); + APInt Offset(DL->getTypeSizeInBits(OffsetTy), /*val*/ 0, /*isSigned*/ true); auto *GEPO = cast<GEPOperator>(ConstExpr); // TODO: If we have a mix of inbounds and non-inbounds GEPs, then basing a @@ -432,7 +432,7 @@ void ConstantHoistingPass::collectConstantCandidates( // to be cheaper than compute it by <Base + Offset>, which can be lowered to // an ADD instruction or folded into Load/Store instruction. InstructionCost Cost = - TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy, + TTI->getIntImmCostInst(Instruction::Add, 1, Offset, OffsetTy, TargetTransformInfo::TCK_SizeAndLatency, Inst); ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV]; ConstCandMapType::iterator Itr; @@ -751,45 +751,41 @@ static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) { /// Emit materialization code for all rebased constants and update their /// users. void ConstantHoistingPass::emitBaseConstants(Instruction *Base, - Constant *Offset, - Type *Ty, - const ConstantUser &ConstUser) { + UserAdjustment *Adj) { Instruction *Mat = Base; // The same offset can be dereferenced to different types in nested struct. - if (!Offset && Ty && Ty != Base->getType()) - Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0); + if (!Adj->Offset && Adj->Ty && Adj->Ty != Base->getType()) + Adj->Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0); - if (Offset) { - Instruction *InsertionPt = findMatInsertPt(ConstUser.Inst, - ConstUser.OpndIdx); - if (Ty) { + if (Adj->Offset) { + if (Adj->Ty) { // Constant being rebased is a ConstantExpr. - PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx, - cast<PointerType>(Ty)->getAddressSpace()); - Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", InsertionPt); - Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, - Offset, "mat_gep", InsertionPt); - Mat = new BitCastInst(Mat, Ty, "mat_bitcast", InsertionPt); + PointerType *Int8PtrTy = Type::getInt8PtrTy( + *Ctx, cast<PointerType>(Adj->Ty)->getAddressSpace()); + Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", Adj->MatInsertPt); + Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset, + "mat_gep", Adj->MatInsertPt); + Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", Adj->MatInsertPt); } else // Constant being rebased is a ConstantInt. - Mat = BinaryOperator::Create(Instruction::Add, Base, Offset, - "const_mat", InsertionPt); + Mat = BinaryOperator::Create(Instruction::Add, Base, Adj->Offset, + "const_mat", Adj->MatInsertPt); LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) - << " + " << *Offset << ") in BB " + << " + " << *Adj->Offset << ") in BB " << Mat->getParent()->getName() << '\n' << *Mat << '\n'); - Mat->setDebugLoc(ConstUser.Inst->getDebugLoc()); + Mat->setDebugLoc(Adj->User.Inst->getDebugLoc()); } - Value *Opnd = ConstUser.Inst->getOperand(ConstUser.OpndIdx); + Value *Opnd = Adj->User.Inst->getOperand(Adj->User.OpndIdx); // Visit constant integer. if (isa<ConstantInt>(Opnd)) { - LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); - if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat) && Offset) + LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n'); + if (!updateOperand(Adj->User.Inst, Adj->User.OpndIdx, Mat) && Adj->Offset) Mat->eraseFromParent(); - LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "To : " << *Adj->User.Inst << '\n'); return; } @@ -809,9 +805,9 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, << "To : " << *ClonedCastInst << '\n'); } - LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); - updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ClonedCastInst); - LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n'); + updateOperand(Adj->User.Inst, Adj->User.OpndIdx, ClonedCastInst); + LLVM_DEBUG(dbgs() << "To : " << *Adj->User.Inst << '\n'); return; } @@ -819,28 +815,27 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { if (isa<GEPOperator>(ConstExpr)) { // Operand is a ConstantGEP, replace it. - updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat); + updateOperand(Adj->User.Inst, Adj->User.OpndIdx, Mat); return; } // Aside from constant GEPs, only constant cast expressions are collected. assert(ConstExpr->isCast() && "ConstExpr should be a cast"); - Instruction *ConstExprInst = ConstExpr->getAsInstruction( - findMatInsertPt(ConstUser.Inst, ConstUser.OpndIdx)); + Instruction *ConstExprInst = ConstExpr->getAsInstruction(Adj->MatInsertPt); ConstExprInst->setOperand(0, Mat); // Use the same debug location as the instruction we are about to update. - ConstExprInst->setDebugLoc(ConstUser.Inst->getDebugLoc()); + ConstExprInst->setDebugLoc(Adj->User.Inst->getDebugLoc()); LLVM_DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n' << "From : " << *ConstExpr << '\n'); - LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); - if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ConstExprInst)) { + LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n'); + if (!updateOperand(Adj->User.Inst, Adj->User.OpndIdx, ConstExprInst)) { ConstExprInst->eraseFromParent(); - if (Offset) + if (Adj->Offset) Mat->eraseFromParent(); } - LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "To : " << *Adj->User.Inst << '\n'); return; } } @@ -851,8 +846,11 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { bool MadeChange = false; SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec = BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; - for (auto const &ConstInfo : ConstInfoVec) { - SetVector<Instruction *> IPSet = findConstantInsertionPoint(ConstInfo); + for (const consthoist::ConstantInfo &ConstInfo : ConstInfoVec) { + SmallVector<Instruction *, 4> MatInsertPts; + collectMatInsertPts(ConstInfo.RebasedConstants, MatInsertPts); + SetVector<Instruction *> IPSet = + findConstantInsertionPoint(ConstInfo, MatInsertPts); // We can have an empty set if the function contains unreachable blocks. if (IPSet.empty()) continue; @@ -862,22 +860,21 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { unsigned NotRebasedNum = 0; for (Instruction *IP : IPSet) { // First, collect constants depending on this IP of the base. - unsigned Uses = 0; - using RebasedUse = std::tuple<Constant *, Type *, ConstantUser>; - SmallVector<RebasedUse, 4> ToBeRebased; + UsesNum = 0; + SmallVector<UserAdjustment, 4> ToBeRebased; + unsigned MatCtr = 0; for (auto const &RCI : ConstInfo.RebasedConstants) { + UsesNum += RCI.Uses.size(); for (auto const &U : RCI.Uses) { - Uses++; - BasicBlock *OrigMatInsertBB = - findMatInsertPt(U.Inst, U.OpndIdx)->getParent(); + Instruction *MatInsertPt = MatInsertPts[MatCtr++]; + BasicBlock *OrigMatInsertBB = MatInsertPt->getParent(); // If Base constant is to be inserted in multiple places, // generate rebase for U using the Base dominating U. if (IPSet.size() == 1 || DT->dominates(IP->getParent(), OrigMatInsertBB)) - ToBeRebased.push_back(RebasedUse(RCI.Offset, RCI.Ty, U)); + ToBeRebased.emplace_back(RCI.Offset, RCI.Ty, MatInsertPt, U); } } - UsesNum = Uses; // If only few constants depend on this IP of base, skip rebasing, // assuming the base and the rebased have the same materialization cost. @@ -905,15 +902,12 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { << *Base << '\n'); // Emit materialization code for rebased constants depending on this IP. - for (auto const &R : ToBeRebased) { - Constant *Off = std::get<0>(R); - Type *Ty = std::get<1>(R); - ConstantUser U = std::get<2>(R); - emitBaseConstants(Base, Off, Ty, U); + for (UserAdjustment &R : ToBeRebased) { + emitBaseConstants(Base, &R); ReBasesNum++; // Use the same debug location as the last user of the constant. Base->setDebugLoc(DILocation::getMergedLocation( - Base->getDebugLoc(), U.Inst->getDebugLoc())); + Base->getDebugLoc(), R.User.Inst->getDebugLoc())); } assert(!Base->use_empty() && "The use list is empty!?"); assert(isa<Instruction>(Base->user_back()) && diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 12fcb6aa9846..15628d32280d 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstraintSystem.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -26,13 +27,18 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Verifier.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <cmath> +#include <optional> #include <string> using namespace llvm; @@ -48,6 +54,10 @@ static cl::opt<unsigned> MaxRows("constraint-elimination-max-rows", cl::init(500), cl::Hidden, cl::desc("Maximum number of rows to keep in constraint system")); +static cl::opt<bool> DumpReproducers( + "constraint-elimination-dump-reproducers", cl::init(false), cl::Hidden, + cl::desc("Dump IR to reproduce successful transformations.")); + static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max(); static int64_t MinSignedConstraintValue = std::numeric_limits<int64_t>::min(); @@ -65,7 +75,86 @@ static int64_t addWithOverflow(int64_t A, int64_t B) { return Result; } +static Instruction *getContextInstForUse(Use &U) { + Instruction *UserI = cast<Instruction>(U.getUser()); + if (auto *Phi = dyn_cast<PHINode>(UserI)) + UserI = Phi->getIncomingBlock(U)->getTerminator(); + return UserI; +} + namespace { +/// Represents either +/// * a condition that holds on entry to a block (=conditional fact) +/// * an assume (=assume fact) +/// * a use of a compare instruction to simplify. +/// It also tracks the Dominator DFS in and out numbers for each entry. +struct FactOrCheck { + union { + Instruction *Inst; + Use *U; + }; + unsigned NumIn; + unsigned NumOut; + bool HasInst; + bool Not; + + FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool Not) + : Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), + HasInst(true), Not(Not) {} + + FactOrCheck(DomTreeNode *DTN, Use *U) + : U(U), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), + HasInst(false), Not(false) {} + + static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst, + bool Not = false) { + return FactOrCheck(DTN, Inst, Not); + } + + static FactOrCheck getCheck(DomTreeNode *DTN, Use *U) { + return FactOrCheck(DTN, U); + } + + static FactOrCheck getCheck(DomTreeNode *DTN, CallInst *CI) { + return FactOrCheck(DTN, CI, false); + } + + bool isCheck() const { + return !HasInst || + match(Inst, m_Intrinsic<Intrinsic::ssub_with_overflow>()); + } + + Instruction *getContextInst() const { + if (HasInst) + return Inst; + return getContextInstForUse(*U); + } + Instruction *getInstructionToSimplify() const { + assert(isCheck()); + if (HasInst) + return Inst; + // The use may have been simplified to a constant already. + return dyn_cast<Instruction>(*U); + } + bool isConditionFact() const { return !isCheck() && isa<CmpInst>(Inst); } +}; + +/// Keep state required to build worklist. +struct State { + DominatorTree &DT; + SmallVector<FactOrCheck, 64> WorkList; + + State(DominatorTree &DT) : DT(DT) {} + + /// Process block \p BB and add known facts to work-list. + void addInfoFor(BasicBlock &BB); + + /// Returns true if we can add a known condition from BB to its successor + /// block Succ. + bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const { + return DT.dominates(BasicBlockEdge(&BB, Succ), Succ); + } +}; class ConstraintInfo; @@ -100,12 +189,13 @@ struct ConstraintTy { SmallVector<SmallVector<int64_t, 8>> ExtraInfo; bool IsSigned = false; - bool IsEq = false; ConstraintTy() = default; - ConstraintTy(SmallVector<int64_t, 8> Coefficients, bool IsSigned) - : Coefficients(Coefficients), IsSigned(IsSigned) {} + ConstraintTy(SmallVector<int64_t, 8> Coefficients, bool IsSigned, bool IsEq, + bool IsNe) + : Coefficients(Coefficients), IsSigned(IsSigned), IsEq(IsEq), IsNe(IsNe) { + } unsigned size() const { return Coefficients.size(); } @@ -114,6 +204,21 @@ struct ConstraintTy { /// Returns true if all preconditions for this list of constraints are /// satisfied given \p CS and the corresponding \p Value2Index mapping. bool isValid(const ConstraintInfo &Info) const; + + bool isEq() const { return IsEq; } + + bool isNe() const { return IsNe; } + + /// Check if the current constraint is implied by the given ConstraintSystem. + /// + /// \return true or false if the constraint is proven to be respectively true, + /// or false. When the constraint cannot be proven to be either true or false, + /// std::nullopt is returned. + std::optional<bool> isImpliedBy(const ConstraintSystem &CS) const; + +private: + bool IsEq = false; + bool IsNe = false; }; /// Wrapper encapsulating separate constraint systems and corresponding value @@ -123,8 +228,6 @@ struct ConstraintTy { /// based on signed-ness, certain conditions can be transferred between the two /// systems. class ConstraintInfo { - DenseMap<Value *, unsigned> UnsignedValue2Index; - DenseMap<Value *, unsigned> SignedValue2Index; ConstraintSystem UnsignedCS; ConstraintSystem SignedCS; @@ -132,13 +235,14 @@ class ConstraintInfo { const DataLayout &DL; public: - ConstraintInfo(const DataLayout &DL) : DL(DL) {} + ConstraintInfo(const DataLayout &DL, ArrayRef<Value *> FunctionArgs) + : UnsignedCS(FunctionArgs), SignedCS(FunctionArgs), DL(DL) {} DenseMap<Value *, unsigned> &getValue2Index(bool Signed) { - return Signed ? SignedValue2Index : UnsignedValue2Index; + return Signed ? SignedCS.getValue2Index() : UnsignedCS.getValue2Index(); } const DenseMap<Value *, unsigned> &getValue2Index(bool Signed) const { - return Signed ? SignedValue2Index : UnsignedValue2Index; + return Signed ? SignedCS.getValue2Index() : UnsignedCS.getValue2Index(); } ConstraintSystem &getCS(bool Signed) { @@ -235,9 +339,8 @@ static bool canUseSExt(ConstantInt *CI) { } static Decomposition -decomposeGEP(GetElementPtrInst &GEP, - SmallVectorImpl<PreconditionTy> &Preconditions, bool IsSigned, - const DataLayout &DL) { +decomposeGEP(GEPOperator &GEP, SmallVectorImpl<PreconditionTy> &Preconditions, + bool IsSigned, const DataLayout &DL) { // Do not reason about pointers where the index size is larger than 64 bits, // as the coefficients used to encode constraints are 64 bit integers. if (DL.getIndexTypeSizeInBits(GEP.getPointerOperand()->getType()) > 64) @@ -257,7 +360,7 @@ decomposeGEP(GetElementPtrInst &GEP, // Handle the (gep (gep ....), C) case by incrementing the constant // coefficient of the inner GEP, if C is a constant. - auto *InnerGEP = dyn_cast<GetElementPtrInst>(GEP.getPointerOperand()); + auto *InnerGEP = dyn_cast<GEPOperator>(GEP.getPointerOperand()); if (VariableOffsets.empty() && InnerGEP && InnerGEP->getNumOperands() == 2) { auto Result = decompose(InnerGEP, Preconditions, IsSigned, DL); Result.add(ConstantOffset.getSExtValue()); @@ -320,6 +423,13 @@ static Decomposition decompose(Value *V, if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) return MergeResults(Op0, Op1, IsSigned); + ConstantInt *CI; + if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI)))) { + auto Result = decompose(Op0, Preconditions, IsSigned, DL); + Result.mul(CI->getSExtValue()); + return Result; + } + return V; } @@ -329,7 +439,7 @@ static Decomposition decompose(Value *V, return int64_t(CI->getZExtValue()); } - if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) + if (auto *GEP = dyn_cast<GEPOperator>(V)) return decomposeGEP(*GEP, Preconditions, IsSigned, DL); Value *Op0; @@ -363,10 +473,17 @@ static Decomposition decompose(Value *V, return MergeResults(Op0, CI, true); } + // Decompose or as an add if there are no common bits between the operands. + if (match(V, m_Or(m_Value(Op0), m_ConstantInt(CI))) && + haveNoCommonBitsSet(Op0, CI, DL)) { + return MergeResults(Op0, CI, IsSigned); + } + if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) { - int64_t Mult = int64_t(std::pow(int64_t(2), CI->getSExtValue())); + if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64) + return {V, IsKnownNonNegative}; auto Result = decompose(Op1, Preconditions, IsSigned, DL); - Result.mul(Mult); + Result.mul(int64_t{1} << CI->getSExtValue()); return Result; } @@ -390,6 +507,8 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, SmallVectorImpl<Value *> &NewVariables) const { assert(NewVariables.empty() && "NewVariables must be empty when passed in"); bool IsEq = false; + bool IsNe = false; + // Try to convert Pred to one of ULE/SLT/SLE/SLT. switch (Pred) { case CmpInst::ICMP_UGT: @@ -409,10 +528,13 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, } break; case CmpInst::ICMP_NE: - if (!match(Op1, m_Zero())) - return {}; - Pred = CmpInst::getSwappedPredicate(CmpInst::ICMP_UGT); - std::swap(Op0, Op1); + if (match(Op1, m_Zero())) { + Pred = CmpInst::getSwappedPredicate(CmpInst::ICMP_UGT); + std::swap(Op0, Op1); + } else { + IsNe = true; + Pred = CmpInst::ICMP_ULE; + } break; default: break; @@ -459,11 +581,10 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, // subtracting all coefficients from B. ConstraintTy Res( SmallVector<int64_t, 8>(Value2Index.size() + NewVariables.size() + 1, 0), - IsSigned); + IsSigned, IsEq, IsNe); // Collect variables that are known to be positive in all uses in the // constraint. DenseMap<Value *, bool> KnownNonNegativeVariables; - Res.IsEq = IsEq; auto &R = Res.Coefficients; for (const auto &KV : VariablesA) { R[GetOrAddIndex(KV.Variable)] += KV.Coefficient; @@ -473,7 +594,9 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, } for (const auto &KV : VariablesB) { - R[GetOrAddIndex(KV.Variable)] -= KV.Coefficient; + if (SubOverflow(R[GetOrAddIndex(KV.Variable)], KV.Coefficient, + R[GetOrAddIndex(KV.Variable)])) + return {}; auto I = KnownNonNegativeVariables.insert({KV.Variable, KV.IsKnownNonNegative}); I.first->second &= KV.IsKnownNonNegative; @@ -501,8 +624,8 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, // Add extra constraints for variables that are known positive. for (auto &KV : KnownNonNegativeVariables) { - if (!KV.second || (Value2Index.find(KV.first) == Value2Index.end() && - NewIndexMap.find(KV.first) == NewIndexMap.end())) + if (!KV.second || + (!Value2Index.contains(KV.first) && !NewIndexMap.contains(KV.first))) continue; SmallVector<int64_t, 8> C(Value2Index.size() + NewVariables.size() + 1, 0); C[GetOrAddIndex(KV.first)] = -1; @@ -524,7 +647,7 @@ ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred, SmallVector<Value *> NewVariables; ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables); - if (R.IsEq || !NewVariables.empty()) + if (!NewVariables.empty()) return {}; return R; } @@ -536,10 +659,54 @@ bool ConstraintTy::isValid(const ConstraintInfo &Info) const { }); } +std::optional<bool> +ConstraintTy::isImpliedBy(const ConstraintSystem &CS) const { + bool IsConditionImplied = CS.isConditionImplied(Coefficients); + + if (IsEq || IsNe) { + auto NegatedOrEqual = ConstraintSystem::negateOrEqual(Coefficients); + bool IsNegatedOrEqualImplied = + !NegatedOrEqual.empty() && CS.isConditionImplied(NegatedOrEqual); + + // In order to check that `%a == %b` is true (equality), both conditions `%a + // >= %b` and `%a <= %b` must hold true. When checking for equality (`IsEq` + // is true), we return true if they both hold, false in the other cases. + if (IsConditionImplied && IsNegatedOrEqualImplied) + return IsEq; + + auto Negated = ConstraintSystem::negate(Coefficients); + bool IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated); + + auto StrictLessThan = ConstraintSystem::toStrictLessThan(Coefficients); + bool IsStrictLessThanImplied = + !StrictLessThan.empty() && CS.isConditionImplied(StrictLessThan); + + // In order to check that `%a != %b` is true (non-equality), either + // condition `%a > %b` or `%a < %b` must hold true. When checking for + // non-equality (`IsNe` is true), we return true if one of the two holds, + // false in the other cases. + if (IsNegatedImplied || IsStrictLessThanImplied) + return IsNe; + + return std::nullopt; + } + + if (IsConditionImplied) + return true; + + auto Negated = ConstraintSystem::negate(Coefficients); + auto IsNegatedImplied = !Negated.empty() && CS.isConditionImplied(Negated); + if (IsNegatedImplied) + return false; + + // Neither the condition nor its negated holds, did not prove anything. + return std::nullopt; +} + bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const { auto R = getConstraintForSolving(Pred, A, B); - return R.Preconditions.empty() && !R.empty() && + return R.isValid(*this) && getCS(R.IsSigned).isConditionImplied(R.Coefficients); } @@ -568,11 +735,15 @@ void ConstraintInfo::transferToOtherSystem( if (doesHold(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0))) addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack); break; - case CmpInst::ICMP_SGT: + case CmpInst::ICMP_SGT: { if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1))) addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn, NumOut, DFSInStack); + if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) + addFact(CmpInst::ICMP_UGT, A, B, NumIn, NumOut, DFSInStack); + break; + } case CmpInst::ICMP_SGE: if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack); @@ -581,77 +752,13 @@ void ConstraintInfo::transferToOtherSystem( } } -namespace { -/// Represents either -/// * a condition that holds on entry to a block (=conditional fact) -/// * an assume (=assume fact) -/// * an instruction to simplify. -/// It also tracks the Dominator DFS in and out numbers for each entry. -struct FactOrCheck { - Instruction *Inst; - unsigned NumIn; - unsigned NumOut; - bool IsCheck; - bool Not; - - FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool IsCheck, bool Not) - : Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), - IsCheck(IsCheck), Not(Not) {} - - static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst, - bool Not = false) { - return FactOrCheck(DTN, Inst, false, Not); - } - - static FactOrCheck getCheck(DomTreeNode *DTN, Instruction *Inst) { - return FactOrCheck(DTN, Inst, true, false); - } - - bool isAssumeFact() const { - if (!IsCheck && isa<IntrinsicInst>(Inst)) { - assert(match(Inst, m_Intrinsic<Intrinsic::assume>())); - return true; - } - return false; - } - - bool isConditionFact() const { return !IsCheck && isa<CmpInst>(Inst); } -}; - -/// Keep state required to build worklist. -struct State { - DominatorTree &DT; - SmallVector<FactOrCheck, 64> WorkList; - - State(DominatorTree &DT) : DT(DT) {} - - /// Process block \p BB and add known facts to work-list. - void addInfoFor(BasicBlock &BB); - - /// Returns true if we can add a known condition from BB to its successor - /// block Succ. - bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const { - return DT.dominates(BasicBlockEdge(&BB, Succ), Succ); - } -}; - -} // namespace - #ifndef NDEBUG -static void dumpWithNames(const ConstraintSystem &CS, - DenseMap<Value *, unsigned> &Value2Index) { - SmallVector<std::string> Names(Value2Index.size(), ""); - for (auto &KV : Value2Index) { - Names[KV.second - 1] = std::string("%") + KV.first->getName().str(); - } - CS.dump(Names); -} -static void dumpWithNames(ArrayRef<int64_t> C, - DenseMap<Value *, unsigned> &Value2Index) { - ConstraintSystem CS; +static void dumpConstraint(ArrayRef<int64_t> C, + const DenseMap<Value *, unsigned> &Value2Index) { + ConstraintSystem CS(Value2Index); CS.addVariableRowFill(C); - dumpWithNames(CS, Value2Index); + CS.dump(); } #endif @@ -661,12 +768,24 @@ void State::addInfoFor(BasicBlock &BB) { // Queue conditions and assumes. for (Instruction &I : BB) { if (auto Cmp = dyn_cast<ICmpInst>(&I)) { - WorkList.push_back(FactOrCheck::getCheck(DT.getNode(&BB), Cmp)); + for (Use &U : Cmp->uses()) { + auto *UserI = getContextInstForUse(U); + auto *DTN = DT.getNode(UserI->getParent()); + if (!DTN) + continue; + WorkList.push_back(FactOrCheck::getCheck(DTN, &U)); + } continue; } if (match(&I, m_Intrinsic<Intrinsic::ssub_with_overflow>())) { - WorkList.push_back(FactOrCheck::getCheck(DT.getNode(&BB), &I)); + WorkList.push_back( + FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I))); + continue; + } + + if (isa<MinMaxIntrinsic>(&I)) { + WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I)); continue; } @@ -748,7 +867,160 @@ void State::addInfoFor(BasicBlock &BB) { FactOrCheck::getFact(DT.getNode(Br->getSuccessor(1)), CmpI, true)); } -static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { +namespace { +/// Helper to keep track of a condition and if it should be treated as negated +/// for reproducer construction. +/// Pred == Predicate::BAD_ICMP_PREDICATE indicates that this entry is a +/// placeholder to keep the ReproducerCondStack in sync with DFSInStack. +struct ReproducerEntry { + ICmpInst::Predicate Pred; + Value *LHS; + Value *RHS; + + ReproducerEntry(ICmpInst::Predicate Pred, Value *LHS, Value *RHS) + : Pred(Pred), LHS(LHS), RHS(RHS) {} +}; +} // namespace + +/// Helper function to generate a reproducer function for simplifying \p Cond. +/// The reproducer function contains a series of @llvm.assume calls, one for +/// each condition in \p Stack. For each condition, the operand instruction are +/// cloned until we reach operands that have an entry in \p Value2Index. Those +/// will then be added as function arguments. \p DT is used to order cloned +/// instructions. The reproducer function will get added to \p M, if it is +/// non-null. Otherwise no reproducer function is generated. +static void generateReproducer(CmpInst *Cond, Module *M, + ArrayRef<ReproducerEntry> Stack, + ConstraintInfo &Info, DominatorTree &DT) { + if (!M) + return; + + LLVMContext &Ctx = Cond->getContext(); + + LLVM_DEBUG(dbgs() << "Creating reproducer for " << *Cond << "\n"); + + ValueToValueMapTy Old2New; + SmallVector<Value *> Args; + SmallPtrSet<Value *, 8> Seen; + // Traverse Cond and its operands recursively until we reach a value that's in + // Value2Index or not an instruction, or not a operation that + // ConstraintElimination can decompose. Such values will be considered as + // external inputs to the reproducer, they are collected and added as function + // arguments later. + auto CollectArguments = [&](ArrayRef<Value *> Ops, bool IsSigned) { + auto &Value2Index = Info.getValue2Index(IsSigned); + SmallVector<Value *, 4> WorkList(Ops); + while (!WorkList.empty()) { + Value *V = WorkList.pop_back_val(); + if (!Seen.insert(V).second) + continue; + if (Old2New.find(V) != Old2New.end()) + continue; + if (isa<Constant>(V)) + continue; + + auto *I = dyn_cast<Instruction>(V); + if (Value2Index.contains(V) || !I || + !isa<CmpInst, BinaryOperator, GEPOperator, CastInst>(V)) { + Old2New[V] = V; + Args.push_back(V); + LLVM_DEBUG(dbgs() << " found external input " << *V << "\n"); + } else { + append_range(WorkList, I->operands()); + } + } + }; + + for (auto &Entry : Stack) + if (Entry.Pred != ICmpInst::BAD_ICMP_PREDICATE) + CollectArguments({Entry.LHS, Entry.RHS}, ICmpInst::isSigned(Entry.Pred)); + CollectArguments(Cond, ICmpInst::isSigned(Cond->getPredicate())); + + SmallVector<Type *> ParamTys; + for (auto *P : Args) + ParamTys.push_back(P->getType()); + + FunctionType *FTy = FunctionType::get(Cond->getType(), ParamTys, + /*isVarArg=*/false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, + Cond->getModule()->getName() + + Cond->getFunction()->getName() + "repro", + M); + // Add arguments to the reproducer function for each external value collected. + for (unsigned I = 0; I < Args.size(); ++I) { + F->getArg(I)->setName(Args[I]->getName()); + Old2New[Args[I]] = F->getArg(I); + } + + BasicBlock *Entry = BasicBlock::Create(Ctx, "entry", F); + IRBuilder<> Builder(Entry); + Builder.CreateRet(Builder.getTrue()); + Builder.SetInsertPoint(Entry->getTerminator()); + + // Clone instructions in \p Ops and their operands recursively until reaching + // an value in Value2Index (external input to the reproducer). Update Old2New + // mapping for the original and cloned instructions. Sort instructions to + // clone by dominance, then insert the cloned instructions in the function. + auto CloneInstructions = [&](ArrayRef<Value *> Ops, bool IsSigned) { + SmallVector<Value *, 4> WorkList(Ops); + SmallVector<Instruction *> ToClone; + auto &Value2Index = Info.getValue2Index(IsSigned); + while (!WorkList.empty()) { + Value *V = WorkList.pop_back_val(); + if (Old2New.find(V) != Old2New.end()) + continue; + + auto *I = dyn_cast<Instruction>(V); + if (!Value2Index.contains(V) && I) { + Old2New[V] = nullptr; + ToClone.push_back(I); + append_range(WorkList, I->operands()); + } + } + + sort(ToClone, + [&DT](Instruction *A, Instruction *B) { return DT.dominates(A, B); }); + for (Instruction *I : ToClone) { + Instruction *Cloned = I->clone(); + Old2New[I] = Cloned; + Old2New[I]->setName(I->getName()); + Cloned->insertBefore(&*Builder.GetInsertPoint()); + Cloned->dropUnknownNonDebugMetadata(); + Cloned->setDebugLoc({}); + } + }; + + // Materialize the assumptions for the reproducer using the entries in Stack. + // That is, first clone the operands of the condition recursively until we + // reach an external input to the reproducer and add them to the reproducer + // function. Then add an ICmp for the condition (with the inverse predicate if + // the entry is negated) and an assert using the ICmp. + for (auto &Entry : Stack) { + if (Entry.Pred == ICmpInst::BAD_ICMP_PREDICATE) + continue; + + LLVM_DEBUG( + dbgs() << " Materializing assumption icmp " << Entry.Pred << ' '; + Entry.LHS->printAsOperand(dbgs(), /*PrintType=*/true); dbgs() << ", "; + Entry.RHS->printAsOperand(dbgs(), /*PrintType=*/false); dbgs() << "\n"); + CloneInstructions({Entry.LHS, Entry.RHS}, CmpInst::isSigned(Entry.Pred)); + + auto *Cmp = Builder.CreateICmp(Entry.Pred, Entry.LHS, Entry.RHS); + Builder.CreateAssumption(Cmp); + } + + // Finally, clone the condition to reproduce and remap instruction operands in + // the reproducer using Old2New. + CloneInstructions(Cond, CmpInst::isSigned(Cond->getPredicate())); + Entry->getTerminator()->setOperand(0, Cond); + remapInstructionsInBlocks({Entry}, Old2New); + + assert(!verifyFunction(*F, &dbgs())); +} + +static std::optional<bool> checkCondition(CmpInst *Cmp, ConstraintInfo &Info, + unsigned NumIn, unsigned NumOut, + Instruction *ContextInst) { LLVM_DEBUG(dbgs() << "Checking " << *Cmp << "\n"); CmpInst::Predicate Pred = Cmp->getPredicate(); @@ -758,7 +1030,7 @@ static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { auto R = Info.getConstraintForSolving(Pred, A, B); if (R.empty() || !R.isValid(Info)){ LLVM_DEBUG(dbgs() << " failed to decompose condition\n"); - return false; + return std::nullopt; } auto &CSToUse = Info.getCS(R.IsSigned); @@ -773,39 +1045,107 @@ static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { CSToUse.popLastConstraint(); }); - bool Changed = false; - if (CSToUse.isConditionImplied(R.Coefficients)) { + if (auto ImpliedCondition = R.isImpliedBy(CSToUse)) { if (!DebugCounter::shouldExecute(EliminatedCounter)) - return false; + return std::nullopt; LLVM_DEBUG({ - dbgs() << "Condition " << *Cmp << " implied by dominating constraints\n"; - dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); + if (*ImpliedCondition) { + dbgs() << "Condition " << *Cmp; + } else { + auto InversePred = Cmp->getInversePredicate(); + dbgs() << "Condition " << CmpInst::getPredicateName(InversePred) << " " + << *A << ", " << *B; + } + dbgs() << " implied by dominating constraints\n"; + CSToUse.dump(); }); - Constant *TrueC = - ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType())); - Cmp->replaceUsesWithIf(TrueC, [](Use &U) { + return ImpliedCondition; + } + + return std::nullopt; +} + +static bool checkAndReplaceCondition( + CmpInst *Cmp, ConstraintInfo &Info, unsigned NumIn, unsigned NumOut, + Instruction *ContextInst, Module *ReproducerModule, + ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT) { + auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) { + generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT); + Constant *ConstantC = ConstantInt::getBool( + CmpInst::makeCmpResultType(Cmp->getType()), IsTrue); + Cmp->replaceUsesWithIf(ConstantC, [&DT, NumIn, NumOut, + ContextInst](Use &U) { + auto *UserI = getContextInstForUse(U); + auto *DTN = DT.getNode(UserI->getParent()); + if (!DTN || DTN->getDFSNumIn() < NumIn || DTN->getDFSNumOut() > NumOut) + return false; + if (UserI->getParent() == ContextInst->getParent() && + UserI->comesBefore(ContextInst)) + return false; + // Conditions in an assume trivially simplify to true. Skip uses // in assume calls to not destroy the available information. auto *II = dyn_cast<IntrinsicInst>(U.getUser()); return !II || II->getIntrinsicID() != Intrinsic::assume; }); NumCondsRemoved++; + return true; + }; + + if (auto ImpliedCondition = + checkCondition(Cmp, Info, NumIn, NumOut, ContextInst)) + return ReplaceCmpWithConstant(Cmp, *ImpliedCondition); + return false; +} + +static void +removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info, + Module *ReproducerModule, + SmallVectorImpl<ReproducerEntry> &ReproducerCondStack, + SmallVectorImpl<StackEntry> &DFSInStack) { + Info.popLastConstraint(E.IsSigned); + // Remove variables in the system that went out of scope. + auto &Mapping = Info.getValue2Index(E.IsSigned); + for (Value *V : E.ValuesToRelease) + Mapping.erase(V); + Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size()); + DFSInStack.pop_back(); + if (ReproducerModule) + ReproducerCondStack.pop_back(); +} + +/// Check if the first condition for an AND implies the second. +static bool checkAndSecondOpImpliedByFirst( + FactOrCheck &CB, ConstraintInfo &Info, Module *ReproducerModule, + SmallVectorImpl<ReproducerEntry> &ReproducerCondStack, + SmallVectorImpl<StackEntry> &DFSInStack) { + CmpInst::Predicate Pred; + Value *A, *B; + Instruction *And = CB.getContextInst(); + if (!match(And->getOperand(0), m_ICmp(Pred, m_Value(A), m_Value(B)))) + return false; + + // Optimistically add fact from first condition. + unsigned OldSize = DFSInStack.size(); + Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); + if (OldSize == DFSInStack.size()) + return false; + + bool Changed = false; + // Check if the second condition can be simplified now. + if (auto ImpliedCondition = + checkCondition(cast<ICmpInst>(And->getOperand(1)), Info, CB.NumIn, + CB.NumOut, CB.getContextInst())) { + And->setOperand(1, ConstantInt::getBool(And->getType(), *ImpliedCondition)); Changed = true; } - if (CSToUse.isConditionImplied(ConstraintSystem::negate(R.Coefficients))) { - if (!DebugCounter::shouldExecute(EliminatedCounter)) - return false; - LLVM_DEBUG({ - dbgs() << "Condition !" << *Cmp << " implied by dominating constraints\n"; - dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); - }); - Constant *FalseC = - ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType())); - Cmp->replaceAllUsesWith(FalseC); - NumCondsRemoved++; - Changed = true; + // Remove entries again. + while (OldSize < DFSInStack.size()) { + StackEntry E = DFSInStack.back(); + removeEntryFromStack(E, Info, ReproducerModule, ReproducerCondStack, + DFSInStack); } return Changed; } @@ -817,10 +1157,12 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B, // hold. SmallVector<Value *> NewVariables; auto R = getConstraint(Pred, A, B, NewVariables); - if (!R.isValid(*this)) + + // TODO: Support non-equality for facts as well. + if (!R.isValid(*this) || R.isNe()) return; - LLVM_DEBUG(dbgs() << "Adding '" << CmpInst::getPredicateName(Pred) << " "; + LLVM_DEBUG(dbgs() << "Adding '" << Pred << " "; A->printAsOperand(dbgs(), false); dbgs() << ", "; B->printAsOperand(dbgs(), false); dbgs() << "'\n"); bool Added = false; @@ -842,14 +1184,14 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B, LLVM_DEBUG({ dbgs() << " constraint: "; - dumpWithNames(R.Coefficients, getValue2Index(R.IsSigned)); + dumpConstraint(R.Coefficients, getValue2Index(R.IsSigned)); dbgs() << "\n"; }); DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, std::move(ValuesToRelease)); - if (R.IsEq) { + if (R.isEq()) { // Also add the inverted constraint for equality constraints. for (auto &Coeff : R.Coefficients) Coeff *= -1; @@ -921,12 +1263,17 @@ tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info, return Changed; } -static bool eliminateConstraints(Function &F, DominatorTree &DT) { +static bool eliminateConstraints(Function &F, DominatorTree &DT, + OptimizationRemarkEmitter &ORE) { bool Changed = false; DT.updateDFSNumbers(); - - ConstraintInfo Info(F.getParent()->getDataLayout()); + SmallVector<Value *> FunctionArgs; + for (Value &Arg : F.args()) + FunctionArgs.push_back(&Arg); + ConstraintInfo Info(F.getParent()->getDataLayout(), FunctionArgs); State S(DT); + std::unique_ptr<Module> ReproducerModule( + DumpReproducers ? new Module(F.getName(), F.getContext()) : nullptr); // First, collect conditions implied by branches and blocks with their // Dominator DFS in and out numbers. @@ -961,7 +1308,9 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { return true; if (B.isConditionFact()) return false; - return A.Inst->comesBefore(B.Inst); + auto *InstA = A.getContextInst(); + auto *InstB = B.getContextInst(); + return InstA->comesBefore(InstB); } return A.NumIn < B.NumIn; }); @@ -970,6 +1319,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { // Finally, process ordered worklist and eliminate implied conditions. SmallVector<StackEntry, 16> DFSInStack; + SmallVector<ReproducerEntry> ReproducerCondStack; for (FactOrCheck &CB : S.WorkList) { // First, pop entries from the stack that are out-of-scope for CB. Remove // the corresponding entry from the constraint system. @@ -983,61 +1333,96 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { break; LLVM_DEBUG({ dbgs() << "Removing "; - dumpWithNames(Info.getCS(E.IsSigned).getLastConstraint(), - Info.getValue2Index(E.IsSigned)); + dumpConstraint(Info.getCS(E.IsSigned).getLastConstraint(), + Info.getValue2Index(E.IsSigned)); dbgs() << "\n"; }); - - Info.popLastConstraint(E.IsSigned); - // Remove variables in the system that went out of scope. - auto &Mapping = Info.getValue2Index(E.IsSigned); - for (Value *V : E.ValuesToRelease) - Mapping.erase(V); - Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size()); - DFSInStack.pop_back(); + removeEntryFromStack(E, Info, ReproducerModule.get(), ReproducerCondStack, + DFSInStack); } - LLVM_DEBUG({ - dbgs() << "Processing "; - if (CB.IsCheck) - dbgs() << "condition to simplify: " << *CB.Inst; - else - dbgs() << "fact to add to the system: " << *CB.Inst; - dbgs() << "\n"; - }); + LLVM_DEBUG(dbgs() << "Processing "); // For a block, check if any CmpInsts become known based on the current set // of constraints. - if (CB.IsCheck) { - if (auto *II = dyn_cast<WithOverflowInst>(CB.Inst)) { + if (CB.isCheck()) { + Instruction *Inst = CB.getInstructionToSimplify(); + if (!Inst) + continue; + LLVM_DEBUG(dbgs() << "condition to simplify: " << *Inst << "\n"); + if (auto *II = dyn_cast<WithOverflowInst>(Inst)) { Changed |= tryToSimplifyOverflowMath(II, Info, ToRemove); - } else if (auto *Cmp = dyn_cast<ICmpInst>(CB.Inst)) { - Changed |= checkAndReplaceCondition(Cmp, Info); + } else if (auto *Cmp = dyn_cast<ICmpInst>(Inst)) { + bool Simplified = checkAndReplaceCondition( + Cmp, Info, CB.NumIn, CB.NumOut, CB.getContextInst(), + ReproducerModule.get(), ReproducerCondStack, S.DT); + if (!Simplified && match(CB.getContextInst(), + m_LogicalAnd(m_Value(), m_Specific(Inst)))) { + Simplified = + checkAndSecondOpImpliedByFirst(CB, Info, ReproducerModule.get(), + ReproducerCondStack, DFSInStack); + } + Changed |= Simplified; } continue; } - ICmpInst::Predicate Pred; - Value *A, *B; - Value *Cmp = CB.Inst; - match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp))); - if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) { + LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n"); + auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) { if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) { LLVM_DEBUG( dbgs() << "Skip adding constraint because system has too many rows.\n"); - continue; + return; + } + + Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); + if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size()) + ReproducerCondStack.emplace_back(Pred, A, B); + + Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); + if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size()) { + // Add dummy entries to ReproducerCondStack to keep it in sync with + // DFSInStack. + for (unsigned I = 0, + E = (DFSInStack.size() - ReproducerCondStack.size()); + I < E; ++I) { + ReproducerCondStack.emplace_back(ICmpInst::BAD_ICMP_PREDICATE, + nullptr, nullptr); + } } + }; + ICmpInst::Predicate Pred; + if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) { + Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); + AddFact(Pred, MinMax, MinMax->getLHS()); + AddFact(Pred, MinMax, MinMax->getRHS()); + continue; + } + + Value *A, *B; + Value *Cmp = CB.Inst; + match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp))); + if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) { // Use the inverse predicate if required. if (CB.Not) Pred = CmpInst::getInversePredicate(Pred); - Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); - Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); + AddFact(Pred, A, B); } } + if (ReproducerModule && !ReproducerModule->functions().empty()) { + std::string S; + raw_string_ostream StringS(S); + ReproducerModule->print(StringS, nullptr); + StringS.flush(); + OptimizationRemark Rem(DEBUG_TYPE, "Reproducer", &F); + Rem << ore::NV("module") << S; + ORE.emit(Rem); + } + #ifndef NDEBUG unsigned SignedEntries = count_if(DFSInStack, [](const StackEntry &E) { return E.IsSigned; }); @@ -1055,7 +1440,8 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { PreservedAnalyses ConstraintEliminationPass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - if (!eliminateConstraints(F, DT)) + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + if (!eliminateConstraints(F, DT, ORE)) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 90b4b521e7de..48b27a1ea0a2 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -36,11 +36,8 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <optional> @@ -97,60 +94,33 @@ STATISTIC(NumMinMax, "Number of llvm.[us]{min,max} intrinsics removed"); STATISTIC(NumUDivURemsNarrowedExpanded, "Number of bound udiv's/urem's expanded"); -namespace { - - class CorrelatedValuePropagation : public FunctionPass { - public: - static char ID; - - CorrelatedValuePropagation(): FunctionPass(ID) { - initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LazyValueInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<LazyValueInfoWrapperPass>(); - } - }; - -} // end anonymous namespace - -char CorrelatedValuePropagation::ID = 0; - -INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation", - "Value Propagation", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) -INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation", - "Value Propagation", false, false) - -// Public interface to the Value Propagation pass -Pass *llvm::createCorrelatedValuePropagationPass() { - return new CorrelatedValuePropagation(); -} - static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { - if (S->getType()->isVectorTy()) return false; - if (isa<Constant>(S->getCondition())) return false; - - Constant *C = LVI->getConstant(S->getCondition(), S); - if (!C) return false; + if (S->getType()->isVectorTy() || isa<Constant>(S->getCondition())) + return false; - ConstantInt *CI = dyn_cast<ConstantInt>(C); - if (!CI) return false; + bool Changed = false; + for (Use &U : make_early_inc_range(S->uses())) { + auto *I = cast<Instruction>(U.getUser()); + Constant *C; + if (auto *PN = dyn_cast<PHINode>(I)) + C = LVI->getConstantOnEdge(S->getCondition(), PN->getIncomingBlock(U), + I->getParent(), I); + else + C = LVI->getConstant(S->getCondition(), I); + + auto *CI = dyn_cast_or_null<ConstantInt>(C); + if (!CI) + continue; - Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue(); - S->replaceAllUsesWith(ReplaceWith); - S->eraseFromParent(); + U.set(CI->isOne() ? S->getTrueValue() : S->getFalseValue()); + Changed = true; + ++NumSelects; + } - ++NumSelects; + if (Changed && S->use_empty()) + S->eraseFromParent(); - return true; + return Changed; } /// Try to simplify a phi with constant incoming values that match the edge @@ -698,7 +668,7 @@ enum class Domain { NonNegative, NonPositive, Unknown }; static Domain getDomain(const ConstantRange &CR) { if (CR.isAllNonNegative()) return Domain::NonNegative; - if (CR.icmp(ICmpInst::ICMP_SLE, APInt::getNullValue(CR.getBitWidth()))) + if (CR.icmp(ICmpInst::ICMP_SLE, APInt::getZero(CR.getBitWidth()))) return Domain::NonPositive; return Domain::Unknown; } @@ -717,7 +687,6 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR, // What is the smallest bit width that can accommodate the entire value ranges // of both of the operands? - std::array<std::optional<ConstantRange>, 2> CRs; unsigned MinSignedBits = std::max(LCR.getMinSignedBits(), RCR.getMinSignedBits()); @@ -804,10 +773,18 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, IRBuilder<> B(Instr); Value *ExpandedOp; - if (IsRem) { + if (XCR.icmp(ICmpInst::ICMP_UGE, YCR)) { + // If X is between Y and 2*Y the result is known. + if (IsRem) + ExpandedOp = B.CreateNUWSub(X, Y); + else + ExpandedOp = ConstantInt::get(Instr->getType(), 1); + } else if (IsRem) { // NOTE: this transformation introduces two uses of X, // but it may be undef so we must freeze it first. - Value *FrozenX = B.CreateFreeze(X, X->getName() + ".frozen"); + Value *FrozenX = X; + if (!isGuaranteedNotToBeUndefOrPoison(X)) + FrozenX = B.CreateFreeze(X, X->getName() + ".frozen"); auto *AdjX = B.CreateNUWSub(FrozenX, Y, Instr->getName() + ".urem"); auto *Cmp = B.CreateICmp(ICmpInst::ICMP_ULT, FrozenX, Y, Instr->getName() + ".cmp"); @@ -1008,7 +985,8 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { if (SDI->getType()->isVectorTy()) return false; - ConstantRange LRange = LVI->getConstantRangeAtUse(SDI->getOperandUse(0)); + ConstantRange LRange = + LVI->getConstantRangeAtUse(SDI->getOperandUse(0), /*UndefAllowed*/ false); unsigned OrigWidth = SDI->getType()->getIntegerBitWidth(); ConstantRange NegOneOrZero = ConstantRange(APInt(OrigWidth, (uint64_t)-1, true), APInt(OrigWidth, 1)); @@ -1040,7 +1018,8 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { return false; const Use &Base = SDI->getOperandUse(0); - if (!LVI->getConstantRangeAtUse(Base).isAllNonNegative()) + if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false) + .isAllNonNegative()) return false; ++NumSExt; @@ -1222,16 +1201,6 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, return FnChanged; } -bool CorrelatedValuePropagation::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - - LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - - return runImpl(F, LVI, DT, getBestSimplifyQuery(*this, F)); -} - PreservedAnalyses CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 658d0fcb53fa..f2efe60bdf88 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -70,11 +70,8 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/SSAUpdaterBulk.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -168,51 +165,8 @@ private: OptimizationRemarkEmitter *ORE; }; -class DFAJumpThreadingLegacyPass : public FunctionPass { -public: - static char ID; // Pass identification - DFAJumpThreadingLegacyPass() : FunctionPass(ID) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - AssumptionCache *AC = - &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - TargetTransformInfo *TTI = - &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - OptimizationRemarkEmitter *ORE = - &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - - return DFAJumpThreading(AC, DT, TTI, ORE).run(F); - } -}; } // end anonymous namespace -char DFAJumpThreadingLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(DFAJumpThreadingLegacyPass, "dfa-jump-threading", - "DFA Jump Threading", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_END(DFAJumpThreadingLegacyPass, "dfa-jump-threading", - "DFA Jump Threading", false, false) - -// Public interface to the DFA Jump Threading pass -FunctionPass *llvm::createDFAJumpThreadingPass() { - return new DFAJumpThreadingLegacyPass(); -} - namespace { /// Create a new basic block and sink \p SIToSink into it. @@ -625,7 +579,7 @@ private: continue; PathsType SuccPaths = paths(Succ, Visited, PathDepth + 1); - for (PathType Path : SuccPaths) { + for (const PathType &Path : SuccPaths) { PathType NewPath(Path); NewPath.push_front(BB); Res.push_back(NewPath); @@ -978,7 +932,7 @@ private: SSAUpdaterBulk SSAUpdate; SmallVector<Use *, 16> UsesToRename; - for (auto KV : NewDefs) { + for (const auto &KV : NewDefs) { Instruction *I = KV.first; BasicBlock *BB = I->getParent(); std::vector<Instruction *> Cloned = KV.second; diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 9c0b4d673145..d3fbe49439a8 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -69,15 +69,12 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -462,10 +459,10 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, "Should not hit the entry block because SI must be dominated by LI"); for (BasicBlock *Pred : predecessors(B)) { PHITransAddr PredAddr = Addr; - if (PredAddr.NeedsPHITranslationFromBlock(B)) { - if (!PredAddr.IsPotentiallyPHITranslatable()) + if (PredAddr.needsPHITranslationFromBlock(B)) { + if (!PredAddr.isPotentiallyPHITranslatable()) return false; - if (PredAddr.PHITranslateValue(B, Pred, DT, false)) + if (!PredAddr.translateValue(B, Pred, DT, false)) return false; } Value *TranslatedPtr = PredAddr.getAddr(); @@ -485,41 +482,75 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, return true; } -static void shortenAssignment(Instruction *Inst, uint64_t OldOffsetInBits, - uint64_t OldSizeInBits, uint64_t NewSizeInBits, - bool IsOverwriteEnd) { - DIExpression::FragmentInfo DeadFragment; - DeadFragment.SizeInBits = OldSizeInBits - NewSizeInBits; - DeadFragment.OffsetInBits = +static void shortenAssignment(Instruction *Inst, Value *OriginalDest, + uint64_t OldOffsetInBits, uint64_t OldSizeInBits, + uint64_t NewSizeInBits, bool IsOverwriteEnd) { + const DataLayout &DL = Inst->getModule()->getDataLayout(); + uint64_t DeadSliceSizeInBits = OldSizeInBits - NewSizeInBits; + uint64_t DeadSliceOffsetInBits = OldOffsetInBits + (IsOverwriteEnd ? NewSizeInBits : 0); - - auto CreateDeadFragExpr = [Inst, DeadFragment]() { - // FIXME: This should be using the DIExpression in the Alloca's dbg.assign - // for the variable, since that could also contain a fragment? - return *DIExpression::createFragmentExpression( - DIExpression::get(Inst->getContext(), std::nullopt), + auto SetDeadFragExpr = [](DbgAssignIntrinsic *DAI, + DIExpression::FragmentInfo DeadFragment) { + // createFragmentExpression expects an offset relative to the existing + // fragment offset if there is one. + uint64_t RelativeOffset = DeadFragment.OffsetInBits - + DAI->getExpression() + ->getFragmentInfo() + .value_or(DIExpression::FragmentInfo(0, 0)) + .OffsetInBits; + if (auto NewExpr = DIExpression::createFragmentExpression( + DAI->getExpression(), RelativeOffset, DeadFragment.SizeInBits)) { + DAI->setExpression(*NewExpr); + return; + } + // Failed to create a fragment expression for this so discard the value, + // making this a kill location. + auto *Expr = *DIExpression::createFragmentExpression( + DIExpression::get(DAI->getContext(), std::nullopt), DeadFragment.OffsetInBits, DeadFragment.SizeInBits); + DAI->setExpression(Expr); + DAI->setKillLocation(); }; // A DIAssignID to use so that the inserted dbg.assign intrinsics do not // link to any instructions. Created in the loop below (once). DIAssignID *LinkToNothing = nullptr; + LLVMContext &Ctx = Inst->getContext(); + auto GetDeadLink = [&Ctx, &LinkToNothing]() { + if (!LinkToNothing) + LinkToNothing = DIAssignID::getDistinct(Ctx); + return LinkToNothing; + }; // Insert an unlinked dbg.assign intrinsic for the dead fragment after each - // overlapping dbg.assign intrinsic. - for (auto *DAI : at::getAssignmentMarkers(Inst)) { - if (auto FragInfo = DAI->getExpression()->getFragmentInfo()) { - if (!DIExpression::fragmentsOverlap(*FragInfo, DeadFragment)) - continue; + // overlapping dbg.assign intrinsic. The loop invalidates the iterators + // returned by getAssignmentMarkers so save a copy of the markers to iterate + // over. + auto LinkedRange = at::getAssignmentMarkers(Inst); + SmallVector<DbgAssignIntrinsic *> Linked(LinkedRange.begin(), + LinkedRange.end()); + for (auto *DAI : Linked) { + std::optional<DIExpression::FragmentInfo> NewFragment; + if (!at::calculateFragmentIntersect(DL, OriginalDest, DeadSliceOffsetInBits, + DeadSliceSizeInBits, DAI, + NewFragment) || + !NewFragment) { + // We couldn't calculate the intersecting fragment for some reason. Be + // cautious and unlink the whole assignment from the store. + DAI->setKillAddress(); + DAI->setAssignId(GetDeadLink()); + continue; } + // No intersect. + if (NewFragment->SizeInBits == 0) + continue; // Fragments overlap: insert a new dbg.assign for this dead part. auto *NewAssign = cast<DbgAssignIntrinsic>(DAI->clone()); NewAssign->insertAfter(DAI); - if (!LinkToNothing) - LinkToNothing = DIAssignID::getDistinct(Inst->getContext()); - NewAssign->setAssignId(LinkToNothing); - NewAssign->setExpression(CreateDeadFragExpr()); + NewAssign->setAssignId(GetDeadLink()); + if (NewFragment) + SetDeadFragExpr(NewAssign, *NewFragment); NewAssign->setKillAddress(); } } @@ -596,8 +627,8 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, DeadIntrinsic->setLength(TrimmedLength); DeadIntrinsic->setDestAlignment(PrefAlign); + Value *OrigDest = DeadIntrinsic->getRawDest(); if (!IsOverwriteEnd) { - Value *OrigDest = DeadIntrinsic->getRawDest(); Type *Int8PtrTy = Type::getInt8PtrTy(DeadIntrinsic->getContext(), OrigDest->getType()->getPointerAddressSpace()); @@ -616,7 +647,7 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, } // Update attached dbg.assign intrinsics. Assume 8-bit byte. - shortenAssignment(DeadI, DeadStart * 8, DeadSize * 8, NewSize * 8, + shortenAssignment(DeadI, OrigDest, DeadStart * 8, DeadSize * 8, NewSize * 8, IsOverwriteEnd); // Finally update start and size of dead access. @@ -730,7 +761,7 @@ tryToMergePartialOverlappingStores(StoreInst *KillingI, StoreInst *DeadI, } namespace { -// Returns true if \p I is an intrisnic that does not read or write memory. +// Returns true if \p I is an intrinsic that does not read or write memory. bool isNoopIntrinsic(Instruction *I) { if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { @@ -740,7 +771,6 @@ bool isNoopIntrinsic(Instruction *I) { case Intrinsic::launder_invariant_group: case Intrinsic::assume: return true; - case Intrinsic::dbg_addr: case Intrinsic::dbg_declare: case Intrinsic::dbg_label: case Intrinsic::dbg_value: @@ -2039,7 +2069,6 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, const LoopInfo &LI) { bool MadeChange = false; - MSSA.ensureOptimizedUses(); DSEState State(F, AA, MSSA, DT, PDT, AC, TLI, LI); // For each store: for (unsigned I = 0; I < State.MemDefs.size(); I++) { @@ -2241,79 +2270,3 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<LoopAnalysis>(); return PA; } - -namespace { - -/// A legacy pass for the legacy pass manager that wraps \c DSEPass. -class DSELegacyPass : public FunctionPass { -public: - static char ID; // Pass identification, replacement for typeid - - DSELegacyPass() : FunctionPass(ID) { - initializeDSELegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - PostDominatorTree &PDT = - getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, AC, TLI, LI); - -#ifdef LLVM_ENABLE_STATS - if (AreStatisticsEnabled()) - for (auto &I : instructions(F)) - NumRemainingStores += isa<StoreInst>(&I); -#endif - - return Changed; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<PostDominatorTreeWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - } -}; - -} // end anonymous namespace - -char DSELegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false, - false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, - false) - -FunctionPass *llvm::createDeadStoreEliminationPass() { - return new DSELegacyPass(); -} diff --git a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp index 303951643a0b..57d3f312186e 100644 --- a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp @@ -21,10 +21,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/DebugCounter.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BypassSlowDivision.h" #include <optional> @@ -371,6 +368,10 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, Mul->insertAfter(RemInst); Sub->insertAfter(Mul); + // If DivInst has the exact flag, remove it. Otherwise this optimization + // may replace a well-defined value 'X % Y' with poison. + DivInst->dropPoisonGeneratingFlags(); + // If X can be undef, X should be frozen first. // For example, let's assume that Y = 1 & X = undef: // %div = sdiv undef, 1 // %div = undef @@ -413,44 +414,6 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, // Pass manager boilerplate below here. -namespace { -struct DivRemPairsLegacyPass : public FunctionPass { - static char ID; - DivRemPairsLegacyPass() : FunctionPass(ID) { - initializeDivRemPairsLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.setPreservesCFG(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - FunctionPass::getAnalysisUsage(AU); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return optimizeDivRem(F, TTI, DT); - } -}; -} // namespace - -char DivRemPairsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(DivRemPairsLegacyPass, "div-rem-pairs", - "Hoist/decompose integer division and remainder", false, - false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(DivRemPairsLegacyPass, "div-rem-pairs", - "Hoist/decompose integer division and remainder", false, - false) -FunctionPass *llvm::createDivRemPairsPass() { - return new DivRemPairsLegacyPass(); -} - PreservedAnalyses DivRemPairsPass::run(Function &F, FunctionAnalysisManager &FAM) { TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 26821c7ee81e..67e8e82e408f 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -218,6 +218,19 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A, return true; } +static unsigned hashCallInst(CallInst *CI) { + // Don't CSE convergent calls in different basic blocks, because they + // implicitly depend on the set of threads that is currently executing. + if (CI->isConvergent()) { + return hash_combine( + CI->getOpcode(), CI->getParent(), + hash_combine_range(CI->value_op_begin(), CI->value_op_end())); + } + return hash_combine( + CI->getOpcode(), + hash_combine_range(CI->value_op_begin(), CI->value_op_end())); +} + static unsigned getHashValueImpl(SimpleValue Val) { Instruction *Inst = Val.Inst; // Hash in all of the operands as pointers. @@ -318,6 +331,11 @@ static unsigned getHashValueImpl(SimpleValue Val) { return hash_combine(GCR->getOpcode(), GCR->getOperand(0), GCR->getBasePtr(), GCR->getDerivedPtr()); + // Don't CSE convergent calls in different basic blocks, because they + // implicitly depend on the set of threads that is currently executing. + if (CallInst *CI = dyn_cast<CallInst>(Inst)) + return hashCallInst(CI); + // Mix in the opcode. return hash_combine( Inst->getOpcode(), @@ -344,8 +362,16 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { if (LHSI->getOpcode() != RHSI->getOpcode()) return false; - if (LHSI->isIdenticalToWhenDefined(RHSI)) + if (LHSI->isIdenticalToWhenDefined(RHSI)) { + // Convergent calls implicitly depend on the set of threads that is + // currently executing, so conservatively return false if they are in + // different basic blocks. + if (CallInst *CI = dyn_cast<CallInst>(LHSI); + CI && CI->isConvergent() && LHSI->getParent() != RHSI->getParent()) + return false; + return true; + } // If we're not strictly identical, we still might be a commutable instruction if (BinaryOperator *LHSBinOp = dyn_cast<BinaryOperator>(LHSI)) { @@ -508,15 +534,21 @@ unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) { Instruction *Inst = Val.Inst; // Hash all of the operands as pointers and mix in the opcode. - return hash_combine( - Inst->getOpcode(), - hash_combine_range(Inst->value_op_begin(), Inst->value_op_end())); + return hashCallInst(cast<CallInst>(Inst)); } bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) { - Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; if (LHS.isSentinel() || RHS.isSentinel()) - return LHSI == RHSI; + return LHS.Inst == RHS.Inst; + + CallInst *LHSI = cast<CallInst>(LHS.Inst); + CallInst *RHSI = cast<CallInst>(RHS.Inst); + + // Convergent calls implicitly depend on the set of threads that is + // currently executing, so conservatively return false if they are in + // different basic blocks. + if (LHSI->isConvergent() && LHSI->getParent() != RHSI->getParent()) + return false; return LHSI->isIdenticalTo(RHSI); } @@ -578,12 +610,13 @@ public: unsigned Generation = 0; int MatchingId = -1; bool IsAtomic = false; + bool IsLoad = false; LoadValue() = default; LoadValue(Instruction *Inst, unsigned Generation, unsigned MatchingId, - bool IsAtomic) + bool IsAtomic, bool IsLoad) : DefInst(Inst), Generation(Generation), MatchingId(MatchingId), - IsAtomic(IsAtomic) {} + IsAtomic(IsAtomic), IsLoad(IsLoad) {} }; using LoadMapAllocator = @@ -802,17 +835,7 @@ private: Type *getValueType() const { // TODO: handle target-specific intrinsics. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - switch (II->getIntrinsicID()) { - case Intrinsic::masked_load: - return II->getType(); - case Intrinsic::masked_store: - return II->getArgOperand(0)->getType(); - default: - return nullptr; - } - } - return getLoadStoreType(Inst); + return Inst->getAccessType(); } bool mayReadFromMemory() const { @@ -1476,6 +1499,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } + if (InVal.IsLoad) + if (auto *I = dyn_cast<Instruction>(Op)) + combineMetadataForCSE(I, &Inst, false); if (!Inst.use_empty()) Inst.replaceAllUsesWith(Op); salvageKnowledge(&Inst, &AC); @@ -1490,7 +1516,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { AvailableLoads.insert(MemInst.getPointerOperand(), LoadValue(&Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic())); + MemInst.isAtomic(), + MemInst.isLoad())); LastStore = nullptr; continue; } @@ -1614,7 +1641,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { AvailableLoads.insert(MemInst.getPointerOperand(), LoadValue(&Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic())); + MemInst.isAtomic(), + MemInst.isLoad())); // Remember that this was the last unordered store we saw for DSE. We // don't yet handle DSE on ordered or volatile stores since we don't @@ -1710,10 +1738,10 @@ void EarlyCSEPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<EarlyCSEPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (UseMemorySSA) OS << "memssa"; - OS << ">"; + OS << '>'; } namespace { diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp index f66d1b914b0b..ccca8bcc1a56 100644 --- a/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -20,12 +20,9 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include <deque> #define DEBUG_TYPE "float2int" @@ -49,35 +46,6 @@ MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden, cl::desc("Max integer bitwidth to consider in float2int" "(default=64)")); -namespace { - struct Float2IntLegacyPass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - Float2IntLegacyPass() : FunctionPass(ID) { - initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return Impl.runImpl(F, DT); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } - - private: - Float2IntPass Impl; - }; -} - -char Float2IntLegacyPass::ID = 0; -INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false) - // Given a FCmp predicate, return a matching ICmp predicate if one // exists, otherwise return BAD_ICMP_PREDICATE. static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) { @@ -187,7 +155,7 @@ void Float2IntPass::walkBackwards() { Instruction *I = Worklist.back(); Worklist.pop_back(); - if (SeenInsts.find(I) != SeenInsts.end()) + if (SeenInsts.contains(I)) // Seen already. continue; @@ -371,7 +339,7 @@ bool Float2IntPass::validateAndTransform() { ConvertedToTy = I->getType(); for (User *U : I->users()) { Instruction *UI = dyn_cast<Instruction>(U); - if (!UI || SeenInsts.find(UI) == SeenInsts.end()) { + if (!UI || !SeenInsts.contains(UI)) { LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n"); Fail = true; break; @@ -391,8 +359,9 @@ bool Float2IntPass::validateAndTransform() { // The number of bits required is the maximum of the upper and // lower limits, plus one so it can be signed. - unsigned MinBW = std::max(R.getLower().getMinSignedBits(), - R.getUpper().getMinSignedBits()) + 1; + unsigned MinBW = std::max(R.getLower().getSignificantBits(), + R.getUpper().getSignificantBits()) + + 1; LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n"); // If we've run off the realms of the exactly representable integers, @@ -427,7 +396,7 @@ bool Float2IntPass::validateAndTransform() { } Value *Float2IntPass::convert(Instruction *I, Type *ToTy) { - if (ConvertedInsts.find(I) != ConvertedInsts.end()) + if (ConvertedInsts.contains(I)) // Already converted this instruction. return ConvertedInsts[I]; @@ -528,9 +497,6 @@ bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) { return Modified; } -namespace llvm { -FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); } - PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) { const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); if (!runImpl(F, DT)) @@ -540,4 +506,3 @@ PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserveSet<CFGAnalyses>(); return PA; } -} // End namespace llvm diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 6158894e3437..03e8a2507b45 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -94,6 +94,8 @@ STATISTIC(NumGVNSimpl, "Number of instructions simplified"); STATISTIC(NumGVNEqProp, "Number of equalities propagated"); STATISTIC(NumPRELoad, "Number of loads PRE'd"); STATISTIC(NumPRELoopLoad, "Number of loop loads PRE'd"); +STATISTIC(NumPRELoadMoved2CEPred, + "Number of loads moved to predecessor of a critical edge in PRE"); STATISTIC(IsValueFullyAvailableInBlockNumSpeculationsMax, "Number of blocks speculated as available in " @@ -127,6 +129,11 @@ static cl::opt<uint32_t> MaxNumVisitedInsts( cl::desc("Max number of visited instructions when trying to find " "dominating value of select dependency (default = 100)")); +static cl::opt<uint32_t> MaxNumInsnsPerBlock( + "gvn-max-num-insns", cl::Hidden, cl::init(100), + cl::desc("Max number of instructions to scan in each basic block in GVN " + "(default = 100)")); + struct llvm::GVNPass::Expression { uint32_t opcode; bool commutative = false; @@ -416,10 +423,9 @@ GVNPass::Expression GVNPass::ValueTable::createGEPExpr(GetElementPtrInst *GEP) { unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy); MapVector<Value *, APInt> VariableOffsets; APInt ConstantOffset(BitWidth, 0); - if (PtrTy->isOpaquePointerTy() && - GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) { - // For opaque pointers, convert into offset representation, to recognize - // equivalent address calculations that use different type encoding. + if (GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) { + // Convert into offset representation, to recognize equivalent address + // calculations that use different type encoding. LLVMContext &Context = GEP->getContext(); E.opcode = GEP->getOpcode(); E.type = nullptr; @@ -432,8 +438,8 @@ GVNPass::Expression GVNPass::ValueTable::createGEPExpr(GetElementPtrInst *GEP) { E.varargs.push_back( lookupOrAdd(ConstantInt::get(Context, ConstantOffset))); } else { - // If converting to offset representation fails (for typed pointers and - // scalable vectors), fall back to type-based implementation: + // If converting to offset representation fails (for scalable vectors), + // fall back to type-based implementation: E.opcode = GEP->getOpcode(); E.type = GEP->getSourceElementType(); for (Use &Op : GEP->operands()) @@ -461,28 +467,34 @@ void GVNPass::ValueTable::add(Value *V, uint32_t num) { } uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) { - if (AA->doesNotAccessMemory(C) && - // FIXME: Currently the calls which may access the thread id may - // be considered as not accessing the memory. But this is - // problematic for coroutines, since coroutines may resume in a - // different thread. So we disable the optimization here for the - // correctness. However, it may block many other correct - // optimizations. Revert this one when we detect the memory - // accessing kind more precisely. - !C->getFunction()->isPresplitCoroutine()) { + // FIXME: Currently the calls which may access the thread id may + // be considered as not accessing the memory. But this is + // problematic for coroutines, since coroutines may resume in a + // different thread. So we disable the optimization here for the + // correctness. However, it may block many other correct + // optimizations. Revert this one when we detect the memory + // accessing kind more precisely. + if (C->getFunction()->isPresplitCoroutine()) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + + // Do not combine convergent calls since they implicitly depend on the set of + // threads that is currently executing, and they might be in different basic + // blocks. + if (C->isConvergent()) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + + if (AA->doesNotAccessMemory(C)) { Expression exp = createExpr(C); uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[C] = e; return e; - } else if (MD && AA->onlyReadsMemory(C) && - // FIXME: Currently the calls which may access the thread id may - // be considered as not accessing the memory. But this is - // problematic for coroutines, since coroutines may resume in a - // different thread. So we disable the optimization here for the - // correctness. However, it may block many other correct - // optimizations. Revert this one when we detect the memory - // accessing kind more precisely. - !C->getFunction()->isPresplitCoroutine()) { + } + + if (MD && AA->onlyReadsMemory(C)) { Expression exp = createExpr(C); auto ValNum = assignExpNewValueNum(exp); if (ValNum.second) { @@ -572,10 +584,10 @@ uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) { uint32_t v = lookupOrAdd(cdep); valueNumbering[C] = v; return v; - } else { - valueNumbering[C] = nextValueNumber; - return nextValueNumber++; } + + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; } /// Returns true if a value number exists for the specified value. @@ -708,10 +720,8 @@ void GVNPass::ValueTable::erase(Value *V) { /// verifyRemoved - Verify that the value is removed from all internal data /// structures. void GVNPass::ValueTable::verifyRemoved(const Value *V) const { - for (DenseMap<Value*, uint32_t>::const_iterator - I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) { - assert(I->first != V && "Inst still occurs in value numbering map!"); - } + assert(!valueNumbering.contains(V) && + "Inst still occurs in value numbering map!"); } //===----------------------------------------------------------------------===// @@ -772,7 +782,7 @@ void GVNPass::printPipeline( static_cast<PassInfoMixin<GVNPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (Options.AllowPRE != std::nullopt) OS << (*Options.AllowPRE ? "" : "no-") << "pre;"; if (Options.AllowLoadPRE != std::nullopt) @@ -782,7 +792,7 @@ void GVNPass::printPipeline( << "split-backedge-load-pre;"; if (Options.AllowMemDep != std::nullopt) OS << (*Options.AllowMemDep ? "" : "no-") << "memdep"; - OS << ">"; + OS << '>'; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -930,6 +940,18 @@ static bool IsValueFullyAvailableInBlock( return !UnavailableBB; } +/// If the specified OldValue exists in ValuesPerBlock, replace its value with +/// NewValue. +static void replaceValuesPerBlockEntry( + SmallVectorImpl<AvailableValueInBlock> &ValuesPerBlock, Value *OldValue, + Value *NewValue) { + for (AvailableValueInBlock &V : ValuesPerBlock) { + if ((V.AV.isSimpleValue() && V.AV.getSimpleValue() == OldValue) || + (V.AV.isCoercedLoadValue() && V.AV.getCoercedLoadValue() == OldValue)) + V = AvailableValueInBlock::get(V.BB, NewValue); + } +} + /// Given a set of loads specified by ValuesPerBlock, /// construct SSA form, allowing us to eliminate Load. This returns the value /// that should be used at Load's definition site. @@ -986,7 +1008,7 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, if (isSimpleValue()) { Res = getSimpleValue(); if (Res->getType() != LoadTy) { - Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); + Res = getValueForLoad(Res, Offset, LoadTy, InsertPt, DL); LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << " " << *getSimpleValue() << '\n' @@ -997,14 +1019,23 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, LoadInst *CoercedLoad = getCoercedLoadValue(); if (CoercedLoad->getType() == LoadTy && Offset == 0) { Res = CoercedLoad; + combineMetadataForCSE(CoercedLoad, Load, false); } else { - Res = getLoadValueForLoad(CoercedLoad, Offset, LoadTy, InsertPt, DL); - // We would like to use gvn.markInstructionForDeletion here, but we can't - // because the load is already memoized into the leader map table that GVN - // tracks. It is potentially possible to remove the load from the table, - // but then there all of the operations based on it would need to be - // rehashed. Just leave the dead load around. - gvn.getMemDep().removeInstruction(CoercedLoad); + Res = getValueForLoad(CoercedLoad, Offset, LoadTy, InsertPt, DL); + // We are adding a new user for this load, for which the original + // metadata may not hold. Additionally, the new load may have a different + // size and type, so their metadata cannot be combined in any + // straightforward way. + // Drop all metadata that is not known to cause immediate UB on violation, + // unless the load has !noundef, in which case all metadata violations + // will be promoted to UB. + // TODO: We can combine noalias/alias.scope metadata here, because it is + // independent of the load type. + if (!CoercedLoad->hasMetadata(LLVMContext::MD_noundef)) + CoercedLoad->dropUnknownNonDebugMetadata( + {LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_invariant_load, LLVMContext::MD_invariant_group}); LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << " " << *getCoercedLoadValue() << '\n' << *Res << '\n' @@ -1314,9 +1345,67 @@ void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, "post condition violation"); } +/// Given the following code, v1 is partially available on some edges, but not +/// available on the edge from PredBB. This function tries to find if there is +/// another identical load in the other successor of PredBB. +/// +/// v0 = load %addr +/// br %LoadBB +/// +/// LoadBB: +/// v1 = load %addr +/// ... +/// +/// PredBB: +/// ... +/// br %cond, label %LoadBB, label %SuccBB +/// +/// SuccBB: +/// v2 = load %addr +/// ... +/// +LoadInst *GVNPass::findLoadToHoistIntoPred(BasicBlock *Pred, BasicBlock *LoadBB, + LoadInst *Load) { + // For simplicity we handle a Pred has 2 successors only. + auto *Term = Pred->getTerminator(); + if (Term->getNumSuccessors() != 2 || Term->isExceptionalTerminator()) + return nullptr; + auto *SuccBB = Term->getSuccessor(0); + if (SuccBB == LoadBB) + SuccBB = Term->getSuccessor(1); + if (!SuccBB->getSinglePredecessor()) + return nullptr; + + unsigned int NumInsts = MaxNumInsnsPerBlock; + for (Instruction &Inst : *SuccBB) { + if (Inst.isDebugOrPseudoInst()) + continue; + if (--NumInsts == 0) + return nullptr; + + if (!Inst.isIdenticalTo(Load)) + continue; + + MemDepResult Dep = MD->getDependency(&Inst); + // If an identical load doesn't depends on any local instructions, it can + // be safely moved to PredBB. + // Also check for the implicit control flow instructions. See the comments + // in PerformLoadPRE for details. + if (Dep.isNonLocal() && !ICF->isDominatedByICFIFromSameBlock(&Inst)) + return cast<LoadInst>(&Inst); + + // Otherwise there is something in the same BB clobbers the memory, we can't + // move this and later load to PredBB. + return nullptr; + } + + return nullptr; +} + void GVNPass::eliminatePartiallyRedundantLoad( LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, - MapVector<BasicBlock *, Value *> &AvailableLoads) { + MapVector<BasicBlock *, Value *> &AvailableLoads, + MapVector<BasicBlock *, LoadInst *> *CriticalEdgePredAndLoad) { for (const auto &AvailableLoad : AvailableLoads) { BasicBlock *UnavailableBlock = AvailableLoad.first; Value *LoadPtr = AvailableLoad.second; @@ -1370,10 +1459,29 @@ void GVNPass::eliminatePartiallyRedundantLoad( AvailableValueInBlock::get(UnavailableBlock, NewLoad)); MD->invalidateCachedPointerInfo(LoadPtr); LLVM_DEBUG(dbgs() << "GVN INSERTED " << *NewLoad << '\n'); + + // For PredBB in CriticalEdgePredAndLoad we need to replace the uses of old + // load instruction with the new created load instruction. + if (CriticalEdgePredAndLoad) { + auto I = CriticalEdgePredAndLoad->find(UnavailableBlock); + if (I != CriticalEdgePredAndLoad->end()) { + ++NumPRELoadMoved2CEPred; + ICF->insertInstructionTo(NewLoad, UnavailableBlock); + LoadInst *OldLoad = I->second; + combineMetadataForCSE(NewLoad, OldLoad, false); + OldLoad->replaceAllUsesWith(NewLoad); + replaceValuesPerBlockEntry(ValuesPerBlock, OldLoad, NewLoad); + if (uint32_t ValNo = VN.lookup(OldLoad, false)) + removeFromLeaderTable(ValNo, OldLoad, OldLoad->getParent()); + VN.erase(OldLoad); + removeInstruction(OldLoad); + } + } } // Perform PHI construction. Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this); + // ConstructSSAForLoadSet is responsible for combining metadata. Load->replaceAllUsesWith(V); if (isa<PHINode>(V)) V->takeName(Load); @@ -1456,7 +1564,12 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, for (BasicBlock *UnavailableBB : UnavailableBlocks) FullyAvailableBlocks[UnavailableBB] = AvailabilityState::Unavailable; - SmallVector<BasicBlock *, 4> CriticalEdgePred; + // The edge from Pred to LoadBB is a critical edge will be splitted. + SmallVector<BasicBlock *, 4> CriticalEdgePredSplit; + // The edge from Pred to LoadBB is a critical edge, another successor of Pred + // contains a load can be moved to Pred. This data structure maps the Pred to + // the movable load. + MapVector<BasicBlock *, LoadInst *> CriticalEdgePredAndLoad; for (BasicBlock *Pred : predecessors(LoadBB)) { // If any predecessor block is an EH pad that does not allow non-PHI // instructions before the terminator, we can't PRE the load. @@ -1496,7 +1609,10 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, return false; } - CriticalEdgePred.push_back(Pred); + if (LoadInst *LI = findLoadToHoistIntoPred(Pred, LoadBB, Load)) + CriticalEdgePredAndLoad[Pred] = LI; + else + CriticalEdgePredSplit.push_back(Pred); } else { // Only add the predecessors that will not be split for now. PredLoads[Pred] = nullptr; @@ -1504,31 +1620,38 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, } // Decide whether PRE is profitable for this load. - unsigned NumUnavailablePreds = PredLoads.size() + CriticalEdgePred.size(); + unsigned NumInsertPreds = PredLoads.size() + CriticalEdgePredSplit.size(); + unsigned NumUnavailablePreds = NumInsertPreds + + CriticalEdgePredAndLoad.size(); assert(NumUnavailablePreds != 0 && "Fully available value should already be eliminated!"); + (void)NumUnavailablePreds; - // If this load is unavailable in multiple predecessors, reject it. + // If we need to insert new load in multiple predecessors, reject it. // FIXME: If we could restructure the CFG, we could make a common pred with // all the preds that don't have an available Load and insert a new load into // that one block. - if (NumUnavailablePreds != 1) + if (NumInsertPreds > 1) return false; // Now we know where we will insert load. We must ensure that it is safe // to speculatively execute the load at that points. if (MustEnsureSafetyOfSpeculativeExecution) { - if (CriticalEdgePred.size()) + if (CriticalEdgePredSplit.size()) if (!isSafeToSpeculativelyExecute(Load, LoadBB->getFirstNonPHI(), AC, DT)) return false; for (auto &PL : PredLoads) if (!isSafeToSpeculativelyExecute(Load, PL.first->getTerminator(), AC, DT)) return false; + for (auto &CEP : CriticalEdgePredAndLoad) + if (!isSafeToSpeculativelyExecute(Load, CEP.first->getTerminator(), AC, + DT)) + return false; } // Split critical edges, and update the unavailable predecessors accordingly. - for (BasicBlock *OrigPred : CriticalEdgePred) { + for (BasicBlock *OrigPred : CriticalEdgePredSplit) { BasicBlock *NewPred = splitCriticalEdges(OrigPred, LoadBB); assert(!PredLoads.count(OrigPred) && "Split edges shouldn't be in map!"); PredLoads[NewPred] = nullptr; @@ -1536,6 +1659,9 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, << LoadBB->getName() << '\n'); } + for (auto &CEP : CriticalEdgePredAndLoad) + PredLoads[CEP.first] = nullptr; + // Check if the load can safely be moved to all the unavailable predecessors. bool CanDoPRE = true; const DataLayout &DL = Load->getModule()->getDataLayout(); @@ -1555,8 +1681,8 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, BasicBlock *Cur = Load->getParent(); while (Cur != LoadBB) { PHITransAddr Address(LoadPtr, DL, AC); - LoadPtr = Address.PHITranslateWithInsertion( - Cur, Cur->getSinglePredecessor(), *DT, NewInsts); + LoadPtr = Address.translateWithInsertion(Cur, Cur->getSinglePredecessor(), + *DT, NewInsts); if (!LoadPtr) { CanDoPRE = false; break; @@ -1566,8 +1692,8 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, if (LoadPtr) { PHITransAddr Address(LoadPtr, DL, AC); - LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, *DT, - NewInsts); + LoadPtr = Address.translateWithInsertion(LoadBB, UnavailablePred, *DT, + NewInsts); } // If we couldn't find or insert a computation of this phi translated value, // we fail PRE. @@ -1592,7 +1718,7 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, } // HINT: Don't revert the edge-splitting as following transformation may // also need to split these critical edges. - return !CriticalEdgePred.empty(); + return !CriticalEdgePredSplit.empty(); } // Okay, we can eliminate this load by inserting a reload in the predecessor @@ -1617,7 +1743,8 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, VN.lookupOrAdd(I); } - eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, PredLoads); + eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, PredLoads, + &CriticalEdgePredAndLoad); ++NumPRELoad; return true; } @@ -1696,7 +1823,8 @@ bool GVNPass::performLoopLoadPRE(LoadInst *Load, AvailableLoads[Preheader] = LoadPtr; LLVM_DEBUG(dbgs() << "GVN REMOVING PRE LOOP LOAD: " << *Load << '\n'); - eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, AvailableLoads); + eliminatePartiallyRedundantLoad(Load, ValuesPerBlock, AvailableLoads, + /*CriticalEdgePredAndLoad*/ nullptr); ++NumPRELoopLoad; return true; } @@ -1772,6 +1900,7 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) { // Perform PHI construction. Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this); + // ConstructSSAForLoadSet is responsible for combining metadata. Load->replaceAllUsesWith(V); if (isa<PHINode>(V)) @@ -1823,7 +1952,7 @@ static bool impliesEquivalanceIfTrue(CmpInst* Cmp) { if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero()) return true; if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero()) - return true;; + return true; // TODO: Handle vector floating point constants } return false; @@ -1849,7 +1978,7 @@ static bool impliesEquivalanceIfFalse(CmpInst* Cmp) { if (isa<ConstantFP>(LHS) && !cast<ConstantFP>(LHS)->isZero()) return true; if (isa<ConstantFP>(RHS) && !cast<ConstantFP>(RHS)->isZero()) - return true;; + return true; // TODO: Handle vector floating point constants } return false; @@ -1907,10 +2036,14 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { MSSAU->insertDef(cast<MemoryDef>(NewDef), /*RenameUses=*/false); } } - if (isAssumeWithEmptyBundle(*IntrinsicI)) + if (isAssumeWithEmptyBundle(*IntrinsicI)) { markInstructionForDeletion(IntrinsicI); + return true; + } return false; - } else if (isa<Constant>(V)) { + } + + if (isa<Constant>(V)) { // If it's not false, and constant, it must evaluate to true. This means our // assume is assume(true), and thus, pointless, and we don't want to do // anything more here. @@ -2043,8 +2176,8 @@ bool GVNPass::processLoad(LoadInst *L) { Value *AvailableValue = AV->MaterializeAdjustedValue(L, L, *this); - // Replace the load! - patchAndReplaceAllUsesWith(L, AvailableValue); + // MaterializeAdjustedValue is responsible for combining metadata. + L->replaceAllUsesWith(AvailableValue); markInstructionForDeletion(L); if (MSSAU) MSSAU->removeMemoryAccess(L); @@ -2543,7 +2676,9 @@ bool GVNPass::processInstruction(Instruction *I) { // Failure, just remember this instance for future use. addToLeaderTable(Num, I, I->getParent()); return false; - } else if (Repl == I) { + } + + if (Repl == I) { // If I was the result of a shortcut PRE, it might already be in the table // and the best replacement for itself. Nothing to do. return false; @@ -2669,12 +2804,7 @@ bool GVNPass::processBlock(BasicBlock *BB) { LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n'); salvageKnowledge(I, AC); salvageDebugInfo(*I); - if (MD) MD->removeInstruction(I); - if (MSSAU) - MSSAU->removeMemoryAccess(I); - LLVM_DEBUG(verifyRemoved(I)); - ICF->removeInstruction(I); - I->eraseFromParent(); + removeInstruction(I); } InstrsToErase.clear(); @@ -2765,9 +2895,6 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { // We don't currently value number ANY inline asm calls. if (CallB->isInlineAsm()) return false; - // Don't do PRE on convergent calls. - if (CallB->isConvergent()) - return false; } uint32_t ValNo = VN.lookup(CurInst); @@ -2855,7 +2982,9 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { PREInstr = CurInst->clone(); if (!performScalarPREInsertion(PREInstr, PREPred, CurrentBlock, ValNo)) { // If we failed insertion, make sure we remove the instruction. - LLVM_DEBUG(verifyRemoved(PREInstr)); +#ifndef NDEBUG + verifyRemoved(PREInstr); +#endif PREInstr->deleteValue(); return false; } @@ -2894,15 +3023,7 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { removeFromLeaderTable(ValNo, CurInst, CurrentBlock); LLVM_DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); - if (MD) - MD->removeInstruction(CurInst); - if (MSSAU) - MSSAU->removeMemoryAccess(CurInst); - LLVM_DEBUG(verifyRemoved(CurInst)); - // FIXME: Intended to be markInstructionForDeletion(CurInst), but it causes - // some assertion failures. - ICF->removeInstruction(CurInst); - CurInst->eraseFromParent(); + removeInstruction(CurInst); ++NumGVNInstr; return true; @@ -2998,6 +3119,17 @@ void GVNPass::cleanupGlobalSets() { InvalidBlockRPONumbers = true; } +void GVNPass::removeInstruction(Instruction *I) { + if (MD) MD->removeInstruction(I); + if (MSSAU) + MSSAU->removeMemoryAccess(I); +#ifndef NDEBUG + verifyRemoved(I); +#endif + ICF->removeInstruction(I); + I->eraseFromParent(); +} + /// Verify that the specified instruction does not occur in our /// internal data structures. void GVNPass::verifyRemoved(const Instruction *Inst) const { diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp index bbff497b7d92..b564f00eb9d1 100644 --- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -62,13 +62,10 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> @@ -519,39 +516,6 @@ private: std::pair<unsigned, unsigned> hoistExpressions(Function &F); }; -class GVNHoistLegacyPass : public FunctionPass { -public: - static char ID; - - GVNHoistLegacyPass() : FunctionPass(ID) { - initializeGVNHoistLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto &MD = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); - auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - - GVNHoist G(&DT, &PDT, &AA, &MD, &MSSA); - return G.run(F); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<MemoryDependenceWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; - bool GVNHoist::run(Function &F) { NumFuncArgs = F.arg_size(); VN.setDomTree(DT); @@ -808,15 +772,20 @@ bool GVNHoist::valueAnticipable(CHIArgs C, Instruction *TI) const { void GVNHoist::checkSafety(CHIArgs C, BasicBlock *BB, GVNHoist::InsKind K, SmallVectorImpl<CHIArg> &Safe) { int NumBBsOnAllPaths = MaxNumberOfBBSInPath; + const Instruction *T = BB->getTerminator(); for (auto CHI : C) { Instruction *Insn = CHI.I; if (!Insn) // No instruction was inserted in this CHI. continue; + // If the Terminator is some kind of "exotic terminator" that produces a + // value (such as InvokeInst, CallBrInst, or CatchSwitchInst) which the CHI + // uses, it is not safe to hoist the use above the def. + if (!T->use_empty() && is_contained(Insn->operands(), cast<const Value>(T))) + continue; if (K == InsKind::Scalar) { if (safeToHoistScalar(BB, Insn->getParent(), NumBBsOnAllPaths)) Safe.push_back(CHI); } else { - auto *T = BB->getTerminator(); if (MemoryUseOrDef *UD = MSSA->getMemoryAccess(Insn)) if (safeToHoistLdSt(T, Insn, UD, K, NumBBsOnAllPaths)) Safe.push_back(CHI); @@ -1251,17 +1220,3 @@ PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<MemorySSAAnalysis>(); return PA; } - -char GVNHoistLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist", - "Early GVN Hoisting of Expressions", false, false) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(GVNHoistLegacyPass, "gvn-hoist", - "Early GVN Hoisting of Expressions", false, false) - -FunctionPass *llvm::createGVNHoistPass() { return new GVNHoistLegacyPass(); } diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index 5fb8a77051fb..26a6978656e6 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -54,8 +54,6 @@ #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ArrayRecycler.h" #include "llvm/Support/AtomicOrdering.h" @@ -63,7 +61,6 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/GVNExpression.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -154,7 +151,7 @@ public: void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) { for (auto II = Insts.begin(); II != Insts.end();) { - if (!llvm::is_contained(Blocks, (*II)->getParent())) { + if (!Blocks.contains((*II)->getParent())) { ActiveBlocks.remove((*II)->getParent()); II = Insts.erase(II); } else { @@ -272,7 +269,7 @@ public: auto VI = Values.begin(); while (BI != Blocks.end()) { assert(VI != Values.end()); - if (!llvm::is_contained(NewBlocks, *BI)) { + if (!NewBlocks.contains(*BI)) { BI = Blocks.erase(BI); VI = Values.erase(VI); } else { @@ -886,29 +883,6 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, NumRemoved += Insts.size() - 1; } -//////////////////////////////////////////////////////////////////////////////// -// Pass machinery / boilerplate - -class GVNSinkLegacyPass : public FunctionPass { -public: - static char ID; - - GVNSinkLegacyPass() : FunctionPass(ID) { - initializeGVNSinkLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - GVNSink G; - return G.run(F); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; - } // end anonymous namespace PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { @@ -917,14 +891,3 @@ PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { return PreservedAnalyses::all(); return PreservedAnalyses::none(); } - -char GVNSinkLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink", - "Early GVN sinking of Expressions", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END(GVNSinkLegacyPass, "gvn-sink", - "Early GVN sinking of Expressions", false, false) - -FunctionPass *llvm::createGVNSinkPass() { return new GVNSinkLegacyPass(); } diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp index abe0babc3f12..62b40a23e38c 100644 --- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -69,6 +69,7 @@ using namespace llvm; STATISTIC(GuardsEliminated, "Number of eliminated guards"); STATISTIC(CondBranchEliminated, "Number of eliminated conditional branches"); +STATISTIC(FreezeAdded, "Number of freeze instruction introduced"); static cl::opt<bool> WidenBranchGuards("guard-widening-widen-branch-guards", cl::Hidden, @@ -113,6 +114,23 @@ static void eliminateGuard(Instruction *GuardInst, MemorySSAUpdater *MSSAU) { ++GuardsEliminated; } +/// Find a point at which the widened condition of \p Guard should be inserted. +/// When it is represented as intrinsic call, we can do it right before the call +/// instruction. However, when we are dealing with widenable branch, we must +/// account for the following situation: widening should not turn a +/// loop-invariant condition into a loop-variant. It means that if +/// widenable.condition() call is invariant (w.r.t. any loop), the new wide +/// condition should stay invariant. Otherwise there can be a miscompile, like +/// the one described at https://github.com/llvm/llvm-project/issues/60234. The +/// safest way to do it is to expand the new condition at WC's block. +static Instruction *findInsertionPointForWideCondition(Instruction *Guard) { + Value *Condition, *WC; + BasicBlock *IfTrue, *IfFalse; + if (parseWidenableBranch(Guard, Condition, WC, IfTrue, IfFalse)) + return cast<Instruction>(WC); + return Guard; +} + class GuardWideningImpl { DominatorTree &DT; PostDominatorTree *PDT; @@ -170,16 +188,16 @@ class GuardWideningImpl { bool InvertCond); /// Helper to check if \p V can be hoisted to \p InsertPos. - bool isAvailableAt(const Value *V, const Instruction *InsertPos) const { + bool canBeHoistedTo(const Value *V, const Instruction *InsertPos) const { SmallPtrSet<const Instruction *, 8> Visited; - return isAvailableAt(V, InsertPos, Visited); + return canBeHoistedTo(V, InsertPos, Visited); } - bool isAvailableAt(const Value *V, const Instruction *InsertPos, - SmallPtrSetImpl<const Instruction *> &Visited) const; + bool canBeHoistedTo(const Value *V, const Instruction *InsertPos, + SmallPtrSetImpl<const Instruction *> &Visited) const; /// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c - /// isAvailableAt returned true. + /// canBeHoistedTo returned true. void makeAvailableAt(Value *V, Instruction *InsertPos) const; /// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try @@ -192,6 +210,10 @@ class GuardWideningImpl { bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, Value *&Result, bool InvertCondition); + /// Adds freeze to Orig and push it as far as possible very aggressively. + /// Also replaces all uses of frozen instruction with frozen version. + Value *freezeAndPush(Value *Orig, Instruction *InsertPt); + /// Represents a range check of the form \c Base + \c Offset u< \c Length, /// with the constraint that \c Length is not negative. \c CheckInst is the /// pre-existing instruction in the IR that computes the result of this range @@ -263,8 +285,8 @@ class GuardWideningImpl { void widenGuard(Instruction *ToWiden, Value *NewCondition, bool InvertCondition) { Value *Result; - - widenCondCommon(getCondition(ToWiden), NewCondition, ToWiden, Result, + Instruction *InsertPt = findInsertionPointForWideCondition(ToWiden); + widenCondCommon(getCondition(ToWiden), NewCondition, InsertPt, Result, InvertCondition); if (isGuardAsWidenableBranch(ToWiden)) { setWidenableBranchCond(cast<BranchInst>(ToWiden), Result); @@ -422,7 +444,10 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, HoistingOutOfLoop = true; } - if (!isAvailableAt(getCondition(DominatedInstr), DominatingGuard)) + auto *WideningPoint = findInsertionPointForWideCondition(DominatingGuard); + if (!canBeHoistedTo(getCondition(DominatedInstr), WideningPoint)) + return WS_IllegalOrNegative; + if (!canBeHoistedTo(getCondition(DominatingGuard), WideningPoint)) return WS_IllegalOrNegative; // If the guard was conditional executed, it may never be reached @@ -440,30 +465,70 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, if (HoistingOutOfLoop) return WS_Positive; - // Returns true if we might be hoisting above explicit control flow. Note - // that this completely ignores implicit control flow (guards, calls which - // throw, etc...). That choice appears arbitrary. - auto MaybeHoistingOutOfIf = [&]() { - auto *DominatingBlock = DominatingGuard->getParent(); - auto *DominatedBlock = DominatedInstr->getParent(); - if (isGuardAsWidenableBranch(DominatingGuard)) - DominatingBlock = cast<BranchInst>(DominatingGuard)->getSuccessor(0); + // For a given basic block \p BB, return its successor which is guaranteed or + // highly likely will be taken as its successor. + auto GetLikelySuccessor = [](const BasicBlock * BB)->const BasicBlock * { + if (auto *UniqueSucc = BB->getUniqueSuccessor()) + return UniqueSucc; + auto *Term = BB->getTerminator(); + Value *Cond = nullptr; + const BasicBlock *IfTrue = nullptr, *IfFalse = nullptr; + using namespace PatternMatch; + if (!match(Term, m_Br(m_Value(Cond), m_BasicBlock(IfTrue), + m_BasicBlock(IfFalse)))) + return nullptr; + // For constant conditions, only one dynamical successor is possible + if (auto *ConstCond = dyn_cast<ConstantInt>(Cond)) + return ConstCond->isAllOnesValue() ? IfTrue : IfFalse; + // If one of successors ends with deopt, another one is likely. + if (IfFalse->getPostdominatingDeoptimizeCall()) + return IfTrue; + if (IfTrue->getPostdominatingDeoptimizeCall()) + return IfFalse; + // TODO: Use branch frequency metatada to allow hoisting through non-deopt + // branches? + return nullptr; + }; + + // Returns true if we might be hoisting above explicit control flow into a + // considerably hotter block. Note that this completely ignores implicit + // control flow (guards, calls which throw, etc...). That choice appears + // arbitrary (we assume that implicit control flow exits are all rare). + auto MaybeHoistingToHotterBlock = [&]() { + const auto *DominatingBlock = DominatingGuard->getParent(); + const auto *DominatedBlock = DominatedInstr->getParent(); + + // Descend as low as we can, always taking the likely successor. + assert(DT.isReachableFromEntry(DominatingBlock) && "Unreached code"); + assert(DT.isReachableFromEntry(DominatedBlock) && "Unreached code"); + assert(DT.dominates(DominatingBlock, DominatedBlock) && "No dominance"); + while (DominatedBlock != DominatingBlock) { + auto *LikelySucc = GetLikelySuccessor(DominatingBlock); + // No likely successor? + if (!LikelySucc) + break; + // Only go down the dominator tree. + if (!DT.properlyDominates(DominatingBlock, LikelySucc)) + break; + DominatingBlock = LikelySucc; + } - // Same Block? + // Found? if (DominatedBlock == DominatingBlock) return false; - // Obvious successor (common loop header/preheader case) - if (DominatedBlock == DominatingBlock->getUniqueSuccessor()) - return false; + // We followed the likely successor chain and went past the dominated + // block. It means that the dominated guard is in dead/very cold code. + if (!DT.dominates(DominatingBlock, DominatedBlock)) + return true; // TODO: diamond, triangle cases if (!PDT) return true; return !PDT->dominates(DominatedBlock, DominatingBlock); }; - return MaybeHoistingOutOfIf() ? WS_IllegalOrNegative : WS_Neutral; + return MaybeHoistingToHotterBlock() ? WS_IllegalOrNegative : WS_Neutral; } -bool GuardWideningImpl::isAvailableAt( +bool GuardWideningImpl::canBeHoistedTo( const Value *V, const Instruction *Loc, SmallPtrSetImpl<const Instruction *> &Visited) const { auto *Inst = dyn_cast<Instruction>(V); @@ -482,7 +547,7 @@ bool GuardWideningImpl::isAvailableAt( assert(DT.isReachableFromEntry(Inst->getParent()) && "We did a DFS from the block entry!"); return all_of(Inst->operands(), - [&](Value *Op) { return isAvailableAt(Op, Loc, Visited); }); + [&](Value *Op) { return canBeHoistedTo(Op, Loc, Visited); }); } void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { @@ -491,14 +556,115 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { return; assert(isSafeToSpeculativelyExecute(Inst, Loc, &AC, &DT) && - !Inst->mayReadFromMemory() && "Should've checked with isAvailableAt!"); + !Inst->mayReadFromMemory() && + "Should've checked with canBeHoistedTo!"); for (Value *Op : Inst->operands()) makeAvailableAt(Op, Loc); Inst->moveBefore(Loc); - // If we moved instruction before guard we must clean poison generating flags. - Inst->dropPoisonGeneratingFlags(); +} + +// Return Instruction before which we can insert freeze for the value V as close +// to def as possible. If there is no place to add freeze, return nullptr. +static Instruction *getFreezeInsertPt(Value *V, const DominatorTree &DT) { + auto *I = dyn_cast<Instruction>(V); + if (!I) + return &*DT.getRoot()->getFirstNonPHIOrDbgOrAlloca(); + + auto *Res = I->getInsertionPointAfterDef(); + // If there is no place to add freeze - return nullptr. + if (!Res || !DT.dominates(I, Res)) + return nullptr; + + // If there is a User dominated by original I, then it should be dominated + // by Freeze instruction as well. + if (any_of(I->users(), [&](User *U) { + Instruction *User = cast<Instruction>(U); + return Res != User && DT.dominates(I, User) && !DT.dominates(Res, User); + })) + return nullptr; + return Res; +} + +Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { + if (isGuaranteedNotToBePoison(Orig, nullptr, InsertPt, &DT)) + return Orig; + Instruction *InsertPtAtDef = getFreezeInsertPt(Orig, DT); + if (!InsertPtAtDef) + return new FreezeInst(Orig, "gw.freeze", InsertPt); + if (isa<Constant>(Orig) || isa<GlobalValue>(Orig)) + return new FreezeInst(Orig, "gw.freeze", InsertPtAtDef); + + SmallSet<Value *, 16> Visited; + SmallVector<Value *, 16> Worklist; + SmallSet<Instruction *, 16> DropPoisonFlags; + SmallVector<Value *, 16> NeedFreeze; + DenseMap<Value *, FreezeInst *> CacheOfFreezes; + + // A bit overloaded data structures. Visited contains constant/GV + // if we already met it. In this case CacheOfFreezes has a freeze if it is + // required. + auto handleConstantOrGlobal = [&](Use &U) { + Value *Def = U.get(); + if (!isa<Constant>(Def) && !isa<GlobalValue>(Def)) + return false; + + if (Visited.insert(Def).second) { + if (isGuaranteedNotToBePoison(Def, nullptr, InsertPt, &DT)) + return true; + CacheOfFreezes[Def] = new FreezeInst(Def, Def->getName() + ".gw.fr", + getFreezeInsertPt(Def, DT)); + } + + if (CacheOfFreezes.count(Def)) + U.set(CacheOfFreezes[Def]); + return true; + }; + + Worklist.push_back(Orig); + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + if (isGuaranteedNotToBePoison(V, nullptr, InsertPt, &DT)) + continue; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I || canCreateUndefOrPoison(cast<Operator>(I), + /*ConsiderFlagsAndMetadata*/ false)) { + NeedFreeze.push_back(V); + continue; + } + // Check all operands. If for any of them we cannot insert Freeze, + // stop here. Otherwise, iterate. + if (any_of(I->operands(), [&](Value *Op) { + return isa<Instruction>(Op) && !getFreezeInsertPt(Op, DT); + })) { + NeedFreeze.push_back(I); + continue; + } + DropPoisonFlags.insert(I); + for (Use &U : I->operands()) + if (!handleConstantOrGlobal(U)) + Worklist.push_back(U.get()); + } + for (Instruction *I : DropPoisonFlags) + I->dropPoisonGeneratingFlagsAndMetadata(); + + Value *Result = Orig; + for (Value *V : NeedFreeze) { + auto *FreezeInsertPt = getFreezeInsertPt(V, DT); + FreezeInst *FI = new FreezeInst(V, V->getName() + ".gw.fr", FreezeInsertPt); + ++FreezeAdded; + if (V == Orig) + Result = FI; + V->replaceUsesWithIf( + FI, [&](const Use & U)->bool { return U.getUser() != FI; }); + } + + return Result; } bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, @@ -532,6 +698,8 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, if (InsertPt) { ConstantInt *NewRHS = ConstantInt::get(Cond0->getContext(), NewRHSAP); + assert(canBeHoistedTo(LHS, InsertPt) && "must be"); + makeAvailableAt(LHS, InsertPt); Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); } return true; @@ -558,6 +726,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, } assert(Result && "Failed to find result value"); Result->setName("wide.chk"); + Result = freezeAndPush(Result, InsertPt); } return true; } @@ -570,6 +739,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, makeAvailableAt(Cond1, InsertPt); if (InvertCondition) Cond1 = BinaryOperator::CreateNot(Cond1, "inverted", InsertPt); + Cond1 = freezeAndPush(Cond1, InsertPt); Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt); } diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index c834e51b5f29..40475d9563b2 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -64,15 +64,12 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -93,15 +90,6 @@ STATISTIC(NumLFTR , "Number of loop exit tests replaced"); STATISTIC(NumElimExt , "Number of IV sign/zero extends eliminated"); STATISTIC(NumElimIV , "Number of congruent IVs eliminated"); -// Trip count verification can be enabled by default under NDEBUG if we -// implement a strong expression equivalence checker in SCEV. Until then, we -// use the verify-indvars flag, which may assert in some cases. -static cl::opt<bool> VerifyIndvars( - "verify-indvars", cl::Hidden, - cl::desc("Verify the ScalarEvolution result after running indvars. Has no " - "effect in release builds. (Note: this adds additional SCEV " - "queries potentially changing the analysis result)")); - static cl::opt<ReplaceExitVal> ReplaceExitValue( "replexitval", cl::Hidden, cl::init(OnlyCheapRepl), cl::desc("Choose the strategy to replace exit value in IndVarSimplify"), @@ -416,8 +404,8 @@ bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { PHIs.push_back(&PN); bool Changed = false; - for (unsigned i = 0, e = PHIs.size(); i != e; ++i) - if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHIs[i])) + for (WeakTrackingVH &PHI : PHIs) + if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHI)) Changed |= handleFloatingPointIV(L, PN); // If the loop previously had floating-point IV, ScalarEvolution @@ -759,50 +747,6 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) { return Phi != getLoopPhiForCounter(IncV, L); } -/// Return true if undefined behavior would provable be executed on the path to -/// OnPathTo if Root produced a posion result. Note that this doesn't say -/// anything about whether OnPathTo is actually executed or whether Root is -/// actually poison. This can be used to assess whether a new use of Root can -/// be added at a location which is control equivalent with OnPathTo (such as -/// immediately before it) without introducing UB which didn't previously -/// exist. Note that a false result conveys no information. -static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, - Instruction *OnPathTo, - DominatorTree *DT) { - // Basic approach is to assume Root is poison, propagate poison forward - // through all users we can easily track, and then check whether any of those - // users are provable UB and must execute before out exiting block might - // exit. - - // The set of all recursive users we've visited (which are assumed to all be - // poison because of said visit) - SmallSet<const Value *, 16> KnownPoison; - SmallVector<const Instruction*, 16> Worklist; - Worklist.push_back(Root); - while (!Worklist.empty()) { - const Instruction *I = Worklist.pop_back_val(); - - // If we know this must trigger UB on a path leading our target. - if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo)) - return true; - - // If we can't analyze propagation through this instruction, just skip it - // and transitive users. Safe as false is a conservative result. - if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) { - return KnownPoison.contains(U) && propagatesPoison(U); - })) - continue; - - if (KnownPoison.insert(I).second) - for (const User *User : I->users()) - Worklist.push_back(cast<Instruction>(User)); - } - - // Might be non-UB, or might have a path we couldn't prove must execute on - // way to exiting bb. - return false; -} - /// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils /// down to checking that all operands are constant and listing instructions /// that may hide undef. @@ -845,20 +789,6 @@ static bool hasConcreteDef(Value *V) { return hasConcreteDefImpl(V, Visited, 0); } -/// Return true if this IV has any uses other than the (soon to be rewritten) -/// loop exit test. -static bool AlmostDeadIV(PHINode *Phi, BasicBlock *LatchBlock, Value *Cond) { - int LatchIdx = Phi->getBasicBlockIndex(LatchBlock); - Value *IncV = Phi->getIncomingValue(LatchIdx); - - for (User *U : Phi->users()) - if (U != Cond && U != IncV) return false; - - for (User *U : IncV->users()) - if (U != Cond && U != Phi) return false; - return true; -} - /// Return true if the given phi is a "counter" in L. A counter is an /// add recurance (of integer or pointer type) with an arbitrary start, and a /// step of 1. Note that L must have exactly one latch. @@ -910,10 +840,6 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, if (!isLoopCounter(Phi, L, SE)) continue; - // Avoid comparing an integer IV against a pointer Limit. - if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy()) - continue; - const auto *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); // AR may be a pointer type, while BECount is an integer type. @@ -949,9 +875,9 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, const SCEV *Init = AR->getStart(); - if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { + if (BestPhi && !isAlmostDeadIV(BestPhi, LatchBlock, Cond)) { // Don't force a live loop counter if another IV can be used. - if (AlmostDeadIV(Phi, LatchBlock, Cond)) + if (isAlmostDeadIV(Phi, LatchBlock, Cond)) continue; // Prefer to count-from-zero. This is a more "canonical" counter form. It @@ -979,78 +905,29 @@ static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB, const SCEV *ExitCount, bool UsePostInc, Loop *L, SCEVExpander &Rewriter, ScalarEvolution *SE) { assert(isLoopCounter(IndVar, L, SE)); + assert(ExitCount->getType()->isIntegerTy() && "exit count must be integer"); const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar)); - const SCEV *IVInit = AR->getStart(); assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride"); - // IVInit may be a pointer while ExitCount is an integer when FindLoopCounter - // finds a valid pointer IV. Sign extend ExitCount in order to materialize a - // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing - // the existing GEPs whenever possible. - if (IndVar->getType()->isPointerTy() && - !ExitCount->getType()->isPointerTy()) { - // IVOffset will be the new GEP offset that is interpreted by GEP as a - // signed value. ExitCount on the other hand represents the loop trip count, - // which is an unsigned value. FindLoopCounter only allows induction - // variables that have a positive unit stride of one. This means we don't - // have to handle the case of negative offsets (yet) and just need to zero - // extend ExitCount. - Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType()); - const SCEV *IVOffset = SE->getTruncateOrZeroExtend(ExitCount, OfsTy); - if (UsePostInc) - IVOffset = SE->getAddExpr(IVOffset, SE->getOne(OfsTy)); - - // Expand the code for the iteration count. - assert(SE->isLoopInvariant(IVOffset, L) && - "Computed iteration count is not loop invariant!"); - - const SCEV *IVLimit = SE->getAddExpr(IVInit, IVOffset); - BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); - return Rewriter.expandCodeFor(IVLimit, IndVar->getType(), BI); - } else { - // In any other case, convert both IVInit and ExitCount to integers before - // comparing. This may result in SCEV expansion of pointers, but in practice - // SCEV will fold the pointer arithmetic away as such: - // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc). - // - // Valid Cases: (1) both integers is most common; (2) both may be pointers - // for simple memset-style loops. - // - // IVInit integer and ExitCount pointer would only occur if a canonical IV - // were generated on top of case #2, which is not expected. - - // For unit stride, IVCount = Start + ExitCount with 2's complement - // overflow. - - // For integer IVs, truncate the IV before computing IVInit + BECount, - // unless we know apriori that the limit must be a constant when evaluated - // in the bitwidth of the IV. We prefer (potentially) keeping a truncate - // of the IV in the loop over a (potentially) expensive expansion of the - // widened exit count add(zext(add)) expression. - if (SE->getTypeSizeInBits(IVInit->getType()) - > SE->getTypeSizeInBits(ExitCount->getType())) { - if (isa<SCEVConstant>(IVInit) && isa<SCEVConstant>(ExitCount)) - ExitCount = SE->getZeroExtendExpr(ExitCount, IVInit->getType()); - else - IVInit = SE->getTruncateExpr(IVInit, ExitCount->getType()); - } - - const SCEV *IVLimit = SE->getAddExpr(IVInit, ExitCount); - - if (UsePostInc) - IVLimit = SE->getAddExpr(IVLimit, SE->getOne(IVLimit->getType())); - - // Expand the code for the iteration count. - assert(SE->isLoopInvariant(IVLimit, L) && - "Computed iteration count is not loop invariant!"); - // Ensure that we generate the same type as IndVar, or a smaller integer - // type. In the presence of null pointer values, we have an integer type - // SCEV expression (IVInit) for a pointer type IV value (IndVar). - Type *LimitTy = ExitCount->getType()->isPointerTy() ? - IndVar->getType() : ExitCount->getType(); - BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); - return Rewriter.expandCodeFor(IVLimit, LimitTy, BI); + // For integer IVs, truncate the IV before computing the limit unless we + // know apriori that the limit must be a constant when evaluated in the + // bitwidth of the IV. We prefer (potentially) keeping a truncate of the + // IV in the loop over a (potentially) expensive expansion of the widened + // exit count add(zext(add)) expression. + if (IndVar->getType()->isIntegerTy() && + SE->getTypeSizeInBits(AR->getType()) > + SE->getTypeSizeInBits(ExitCount->getType())) { + const SCEV *IVInit = AR->getStart(); + if (!isa<SCEVConstant>(IVInit) || !isa<SCEVConstant>(ExitCount)) + AR = cast<SCEVAddRecExpr>(SE->getTruncateExpr(AR, ExitCount->getType())); } + + const SCEVAddRecExpr *ARBase = UsePostInc ? AR->getPostIncExpr(*SE) : AR; + const SCEV *IVLimit = ARBase->evaluateAtIteration(ExitCount, *SE); + assert(SE->isLoopInvariant(IVLimit, L) && + "Computed iteration count is not loop invariant!"); + return Rewriter.expandCodeFor(IVLimit, ARBase->getType(), + ExitingBB->getTerminator()); } /// This method rewrites the exit condition of the loop to be a canonical != @@ -1148,8 +1025,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, // a truncate within in. bool Extended = false; const SCEV *IV = SE->getSCEV(CmpIndVar); - const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar), - ExitCnt->getType()); + const SCEV *TruncatedIV = SE->getTruncateExpr(IV, ExitCnt->getType()); const SCEV *ZExtTrunc = SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType()); @@ -1359,14 +1235,16 @@ createInvariantCond(const Loop *L, BasicBlock *ExitingBB, const ScalarEvolution::LoopInvariantPredicate &LIP, SCEVExpander &Rewriter) { ICmpInst::Predicate InvariantPred = LIP.Pred; - BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); - Rewriter.setInsertPoint(BI); + BasicBlock *Preheader = L->getLoopPreheader(); + assert(Preheader && "Preheader doesn't exist"); + Rewriter.setInsertPoint(Preheader->getTerminator()); auto *LHSV = Rewriter.expandCodeFor(LIP.LHS); auto *RHSV = Rewriter.expandCodeFor(LIP.RHS); bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); if (ExitIfTrue) InvariantPred = ICmpInst::getInversePredicate(InvariantPred); - IRBuilder<> Builder(BI); + IRBuilder<> Builder(Preheader->getTerminator()); + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); return Builder.CreateICmp(InvariantPred, LHSV, RHSV, BI->getCondition()->getName()); } @@ -1519,7 +1397,6 @@ static bool optimizeLoopExitWithUnknownExitCount( auto *NewCond = *Replaced; if (auto *NCI = dyn_cast<Instruction>(NewCond)) { NCI->setName(OldCond->getName() + ".first_iter"); - NCI->moveBefore(cast<Instruction>(OldCond)); } LLVM_DEBUG(dbgs() << "Unknown exit count: Replacing " << *OldCond << " with " << *NewCond << "\n"); @@ -2022,16 +1899,6 @@ bool IndVarSimplify::run(Loop *L) { if (!L->isLoopSimplifyForm()) return false; -#ifndef NDEBUG - // Used below for a consistency check only - // Note: Since the result returned by ScalarEvolution may depend on the order - // in which previous results are added to its cache, the call to - // getBackedgeTakenCount() may change following SCEV queries. - const SCEV *BackedgeTakenCount; - if (VerifyIndvars) - BackedgeTakenCount = SE->getBackedgeTakenCount(L); -#endif - bool Changed = false; // If there are any floating-point recurrences, attempt to // transform them to use integer recurrences. @@ -2180,27 +2047,8 @@ bool IndVarSimplify::run(Loop *L) { // Check a post-condition. assert(L->isRecursivelyLCSSAForm(*DT, *LI) && "Indvars did not preserve LCSSA!"); - - // Verify that LFTR, and any other change have not interfered with SCEV's - // ability to compute trip count. We may have *changed* the exit count, but - // only by reducing it. -#ifndef NDEBUG - if (VerifyIndvars && !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) { - SE->forgetLoop(L); - const SCEV *NewBECount = SE->getBackedgeTakenCount(L); - if (SE->getTypeSizeInBits(BackedgeTakenCount->getType()) < - SE->getTypeSizeInBits(NewBECount->getType())) - NewBECount = SE->getTruncateOrNoop(NewBECount, - BackedgeTakenCount->getType()); - else - BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, - NewBECount->getType()); - assert(!SE->isKnownPredicate(ICmpInst::ICMP_ULT, BackedgeTakenCount, - NewBECount) && "indvars must preserve SCEV"); - } if (VerifyMemorySSA && MSSAU) MSSAU->getMemorySSA()->verifyMemorySSA(); -#endif return Changed; } @@ -2222,54 +2070,3 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { - -struct IndVarSimplifyLegacyPass : public LoopPass { - static char ID; // Pass identification, replacement for typeid - - IndVarSimplifyLegacyPass() : LoopPass(ID) { - initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr; - auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); - auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - MemorySSA *MSSA = nullptr; - if (MSSAAnalysis) - MSSA = &MSSAAnalysis->getMSSA(); - - IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA, AllowIVWidening); - return IVS.run(L); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addPreserved<MemorySSAWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - -char IndVarSimplifyLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars", - "Induction Variable Simplification", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars", - "Induction Variable Simplification", false, false) - -Pass *llvm::createIndVarSimplifyPass() { - return new IndVarSimplifyLegacyPass(); -} diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 52a4bc8a9f24..b52589baeee7 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -72,8 +72,6 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -81,7 +79,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -121,6 +119,16 @@ static cl::opt<bool> AllowNarrowLatchCondition( cl::desc("If set to true, IRCE may eliminate wide range checks in loops " "with narrow latch condition.")); +static cl::opt<unsigned> MaxTypeSizeForOverflowCheck( + "irce-max-type-size-for-overflow-check", cl::Hidden, cl::init(32), + cl::desc( + "Maximum size of range check type for which can be produced runtime " + "overflow check of its limit's computation")); + +static cl::opt<bool> + PrintScaledBoundaryRangeChecks("irce-print-scaled-boundary-range-checks", + cl::Hidden, cl::init(false)); + static const char *ClonedLoopTag = "irce.loop.clone"; #define DEBUG_TYPE "irce" @@ -145,14 +153,23 @@ class InductiveRangeCheck { Use *CheckUse = nullptr; static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, - Value *&Index, Value *&Length, - bool &IsSigned); + const SCEVAddRecExpr *&Index, + const SCEV *&End); static void extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, SmallVectorImpl<InductiveRangeCheck> &Checks, SmallPtrSetImpl<Value *> &Visited); + static bool parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS, + ICmpInst::Predicate Pred, ScalarEvolution &SE, + const SCEVAddRecExpr *&Index, + const SCEV *&End); + + static bool reassociateSubLHS(Loop *L, Value *VariantLHS, Value *InvariantRHS, + ICmpInst::Predicate Pred, ScalarEvolution &SE, + const SCEVAddRecExpr *&Index, const SCEV *&End); + public: const SCEV *getBegin() const { return Begin; } const SCEV *getStep() const { return Step; } @@ -219,10 +236,9 @@ public: /// /// NB! There may be conditions feeding into \p BI that aren't inductive range /// checks, and hence don't end up in \p Checks. - static void - extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, - BranchProbabilityInfo *BPI, - SmallVectorImpl<InductiveRangeCheck> &Checks); + static void extractRangeChecksFromBranch( + BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, + SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed); }; struct LoopStructure; @@ -250,48 +266,16 @@ public: bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); }; -class IRCELegacyPass : public FunctionPass { -public: - static char ID; - - IRCELegacyPass() : FunctionPass(ID) { - initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<BranchProbabilityInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - } - - bool runOnFunction(Function &F) override; -}; - } // end anonymous namespace -char IRCELegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce", - "Inductive range check elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination", - false, false) - /// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot -/// be interpreted as a range check, return false and set `Index` and `Length` -/// to `nullptr`. Otherwise set `Index` to the value being range checked, and -/// set `Length` to the upper limit `Index` is being range checked. -bool -InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, - ScalarEvolution &SE, Value *&Index, - Value *&Length, bool &IsSigned) { +/// be interpreted as a range check, return false. Otherwise set `Index` to the +/// SCEV being range checked, and set `End` to the upper or lower limit `Index` +/// is being range checked. +bool InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, + ScalarEvolution &SE, + const SCEVAddRecExpr *&Index, + const SCEV *&End) { auto IsLoopInvariant = [&SE, L](Value *V) { return SE.isLoopInvariant(SE.getSCEV(V), L); }; @@ -300,47 +284,79 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); + // Canonicalize to the `Index Pred Invariant` comparison + if (IsLoopInvariant(LHS)) { + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } else if (!IsLoopInvariant(RHS)) + // Both LHS and RHS are loop variant + return false; + + if (parseIvAgaisntLimit(L, LHS, RHS, Pred, SE, Index, End)) + return true; + + if (reassociateSubLHS(L, LHS, RHS, Pred, SE, Index, End)) + return true; + + // TODO: support ReassociateAddLHS + return false; +} + +// Try to parse range check in the form of "IV vs Limit" +bool InductiveRangeCheck::parseIvAgaisntLimit(Loop *L, Value *LHS, Value *RHS, + ICmpInst::Predicate Pred, + ScalarEvolution &SE, + const SCEVAddRecExpr *&Index, + const SCEV *&End) { + + auto SIntMaxSCEV = [&](Type *T) { + unsigned BitWidth = cast<IntegerType>(T)->getBitWidth(); + return SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + }; + + const auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(LHS)); + if (!AddRec) + return false; + + // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". + // We can potentially do much better here. + // If we want to adjust upper bound for the unsigned range check as we do it + // for signed one, we will need to pick Unsigned max switch (Pred) { default: return false; - case ICmpInst::ICMP_SLE: - std::swap(LHS, RHS); - [[fallthrough]]; case ICmpInst::ICMP_SGE: - IsSigned = true; if (match(RHS, m_ConstantInt<0>())) { - Index = LHS; - return true; // Lower. + Index = AddRec; + End = SIntMaxSCEV(Index->getType()); + return true; } return false; - case ICmpInst::ICMP_SLT: - std::swap(LHS, RHS); - [[fallthrough]]; case ICmpInst::ICMP_SGT: - IsSigned = true; if (match(RHS, m_ConstantInt<-1>())) { - Index = LHS; - return true; // Lower. - } - - if (IsLoopInvariant(LHS)) { - Index = RHS; - Length = LHS; - return true; // Upper. + Index = AddRec; + End = SIntMaxSCEV(Index->getType()); + return true; } return false; + case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: - std::swap(LHS, RHS); - [[fallthrough]]; - case ICmpInst::ICMP_UGT: - IsSigned = false; - if (IsLoopInvariant(LHS)) { - Index = RHS; - Length = LHS; - return true; // Both lower and upper. + Index = AddRec; + End = SE.getSCEV(RHS); + return true; + + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + const SCEV *One = SE.getOne(RHS->getType()); + const SCEV *RHSS = SE.getSCEV(RHS); + bool Signed = Pred == ICmpInst::ICMP_SLE; + if (SE.willNotOverflow(Instruction::BinaryOps::Add, Signed, RHSS, One)) { + Index = AddRec; + End = SE.getAddExpr(RHSS, One); + return true; } return false; } @@ -348,6 +364,126 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, llvm_unreachable("default clause returns!"); } +// Try to parse range check in the form of "IV - Offset vs Limit" or "Offset - +// IV vs Limit" +bool InductiveRangeCheck::reassociateSubLHS( + Loop *L, Value *VariantLHS, Value *InvariantRHS, ICmpInst::Predicate Pred, + ScalarEvolution &SE, const SCEVAddRecExpr *&Index, const SCEV *&End) { + Value *LHS, *RHS; + if (!match(VariantLHS, m_Sub(m_Value(LHS), m_Value(RHS)))) + return false; + + const SCEV *IV = SE.getSCEV(LHS); + const SCEV *Offset = SE.getSCEV(RHS); + const SCEV *Limit = SE.getSCEV(InvariantRHS); + + bool OffsetSubtracted = false; + if (SE.isLoopInvariant(IV, L)) + // "Offset - IV vs Limit" + std::swap(IV, Offset); + else if (SE.isLoopInvariant(Offset, L)) + // "IV - Offset vs Limit" + OffsetSubtracted = true; + else + return false; + + const auto *AddRec = dyn_cast<SCEVAddRecExpr>(IV); + if (!AddRec) + return false; + + // In order to turn "IV - Offset < Limit" into "IV < Limit + Offset", we need + // to be able to freely move values from left side of inequality to right side + // (just as in normal linear arithmetics). Overflows make things much more + // complicated, so we want to avoid this. + // + // Let's prove that the initial subtraction doesn't overflow with all IV's + // values from the safe range constructed for that check. + // + // [Case 1] IV - Offset < Limit + // It doesn't overflow if: + // SINT_MIN <= IV - Offset <= SINT_MAX + // In terms of scaled SINT we need to prove: + // SINT_MIN + Offset <= IV <= SINT_MAX + Offset + // Safe range will be constructed: + // 0 <= IV < Limit + Offset + // It means that 'IV - Offset' doesn't underflow, because: + // SINT_MIN + Offset < 0 <= IV + // and doesn't overflow: + // IV < Limit + Offset <= SINT_MAX + Offset + // + // [Case 2] Offset - IV > Limit + // It doesn't overflow if: + // SINT_MIN <= Offset - IV <= SINT_MAX + // In terms of scaled SINT we need to prove: + // -SINT_MIN >= IV - Offset >= -SINT_MAX + // Offset - SINT_MIN >= IV >= Offset - SINT_MAX + // Safe range will be constructed: + // 0 <= IV < Offset - Limit + // It means that 'Offset - IV' doesn't underflow, because + // Offset - SINT_MAX < 0 <= IV + // and doesn't overflow: + // IV < Offset - Limit <= Offset - SINT_MIN + // + // For the computed upper boundary of the IV's range (Offset +/- Limit) we + // don't know exactly whether it overflows or not. So if we can't prove this + // fact at compile time, we scale boundary computations to a wider type with + // the intention to add runtime overflow check. + + auto getExprScaledIfOverflow = [&](Instruction::BinaryOps BinOp, + const SCEV *LHS, + const SCEV *RHS) -> const SCEV * { + const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, + SCEV::NoWrapFlags, unsigned); + switch (BinOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + Operation = &ScalarEvolution::getAddExpr; + break; + case Instruction::Sub: + Operation = &ScalarEvolution::getMinusSCEV; + break; + } + + if (SE.willNotOverflow(BinOp, ICmpInst::isSigned(Pred), LHS, RHS, + cast<Instruction>(VariantLHS))) + return (SE.*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0); + + // We couldn't prove that the expression does not overflow. + // Than scale it to a wider type to check overflow at runtime. + auto *Ty = cast<IntegerType>(LHS->getType()); + if (Ty->getBitWidth() > MaxTypeSizeForOverflowCheck) + return nullptr; + + auto WideTy = IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + return (SE.*Operation)(SE.getSignExtendExpr(LHS, WideTy), + SE.getSignExtendExpr(RHS, WideTy), SCEV::FlagAnyWrap, + 0); + }; + + if (OffsetSubtracted) + // "IV - Offset < Limit" -> "IV" < Offset + Limit + Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Offset, Limit); + else { + // "Offset - IV > Limit" -> "IV" < Offset - Limit + Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Sub, Offset, Limit); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { + // "Expr <= Limit" -> "Expr < Limit + 1" + if (Pred == ICmpInst::ICMP_SLE && Limit) + Limit = getExprScaledIfOverflow(Instruction::BinaryOps::Add, Limit, + SE.getOne(Limit->getType())); + if (Limit) { + Index = AddRec; + End = Limit; + return true; + } + } + return false; +} + void InductiveRangeCheck::extractRangeChecksFromCond( Loop *L, ScalarEvolution &SE, Use &ConditionUse, SmallVectorImpl<InductiveRangeCheck> &Checks, @@ -369,32 +505,17 @@ void InductiveRangeCheck::extractRangeChecksFromCond( if (!ICI) return; - Value *Length = nullptr, *Index; - bool IsSigned; - if (!parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned)) + const SCEV *End = nullptr; + const SCEVAddRecExpr *IndexAddRec = nullptr; + if (!parseRangeCheckICmp(L, ICI, SE, IndexAddRec, End)) return; - const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index)); - bool IsAffineIndex = - IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); + assert(IndexAddRec && "IndexAddRec was not computed"); + assert(End && "End was not computed"); - if (!IsAffineIndex) + if ((IndexAddRec->getLoop() != L) || !IndexAddRec->isAffine()) return; - const SCEV *End = nullptr; - // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". - // We can potentially do much better here. - if (Length) - End = SE.getSCEV(Length); - else { - // So far we can only reach this point for Signed range check. This may - // change in future. In this case we will need to pick Unsigned max for the - // unsigned range check. - unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth(); - const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); - End = SIntMax; - } - InductiveRangeCheck IRC; IRC.End = End; IRC.Begin = IndexAddRec->getStart(); @@ -405,16 +526,29 @@ void InductiveRangeCheck::extractRangeChecksFromCond( void InductiveRangeCheck::extractRangeChecksFromBranch( BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, - SmallVectorImpl<InductiveRangeCheck> &Checks) { + SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed) { if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) return; + unsigned IndexLoopSucc = L->contains(BI->getSuccessor(0)) ? 0 : 1; + assert(L->contains(BI->getSuccessor(IndexLoopSucc)) && + "No edges coming to loop?"); BranchProbability LikelyTaken(15, 16); if (!SkipProfitabilityChecks && BPI && - BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) + BPI->getEdgeProbability(BI->getParent(), IndexLoopSucc) < LikelyTaken) return; + // IRCE expects branch's true edge comes to loop. Invert branch for opposite + // case. + if (IndexLoopSucc != 0) { + IRBuilder<> Builder(BI); + InvertBranch(BI, Builder); + if (BPI) + BPI->swapSuccEdgesProbabilities(BI->getParent()); + Changed = true; + } + SmallPtrSet<Value *, 8> Visited; InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0), Checks, Visited); @@ -622,7 +756,7 @@ class LoopConstrainer { // Information about the original loop we started out with. Loop &OriginalLoop; - const SCEV *LatchTakenCount = nullptr; + const IntegerType *ExitCountTy = nullptr; BasicBlock *OriginalPreheader = nullptr; // The preheader of the main loop. This may or may not be different from @@ -671,8 +805,7 @@ static bool isSafeDecreasingBound(const SCEV *Start, LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); - LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) - << "\n"); + LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n"); LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); bool IsSigned = ICmpInst::isSigned(Pred); @@ -719,8 +852,7 @@ static bool isSafeIncreasingBound(const SCEV *Start, LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); - LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) - << "\n"); + LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n"); LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); bool IsSigned = ICmpInst::isSigned(Pred); @@ -746,6 +878,19 @@ static bool isSafeIncreasingBound(const SCEV *Start, SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); } +/// Returns estimate for max latch taken count of the loop of the narrowest +/// available type. If the latch block has such estimate, it is returned. +/// Otherwise, we use max exit count of whole loop (that is potentially of wider +/// type than latch check itself), which is still better than no estimate. +static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, + const Loop &L) { + const SCEV *FromBlock = + SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); + if (isa<SCEVCouldNotCompute>(FromBlock)) + return SE.getSymbolicMaxBackedgeTakenCount(&L); + return FromBlock; +} + std::optional<LoopStructure> LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, const char *&FailureReason) { @@ -788,11 +933,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, return std::nullopt; } - const SCEV *LatchCount = SE.getExitCount(&L, Latch); - if (isa<SCEVCouldNotCompute>(LatchCount)) { + const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); + if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) { FailureReason = "could not compute latch count"; return std::nullopt; } + assert(SE.getLoopDisposition(MaxBETakenCount, &L) == + ScalarEvolution::LoopInvariant && + "loop variant exit count doesn't make sense!"); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *LeftValue = ICI->getOperand(0); @@ -1017,10 +1165,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, } BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); - assert(SE.getLoopDisposition(LatchCount, &L) == - ScalarEvolution::LoopInvariant && - "loop variant exit count doesn't make sense!"); - assert(!L.contains(LatchExit) && "expected an exit block!"); const DataLayout &DL = Preheader->getModule()->getDataLayout(); SCEVExpander Expander(SE, DL, "irce"); @@ -1062,14 +1206,11 @@ static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, std::optional<LoopConstrainer::SubRanges> LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { - IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); - auto *RTy = cast<IntegerType>(Range.getType()); - // We only support wide range checks and narrow latches. - if (!AllowNarrowLatchCondition && RTy != Ty) + if (!AllowNarrowLatchCondition && RTy != ExitCountTy) return std::nullopt; - if (RTy->getBitWidth() < Ty->getBitWidth()) + if (RTy->getBitWidth() < ExitCountTy->getBitWidth()) return std::nullopt; LoopConstrainer::SubRanges Result; @@ -1403,10 +1544,12 @@ Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, bool LoopConstrainer::run() { BasicBlock *Preheader = nullptr; - LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch); + const SCEV *MaxBETakenCount = + getNarrowestLatchMaxTakenCountEstimate(SE, OriginalLoop); Preheader = OriginalLoop.getLoopPreheader(); - assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr && + assert(!isa<SCEVCouldNotCompute>(MaxBETakenCount) && Preheader != nullptr && "preconditions!"); + ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType()); OriginalPreheader = Preheader; MainLoopPreheader = Preheader; @@ -1574,6 +1717,27 @@ bool LoopConstrainer::run() { CanonicalizeLoop(PostL, false); CanonicalizeLoop(&OriginalLoop, true); + /// At this point: + /// - We've broken a "main loop" out of the loop in a way that the "main loop" + /// runs with the induction variable in a subset of [Begin, End). + /// - There is no overflow when computing "main loop" exit limit. + /// - Max latch taken count of the loop is limited. + /// It guarantees that induction variable will not overflow iterating in the + /// "main loop". + if (auto BO = dyn_cast<BinaryOperator>(MainLoopStructure.IndVarBase)) + if (IsSignedPredicate) + BO->setHasNoSignedWrap(true); + /// TODO: support unsigned predicate. + /// To add NUW flag we need to prove that both operands of BO are + /// non-negative. E.g: + /// ... + /// %iv.next = add nsw i32 %iv, -1 + /// %cmp = icmp ult i32 %iv.next, %n + /// br i1 %cmp, label %loopexit, label %loop + /// + /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will + /// overflow, therefore NUW flag is not legal here. + return true; } @@ -1588,11 +1752,13 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, // if latch check is more narrow. auto *IVType = dyn_cast<IntegerType>(IndVar->getType()); auto *RCType = dyn_cast<IntegerType>(getBegin()->getType()); + auto *EndType = dyn_cast<IntegerType>(getEnd()->getType()); // Do not work with pointer types. if (!IVType || !RCType) return std::nullopt; if (IVType->getBitWidth() > RCType->getBitWidth()) return std::nullopt; + // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1631,6 +1797,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, assert(!D->getValue()->isZero() && "Recurrence with zero step?"); unsigned BitWidth = RCType->getBitWidth(); const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + const SCEV *SIntMin = SE.getConstant(APInt::getSignedMinValue(BitWidth)); // Subtract Y from X so that it does not go through border of the IV // iteration space. Mathematically, it is equivalent to: @@ -1682,6 +1849,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. auto SCEVCheckNonNegative = [&](const SCEV *X) { const Loop *L = IndVar->getLoop(); + const SCEV *Zero = SE.getZero(X->getType()); const SCEV *One = SE.getOne(X->getType()); // Can we trivially prove that X is a non-negative or negative value? if (isKnownNonNegativeInLoop(X, L, SE)) @@ -1693,6 +1861,25 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, const SCEV *NegOne = SE.getNegativeSCEV(One); return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One); }; + + // This function returns SCEV equal to 1 if X will not overflow in terms of + // range check type, 0 otherwise. + auto SCEVCheckWillNotOverflow = [&](const SCEV *X) { + // X doesn't overflow if SINT_MAX >= X. + // Then if (SINT_MAX - X) >= 0, X doesn't overflow + const SCEV *SIntMaxExt = SE.getSignExtendExpr(SIntMax, X->getType()); + const SCEV *OverflowCheck = + SCEVCheckNonNegative(SE.getMinusSCEV(SIntMaxExt, X)); + + // X doesn't underflow if X >= SINT_MIN. + // Then if (X - SINT_MIN) >= 0, X doesn't underflow + const SCEV *SIntMinExt = SE.getSignExtendExpr(SIntMin, X->getType()); + const SCEV *UnderflowCheck = + SCEVCheckNonNegative(SE.getMinusSCEV(X, SIntMinExt)); + + return SE.getMulExpr(OverflowCheck, UnderflowCheck); + }; + // FIXME: Current implementation of ClampedSubtract implicitly assumes that // X is non-negative (in sense of a signed value). We need to re-implement // this function in a way that it will correctly handle negative X as well. @@ -1702,10 +1889,35 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, // Note that this may pessimize elimination of unsigned range checks against // negative values. const SCEV *REnd = getEnd(); - const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd); + const SCEV *EndWillNotOverflow = SE.getOne(RCType); + + auto PrintRangeCheck = [&](raw_ostream &OS) { + auto L = IndVar->getLoop(); + OS << "irce: in function "; + OS << L->getHeader()->getParent()->getName(); + OS << ", in "; + L->print(OS); + OS << "there is range check with scaled boundary:\n"; + print(OS); + }; + + if (EndType->getBitWidth() > RCType->getBitWidth()) { + assert(EndType->getBitWidth() == RCType->getBitWidth() * 2); + if (PrintScaledBoundaryRangeChecks) + PrintRangeCheck(errs()); + // End is computed with extended type but will be truncated to a narrow one + // type of range check. Therefore we need a check that the result will not + // overflow in terms of narrow type. + EndWillNotOverflow = + SE.getTruncateExpr(SCEVCheckWillNotOverflow(REnd), RCType); + REnd = SE.getTruncateExpr(REnd, RCType); + } + + const SCEV *RuntimeChecks = + SE.getMulExpr(SCEVCheckNonNegative(REnd), EndWillNotOverflow); + const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), RuntimeChecks); + const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), RuntimeChecks); - const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative); - const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative); return InductiveRangeCheck::Range(Begin, End); } @@ -1825,39 +2037,6 @@ PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) { return getLoopPassPreservedAnalyses(); } -bool IRCELegacyPass::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - - ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - BranchProbabilityInfo &BPI = - getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); - - bool Changed = false; - - for (const auto &L : LI) { - Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, - /*PreserveLCSSA=*/false); - Changed |= formLCSSARecursively(*L, DT, &LI, &SE); - } - - SmallPriorityWorklist<Loop *, 4> Worklist; - appendLoopsToWorklist(LI, Worklist); - auto LPMAddNewLoop = [&](Loop *NL, bool IsSubloop) { - if (!IsSubloop) - appendLoopsToWorklist(*NL, Worklist); - }; - - while (!Worklist.empty()) { - Loop *L = Worklist.pop_back_val(); - Changed |= IRCE.run(L, LPMAddNewLoop); - } - return Changed; -} - bool InductiveRangeCheckElimination::isProfitableToTransform(const Loop &L, LoopStructure &LS) { @@ -1904,14 +2083,15 @@ bool InductiveRangeCheckElimination::run( LLVMContext &Context = Preheader->getContext(); SmallVector<InductiveRangeCheck, 16> RangeChecks; + bool Changed = false; for (auto *BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI, - RangeChecks); + RangeChecks, Changed); if (RangeChecks.empty()) - return false; + return Changed; auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { OS << "irce: looking at loop "; L->print(OS); @@ -1932,16 +2112,15 @@ bool InductiveRangeCheckElimination::run( if (!MaybeLoopStructure) { LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason << "\n";); - return false; + return Changed; } LoopStructure LS = *MaybeLoopStructure; if (!isProfitableToTransform(*L, LS)) - return false; + return Changed; const SCEVAddRecExpr *IndVar = cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep))); std::optional<InductiveRangeCheck::Range> SafeIterRange; - Instruction *ExprInsertPt = Preheader->getTerminator(); SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; // Basing on the type of latch predicate, we interpret the IV iteration range @@ -1951,7 +2130,6 @@ bool InductiveRangeCheckElimination::run( auto IntersectRange = LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange; - IRBuilder<> B(ExprInsertPt); for (InductiveRangeCheck &IRC : RangeChecks) { auto Result = IRC.computeSafeIterationSpace(SE, IndVar, LS.IsSignedPredicate); @@ -1967,12 +2145,13 @@ bool InductiveRangeCheckElimination::run( } if (!SafeIterRange) - return false; + return Changed; LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange); - bool Changed = LC.run(); - if (Changed) { + if (LC.run()) { + Changed = true; + auto PrintConstrainedLoopInfo = [L]() { dbgs() << "irce: in function "; dbgs() << L->getHeader()->getParent()->getName() << ": "; @@ -1997,7 +2176,3 @@ bool InductiveRangeCheckElimination::run( return Changed; } - -Pass *llvm::createInductiveRangeCheckEliminationPass() { - return new IRCELegacyPass(); -} diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index 114738a35fd1..c2b5a12fd63f 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -76,14 +76,14 @@ // Second, IR rewriting in Step 2 also needs to be circular. For example, // converting %y to addrspace(3) requires the compiler to know the converted // %y2, but converting %y2 needs the converted %y. To address this complication, -// we break these cycles using "undef" placeholders. When converting an +// we break these cycles using "poison" placeholders. When converting an // instruction `I` to a new address space, if its operand `Op` is not converted -// yet, we let `I` temporarily use `undef` and fix all the uses of undef later. +// yet, we let `I` temporarily use `poison` and fix all the uses later. // For instance, our algorithm first converts %y to -// %y' = phi float addrspace(3)* [ %input, undef ] +// %y' = phi float addrspace(3)* [ %input, poison ] // Then, it converts %y2 to // %y2' = getelementptr %y', 1 -// Finally, it fixes the undef in %y' so that +// Finally, it fixes the poison in %y' so that // %y' = phi float addrspace(3)* [ %input, %y2' ] // //===----------------------------------------------------------------------===// @@ -206,7 +206,7 @@ class InferAddressSpacesImpl { Instruction *I, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, - SmallVectorImpl<const Use *> *UndefUsesToFix) const; + SmallVectorImpl<const Use *> *PoisonUsesToFix) const; // Changes the flat address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of @@ -233,7 +233,7 @@ class InferAddressSpacesImpl { Value *V, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, - SmallVectorImpl<const Use *> *UndefUsesToFix) const; + SmallVectorImpl<const Use *> *PoisonUsesToFix) const; unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; unsigned getPredicatedAddrSpace(const Value &V, Value *Opnd) const; @@ -256,6 +256,12 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", false, false) +static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) { + assert(Ty->isPtrOrPtrVectorTy()); + PointerType *NPT = PointerType::get(Ty->getContext(), NewAddrSpace); + return Ty->getWithNewType(NPT); +} + // Check whether that's no-op pointer bicast using a pair of // `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over // different address spaces. @@ -301,14 +307,14 @@ static bool isAddressExpression(const Value &V, const DataLayout &DL, switch (Op->getOpcode()) { case Instruction::PHI: - assert(Op->getType()->isPointerTy()); + assert(Op->getType()->isPtrOrPtrVectorTy()); return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return true; case Instruction::Select: - return Op->getType()->isPointerTy(); + return Op->getType()->isPtrOrPtrVectorTy(); case Instruction::Call: { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V); return II && II->getIntrinsicID() == Intrinsic::ptrmask; @@ -373,6 +379,24 @@ bool InferAddressSpacesImpl::rewriteIntrinsicOperands(IntrinsicInst *II, case Intrinsic::ptrmask: // This is handled as an address expression, not as a use memory operation. return false; + case Intrinsic::masked_gather: { + Type *RetTy = II->getType(); + Type *NewPtrTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {RetTy, NewPtrTy}); + II->setArgOperand(0, NewV); + II->setCalledFunction(NewDecl); + return true; + } + case Intrinsic::masked_scatter: { + Type *ValueTy = II->getOperand(0)->getType(); + Type *NewPtrTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {ValueTy, NewPtrTy}); + II->setArgOperand(1, NewV); + II->setCalledFunction(NewDecl); + return true; + } default: { Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); if (!Rewrite) @@ -394,6 +418,14 @@ void InferAddressSpacesImpl::collectRewritableIntrinsicOperands( appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); break; + case Intrinsic::masked_gather: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), + PostorderStack, Visited); + break; + case Intrinsic::masked_scatter: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(1), + PostorderStack, Visited); + break; default: SmallVector<int, 2> OpIndexes; if (TTI->collectFlatAddressOperands(OpIndexes, IID)) { @@ -412,7 +444,7 @@ void InferAddressSpacesImpl::collectRewritableIntrinsicOperands( void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, DenseSet<Value *> &Visited) const { - assert(V->getType()->isPointerTy()); + assert(V->getType()->isPtrOrPtrVectorTy()); // Generic addressing expressions may be hidden in nested constant // expressions. @@ -460,8 +492,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { // addressing calculations may also be faster. for (Instruction &I : instructions(F)) { if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { - if (!GEP->getType()->isVectorTy()) - PushPtrOperand(GEP->getPointerOperand()); + PushPtrOperand(GEP->getPointerOperand()); } else if (auto *LI = dyn_cast<LoadInst>(&I)) PushPtrOperand(LI->getPointerOperand()); else if (auto *SI = dyn_cast<StoreInst>(&I)) @@ -480,14 +511,12 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { } else if (auto *II = dyn_cast<IntrinsicInst>(&I)) collectRewritableIntrinsicOperands(II, PostorderStack, Visited); else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) { - // FIXME: Handle vectors of pointers - if (Cmp->getOperand(0)->getType()->isPointerTy()) { + if (Cmp->getOperand(0)->getType()->isPtrOrPtrVectorTy()) { PushPtrOperand(Cmp->getOperand(0)); PushPtrOperand(Cmp->getOperand(1)); } } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) { - if (!ASC->getType()->isVectorTy()) - PushPtrOperand(ASC->getPointerOperand()); + PushPtrOperand(ASC->getPointerOperand()); } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) { if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI)) PushPtrOperand( @@ -521,16 +550,15 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { // A helper function for cloneInstructionWithNewAddressSpace. Returns the clone // of OperandUse.get() in the new address space. If the clone is not ready yet, -// returns an undef in the new address space as a placeholder. -static Value *operandWithNewAddressSpaceOrCreateUndef( +// returns poison in the new address space as a placeholder. +static Value *operandWithNewAddressSpaceOrCreatePoison( const Use &OperandUse, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, - SmallVectorImpl<const Use *> *UndefUsesToFix) { + SmallVectorImpl<const Use *> *PoisonUsesToFix) { Value *Operand = OperandUse.get(); - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast<PointerType>(Operand->getType()), NewAddrSpace); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace); if (Constant *C = dyn_cast<Constant>(Operand)) return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); @@ -543,23 +571,22 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( if (I != PredicatedAS.end()) { // Insert an addrspacecast on that operand before the user. unsigned NewAS = I->second; - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast<PointerType>(Operand->getType()), NewAS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS); auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); NewI->insertBefore(Inst); NewI->setDebugLoc(Inst->getDebugLoc()); return NewI; } - UndefUsesToFix->push_back(&OperandUse); - return UndefValue::get(NewPtrTy); + PoisonUsesToFix->push_back(&OperandUse); + return PoisonValue::get(NewPtrTy); } // Returns a clone of `I` with its operands converted to those specified in // ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an // operand whose address space needs to be modified might not exist in -// ValueWithNewAddrSpace. In that case, uses undef as a placeholder operand and -// adds that operand use to UndefUsesToFix so that caller can fix them later. +// ValueWithNewAddrSpace. In that case, uses poison as a placeholder operand and +// adds that operand use to PoisonUsesToFix so that caller can fix them later. // // Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast // from a pointer whose type already matches. Therefore, this function returns a @@ -571,9 +598,8 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( Instruction *I, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, - SmallVectorImpl<const Use *> *UndefUsesToFix) const { - Type *NewPtrType = PointerType::getWithSamePointeeType( - cast<PointerType>(I->getType()), NewAddrSpace); + SmallVectorImpl<const Use *> *PoisonUsesToFix) const { + Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace); if (I->getOpcode() == Instruction::AddrSpaceCast) { Value *Src = I->getOperand(0); @@ -590,9 +616,9 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( // Technically the intrinsic ID is a pointer typed argument, so specially // handle calls early. assert(II->getIntrinsicID() == Intrinsic::ptrmask); - Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef( + Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison( II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace, - PredicatedAS, UndefUsesToFix); + PredicatedAS, PoisonUsesToFix); Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr); if (Rewrite) { @@ -607,8 +633,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( if (AS != UninitializedAddressSpace) { // For the assumed address space, insert an `addrspacecast` to make that // explicit. - Type *NewPtrTy = PointerType::getWithSamePointeeType( - cast<PointerType>(I->getType()), AS); + Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS); auto *NewI = new AddrSpaceCastInst(I, NewPtrTy); NewI->insertAfter(I); return NewI; @@ -617,19 +642,19 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( // Computes the converted pointer operands. SmallVector<Value *, 4> NewPointerOperands; for (const Use &OperandUse : I->operands()) { - if (!OperandUse.get()->getType()->isPointerTy()) + if (!OperandUse.get()->getType()->isPtrOrPtrVectorTy()) NewPointerOperands.push_back(nullptr); else - NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( + NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreatePoison( OperandUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, - UndefUsesToFix)); + PoisonUsesToFix)); } switch (I->getOpcode()) { case Instruction::BitCast: return new BitCastInst(NewPointerOperands[0], NewPtrType); case Instruction::PHI: { - assert(I->getType()->isPointerTy()); + assert(I->getType()->isPtrOrPtrVectorTy()); PHINode *PHI = cast<PHINode>(I); PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { @@ -648,7 +673,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( return NewGEP; } case Instruction::Select: - assert(I->getType()->isPointerTy()); + assert(I->getType()->isPtrOrPtrVectorTy()); return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], NewPointerOperands[2], "", nullptr, I); case Instruction::IntToPtr: { @@ -674,10 +699,10 @@ static Value *cloneConstantExprWithNewAddressSpace( ConstantExpr *CE, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL, const TargetTransformInfo *TTI) { - Type *TargetType = CE->getType()->isPointerTy() - ? PointerType::getWithSamePointeeType( - cast<PointerType>(CE->getType()), NewAddrSpace) - : CE->getType(); + Type *TargetType = + CE->getType()->isPtrOrPtrVectorTy() + ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace) + : CE->getType(); if (CE->getOpcode() == Instruction::AddrSpaceCast) { // Because CE is flat, the source address space must be specific. @@ -694,18 +719,6 @@ static Value *cloneConstantExprWithNewAddressSpace( return ConstantExpr::getAddrSpaceCast(CE, TargetType); } - if (CE->getOpcode() == Instruction::Select) { - Constant *Src0 = CE->getOperand(1); - Constant *Src1 = CE->getOperand(2); - if (Src0->getType()->getPointerAddressSpace() == - Src1->getType()->getPointerAddressSpace()) { - - return ConstantExpr::getSelect( - CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType), - ConstantExpr::getAddrSpaceCast(Src1, TargetType)); - } - } - if (CE->getOpcode() == Instruction::IntToPtr) { assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI)); Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0); @@ -758,19 +771,19 @@ static Value *cloneConstantExprWithNewAddressSpace( // ValueWithNewAddrSpace. This function is called on every flat address // expression whose address space needs to be modified, in postorder. // -// See cloneInstructionWithNewAddressSpace for the meaning of UndefUsesToFix. +// See cloneInstructionWithNewAddressSpace for the meaning of PoisonUsesToFix. Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace( Value *V, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, - SmallVectorImpl<const Use *> *UndefUsesToFix) const { + SmallVectorImpl<const Use *> *PoisonUsesToFix) const { // All values in Postorder are flat address expressions. assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace && isAddressExpression(*V, *DL, TTI)); if (Instruction *I = dyn_cast<Instruction>(V)) { Value *NewV = cloneInstructionWithNewAddressSpace( - I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, UndefUsesToFix); + I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, PoisonUsesToFix); if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) { if (NewI->getParent() == nullptr) { NewI->insertBefore(I); @@ -1114,7 +1127,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( // operands are converted, the clone is naturally in the new address space by // construction. ValueToValueMapTy ValueWithNewAddrSpace; - SmallVector<const Use *, 32> UndefUsesToFix; + SmallVector<const Use *, 32> PoisonUsesToFix; for (Value* V : Postorder) { unsigned NewAddrSpace = InferredAddrSpace.lookup(V); @@ -1126,7 +1139,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { Value *New = cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace, - PredicatedAS, &UndefUsesToFix); + PredicatedAS, &PoisonUsesToFix); if (New) ValueWithNewAddrSpace[V] = New; } @@ -1135,16 +1148,16 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (ValueWithNewAddrSpace.empty()) return false; - // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace. - for (const Use *UndefUse : UndefUsesToFix) { - User *V = UndefUse->getUser(); + // Fixes all the poison uses generated by cloneInstructionWithNewAddressSpace. + for (const Use *PoisonUse : PoisonUsesToFix) { + User *V = PoisonUse->getUser(); User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V)); if (!NewV) continue; - unsigned OperandNo = UndefUse->getOperandNo(); - assert(isa<UndefValue>(NewV->getOperand(OperandNo))); - NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); + unsigned OperandNo = PoisonUse->getOperandNo(); + assert(isa<PoisonValue>(NewV->getOperand(OperandNo))); + NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(PoisonUse->get())); } SmallVector<Instruction *, 16> DeadInstructions; @@ -1238,20 +1251,6 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) { unsigned NewAS = NewV->getType()->getPointerAddressSpace(); if (ASC->getDestAddressSpace() == NewAS) { - if (!cast<PointerType>(ASC->getType()) - ->hasSameElementTypeAs( - cast<PointerType>(NewV->getType()))) { - BasicBlock::iterator InsertPos; - if (Instruction *NewVInst = dyn_cast<Instruction>(NewV)) - InsertPos = std::next(NewVInst->getIterator()); - else if (Instruction *VInst = dyn_cast<Instruction>(V)) - InsertPos = std::next(VInst->getIterator()); - else - InsertPos = ASC->getIterator(); - - NewV = CastInst::Create(Instruction::BitCast, NewV, - ASC->getType(), "", &*InsertPos); - } ASC->replaceAllUsesWith(NewV); DeadInstructions.push_back(ASC); continue; diff --git a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp index 4644905adba3..ee9452ce1c7d 100644 --- a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -11,7 +11,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -26,8 +25,7 @@ using namespace llvm; STATISTIC(NumSimplified, "Number of redundant instructions removed"); -static bool runImpl(Function &F, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +static bool runImpl(Function &F, const SimplifyQuery &SQ) { SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; @@ -51,7 +49,7 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ, DeadInstsInBB.push_back(&I); Changed = true; } else if (!I.use_empty()) { - if (Value *V = simplifyInstruction(&I, SQ, ORE)) { + if (Value *V = simplifyInstruction(&I, SQ)) { // Mark all uses for resimplification next time round the loop. for (User *U : I.users()) Next->insert(cast<Instruction>(U)); @@ -88,7 +86,6 @@ struct InstSimplifyLegacyPass : public FunctionPass { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } /// Remove instructions that simplify. @@ -102,11 +99,9 @@ struct InstSimplifyLegacyPass : public FunctionPass { &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - OptimizationRemarkEmitter *ORE = - &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); const DataLayout &DL = F.getParent()->getDataLayout(); const SimplifyQuery SQ(DL, TLI, DT, AC); - return runImpl(F, SQ, ORE); + return runImpl(F, SQ); } }; } // namespace @@ -117,7 +112,6 @@ INITIALIZE_PASS_BEGIN(InstSimplifyLegacyPass, "instsimplify", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(InstSimplifyLegacyPass, "instsimplify", "Remove redundant instructions", false, false) @@ -131,10 +125,9 @@ PreservedAnalyses InstSimplifyPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); const DataLayout &DL = F.getParent()->getDataLayout(); const SimplifyQuery SQ(DL, &TLI, &DT, &AC); - bool Changed = runImpl(F, SQ, &ORE); + bool Changed = runImpl(F, SQ); if (!Changed) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index f41eaed2e3e7..24390f1b54f6 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -23,7 +23,6 @@ #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -31,6 +30,7 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -40,6 +40,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -57,15 +58,12 @@ #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/BlockFrequency.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -114,68 +112,6 @@ static cl::opt<bool> ThreadAcrossLoopHeaders( cl::desc("Allow JumpThreading to thread across loop headers, for testing"), cl::init(false), cl::Hidden); - -namespace { - - /// This pass performs 'jump threading', which looks at blocks that have - /// multiple predecessors and multiple successors. If one or more of the - /// predecessors of the block can be proven to always jump to one of the - /// successors, we forward the edge from the predecessor to the successor by - /// duplicating the contents of this block. - /// - /// An example of when this can occur is code like this: - /// - /// if () { ... - /// X = 4; - /// } - /// if (X < 3) { - /// - /// In this case, the unconditional branch at the end of the first if can be - /// revectored to the false side of the second if. - class JumpThreading : public FunctionPass { - JumpThreadingPass Impl; - - public: - static char ID; // Pass identification - - JumpThreading(int T = -1) : FunctionPass(ID), Impl(T) { - initializeJumpThreadingPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<LazyValueInfoWrapperPass>(); - AU.addPreserved<LazyValueInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - } - - void releaseMemory() override { Impl.releaseMemory(); } - }; - -} // end anonymous namespace - -char JumpThreading::ID = 0; - -INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", - "Jump Threading", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(JumpThreading, "jump-threading", - "Jump Threading", false, false) - -// Public interface to the Jump Threading pass -FunctionPass *llvm::createJumpThreadingPass(int Threshold) { - return new JumpThreading(Threshold); -} - JumpThreadingPass::JumpThreadingPass(int T) { DefaultBBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); } @@ -306,102 +242,81 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { } } -/// runOnFunction - Toplevel algorithm. -bool JumpThreading::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - auto TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - // Jump Threading has no sense for the targets with divergent CF - if (TTI->hasBranchDivergence()) - return false; - auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); - auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); - std::unique_ptr<BlockFrequencyInfo> BFI; - std::unique_ptr<BranchProbabilityInfo> BPI; - if (F.hasProfileData()) { - LoopInfo LI{*DT}; - BPI.reset(new BranchProbabilityInfo(F, LI, TLI)); - BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); - } - - bool Changed = Impl.runImpl(F, TLI, TTI, LVI, AA, &DTU, F.hasProfileData(), - std::move(BFI), std::move(BPI)); - if (PrintLVIAfterJumpThreading) { - dbgs() << "LVI for function '" << F.getName() << "':\n"; - LVI->printLVI(F, DTU.getDomTree(), dbgs()); - } - return Changed; -} - PreservedAnalyses JumpThreadingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); // Jump Threading has no sense for the targets with divergent CF - if (TTI.hasBranchDivergence()) + if (TTI.hasBranchDivergence(&F)) return PreservedAnalyses::all(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - - std::unique_ptr<BlockFrequencyInfo> BFI; - std::unique_ptr<BranchProbabilityInfo> BPI; - if (F.hasProfileData()) { - LoopInfo LI{DT}; - BPI.reset(new BranchProbabilityInfo(F, LI, &TLI)); - BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); - } + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - bool Changed = runImpl(F, &TLI, &TTI, &LVI, &AA, &DTU, F.hasProfileData(), - std::move(BFI), std::move(BPI)); + bool Changed = + runImpl(F, &AM, &TLI, &TTI, &LVI, &AA, + std::make_unique<DomTreeUpdater>( + &DT, nullptr, DomTreeUpdater::UpdateStrategy::Lazy), + std::nullopt, std::nullopt); if (PrintLVIAfterJumpThreading) { dbgs() << "LVI for function '" << F.getName() << "':\n"; - LVI.printLVI(F, DTU.getDomTree(), dbgs()); + LVI.printLVI(F, getDomTreeUpdater()->getDomTree(), dbgs()); } if (!Changed) return PreservedAnalyses::all(); - PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<LazyValueAnalysis>(); - return PA; + + + getDomTreeUpdater()->flush(); + +#if defined(EXPENSIVE_CHECKS) + assert(getDomTreeUpdater()->getDomTree().verify( + DominatorTree::VerificationLevel::Full) && + "DT broken after JumpThreading"); + assert((!getDomTreeUpdater()->hasPostDomTree() || + getDomTreeUpdater()->getPostDomTree().verify( + PostDominatorTree::VerificationLevel::Full)) && + "PDT broken after JumpThreading"); +#else + assert(getDomTreeUpdater()->getDomTree().verify( + DominatorTree::VerificationLevel::Fast) && + "DT broken after JumpThreading"); + assert((!getDomTreeUpdater()->hasPostDomTree() || + getDomTreeUpdater()->getPostDomTree().verify( + PostDominatorTree::VerificationLevel::Fast)) && + "PDT broken after JumpThreading"); +#endif + + return getPreservedAnalysis(); } -bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, +bool JumpThreadingPass::runImpl(Function &F_, FunctionAnalysisManager *FAM_, + TargetLibraryInfo *TLI_, TargetTransformInfo *TTI_, LazyValueInfo *LVI_, - AliasAnalysis *AA_, DomTreeUpdater *DTU_, - bool HasProfileData_, - std::unique_ptr<BlockFrequencyInfo> BFI_, - std::unique_ptr<BranchProbabilityInfo> BPI_) { - LLVM_DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); + AliasAnalysis *AA_, + std::unique_ptr<DomTreeUpdater> DTU_, + std::optional<BlockFrequencyInfo *> BFI_, + std::optional<BranchProbabilityInfo *> BPI_) { + LLVM_DEBUG(dbgs() << "Jump threading on function '" << F_.getName() << "'\n"); + F = &F_; + FAM = FAM_; TLI = TLI_; TTI = TTI_; LVI = LVI_; AA = AA_; - DTU = DTU_; - BFI.reset(); - BPI.reset(); - // When profile data is available, we need to update edge weights after - // successful jump threading, which requires both BPI and BFI being available. - HasProfileData = HasProfileData_; - auto *GuardDecl = F.getParent()->getFunction( + DTU = std::move(DTU_); + BFI = BFI_; + BPI = BPI_; + auto *GuardDecl = F->getParent()->getFunction( Intrinsic::getName(Intrinsic::experimental_guard)); HasGuards = GuardDecl && !GuardDecl->use_empty(); - if (HasProfileData) { - BPI = std::move(BPI_); - BFI = std::move(BFI_); - } // Reduce the number of instructions duplicated when optimizing strictly for // size. if (BBDuplicateThreshold.getNumOccurrences()) BBDupThreshold = BBDuplicateThreshold; - else if (F.hasFnAttribute(Attribute::MinSize)) + else if (F->hasFnAttribute(Attribute::MinSize)) BBDupThreshold = 3; else BBDupThreshold = DefaultBBDupThreshold; @@ -412,22 +327,22 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, assert(DTU && "DTU isn't passed into JumpThreading before using it."); assert(DTU->hasDomTree() && "JumpThreading relies on DomTree to proceed."); DominatorTree &DT = DTU->getDomTree(); - for (auto &BB : F) + for (auto &BB : *F) if (!DT.isReachableFromEntry(&BB)) Unreachable.insert(&BB); if (!ThreadAcrossLoopHeaders) - findLoopHeaders(F); + findLoopHeaders(*F); bool EverChanged = false; bool Changed; do { Changed = false; - for (auto &BB : F) { + for (auto &BB : *F) { if (Unreachable.count(&BB)) continue; while (processBlock(&BB)) // Thread all of the branches we can over BB. - Changed = true; + Changed = ChangedSinceLastAnalysisUpdate = true; // Jump threading may have introduced redundant debug values into BB // which should be removed. @@ -437,7 +352,7 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // Stop processing BB if it's the entry or is now deleted. The following // routines attempt to eliminate BB and locating a suitable replacement // for the entry is non-trivial. - if (&BB == &F.getEntryBlock() || DTU->isBBPendingDeletion(&BB)) + if (&BB == &F->getEntryBlock() || DTU->isBBPendingDeletion(&BB)) continue; if (pred_empty(&BB)) { @@ -448,8 +363,8 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, << '\n'); LoopHeaders.erase(&BB); LVI->eraseBlock(&BB); - DeleteDeadBlock(&BB, DTU); - Changed = true; + DeleteDeadBlock(&BB, DTU.get()); + Changed = ChangedSinceLastAnalysisUpdate = true; continue; } @@ -464,12 +379,12 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // Don't alter Loop headers and latches to ensure another pass can // detect and transform nested loops later. !LoopHeaders.count(&BB) && !LoopHeaders.count(Succ) && - TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) { + TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU.get())) { RemoveRedundantDbgInstrs(Succ); // BB is valid for cleanup here because we passed in DTU. F remains // BB's parent until a DTU->getDomTree() event. LVI->eraseBlock(&BB); - Changed = true; + Changed = ChangedSinceLastAnalysisUpdate = true; } } } @@ -1140,8 +1055,8 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { << "' folding terminator: " << *BB->getTerminator() << '\n'); ++NumFolds; - ConstantFoldTerminator(BB, true, nullptr, DTU); - if (HasProfileData) + ConstantFoldTerminator(BB, true, nullptr, DTU.get()); + if (auto *BPI = getBPI()) BPI->eraseBlock(BB); return true; } @@ -1296,7 +1211,7 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { FICond->eraseFromParent(); DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, RemoveSucc}}); - if (HasProfileData) + if (auto *BPI = getBPI()) BPI->eraseBlock(BB); return true; } @@ -1740,7 +1655,7 @@ bool JumpThreadingPass::processThreadableEdges(Value *Cond, BasicBlock *BB, ++NumFolds; Term->eraseFromParent(); DTU->applyUpdatesPermissive(Updates); - if (HasProfileData) + if (auto *BPI = getBPI()) BPI->eraseBlock(BB); // If the condition is now dead due to the removal of the old terminator, @@ -1993,7 +1908,7 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) { LoopHeaders.insert(BB); LVI->eraseBlock(SinglePred); - MergeBasicBlockIntoOnlyPred(BB, DTU); + MergeBasicBlockIntoOnlyPred(BB, DTU.get()); // Now that BB is merged into SinglePred (i.e. SinglePred code followed by // BB code within one basic block `BB`), we need to invalidate the LVI @@ -2038,6 +1953,7 @@ void JumpThreadingPass::updateSSA( // PHI insertion, of which we are prepared to do, clean these up now. SSAUpdater SSAUpdate; SmallVector<Use *, 16> UsesToRename; + SmallVector<DbgValueInst *, 4> DbgValues; for (Instruction &I : *BB) { // Scan all uses of this instruction to see if it is used outside of its @@ -2053,8 +1969,16 @@ void JumpThreadingPass::updateSSA( UsesToRename.push_back(&U); } + // Find debug values outside of the block + findDbgValues(DbgValues, &I); + DbgValues.erase(remove_if(DbgValues, + [&](const DbgValueInst *DbgVal) { + return DbgVal->getParent() == BB; + }), + DbgValues.end()); + // If there are no uses outside the block, we're done with this instruction. - if (UsesToRename.empty()) + if (UsesToRename.empty() && DbgValues.empty()) continue; LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); @@ -2067,6 +1991,11 @@ void JumpThreadingPass::updateSSA( while (!UsesToRename.empty()) SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); + if (!DbgValues.empty()) { + SSAUpdate.UpdateDebugValues(&I, DbgValues); + DbgValues.clear(); + } + LLVM_DEBUG(dbgs() << "\n"); } } @@ -2298,6 +2227,11 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB, LLVM_DEBUG(dbgs() << " Threading through '" << PredBB->getName() << "' and '" << BB->getName() << "'\n"); + // Build BPI/BFI before any changes are made to IR. + bool HasProfile = doesBlockHaveProfileData(BB); + auto *BFI = getOrCreateBFI(HasProfile); + auto *BPI = getOrCreateBPI(BFI != nullptr); + BranchInst *CondBr = cast<BranchInst>(BB->getTerminator()); BranchInst *PredBBBranch = cast<BranchInst>(PredBB->getTerminator()); @@ -2307,7 +2241,8 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB, NewBB->moveAfter(PredBB); // Set the block frequency of NewBB. - if (HasProfileData) { + if (BFI) { + assert(BPI && "It's expected BPI to exist along with BFI"); auto NewBBFreq = BFI->getBlockFreq(PredPredBB) * BPI->getEdgeProbability(PredPredBB, PredBB); BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); @@ -2320,7 +2255,7 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB, cloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB); // Copy the edge probabilities from PredBB to NewBB. - if (HasProfileData) + if (BPI) BPI->copyEdgeProbabilities(PredBB, NewBB); // Update the terminator of PredPredBB to jump to NewBB instead of PredBB. @@ -2404,6 +2339,11 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB, assert(!LoopHeaders.count(BB) && !LoopHeaders.count(SuccBB) && "Don't thread across loop headers"); + // Build BPI/BFI before any changes are made to IR. + bool HasProfile = doesBlockHaveProfileData(BB); + auto *BFI = getOrCreateBFI(HasProfile); + auto *BPI = getOrCreateBPI(BFI != nullptr); + // And finally, do it! Start by factoring the predecessors if needed. BasicBlock *PredBB; if (PredBBs.size() == 1) @@ -2427,7 +2367,8 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB, NewBB->moveAfter(PredBB); // Set the block frequency of NewBB. - if (HasProfileData) { + if (BFI) { + assert(BPI && "It's expected BPI to exist along with BFI"); auto NewBBFreq = BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB); BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); @@ -2469,7 +2410,7 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB, SimplifyInstructionsInBlock(NewBB, TLI); // Update the edge weight from BB to SuccBB, which should be less than before. - updateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); + updateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB, BFI, BPI, HasProfile); // Threaded an edge! ++NumThreads; @@ -2486,10 +2427,13 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, // Collect the frequencies of all predecessors of BB, which will be used to // update the edge weight of the result of splitting predecessors. DenseMap<BasicBlock *, BlockFrequency> FreqMap; - if (HasProfileData) + auto *BFI = getBFI(); + if (BFI) { + auto *BPI = getOrCreateBPI(true); for (auto *Pred : Preds) FreqMap.insert(std::make_pair( Pred, BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB))); + } // In the case when BB is a LandingPad block we create 2 new predecessors // instead of just one. @@ -2508,10 +2452,10 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, for (auto *Pred : predecessors(NewBB)) { Updates.push_back({DominatorTree::Delete, Pred, BB}); Updates.push_back({DominatorTree::Insert, Pred, NewBB}); - if (HasProfileData) // Update frequencies between Pred -> NewBB. + if (BFI) // Update frequencies between Pred -> NewBB. NewBBFreq += FreqMap.lookup(Pred); } - if (HasProfileData) // Apply the summed frequency to NewBB. + if (BFI) // Apply the summed frequency to NewBB. BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); } @@ -2521,7 +2465,9 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { const Instruction *TI = BB->getTerminator(); - assert(TI->getNumSuccessors() > 1 && "not a split"); + if (!TI || TI->getNumSuccessors() < 2) + return false; + return hasValidBranchWeightMD(*TI); } @@ -2531,11 +2477,18 @@ bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, BasicBlock *BB, BasicBlock *NewBB, - BasicBlock *SuccBB) { - if (!HasProfileData) + BasicBlock *SuccBB, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI, + bool HasProfile) { + assert(((BFI && BPI) || (!BFI && !BFI)) && + "Both BFI & BPI should either be set or unset"); + + if (!BFI) { + assert(!HasProfile && + "It's expected to have BFI/BPI when profile info exists"); return; - - assert(BFI && BPI && "BFI & BPI should have been created here"); + } // As the edge from PredBB to BB is deleted, we have to update the block // frequency of BB. @@ -2608,7 +2561,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, // FIXME this locally as well so that BPI and BFI are consistent as well. We // shouldn't make edges extremely likely or unlikely based solely on static // estimation. - if (BBSuccProbs.size() >= 2 && doesBlockHaveProfileData(BB)) { + if (BBSuccProbs.size() >= 2 && HasProfile) { SmallVector<uint32_t, 4> Weights; for (auto Prob : BBSuccProbs) Weights.push_back(Prob.getNumerator()); @@ -2690,6 +2643,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( // mapping and using it to remap operands in the cloned instructions. for (; BI != BB->end(); ++BI) { Instruction *New = BI->clone(); + New->insertInto(PredBB, OldPredBranch->getIterator()); // Remap operands to patch up intra-block references. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) @@ -2707,7 +2661,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( {BB->getModule()->getDataLayout(), TLI, nullptr, nullptr, New})) { ValueMapping[&*BI] = IV; if (!New->mayHaveSideEffects()) { - New->deleteValue(); + New->eraseFromParent(); New = nullptr; } } else { @@ -2716,7 +2670,6 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( if (New) { // Otherwise, insert the new instruction into the block. New->setName(BI->getName()); - New->insertInto(PredBB, OldPredBranch->getIterator()); // Update Dominance from simplified New instruction operands. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) if (BasicBlock *SuccBB = dyn_cast<BasicBlock>(New->getOperand(i))) @@ -2740,7 +2693,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( // Remove the unconditional branch at the end of the PredBB block. OldPredBranch->eraseFromParent(); - if (HasProfileData) + if (auto *BPI = getBPI()) BPI->copyEdgeProbabilities(BB, PredBB); DTU->applyUpdatesPermissive(Updates); @@ -2777,21 +2730,30 @@ void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB, BI->copyMetadata(*SI, {LLVMContext::MD_prof}); SIUse->setIncomingValue(Idx, SI->getFalseValue()); SIUse->addIncoming(SI->getTrueValue(), NewBB); - // Set the block frequency of NewBB. - if (HasProfileData) { - uint64_t TrueWeight, FalseWeight; - if (extractBranchWeights(*SI, TrueWeight, FalseWeight) && - (TrueWeight + FalseWeight) != 0) { - SmallVector<BranchProbability, 2> BP; - BP.emplace_back(BranchProbability::getBranchProbability( - TrueWeight, TrueWeight + FalseWeight)); - BP.emplace_back(BranchProbability::getBranchProbability( - FalseWeight, TrueWeight + FalseWeight)); + + uint64_t TrueWeight = 1; + uint64_t FalseWeight = 1; + // Copy probabilities from 'SI' to created conditional branch in 'Pred'. + if (extractBranchWeights(*SI, TrueWeight, FalseWeight) && + (TrueWeight + FalseWeight) != 0) { + SmallVector<BranchProbability, 2> BP; + BP.emplace_back(BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight)); + BP.emplace_back(BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); + // Update BPI if exists. + if (auto *BPI = getBPI()) BPI->setEdgeProbability(Pred, BP); + } + // Set the block frequency of NewBB. + if (auto *BFI = getBFI()) { + if ((TrueWeight + FalseWeight) == 0) { + TrueWeight = 1; + FalseWeight = 1; } - - auto NewBBFreq = - BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, NewBB); + BranchProbability PredToNewBBProb = BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight); + auto NewBBFreq = BFI->getBlockFreq(Pred) * PredToNewBBProb; BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); } @@ -3112,3 +3074,93 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard, } return true; } + +PreservedAnalyses JumpThreadingPass::getPreservedAnalysis() const { + PreservedAnalyses PA; + PA.preserve<LazyValueAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + + // TODO: We would like to preserve BPI/BFI. Enable once all paths update them. + // TODO: Would be nice to verify BPI/BFI consistency as well. + return PA; +} + +template <typename AnalysisT> +typename AnalysisT::Result *JumpThreadingPass::runExternalAnalysis() { + assert(FAM && "Can't run external analysis without FunctionAnalysisManager"); + + // If there were no changes since last call to 'runExternalAnalysis' then all + // analysis is either up to date or explicitly invalidated. Just go ahead and + // run the "external" analysis. + if (!ChangedSinceLastAnalysisUpdate) { + assert(!DTU->hasPendingUpdates() && + "Lost update of 'ChangedSinceLastAnalysisUpdate'?"); + // Run the "external" analysis. + return &FAM->getResult<AnalysisT>(*F); + } + ChangedSinceLastAnalysisUpdate = false; + + auto PA = getPreservedAnalysis(); + // TODO: This shouldn't be needed once 'getPreservedAnalysis' reports BPI/BFI + // as preserved. + PA.preserve<BranchProbabilityAnalysis>(); + PA.preserve<BlockFrequencyAnalysis>(); + // Report everything except explicitly preserved as invalid. + FAM->invalidate(*F, PA); + // Update DT/PDT. + DTU->flush(); + // Make sure DT/PDT are valid before running "external" analysis. + assert(DTU->getDomTree().verify(DominatorTree::VerificationLevel::Fast)); + assert((!DTU->hasPostDomTree() || + DTU->getPostDomTree().verify( + PostDominatorTree::VerificationLevel::Fast))); + // Run the "external" analysis. + auto *Result = &FAM->getResult<AnalysisT>(*F); + // Update analysis JumpThreading depends on and not explicitly preserved. + TTI = &FAM->getResult<TargetIRAnalysis>(*F); + TLI = &FAM->getResult<TargetLibraryAnalysis>(*F); + AA = &FAM->getResult<AAManager>(*F); + + return Result; +} + +BranchProbabilityInfo *JumpThreadingPass::getBPI() { + if (!BPI) { + assert(FAM && "Can't create BPI without FunctionAnalysisManager"); + BPI = FAM->getCachedResult<BranchProbabilityAnalysis>(*F); + } + return *BPI; +} + +BlockFrequencyInfo *JumpThreadingPass::getBFI() { + if (!BFI) { + assert(FAM && "Can't create BFI without FunctionAnalysisManager"); + BFI = FAM->getCachedResult<BlockFrequencyAnalysis>(*F); + } + return *BFI; +} + +// Important note on validity of BPI/BFI. JumpThreading tries to preserve +// BPI/BFI as it goes. Thus if cached instance exists it will be updated. +// Otherwise, new instance of BPI/BFI is created (up to date by definition). +BranchProbabilityInfo *JumpThreadingPass::getOrCreateBPI(bool Force) { + auto *Res = getBPI(); + if (Res) + return Res; + + if (Force) + BPI = runExternalAnalysis<BranchProbabilityAnalysis>(); + + return *BPI; +} + +BlockFrequencyInfo *JumpThreadingPass::getOrCreateBFI(bool Force) { + auto *Res = getBFI(); + if (Res) + return Res; + + if (Force) + BFI = runExternalAnalysis<BlockFrequencyAnalysis>(); + + return *BFI; +} diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index 2865dece8723..f8fab03f151d 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -44,7 +44,6 @@ #include "llvm/Analysis/AliasSetTracker.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/Loads.h" @@ -68,6 +67,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" @@ -102,6 +102,12 @@ STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); STATISTIC(NumPromotionCandidates, "Number of promotion candidates"); STATISTIC(NumLoadPromoted, "Number of load-only promotions"); STATISTIC(NumLoadStorePromoted, "Number of load and store promotions"); +STATISTIC(NumMinMaxHoisted, + "Number of min/max expressions hoisted out of the loop"); +STATISTIC(NumGEPsHoisted, + "Number of geps reassociated and hoisted out of the loop"); +STATISTIC(NumAddSubHoisted, "Number of add/subtract expressions reassociated " + "and hoisted out of the loop"); /// Memory promotion is enabled by default. static cl::opt<bool> @@ -145,10 +151,10 @@ cl::opt<unsigned> llvm::SetLicmMssaNoAccForPromotionCap( "enable memory promotion.")); static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); -static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo, - TargetTransformInfo *TTI, bool &FreeInLoop, - bool LoopNestMode); +static bool isNotUsedOrFoldableInLoop(const Instruction &I, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + TargetTransformInfo *TTI, + bool &FoldableInLoop, bool LoopNestMode); static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU, ScalarEvolution *SE, @@ -163,9 +169,15 @@ static bool isSafeToExecuteUnconditionally( AssumptionCache *AC, bool AllowSpeculation); static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, Instruction &I, - SinkAndHoistLICMFlags &Flags); + SinkAndHoistLICMFlags &Flags, + bool InvariantGroup); static bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA, MemoryUse &MU); +/// Aggregates various functions for hoisting computations out of loop. +static bool hoistArithmetics(Instruction &I, Loop &L, + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, AssumptionCache *AC, + DominatorTree *DT); static Instruction *cloneInstructionInExitBlock( Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU); @@ -280,9 +292,6 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); - - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<LoopAnalysis>(); PA.preserve<MemorySSAAnalysis>(); return PA; @@ -293,9 +302,9 @@ void LICMPass::printPipeline( static_cast<PassInfoMixin<LICMPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; OS << (Opts.AllowSpeculation ? "" : "no-") << "allowspeculation"; - OS << ">"; + OS << '>'; } PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, @@ -334,9 +343,9 @@ void LNICMPass::printPipeline( static_cast<PassInfoMixin<LNICMPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; OS << (Opts.AllowSpeculation ? "" : "no-") << "allowspeculation"; - OS << ">"; + OS << '>'; } char LegacyLICMPass::ID = 0; @@ -351,32 +360,21 @@ INITIALIZE_PASS_END(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, false) Pass *llvm::createLICMPass() { return new LegacyLICMPass(); } -Pass *llvm::createLICMPass(unsigned LicmMssaOptCap, - unsigned LicmMssaNoAccForPromotionCap, - bool LicmAllowSpeculation) { - return new LegacyLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - LicmAllowSpeculation); -} -llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags(bool IsSink, Loop *L, - MemorySSA *MSSA) +llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags(bool IsSink, Loop &L, + MemorySSA &MSSA) : SinkAndHoistLICMFlags(SetLicmMssaOptCap, SetLicmMssaNoAccForPromotionCap, IsSink, L, MSSA) {} llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags( unsigned LicmMssaOptCap, unsigned LicmMssaNoAccForPromotionCap, bool IsSink, - Loop *L, MemorySSA *MSSA) + Loop &L, MemorySSA &MSSA) : LicmMssaOptCap(LicmMssaOptCap), LicmMssaNoAccForPromotionCap(LicmMssaNoAccForPromotionCap), IsSink(IsSink) { - assert(((L != nullptr) == (MSSA != nullptr)) && - "Unexpected values for SinkAndHoistLICMFlags"); - if (!MSSA) - return; - unsigned AccessCapCount = 0; - for (auto *BB : L->getBlocks()) - if (const auto *Accesses = MSSA->getBlockAccesses(BB)) + for (auto *BB : L.getBlocks()) + if (const auto *Accesses = MSSA.getBlockAccesses(BB)) for (const auto &MA : *Accesses) { (void)MA; ++AccessCapCount; @@ -400,7 +398,6 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); - MSSA->ensureOptimizedUses(); // If this loop has metadata indicating that LICM is not to be performed then // just exit. @@ -426,7 +423,7 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, MemorySSAUpdater MSSAU(MSSA); SinkAndHoistLICMFlags Flags(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*IsSink=*/true, L, MSSA); + /*IsSink=*/true, *L, *MSSA); // Get the preheader block to move instructions into... BasicBlock *Preheader = L->getLoopPreheader(); @@ -581,14 +578,15 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // outside of the loop. In this case, it doesn't even matter if the // operands of the instruction are loop invariant. // - bool FreeInLoop = false; + bool FoldableInLoop = false; bool LoopNestMode = OutermostLoop != nullptr; if (!I.mayHaveSideEffects() && - isNotUsedOrFreeInLoop(I, LoopNestMode ? OutermostLoop : CurLoop, - SafetyInfo, TTI, FreeInLoop, LoopNestMode) && + isNotUsedOrFoldableInLoop(I, LoopNestMode ? OutermostLoop : CurLoop, + SafetyInfo, TTI, FoldableInLoop, + LoopNestMode) && canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE)) { if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) { - if (!FreeInLoop) { + if (!FoldableInLoop) { ++II; salvageDebugInfo(I); eraseInstruction(I, *SafetyInfo, MSSAU); @@ -881,6 +879,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, LoopBlocksRPO Worklist(CurLoop); Worklist.perform(LI); bool Changed = false; + BasicBlock *Preheader = CurLoop->getLoopPreheader(); for (BasicBlock *BB : Worklist) { // Only need to process the contents of this block if it is not part of a // subloop (which would already have been processed). @@ -888,21 +887,6 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, continue; for (Instruction &I : llvm::make_early_inc_range(*BB)) { - // Try constant folding this instruction. If all the operands are - // constants, it is technically hoistable, but it would be better to - // just fold it. - if (Constant *C = ConstantFoldInstruction( - &I, I.getModule()->getDataLayout(), TLI)) { - LLVM_DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C - << '\n'); - // FIXME MSSA: Such replacements may make accesses unoptimized (D51960). - I.replaceAllUsesWith(C); - if (isInstructionTriviallyDead(&I, TLI)) - eraseInstruction(I, *SafetyInfo, MSSAU); - Changed = true; - continue; - } - // Try hoisting the instruction out to the preheader. We can only do // this if all of the operands of the instruction are loop invariant and // if it is safe to hoist the instruction. We also check block frequency @@ -914,8 +898,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE) && isSafeToExecuteUnconditionally( I, DT, TLI, CurLoop, SafetyInfo, ORE, - CurLoop->getLoopPreheader()->getTerminator(), AC, - AllowSpeculation)) { + Preheader->getTerminator(), AC, AllowSpeculation)) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); HoistedInstructions.push_back(&I); @@ -983,6 +966,13 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, } } + // Try to reassociate instructions so that part of computations can be + // done out of loop. + if (hoistArithmetics(I, *CurLoop, *SafetyInfo, MSSAU, AC, DT)) { + Changed = true; + continue; + } + // Remember possibly hoistable branches so we can actually hoist them // later if needed. if (BranchInst *BI = dyn_cast<BranchInst>(&I)) @@ -1147,6 +1137,20 @@ bool isOnlyMemoryAccess(const Instruction *I, const Loop *L, } } +static MemoryAccess *getClobberingMemoryAccess(MemorySSA &MSSA, + BatchAAResults &BAA, + SinkAndHoistLICMFlags &Flags, + MemoryUseOrDef *MA) { + // See declaration of SetLicmMssaOptCap for usage details. + if (Flags.tooManyClobberingCalls()) + return MA->getDefiningAccess(); + + MemoryAccess *Source = + MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(MA, BAA); + Flags.incrementClobberingCalls(); + return Source; +} + bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, MemorySSAUpdater &MSSAU, bool TargetExecutesOncePerLoop, @@ -1176,8 +1180,12 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (isLoadInvariantInLoop(LI, DT, CurLoop)) return true; + auto MU = cast<MemoryUse>(MSSA->getMemoryAccess(LI)); + + bool InvariantGroup = LI->hasMetadata(LLVMContext::MD_invariant_group); + bool Invalidated = pointerInvalidatedByLoop( - MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop, I, Flags); + MSSA, MU, CurLoop, I, Flags, InvariantGroup); // Check loop-invariant address because this may also be a sinkable load // whose address is not necessarily loop-invariant. if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) @@ -1210,12 +1218,17 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // Assumes don't actually alias anything or throw return true; - if (match(CI, m_Intrinsic<Intrinsic::experimental_widenable_condition>())) - // Widenable conditions don't actually alias anything or throw - return true; - // Handle simple cases by querying alias analysis. MemoryEffects Behavior = AA->getMemoryEffects(CI); + + // FIXME: we don't handle the semantics of thread local well. So that the + // address of thread locals are fake constants in coroutines. So We forbid + // to treat onlyReadsMemory call in coroutines as constants now. Note that + // it is possible to hide a thread local access in a onlyReadsMemory call. + // Remove this check after we handle the semantics of thread locals well. + if (Behavior.onlyReadsMemory() && CI->getFunction()->isPresplitCoroutine()) + return false; + if (Behavior.doesNotAccessMemory()) return true; if (Behavior.onlyReadsMemory()) { @@ -1228,7 +1241,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (Op->getType()->isPointerTy() && pointerInvalidatedByLoop( MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(CI)), CurLoop, I, - Flags)) + Flags, /*InvariantGroup=*/false)) return false; return true; } @@ -1258,21 +1271,30 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // arbitrary number of reads in the loop. if (isOnlyMemoryAccess(SI, CurLoop, MSSAU)) return true; - // If there are more accesses than the Promotion cap or no "quota" to - // check clobber, then give up as we're not walking a list that long. - if (Flags.tooManyMemoryAccesses() || Flags.tooManyClobberingCalls()) + // If there are more accesses than the Promotion cap, then give up as we're + // not walking a list that long. + if (Flags.tooManyMemoryAccesses()) + return false; + + auto *SIMD = MSSA->getMemoryAccess(SI); + BatchAAResults BAA(*AA); + auto *Source = getClobberingMemoryAccess(*MSSA, BAA, Flags, SIMD); + // Make sure there are no clobbers inside the loop. + if (!MSSA->isLiveOnEntryDef(Source) && + CurLoop->contains(Source->getBlock())) return false; + // If there are interfering Uses (i.e. their defining access is in the // loop), or ordered loads (stored as Defs!), don't move this store. // Could do better here, but this is conservatively correct. // TODO: Cache set of Uses on the first walk in runOnLoop, update when // moving accesses. Can also extend to dominating uses. - auto *SIMD = MSSA->getMemoryAccess(SI); for (auto *BB : CurLoop->getBlocks()) if (auto *Accesses = MSSA->getBlockAccesses(BB)) { for (const auto &MA : *Accesses) if (const auto *MU = dyn_cast<MemoryUse>(&MA)) { - auto *MD = MU->getDefiningAccess(); + auto *MD = getClobberingMemoryAccess(*MSSA, BAA, Flags, + const_cast<MemoryUse *>(MU)); if (!MSSA->isLiveOnEntryDef(MD) && CurLoop->contains(MD->getBlock())) return false; @@ -1293,17 +1315,13 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // Check if the call may read from the memory location written // to by SI. Check CI's attributes and arguments; the number of // such checks performed is limited above by NoOfMemAccTooLarge. - ModRefInfo MRI = AA->getModRefInfo(CI, MemoryLocation::get(SI)); + ModRefInfo MRI = BAA.getModRefInfo(CI, MemoryLocation::get(SI)); if (isModOrRefSet(MRI)) return false; } } } - auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); - Flags.incrementClobberingCalls(); - // If there are no clobbering Defs in the loop, store is safe to hoist. - return MSSA->isLiveOnEntryDef(Source) || - !CurLoop->contains(Source->getBlock()); + return true; } assert(!I.mayReadOrWriteMemory() && "unhandled aliasing"); @@ -1326,13 +1344,12 @@ static bool isTriviallyReplaceablePHI(const PHINode &PN, const Instruction &I) { return true; } -/// Return true if the instruction is free in the loop. -static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, +/// Return true if the instruction is foldable in the loop. +static bool isFoldableInLoop(const Instruction &I, const Loop *CurLoop, const TargetTransformInfo *TTI) { - InstructionCost CostI = - TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); - if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + InstructionCost CostI = + TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); if (CostI != TargetTransformInfo::TCC_Free) return false; // For a GEP, we cannot simply use getInstructionCost because currently @@ -1349,7 +1366,7 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, return true; } - return CostI == TargetTransformInfo::TCC_Free; + return false; } /// Return true if the only users of this instruction are outside of @@ -1358,12 +1375,12 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, /// /// We also return true if the instruction could be folded away in lowering. /// (e.g., a GEP can be folded into a load as an addressing mode in the loop). -static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo, - TargetTransformInfo *TTI, bool &FreeInLoop, - bool LoopNestMode) { +static bool isNotUsedOrFoldableInLoop(const Instruction &I, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + TargetTransformInfo *TTI, + bool &FoldableInLoop, bool LoopNestMode) { const auto &BlockColors = SafetyInfo->getBlockColors(); - bool IsFree = isFreeInLoop(I, CurLoop, TTI); + bool IsFoldable = isFoldableInLoop(I, CurLoop, TTI); for (const User *U : I.users()) { const Instruction *UI = cast<Instruction>(U); if (const PHINode *PN = dyn_cast<PHINode>(UI)) { @@ -1390,8 +1407,8 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, } if (CurLoop->contains(UI)) { - if (IsFree) { - FreeInLoop = true; + if (IsFoldable) { + FoldableInLoop = true; continue; } return false; @@ -1490,7 +1507,7 @@ static void moveInstructionBefore(Instruction &I, Instruction &Dest, MSSAU.getMemorySSA()->getMemoryAccess(&I))) MSSAU.moveToPlace(OldMemAcc, Dest.getParent(), MemorySSA::BeforeTerminator); if (SE) - SE->forgetValue(&I); + SE->forgetBlockAndLoopDispositions(&I); } static Instruction *sinkThroughTriviallyReplaceablePHI( @@ -1695,6 +1712,8 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, // The PHI must be trivially replaceable. Instruction *New = sinkThroughTriviallyReplaceablePHI( PN, &I, LI, SunkCopies, SafetyInfo, CurLoop, MSSAU); + // As we sink the instruction out of the BB, drop its debug location. + New->dropLocation(); PN->replaceAllUsesWith(New); eraseInstruction(*PN, *SafetyInfo, MSSAU); Changed = true; @@ -1729,7 +1748,7 @@ static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, // time in isGuaranteedToExecute if we don't actually have anything to // drop. It is a compile time optimization, not required for correctness. !SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop)) - I.dropUndefImplyingAttrsAndUnknownMetadata(); + I.dropUBImplyingAttrsAndMetadata(); if (isa<PHINode>(I)) // Move the new node to the end of the phi list in the destination block. @@ -1915,6 +1934,8 @@ bool isNotVisibleOnUnwindInLoop(const Value *Object, const Loop *L, isNotCapturedBeforeOrInLoop(Object, L, DT); } +// We don't consider globals as writable: While the physical memory is writable, +// we may not have provenance to perform the write. bool isWritableObject(const Value *Object) { // TODO: Alloca might not be writable after its lifetime ends. // See https://github.com/llvm/llvm-project/issues/51838. @@ -1925,9 +1946,6 @@ bool isWritableObject(const Value *Object) { if (auto *A = dyn_cast<Argument>(Object)) return A->hasByValAttr(); - if (auto *G = dyn_cast<GlobalVariable>(Object)) - return !G->isConstant(); - // TODO: Noalias has nothing to do with writability, this should check for // an allocator function. return isNoAliasCall(Object); @@ -2203,7 +2221,7 @@ bool llvm::promoteLoopAccessesToScalars( }); // Look at all the loop uses, and try to merge their locations. - std::vector<const DILocation *> LoopUsesLocs; + std::vector<DILocation *> LoopUsesLocs; for (auto *U : LoopUses) LoopUsesLocs.push_back(U->getDebugLoc().get()); auto DL = DebugLoc(DILocation::getMergedLocations(LoopUsesLocs)); @@ -2330,19 +2348,24 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, Instruction &I, - SinkAndHoistLICMFlags &Flags) { + SinkAndHoistLICMFlags &Flags, + bool InvariantGroup) { // For hoisting, use the walker to determine safety if (!Flags.getIsSink()) { - MemoryAccess *Source; - // See declaration of SetLicmMssaOptCap for usage details. - if (Flags.tooManyClobberingCalls()) - Source = MU->getDefiningAccess(); - else { - Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(MU); - Flags.incrementClobberingCalls(); - } + // If hoisting an invariant group, we only need to check that there + // is no store to the loaded pointer between the start of the loop, + // and the load (since all values must be the same). + + // This can be checked in two conditions: + // 1) if the memoryaccess is outside the loop + // 2) the earliest access is at the loop header, + // if the memory loaded is the phi node + + BatchAAResults BAA(MSSA->getAA()); + MemoryAccess *Source = getClobberingMemoryAccess(*MSSA, BAA, Flags, MU); return !MSSA->isLiveOnEntryDef(Source) && - CurLoop->contains(Source->getBlock()); + CurLoop->contains(Source->getBlock()) && + !(InvariantGroup && Source->getBlock() == CurLoop->getHeader() && isa<MemoryPhi>(Source)); } // For sinking, we'd need to check all Defs below this use. The getClobbering @@ -2383,6 +2406,304 @@ bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA, MemoryUse &MU) { return false; } +/// Try to simplify things like (A < INV_1 AND icmp A < INV_2) into (A < +/// min(INV_1, INV_2)), if INV_1 and INV_2 are both loop invariants and their +/// minimun can be computed outside of loop, and X is not a loop-invariant. +static bool hoistMinMax(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU) { + bool Inverse = false; + using namespace PatternMatch; + Value *Cond1, *Cond2; + if (match(&I, m_LogicalOr(m_Value(Cond1), m_Value(Cond2)))) { + Inverse = true; + } else if (match(&I, m_LogicalAnd(m_Value(Cond1), m_Value(Cond2)))) { + // Do nothing + } else + return false; + + auto MatchICmpAgainstInvariant = [&](Value *C, ICmpInst::Predicate &P, + Value *&LHS, Value *&RHS) { + if (!match(C, m_OneUse(m_ICmp(P, m_Value(LHS), m_Value(RHS))))) + return false; + if (!LHS->getType()->isIntegerTy()) + return false; + if (!ICmpInst::isRelational(P)) + return false; + if (L.isLoopInvariant(LHS)) { + std::swap(LHS, RHS); + P = ICmpInst::getSwappedPredicate(P); + } + if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS)) + return false; + if (Inverse) + P = ICmpInst::getInversePredicate(P); + return true; + }; + ICmpInst::Predicate P1, P2; + Value *LHS1, *LHS2, *RHS1, *RHS2; + if (!MatchICmpAgainstInvariant(Cond1, P1, LHS1, RHS1) || + !MatchICmpAgainstInvariant(Cond2, P2, LHS2, RHS2)) + return false; + if (P1 != P2 || LHS1 != LHS2) + return false; + + // Everything is fine, we can do the transform. + bool UseMin = ICmpInst::isLT(P1) || ICmpInst::isLE(P1); + assert( + (UseMin || ICmpInst::isGT(P1) || ICmpInst::isGE(P1)) && + "Relational predicate is either less (or equal) or greater (or equal)!"); + Intrinsic::ID id = ICmpInst::isSigned(P1) + ? (UseMin ? Intrinsic::smin : Intrinsic::smax) + : (UseMin ? Intrinsic::umin : Intrinsic::umax); + auto *Preheader = L.getLoopPreheader(); + assert(Preheader && "Loop is not in simplify form?"); + IRBuilder<> Builder(Preheader->getTerminator()); + // We are about to create a new guaranteed use for RHS2 which might not exist + // before (if it was a non-taken input of logical and/or instruction). If it + // was poison, we need to freeze it. Note that no new use for LHS and RHS1 are + // introduced, so they don't need this. + if (isa<SelectInst>(I)) + RHS2 = Builder.CreateFreeze(RHS2, RHS2->getName() + ".fr"); + Value *NewRHS = Builder.CreateBinaryIntrinsic( + id, RHS1, RHS2, nullptr, StringRef("invariant.") + + (ICmpInst::isSigned(P1) ? "s" : "u") + + (UseMin ? "min" : "max")); + Builder.SetInsertPoint(&I); + ICmpInst::Predicate P = P1; + if (Inverse) + P = ICmpInst::getInversePredicate(P); + Value *NewCond = Builder.CreateICmp(P, LHS1, NewRHS); + NewCond->takeName(&I); + I.replaceAllUsesWith(NewCond); + eraseInstruction(I, SafetyInfo, MSSAU); + eraseInstruction(*cast<Instruction>(Cond1), SafetyInfo, MSSAU); + eraseInstruction(*cast<Instruction>(Cond2), SafetyInfo, MSSAU); + return true; +} + +/// Reassociate gep (gep ptr, idx1), idx2 to gep (gep ptr, idx2), idx1 if +/// this allows hoisting the inner GEP. +static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, AssumptionCache *AC, + DominatorTree *DT) { + auto *GEP = dyn_cast<GetElementPtrInst>(&I); + if (!GEP) + return false; + + auto *Src = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand()); + if (!Src || !Src->hasOneUse() || !L.contains(Src)) + return false; + + Value *SrcPtr = Src->getPointerOperand(); + auto LoopInvariant = [&](Value *V) { return L.isLoopInvariant(V); }; + if (!L.isLoopInvariant(SrcPtr) || !all_of(GEP->indices(), LoopInvariant)) + return false; + + // This can only happen if !AllowSpeculation, otherwise this would already be + // handled. + // FIXME: Should we respect AllowSpeculation in these reassociation folds? + // The flag exists to prevent metadata dropping, which is not relevant here. + if (all_of(Src->indices(), LoopInvariant)) + return false; + + // The swapped GEPs are inbounds if both original GEPs are inbounds + // and the sign of the offsets is the same. For simplicity, only + // handle both offsets being non-negative. + const DataLayout &DL = GEP->getModule()->getDataLayout(); + auto NonNegative = [&](Value *V) { + return isKnownNonNegative(V, DL, 0, AC, GEP, DT); + }; + bool IsInBounds = Src->isInBounds() && GEP->isInBounds() && + all_of(Src->indices(), NonNegative) && + all_of(GEP->indices(), NonNegative); + + BasicBlock *Preheader = L.getLoopPreheader(); + IRBuilder<> Builder(Preheader->getTerminator()); + Value *NewSrc = Builder.CreateGEP(GEP->getSourceElementType(), SrcPtr, + SmallVector<Value *>(GEP->indices()), + "invariant.gep", IsInBounds); + Builder.SetInsertPoint(GEP); + Value *NewGEP = Builder.CreateGEP(Src->getSourceElementType(), NewSrc, + SmallVector<Value *>(Src->indices()), "gep", + IsInBounds); + GEP->replaceAllUsesWith(NewGEP); + eraseInstruction(*GEP, SafetyInfo, MSSAU); + eraseInstruction(*Src, SafetyInfo, MSSAU); + return true; +} + +/// Try to turn things like "LV + C1 < C2" into "LV < C2 - C1". Here +/// C1 and C2 are loop invariants and LV is a loop-variant. +static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS, + Value *InvariantRHS, ICmpInst &ICmp, Loop &L, + ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU, + AssumptionCache *AC, DominatorTree *DT) { + assert(ICmpInst::isSigned(Pred) && "Not supported yet!"); + assert(!L.isLoopInvariant(VariantLHS) && "Precondition."); + assert(L.isLoopInvariant(InvariantRHS) && "Precondition."); + + // Try to represent VariantLHS as sum of invariant and variant operands. + using namespace PatternMatch; + Value *VariantOp, *InvariantOp; + if (!match(VariantLHS, m_NSWAdd(m_Value(VariantOp), m_Value(InvariantOp)))) + return false; + + // LHS itself is a loop-variant, try to represent it in the form: + // "VariantOp + InvariantOp". If it is possible, then we can reassociate. + if (L.isLoopInvariant(VariantOp)) + std::swap(VariantOp, InvariantOp); + if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp)) + return false; + + // In order to turn "LV + C1 < C2" into "LV < C2 - C1", we need to be able to + // freely move values from left side of inequality to right side (just as in + // normal linear arithmetics). Overflows make things much more complicated, so + // we want to avoid this. + auto &DL = L.getHeader()->getModule()->getDataLayout(); + bool ProvedNoOverflowAfterReassociate = + computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp, + DT) == llvm::OverflowResult::NeverOverflows; + if (!ProvedNoOverflowAfterReassociate) + return false; + auto *Preheader = L.getLoopPreheader(); + assert(Preheader && "Loop is not in simplify form?"); + IRBuilder<> Builder(Preheader->getTerminator()); + Value *NewCmpOp = Builder.CreateSub(InvariantRHS, InvariantOp, "invariant.op", + /*HasNUW*/ false, /*HasNSW*/ true); + ICmp.setPredicate(Pred); + ICmp.setOperand(0, VariantOp); + ICmp.setOperand(1, NewCmpOp); + eraseInstruction(cast<Instruction>(*VariantLHS), SafetyInfo, MSSAU); + return true; +} + +/// Try to reassociate and hoist the following two patterns: +/// LV - C1 < C2 --> LV < C1 + C2, +/// C1 - LV < C2 --> LV > C1 - C2. +static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS, + Value *InvariantRHS, ICmpInst &ICmp, Loop &L, + ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU, + AssumptionCache *AC, DominatorTree *DT) { + assert(ICmpInst::isSigned(Pred) && "Not supported yet!"); + assert(!L.isLoopInvariant(VariantLHS) && "Precondition."); + assert(L.isLoopInvariant(InvariantRHS) && "Precondition."); + + // Try to represent VariantLHS as sum of invariant and variant operands. + using namespace PatternMatch; + Value *VariantOp, *InvariantOp; + if (!match(VariantLHS, m_NSWSub(m_Value(VariantOp), m_Value(InvariantOp)))) + return false; + + bool VariantSubtracted = false; + // LHS itself is a loop-variant, try to represent it in the form: + // "VariantOp + InvariantOp". If it is possible, then we can reassociate. If + // the variant operand goes with minus, we use a slightly different scheme. + if (L.isLoopInvariant(VariantOp)) { + std::swap(VariantOp, InvariantOp); + VariantSubtracted = true; + Pred = ICmpInst::getSwappedPredicate(Pred); + } + if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp)) + return false; + + // In order to turn "LV - C1 < C2" into "LV < C2 + C1", we need to be able to + // freely move values from left side of inequality to right side (just as in + // normal linear arithmetics). Overflows make things much more complicated, so + // we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that + // "C1 - C2" does not overflow. + auto &DL = L.getHeader()->getModule()->getDataLayout(); + if (VariantSubtracted) { + // C1 - LV < C2 --> LV > C1 - C2 + if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp, + DT) != llvm::OverflowResult::NeverOverflows) + return false; + } else { + // LV - C1 < C2 --> LV < C1 + C2 + if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp, + DT) != llvm::OverflowResult::NeverOverflows) + return false; + } + auto *Preheader = L.getLoopPreheader(); + assert(Preheader && "Loop is not in simplify form?"); + IRBuilder<> Builder(Preheader->getTerminator()); + Value *NewCmpOp = + VariantSubtracted + ? Builder.CreateSub(InvariantOp, InvariantRHS, "invariant.op", + /*HasNUW*/ false, /*HasNSW*/ true) + : Builder.CreateAdd(InvariantOp, InvariantRHS, "invariant.op", + /*HasNUW*/ false, /*HasNSW*/ true); + ICmp.setPredicate(Pred); + ICmp.setOperand(0, VariantOp); + ICmp.setOperand(1, NewCmpOp); + eraseInstruction(cast<Instruction>(*VariantLHS), SafetyInfo, MSSAU); + return true; +} + +/// Reassociate and hoist add/sub expressions. +static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, AssumptionCache *AC, + DominatorTree *DT) { + using namespace PatternMatch; + ICmpInst::Predicate Pred; + Value *LHS, *RHS; + if (!match(&I, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) + return false; + + // TODO: Support unsigned predicates? + if (!ICmpInst::isSigned(Pred)) + return false; + + // Put variant operand to LHS position. + if (L.isLoopInvariant(LHS)) { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + // We want to delete the initial operation after reassociation, so only do it + // if it has no other uses. + if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS) || !LHS->hasOneUse()) + return false; + + // TODO: We could go with smarter context, taking common dominator of all I's + // users instead of I itself. + if (hoistAdd(Pred, LHS, RHS, cast<ICmpInst>(I), L, SafetyInfo, MSSAU, AC, DT)) + return true; + + if (hoistSub(Pred, LHS, RHS, cast<ICmpInst>(I), L, SafetyInfo, MSSAU, AC, DT)) + return true; + + return false; +} + +static bool hoistArithmetics(Instruction &I, Loop &L, + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, AssumptionCache *AC, + DominatorTree *DT) { + // Optimize complex patterns, such as (x < INV1 && x < INV2), turning them + // into (x < min(INV1, INV2)), and hoisting the invariant part of this + // expression out of the loop. + if (hoistMinMax(I, L, SafetyInfo, MSSAU)) { + ++NumHoisted; + ++NumMinMaxHoisted; + return true; + } + + // Try to hoist GEPs by reassociation. + if (hoistGEP(I, L, SafetyInfo, MSSAU, AC, DT)) { + ++NumHoisted; + ++NumGEPsHoisted; + return true; + } + + // Try to hoist add/sub's by reassociation. + if (hoistAddSub(I, L, SafetyInfo, MSSAU, AC, DT)) { + ++NumHoisted; + ++NumAddSubHoisted; + return true; + } + + return false; +} + /// Little predicate that returns true if the specified basic block is in /// a subloop of the current one, not the current one itself. /// diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index 7e4dbace043a..c041e3621a16 100644 --- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -26,8 +26,6 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -73,7 +71,7 @@ static bool isLoopDead(Loop *L, ScalarEvolution &SE, // of the loop. bool AllEntriesInvariant = true; bool AllOutgoingValuesSame = true; - if (!L->hasNoExitBlocks()) { + if (ExitBlock) { for (PHINode &P : ExitBlock->phis()) { Value *incoming = P.getIncomingValueForBlock(ExitingBlocks[0]); @@ -488,6 +486,14 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, LLVM_DEBUG(dbgs() << "Deletion requires at most one exit block.\n"); return LoopDeletionResult::Unmodified; } + + // We can't directly branch to an EH pad. Don't bother handling this edge + // case. + if (ExitBlock && ExitBlock->isEHPad()) { + LLVM_DEBUG(dbgs() << "Cannot delete loop exiting to EH pad.\n"); + return LoopDeletionResult::Unmodified; + } + // Finally, we have to check that the loop really is dead. bool Changed = false; if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader, LI)) { @@ -539,62 +545,3 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { -class LoopDeletionLegacyPass : public LoopPass { -public: - static char ID; // Pass ID, replacement for typeid - LoopDeletionLegacyPass() : LoopPass(ID) { - initializeLoopDeletionLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - // Possibly eliminate loop L if it is dead. - bool runOnLoop(Loop *L, LPPassManager &) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addPreserved<MemorySSAWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; -} - -char LoopDeletionLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoopDeletionLegacyPass, "loop-deletion", - "Delete dead loops", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_END(LoopDeletionLegacyPass, "loop-deletion", - "Delete dead loops", false, false) - -Pass *llvm::createLoopDeletionPass() { return new LoopDeletionLegacyPass(); } - -bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - MemorySSA *MSSA = nullptr; - if (MSSAAnalysis) - MSSA = &MSSAAnalysis->getMSSA(); - // For the old PM, we can't use OptimizationRemarkEmitter as an analysis - // pass. Function analyses need to be preserved across loop transformations - // but ORE cannot be preserved (see comment before the pass definition). - OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); - - LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: "); - LLVM_DEBUG(L->dump()); - - LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI, MSSA, ORE); - - // If we can prove the backedge isn't taken, just break it and be done. This - // leaves the loop structure in place which means it can handle dispatching - // to the right exit based on whatever loop invariant structure remains. - if (Result != LoopDeletionResult::Deleted) - Result = merge(Result, breakBackedgeIfNotTaken(L, DT, SE, LI, MSSA, ORE)); - - if (Result == LoopDeletionResult::Deleted) - LPM.markLoopAsDeleted(*L); - - return Result != LoopDeletionResult::Unmodified; -} diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 7b52b7dca85f..27196e46ca56 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -52,13 +52,10 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -995,45 +992,6 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, return Changed; } -namespace { - -/// The pass class. -class LoopDistributeLegacy : public FunctionPass { -public: - static char ID; - - LoopDistributeLegacy() : FunctionPass(ID) { - // The default is set by the caller. - initializeLoopDistributeLegacyPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - - return runImpl(F, LI, DT, SE, ORE, LAIs); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<LoopAccessLegacyAnalysis>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; - -} // end anonymous namespace - PreservedAnalyses LoopDistributePass::run(Function &F, FunctionAnalysisManager &AM) { auto &LI = AM.getResult<LoopAnalysis>(F); @@ -1050,18 +1008,3 @@ PreservedAnalyses LoopDistributePass::run(Function &F, PA.preserve<DominatorTreeAnalysis>(); return PA; } - -char LoopDistributeLegacy::ID; - -static const char ldist_name[] = "Loop Distribution"; - -INITIALIZE_PASS_BEGIN(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, - false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_END(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false) - -FunctionPass *llvm::createLoopDistributePass() { return new LoopDistributeLegacy(); } diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 7d9ce8d35e0b..edc8a4956dd1 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -65,11 +65,8 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -318,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L, return false; } - // The Extend=false flag is used for getTripCountFromExitCount as we want - // to verify and match it with the pattern matched tripcount. Please note - // that overflow checks are performed in checkOverflow, but are first tried - // to avoid by widening the IV. + // Evaluating in the trip count's type can not overflow here as the overflow + // checks are performed in checkOverflow, but are first tried to avoid by + // widening the IV. const SCEV *SCEVTripCount = - SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false); + SE->getTripCountFromExitCount(BackedgeTakenCount, + BackedgeTakenCount->getType(), L); const SCEV *SCEVRHS = SE->getSCEV(RHS); if (SCEVRHS == SCEVTripCount) @@ -336,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L, // Find the extended backedge taken count and extended trip count using // SCEV. One of these should now match the RHS of the compare. BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); - SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, + RHS->getType(), L); if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); return false; @@ -918,20 +916,6 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); } -bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U, - MemorySSAUpdater *MSSAU) { - bool Changed = false; - for (Loop *InnerLoop : LN.getLoops()) { - auto *OuterLoop = InnerLoop->getParentLoop(); - if (!OuterLoop) - continue; - FlattenInfo FI(OuterLoop, InnerLoop); - Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); - } - return Changed; -} - PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { @@ -949,8 +933,14 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, - MSSAU ? &*MSSAU : nullptr); + for (Loop *InnerLoop : LN.getLoops()) { + auto *OuterLoop = InnerLoop->getParentLoop(); + if (!OuterLoop) + continue; + FlattenInfo FI(OuterLoop, InnerLoop); + Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, + MSSAU ? &*MSSAU : nullptr); + } if (!Changed) return PreservedAnalyses::all(); @@ -963,60 +953,3 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { -class LoopFlattenLegacyPass : public FunctionPass { -public: - static char ID; // Pass ID, replacement for typeid - LoopFlattenLegacyPass() : FunctionPass(ID) { - initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - // Possibly flatten loop L into its child. - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - getLoopAnalysisUsage(AU); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<TargetTransformInfoWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addPreserved<AssumptionCacheTracker>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } -}; -} // namespace - -char LoopFlattenLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", - false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", - false, false) - -FunctionPass *llvm::createLoopFlattenPass() { - return new LoopFlattenLegacyPass(); -} - -bool LoopFlattenLegacyPass::runOnFunction(Function &F) { - ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>(); - auto *TTI = &TTIP.getTTI(F); - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - - std::optional<MemorySSAUpdater> MSSAU; - if (MSSA) - MSSAU = MemorySSAUpdater(&MSSA->getMSSA()); - - bool Changed = false; - for (Loop *L : *LI) { - auto LN = LoopNest::getLoopNest(*L, *SE); - Changed |= - Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr); - } - return Changed; -} diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index 0eecec373736..d35b562be0aa 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -57,12 +57,9 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeMoverUtils.h" @@ -2061,51 +2058,6 @@ private: return FC0.L; } }; - -struct LoopFuseLegacy : public FunctionPass { - - static char ID; - - LoopFuseLegacy() : FunctionPass(ID) { - initializeLoopFuseLegacyPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredID(LoopSimplifyID); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<DependenceAnalysisWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<PostDominatorTreeWrapperPass>(); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI(); - auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - const TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - const DataLayout &DL = F.getParent()->getDataLayout(); - - LoopFuser LF(LI, DT, DI, SE, PDT, ORE, DL, AC, TTI); - return LF.fuseLoops(F); - } -}; } // namespace PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) { @@ -2142,19 +2094,3 @@ PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<LoopAnalysis>(); return PA; } - -char LoopFuseLegacy::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopFuseLegacy, "loop-fusion", "Loop Fusion", false, - false) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(LoopFuseLegacy, "loop-fusion", "Loop Fusion", false, false) - -FunctionPass *llvm::createLoopFusePass() { return new LoopFuseLegacy(); } diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 035cbdf595a8..8572a442e784 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -84,14 +84,11 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/InstructionCost.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -254,62 +251,8 @@ private: /// @} }; - -class LoopIdiomRecognizeLegacyPass : public LoopPass { -public: - static char ID; - - explicit LoopIdiomRecognizeLegacyPass() : LoopPass(ID) { - initializeLoopIdiomRecognizeLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (DisableLIRP::All) - return false; - - if (skipLoop(L)) - return false; - - AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( - *L->getHeader()->getParent()); - const TargetTransformInfo *TTI = - &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *L->getHeader()->getParent()); - const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); - auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - MemorySSA *MSSA = nullptr; - if (MSSAAnalysis) - MSSA = &MSSAAnalysis->getMSSA(); - - // For the old PM, we can't use OptimizationRemarkEmitter as an analysis - // pass. Function analyses need to be preserved across loop transformations - // but ORE cannot be preserved (see comment before the pass definition). - OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); - - LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, MSSA, DL, ORE); - return LIR.runOnLoop(L); - } - - /// This transformation requires natural loop information & requires that - /// loop preheaders be inserted into the CFG. - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; - } // end anonymous namespace -char LoopIdiomRecognizeLegacyPass::ID = 0; - PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { @@ -334,16 +277,6 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, return PA; } -INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom", - "Recognize loop idioms", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(LoopIdiomRecognizeLegacyPass, "loop-idiom", - "Recognize loop idioms", false, false) - -Pass *llvm::createLoopIdiomPass() { return new LoopIdiomRecognizeLegacyPass(); } - static void deleteDeadInstruction(Instruction *I) { I->replaceAllUsesWith(PoisonValue::get(I->getType())); I->eraseFromParent(); @@ -1050,33 +983,6 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, return SE->getMinusSCEV(Start, Index); } -/// Compute trip count from the backedge taken count. -static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, - Loop *CurLoop, const DataLayout *DL, - ScalarEvolution *SE) { - const SCEV *TripCountS = nullptr; - // The # stored bytes is (BECount+1). Expand the trip count out to - // pointer size if it isn't already. - // - // If we're going to need to zero extend the BE count, check if we can add - // one to it prior to zero extending without overflow. Provided this is safe, - // it allows better simplification of the +1. - if (DL->getTypeSizeInBits(BECount->getType()) < - DL->getTypeSizeInBits(IntPtr) && - SE->isLoopEntryGuardedByCond( - CurLoop, ICmpInst::ICMP_NE, BECount, - SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { - TripCountS = SE->getZeroExtendExpr( - SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), - IntPtr); - } else { - TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), - SE->getOne(IntPtr), SCEV::FlagNUW); - } - - return TripCountS; -} - /// Compute the number of bytes as a SCEV from the backedge taken count. /// /// This also maps the SCEV into the provided type and tries to handle the @@ -1084,8 +990,8 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, const SCEV *StoreSizeSCEV, Loop *CurLoop, const DataLayout *DL, ScalarEvolution *SE) { - const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); - + const SCEV *TripCountSCEV = + SE->getTripCountFromExitCount(BECount, IntPtr, CurLoop); return SE->getMulExpr(TripCountSCEV, SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), SCEV::FlagNUW); @@ -1168,20 +1074,24 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); + if (!SplatValue && !isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16)) + return Changed; + + AAMDNodes AATags = TheStore->getAAMetadata(); + for (Instruction *Store : Stores) + AATags = AATags.merge(Store->getAAMetadata()); + if (auto CI = dyn_cast<ConstantInt>(NumBytes)) + AATags = AATags.extendTo(CI->getZExtValue()); + else + AATags = AATags.extendTo(-1); + CallInst *NewCall; if (SplatValue) { - AAMDNodes AATags = TheStore->getAAMetadata(); - for (Instruction *Store : Stores) - AATags = AATags.merge(Store->getAAMetadata()); - if (auto CI = dyn_cast<ConstantInt>(NumBytes)) - AATags = AATags.extendTo(CI->getZExtValue()); - else - AATags = AATags.extendTo(-1); - NewCall = Builder.CreateMemSet( BasePtr, SplatValue, NumBytes, MaybeAlign(StoreAlignment), /*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias); - } else if (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16)) { + } else { + assert (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16)); // Everything is emitted in default address space Type *Int8PtrTy = DestInt8PtrTy; @@ -1199,8 +1109,17 @@ bool LoopIdiomRecognize::processLoopStridedStore( GV->setAlignment(Align(16)); Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); - } else - return Changed; + + // Set the TBAA info if present. + if (AATags.TBAA) + NewCall->setMetadata(LLVMContext::MD_tbaa, AATags.TBAA); + + if (AATags.Scope) + NewCall->setMetadata(LLVMContext::MD_alias_scope, AATags.Scope); + + if (AATags.NoAlias) + NewCall->setMetadata(LLVMContext::MD_noalias, AATags.NoAlias); + } NewCall->setDebugLoc(TheStore->getDebugLoc()); @@ -2471,7 +2390,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { // intrinsic/shift we'll use are not cheap. Note that we are okay with *just* // making the loop countable, even if nothing else changes. IntrinsicCostAttributes Attrs( - IntrID, Ty, {UndefValue::get(Ty), /*is_zero_undef=*/Builder.getTrue()}); + IntrID, Ty, {PoisonValue::get(Ty), /*is_zero_poison=*/Builder.getTrue()}); InstructionCost Cost = TTI->getIntrinsicInstrCost(Attrs, CostKind); if (Cost > TargetTransformInfo::TCC_Basic) { LLVM_DEBUG(dbgs() << DEBUG_TYPE @@ -2487,6 +2406,24 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { // Ok, transform appears worthwhile. MadeChange = true; + if (!isGuaranteedNotToBeUndefOrPoison(BitPos)) { + // BitMask may be computed from BitPos, Freeze BitPos so we can increase + // it's use count. + Instruction *InsertPt = nullptr; + if (auto *BitPosI = dyn_cast<Instruction>(BitPos)) + InsertPt = BitPosI->getInsertionPointAfterDef(); + else + InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca(); + if (!InsertPt) + return false; + FreezeInst *BitPosFrozen = + new FreezeInst(BitPos, BitPos->getName() + ".fr", InsertPt); + BitPos->replaceUsesWithIf(BitPosFrozen, [BitPosFrozen](Use &U) { + return U.getUser() != BitPosFrozen; + }); + BitPos = BitPosFrozen; + } + // Step 1: Compute the loop trip count. Value *LowBitMask = Builder.CreateAdd(BitMask, Constant::getAllOnesValue(Ty), @@ -2495,7 +2432,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { Builder.CreateOr(LowBitMask, BitMask, BitPos->getName() + ".mask"); Value *XMasked = Builder.CreateAnd(X, Mask, X->getName() + ".masked"); CallInst *XMaskedNumLeadingZeros = Builder.CreateIntrinsic( - IntrID, Ty, {XMasked, /*is_zero_undef=*/Builder.getTrue()}, + IntrID, Ty, {XMasked, /*is_zero_poison=*/Builder.getTrue()}, /*FMFSource=*/nullptr, XMasked->getName() + ".numleadingzeros"); Value *XMaskedNumActiveBits = Builder.CreateSub( ConstantInt::get(Ty, Ty->getScalarSizeInBits()), XMaskedNumLeadingZeros, @@ -2825,7 +2762,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() { // intrinsic we'll use are not cheap. Note that we are okay with *just* // making the loop countable, even if nothing else changes. IntrinsicCostAttributes Attrs( - IntrID, Ty, {UndefValue::get(Ty), /*is_zero_undef=*/Builder.getFalse()}); + IntrID, Ty, {PoisonValue::get(Ty), /*is_zero_poison=*/Builder.getFalse()}); InstructionCost Cost = TTI->getIntrinsicInstrCost(Attrs, CostKind); if (Cost > TargetTransformInfo::TCC_Basic) { LLVM_DEBUG(dbgs() << DEBUG_TYPE @@ -2843,7 +2780,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() { // Step 1: Compute the loop's final IV value / trip count. CallInst *ValNumLeadingZeros = Builder.CreateIntrinsic( - IntrID, Ty, {Val, /*is_zero_undef=*/Builder.getFalse()}, + IntrID, Ty, {Val, /*is_zero_poison=*/Builder.getFalse()}, /*FMFSource=*/nullptr, Val->getName() + ".numleadingzeros"); Value *ValNumActiveBits = Builder.CreateSub( ConstantInt::get(Ty, Ty->getScalarSizeInBits()), ValNumLeadingZeros, diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 0a7c62113c7f..91286ebcea33 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -30,20 +30,16 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -187,8 +183,7 @@ static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx, // if the direction matrix, after the same permutation is applied to its // columns, has no ">" direction as the leftmost non-"=" direction in any row. static bool isLexicographicallyPositive(std::vector<char> &DV) { - for (unsigned Level = 0; Level < DV.size(); ++Level) { - unsigned char Direction = DV[Level]; + for (unsigned char Direction : DV) { if (Direction == '<') return true; if (Direction == '>' || Direction == '*') @@ -736,7 +731,6 @@ bool LoopInterchangeLegality::findInductionAndReductions( if (!L->getLoopLatch() || !L->getLoopPredecessor()) return false; for (PHINode &PHI : L->getHeader()->phis()) { - RecurrenceDescriptor RD; InductionDescriptor ID; if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) Inductions.push_back(&PHI); @@ -1105,8 +1099,7 @@ LoopInterchangeProfitability::isProfitablePerLoopCacheAnalysis( // This is the new cost model returned from loop cache analysis. // A smaller index means the loop should be placed an outer loop, and vice // versa. - if (CostMap.find(InnerLoop) != CostMap.end() && - CostMap.find(OuterLoop) != CostMap.end()) { + if (CostMap.contains(InnerLoop) && CostMap.contains(OuterLoop)) { unsigned InnerIndex = 0, OuterIndex = 0; InnerIndex = CostMap.find(InnerLoop)->second; OuterIndex = CostMap.find(OuterLoop)->second; @@ -1692,12 +1685,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() { // latch. In that case, we need to create LCSSA phis for them, because after // interchanging they will be defined in the new inner loop and used in the // new outer loop. - IRBuilder<> Builder(OuterLoopHeader->getContext()); SmallVector<Instruction *, 4> MayNeedLCSSAPhis; for (Instruction &I : make_range(OuterLoopHeader->begin(), std::prev(OuterLoopHeader->end()))) MayNeedLCSSAPhis.push_back(&I); - formLCSSAForInstructions(MayNeedLCSSAPhis, *DT, *LI, SE, Builder); + formLCSSAForInstructions(MayNeedLCSSAPhis, *DT, *LI, SE); return true; } @@ -1716,52 +1708,6 @@ bool LoopInterchangeTransform::adjustLoopLinks() { return Changed; } -namespace { -/// Main LoopInterchange Pass. -struct LoopInterchangeLegacyPass : public LoopPass { - static char ID; - - LoopInterchangeLegacyPass() : LoopPass(ID) { - initializeLoopInterchangeLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DependenceAnalysisWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - - getLoopAnalysisUsage(AU); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - std::unique_ptr<CacheCost> CC = nullptr; - return LoopInterchange(SE, LI, DI, DT, CC, ORE).run(L); - } -}; -} // namespace - -char LoopInterchangeLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopInterchangeLegacyPass, "loop-interchange", - "Interchanges loops for cache reuse", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) - -INITIALIZE_PASS_END(LoopInterchangeLegacyPass, "loop-interchange", - "Interchanges loops for cache reuse", false, false) - -Pass *llvm::createLoopInterchangePass() { - return new LoopInterchangeLegacyPass(); -} - PreservedAnalyses LoopInterchangePass::run(LoopNest &LN, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index b615a0a0a9c0..179ccde8d035 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -46,13 +46,10 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopVersioning.h" @@ -91,8 +88,9 @@ struct StoreToLoadForwardingCandidate { StoreToLoadForwardingCandidate(LoadInst *Load, StoreInst *Store) : Load(Load), Store(Store) {} - /// Return true if the dependence from the store to the load has a - /// distance of one. E.g. A[i+1] = A[i] + /// Return true if the dependence from the store to the load has an + /// absolute distance of one. + /// E.g. A[i+1] = A[i] (or A[i-1] = A[i] for descending loop) bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L) const { Value *LoadPtr = Load->getPointerOperand(); @@ -106,11 +104,19 @@ struct StoreToLoadForwardingCandidate { DL.getTypeSizeInBits(getLoadStoreType(Store)) && "Should be a known dependence"); - // Currently we only support accesses with unit stride. FIXME: we should be - // able to handle non unit stirde as well as long as the stride is equal to - // the dependence distance. - if (getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0) != 1 || - getPtrStride(PSE, LoadType, StorePtr, L).value_or(0) != 1) + int64_t StrideLoad = getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0); + int64_t StrideStore = getPtrStride(PSE, LoadType, StorePtr, L).value_or(0); + if (!StrideLoad || !StrideStore || StrideLoad != StrideStore) + return false; + + // TODO: This check for stride values other than 1 and -1 can be eliminated. + // However, doing so may cause the LoopAccessAnalysis to overcompensate, + // generating numerous non-wrap runtime checks that may undermine the + // benefits of load elimination. To safely implement support for non-unit + // strides, we would need to ensure either that the processed case does not + // require these additional checks, or improve the LAA to handle them more + // efficiently, or potentially both. + if (std::abs(StrideLoad) != 1) return false; unsigned TypeByteSize = DL.getTypeAllocSize(const_cast<Type *>(LoadType)); @@ -123,7 +129,7 @@ struct StoreToLoadForwardingCandidate { auto *Dist = cast<SCEVConstant>( PSE.getSE()->getMinusSCEV(StorePtrSCEV, LoadPtrSCEV)); const APInt &Val = Dist->getAPInt(); - return Val == TypeByteSize; + return Val == TypeByteSize * StrideLoad; } Value *getLoadPtr() const { return Load->getPointerOperand(); } @@ -658,70 +664,6 @@ static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, return Changed; } -namespace { - -/// The pass. Most of the work is delegated to the per-loop -/// LoadEliminationForLoop class. -class LoopLoadElimination : public FunctionPass { -public: - static char ID; - - LoopLoadElimination() : FunctionPass(ID) { - initializeLoopLoadEliminationPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - auto *BFI = (PSI && PSI->hasProfileSummary()) ? - &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : - nullptr; - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - - // Process each loop nest in the function. - return eliminateLoadsAcrossLoops(F, LI, DT, BFI, PSI, SE, /*AC*/ nullptr, - LAIs); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredID(LoopSimplifyID); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<LoopAccessLegacyAnalysis>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - -char LoopLoadElimination::ID; - -static const char LLE_name[] = "Loop Load Elimination"; - -INITIALIZE_PASS_BEGIN(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) -INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) - -FunctionPass *llvm::createLoopLoadEliminationPass() { - return new LoopLoadElimination(); -} - PreservedAnalyses LoopLoadEliminationPass::run(Function &F, FunctionAnalysisManager &AM) { auto &LI = AM.getResult<LoopAnalysis>(F); @@ -744,5 +686,7 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, return PreservedAnalyses::all(); PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); return PA; } diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index c98b94b56e48..2c8a3351281b 100644 --- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -59,7 +59,7 @@ void PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, P->printPipeline(OS, MapClassName2PassName); } if (Idx + 1 < Size) - OS << ","; + OS << ','; } } @@ -193,7 +193,7 @@ void FunctionToLoopPassAdaptor::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { OS << (UseMemorySSA ? "loop-mssa(" : "loop("); Pass->printPipeline(OS, MapClassName2PassName); - OS << ")"; + OS << ')'; } PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, FunctionAnalysisManager &AM) { diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 49c0fff84d81..12852ae5c460 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -623,7 +623,8 @@ std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, GuardStart, GuardLimit); IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); - return Builder.CreateAnd(FirstIterationCheck, LimitCheck); + return Builder.CreateFreeze( + Builder.CreateAnd(FirstIterationCheck, LimitCheck)); } std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( @@ -671,7 +672,8 @@ std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, SE->getOne(Ty)); IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); - return Builder.CreateAnd(FirstIterationCheck, LimitCheck); + return Builder.CreateFreeze( + Builder.CreateAnd(FirstIterationCheck, LimitCheck)); } static void normalizePredicate(ScalarEvolution *SE, Loop *L, @@ -863,7 +865,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions( BI->setCondition(AllChecks); if (InsertAssumesOfPredicatedGuardsConditions) { Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt()); - Builder.CreateAssumption(Cond); + // If this block has other predecessors, we might not be able to use Cond. + // In this case, create a Phi where every other input is `true` and input + // from guard block is Cond. + Value *AssumeCond = Cond; + if (!IfTrueBB->getUniquePredecessor()) { + auto *GuardBB = BI->getParent(); + auto *PN = Builder.CreatePHI(Cond->getType(), pred_size(IfTrueBB), + "assume.cond"); + for (auto *Pred : predecessors(IfTrueBB)) + PN->addIncoming(Pred == GuardBB ? Cond : Builder.getTrue(), Pred); + AssumeCond = PN; + } + Builder.CreateAssumption(AssumeCond); } RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); assert(isGuardAsWidenableBranch(BI) && @@ -1161,6 +1175,11 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { if (ChangedLoop) SE->forgetLoop(L); + // The insertion point for the widening should be at the widenably call, not + // at the WidenableBR. If we do this at the widenableBR, we can incorrectly + // change a loop-invariant condition to a loop-varying one. + auto *IP = cast<Instruction>(WidenableBR->getCondition()); + // The use of umin(all analyzeable exits) instead of latch is subtle, but // important for profitability. We may have a loop which hasn't been fully // canonicalized just yet. If the exit we chose to widen is provably never @@ -1170,21 +1189,9 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { const SCEV *MinEC = getMinAnalyzeableBackedgeTakenCount(*SE, *DT, L); if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() || !SE->isLoopInvariant(MinEC, L) || - !Rewriter.isSafeToExpandAt(MinEC, WidenableBR)) + !Rewriter.isSafeToExpandAt(MinEC, IP)) return ChangedLoop; - // Subtlety: We need to avoid inserting additional uses of the WC. We know - // that it can only have one transitive use at the moment, and thus moving - // that use to just before the branch and inserting code before it and then - // modifying the operand is legal. - auto *IP = cast<Instruction>(WidenableBR->getCondition()); - // Here we unconditionally modify the IR, so after this point we should return - // only `true`! - IP->moveBefore(WidenableBR); - if (MSSAU) - if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP)) - MSSAU->moveToPlace(MUD, WidenableBR->getParent(), - MemorySSA::BeforeTerminator); Rewriter.setInsertPoint(IP); IRBuilder<> B(IP); diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index a0b3189c7e09..7f62526a4f6d 100644 --- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -39,13 +39,10 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopReroll.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -157,22 +154,6 @@ namespace { IL_End }; - class LoopRerollLegacyPass : public LoopPass { - public: - static char ID; // Pass ID, replacement for typeid - - LoopRerollLegacyPass() : LoopPass(ID) { - initializeLoopRerollLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - getLoopAnalysisUsage(AU); - } - }; - class LoopReroll { public: LoopReroll(AliasAnalysis *AA, LoopInfo *LI, ScalarEvolution *SE, @@ -490,17 +471,6 @@ namespace { } // end anonymous namespace -char LoopRerollLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopRerollLegacyPass, "loop-reroll", "Reroll loops", - false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(LoopRerollLegacyPass, "loop-reroll", "Reroll loops", false, - false) - -Pass *llvm::createLoopRerollPass() { return new LoopRerollLegacyPass; } - // Returns true if the provided instruction is used outside the given loop. // This operates like Instruction::isUsedOutsideOfBlock, but considers PHIs in // non-loop blocks to be outside the loop. @@ -1700,21 +1670,6 @@ bool LoopReroll::runOnLoop(Loop *L) { return Changed; } -bool LoopRerollLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - - auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( - *L->getHeader()->getParent()); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); - - return LoopReroll(AA, LI, SE, TLI, DT, PreserveLCSSA).runOnLoop(L); -} - PreservedAnalyses LoopRerollPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp index ba735adc5b27..eee855058706 100644 --- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -43,6 +43,21 @@ LoopRotatePass::LoopRotatePass(bool EnableHeaderDuplication, bool PrepareForLTO) : EnableHeaderDuplication(EnableHeaderDuplication), PrepareForLTO(PrepareForLTO) {} +void LoopRotatePass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LoopRotatePass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (!EnableHeaderDuplication) + OS << "no-"; + OS << "header-duplication;"; + + if (!PrepareForLTO) + OS << "no-"; + OS << "prepare-for-lto"; + OS << ">"; +} + PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp index 21025b0bdb33..597c159682c5 100644 --- a/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -177,13 +177,27 @@ static bool sinkInstruction( SmallPtrSet<BasicBlock *, 2> BBs; for (auto &U : I.uses()) { Instruction *UI = cast<Instruction>(U.getUser()); - // We cannot sink I to PHI-uses. - if (isa<PHINode>(UI)) - return false; + // We cannot sink I if it has uses outside of the loop. if (!L.contains(LI.getLoopFor(UI->getParent()))) return false; - BBs.insert(UI->getParent()); + + if (!isa<PHINode>(UI)) { + BBs.insert(UI->getParent()); + continue; + } + + // We cannot sink I to PHI-uses, try to look through PHI to find the incoming + // block of the value being used. + PHINode *PN = dyn_cast<PHINode>(UI); + BasicBlock *PhiBB = PN->getIncomingBlock(U); + + // If value's incoming block is from loop preheader directly, there's no + // place to sink to, bailout. + if (L.getLoopPreheader() == PhiBB) + return false; + + BBs.insert(PhiBB); } // findBBsToSinkInto is O(BBs.size() * ColdLoopBBs.size()). We cap the max @@ -238,9 +252,11 @@ static bool sinkInstruction( } } - // Replaces uses of I with IC in N + // Replaces uses of I with IC in N, except PHI-use which is being taken + // care of by defs in PHI's incoming blocks. I.replaceUsesWithIf(IC, [N](Use &U) { - return cast<Instruction>(U.getUser())->getParent() == N; + Instruction *UIToReplace = cast<Instruction>(U.getUser()); + return UIToReplace->getParent() == N && !isa<PHINode>(UIToReplace); }); // Replaces uses of I with IC in blocks dominated by N replaceDominatedUsesWith(&I, IC, DT, N); @@ -283,7 +299,7 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, return false; MemorySSAUpdater MSSAU(&MSSA); - SinkAndHoistLICMFlags LICMFlags(/*IsSink=*/true, &L, &MSSA); + SinkAndHoistLICMFlags LICMFlags(/*IsSink=*/true, L, MSSA); bool Changed = false; @@ -323,6 +339,11 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, } PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { + // Enable LoopSink only when runtime profile is available. + // With static profile, the sinking decision may be sub-optimal. + if (!F.hasProfileData()) + return PreservedAnalyses::all(); + LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); // Nothing to do if there are no loops. if (LI.empty()) @@ -348,11 +369,6 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { if (!Preheader) continue; - // Enable LoopSink only when runtime profile is available. - // With static profile, the sinking decision may be sub-optimal. - if (!Preheader->getParent()->hasProfileData()) - continue; - // Note that we don't pass SCEV here because it is only used to invalidate // loops in SCEV and we don't preserve (or request) SCEV at all making that // unnecessary. diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 4c89f947d7fc..a4369b83e732 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -799,7 +799,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, /// value, and mutate S to point to a new SCEV with that value excluded. static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { - if (C->getAPInt().getMinSignedBits() <= 64) { + if (C->getAPInt().getSignificantBits() <= 64) { S = SE.getConstant(C->getType(), 0); return C->getValue()->getSExtValue(); } @@ -896,9 +896,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI, /// Return the type of the memory being accessed. static MemAccessTy getAccessType(const TargetTransformInfo &TTI, Instruction *Inst, Value *OperandVal) { - MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); + MemAccessTy AccessTy = MemAccessTy::getUnknown(Inst->getContext()); + + // First get the type of memory being accessed. + if (Type *Ty = Inst->getAccessType()) + AccessTy.MemTy = Ty; + + // Then get the pointer address space. if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - AccessTy.MemTy = SI->getOperand(0)->getType(); AccessTy.AddrSpace = SI->getPointerAddressSpace(); } else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { AccessTy.AddrSpace = LI->getPointerAddressSpace(); @@ -923,7 +928,6 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI, II->getArgOperand(0)->getType()->getPointerAddressSpace(); break; case Intrinsic::masked_store: - AccessTy.MemTy = II->getOperand(0)->getType(); AccessTy.AddrSpace = II->getArgOperand(1)->getType()->getPointerAddressSpace(); break; @@ -976,6 +980,7 @@ static bool isHighCostExpansion(const SCEV *S, switch (S->getSCEVType()) { case scUnknown: case scConstant: + case scVScale: return false; case scTruncate: return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(), @@ -1414,7 +1419,7 @@ void Cost::RateFormula(const Formula &F, C.ImmCost += 64; // Handle symbolic values conservatively. // TODO: This should probably be the pointer size. else if (Offset != 0) - C.ImmCost += APInt(64, Offset, true).getMinSignedBits(); + C.ImmCost += APInt(64, Offset, true).getSignificantBits(); // Check with target if this offset with this instruction is // specifically not supported. @@ -2498,7 +2503,7 @@ LSRInstance::OptimizeLoopTermCond() { if (C->isOne() || C->isMinusOne()) goto decline_post_inc; // Avoid weird situations. - if (C->getValue().getMinSignedBits() >= 64 || + if (C->getValue().getSignificantBits() >= 64 || C->getValue().isMinSignedValue()) goto decline_post_inc; // Check for possible scaled-address reuse. @@ -2508,13 +2513,13 @@ LSRInstance::OptimizeLoopTermCond() { int64_t Scale = C->getSExtValue(); if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, /*BaseOffset=*/0, - /*HasBaseReg=*/false, Scale, + /*HasBaseReg=*/true, Scale, AccessTy.AddrSpace)) goto decline_post_inc; Scale = -Scale; if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, /*BaseOffset=*/0, - /*HasBaseReg=*/false, Scale, + /*HasBaseReg=*/true, Scale, AccessTy.AddrSpace)) goto decline_post_inc; } @@ -2660,8 +2665,7 @@ LSRUse * LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF, const LSRUse &OrigLU) { // Search all uses for the formula. This could be more clever. - for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { - LSRUse &LU = Uses[LUIdx]; + for (LSRUse &LU : Uses) { // Check whether this use is close enough to OrigLU, to see whether it's // worthwhile looking through its formulae. // Ignore ICmpZero uses because they may contain formulae generated by @@ -2703,6 +2707,8 @@ void LSRInstance::CollectInterestingTypesAndFactors() { SmallVector<const SCEV *, 4> Worklist; for (const IVStrideUse &U : IU) { const SCEV *Expr = IU.getExpr(U); + if (!Expr) + continue; // Collect interesting types. Types.insert(SE.getEffectiveSCEVType(Expr->getType())); @@ -2740,13 +2746,13 @@ void LSRInstance::CollectInterestingTypesAndFactors() { if (const SCEVConstant *Factor = dyn_cast_or_null<SCEVConstant>(getExactSDiv(NewStride, OldStride, SE, true))) { - if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero()) + if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero()) Factors.insert(Factor->getAPInt().getSExtValue()); } else if (const SCEVConstant *Factor = dyn_cast_or_null<SCEVConstant>(getExactSDiv(OldStride, NewStride, SE, true))) { - if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero()) + if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero()) Factors.insert(Factor->getAPInt().getSExtValue()); } } @@ -2812,9 +2818,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) { /// SCEVUnknown, we simply return the rightmost SCEV operand. static const SCEV *getExprBase(const SCEV *S) { switch (S->getSCEVType()) { - default: // uncluding scUnknown. + default: // including scUnknown. return S; case scConstant: + case scVScale: return nullptr; case scTruncate: return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand()); @@ -3175,7 +3182,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) return false; - if (IncConst->getAPInt().getMinSignedBits() > 64) + if (IncConst->getAPInt().getSignificantBits() > 64) return false; MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); @@ -3320,6 +3327,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { } const SCEV *S = IU.getExpr(U); + if (!S) + continue; PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops(); // Equality (== and !=) ICmps are special. We can rewrite (i == N) as @@ -3352,6 +3361,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // S is normalized, so normalize N before folding it into S // to keep the result normalized. N = normalizeForPostIncUse(N, TmpPostIncLoops, SE); + if (!N) + continue; Kind = LSRUse::ICmpZero; S = SE.getMinusSCEV(N, S); } else if (L->isLoopInvariant(NV) && @@ -3366,6 +3377,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // SCEV can't compute the difference of two unknown pointers. N = SE.getUnknown(NV); N = normalizeForPostIncUse(N, TmpPostIncLoops, SE); + if (!N) + continue; Kind = LSRUse::ICmpZero; S = SE.getMinusSCEV(N, S); assert(!isa<SCEVCouldNotCompute>(S)); @@ -3494,8 +3507,8 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { if (const Instruction *Inst = dyn_cast<Instruction>(V)) { // Look for instructions defined outside the loop. if (L->contains(Inst)) continue; - } else if (isa<UndefValue>(V)) - // Undef doesn't have a live range, so it doesn't matter. + } else if (isa<Constant>(V)) + // Constants can be re-materialized. continue; for (const Use &U : V->uses()) { const Instruction *UserInst = dyn_cast<Instruction>(U.getUser()); @@ -4137,6 +4150,29 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { } } +/// Extend/Truncate \p Expr to \p ToTy considering post-inc uses in \p Loops. +/// For all PostIncLoopSets in \p Loops, first de-normalize \p Expr, then +/// perform the extension/truncate and normalize again, as the normalized form +/// can result in folds that are not valid in the post-inc use contexts. The +/// expressions for all PostIncLoopSets must match, otherwise return nullptr. +static const SCEV * +getAnyExtendConsideringPostIncUses(ArrayRef<PostIncLoopSet> Loops, + const SCEV *Expr, Type *ToTy, + ScalarEvolution &SE) { + const SCEV *Result = nullptr; + for (auto &L : Loops) { + auto *DenormExpr = denormalizeForPostIncUse(Expr, L, SE); + const SCEV *NewDenormExpr = SE.getAnyExtendExpr(DenormExpr, ToTy); + const SCEV *New = normalizeForPostIncUse(NewDenormExpr, L, SE); + if (!New || (Result && New != Result)) + return nullptr; + Result = New; + } + + assert(Result && "failed to create expression"); + return Result; +} + /// Generate reuse formulae from different IV types. void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { // Don't bother truncating symbolic values. @@ -4156,6 +4192,10 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { [](const SCEV *S) { return S->getType()->isPointerTy(); })) return; + SmallVector<PostIncLoopSet> Loops; + for (auto &LF : LU.Fixups) + Loops.push_back(LF.PostIncLoops); + for (Type *SrcTy : Types) { if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) { Formula F = Base; @@ -4165,15 +4205,17 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { // initial node (maybe due to depth limitations), but it can do them while // taking ext. if (F.ScaledReg) { - const SCEV *NewScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy); - if (NewScaledReg->isZero()) - continue; + const SCEV *NewScaledReg = + getAnyExtendConsideringPostIncUses(Loops, F.ScaledReg, SrcTy, SE); + if (!NewScaledReg || NewScaledReg->isZero()) + continue; F.ScaledReg = NewScaledReg; } bool HasZeroBaseReg = false; for (const SCEV *&BaseReg : F.BaseRegs) { - const SCEV *NewBaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy); - if (NewBaseReg->isZero()) { + const SCEV *NewBaseReg = + getAnyExtendConsideringPostIncUses(Loops, BaseReg, SrcTy, SE); + if (!NewBaseReg || NewBaseReg->isZero()) { HasZeroBaseReg = true; break; } @@ -4379,8 +4421,8 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { if ((C->getAPInt() + NewF.BaseOffset) .abs() .slt(std::abs(NewF.BaseOffset)) && - (C->getAPInt() + NewF.BaseOffset).countTrailingZeros() >= - countTrailingZeros<uint64_t>(NewF.BaseOffset)) + (C->getAPInt() + NewF.BaseOffset).countr_zero() >= + (unsigned)llvm::countr_zero<uint64_t>(NewF.BaseOffset)) goto skip_formula; // Ok, looks good. @@ -4982,6 +5024,32 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } +// Check if Best and Reg are SCEVs separated by a constant amount C, and if so +// would the addressing offset +C would be legal where the negative offset -C is +// not. +static bool IsSimplerBaseSCEVForTarget(const TargetTransformInfo &TTI, + ScalarEvolution &SE, const SCEV *Best, + const SCEV *Reg, + MemAccessTy AccessType) { + if (Best->getType() != Reg->getType() || + (isa<SCEVAddRecExpr>(Best) && isa<SCEVAddRecExpr>(Reg) && + cast<SCEVAddRecExpr>(Best)->getLoop() != + cast<SCEVAddRecExpr>(Reg)->getLoop())) + return false; + const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Best, Reg)); + if (!Diff) + return false; + + return TTI.isLegalAddressingMode( + AccessType.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/Diff->getAPInt().getSExtValue(), + /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace) && + !TTI.isLegalAddressingMode( + AccessType.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/-Diff->getAPInt().getSExtValue(), + /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace); +} + /// Pick a register which seems likely to be profitable, and then in any use /// which has any reference to that register, delete all formulae which do not /// reference that register. @@ -5010,6 +5078,19 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { Best = Reg; BestNum = Count; } + + // If the scores are the same, but the Reg is simpler for the target + // (for example {x,+,1} as opposed to {x+C,+,1}, where the target can + // handle +C but not -C), opt for the simpler formula. + if (Count == BestNum) { + int LUIdx = RegUses.getUsedByIndices(Reg).find_first(); + if (LUIdx >= 0 && Uses[LUIdx].Kind == LSRUse::Address && + IsSimplerBaseSCEVForTarget(TTI, SE, Best, Reg, + Uses[LUIdx].AccessTy)) { + Best = Reg; + BestNum = Count; + } + } } } assert(Best && "Failed to find best LSRUse candidate"); @@ -5497,6 +5578,13 @@ void LSRInstance::RewriteForPHI( PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F, SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { DenseMap<BasicBlock *, Value *> Inserted; + + // Inserting instructions in the loop and using them as PHI's input could + // break LCSSA in case if PHI's parent block is not a loop exit (i.e. the + // corresponding incoming block is not loop exiting). So collect all such + // instructions to form LCSSA for them later. + SmallVector<Instruction *, 4> InsertedNonLCSSAInsts; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (PN->getIncomingValue(i) == LF.OperandValToReplace) { bool needUpdateFixups = false; @@ -5562,6 +5650,13 @@ void LSRInstance::RewriteForPHI( FullV, LF.OperandValToReplace->getType(), "tmp", BB->getTerminator()); + // If the incoming block for this value is not in the loop, it means the + // current PHI is not in a loop exit, so we must create a LCSSA PHI for + // the inserted value. + if (auto *I = dyn_cast<Instruction>(FullV)) + if (L->contains(I) && !L->contains(BB)) + InsertedNonLCSSAInsts.push_back(I); + PN->setIncomingValue(i, FullV); Pair.first->second = FullV; } @@ -5604,6 +5699,8 @@ void LSRInstance::RewriteForPHI( } } } + + formLCSSAForInstructions(InsertedNonLCSSAInsts, DT, LI, &SE); } /// Emit instructions for the leading candidate expression for this LSRUse (this @@ -5643,6 +5740,36 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, DeadInsts.emplace_back(OperandIsInstr); } +// Trying to hoist the IVInc to loop header if all IVInc users are in +// the loop header. It will help backend to generate post index load/store +// when the latch block is different from loop header block. +static bool canHoistIVInc(const TargetTransformInfo &TTI, const LSRFixup &Fixup, + const LSRUse &LU, Instruction *IVIncInsertPos, + Loop *L) { + if (LU.Kind != LSRUse::Address) + return false; + + // For now this code do the conservative optimization, only work for + // the header block. Later we can hoist the IVInc to the block post + // dominate all users. + BasicBlock *LHeader = L->getHeader(); + if (IVIncInsertPos->getParent() == LHeader) + return false; + + if (!Fixup.OperandValToReplace || + any_of(Fixup.OperandValToReplace->users(), [&LHeader](User *U) { + Instruction *UI = cast<Instruction>(U); + return UI->getParent() != LHeader; + })) + return false; + + Instruction *I = Fixup.UserInst; + Type *Ty = I->getType(); + return Ty->isIntegerTy() && + ((isa<LoadInst>(I) && TTI.isIndexedLoadLegal(TTI.MIM_PostInc, Ty)) || + (isa<StoreInst>(I) && TTI.isIndexedStoreLegal(TTI.MIM_PostInc, Ty))); +} + /// Rewrite all the fixup locations with new values, following the chosen /// solution. void LSRInstance::ImplementSolution( @@ -5651,8 +5778,6 @@ void LSRInstance::ImplementSolution( // we can remove them after we are done working. SmallVector<WeakTrackingVH, 16> DeadInsts; - Rewriter.setIVIncInsertPos(L, IVIncInsertPos); - // Mark phi nodes that terminate chains so the expander tries to reuse them. for (const IVChain &Chain : IVChainVec) { if (PHINode *PN = dyn_cast<PHINode>(Chain.tailUserInst())) @@ -5662,6 +5787,11 @@ void LSRInstance::ImplementSolution( // Expand the new value definitions and update the users. for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) { + Instruction *InsertPos = + canHoistIVInc(TTI, Fixup, Uses[LUIdx], IVIncInsertPos, L) + ? L->getHeader()->getTerminator() + : IVIncInsertPos; + Rewriter.setIVIncInsertPos(L, InsertPos); Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], DeadInsts); Changed = true; } @@ -5994,7 +6124,7 @@ struct SCEVDbgValueBuilder { } bool pushConst(const SCEVConstant *C) { - if (C->getAPInt().getMinSignedBits() > 64) + if (C->getAPInt().getSignificantBits() > 64) return false; Expr.push_back(llvm::dwarf::DW_OP_consts); Expr.push_back(C->getAPInt().getSExtValue()); @@ -6083,7 +6213,7 @@ struct SCEVDbgValueBuilder { /// SCEV constant value is an identity function. bool isIdentityFunction(uint64_t Op, const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { - if (C->getAPInt().getMinSignedBits() > 64) + if (C->getAPInt().getSignificantBits() > 64) return false; int64_t I = C->getAPInt().getSExtValue(); switch (Op) { @@ -6338,13 +6468,13 @@ static void UpdateDbgValueInst(DVIRecoveryRec &DVIRec, } } -/// Cached location ops may be erased during LSR, in which case an undef is +/// Cached location ops may be erased during LSR, in which case a poison is /// required when restoring from the cache. The type of that location is no -/// longer available, so just use int8. The undef will be replaced by one or +/// longer available, so just use int8. The poison will be replaced by one or /// more locations later when a SCEVDbgValueBuilder selects alternative /// locations to use for the salvage. -static Value *getValueOrUndef(WeakVH &VH, LLVMContext &C) { - return (VH) ? VH : UndefValue::get(llvm::Type::getInt8Ty(C)); +static Value *getValueOrPoison(WeakVH &VH, LLVMContext &C) { + return (VH) ? VH : PoisonValue::get(llvm::Type::getInt8Ty(C)); } /// Restore the DVI's pre-LSR arguments. Substitute undef for any erased values. @@ -6363,12 +6493,12 @@ static void restorePreTransformState(DVIRecoveryRec &DVIRec) { // this case was not present before, so force the location back to a single // uncontained Value. Value *CachedValue = - getValueOrUndef(DVIRec.LocationOps[0], DVIRec.DVI->getContext()); + getValueOrPoison(DVIRec.LocationOps[0], DVIRec.DVI->getContext()); DVIRec.DVI->setRawLocation(ValueAsMetadata::get(CachedValue)); } else { SmallVector<ValueAsMetadata *, 3> MetadataLocs; for (WeakVH VH : DVIRec.LocationOps) { - Value *CachedValue = getValueOrUndef(VH, DVIRec.DVI->getContext()); + Value *CachedValue = getValueOrPoison(VH, DVIRec.DVI->getContext()); MetadataLocs.push_back(ValueAsMetadata::get(CachedValue)); } auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(MetadataLocs); @@ -6431,7 +6561,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, // less DWARF ops than an iteration count-based expression. if (std::optional<APInt> Offset = SE.computeConstantDifference(DVIRec.SCEVs[i], SCEVInductionVar)) { - if (Offset->getMinSignedBits() <= 64) + if (Offset->getSignificantBits() <= 64) SalvageExpr->createOffsetExpr(Offset->getSExtValue(), LSRInductionVar); } else if (!SalvageExpr->createIterCountExpr(DVIRec.SCEVs[i], IterCountExpr, SE)) @@ -6607,7 +6737,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE, return nullptr; } -static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *>> +static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, const LoopInfo &LI) { if (!L->isInnermost()) { @@ -6626,16 +6756,13 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, } BasicBlock *LoopLatch = L->getLoopLatch(); - - // TODO: Can we do something for greater than and less than? - // Terminating condition is foldable when it is an eq/ne icmp - BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); - if (BI->isUnconditional()) + BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); + if (!BI || BI->isUnconditional()) return std::nullopt; - Value *TermCond = BI->getCondition(); - if (!isa<ICmpInst>(TermCond) || !cast<ICmpInst>(TermCond)->isEquality()) { - LLVM_DEBUG(dbgs() << "Cannot fold on branching condition that is not an " - "ICmpInst::eq / ICmpInst::ne\n"); + auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition()); + if (!TermCond) { + LLVM_DEBUG( + dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); return std::nullopt; } if (!TermCond->hasOneUse()) { @@ -6645,89 +6772,42 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, return std::nullopt; } - // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it - // is only used by the terminating condition. To check for this, we may need - // to traverse through a chain of use-def until we can examine the final - // usage. - // *----------------------* - // *---->| LoopHeader: | - // | | PrimaryIV = phi ... | - // | *----------------------* - // | | - // | | - // | chain of - // | single use - // used by | - // phi | - // | Value - // | / \ - // | chain of chain of - // | single use single use - // | / \ - // | / \ - // *- Value Value --> used by terminating condition - auto IsToFold = [&](PHINode &PN) -> bool { - Value *V = &PN; - - while (V->getNumUses() == 1) - V = *V->user_begin(); - - if (V->getNumUses() != 2) - return false; + BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0)); + Value *RHS = TermCond->getOperand(1); + if (!LHS || !L->isLoopInvariant(RHS)) + // We could pattern match the inverse form of the icmp, but that is + // non-canonical, and this pass is running *very* late in the pipeline. + return std::nullopt; - Value *VToPN = nullptr; - Value *VToTermCond = nullptr; - for (User *U : V->users()) { - while (U->getNumUses() == 1) { - if (isa<PHINode>(U)) - VToPN = U; - if (U == TermCond) - VToTermCond = U; - U = *U->user_begin(); - } - } - return VToPN && VToTermCond; - }; + // Find the IV used by the current exit condition. + PHINode *ToFold; + Value *ToFoldStart, *ToFoldStep; + if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) + return std::nullopt; - // If this is an IV which we could replace the terminating condition, return - // the final value of the alternative IV on the last iteration. - auto getAlternateIVEnd = [&](PHINode &PN) -> const SCEV * { - // FIXME: This does not properly account for overflow. - const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); - const SCEV *BECount = SE.getBackedgeTakenCount(L); - const SCEV *TermValueS = SE.getAddExpr( - AddRec->getOperand(0), - SE.getTruncateOrZeroExtend( - SE.getMulExpr( - AddRec->getOperand(1), - SE.getTruncateOrZeroExtend( - SE.getAddExpr(BECount, SE.getOne(BECount->getType())), - AddRec->getOperand(1)->getType())), - AddRec->getOperand(0)->getType())); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - if (!Expander.isSafeToExpand(TermValueS)) { - LLVM_DEBUG( - dbgs() << "Is not safe to expand terminating value for phi node" << PN - << "\n"); - return nullptr; - } - return TermValueS; - }; + // If that IV isn't dead after we rewrite the exit condition in terms of + // another IV, there's no point in doing the transform. + if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) + return std::nullopt; + + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - PHINode *ToFold = nullptr; PHINode *ToHelpFold = nullptr; const SCEV *TermValueS = nullptr; - + bool MustDropPoison = false; for (PHINode &PN : L->getHeader()->phis()) { + if (ToFold == &PN) + continue; + if (!SE.isSCEVable(PN.getType())) { LLVM_DEBUG(dbgs() << "IV of phi '" << PN << "' is not SCEV-able, not qualified for the " "terminating condition folding.\n"); continue; } - const SCEV *S = SE.getSCEV(&PN); - const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S); + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); // Only speculate on affine AddRec if (!AddRec || !AddRec->isAffine()) { LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN @@ -6736,12 +6816,63 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, continue; } - if (IsToFold(PN)) - ToFold = &PN; - else if (auto P = getAlternateIVEnd(PN)) { - ToHelpFold = &PN; - TermValueS = P; + // Check that we can compute the value of AddRec on the exiting iteration + // without soundness problems. evaluateAtIteration internally needs + // to multiply the stride of the iteration number - which may wrap around. + // The issue here is subtle because computing the result accounting for + // wrap is insufficient. In order to use the result in an exit test, we + // must also know that AddRec doesn't take the same value on any previous + // iteration. The simplest case to consider is a candidate IV which is + // narrower than the trip count (and thus original IV), but this can + // also happen due to non-unit strides on the candidate IVs. + if (!AddRec->hasNoSelfWrap()) + continue; + + const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); + const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); + if (!Expander.isSafeToExpand(TermValueSLocal)) { + LLVM_DEBUG( + dbgs() << "Is not safe to expand terminating value for phi node" << PN + << "\n"); + continue; } + + // The candidate IV may have been otherwise dead and poison from the + // very first iteration. If we can't disprove that, we can't use the IV. + if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { + LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " + << PN << "\n"); + continue; + } + + // The candidate IV may become poison on the last iteration. If this + // value is not branched on, this is a well defined program. We're + // about to add a new use to this IV, and we have to ensure we don't + // insert UB which didn't previously exist. + bool MustDropPoisonLocal = false; + Instruction *PostIncV = + cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch)); + if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), + &DT)) { + LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" + << PN << "\n"); + + // If this is a complex recurrance with multiple instructions computing + // the backedge value, we might need to strip poison flags from all of + // them. + if (PostIncV->getOperand(0) != &PN) + continue; + + // In order to perform the transform, we need to drop the poison generating + // flags on this instruction (if any). + MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); + } + + // We pick the last legal alternate IV. We could expore choosing an optimal + // alternate IV if we had a decent heuristic to do so. + ToHelpFold = &PN; + TermValueS = TermValueSLocal; + MustDropPoison = MustDropPoisonLocal; } LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() @@ -6757,7 +6888,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, if (!ToFold || !ToHelpFold) return std::nullopt; - return std::make_tuple(ToFold, ToHelpFold, TermValueS); + return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); } static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, @@ -6820,7 +6951,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, if (AllowTerminatingConditionFoldingAfterLSR) { if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) { - auto [ToFold, ToHelpFold, TermValueS] = *Opt; + auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; Changed = true; NumTermFold++; @@ -6838,6 +6969,10 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, (void)StartValue; Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); + // See comment in canFoldTermCondOfLoop on why this is sufficient. + if (MustDrop) + cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags(); + // SCEVExpander for both use in preheader and latch const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); @@ -6859,11 +6994,12 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition()); IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); - // FIXME: We are adding a use of an IV here without account for poison safety. - // This is incorrect. - Value *NewTermCond = LatchBuilder.CreateICmp( - OldTermCond->getPredicate(), LoopValue, TermValue, - "lsr_fold_term_cond.replaced_term_cond"); + Value *NewTermCond = + LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + // Swap successors to exit loop body if IV equals to new TermValue + if (BI->getSuccessor(0) == L->getHeader()) + BI->swapSuccessors(); LLVM_DEBUG(dbgs() << "Old term-cond:\n" << *OldTermCond << "\n" diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 0ae26b494c5a..9c6e4ebf62a9 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -32,15 +32,11 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopPeel.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -460,76 +456,6 @@ static bool tryToUnrollAndJamLoop(LoopNest &LN, DominatorTree &DT, LoopInfo &LI, return DidSomething; } -namespace { - -class LoopUnrollAndJam : public LoopPass { -public: - static char ID; // Pass ID, replacement for typeid - unsigned OptLevel; - - LoopUnrollAndJam(int OptLevel = 2) : LoopPass(ID), OptLevel(OptLevel) { - initializeLoopUnrollAndJamPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - - auto *F = L->getHeader()->getParent(); - auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*F); - auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F); - - LoopUnrollResult Result = - tryToUnrollAndJamLoop(L, DT, LI, SE, TTI, AC, DI, ORE, OptLevel); - - if (Result == LoopUnrollResult::FullyUnrolled) - LPM.markLoopAsDeleted(*L); - - return Result != LoopUnrollResult::Unmodified; - } - - /// This transformation requires natural loop information & requires that - /// loop preheaders be inserted into the CFG... - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DependenceAnalysisWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - -char LoopUnrollAndJam::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopUnrollAndJam, "loop-unroll-and-jam", - "Unroll and Jam loops", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_END(LoopUnrollAndJam, "loop-unroll-and-jam", - "Unroll and Jam loops", false, false) - -Pass *llvm::createLoopUnrollAndJamPass(int OptLevel) { - return new LoopUnrollAndJam(OptLevel); -} - PreservedAnalyses LoopUnrollAndJamPass::run(LoopNest &LN, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 1a6065cb3f1a..335b489d3cb2 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -1124,7 +1124,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, bool PreserveLCSSA, int OptLevel, - bool OnlyWhenForced, bool ForgetAllSCEV, + bool OnlyFullUnroll, bool OnlyWhenForced, bool ForgetAllSCEV, std::optional<unsigned> ProvidedCount, std::optional<unsigned> ProvidedThreshold, std::optional<bool> ProvidedAllowPartial, @@ -1133,6 +1133,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, std::optional<bool> ProvidedAllowPeeling, std::optional<bool> ProvidedAllowProfileBasedPeeling, std::optional<unsigned> ProvidedFullUnrollMaxCount) { + LLVM_DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); @@ -1304,6 +1305,13 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, return LoopUnrollResult::Unmodified; } + // Do not attempt partial/runtime unrolling in FullLoopUnrolling + if (OnlyFullUnroll && !(UP.Count >= MaxTripCount)) { + LLVM_DEBUG( + dbgs() << "Not attempting partial/runtime unroll in FullLoopUnroll.\n"); + return LoopUnrollResult::Unmodified; + } + // At this point, UP.Runtime indicates that run-time unrolling is allowed. // However, we only want to actually perform it if we don't know the trip // count and the unroll count doesn't divide the known trip multiple. @@ -1420,10 +1428,10 @@ public: LoopUnrollResult Result = tryToUnrollLoop( L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr, PreserveLCSSA, OptLevel, - OnlyWhenForced, ForgetAllSCEV, ProvidedCount, ProvidedThreshold, - ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, - ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling, - ProvidedFullUnrollMaxCount); + /*OnlyFullUnroll*/ false, OnlyWhenForced, ForgetAllSCEV, ProvidedCount, + ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, + ProvidedUpperBound, ProvidedAllowPeeling, + ProvidedAllowProfileBasedPeeling, ProvidedFullUnrollMaxCount); if (Result == LoopUnrollResult::FullyUnrolled) LPM.markLoopAsDeleted(*L); @@ -1469,12 +1477,6 @@ Pass *llvm::createLoopUnrollPass(int OptLevel, bool OnlyWhenForced, AllowPeeling == -1 ? std::nullopt : std::optional<bool>(AllowPeeling)); } -Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced, - bool ForgetAllSCEV) { - return createLoopUnrollPass(OptLevel, OnlyWhenForced, ForgetAllSCEV, -1, -1, - 0, 0, 0, 1); -} - PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &Updater) { @@ -1497,8 +1499,8 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE, /*BFI*/ nullptr, /*PSI*/ nullptr, - /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced, - ForgetSCEV, /*Count*/ std::nullopt, + /*PreserveLCSSA*/ true, OptLevel, /*OnlyFullUnroll*/ true, + OnlyWhenForced, ForgetSCEV, /*Count*/ std::nullopt, /*Threshold*/ std::nullopt, /*AllowPartial*/ false, /*Runtime*/ false, /*UpperBound*/ false, /*AllowPeeling*/ true, @@ -1623,8 +1625,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, // flavors of unrolling during construction time (by setting UnrollOpts). LoopUnrollResult Result = tryToUnrollLoop( &L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI, - /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced, - UnrollOpts.ForgetSCEV, /*Count*/ std::nullopt, + /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, /*OnlyFullUnroll*/ false, + UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV, + /*Count*/ std::nullopt, /*Threshold*/ std::nullopt, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling, UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount); @@ -1651,7 +1654,7 @@ void LoopUnrollPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<LoopUnrollPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (UnrollOpts.AllowPartial != std::nullopt) OS << (*UnrollOpts.AllowPartial ? "" : "no-") << "partial;"; if (UnrollOpts.AllowPeeling != std::nullopt) @@ -1664,7 +1667,7 @@ void LoopUnrollPass::printPipeline( OS << (*UnrollOpts.AllowProfileBasedPeeling ? "" : "no-") << "profile-peeling;"; if (UnrollOpts.FullUnrollMaxCount != std::nullopt) - OS << "full-unroll-max=" << UnrollOpts.FullUnrollMaxCount << ";"; - OS << "O" << UnrollOpts.OptLevel; - OS << ">"; + OS << "full-unroll-max=" << UnrollOpts.FullUnrollMaxCount << ';'; + OS << 'O' << UnrollOpts.OptLevel; + OS << '>'; } diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 848be25a2fe0..13e06c79d0d7 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -77,13 +77,10 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" @@ -113,33 +110,6 @@ static cl::opt<unsigned> LVLoopDepthThreshold( namespace { -struct LoopVersioningLICMLegacyPass : public LoopPass { - static char ID; - - LoopVersioningLICMLegacyPass() : LoopPass(ID) { - initializeLoopVersioningLICMLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - - StringRef getPassName() const override { return "Loop Versioning for LICM"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequiredID(LCSSAID); - AU.addRequired<LoopAccessLegacyAnalysis>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - } -}; - struct LoopVersioningLICM { // We don't explicitly pass in LoopAccessInfo to the constructor since the // loop versioning might return early due to instructions that are not safe @@ -563,21 +533,6 @@ void LoopVersioningLICM::setNoAliasToLoop(Loop *VerLoop) { } } -bool LoopVersioningLICMLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - - AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - OptimizationRemarkEmitter *ORE = - &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - - return LoopVersioningLICM(AA, SE, ORE, LAIs, LI, L).run(DT); -} - bool LoopVersioningLICM::run(DominatorTree *DT) { // Do not do the transformation if disabled by metadata. if (hasLICMVersioningTransformation(CurLoop) & TM_Disable) @@ -611,26 +566,6 @@ bool LoopVersioningLICM::run(DominatorTree *DT) { return Changed; } -char LoopVersioningLICMLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopVersioningLICMLegacyPass, "loop-versioning-licm", - "Loop Versioning For LICM", false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_END(LoopVersioningLICMLegacyPass, "loop-versioning-licm", - "Loop Versioning For LICM", false, false) - -Pass *llvm::createLoopVersioningLICMPass() { - return new LoopVersioningLICMLegacyPass(); -} - namespace llvm { PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM, diff --git a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index ef22b0401b1b..b167120a906d 100644 --- a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <optional> @@ -136,10 +137,12 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo &TLI, continue; case Intrinsic::is_constant: NewValue = lowerIsConstantIntrinsic(II); + LLVM_DEBUG(dbgs() << "Folding " << *II << " to " << *NewValue << "\n"); IsConstantIntrinsicsHandled++; break; case Intrinsic::objectsize: NewValue = lowerObjectSizeCall(II, DL, &TLI, true); + LLVM_DEBUG(dbgs() << "Folding " << *II << " to " << *NewValue << "\n"); ObjectSizeIntrinsicsHandled++; break; } diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 17594b98c5bc..f46ea6a20afa 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -72,6 +72,11 @@ static cl::opt<bool> AllowContractEnabled( cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error.")); +static cl::opt<bool> + VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, + cl::desc("Enable/disable matrix shape verification."), + cl::init(false)); + enum class MatrixLayoutTy { ColumnMajor, RowMajor }; static cl::opt<MatrixLayoutTy> MatrixLayout( @@ -267,7 +272,7 @@ class LowerMatrixIntrinsics { unsigned D = isColumnMajor() ? NumColumns : NumRows; for (unsigned J = 0; J < D; ++J) - addVector(UndefValue::get(FixedVectorType::get( + addVector(PoisonValue::get(FixedVectorType::get( EltTy, isColumnMajor() ? NumRows : NumColumns))); } @@ -535,6 +540,15 @@ public: auto SIter = ShapeMap.find(V); if (SIter != ShapeMap.end()) { + if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || + SIter->second.NumColumns != Shape.NumColumns)) { + errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" + << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" + << Shape.NumColumns << ") for " << *V << "\n"; + report_fatal_error( + "Matrix shape verification failed, compilation aborted!"); + } + LLVM_DEBUG(dbgs() << " not overriding existing shape: " << SIter->second.NumRows << " " << SIter->second.NumColumns << " for " << *V << "\n"); @@ -838,10 +852,13 @@ public: auto NewInst = distributeTransposes( TAMA, {R, C}, TAMB, {R, C}, Builder, [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { - auto *FAdd = - cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd")); - setShapeInfo(FAdd, Shape0); - return FAdd; + bool IsFP = I.getType()->isFPOrFPVectorTy(); + auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") + : LocalBuilder.CreateAdd(T0, T1, "madd"); + + auto *Result = cast<Instruction>(Add); + setShapeInfo(Result, Shape0); + return Result; }); updateShapeAndReplaceAllUsesWith(I, NewInst); eraseFromParentAndMove(&I, II, BB); @@ -978,13 +995,18 @@ public: MatrixInsts.push_back(&I); } - // Second, try to fuse candidates. + // Second, try to lower any dot products SmallPtrSet<Instruction *, 16> FusedInsts; for (CallInst *CI : MaybeFusableInsts) + lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); + + // Third, try to fuse candidates. + for (CallInst *CI : MaybeFusableInsts) LowerMatrixMultiplyFused(CI, FusedInsts); + Changed = !FusedInsts.empty(); - // Third, lower remaining instructions with shape information. + // Fourth, lower remaining instructions with shape information. for (Instruction *Inst : MatrixInsts) { if (FusedInsts.count(Inst)) continue; @@ -1311,6 +1333,165 @@ public: } } + /// Special case for MatMul lowering. Prevents scalar loads of row-major + /// vectors Lowers to vector reduction add instead of sequential add if + /// reassocation is enabled. + void lowerDotProduct(CallInst *MatMul, + SmallPtrSet<Instruction *, 16> &FusedInsts, + FastMathFlags FMF) { + if (FusedInsts.contains(MatMul) || + MatrixLayout != MatrixLayoutTy::ColumnMajor) + return; + ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); + ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + + if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product + return; + + Value *LHS = MatMul->getArgOperand(0); + Value *RHS = MatMul->getArgOperand(1); + + Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); + bool IsIntVec = ElementType->isIntegerTy(); + + // Floating point reductions require reassocation. + if (!IsIntVec && !FMF.allowReassoc()) + return; + + auto CanBeFlattened = [this](Value *Op) { + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) + return true; + return match( + Op, m_OneUse(m_CombineOr( + m_Load(m_Value()), + m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), + m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(), m_SpecificInt(1)))))); + }; + // Returns the cost benefit of using \p Op with the dot product lowering. If + // the returned cost is < 0, the argument is cheaper to use in the + // dot-product lowering. + auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { + if (!isa<Instruction>(Op)) + return InstructionCost(0); + + FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); + Type *EltTy = VecTy->getElementType(); + + if (!CanBeFlattened(Op)) { + InstructionCost EmbedCost(0); + // Roughly estimate the cost for embedding the columns into a vector. + for (unsigned I = 1; I < N; ++I) + EmbedCost -= + TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), + std::nullopt, TTI::TCK_RecipThroughput); + return EmbedCost; + } + + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { + InstructionCost OriginalCost = + TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), + EltTy) * + N; + InstructionCost NewCost = TTI.getArithmeticInstrCost( + cast<Instruction>(Op)->getOpcode(), VecTy); + return NewCost - OriginalCost; + } + + if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { + // The transpose can be skipped for the dot product lowering, roughly + // estimate the savings as the cost of embedding the columns in a + // vector. + InstructionCost EmbedCost(0); + for (unsigned I = 1; I < N; ++I) + EmbedCost += + TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), + std::nullopt, TTI::TCK_RecipThroughput); + return EmbedCost; + } + + // Costs for loads. + if (N == 1) + return InstructionCost(0); + + return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - + N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); + }; + auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); + + // We compare the costs of a vector.reduce.add to sequential add. + int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; + int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; + InstructionCost ReductionCost = + TTI.getArithmeticReductionCost( + AddOpCode, cast<VectorType>(LHS->getType()), + IsIntVec ? std::nullopt : std::optional(FMF)) + + TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); + InstructionCost SequentialAddCost = + TTI.getArithmeticInstrCost(AddOpCode, ElementType) * + (LShape.NumColumns - 1) + + TTI.getArithmeticInstrCost(MulOpCode, ElementType) * + (LShape.NumColumns); + if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) + return; + + FusedInsts.insert(MatMul); + IRBuilder<> Builder(MatMul); + auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, + this](Value *Op) -> Value * { + // Matmul must be the only user of loads because we don't use LowerLoad + // for row vectors (LowerLoad results in scalar loads and shufflevectors + // instead of single vector load). + if (!CanBeFlattened(Op)) + return Op; + + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { + ShapeMap[Op] = ShapeMap[Op].t(); + return Op; + } + + FusedInsts.insert(cast<Instruction>(Op)); + // If vector uses the builtin load, lower to a LoadInst + Value *Arg; + if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(Arg)))) { + auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); + Op->replaceAllUsesWith(NewLoad); + cast<Instruction>(Op)->eraseFromParent(); + return NewLoad; + } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(Arg)))) { + ToRemove.push_back(cast<Instruction>(Op)); + return Arg; + } + + return Op; + }; + LHS = FlattenArg(LHS); + + // Insert mul/fmul and llvm.vector.reduce.fadd + Value *Mul = + IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); + + Value *Result; + if (IsIntVec) + Result = Builder.CreateAddReduce(Mul); + else { + Result = Builder.CreateFAddReduce( + ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), + 0.0), + Mul); + cast<Instruction>(Result)->setFastMathFlags(FMF); + } + + // pack scalar back into a matrix and then replace matmul inst + Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), + Result, uint64_t(0)); + MatMul->replaceAllUsesWith(Result); + FusedInsts.insert(MatMul); + ToRemove.push_back(MatMul); + } + /// Compute \p Result += \p A * \p B for input matrices with left-associating /// addition. /// @@ -1469,15 +1650,14 @@ public: auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); AllocaInst *Alloca = Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); - Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo()); - Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(), + Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), Load->getAlign(), LoadLoc.Size.getValue()); Builder.SetInsertPoint(Fusion, Fusion->begin()); PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); PHI->addIncoming(Load->getPointerOperand(), Check0); PHI->addIncoming(Load->getPointerOperand(), Check1); - PHI->addIncoming(BC, Copy); + PHI->addIncoming(Alloca, Copy); // Adjust DT. DTUpdates.push_back({DT->Insert, Check0, Check1}); @@ -2397,99 +2577,8 @@ void LowerMatrixIntrinsicsPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (Minimal) OS << "minimal"; - OS << ">"; -} - -namespace { - -class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { -public: - static char ID; - - LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { - initializeLowerMatrixIntrinsicsLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); - bool C = LMT.Visit(); - return C; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - } -}; -} // namespace - -static const char pass_name[] = "Lower the matrix intrinsics"; -char LowerMatrixIntrinsicsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, - false, false) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, - false, false) - -Pass *llvm::createLowerMatrixIntrinsicsPass() { - return new LowerMatrixIntrinsicsLegacyPass(); -} - -namespace { - -/// A lightweight version of the matrix lowering pass that only requires TTI. -/// Advanced features that require DT, AA or ORE like tiling are disabled. This -/// is used to lower matrix intrinsics if the main lowering pass is not run, for -/// example with -O0. -class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { -public: - static char ID; - - LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { - initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); - bool C = LMT.Visit(); - return C; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.setPreservesCFG(); - } -}; -} // namespace - -static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; -char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, - "lower-matrix-intrinsics-minimal", pass_name_minimal, - false, false) -INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, - "lower-matrix-intrinsics-minimal", pass_name_minimal, false, - false) - -Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { - return new LowerMatrixIntrinsicsMinimalLegacyPass(); + OS << '>'; } diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 64846484f936..68642a01b37c 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -46,13 +46,10 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> @@ -72,6 +69,7 @@ STATISTIC(NumMemSetInfer, "Number of memsets inferred"); STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy"); STATISTIC(NumCpyToSet, "Number of memcpys converted to memset"); STATISTIC(NumCallSlot, "Number of call slot optimizations performed"); +STATISTIC(NumStackMove, "Number of stack-move optimizations performed"); namespace { @@ -255,54 +253,6 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, // MemCpyOptLegacyPass Pass //===----------------------------------------------------------------------===// -namespace { - -class MemCpyOptLegacyPass : public FunctionPass { - MemCpyOptPass Impl; - -public: - static char ID; // Pass identification, replacement for typeid - - MemCpyOptLegacyPass() : FunctionPass(ID) { - initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - -private: - // This transformation requires dominator postdominator info - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } -}; - -} // end anonymous namespace - -char MemCpyOptLegacyPass::ID = 0; - -/// The public interface to this file... -FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); } - -INITIALIZE_PASS_BEGIN(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", - false, false) - // Check that V is either not accessible by the caller, or unwinding cannot // occur between Start and End. static bool mayBeVisibleThroughUnwinding(Value *V, Instruction *Start, @@ -463,7 +413,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // Check to see if this store is to a constant offset from the start ptr. std::optional<int64_t> Offset = - isPointerOffset(StartPtr, NextStore->getPointerOperand(), DL); + NextStore->getPointerOperand()->getPointerOffsetFrom(StartPtr, DL); if (!Offset) break; @@ -477,7 +427,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // Check to see if this store is to a constant offset from the start ptr. std::optional<int64_t> Offset = - isPointerOffset(StartPtr, MSI->getDest(), DL); + MSI->getDest()->getPointerOffsetFrom(StartPtr, DL); if (!Offset) break; @@ -781,6 +731,23 @@ bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI, return true; } + // If this is a load-store pair from a stack slot to a stack slot, we + // might be able to perform the stack-move optimization just as we do for + // memcpys from an alloca to an alloca. + if (auto *DestAlloca = dyn_cast<AllocaInst>(SI->getPointerOperand())) { + if (auto *SrcAlloca = dyn_cast<AllocaInst>(LI->getPointerOperand())) { + if (performStackMoveOptzn(LI, SI, DestAlloca, SrcAlloca, + DL.getTypeStoreSize(T), BAA)) { + // Avoid invalidating the iterator. + BBI = SI->getNextNonDebugInstruction()->getIterator(); + eraseInstruction(SI); + eraseInstruction(LI); + ++NumMemCpyInstr; + return true; + } + } + } + return false; } @@ -1200,8 +1167,14 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // still want to eliminate the intermediate value, but we have to generate a // memmove instead of memcpy. bool UseMemMove = false; - if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep)))) + if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep)))) { + // Don't convert llvm.memcpy.inline into memmove because memmove can be + // lowered as a call, and that is not allowed for llvm.memcpy.inline (and + // there is no inline version of llvm.memmove) + if (isa<MemCpyInlineInst>(M)) + return false; UseMemMove = true; + } // If all checks passed, then we can transform M. LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy->memcpy src:\n" @@ -1246,13 +1219,18 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, /// In other words, transform: /// \code /// memset(dst, c, dst_size); +/// ... /// memcpy(dst, src, src_size); /// \endcode /// into: /// \code -/// memcpy(dst, src, src_size); +/// ... /// memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size); +/// memcpy(dst, src, src_size); /// \endcode +/// +/// The memset is sunk to just before the memcpy to ensure that src_size is +/// present when emitting the simplified memset. bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet, BatchAAResults &BAA) { @@ -1300,6 +1278,15 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, IRBuilder<> Builder(MemCpy); + // Preserve the debug location of the old memset for the code emitted here + // related to the new memset. This is correct according to the rules in + // https://llvm.org/docs/HowToUpdateDebugInfo.html about "when to preserve an + // instruction location", given that we move the memset within the basic + // block. + assert(MemSet->getParent() == MemCpy->getParent() && + "Preserving debug location based on moving memset within BB."); + Builder.SetCurrentDebugLocation(MemSet->getDebugLoc()); + // If the sizes have different types, zext the smaller one. if (DestSize->getType() != SrcSize->getType()) { if (DestSize->getType()->getIntegerBitWidth() > @@ -1323,9 +1310,8 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) && "MemCpy must be a MemoryDef"); - // The new memset is inserted after the memcpy, but it is known that its - // defining access is the memset about to be removed which immediately - // precedes the memcpy. + // The new memset is inserted before the memcpy, and it is known that the + // memcpy's defining access is the memset about to be removed. auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); auto *NewAccess = MSSAU->createMemoryAccessBefore( @@ -1440,6 +1426,217 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, return true; } +// Attempts to optimize the pattern whereby memory is copied from an alloca to +// another alloca, where the two allocas don't have conflicting mod/ref. If +// successful, the two allocas can be merged into one and the transfer can be +// deleted. This pattern is generated frequently in Rust, due to the ubiquity of +// move operations in that language. +// +// Once we determine that the optimization is safe to perform, we replace all +// uses of the destination alloca with the source alloca. We also "shrink wrap" +// the lifetime markers of the single merged alloca to before the first use +// and after the last use. Note that the "shrink wrapping" procedure is a safe +// transformation only because we restrict the scope of this optimization to +// allocas that aren't captured. +bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, + AllocaInst *DestAlloca, + AllocaInst *SrcAlloca, uint64_t Size, + BatchAAResults &BAA) { + LLVM_DEBUG(dbgs() << "Stack Move: Attempting to optimize:\n" + << *Store << "\n"); + + // Make sure the two allocas are in the same address space. + if (SrcAlloca->getAddressSpace() != DestAlloca->getAddressSpace()) { + LLVM_DEBUG(dbgs() << "Stack Move: Address space mismatch\n"); + return false; + } + + // 1. Check that copy is full. Calculate the static size of the allocas to be + // merged, bail out if we can't. + const DataLayout &DL = DestAlloca->getModule()->getDataLayout(); + std::optional<TypeSize> SrcSize = SrcAlloca->getAllocationSize(DL); + if (!SrcSize || SrcSize->isScalable() || Size != SrcSize->getFixedValue()) { + LLVM_DEBUG(dbgs() << "Stack Move: Source alloca size mismatch\n"); + return false; + } + std::optional<TypeSize> DestSize = DestAlloca->getAllocationSize(DL); + if (!DestSize || DestSize->isScalable() || + Size != DestSize->getFixedValue()) { + LLVM_DEBUG(dbgs() << "Stack Move: Destination alloca size mismatch\n"); + return false; + } + + // 2-1. Check that src and dest are static allocas, which are not affected by + // stacksave/stackrestore. + if (!SrcAlloca->isStaticAlloca() || !DestAlloca->isStaticAlloca() || + SrcAlloca->getParent() != Load->getParent() || + SrcAlloca->getParent() != Store->getParent()) + return false; + + // 2-2. Check that src and dest are never captured, unescaped allocas. Also + // collect lifetime markers first/last users in order to shrink wrap the + // lifetimes, and instructions with noalias metadata to remove them. + + SmallVector<Instruction *, 4> LifetimeMarkers; + Instruction *FirstUser = nullptr, *LastUser = nullptr; + SmallSet<Instruction *, 4> NoAliasInstrs; + + // Recursively track the user and check whether modified alias exist. + auto IsDereferenceableOrNull = [](Value *V, const DataLayout &DL) -> bool { + bool CanBeNull, CanBeFreed; + return V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed); + }; + + auto CaptureTrackingWithModRef = + [&](Instruction *AI, + function_ref<bool(Instruction *)> ModRefCallback) -> bool { + SmallVector<Instruction *, 8> Worklist; + Worklist.push_back(AI); + unsigned MaxUsesToExplore = getDefaultMaxUsesToExploreForCaptureTracking(); + Worklist.reserve(MaxUsesToExplore); + SmallSet<const Use *, 20> Visited; + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + for (const Use &U : I->uses()) { + if (Visited.size() >= MaxUsesToExplore) { + LLVM_DEBUG( + dbgs() + << "Stack Move: Exceeded max uses to see ModRef, bailing\n"); + return false; + } + if (!Visited.insert(&U).second) + continue; + switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) { + case UseCaptureKind::MAY_CAPTURE: + return false; + case UseCaptureKind::PASSTHROUGH: + // Instructions cannot have non-instruction users. + Worklist.push_back(cast<Instruction>(U.getUser())); + continue; + case UseCaptureKind::NO_CAPTURE: { + auto *UI = cast<Instruction>(U.getUser()); + if (DestAlloca->getParent() != UI->getParent()) + return false; + if (!FirstUser || UI->comesBefore(FirstUser)) + FirstUser = UI; + if (!LastUser || LastUser->comesBefore(UI)) + LastUser = UI; + if (UI->isLifetimeStartOrEnd()) { + // We note the locations of these intrinsic calls so that we can + // delete them later if the optimization succeeds, this is safe + // since both llvm.lifetime.start and llvm.lifetime.end intrinsics + // conceptually fill all the bytes of the alloca with an undefined + // value. + int64_t Size = cast<ConstantInt>(UI->getOperand(0))->getSExtValue(); + if (Size < 0 || Size == DestSize) { + LifetimeMarkers.push_back(UI); + continue; + } + } + if (UI->hasMetadata(LLVMContext::MD_noalias)) + NoAliasInstrs.insert(UI); + if (!ModRefCallback(UI)) + return false; + } + } + } + } + return true; + }; + + // 3. Check that dest has no Mod/Ref, except full size lifetime intrinsics, + // from the alloca to the Store. + ModRefInfo DestModRef = ModRefInfo::NoModRef; + MemoryLocation DestLoc(DestAlloca, LocationSize::precise(Size)); + auto DestModRefCallback = [&](Instruction *UI) -> bool { + // We don't care about the store itself. + if (UI == Store) + return true; + ModRefInfo Res = BAA.getModRefInfo(UI, DestLoc); + // FIXME: For multi-BB cases, we need to see reachability from it to + // store. + // Bailout if Dest may have any ModRef before Store. + if (UI->comesBefore(Store) && isModOrRefSet(Res)) + return false; + DestModRef |= BAA.getModRefInfo(UI, DestLoc); + + return true; + }; + + if (!CaptureTrackingWithModRef(DestAlloca, DestModRefCallback)) + return false; + + // 3. Check that, from after the Load to the end of the BB, + // 3-1. if the dest has any Mod, src has no Ref, and + // 3-2. if the dest has any Ref, src has no Mod except full-sized lifetimes. + MemoryLocation SrcLoc(SrcAlloca, LocationSize::precise(Size)); + + auto SrcModRefCallback = [&](Instruction *UI) -> bool { + // Any ModRef before Load doesn't matter, also Load and Store can be + // ignored. + if (UI->comesBefore(Load) || UI == Load || UI == Store) + return true; + ModRefInfo Res = BAA.getModRefInfo(UI, SrcLoc); + if ((isModSet(DestModRef) && isRefSet(Res)) || + (isRefSet(DestModRef) && isModSet(Res))) + return false; + + return true; + }; + + if (!CaptureTrackingWithModRef(SrcAlloca, SrcModRefCallback)) + return false; + + // We can do the transformation. First, align the allocas appropriately. + SrcAlloca->setAlignment( + std::max(SrcAlloca->getAlign(), DestAlloca->getAlign())); + + // Merge the two allocas. + DestAlloca->replaceAllUsesWith(SrcAlloca); + eraseInstruction(DestAlloca); + + // Drop metadata on the source alloca. + SrcAlloca->dropUnknownNonDebugMetadata(); + + // Do "shrink wrap" the lifetimes, if the original lifetime intrinsics exists. + if (!LifetimeMarkers.empty()) { + LLVMContext &C = SrcAlloca->getContext(); + IRBuilder<> Builder(C); + + ConstantInt *AllocaSize = ConstantInt::get(Type::getInt64Ty(C), Size); + // Create a new lifetime start marker before the first user of src or alloca + // users. + Builder.SetInsertPoint(FirstUser->getParent(), FirstUser->getIterator()); + Builder.CreateLifetimeStart(SrcAlloca, AllocaSize); + + // Create a new lifetime end marker after the last user of src or alloca + // users. + // FIXME: If the last user is the terminator for the bb, we can insert + // lifetime.end marker to the immidiate post-dominator, but currently do + // nothing. + if (!LastUser->isTerminator()) { + Builder.SetInsertPoint(LastUser->getParent(), ++LastUser->getIterator()); + Builder.CreateLifetimeEnd(SrcAlloca, AllocaSize); + } + + // Remove all other lifetime markers. + for (Instruction *I : LifetimeMarkers) + eraseInstruction(I); + } + + // As this transformation can cause memory accesses that didn't previously + // alias to begin to alias one another, we remove !noalias metadata from any + // uses of either alloca. This is conservative, but more precision doesn't + // seem worthwhile right now. + for (Instruction *I : NoAliasInstrs) + I->setMetadata(LLVMContext::MD_noalias, nullptr); + + LLVM_DEBUG(dbgs() << "Stack Move: Performed staack-move optimization\n"); + NumStackMove++; + return true; +} + /// Perform simplification of memcpy's. If we have memcpy A /// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite /// B to be a memcpy from X to Z (or potentially a memmove, depending on @@ -1484,8 +1681,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc, BAA); // Try to turn a partially redundant memset + memcpy into - // memcpy + smaller memset. We don't need the memcpy size for this. - // The memcpy most post-dom the memset, so limit this to the same basic + // smaller memset + memcpy. We don't need the memcpy size for this. + // The memcpy must post-dom the memset, so limit this to the same basic // block. A non-local generalization is likely not worthwhile. if (auto *MD = dyn_cast<MemoryDef>(DestClobber)) if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst())) @@ -1496,13 +1693,14 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess( AnyClobber, MemoryLocation::getForSource(M), BAA); - // There are four possible optimizations we can do for memcpy: + // There are five possible optimizations we can do for memcpy: // a) memcpy-memcpy xform which exposes redundance for DSE. // b) call-memcpy xform for return slot optimization. // c) memcpy from freshly alloca'd space or space that has just started // its lifetime copies undefined data, and we can therefore eliminate // the memcpy in favor of the data that was already at the destination. // d) memcpy from a just-memset'd source can be turned into memset. + // e) elimination of memcpy via stack-move optimization. if (auto *MD = dyn_cast<MemoryDef>(SrcClobber)) { if (Instruction *MI = MD->getMemoryInst()) { if (auto *CopySize = dyn_cast<ConstantInt>(M->getLength())) { @@ -1521,7 +1719,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { } } if (auto *MDep = dyn_cast<MemCpyInst>(MI)) - return processMemCpyMemCpyDependence(M, MDep, BAA); + if (processMemCpyMemCpyDependence(M, MDep, BAA)) + return true; if (auto *MDep = dyn_cast<MemSetInst>(MI)) { if (performMemCpyToMemSetOptzn(M, MDep, BAA)) { LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n"); @@ -1540,6 +1739,27 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { } } + // If the transfer is from a stack slot to a stack slot, then we may be able + // to perform the stack-move optimization. See the comments in + // performStackMoveOptzn() for more details. + auto *DestAlloca = dyn_cast<AllocaInst>(M->getDest()); + if (!DestAlloca) + return false; + auto *SrcAlloca = dyn_cast<AllocaInst>(M->getSource()); + if (!SrcAlloca) + return false; + ConstantInt *Len = dyn_cast<ConstantInt>(M->getLength()); + if (Len == nullptr) + return false; + if (performStackMoveOptzn(M, M, DestAlloca, SrcAlloca, Len->getZExtValue(), + BAA)) { + // Avoid invalidating the iterator. + BBI = M->getNextNonDebugInstruction()->getIterator(); + eraseInstruction(M); + ++NumMemCpyInstr; + return true; + } + return false; } @@ -1623,24 +1843,110 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep), - MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) + MSSA->getMemoryAccess(MDep), CallAccess)) return false; - Value *TmpCast = MDep->getSource(); - if (MDep->getSource()->getType() != ByValArg->getType()) { - BitCastInst *TmpBitCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), - "tmpcast", &CB); - // Set the tmpcast's DebugLoc to MDep's - TmpBitCast->setDebugLoc(MDep->getDebugLoc()); - TmpCast = TmpBitCast; - } - LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n" << " " << *MDep << "\n" << " " << CB << "\n"); // Otherwise we're good! Update the byval argument. - CB.setArgOperand(ArgNo, TmpCast); + CB.setArgOperand(ArgNo, MDep->getSource()); + ++NumMemCpyInstr; + return true; +} + +/// This is called on memcpy dest pointer arguments attributed as immutable +/// during call. Try to use memcpy source directly if all of the following +/// conditions are satisfied. +/// 1. The memcpy dst is neither modified during the call nor captured by the +/// call. (if readonly, noalias, nocapture attributes on call-site.) +/// 2. The memcpy dst is an alloca with known alignment & size. +/// 2-1. The memcpy length == the alloca size which ensures that the new +/// pointer is dereferenceable for the required range +/// 2-2. The src pointer has alignment >= the alloca alignment or can be +/// enforced so. +/// 3. The memcpy dst and src is not modified between the memcpy and the call. +/// (if MSSA clobber check is safe.) +/// 4. The memcpy src is not modified during the call. (ModRef check shows no +/// Mod.) +bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) { + // 1. Ensure passed argument is immutable during call. + if (!(CB.paramHasAttr(ArgNo, Attribute::NoAlias) && + CB.paramHasAttr(ArgNo, Attribute::NoCapture))) + return false; + const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout(); + Value *ImmutArg = CB.getArgOperand(ArgNo); + + // 2. Check that arg is alloca + // TODO: Even if the arg gets back to branches, we can remove memcpy if all + // the alloca alignments can be enforced to source alignment. + auto *AI = dyn_cast<AllocaInst>(ImmutArg->stripPointerCasts()); + if (!AI) + return false; + + std::optional<TypeSize> AllocaSize = AI->getAllocationSize(DL); + // Can't handle unknown size alloca. + // (e.g. Variable Length Array, Scalable Vector) + if (!AllocaSize || AllocaSize->isScalable()) + return false; + MemoryLocation Loc(ImmutArg, LocationSize::precise(*AllocaSize)); + MemoryUseOrDef *CallAccess = MSSA->getMemoryAccess(&CB); + if (!CallAccess) + return false; + + MemCpyInst *MDep = nullptr; + BatchAAResults BAA(*AA); + MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( + CallAccess->getDefiningAccess(), Loc, BAA); + if (auto *MD = dyn_cast<MemoryDef>(Clobber)) + MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst()); + + // If the immut argument isn't fed by a memcpy, ignore it. If it is fed by + // a memcpy, check that the arg equals the memcpy dest. + if (!MDep || MDep->isVolatile() || AI != MDep->getDest()) + return false; + + // The address space of the memcpy source must match the immut argument + if (MDep->getSource()->getType()->getPointerAddressSpace() != + ImmutArg->getType()->getPointerAddressSpace()) + return false; + + // 2-1. The length of the memcpy must be equal to the size of the alloca. + auto *MDepLen = dyn_cast<ConstantInt>(MDep->getLength()); + if (!MDepLen || AllocaSize != MDepLen->getValue()) + return false; + + // 2-2. the memcpy source align must be larger than or equal the alloca's + // align. If not so, we check to see if we can force the source of the memcpy + // to the alignment we need. If we fail, we bail out. + Align MemDepAlign = MDep->getSourceAlign().valueOrOne(); + Align AllocaAlign = AI->getAlign(); + if (MemDepAlign < AllocaAlign && + getOrEnforceKnownAlignment(MDep->getSource(), AllocaAlign, DL, &CB, AC, + DT) < AllocaAlign) + return false; + + // 3. Verify that the source doesn't change in between the memcpy and + // the call. + // memcpy(a <- b) + // *b = 42; + // foo(*a) + // It would be invalid to transform the second memcpy into foo(*b). + if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep), + MSSA->getMemoryAccess(MDep), CallAccess)) + return false; + + // 4. The memcpy src must not be modified during the call. + if (isModSet(AA->getModRefInfo(&CB, MemoryLocation::getForSource(MDep)))) + return false; + + LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to Immut src:\n" + << " " << *MDep << "\n" + << " " << CB << "\n"); + + // Otherwise we're good! Update the immut argument. + CB.setArgOperand(ArgNo, MDep->getSource()); ++NumMemCpyInstr; return true; } @@ -1673,9 +1979,12 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) { else if (auto *M = dyn_cast<MemMoveInst>(I)) RepeatInstruction = processMemMove(M); else if (auto *CB = dyn_cast<CallBase>(I)) { - for (unsigned i = 0, e = CB->arg_size(); i != e; ++i) + for (unsigned i = 0, e = CB->arg_size(); i != e; ++i) { if (CB->isByValArgument(i)) MadeChange |= processByValArgument(*CB, i); + else if (CB->onlyReadsMemory(i)) + MadeChange |= processImmutArgument(*CB, i); + } } // Reprocess the instruction if desired. @@ -1730,17 +2039,3 @@ bool MemCpyOptPass::runImpl(Function &F, TargetLibraryInfo *TLI_, return MadeChange; } - -/// This is the main transformation entry point for a function. -bool MemCpyOptLegacyPass::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - - auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - - return Impl.runImpl(F, TLI, AA, AC, DT, MSSA); -} diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index bcedb05890af..311a6435ba7c 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -42,6 +42,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/MergeICmps.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" @@ -49,6 +50,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/IRBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -157,7 +159,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { return {}; } - APInt Offset = APInt(DL.getPointerTypeSizeInBits(Addr->getType()), 0); + APInt Offset = APInt(DL.getIndexTypeSizeInBits(Addr->getType()), 0); Value *Base = Addr; auto *GEP = dyn_cast<GetElementPtrInst>(Addr); if (GEP) { @@ -639,10 +641,11 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, if (Comparisons.size() == 1) { LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); - Value *const LhsLoad = - Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs); - Value *const RhsLoad = - Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); + // Use clone to keep the metadata + Instruction *const LhsLoad = Builder.Insert(FirstCmp.Lhs().LoadI->clone()); + Instruction *const RhsLoad = Builder.Insert(FirstCmp.Rhs().LoadI->clone()); + LhsLoad->replaceUsesOfWith(LhsLoad->getOperand(0), Lhs); + RhsLoad->replaceUsesOfWith(RhsLoad->getOperand(0), Rhs); // There are no blocks to merge, just do the comparison. IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); } else { diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 62e75d98448c..6c5453831ade 100644 --- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -78,6 +78,7 @@ #include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" @@ -191,11 +192,16 @@ StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, MemoryLocation Loc0 = MemoryLocation::get(Store0); MemoryLocation Loc1 = MemoryLocation::get(Store1); - if (AA->isMustAlias(Loc0, Loc1) && Store0->isSameOperationAs(Store1) && + + if (AA->isMustAlias(Loc0, Loc1) && !isStoreSinkBarrierInRange(*Store1->getNextNode(), BB1->back(), Loc1) && - !isStoreSinkBarrierInRange(*Store0->getNextNode(), BB0->back(), Loc0)) { + !isStoreSinkBarrierInRange(*Store0->getNextNode(), BB0->back(), Loc0) && + Store0->hasSameSpecialState(Store1) && + CastInst::isBitOrNoopPointerCastable( + Store0->getValueOperand()->getType(), + Store1->getValueOperand()->getType(), + Store0->getModule()->getDataLayout())) return Store1; - } } return nullptr; } @@ -254,6 +260,13 @@ void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0, S0->applyMergedLocation(S0->getDebugLoc(), S1->getDebugLoc()); S0->mergeDIAssignID(S1); + // Insert bitcast for conflicting typed stores (or just use original value if + // same type). + IRBuilder<> Builder(S0); + auto Cast = Builder.CreateBitOrPointerCast(S0->getValueOperand(), + S1->getValueOperand()->getType()); + S0->setOperand(0, Cast); + // Create the new store to be inserted at the join point. StoreInst *SNew = cast<StoreInst>(S0->clone()); SNew->insertBefore(&*InsertPt); @@ -428,7 +441,7 @@ void MergedLoadStoreMotionPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<MergedLoadStoreMotionPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; OS << (Options.SplitFooterBB ? "" : "no-") << "split-footer-bb"; - OS << ">"; + OS << '>'; } diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index 19bee4fa3879..9c3e9a2fd018 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -351,9 +351,9 @@ Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) { bool NaryReassociatePass::requiresSignExtension(Value *Index, GetElementPtrInst *GEP) { - unsigned PointerSizeInBits = - DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace()); - return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits; + unsigned IndexSizeInBits = + DL->getIndexSizeInBits(GEP->getType()->getPointerAddressSpace()); + return cast<IntegerType>(Index->getType())->getBitWidth() < IndexSizeInBits; } GetElementPtrInst * @@ -449,12 +449,12 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, return nullptr; // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0]))); - Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); - if (RHS->getType() != IntPtrTy) - RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy); + Type *PtrIdxTy = DL->getIndexType(GEP->getType()); + if (RHS->getType() != PtrIdxTy) + RHS = Builder.CreateSExtOrTrunc(RHS, PtrIdxTy); if (IndexedSize != ElementSize) { RHS = Builder.CreateMul( - RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize)); + RHS, ConstantInt::get(PtrIdxTy, IndexedSize / ElementSize)); } GetElementPtrInst *NewGEP = cast<GetElementPtrInst>( Builder.CreateGEP(GEP->getResultElementType(), Candidate, RHS)); diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index d3dba0c5f1d5..1af40e2c4e62 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -93,8 +93,6 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ArrayRecycler.h" #include "llvm/Support/Casting.h" @@ -104,7 +102,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVNExpression.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" @@ -1277,10 +1274,17 @@ const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) const { const CallExpression * NewGVN::createCallExpression(CallInst *CI, const MemoryAccess *MA) const { // FIXME: Add operand bundles for calls. - // FIXME: Allow commutative matching for intrinsics. auto *E = new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, MA); setBasicExpressionInfo(CI, E); + if (CI->isCommutative()) { + // Ensure that commutative intrinsics that only differ by a permutation + // of their operands get the same value number by sorting the operand value + // numbers. + assert(CI->getNumOperands() >= 2 && "Unsupported commutative intrinsic!"); + if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) + E->swapOperands(0, 1); + } return E; } @@ -1453,8 +1457,7 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, if (Offset >= 0) { if (auto *C = dyn_cast<Constant>( lookupOperandLeader(DepSI->getValueOperand()))) { - if (Constant *Res = - getConstantStoreValueForLoad(C, Offset, LoadType, DL)) { + if (Constant *Res = getConstantValueForLoad(C, Offset, LoadType, DL)) { LLVM_DEBUG(dbgs() << "Coercing load from store " << *DepSI << " to constant " << *Res << "\n"); return createConstantExpression(Res); @@ -1470,7 +1473,7 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, // We can coerce a constant load into a load. if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI))) if (auto *PossibleConstant = - getConstantLoadValueForLoad(C, Offset, LoadType, DL)) { + getConstantValueForLoad(C, Offset, LoadType, DL)) { LLVM_DEBUG(dbgs() << "Coercing load from load " << *LI << " to constant " << *PossibleConstant << "\n"); return createConstantExpression(PossibleConstant); @@ -1617,6 +1620,12 @@ NewGVN::ExprResult NewGVN::performSymbolicCallEvaluation(Instruction *I) const { if (CI->getFunction()->isPresplitCoroutine()) return ExprResult::none(); + // Do not combine convergent calls since they implicitly depend on the set of + // threads that is currently executing, and they might be in different basic + // blocks. + if (CI->isConvergent()) + return ExprResult::none(); + if (AA->doesNotAccessMemory(CI)) { return ExprResult::some( createCallExpression(CI, TOPClass->getMemoryLeader())); @@ -1992,6 +2001,7 @@ NewGVN::performSymbolicEvaluation(Value *V, break; case Instruction::BitCast: case Instruction::AddrSpaceCast: + case Instruction::Freeze: return createExpression(I); break; case Instruction::ICmp: @@ -2739,10 +2749,10 @@ NewGVN::makePossiblePHIOfOps(Instruction *I, return nullptr; } // No point in doing this for one-operand phis. - if (OpPHI->getNumOperands() == 1) { - OpPHI = nullptr; - continue; - } + // Since all PHIs for operands must be in the same block, then they must + // have the same number of operands so we can just abort. + if (OpPHI->getNumOperands() == 1) + return nullptr; } if (!OpPHI) @@ -3712,9 +3722,10 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { } // Now insert something that simplifycfg will turn into an unreachable. Type *Int8Ty = Type::getInt8Ty(BB->getContext()); - new StoreInst(PoisonValue::get(Int8Ty), - Constant::getNullValue(Int8Ty->getPointerTo()), - BB->getTerminator()); + new StoreInst( + PoisonValue::get(Int8Ty), + Constant::getNullValue(PointerType::getUnqual(BB->getContext())), + BB->getTerminator()); } void NewGVN::markInstructionForDeletion(Instruction *I) { @@ -4208,61 +4219,6 @@ bool NewGVN::shouldSwapOperandsForIntrinsic(const Value *A, const Value *B, return false; } -namespace { - -class NewGVNLegacyPass : public FunctionPass { -public: - // Pass identification, replacement for typeid. - static char ID; - - NewGVNLegacyPass() : FunctionPass(ID) { - initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - -private: - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; - -} // end anonymous namespace - -bool NewGVNLegacyPass::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), - &getAnalysis<AAResultsWrapperPass>().getAAResults(), - &getAnalysis<MemorySSAWrapperPass>().getMSSA(), - F.getParent()->getDataLayout()) - .runGVN(); -} - -char NewGVNLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false, - false) - -// createGVNPass - The public interface to this file. -FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); } - PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { // Apparently the order in which we get these results matter for // the old GVN (see Chandler's comment in GVN.cpp). I'll keep diff --git a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp index e1cc3fc71c3e..0266eb1a9f50 100644 --- a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -47,6 +47,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/PlaceSafepoints.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -67,7 +68,9 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" -#define DEBUG_TYPE "safepoint-placement" +using namespace llvm; + +#define DEBUG_TYPE "place-safepoints" STATISTIC(NumEntrySafepoints, "Number of entry safepoints inserted"); STATISTIC(NumBackedgeSafepoints, "Number of backedge safepoints inserted"); @@ -77,8 +80,6 @@ STATISTIC(CallInLoop, STATISTIC(FiniteExecution, "Number of loops without safepoints finite execution"); -using namespace llvm; - // Ignore opportunities to avoid placing safepoints on backedges, useful for // validation static cl::opt<bool> AllBackedges("spp-all-backedges", cl::Hidden, @@ -97,10 +98,10 @@ static cl::opt<bool> SplitBackedge("spp-split-backedge", cl::Hidden, cl::init(false)); namespace { - /// An analysis pass whose purpose is to identify each of the backedges in /// the function which require a safepoint poll to be inserted. -struct PlaceBackedgeSafepointsImpl : public FunctionPass { +class PlaceBackedgeSafepointsLegacyPass : public FunctionPass { +public: static char ID; /// The output of the pass - gives a list of each backedge (described by @@ -111,17 +112,14 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { /// the call-dependent placement opts. bool CallSafepointsEnabled; - ScalarEvolution *SE = nullptr; - DominatorTree *DT = nullptr; - LoopInfo *LI = nullptr; - TargetLibraryInfo *TLI = nullptr; - - PlaceBackedgeSafepointsImpl(bool CallSafepoints = false) + PlaceBackedgeSafepointsLegacyPass(bool CallSafepoints = false) : FunctionPass(ID), CallSafepointsEnabled(CallSafepoints) { - initializePlaceBackedgeSafepointsImplPass(*PassRegistry::getPassRegistry()); + initializePlaceBackedgeSafepointsLegacyPassPass( + *PassRegistry::getPassRegistry()); } bool runOnLoop(Loop *); + void runOnLoopAndSubLoops(Loop *L) { // Visit all the subloops for (Loop *I : *L) @@ -149,39 +147,245 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { // analysis are preserved. AU.setPreservesAll(); } + +private: + ScalarEvolution *SE = nullptr; + DominatorTree *DT = nullptr; + LoopInfo *LI = nullptr; + TargetLibraryInfo *TLI = nullptr; }; -} +} // namespace static cl::opt<bool> NoEntry("spp-no-entry", cl::Hidden, cl::init(false)); static cl::opt<bool> NoCall("spp-no-call", cl::Hidden, cl::init(false)); static cl::opt<bool> NoBackedge("spp-no-backedge", cl::Hidden, cl::init(false)); -namespace { -struct PlaceSafepoints : public FunctionPass { - static char ID; // Pass identification, replacement for typeid +char PlaceBackedgeSafepointsLegacyPass::ID = 0; - PlaceSafepoints() : FunctionPass(ID) { - initializePlaceSafepointsPass(*PassRegistry::getPassRegistry()); - } - bool runOnFunction(Function &F) override; +INITIALIZE_PASS_BEGIN(PlaceBackedgeSafepointsLegacyPass, + "place-backedge-safepoints-impl", + "Place Backedge Safepoints", false, false) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(PlaceBackedgeSafepointsLegacyPass, + "place-backedge-safepoints-impl", + "Place Backedge Safepoints", false, false) - void getAnalysisUsage(AnalysisUsage &AU) const override { - // We modify the graph wholesale (inlining, block insertion, etc). We - // preserve nothing at the moment. We could potentially preserve dom tree - // if that was worth doing - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; -} +static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header, + BasicBlock *Pred, + DominatorTree &DT, + const TargetLibraryInfo &TLI); + +static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE, + BasicBlock *Pred); + +static Instruction *findLocationForEntrySafepoint(Function &F, + DominatorTree &DT); + +static bool isGCSafepointPoll(Function &F); +static bool shouldRewriteFunction(Function &F); +static bool enableEntrySafepoints(Function &F); +static bool enableBackedgeSafepoints(Function &F); +static bool enableCallSafepoints(Function &F); -// Insert a safepoint poll immediately before the given instruction. Does -// not handle the parsability of state at the runtime call, that's the -// callers job. static void InsertSafepointPoll(Instruction *InsertBefore, std::vector<CallBase *> &ParsePointsNeeded /*rval*/, const TargetLibraryInfo &TLI); +bool PlaceBackedgeSafepointsLegacyPass::runOnLoop(Loop *L) { + // Loop through all loop latches (branches controlling backedges). We need + // to place a safepoint on every backedge (potentially). + // Note: In common usage, there will be only one edge due to LoopSimplify + // having run sometime earlier in the pipeline, but this code must be correct + // w.r.t. loops with multiple backedges. + BasicBlock *Header = L->getHeader(); + SmallVector<BasicBlock *, 16> LoopLatches; + L->getLoopLatches(LoopLatches); + for (BasicBlock *Pred : LoopLatches) { + assert(L->contains(Pred)); + + // Make a policy decision about whether this loop needs a safepoint or + // not. Note that this is about unburdening the optimizer in loops, not + // avoiding the runtime cost of the actual safepoint. + if (!AllBackedges) { + if (mustBeFiniteCountedLoop(L, SE, Pred)) { + LLVM_DEBUG(dbgs() << "skipping safepoint placement in finite loop\n"); + FiniteExecution++; + continue; + } + if (CallSafepointsEnabled && + containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) { + // Note: This is only semantically legal since we won't do any further + // IPO or inlining before the actual call insertion.. If we hadn't, we + // might latter loose this call safepoint. + LLVM_DEBUG( + dbgs() + << "skipping safepoint placement due to unconditional call\n"); + CallInLoop++; + continue; + } + } + + // TODO: We can create an inner loop which runs a finite number of + // iterations with an outer loop which contains a safepoint. This would + // not help runtime performance that much, but it might help our ability to + // optimize the inner loop. + + // Safepoint insertion would involve creating a new basic block (as the + // target of the current backedge) which does the safepoint (of all live + // variables) and branches to the true header + Instruction *Term = Pred->getTerminator(); + + LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); + + PollLocations.push_back(Term); + } + + return false; +} + +bool PlaceSafepointsPass::runImpl(Function &F, const TargetLibraryInfo &TLI) { + if (F.isDeclaration() || F.empty()) { + // This is a declaration, nothing to do. Must exit early to avoid crash in + // dom tree calculation + return false; + } + + if (isGCSafepointPoll(F)) { + // Given we're inlining this inside of safepoint poll insertion, this + // doesn't make any sense. Note that we do make any contained calls + // parseable after we inline a poll. + return false; + } + + if (!shouldRewriteFunction(F)) + return false; + + bool Modified = false; + + // In various bits below, we rely on the fact that uses are reachable from + // defs. When there are basic blocks unreachable from the entry, dominance + // and reachablity queries return non-sensical results. Thus, we preprocess + // the function to ensure these properties hold. + Modified |= removeUnreachableBlocks(F); + + // STEP 1 - Insert the safepoint polling locations. We do not need to + // actually insert parse points yet. That will be done for all polls and + // calls in a single pass. + + DominatorTree DT; + DT.recalculate(F); + + SmallVector<Instruction *, 16> PollsNeeded; + std::vector<CallBase *> ParsePointNeeded; + + if (enableBackedgeSafepoints(F)) { + // Construct a pass manager to run the LoopPass backedge logic. We + // need the pass manager to handle scheduling all the loop passes + // appropriately. Doing this by hand is painful and just not worth messing + // with for the moment. + legacy::FunctionPassManager FPM(F.getParent()); + bool CanAssumeCallSafepoints = enableCallSafepoints(F); + auto *PBS = new PlaceBackedgeSafepointsLegacyPass(CanAssumeCallSafepoints); + FPM.add(PBS); + FPM.run(F); + + // We preserve dominance information when inserting the poll, otherwise + // we'd have to recalculate this on every insert + DT.recalculate(F); + + auto &PollLocations = PBS->PollLocations; + + auto OrderByBBName = [](Instruction *a, Instruction *b) { + return a->getParent()->getName() < b->getParent()->getName(); + }; + // We need the order of list to be stable so that naming ends up stable + // when we split edges. This makes test cases much easier to write. + llvm::sort(PollLocations, OrderByBBName); + + // We can sometimes end up with duplicate poll locations. This happens if + // a single loop is visited more than once. The fact this happens seems + // wrong, but it does happen for the split-backedge.ll test case. + PollLocations.erase(std::unique(PollLocations.begin(), PollLocations.end()), + PollLocations.end()); + + // Insert a poll at each point the analysis pass identified + // The poll location must be the terminator of a loop latch block. + for (Instruction *Term : PollLocations) { + // We are inserting a poll, the function is modified + Modified = true; + + if (SplitBackedge) { + // Split the backedge of the loop and insert the poll within that new + // basic block. This creates a loop with two latches per original + // latch (which is non-ideal), but this appears to be easier to + // optimize in practice than inserting the poll immediately before the + // latch test. + + // Since this is a latch, at least one of the successors must dominate + // it. Its possible that we have a) duplicate edges to the same header + // and b) edges to distinct loop headers. We need to insert pools on + // each. + SetVector<BasicBlock *> Headers; + for (unsigned i = 0; i < Term->getNumSuccessors(); i++) { + BasicBlock *Succ = Term->getSuccessor(i); + if (DT.dominates(Succ, Term->getParent())) { + Headers.insert(Succ); + } + } + assert(!Headers.empty() && "poll location is not a loop latch?"); + + // The split loop structure here is so that we only need to recalculate + // the dominator tree once. Alternatively, we could just keep it up to + // date and use a more natural merged loop. + SetVector<BasicBlock *> SplitBackedges; + for (BasicBlock *Header : Headers) { + BasicBlock *NewBB = SplitEdge(Term->getParent(), Header, &DT); + PollsNeeded.push_back(NewBB->getTerminator()); + NumBackedgeSafepoints++; + } + } else { + // Split the latch block itself, right before the terminator. + PollsNeeded.push_back(Term); + NumBackedgeSafepoints++; + } + } + } + + if (enableEntrySafepoints(F)) { + if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) { + PollsNeeded.push_back(Location); + Modified = true; + NumEntrySafepoints++; + } + // TODO: else we should assert that there was, in fact, a policy choice to + // not insert a entry safepoint poll. + } + + // Now that we've identified all the needed safepoint poll locations, insert + // safepoint polls themselves. + for (Instruction *PollLocation : PollsNeeded) { + std::vector<CallBase *> RuntimeCalls; + InsertSafepointPoll(PollLocation, RuntimeCalls, TLI); + llvm::append_range(ParsePointNeeded, RuntimeCalls); + } + + return Modified; +} + +PreservedAnalyses PlaceSafepointsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + if (!runImpl(F, TLI)) + return PreservedAnalyses::all(); + + // TODO: can we preserve more? + return PreservedAnalyses::none(); +} + static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) { if (callsGCLeafFunction(Call, TLI)) return false; @@ -306,58 +510,6 @@ static void scanInlinedCode(Instruction *Start, Instruction *End, } } -bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { - // Loop through all loop latches (branches controlling backedges). We need - // to place a safepoint on every backedge (potentially). - // Note: In common usage, there will be only one edge due to LoopSimplify - // having run sometime earlier in the pipeline, but this code must be correct - // w.r.t. loops with multiple backedges. - BasicBlock *Header = L->getHeader(); - SmallVector<BasicBlock*, 16> LoopLatches; - L->getLoopLatches(LoopLatches); - for (BasicBlock *Pred : LoopLatches) { - assert(L->contains(Pred)); - - // Make a policy decision about whether this loop needs a safepoint or - // not. Note that this is about unburdening the optimizer in loops, not - // avoiding the runtime cost of the actual safepoint. - if (!AllBackedges) { - if (mustBeFiniteCountedLoop(L, SE, Pred)) { - LLVM_DEBUG(dbgs() << "skipping safepoint placement in finite loop\n"); - FiniteExecution++; - continue; - } - if (CallSafepointsEnabled && - containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) { - // Note: This is only semantically legal since we won't do any further - // IPO or inlining before the actual call insertion.. If we hadn't, we - // might latter loose this call safepoint. - LLVM_DEBUG( - dbgs() - << "skipping safepoint placement due to unconditional call\n"); - CallInLoop++; - continue; - } - } - - // TODO: We can create an inner loop which runs a finite number of - // iterations with an outer loop which contains a safepoint. This would - // not help runtime performance that much, but it might help our ability to - // optimize the inner loop. - - // Safepoint insertion would involve creating a new basic block (as the - // target of the current backedge) which does the safepoint (of all live - // variables) and branches to the true header - Instruction *Term = Pred->getTerminator(); - - LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); - - PollLocations.push_back(Term); - } - - return false; -} - /// Returns true if an entry safepoint is not required before this callsite in /// the caller function. static bool doesNotRequireEntrySafepointBefore(CallBase *Call) { @@ -463,161 +615,9 @@ static bool enableEntrySafepoints(Function &F) { return !NoEntry; } static bool enableBackedgeSafepoints(Function &F) { return !NoBackedge; } static bool enableCallSafepoints(Function &F) { return !NoCall; } -bool PlaceSafepoints::runOnFunction(Function &F) { - if (F.isDeclaration() || F.empty()) { - // This is a declaration, nothing to do. Must exit early to avoid crash in - // dom tree calculation - return false; - } - - if (isGCSafepointPoll(F)) { - // Given we're inlining this inside of safepoint poll insertion, this - // doesn't make any sense. Note that we do make any contained calls - // parseable after we inline a poll. - return false; - } - - if (!shouldRewriteFunction(F)) - return false; - - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - - bool Modified = false; - - // In various bits below, we rely on the fact that uses are reachable from - // defs. When there are basic blocks unreachable from the entry, dominance - // and reachablity queries return non-sensical results. Thus, we preprocess - // the function to ensure these properties hold. - Modified |= removeUnreachableBlocks(F); - - // STEP 1 - Insert the safepoint polling locations. We do not need to - // actually insert parse points yet. That will be done for all polls and - // calls in a single pass. - - DominatorTree DT; - DT.recalculate(F); - - SmallVector<Instruction *, 16> PollsNeeded; - std::vector<CallBase *> ParsePointNeeded; - - if (enableBackedgeSafepoints(F)) { - // Construct a pass manager to run the LoopPass backedge logic. We - // need the pass manager to handle scheduling all the loop passes - // appropriately. Doing this by hand is painful and just not worth messing - // with for the moment. - legacy::FunctionPassManager FPM(F.getParent()); - bool CanAssumeCallSafepoints = enableCallSafepoints(F); - auto *PBS = new PlaceBackedgeSafepointsImpl(CanAssumeCallSafepoints); - FPM.add(PBS); - FPM.run(F); - - // We preserve dominance information when inserting the poll, otherwise - // we'd have to recalculate this on every insert - DT.recalculate(F); - - auto &PollLocations = PBS->PollLocations; - - auto OrderByBBName = [](Instruction *a, Instruction *b) { - return a->getParent()->getName() < b->getParent()->getName(); - }; - // We need the order of list to be stable so that naming ends up stable - // when we split edges. This makes test cases much easier to write. - llvm::sort(PollLocations, OrderByBBName); - - // We can sometimes end up with duplicate poll locations. This happens if - // a single loop is visited more than once. The fact this happens seems - // wrong, but it does happen for the split-backedge.ll test case. - PollLocations.erase(std::unique(PollLocations.begin(), - PollLocations.end()), - PollLocations.end()); - - // Insert a poll at each point the analysis pass identified - // The poll location must be the terminator of a loop latch block. - for (Instruction *Term : PollLocations) { - // We are inserting a poll, the function is modified - Modified = true; - - if (SplitBackedge) { - // Split the backedge of the loop and insert the poll within that new - // basic block. This creates a loop with two latches per original - // latch (which is non-ideal), but this appears to be easier to - // optimize in practice than inserting the poll immediately before the - // latch test. - - // Since this is a latch, at least one of the successors must dominate - // it. Its possible that we have a) duplicate edges to the same header - // and b) edges to distinct loop headers. We need to insert pools on - // each. - SetVector<BasicBlock *> Headers; - for (unsigned i = 0; i < Term->getNumSuccessors(); i++) { - BasicBlock *Succ = Term->getSuccessor(i); - if (DT.dominates(Succ, Term->getParent())) { - Headers.insert(Succ); - } - } - assert(!Headers.empty() && "poll location is not a loop latch?"); - - // The split loop structure here is so that we only need to recalculate - // the dominator tree once. Alternatively, we could just keep it up to - // date and use a more natural merged loop. - SetVector<BasicBlock *> SplitBackedges; - for (BasicBlock *Header : Headers) { - BasicBlock *NewBB = SplitEdge(Term->getParent(), Header, &DT); - PollsNeeded.push_back(NewBB->getTerminator()); - NumBackedgeSafepoints++; - } - } else { - // Split the latch block itself, right before the terminator. - PollsNeeded.push_back(Term); - NumBackedgeSafepoints++; - } - } - } - - if (enableEntrySafepoints(F)) { - if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) { - PollsNeeded.push_back(Location); - Modified = true; - NumEntrySafepoints++; - } - // TODO: else we should assert that there was, in fact, a policy choice to - // not insert a entry safepoint poll. - } - - // Now that we've identified all the needed safepoint poll locations, insert - // safepoint polls themselves. - for (Instruction *PollLocation : PollsNeeded) { - std::vector<CallBase *> RuntimeCalls; - InsertSafepointPoll(PollLocation, RuntimeCalls, TLI); - llvm::append_range(ParsePointNeeded, RuntimeCalls); - } - - return Modified; -} - -char PlaceBackedgeSafepointsImpl::ID = 0; -char PlaceSafepoints::ID = 0; - -FunctionPass *llvm::createPlaceSafepointsPass() { - return new PlaceSafepoints(); -} - -INITIALIZE_PASS_BEGIN(PlaceBackedgeSafepointsImpl, - "place-backedge-safepoints-impl", - "Place Backedge Safepoints", false, false) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(PlaceBackedgeSafepointsImpl, - "place-backedge-safepoints-impl", - "Place Backedge Safepoints", false, false) - -INITIALIZE_PASS_BEGIN(PlaceSafepoints, "place-safepoints", "Place Safepoints", - false, false) -INITIALIZE_PASS_END(PlaceSafepoints, "place-safepoints", "Place Safepoints", - false, false) - +// Insert a safepoint poll immediately before the given instruction. Does +// not handle the parsability of state at the runtime call, that's the +// callers job. static void InsertSafepointPoll(Instruction *InsertBefore, std::vector<CallBase *> &ParsePointsNeeded /*rval*/, diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 21628b61edd6..40c84e249523 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -52,6 +52,7 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" @@ -70,6 +71,12 @@ STATISTIC(NumChanged, "Number of insts reassociated"); STATISTIC(NumAnnihil, "Number of expr tree annihilated"); STATISTIC(NumFactor , "Number of multiplies factored"); +static cl::opt<bool> + UseCSELocalOpt(DEBUG_TYPE "-use-cse-local", + cl::desc("Only reorder expressions within a basic block " + "when exposing CSE opportunities"), + cl::init(true), cl::Hidden); + #ifndef NDEBUG /// Print out the expression identified in the Ops list. static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) { @@ -620,8 +627,7 @@ static bool LinearizeExprTree(Instruction *I, // The leaves, repeated according to their weights, represent the linearized // form of the expression. - for (unsigned i = 0, e = LeafOrder.size(); i != e; ++i) { - Value *V = LeafOrder[i]; + for (Value *V : LeafOrder) { LeafMap::iterator It = Leaves.find(V); if (It == Leaves.end()) // Node initially thought to be a leaf wasn't. @@ -683,10 +689,12 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, for (unsigned i = 0, e = Ops.size(); i != e; ++i) NotRewritable.insert(Ops[i].Op); - // ExpressionChanged - Non-null if the rewritten expression differs from the - // original in some non-trivial way, requiring the clearing of optional flags. - // Flags are cleared from the operator in ExpressionChanged up to I inclusive. - BinaryOperator *ExpressionChanged = nullptr; + // ExpressionChangedStart - Non-null if the rewritten expression differs from + // the original in some non-trivial way, requiring the clearing of optional + // flags. Flags are cleared from the operator in ExpressionChangedStart up to + // ExpressionChangedEnd inclusive. + BinaryOperator *ExpressionChangedStart = nullptr, + *ExpressionChangedEnd = nullptr; for (unsigned i = 0; ; ++i) { // The last operation (which comes earliest in the IR) is special as both // operands will come from Ops, rather than just one with the other being @@ -728,7 +736,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, } LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); - ExpressionChanged = Op; + ExpressionChangedStart = Op; + if (!ExpressionChangedEnd) + ExpressionChangedEnd = Op; MadeChange = true; ++NumChanged; @@ -750,7 +760,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, if (BO && !NotRewritable.count(BO)) NodesToRewrite.push_back(BO); Op->setOperand(1, NewRHS); - ExpressionChanged = Op; + ExpressionChangedStart = Op; + if (!ExpressionChangedEnd) + ExpressionChangedEnd = Op; } LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); MadeChange = true; @@ -787,7 +799,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); Op->setOperand(0, NewOp); LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); - ExpressionChanged = Op; + ExpressionChangedStart = Op; + if (!ExpressionChangedEnd) + ExpressionChangedEnd = Op; MadeChange = true; ++NumChanged; Op = NewOp; @@ -797,27 +811,36 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, // starting from the operator specified in ExpressionChanged, and compactify // the operators to just before the expression root to guarantee that the // expression tree is dominated by all of Ops. - if (ExpressionChanged) + if (ExpressionChangedStart) { + bool ClearFlags = true; do { // Preserve FastMathFlags. - if (isa<FPMathOperator>(I)) { - FastMathFlags Flags = I->getFastMathFlags(); - ExpressionChanged->clearSubclassOptionalData(); - ExpressionChanged->setFastMathFlags(Flags); - } else - ExpressionChanged->clearSubclassOptionalData(); - - if (ExpressionChanged == I) + if (ClearFlags) { + if (isa<FPMathOperator>(I)) { + FastMathFlags Flags = I->getFastMathFlags(); + ExpressionChangedStart->clearSubclassOptionalData(); + ExpressionChangedStart->setFastMathFlags(Flags); + } else + ExpressionChangedStart->clearSubclassOptionalData(); + } + + if (ExpressionChangedStart == ExpressionChangedEnd) + ClearFlags = false; + if (ExpressionChangedStart == I) break; // Discard any debug info related to the expressions that has changed (we - // can leave debug infor related to the root, since the result of the - // expression tree should be the same even after reassociation). - replaceDbgUsesWithUndef(ExpressionChanged); - - ExpressionChanged->moveBefore(I); - ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin()); + // can leave debug info related to the root and any operation that didn't + // change, since the result of the expression tree should be the same + // even after reassociation). + if (ClearFlags) + replaceDbgUsesWithUndef(ExpressionChangedStart); + + ExpressionChangedStart->moveBefore(I); + ExpressionChangedStart = + cast<BinaryOperator>(*ExpressionChangedStart->user_begin()); } while (true); + } // Throw away any left over nodes from the original expression. for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i) @@ -1507,8 +1530,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // Step 4: Reassemble the Ops if (Changed) { Ops.clear(); - for (unsigned int i = 0, e = Opnds.size(); i < e; i++) { - XorOpnd &O = Opnds[i]; + for (const XorOpnd &O : Opnds) { if (O.isInvalid()) continue; ValueEntry VE(getRank(O.getValue()), O.getValue()); @@ -1644,8 +1666,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, // Add one to FactorOccurrences for each unique factor in this op. SmallPtrSet<Value*, 8> Duplicates; - for (unsigned i = 0, e = Factors.size(); i != e; ++i) { - Value *Factor = Factors[i]; + for (Value *Factor : Factors) { if (!Duplicates.insert(Factor).second) continue; @@ -2048,7 +2069,7 @@ void ReassociatePass::EraseInst(Instruction *I) { // blocks because it's a waste of time and also because it can // lead to infinite loop due to LLVM's non-standard definition // of dominance. - if (ValueRankMap.find(Op) != ValueRankMap.end()) + if (ValueRankMap.contains(Op)) RedoInsts.insert(Op); } @@ -2410,8 +2431,67 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { unsigned BestRank = 0; std::pair<unsigned, unsigned> BestPair; unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; - for (unsigned i = 0; i < Ops.size() - 1; ++i) - for (unsigned j = i + 1; j < Ops.size(); ++j) { + unsigned LimitIdx = 0; + // With the CSE-driven heuristic, we are about to slap two values at the + // beginning of the expression whereas they could live very late in the CFG. + // When using the CSE-local heuristic we avoid creating dependences from + // completely unrelated part of the CFG by limiting the expression + // reordering on the values that live in the first seen basic block. + // The main idea is that we want to avoid forming expressions that would + // become loop dependent. + if (UseCSELocalOpt) { + const BasicBlock *FirstSeenBB = nullptr; + int StartIdx = Ops.size() - 1; + // Skip the first value of the expression since we need at least two + // values to materialize an expression. I.e., even if this value is + // anchored in a different basic block, the actual first sub expression + // will be anchored on the second value. + for (int i = StartIdx - 1; i != -1; --i) { + const Value *Val = Ops[i].Op; + const auto *CurrLeafInstr = dyn_cast<Instruction>(Val); + const BasicBlock *SeenBB = nullptr; + if (!CurrLeafInstr) { + // The value is free of any CFG dependencies. + // Do as if it lives in the entry block. + // + // We do this to make sure all the values falling on this path are + // seen through the same anchor point. The rationale is these values + // can be combined together to from a sub expression free of any CFG + // dependencies so we want them to stay together. + // We could be cleverer and postpone the anchor down to the first + // anchored value, but that's likely complicated to get right. + // E.g., we wouldn't want to do that if that means being stuck in a + // loop. + // + // For instance, we wouldn't want to change: + // res = arg1 op arg2 op arg3 op ... op loop_val1 op loop_val2 ... + // into + // res = loop_val1 op arg1 op arg2 op arg3 op ... op loop_val2 ... + // Because all the sub expressions with arg2..N would be stuck between + // two loop dependent values. + SeenBB = &I->getParent()->getParent()->getEntryBlock(); + } else { + SeenBB = CurrLeafInstr->getParent(); + } + + if (!FirstSeenBB) { + FirstSeenBB = SeenBB; + continue; + } + if (FirstSeenBB != SeenBB) { + // ith value is in a different basic block. + // Rewind the index once to point to the last value on the same basic + // block. + LimitIdx = i + 1; + LLVM_DEBUG(dbgs() << "CSE reordering: Consider values between [" + << LimitIdx << ", " << StartIdx << "]\n"); + break; + } + } + } + for (unsigned i = Ops.size() - 1; i > LimitIdx; --i) { + // We must use int type to go below zero when LimitIdx is 0. + for (int j = i - 1; j >= (int)LimitIdx; --j) { unsigned Score = 0; Value *Op0 = Ops[i].Op; Value *Op1 = Ops[j].Op; @@ -2429,12 +2509,26 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { } unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); + + // By construction, the operands are sorted in reverse order of their + // topological order. + // So we tend to form (sub) expressions with values that are close to + // each other. + // + // Now to expose more CSE opportunities we want to expose the pair of + // operands that occur the most (as statically computed in + // BuildPairMap.) as the first sub-expression. + // + // If two pairs occur as many times, we pick the one with the + // lowest rank, meaning the one with both operands appearing first in + // the topological order. if (Score > Max || (Score == Max && MaxRank < BestRank)) { - BestPair = {i, j}; + BestPair = {j, i}; Max = Score; BestRank = MaxRank; } } + } if (Max > 1) { auto Op0 = Ops[BestPair.first]; auto Op1 = Ops[BestPair.second]; @@ -2444,6 +2538,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { Ops.push_back(Op1); } } + LLVM_DEBUG(dbgs() << "RAOut after CSE reorder:\t"; PrintOps(I, Ops); + dbgs() << '\n'); // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. RewriteExprTree(I, Ops); diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index bcb012b79c2e..908bda5709a0 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallingConv.h" @@ -36,6 +37,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GCStrategy.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" @@ -125,6 +127,9 @@ static cl::opt<bool> RematDerivedAtUses("rs4gc-remat-derived-at-uses", /// constant physical memory: llvm.invariant.start. static void stripNonValidData(Module &M); +// Find the GC strategy for a function, or null if it doesn't have one. +static std::unique_ptr<GCStrategy> findGCStrategy(Function &F); + static bool shouldRewriteStatepointsIn(Function &F); PreservedAnalyses RewriteStatepointsForGC::run(Module &M, @@ -162,76 +167,6 @@ PreservedAnalyses RewriteStatepointsForGC::run(Module &M, namespace { -class RewriteStatepointsForGCLegacyPass : public ModulePass { - RewriteStatepointsForGC Impl; - -public: - static char ID; // Pass identification, replacement for typeid - - RewriteStatepointsForGCLegacyPass() : ModulePass(ID), Impl() { - initializeRewriteStatepointsForGCLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - bool Changed = false; - for (Function &F : M) { - // Nothing to do for declarations. - if (F.isDeclaration() || F.empty()) - continue; - - // Policy choice says not to rewrite - the most common reason is that - // we're compiling code without a GCStrategy. - if (!shouldRewriteStatepointsIn(F)) - continue; - - TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); - - Changed |= Impl.runOnFunction(F, DT, TTI, TLI); - } - - if (!Changed) - return false; - - // stripNonValidData asserts that shouldRewriteStatepointsIn - // returns true for at least one function in the module. Since at least - // one function changed, we know that the precondition is satisfied. - stripNonValidData(M); - return true; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - // We add and rewrite a bunch of instructions, but don't really do much - // else. We could in theory preserve a lot more analyses here. - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - -char RewriteStatepointsForGCLegacyPass::ID = 0; - -ModulePass *llvm::createRewriteStatepointsForGCLegacyPass() { - return new RewriteStatepointsForGCLegacyPass(); -} - -INITIALIZE_PASS_BEGIN(RewriteStatepointsForGCLegacyPass, - "rewrite-statepoints-for-gc", - "Make relocations explicit at statepoints", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(RewriteStatepointsForGCLegacyPass, - "rewrite-statepoints-for-gc", - "Make relocations explicit at statepoints", false, false) - -namespace { - struct GCPtrLivenessData { /// Values defined in this block. MapVector<BasicBlock *, SetVector<Value *>> KillSet; @@ -311,37 +246,35 @@ static ArrayRef<Use> GetDeoptBundleOperands(const CallBase *Call) { /// Compute the live-in set for every basic block in the function static void computeLiveInValues(DominatorTree &DT, Function &F, - GCPtrLivenessData &Data); + GCPtrLivenessData &Data, GCStrategy *GC); /// Given results from the dataflow liveness computation, find the set of live /// Values at a particular instruction. static void findLiveSetAtInst(Instruction *inst, GCPtrLivenessData &Data, - StatepointLiveSetTy &out); + StatepointLiveSetTy &out, GCStrategy *GC); -// TODO: Once we can get to the GCStrategy, this becomes -// std::optional<bool> isGCManagedPointer(const Type *Ty) const override { +static bool isGCPointerType(Type *T, GCStrategy *GC) { + assert(GC && "GC Strategy for isGCPointerType cannot be null"); -static bool isGCPointerType(Type *T) { - if (auto *PT = dyn_cast<PointerType>(T)) - // For the sake of this example GC, we arbitrarily pick addrspace(1) as our - // GC managed heap. We know that a pointer into this heap needs to be - // updated and that no other pointer does. - return PT->getAddressSpace() == 1; - return false; + if (!isa<PointerType>(T)) + return false; + + // conservative - same as StatepointLowering + return GC->isGCManagedPointer(T).value_or(true); } // Return true if this type is one which a) is a gc pointer or contains a GC // pointer and b) is of a type this code expects to encounter as a live value. // (The insertion code will assert that a type which matches (a) and not (b) // is not encountered.) -static bool isHandledGCPointerType(Type *T) { +static bool isHandledGCPointerType(Type *T, GCStrategy *GC) { // We fully support gc pointers - if (isGCPointerType(T)) + if (isGCPointerType(T, GC)) return true; // We partially support vectors of gc pointers. The code will assert if it // can't handle something. if (auto VT = dyn_cast<VectorType>(T)) - if (isGCPointerType(VT->getElementType())) + if (isGCPointerType(VT->getElementType(), GC)) return true; return false; } @@ -349,23 +282,24 @@ static bool isHandledGCPointerType(Type *T) { #ifndef NDEBUG /// Returns true if this type contains a gc pointer whether we know how to /// handle that type or not. -static bool containsGCPtrType(Type *Ty) { - if (isGCPointerType(Ty)) +static bool containsGCPtrType(Type *Ty, GCStrategy *GC) { + if (isGCPointerType(Ty, GC)) return true; if (VectorType *VT = dyn_cast<VectorType>(Ty)) - return isGCPointerType(VT->getScalarType()); + return isGCPointerType(VT->getScalarType(), GC); if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) - return containsGCPtrType(AT->getElementType()); + return containsGCPtrType(AT->getElementType(), GC); if (StructType *ST = dyn_cast<StructType>(Ty)) - return llvm::any_of(ST->elements(), containsGCPtrType); + return llvm::any_of(ST->elements(), + [GC](Type *Ty) { return containsGCPtrType(Ty, GC); }); return false; } // Returns true if this is a type which a) is a gc pointer or contains a GC // pointer and b) is of a type which the code doesn't expect (i.e. first class // aggregates). Used to trip assertions. -static bool isUnhandledGCPointerType(Type *Ty) { - return containsGCPtrType(Ty) && !isHandledGCPointerType(Ty); +static bool isUnhandledGCPointerType(Type *Ty, GCStrategy *GC) { + return containsGCPtrType(Ty, GC) && !isHandledGCPointerType(Ty, GC); } #endif @@ -382,9 +316,9 @@ static std::string suffixed_name_or(Value *V, StringRef Suffix, // live. Values used by that instruction are considered live. static void analyzeParsePointLiveness( DominatorTree &DT, GCPtrLivenessData &OriginalLivenessData, CallBase *Call, - PartiallyConstructedSafepointRecord &Result) { + PartiallyConstructedSafepointRecord &Result, GCStrategy *GC) { StatepointLiveSetTy LiveSet; - findLiveSetAtInst(Call, OriginalLivenessData, LiveSet); + findLiveSetAtInst(Call, OriginalLivenessData, LiveSet, GC); if (PrintLiveSet) { dbgs() << "Live Variables:\n"; @@ -692,7 +626,7 @@ static Value *findBaseDefiningValue(Value *I, DefiningValueMapTy &Cache, /// Returns the base defining value for this value. static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache, IsKnownBaseMapTy &KnownBases) { - if (Cache.find(I) == Cache.end()) { + if (!Cache.contains(I)) { auto *BDV = findBaseDefiningValue(I, Cache, KnownBases); Cache[I] = BDV; LLVM_DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> " @@ -700,7 +634,7 @@ static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache, << KnownBases[I] << "\n"); } assert(Cache[I] != nullptr); - assert(KnownBases.find(Cache[I]) != KnownBases.end() && + assert(KnownBases.contains(Cache[I]) && "Cached value must be present in known bases map"); return Cache[I]; } @@ -1289,9 +1223,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, if (!BdvSV->isZeroEltSplat()) UpdateOperand(1); // vector operand else { - // Never read, so just use undef + // Never read, so just use poison Value *InVal = BdvSV->getOperand(1); - BaseSV->setOperand(1, UndefValue::get(InVal->getType())); + BaseSV->setOperand(1, PoisonValue::get(InVal->getType())); } } } @@ -1385,20 +1319,21 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, CallBase *Call, PartiallyConstructedSafepointRecord &result, - PointerToBaseTy &PointerToBase); + PointerToBaseTy &PointerToBase, + GCStrategy *GC); static void recomputeLiveInValues( Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate, MutableArrayRef<struct PartiallyConstructedSafepointRecord> records, - PointerToBaseTy &PointerToBase) { + PointerToBaseTy &PointerToBase, GCStrategy *GC) { // TODO-PERF: reuse the original liveness, then simply run the dataflow // again. The old values are still live and will help it stabilize quickly. GCPtrLivenessData RevisedLivenessData; - computeLiveInValues(DT, F, RevisedLivenessData); + computeLiveInValues(DT, F, RevisedLivenessData, GC); for (size_t i = 0; i < records.size(); i++) { struct PartiallyConstructedSafepointRecord &info = records[i]; - recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info, - PointerToBase); + recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info, PointerToBase, + GC); } } @@ -1522,7 +1457,7 @@ static AttributeList legalizeCallAttributes(LLVMContext &Ctx, static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, ArrayRef<Value *> BasePtrs, Instruction *StatepointToken, - IRBuilder<> &Builder) { + IRBuilder<> &Builder, GCStrategy *GC) { if (LiveVariables.empty()) return; @@ -1542,8 +1477,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, // towards a single unified pointer type anyways, we can just cast everything // to an i8* of the right address space. A bitcast is added later to convert // gc_relocate to the actual value's type. - auto getGCRelocateDecl = [&] (Type *Ty) { - assert(isHandledGCPointerType(Ty)); + auto getGCRelocateDecl = [&](Type *Ty) { + assert(isHandledGCPointerType(Ty, GC)); auto AS = Ty->getScalarType()->getPointerAddressSpace(); Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS); if (auto *VT = dyn_cast<VectorType>(Ty)) @@ -1668,7 +1603,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ const SmallVectorImpl<Value *> &LiveVariables, PartiallyConstructedSafepointRecord &Result, std::vector<DeferredReplacement> &Replacements, - const PointerToBaseTy &PointerToBase) { + const PointerToBaseTy &PointerToBase, + GCStrategy *GC) { assert(BasePtrs.size() == LiveVariables.size()); // Then go ahead and use the builder do actually do the inserts. We insert @@ -1901,7 +1837,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst(); Result.UnwindToken = ExceptionalToken; - CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder); + CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder, GC); // Generate gc relocates and returns for normal block BasicBlock *NormalDest = II->getNormalDest(); @@ -1947,7 +1883,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ Result.StatepointToken = Token; // Second, create a gc.relocate for every live variable - CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder); + CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder, GC); } // Replace an existing gc.statepoint with a new one and a set of gc.relocates @@ -1959,7 +1895,7 @@ static void makeStatepointExplicit(DominatorTree &DT, CallBase *Call, PartiallyConstructedSafepointRecord &Result, std::vector<DeferredReplacement> &Replacements, - const PointerToBaseTy &PointerToBase) { + const PointerToBaseTy &PointerToBase, GCStrategy *GC) { const auto &LiveSet = Result.LiveSet; // Convert to vector for efficient cross referencing. @@ -1976,7 +1912,7 @@ makeStatepointExplicit(DominatorTree &DT, CallBase *Call, // Do the actual rewriting and delete the old statepoint makeStatepointExplicitImpl(Call, BaseVec, LiveVec, Result, Replacements, - PointerToBase); + PointerToBase, GC); } // Helper function for the relocationViaAlloca. @@ -2277,12 +2213,13 @@ static void insertUseHolderAfter(CallBase *Call, const ArrayRef<Value *> Values, static void findLiveReferences( Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate, - MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) { + MutableArrayRef<struct PartiallyConstructedSafepointRecord> records, + GCStrategy *GC) { GCPtrLivenessData OriginalLivenessData; - computeLiveInValues(DT, F, OriginalLivenessData); + computeLiveInValues(DT, F, OriginalLivenessData, GC); for (size_t i = 0; i < records.size(); i++) { struct PartiallyConstructedSafepointRecord &info = records[i]; - analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info); + analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info, GC); } } @@ -2684,6 +2621,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, SmallVectorImpl<CallBase *> &ToUpdate, DefiningValueMapTy &DVCache, IsKnownBaseMapTy &KnownBases) { + std::unique_ptr<GCStrategy> GC = findGCStrategy(F); + #ifndef NDEBUG // Validate the input std::set<CallBase *> Uniqued; @@ -2718,9 +2657,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, SmallVector<Value *, 64> DeoptValues; for (Value *Arg : GetDeoptBundleOperands(Call)) { - assert(!isUnhandledGCPointerType(Arg->getType()) && + assert(!isUnhandledGCPointerType(Arg->getType(), GC.get()) && "support for FCA unimplemented"); - if (isHandledGCPointerType(Arg->getType())) + if (isHandledGCPointerType(Arg->getType(), GC.get())) DeoptValues.push_back(Arg); } @@ -2731,7 +2670,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // A) Identify all gc pointers which are statically live at the given call // site. - findLiveReferences(F, DT, ToUpdate, Records); + findLiveReferences(F, DT, ToUpdate, Records, GC.get()); /// Global mapping from live pointers to a base-defining-value. PointerToBaseTy PointerToBase; @@ -2782,7 +2721,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // By selecting base pointers, we've effectively inserted new uses. Thus, we // need to rerun liveness. We may *also* have inserted new defs, but that's // not the key issue. - recomputeLiveInValues(F, DT, ToUpdate, Records, PointerToBase); + recomputeLiveInValues(F, DT, ToUpdate, Records, PointerToBase, GC.get()); if (PrintBasePointers) { errs() << "Base Pairs: (w/Relocation)\n"; @@ -2842,7 +2781,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // the old statepoint calls as we go.) for (size_t i = 0; i < Records.size(); i++) makeStatepointExplicit(DT, ToUpdate[i], Records[i], Replacements, - PointerToBase); + PointerToBase, GC.get()); ToUpdate.clear(); // prevent accident use of invalid calls. @@ -2866,9 +2805,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // Do all the fixups of the original live variables to their relocated selves SmallVector<Value *, 128> Live; - for (size_t i = 0; i < Records.size(); i++) { - PartiallyConstructedSafepointRecord &Info = Records[i]; - + for (const PartiallyConstructedSafepointRecord &Info : Records) { // We can't simply save the live set from the original insertion. One of // the live values might be the result of a call which needs a safepoint. // That Value* no longer exists and we need to use the new gc_result. @@ -2899,7 +2836,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, #ifndef NDEBUG // Validation check for (auto *Ptr : Live) - assert(isHandledGCPointerType(Ptr->getType()) && + assert(isHandledGCPointerType(Ptr->getType(), GC.get()) && "must be a gc pointer type"); #endif @@ -3019,25 +2956,33 @@ static void stripNonValidDataFromBody(Function &F) { } } - // Delete the invariant.start instructions and RAUW undef. + // Delete the invariant.start instructions and RAUW poison. for (auto *II : InvariantStartInstructions) { - II->replaceAllUsesWith(UndefValue::get(II->getType())); + II->replaceAllUsesWith(PoisonValue::get(II->getType())); II->eraseFromParent(); } } +/// Looks up the GC strategy for a given function, returning null if the +/// function doesn't have a GC tag. The strategy is stored in the cache. +static std::unique_ptr<GCStrategy> findGCStrategy(Function &F) { + if (!F.hasGC()) + return nullptr; + + return getGCStrategy(F.getGC()); +} + /// Returns true if this function should be rewritten by this pass. The main /// point of this function is as an extension point for custom logic. static bool shouldRewriteStatepointsIn(Function &F) { - // TODO: This should check the GCStrategy - if (F.hasGC()) { - const auto &FunctionGCName = F.getGC(); - const StringRef StatepointExampleName("statepoint-example"); - const StringRef CoreCLRName("coreclr"); - return (StatepointExampleName == FunctionGCName) || - (CoreCLRName == FunctionGCName); - } else + if (!F.hasGC()) return false; + + std::unique_ptr<GCStrategy> Strategy = findGCStrategy(F); + + assert(Strategy && "GC strategy is required by function, but was not found"); + + return Strategy->useRS4GC(); } static void stripNonValidData(Module &M) { @@ -3216,7 +3161,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, /// the live-out set of the basic block static void computeLiveInValues(BasicBlock::reverse_iterator Begin, BasicBlock::reverse_iterator End, - SetVector<Value *> &LiveTmp) { + SetVector<Value *> &LiveTmp, GCStrategy *GC) { for (auto &I : make_range(Begin, End)) { // KILL/Def - Remove this definition from LiveIn LiveTmp.remove(&I); @@ -3228,9 +3173,9 @@ static void computeLiveInValues(BasicBlock::reverse_iterator Begin, // USE - Add to the LiveIn set for this instruction for (Value *V : I.operands()) { - assert(!isUnhandledGCPointerType(V->getType()) && + assert(!isUnhandledGCPointerType(V->getType(), GC) && "support for FCA unimplemented"); - if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) { + if (isHandledGCPointerType(V->getType(), GC) && !isa<Constant>(V)) { // The choice to exclude all things constant here is slightly subtle. // There are two independent reasons: // - We assume that things which are constant (from LLVM's definition) @@ -3247,7 +3192,8 @@ static void computeLiveInValues(BasicBlock::reverse_iterator Begin, } } -static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) { +static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp, + GCStrategy *GC) { for (BasicBlock *Succ : successors(BB)) { for (auto &I : *Succ) { PHINode *PN = dyn_cast<PHINode>(&I); @@ -3255,18 +3201,18 @@ static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) { break; Value *V = PN->getIncomingValueForBlock(BB); - assert(!isUnhandledGCPointerType(V->getType()) && + assert(!isUnhandledGCPointerType(V->getType(), GC) && "support for FCA unimplemented"); - if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) + if (isHandledGCPointerType(V->getType(), GC) && !isa<Constant>(V)) LiveTmp.insert(V); } } } -static SetVector<Value *> computeKillSet(BasicBlock *BB) { +static SetVector<Value *> computeKillSet(BasicBlock *BB, GCStrategy *GC) { SetVector<Value *> KillSet; for (Instruction &I : *BB) - if (isHandledGCPointerType(I.getType())) + if (isHandledGCPointerType(I.getType(), GC)) KillSet.insert(&I); return KillSet; } @@ -3301,14 +3247,14 @@ static void checkBasicSSA(DominatorTree &DT, GCPtrLivenessData &Data, #endif static void computeLiveInValues(DominatorTree &DT, Function &F, - GCPtrLivenessData &Data) { + GCPtrLivenessData &Data, GCStrategy *GC) { SmallSetVector<BasicBlock *, 32> Worklist; // Seed the liveness for each individual block for (BasicBlock &BB : F) { - Data.KillSet[&BB] = computeKillSet(&BB); + Data.KillSet[&BB] = computeKillSet(&BB, GC); Data.LiveSet[&BB].clear(); - computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB]); + computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB], GC); #ifndef NDEBUG for (Value *Kill : Data.KillSet[&BB]) @@ -3316,7 +3262,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F, #endif Data.LiveOut[&BB] = SetVector<Value *>(); - computeLiveOutSeed(&BB, Data.LiveOut[&BB]); + computeLiveOutSeed(&BB, Data.LiveOut[&BB], GC); Data.LiveIn[&BB] = Data.LiveSet[&BB]; Data.LiveIn[&BB].set_union(Data.LiveOut[&BB]); Data.LiveIn[&BB].set_subtract(Data.KillSet[&BB]); @@ -3368,7 +3314,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F, } static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, - StatepointLiveSetTy &Out) { + StatepointLiveSetTy &Out, GCStrategy *GC) { BasicBlock *BB = Inst->getParent(); // Note: The copy is intentional and required @@ -3379,8 +3325,8 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, // call result is not live (normal), nor are it's arguments // (unless they're used again later). This adjustment is // specifically what we need to relocate - computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(), - LiveOut); + computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(), LiveOut, + GC); LiveOut.remove(Inst); Out.insert(LiveOut.begin(), LiveOut.end()); } @@ -3388,9 +3334,10 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, CallBase *Call, PartiallyConstructedSafepointRecord &Info, - PointerToBaseTy &PointerToBase) { + PointerToBaseTy &PointerToBase, + GCStrategy *GC) { StatepointLiveSetTy Updated; - findLiveSetAtInst(Call, RevisedLivenessData, Updated); + findLiveSetAtInst(Call, RevisedLivenessData, Updated, GC); // We may have base pointers which are now live that weren't before. We need // to update the PointerToBase structure to reflect this. diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp index 7b396c6ee074..fcdc503c54a4 100644 --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -41,7 +41,6 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -136,54 +135,3 @@ PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<DominatorTreeAnalysis>(); return PA; } - -namespace { - -//===--------------------------------------------------------------------===// -// -/// SCCP Class - This class uses the SCCPSolver to implement a per-function -/// Sparse Conditional Constant Propagator. -/// -class SCCPLegacyPass : public FunctionPass { -public: - // Pass identification, replacement for typeid - static char ID; - - SCCPLegacyPass() : FunctionPass(ID) { - initializeSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - } - - // runOnFunction - Run the Sparse Conditional Constant Propagation - // algorithm, and return true if the function was modified. - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - const DataLayout &DL = F.getParent()->getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DomTreeUpdater DTU(DTWP ? &DTWP->getDomTree() : nullptr, - DomTreeUpdater::UpdateStrategy::Lazy); - return runSCCP(F, DL, TLI, DTU); - } -}; - -} // end anonymous namespace - -char SCCPLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(SCCPLegacyPass, "sccp", - "Sparse Conditional Constant Propagation", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(SCCPLegacyPass, "sccp", - "Sparse Conditional Constant Propagation", false, false) - -// createSCCPPass - This is the public interface to this file. -FunctionPass *llvm::createSCCPPass() { return new SCCPLegacyPass(); } - diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 8339981e1bdc..983a75e1d708 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -118,13 +118,79 @@ STATISTIC(NumVectorized, "Number of vectorized aggregates"); /// GEPs. static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), cl::Hidden); +/// Disable running mem2reg during SROA in order to test or debug SROA. +static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false), + cl::Hidden); namespace { + +/// Calculate the fragment of a variable to use when slicing a store +/// based on the slice dimensions, existing fragment, and base storage +/// fragment. +/// Results: +/// UseFrag - Use Target as the new fragment. +/// UseNoFrag - The new slice already covers the whole variable. +/// Skip - The new alloca slice doesn't include this variable. +/// FIXME: Can we use calculateFragmentIntersect instead? +enum FragCalcResult { UseFrag, UseNoFrag, Skip }; +static FragCalcResult +calculateFragment(DILocalVariable *Variable, + uint64_t NewStorageSliceOffsetInBits, + uint64_t NewStorageSliceSizeInBits, + std::optional<DIExpression::FragmentInfo> StorageFragment, + std::optional<DIExpression::FragmentInfo> CurrentFragment, + DIExpression::FragmentInfo &Target) { + // If the base storage describes part of the variable apply the offset and + // the size constraint. + if (StorageFragment) { + Target.SizeInBits = + std::min(NewStorageSliceSizeInBits, StorageFragment->SizeInBits); + Target.OffsetInBits = + NewStorageSliceOffsetInBits + StorageFragment->OffsetInBits; + } else { + Target.SizeInBits = NewStorageSliceSizeInBits; + Target.OffsetInBits = NewStorageSliceOffsetInBits; + } + + // If this slice extracts the entirety of an independent variable from a + // larger alloca, do not produce a fragment expression, as the variable is + // not fragmented. + if (!CurrentFragment) { + if (auto Size = Variable->getSizeInBits()) { + // Treat the current fragment as covering the whole variable. + CurrentFragment = DIExpression::FragmentInfo(*Size, 0); + if (Target == CurrentFragment) + return UseNoFrag; + } + } + + // No additional work to do if there isn't a fragment already, or there is + // but it already exactly describes the new assignment. + if (!CurrentFragment || *CurrentFragment == Target) + return UseFrag; + + // Reject the target fragment if it doesn't fit wholly within the current + // fragment. TODO: We could instead chop up the target to fit in the case of + // a partial overlap. + if (Target.startInBits() < CurrentFragment->startInBits() || + Target.endInBits() > CurrentFragment->endInBits()) + return Skip; + + // Target fits within the current fragment, return it. + return UseFrag; +} + +static DebugVariable getAggregateVariable(DbgVariableIntrinsic *DVI) { + return DebugVariable(DVI->getVariable(), std::nullopt, + DVI->getDebugLoc().getInlinedAt()); +} + /// Find linked dbg.assign and generate a new one with the correct /// FragmentInfo. Link Inst to the new dbg.assign. If Value is nullptr the /// value component is copied from the old dbg.assign to the new. /// \param OldAlloca Alloca for the variable before splitting. -/// \param RelativeOffsetInBits Offset into \p OldAlloca relative to the -/// offset prior to splitting (change in offset). +/// \param IsSplit True if the store (not necessarily alloca) +/// is being split. +/// \param OldAllocaOffsetInBits Offset of the slice taken from OldAlloca. /// \param SliceSizeInBits New number of bits being written to. /// \param OldInst Instruction that is being split. /// \param Inst New instruction performing this part of the @@ -132,8 +198,8 @@ namespace { /// \param Dest Store destination. /// \param Value Stored value. /// \param DL Datalayout. -static void migrateDebugInfo(AllocaInst *OldAlloca, - uint64_t RelativeOffsetInBits, +static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, + uint64_t OldAllocaOffsetInBits, uint64_t SliceSizeInBits, Instruction *OldInst, Instruction *Inst, Value *Dest, Value *Value, const DataLayout &DL) { @@ -144,7 +210,9 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, LLVM_DEBUG(dbgs() << " migrateDebugInfo\n"); LLVM_DEBUG(dbgs() << " OldAlloca: " << *OldAlloca << "\n"); - LLVM_DEBUG(dbgs() << " RelativeOffset: " << RelativeOffsetInBits << "\n"); + LLVM_DEBUG(dbgs() << " IsSplit: " << IsSplit << "\n"); + LLVM_DEBUG(dbgs() << " OldAllocaOffsetInBits: " << OldAllocaOffsetInBits + << "\n"); LLVM_DEBUG(dbgs() << " SliceSizeInBits: " << SliceSizeInBits << "\n"); LLVM_DEBUG(dbgs() << " OldInst: " << *OldInst << "\n"); LLVM_DEBUG(dbgs() << " Inst: " << *Inst << "\n"); @@ -152,44 +220,66 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, if (Value) LLVM_DEBUG(dbgs() << " Value: " << *Value << "\n"); + /// Map of aggregate variables to their fragment associated with OldAlloca. + DenseMap<DebugVariable, std::optional<DIExpression::FragmentInfo>> + BaseFragments; + for (auto *DAI : at::getAssignmentMarkers(OldAlloca)) + BaseFragments[getAggregateVariable(DAI)] = + DAI->getExpression()->getFragmentInfo(); + // The new inst needs a DIAssignID unique metadata tag (if OldInst has // one). It shouldn't already have one: assert this assumption. assert(!Inst->getMetadata(LLVMContext::MD_DIAssignID)); DIAssignID *NewID = nullptr; auto &Ctx = Inst->getContext(); DIBuilder DIB(*OldInst->getModule(), /*AllowUnresolved*/ false); - uint64_t AllocaSizeInBits = *OldAlloca->getAllocationSizeInBits(DL); assert(OldAlloca->isStaticAlloca()); for (DbgAssignIntrinsic *DbgAssign : MarkerRange) { LLVM_DEBUG(dbgs() << " existing dbg.assign is: " << *DbgAssign << "\n"); auto *Expr = DbgAssign->getExpression(); + bool SetKillLocation = false; - // Check if the dbg.assign already describes a fragment. - auto GetCurrentFragSize = [AllocaSizeInBits, DbgAssign, - Expr]() -> uint64_t { - if (auto FI = Expr->getFragmentInfo()) - return FI->SizeInBits; - if (auto VarSize = DbgAssign->getVariable()->getSizeInBits()) - return *VarSize; - // The variable type has an unspecified size. This can happen in the - // case of DW_TAG_unspecified_type types, e.g. std::nullptr_t. Because - // there is no fragment and we do not know the size of the variable type, - // we'll guess by looking at the alloca. - return AllocaSizeInBits; - }; - uint64_t CurrentFragSize = GetCurrentFragSize(); - bool MakeNewFragment = CurrentFragSize != SliceSizeInBits; - assert(MakeNewFragment || RelativeOffsetInBits == 0); - - assert(SliceSizeInBits <= AllocaSizeInBits); - if (MakeNewFragment) { - assert(RelativeOffsetInBits + SliceSizeInBits <= CurrentFragSize); - auto E = DIExpression::createFragmentExpression( - Expr, RelativeOffsetInBits, SliceSizeInBits); - assert(E && "Failed to create fragment expr!"); - Expr = *E; + if (IsSplit) { + std::optional<DIExpression::FragmentInfo> BaseFragment; + { + auto R = BaseFragments.find(getAggregateVariable(DbgAssign)); + if (R == BaseFragments.end()) + continue; + BaseFragment = R->second; + } + std::optional<DIExpression::FragmentInfo> CurrentFragment = + Expr->getFragmentInfo(); + DIExpression::FragmentInfo NewFragment; + FragCalcResult Result = calculateFragment( + DbgAssign->getVariable(), OldAllocaOffsetInBits, SliceSizeInBits, + BaseFragment, CurrentFragment, NewFragment); + + if (Result == Skip) + continue; + if (Result == UseFrag && !(NewFragment == CurrentFragment)) { + if (CurrentFragment) { + // Rewrite NewFragment to be relative to the existing one (this is + // what createFragmentExpression wants). CalculateFragment has + // already resolved the size for us. FIXME: Should it return the + // relative fragment too? + NewFragment.OffsetInBits -= CurrentFragment->OffsetInBits; + } + // Add the new fragment info to the existing expression if possible. + if (auto E = DIExpression::createFragmentExpression( + Expr, NewFragment.OffsetInBits, NewFragment.SizeInBits)) { + Expr = *E; + } else { + // Otherwise, add the new fragment info to an empty expression and + // discard the value component of this dbg.assign as the value cannot + // be computed with the new fragment. + Expr = *DIExpression::createFragmentExpression( + DIExpression::get(Expr->getContext(), std::nullopt), + NewFragment.OffsetInBits, NewFragment.SizeInBits); + SetKillLocation = true; + } + } } // If we haven't created a DIAssignID ID do that now and attach it to Inst. @@ -198,11 +288,27 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, Inst->setMetadata(LLVMContext::MD_DIAssignID, NewID); } - Value = Value ? Value : DbgAssign->getValue(); + ::Value *NewValue = Value ? Value : DbgAssign->getValue(); auto *NewAssign = DIB.insertDbgAssign( - Inst, Value, DbgAssign->getVariable(), Expr, Dest, + Inst, NewValue, DbgAssign->getVariable(), Expr, Dest, DIExpression::get(Ctx, std::nullopt), DbgAssign->getDebugLoc()); + // If we've updated the value but the original dbg.assign has an arglist + // then kill it now - we can't use the requested new value. + // We can't replace the DIArgList with the new value as it'd leave + // the DIExpression in an invalid state (DW_OP_LLVM_arg operands without + // an arglist). And we can't keep the DIArgList in case the linked store + // is being split - in which case the DIArgList + expression may no longer + // be computing the correct value. + // This should be a very rare situation as it requires the value being + // stored to differ from the dbg.assign (i.e., the value has been + // represented differently in the debug intrinsic for some reason). + SetKillLocation |= + Value && (DbgAssign->hasArgList() || + !DbgAssign->getExpression()->isSingleLocationExpression()); + if (SetKillLocation) + NewAssign->setKillLocation(); + // We could use more precision here at the cost of some additional (code) // complexity - if the original dbg.assign was adjacent to its store, we // could position this new dbg.assign adjacent to its store rather than the @@ -888,11 +994,12 @@ private: if (!IsOffsetKnown) return PI.setAborted(&LI); - if (isa<ScalableVectorType>(LI.getType())) + TypeSize Size = DL.getTypeStoreSize(LI.getType()); + if (Size.isScalable()) return PI.setAborted(&LI); - uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedValue(); - return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile()); + return handleLoadOrStore(LI.getType(), LI, Offset, Size.getFixedValue(), + LI.isVolatile()); } void visitStoreInst(StoreInst &SI) { @@ -902,10 +1009,11 @@ private: if (!IsOffsetKnown) return PI.setAborted(&SI); - if (isa<ScalableVectorType>(ValOp->getType())) + TypeSize StoreSize = DL.getTypeStoreSize(ValOp->getType()); + if (StoreSize.isScalable()) return PI.setAborted(&SI); - uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedValue(); + uint64_t Size = StoreSize.getFixedValue(); // If this memory access can be shown to *statically* extend outside the // bounds of the allocation, it's behavior is undefined, so simply @@ -1520,12 +1628,6 @@ static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI, IRB.SetInsertPoint(&LI); - if (auto *TypedPtrTy = LI.getPointerOperandType(); - !TypedPtrTy->isOpaquePointerTy() && SI.getType() != TypedPtrTy) { - TV = IRB.CreateBitOrPointerCast(TV, TypedPtrTy, ""); - FV = IRB.CreateBitOrPointerCast(FV, TypedPtrTy, ""); - } - LoadInst *TL = IRB.CreateAlignedLoad(LI.getType(), TV, LI.getAlign(), LI.getName() + ".sroa.speculate.load.true"); @@ -1581,22 +1683,19 @@ static void rewriteMemOpOfSelect(SelectInst &SI, T &I, bool IsThen = SuccBB == HeadBI->getSuccessor(0); int SuccIdx = IsThen ? 0 : 1; auto *NewMemOpBB = SuccBB == Tail ? Head : SuccBB; + auto &CondMemOp = cast<T>(*I.clone()); if (NewMemOpBB != Head) { NewMemOpBB->setName(Head->getName() + (IsThen ? ".then" : ".else")); if (isa<LoadInst>(I)) ++NumLoadsPredicated; else ++NumStoresPredicated; - } else + } else { + CondMemOp.dropUBImplyingAttrsAndMetadata(); ++NumLoadsSpeculated; - auto &CondMemOp = cast<T>(*I.clone()); + } CondMemOp.insertBefore(NewMemOpBB->getTerminator()); Value *Ptr = SI.getOperand(1 + SuccIdx); - if (auto *PtrTy = Ptr->getType(); - !PtrTy->isOpaquePointerTy() && - PtrTy != CondMemOp.getPointerOperandType()) - Ptr = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( - Ptr, CondMemOp.getPointerOperandType(), "", &CondMemOp); CondMemOp.setOperand(I.getPointerOperandIndex(), Ptr); if (isa<LoadInst>(I)) { CondMemOp.setName(I.getName() + (IsThen ? ".then" : ".else") + ".val"); @@ -1654,238 +1753,16 @@ static bool rewriteSelectInstMemOps(SelectInst &SI, return CFGChanged; } -/// Build a GEP out of a base pointer and indices. -/// -/// This will return the BasePtr if that is valid, or build a new GEP -/// instruction using the IRBuilder if GEP-ing is needed. -static Value *buildGEP(IRBuilderTy &IRB, Value *BasePtr, - SmallVectorImpl<Value *> &Indices, - const Twine &NamePrefix) { - if (Indices.empty()) - return BasePtr; - - // A single zero index is a no-op, so check for this and avoid building a GEP - // in that case. - if (Indices.size() == 1 && cast<ConstantInt>(Indices.back())->isZero()) - return BasePtr; - - // buildGEP() is only called for non-opaque pointers. - return IRB.CreateInBoundsGEP( - BasePtr->getType()->getNonOpaquePointerElementType(), BasePtr, Indices, - NamePrefix + "sroa_idx"); -} - -/// Get a natural GEP off of the BasePtr walking through Ty toward -/// TargetTy without changing the offset of the pointer. -/// -/// This routine assumes we've already established a properly offset GEP with -/// Indices, and arrived at the Ty type. The goal is to continue to GEP with -/// zero-indices down through type layers until we find one the same as -/// TargetTy. If we can't find one with the same type, we at least try to use -/// one with the same size. If none of that works, we just produce the GEP as -/// indicated by Indices to have the correct offset. -static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL, - Value *BasePtr, Type *Ty, Type *TargetTy, - SmallVectorImpl<Value *> &Indices, - const Twine &NamePrefix) { - if (Ty == TargetTy) - return buildGEP(IRB, BasePtr, Indices, NamePrefix); - - // Offset size to use for the indices. - unsigned OffsetSize = DL.getIndexTypeSizeInBits(BasePtr->getType()); - - // See if we can descend into a struct and locate a field with the correct - // type. - unsigned NumLayers = 0; - Type *ElementTy = Ty; - do { - if (ElementTy->isPointerTy()) - break; - - if (ArrayType *ArrayTy = dyn_cast<ArrayType>(ElementTy)) { - ElementTy = ArrayTy->getElementType(); - Indices.push_back(IRB.getIntN(OffsetSize, 0)); - } else if (VectorType *VectorTy = dyn_cast<VectorType>(ElementTy)) { - ElementTy = VectorTy->getElementType(); - Indices.push_back(IRB.getInt32(0)); - } else if (StructType *STy = dyn_cast<StructType>(ElementTy)) { - if (STy->element_begin() == STy->element_end()) - break; // Nothing left to descend into. - ElementTy = *STy->element_begin(); - Indices.push_back(IRB.getInt32(0)); - } else { - break; - } - ++NumLayers; - } while (ElementTy != TargetTy); - if (ElementTy != TargetTy) - Indices.erase(Indices.end() - NumLayers, Indices.end()); - - return buildGEP(IRB, BasePtr, Indices, NamePrefix); -} - -/// Get a natural GEP from a base pointer to a particular offset and -/// resulting in a particular type. -/// -/// The goal is to produce a "natural" looking GEP that works with the existing -/// composite types to arrive at the appropriate offset and element type for -/// a pointer. TargetTy is the element type the returned GEP should point-to if -/// possible. We recurse by decreasing Offset, adding the appropriate index to -/// Indices, and setting Ty to the result subtype. -/// -/// If no natural GEP can be constructed, this function returns null. -static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, - Value *Ptr, APInt Offset, Type *TargetTy, - SmallVectorImpl<Value *> &Indices, - const Twine &NamePrefix) { - PointerType *Ty = cast<PointerType>(Ptr->getType()); - - // Don't consider any GEPs through an i8* as natural unless the TargetTy is - // an i8. - if (Ty == IRB.getInt8PtrTy(Ty->getAddressSpace()) && TargetTy->isIntegerTy(8)) - return nullptr; - - Type *ElementTy = Ty->getNonOpaquePointerElementType(); - if (!ElementTy->isSized()) - return nullptr; // We can't GEP through an unsized element. - - SmallVector<APInt> IntIndices = DL.getGEPIndicesForOffset(ElementTy, Offset); - if (Offset != 0) - return nullptr; - - for (const APInt &Index : IntIndices) - Indices.push_back(IRB.getInt(Index)); - return getNaturalGEPWithType(IRB, DL, Ptr, ElementTy, TargetTy, Indices, - NamePrefix); -} - /// Compute an adjusted pointer from Ptr by Offset bytes where the /// resulting pointer has PointerTy. -/// -/// This tries very hard to compute a "natural" GEP which arrives at the offset -/// and produces the pointer type desired. Where it cannot, it will try to use -/// the natural GEP to arrive at the offset and bitcast to the type. Where that -/// fails, it will try to use an existing i8* and GEP to the byte offset and -/// bitcast to the type. -/// -/// The strategy for finding the more natural GEPs is to peel off layers of the -/// pointer, walking back through bit casts and GEPs, searching for a base -/// pointer from which we can compute a natural GEP with the desired -/// properties. The algorithm tries to fold as many constant indices into -/// a single GEP as possible, thus making each GEP more independent of the -/// surrounding code. static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, APInt Offset, Type *PointerTy, const Twine &NamePrefix) { - // Create i8 GEP for opaque pointers. - if (Ptr->getType()->isOpaquePointerTy()) { - if (Offset != 0) - Ptr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(Offset), - NamePrefix + "sroa_idx"); - return IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, PointerTy, - NamePrefix + "sroa_cast"); - } - - // Even though we don't look through PHI nodes, we could be called on an - // instruction in an unreachable block, which may be on a cycle. - SmallPtrSet<Value *, 4> Visited; - Visited.insert(Ptr); - SmallVector<Value *, 4> Indices; - - // We may end up computing an offset pointer that has the wrong type. If we - // never are able to compute one directly that has the correct type, we'll - // fall back to it, so keep it and the base it was computed from around here. - Value *OffsetPtr = nullptr; - Value *OffsetBasePtr; - - // Remember any i8 pointer we come across to re-use if we need to do a raw - // byte offset. - Value *Int8Ptr = nullptr; - APInt Int8PtrOffset(Offset.getBitWidth(), 0); - - PointerType *TargetPtrTy = cast<PointerType>(PointerTy); - Type *TargetTy = TargetPtrTy->getNonOpaquePointerElementType(); - - // As `addrspacecast` is , `Ptr` (the storage pointer) may have different - // address space from the expected `PointerTy` (the pointer to be used). - // Adjust the pointer type based the original storage pointer. - auto AS = cast<PointerType>(Ptr->getType())->getAddressSpace(); - PointerTy = TargetTy->getPointerTo(AS); - - do { - // First fold any existing GEPs into the offset. - while (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) { - APInt GEPOffset(Offset.getBitWidth(), 0); - if (!GEP->accumulateConstantOffset(DL, GEPOffset)) - break; - Offset += GEPOffset; - Ptr = GEP->getPointerOperand(); - if (!Visited.insert(Ptr).second) - break; - } - - // See if we can perform a natural GEP here. - Indices.clear(); - if (Value *P = getNaturalGEPWithOffset(IRB, DL, Ptr, Offset, TargetTy, - Indices, NamePrefix)) { - // If we have a new natural pointer at the offset, clear out any old - // offset pointer we computed. Unless it is the base pointer or - // a non-instruction, we built a GEP we don't need. Zap it. - if (OffsetPtr && OffsetPtr != OffsetBasePtr) - if (Instruction *I = dyn_cast<Instruction>(OffsetPtr)) { - assert(I->use_empty() && "Built a GEP with uses some how!"); - I->eraseFromParent(); - } - OffsetPtr = P; - OffsetBasePtr = Ptr; - // If we also found a pointer of the right type, we're done. - if (P->getType() == PointerTy) - break; - } - - // Stash this pointer if we've found an i8*. - if (Ptr->getType()->isIntegerTy(8)) { - Int8Ptr = Ptr; - Int8PtrOffset = Offset; - } - - // Peel off a layer of the pointer and update the offset appropriately. - if (Operator::getOpcode(Ptr) == Instruction::BitCast) { - Ptr = cast<Operator>(Ptr)->getOperand(0); - } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) { - if (GA->isInterposable()) - break; - Ptr = GA->getAliasee(); - } else { - break; - } - assert(Ptr->getType()->isPointerTy() && "Unexpected operand type!"); - } while (Visited.insert(Ptr).second); - - if (!OffsetPtr) { - if (!Int8Ptr) { - Int8Ptr = IRB.CreateBitCast( - Ptr, IRB.getInt8PtrTy(PointerTy->getPointerAddressSpace()), - NamePrefix + "sroa_raw_cast"); - Int8PtrOffset = Offset; - } - - OffsetPtr = Int8PtrOffset == 0 - ? Int8Ptr - : IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Int8Ptr, - IRB.getInt(Int8PtrOffset), - NamePrefix + "sroa_raw_idx"); - } - Ptr = OffsetPtr; - - // On the off chance we were targeting i8*, guard the bitcast here. - if (cast<PointerType>(Ptr->getType()) != TargetPtrTy) { - Ptr = IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, - TargetPtrTy, - NamePrefix + "sroa_cast"); - } - - return Ptr; + if (Offset != 0) + Ptr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(Offset), + NamePrefix + "sroa_idx"); + return IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, PointerTy, + NamePrefix + "sroa_cast"); } /// Compute the adjusted alignment for a load or store from an offset. @@ -2126,6 +2003,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // Collect the candidate types for vector-based promotion. Also track whether // we have different element types. SmallVector<VectorType *, 4> CandidateTys; + SetVector<Type *> LoadStoreTys; Type *CommonEltTy = nullptr; VectorType *CommonVecPtrTy = nullptr; bool HaveVecPtrTy = false; @@ -2159,15 +2037,40 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { } } }; - // Consider any loads or stores that are the exact size of the slice. - for (const Slice &S : P) - if (S.beginOffset() == P.beginOffset() && - S.endOffset() == P.endOffset()) { - if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser())) - CheckCandidateType(LI->getType()); - else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser())) - CheckCandidateType(SI->getValueOperand()->getType()); + // Put load and store types into a set for de-duplication. + for (const Slice &S : P) { + Type *Ty; + if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser())) + Ty = LI->getType(); + else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser())) + Ty = SI->getValueOperand()->getType(); + else + continue; + LoadStoreTys.insert(Ty); + // Consider any loads or stores that are the exact size of the slice. + if (S.beginOffset() == P.beginOffset() && S.endOffset() == P.endOffset()) + CheckCandidateType(Ty); + } + // Consider additional vector types where the element type size is a + // multiple of load/store element size. + for (Type *Ty : LoadStoreTys) { + if (!VectorType::isValidElementType(Ty)) + continue; + unsigned TypeSize = DL.getTypeSizeInBits(Ty).getFixedValue(); + // Make a copy of CandidateTys and iterate through it, because we might + // append to CandidateTys in the loop. + SmallVector<VectorType *, 4> CandidateTysCopy = CandidateTys; + for (VectorType *&VTy : CandidateTysCopy) { + unsigned VectorSize = DL.getTypeSizeInBits(VTy).getFixedValue(); + unsigned ElementSize = + DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue(); + if (TypeSize != VectorSize && TypeSize != ElementSize && + VectorSize % TypeSize == 0) { + VectorType *NewVTy = VectorType::get(Ty, VectorSize / TypeSize, false); + CheckCandidateType(NewVTy); + } } + } // If we didn't find a vector type, nothing to do here. if (CandidateTys.empty()) @@ -2195,7 +2098,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // Rank the remaining candidate vector types. This is easy because we know // they're all integer vectors. We sort by ascending number of elements. - auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { + auto RankVectorTypesComp = [&DL](VectorType *RHSTy, VectorType *LHSTy) { (void)DL; assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() == DL.getTypeSizeInBits(LHSTy).getFixedValue() && @@ -2207,10 +2110,22 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { return cast<FixedVectorType>(RHSTy)->getNumElements() < cast<FixedVectorType>(LHSTy)->getNumElements(); }; - llvm::sort(CandidateTys, RankVectorTypes); - CandidateTys.erase( - std::unique(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes), - CandidateTys.end()); + auto RankVectorTypesEq = [&DL](VectorType *RHSTy, VectorType *LHSTy) { + (void)DL; + assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() == + DL.getTypeSizeInBits(LHSTy).getFixedValue() && + "Cannot have vector types of different sizes!"); + assert(RHSTy->getElementType()->isIntegerTy() && + "All non-integer types eliminated!"); + assert(LHSTy->getElementType()->isIntegerTy() && + "All non-integer types eliminated!"); + return cast<FixedVectorType>(RHSTy)->getNumElements() == + cast<FixedVectorType>(LHSTy)->getNumElements(); + }; + llvm::sort(CandidateTys, RankVectorTypesComp); + CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(), + RankVectorTypesEq), + CandidateTys.end()); } else { // The only way to have the same element type in every vector type is to // have the same vector type. Check that and remove all but one. @@ -2554,7 +2469,6 @@ class llvm::sroa::AllocaSliceRewriter // original alloca. uint64_t NewBeginOffset = 0, NewEndOffset = 0; - uint64_t RelativeOffset = 0; uint64_t SliceSize = 0; bool IsSplittable = false; bool IsSplit = false; @@ -2628,14 +2542,13 @@ public: NewBeginOffset = std::max(BeginOffset, NewAllocaBeginOffset); NewEndOffset = std::min(EndOffset, NewAllocaEndOffset); - RelativeOffset = NewBeginOffset - BeginOffset; SliceSize = NewEndOffset - NewBeginOffset; LLVM_DEBUG(dbgs() << " Begin:(" << BeginOffset << ", " << EndOffset << ") NewBegin:(" << NewBeginOffset << ", " << NewEndOffset << ") NewAllocaBegin:(" << NewAllocaBeginOffset << ", " << NewAllocaEndOffset << ")\n"); - assert(IsSplit || RelativeOffset == 0); + assert(IsSplit || NewBeginOffset == BeginOffset); OldUse = I->getUse(); OldPtr = cast<Instruction>(OldUse->get()); @@ -2898,8 +2811,8 @@ private: Pass.DeadInsts.push_back(&SI); // NOTE: Careful to use OrigV rather than V. - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, Store, - Store->getPointerOperand(), OrigV, DL); + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI, + Store, Store->getPointerOperand(), OrigV, DL); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return true; } @@ -2923,8 +2836,9 @@ private: if (AATags) Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, Store, - Store->getPointerOperand(), Store->getValueOperand(), DL); + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI, + Store, Store->getPointerOperand(), + Store->getValueOperand(), DL); Pass.DeadInsts.push_back(&SI); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); @@ -3002,8 +2916,9 @@ private: if (NewSI->isAtomic()) NewSI->setAlignment(SI.getAlign()); - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, NewSI, - NewSI->getPointerOperand(), NewSI->getValueOperand(), DL); + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI, + NewSI, NewSI->getPointerOperand(), + NewSI->getValueOperand(), DL); Pass.DeadInsts.push_back(&SI); deleteIfTriviallyDead(OldOp); @@ -3103,8 +3018,8 @@ private: if (AATags) New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New, - New->getRawDest(), nullptr, DL); + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II, + New, New->getRawDest(), nullptr, DL); LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return false; @@ -3179,8 +3094,8 @@ private: if (AATags) New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New, - New->getPointerOperand(), V, DL); + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II, + New, New->getPointerOperand(), V, DL); LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return !II.isVolatile(); @@ -3308,8 +3223,16 @@ private: if (AATags) New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New, - DestPtr, nullptr, DL); + APInt Offset(DL.getIndexTypeSizeInBits(DestPtr->getType()), 0); + if (IsDest) { + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, + &II, New, DestPtr, nullptr, DL); + } else if (AllocaInst *Base = dyn_cast<AllocaInst>( + DestPtr->stripAndAccumulateConstantOffsets( + DL, Offset, /*AllowNonInbounds*/ true))) { + migrateDebugInfo(Base, IsSplit, Offset.getZExtValue() * 8, + SliceSize * 8, &II, New, DestPtr, nullptr, DL); + } LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return false; } @@ -3397,8 +3320,18 @@ private: if (AATags) Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); - migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, Store, - DstPtr, Src, DL); + APInt Offset(DL.getIndexTypeSizeInBits(DstPtr->getType()), 0); + if (IsDest) { + + migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II, + Store, DstPtr, Src, DL); + } else if (AllocaInst *Base = dyn_cast<AllocaInst>( + DstPtr->stripAndAccumulateConstantOffsets( + DL, Offset, /*AllowNonInbounds*/ true))) { + migrateDebugInfo(Base, IsSplit, Offset.getZExtValue() * 8, SliceSize * 8, + &II, Store, DstPtr, Src, DL); + } + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return !II.isVolatile(); } @@ -3760,23 +3693,22 @@ private: APInt Offset( DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()), 0); - if (AATags && - GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset)) + GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset); + if (AATags) Store->setAAMetadata(AATags.shift(Offset.getZExtValue())); // migrateDebugInfo requires the base Alloca. Walk to it from this gep. // If we cannot (because there's an intervening non-const or unbounded // gep) then we wouldn't expect to see dbg.assign intrinsics linked to // this instruction. - APInt OffsetInBytes(DL.getTypeSizeInBits(Ptr->getType()), false); - Value *Base = InBoundsGEP->stripAndAccumulateInBoundsConstantOffsets( - DL, OffsetInBytes); + Value *Base = AggStore->getPointerOperand()->stripInBoundsOffsets(); if (auto *OldAI = dyn_cast<AllocaInst>(Base)) { uint64_t SizeInBits = DL.getTypeSizeInBits(Store->getValueOperand()->getType()); - migrateDebugInfo(OldAI, OffsetInBytes.getZExtValue() * 8, SizeInBits, - AggStore, Store, Store->getPointerOperand(), - Store->getValueOperand(), DL); + migrateDebugInfo(OldAI, /*IsSplit*/ true, Offset.getZExtValue() * 8, + SizeInBits, AggStore, Store, + Store->getPointerOperand(), Store->getValueOperand(), + DL); } else { assert(at::getAssignmentMarkers(Store).empty() && "AT: unexpected debug.assign linked to store through " @@ -3799,6 +3731,9 @@ private: getAdjustedAlignment(&SI, 0), DL, IRB); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); Visited.erase(&SI); + // The stores replacing SI each have markers describing fragments of the + // assignment so delete the assignment markers linked to SI. + at::deleteAssignmentMarkers(&SI); SI.eraseFromParent(); return true; } @@ -4029,6 +3964,10 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, return nullptr; const StructLayout *SL = DL.getStructLayout(STy); + + if (SL->getSizeInBits().isScalable()) + return nullptr; + if (Offset >= SL->getSizeInBytes()) return nullptr; uint64_t EndOffset = Offset + Size; @@ -4869,11 +4808,13 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. - TinyPtrVector<DbgVariableIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI); + TinyPtrVector<DbgVariableIntrinsic *> DbgVariables; + for (auto *DbgDeclare : FindDbgDeclareUses(&AI)) + DbgVariables.push_back(DbgDeclare); for (auto *DbgAssign : at::getAssignmentMarkers(&AI)) - DbgDeclares.push_back(DbgAssign); - for (DbgVariableIntrinsic *DbgDeclare : DbgDeclares) { - auto *Expr = DbgDeclare->getExpression(); + DbgVariables.push_back(DbgAssign); + for (DbgVariableIntrinsic *DbgVariable : DbgVariables) { + auto *Expr = DbgVariable->getExpression(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedValue(); @@ -4905,7 +4846,7 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } // The alloca may be larger than the variable. - auto VarSize = DbgDeclare->getVariable()->getSizeInBits(); + auto VarSize = DbgVariable->getVariable()->getSizeInBits(); if (VarSize) { if (Size > *VarSize) Size = *VarSize; @@ -4925,18 +4866,18 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // Remove any existing intrinsics on the new alloca describing // the variable fragment. - for (DbgVariableIntrinsic *OldDII : FindDbgAddrUses(Fragment.Alloca)) { + for (DbgDeclareInst *OldDII : FindDbgDeclareUses(Fragment.Alloca)) { auto SameVariableFragment = [](const DbgVariableIntrinsic *LHS, const DbgVariableIntrinsic *RHS) { return LHS->getVariable() == RHS->getVariable() && LHS->getDebugLoc()->getInlinedAt() == RHS->getDebugLoc()->getInlinedAt(); }; - if (SameVariableFragment(OldDII, DbgDeclare)) + if (SameVariableFragment(OldDII, DbgVariable)) OldDII->eraseFromParent(); } - if (auto *DbgAssign = dyn_cast<DbgAssignIntrinsic>(DbgDeclare)) { + if (auto *DbgAssign = dyn_cast<DbgAssignIntrinsic>(DbgVariable)) { if (!Fragment.Alloca->hasMetadata(LLVMContext::MD_DIAssignID)) { Fragment.Alloca->setMetadata( LLVMContext::MD_DIAssignID, @@ -4950,8 +4891,8 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { LLVM_DEBUG(dbgs() << "Created new assign intrinsic: " << *NewAssign << "\n"); } else { - DIB.insertDeclare(Fragment.Alloca, DbgDeclare->getVariable(), - FragmentExpr, DbgDeclare->getDebugLoc(), &AI); + DIB.insertDeclare(Fragment.Alloca, DbgVariable->getVariable(), + FragmentExpr, DbgVariable->getDebugLoc(), &AI); } } } @@ -4996,8 +4937,9 @@ SROAPass::runOnAlloca(AllocaInst &AI) { // Skip alloca forms that this analysis can't handle. auto *AT = AI.getAllocatedType(); - if (AI.isArrayAllocation() || !AT->isSized() || isa<ScalableVectorType>(AT) || - DL.getTypeAllocSize(AT).getFixedValue() == 0) + TypeSize Size = DL.getTypeAllocSize(AT); + if (AI.isArrayAllocation() || !AT->isSized() || Size.isScalable() || + Size.getFixedValue() == 0) return {Changed, CFGChanged}; // First, split any FCA loads and stores touching this alloca to promote @@ -5074,7 +5016,7 @@ bool SROAPass::deleteDeadInstructions( // not be able to find it. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { DeletedAllocas.insert(AI); - for (DbgVariableIntrinsic *OldDII : FindDbgAddrUses(AI)) + for (DbgDeclareInst *OldDII : FindDbgDeclareUses(AI)) OldDII->eraseFromParent(); } @@ -5107,8 +5049,13 @@ bool SROAPass::promoteAllocas(Function &F) { NumPromoted += PromotableAllocas.size(); - LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); - PromoteMemToReg(PromotableAllocas, DTU->getDomTree(), AC); + if (SROASkipMem2Reg) { + LLVM_DEBUG(dbgs() << "Not promoting allocas with mem2reg!\n"); + } else { + LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); + PromoteMemToReg(PromotableAllocas, DTU->getDomTree(), AC); + } + PromotableAllocas.clear(); return true; } @@ -5120,16 +5067,16 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU, DTU = &RunDTU; AC = &RunAC; + const DataLayout &DL = F.getParent()->getDataLayout(); BasicBlock &EntryBB = F.getEntryBlock(); for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); I != E; ++I) { if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { - if (isa<ScalableVectorType>(AI->getAllocatedType())) { - if (isAllocaPromotable(AI)) - PromotableAllocas.push_back(AI); - } else { + if (DL.getTypeAllocSize(AI->getAllocatedType()).isScalable() && + isAllocaPromotable(AI)) + PromotableAllocas.push_back(AI); + else Worklist.insert(AI); - } } } @@ -5172,6 +5119,11 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU, if (!Changed) return PreservedAnalyses::all(); + if (isAssignmentTrackingEnabled(*F.getParent())) { + for (auto &BB : F) + RemoveRedundantDbgInstrs(&BB); + } + PreservedAnalyses PA; if (!CFGChanged) PA.preserveSet<CFGAnalyses>(); @@ -5186,8 +5138,9 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, } PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) { - return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F), - AM.getResult<AssumptionAnalysis>(F)); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); + return runImpl(F, DT, AC); } void SROAPass::printPipeline( diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp index 8aee8d140a29..37b032e4d7c7 100644 --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -12,76 +12,38 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" -#include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/Scalar.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/ScopedNoAliasAA.h" -#include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" -#include "llvm/Transforms/Scalar/GVN.h" -#include "llvm/Transforms/Scalar/Scalarizer.h" -#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" using namespace llvm; /// initializeScalarOptsPasses - Initialize all passes linked into the /// ScalarOpts library. void llvm::initializeScalarOpts(PassRegistry &Registry) { - initializeADCELegacyPassPass(Registry); - initializeBDCELegacyPassPass(Registry); - initializeAlignmentFromAssumptionsPass(Registry); - initializeCallSiteSplittingLegacyPassPass(Registry); initializeConstantHoistingLegacyPassPass(Registry); - initializeCorrelatedValuePropagationPass(Registry); initializeDCELegacyPassPass(Registry); - initializeDivRemPairsLegacyPassPass(Registry); initializeScalarizerLegacyPassPass(Registry); - initializeDSELegacyPassPass(Registry); initializeGuardWideningLegacyPassPass(Registry); initializeLoopGuardWideningLegacyPassPass(Registry); initializeGVNLegacyPassPass(Registry); - initializeNewGVNLegacyPassPass(Registry); initializeEarlyCSELegacyPassPass(Registry); initializeEarlyCSEMemSSALegacyPassPass(Registry); initializeMakeGuardsExplicitLegacyPassPass(Registry); - initializeGVNHoistLegacyPassPass(Registry); - initializeGVNSinkLegacyPassPass(Registry); initializeFlattenCFGLegacyPassPass(Registry); - initializeIRCELegacyPassPass(Registry); - initializeIndVarSimplifyLegacyPassPass(Registry); initializeInferAddressSpacesPass(Registry); initializeInstSimplifyLegacyPassPass(Registry); - initializeJumpThreadingPass(Registry); - initializeDFAJumpThreadingLegacyPassPass(Registry); initializeLegacyLICMPassPass(Registry); initializeLegacyLoopSinkPassPass(Registry); - initializeLoopFuseLegacyPass(Registry); initializeLoopDataPrefetchLegacyPassPass(Registry); - initializeLoopDeletionLegacyPassPass(Registry); - initializeLoopAccessLegacyAnalysisPass(Registry); initializeLoopInstSimplifyLegacyPassPass(Registry); - initializeLoopInterchangeLegacyPassPass(Registry); - initializeLoopFlattenLegacyPassPass(Registry); initializeLoopPredicationLegacyPassPass(Registry); initializeLoopRotateLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); - initializeLoopRerollLegacyPassPass(Registry); initializeLoopUnrollPass(Registry); - initializeLoopUnrollAndJamPass(Registry); - initializeWarnMissedTransformationsLegacyPass(Registry); - initializeLoopVersioningLICMLegacyPassPass(Registry); - initializeLoopIdiomRecognizeLegacyPassPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); initializeLowerConstantIntrinsicsPass(Registry); initializeLowerExpectIntrinsicPass(Registry); initializeLowerGuardIntrinsicLegacyPassPass(Registry); - initializeLowerMatrixIntrinsicsLegacyPassPass(Registry); - initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(Registry); initializeLowerWidenableConditionLegacyPassPass(Registry); - initializeMemCpyOptLegacyPassPass(Registry); initializeMergeICmpsLegacyPassPass(Registry); initializeMergedLoadStoreMotionLegacyPassPass(Registry); initializeNaryReassociateLegacyPassPass(Registry); @@ -89,9 +51,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeReassociateLegacyPassPass(Registry); initializeRedundantDbgInstEliminationPass(Registry); initializeRegToMemLegacyPass(Registry); - initializeRewriteStatepointsForGCLegacyPassPass(Registry); initializeScalarizeMaskedMemIntrinLegacyPassPass(Registry); - initializeSCCPLegacyPassPass(Registry); initializeSROALegacyPassPass(Registry); initializeCFGSimplifyPassPass(Registry); initializeStructurizeCFGLegacyPassPass(Registry); @@ -102,196 +62,6 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeSeparateConstOffsetFromGEPLegacyPassPass(Registry); initializeSpeculativeExecutionLegacyPassPass(Registry); initializeStraightLineStrengthReduceLegacyPassPass(Registry); - initializePlaceBackedgeSafepointsImplPass(Registry); - initializePlaceSafepointsPass(Registry); - initializeFloat2IntLegacyPassPass(Registry); - initializeLoopDistributeLegacyPass(Registry); - initializeLoopLoadEliminationPass(Registry); + initializePlaceBackedgeSafepointsLegacyPassPass(Registry); initializeLoopSimplifyCFGLegacyPassPass(Registry); - initializeLoopVersioningLegacyPassPass(Registry); -} - -void LLVMAddLoopSimplifyCFGPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopSimplifyCFGPass()); -} - -void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { - initializeScalarOpts(*unwrap(R)); -} - -void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createAggressiveDCEPass()); -} - -void LLVMAddDCEPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createDeadCodeEliminationPass()); -} - -void LLVMAddBitTrackingDCEPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createBitTrackingDCEPass()); -} - -void LLVMAddAlignmentFromAssumptionsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createAlignmentFromAssumptionsPass()); -} - -void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCFGSimplificationPass()); -} - -void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createDeadStoreEliminationPass()); -} - -void LLVMAddScalarizerPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createScalarizerPass()); -} - -void LLVMAddGVNPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createGVNPass()); -} - -void LLVMAddNewGVNPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createNewGVNPass()); -} - -void LLVMAddMergedLoadStoreMotionPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createMergedLoadStoreMotionPass()); -} - -void LLVMAddIndVarSimplifyPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createIndVarSimplifyPass()); -} - -void LLVMAddInstructionSimplifyPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createInstSimplifyLegacyPass()); -} - -void LLVMAddJumpThreadingPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createJumpThreadingPass()); -} - -void LLVMAddLoopSinkPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopSinkPass()); -} - -void LLVMAddLICMPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLICMPass()); -} - -void LLVMAddLoopDeletionPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopDeletionPass()); -} - -void LLVMAddLoopFlattenPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopFlattenPass()); -} - -void LLVMAddLoopIdiomPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopIdiomPass()); -} - -void LLVMAddLoopRotatePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopRotatePass()); -} - -void LLVMAddLoopRerollPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopRerollPass()); -} - -void LLVMAddLoopUnrollPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopUnrollPass()); -} - -void LLVMAddLoopUnrollAndJamPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopUnrollAndJamPass()); -} - -void LLVMAddLowerAtomicPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLowerAtomicPass()); -} - -void LLVMAddMemCpyOptPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createMemCpyOptPass()); -} - -void LLVMAddPartiallyInlineLibCallsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createPartiallyInlineLibCallsPass()); -} - -void LLVMAddReassociatePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createReassociatePass()); -} - -void LLVMAddSCCPPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createSCCPPass()); -} - -void LLVMAddScalarReplAggregatesPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createSROAPass()); -} - -void LLVMAddScalarReplAggregatesPassSSA(LLVMPassManagerRef PM) { - unwrap(PM)->add(createSROAPass()); -} - -void LLVMAddScalarReplAggregatesPassWithThreshold(LLVMPassManagerRef PM, - int Threshold) { - unwrap(PM)->add(createSROAPass()); -} - -void LLVMAddSimplifyLibCallsPass(LLVMPassManagerRef PM) { - // NOTE: The simplify-libcalls pass has been removed. -} - -void LLVMAddTailCallEliminationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createTailCallEliminationPass()); -} - -void LLVMAddDemoteMemoryToRegisterPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createDemoteRegisterToMemoryPass()); -} - -void LLVMAddVerifierPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createVerifierPass()); -} - -void LLVMAddCorrelatedValuePropagationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCorrelatedValuePropagationPass()); -} - -void LLVMAddEarlyCSEPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createEarlyCSEPass(false/*=UseMemorySSA*/)); -} - -void LLVMAddEarlyCSEMemSSAPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createEarlyCSEPass(true/*=UseMemorySSA*/)); -} - -void LLVMAddGVNHoistLegacyPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createGVNHoistPass()); -} - -void LLVMAddTypeBasedAliasAnalysisPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createTypeBasedAAWrapperPass()); -} - -void LLVMAddScopedNoAliasAAPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createScopedNoAliasAAWrapperPass()); -} - -void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createBasicAAWrapperPass()); -} - -void LLVMAddLowerConstantIntrinsicsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLowerConstantIntrinsicsPass()); -} - -void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLowerExpectIntrinsicPass()); -} - -void LLVMAddUnifyFunctionExitNodesPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createUnifyFunctionExitNodesPass()); } diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index 1c8e4e3512dc..c01d03f64472 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -125,7 +125,7 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, // br label %else // // else: ; preds = %0, %cond.load -// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] +// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ] // %6 = extractelement <16 x i1> %mask, i32 1 // br i1 %6, label %cond.load1, label %else2 // @@ -170,10 +170,6 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, // Adjust alignment for the scalar instruction. const Align AdjustedAlignVal = commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); - // Bitcast %addr from i8* to EltTy* - Type *NewPtrType = - EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); - Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); // The result vector @@ -183,7 +179,7 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) continue; - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); VResult = Builder.CreateInsertElement(VResult, Load, Idx); } @@ -232,7 +228,7 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, CondBlock->setName("cond.load"); Builder.SetInsertPoint(CondBlock->getTerminator()); - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); @@ -309,10 +305,6 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, // Adjust alignment for the scalar instruction. const Align AdjustedAlignVal = commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); - // Bitcast %addr from i8* to EltTy* - Type *NewPtrType = - EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); - Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); if (isConstantIntVector(Mask)) { @@ -320,7 +312,7 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) continue; Value *OneElt = Builder.CreateExtractElement(Src, Idx); - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); } CI->eraseFromParent(); @@ -367,7 +359,7 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, Builder.SetInsertPoint(CondBlock->getTerminator()); Value *OneElt = Builder.CreateExtractElement(Src, Idx); - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); // Create "else" block, fill it in the next iteration @@ -394,11 +386,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, // cond.load: // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 // %Load0 = load i32, i32* %Ptr0, align 4 -// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 +// %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0 // br label %else // // else: -// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] +// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0] // %Mask1 = extractelement <16 x i1> %Mask, i32 1 // br i1 %Mask1, label %cond.load1, label %else2 // @@ -653,16 +645,16 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, Value *VResult = PassThru; // Shorten the way if the mask is a vector of constants. - // Create a build_vector pattern, with loads/undefs as necessary and then + // Create a build_vector pattern, with loads/poisons as necessary and then // shuffle blend with the pass through value. if (isConstantIntVector(Mask)) { unsigned MemIndex = 0; VResult = PoisonValue::get(VecType); - SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); + SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem); for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { Value *InsertElt; if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { - InsertElt = UndefValue::get(EltTy); + InsertElt = PoisonValue::get(EltTy); ShuffleMask[Idx] = Idx + VectorWidth; } else { Value *NewPtr = diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 4aab88b74f10..86b55dfd304a 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -6,8 +6,9 @@ // //===----------------------------------------------------------------------===// // -// This pass converts vector operations into scalar operations, in order -// to expose optimization opportunities on the individual scalar operations. +// This pass converts vector operations into scalar operations (or, optionally, +// operations on smaller vector widths), in order to expose optimization +// opportunities on the individual scalar operations. // It is mainly intended for targets that do not have vector units, but it // may also be useful for revectorizing code to different vector widths. // @@ -62,6 +63,16 @@ static cl::opt<bool> ClScalarizeLoadStore( "scalarize-load-store", cl::init(false), cl::Hidden, cl::desc("Allow the scalarizer pass to scalarize loads and store")); +// Split vectors larger than this size into fragments, where each fragment is +// either a vector no larger than this size or a scalar. +// +// Instructions with operands or results of different sizes that would be split +// into a different number of fragments are currently left as-is. +static cl::opt<unsigned> ClScalarizeMinBits( + "scalarize-min-bits", cl::init(0), cl::Hidden, + cl::desc("Instruct the scalarizer pass to attempt to keep values of a " + "minimum number of bits")); + namespace { BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) { @@ -88,6 +99,29 @@ using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>; // along with a pointer to their scattered forms. using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>; +struct VectorSplit { + // The type of the vector. + FixedVectorType *VecTy = nullptr; + + // The number of elements packed in a fragment (other than the remainder). + unsigned NumPacked = 0; + + // The number of fragments (scalars or smaller vectors) into which the vector + // shall be split. + unsigned NumFragments = 0; + + // The type of each complete fragment. + Type *SplitTy = nullptr; + + // The type of the remainder (last) fragment; null if all fragments are + // complete. + Type *RemainderTy = nullptr; + + Type *getFragmentType(unsigned I) const { + return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy; + } +}; + // Provides a very limited vector-like interface for lazily accessing one // component of a scattered vector or vector pointer. class Scatterer { @@ -97,23 +131,23 @@ public: // Scatter V into Size components. If new instructions are needed, // insert them before BBI in BB. If Cache is nonnull, use it to cache // the results. - Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, Type *PtrElemTy, - ValueVector *cachePtr = nullptr); + Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, + const VectorSplit &VS, ValueVector *cachePtr = nullptr); // Return component I, creating a new Value for it if necessary. Value *operator[](unsigned I); // Return the number of components. - unsigned size() const { return Size; } + unsigned size() const { return VS.NumFragments; } private: BasicBlock *BB; BasicBlock::iterator BBI; Value *V; - Type *PtrElemTy; + VectorSplit VS; + bool IsPointer; ValueVector *CachePtr; ValueVector Tmp; - unsigned Size; }; // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp @@ -171,24 +205,74 @@ struct BinarySplitter { struct VectorLayout { VectorLayout() = default; - // Return the alignment of element I. - Align getElemAlign(unsigned I) { - return commonAlignment(VecAlign, I * ElemSize); + // Return the alignment of fragment Frag. + Align getFragmentAlign(unsigned Frag) { + return commonAlignment(VecAlign, Frag * SplitSize); } - // The type of the vector. - FixedVectorType *VecTy = nullptr; - - // The type of each element. - Type *ElemTy = nullptr; + // The split of the underlying vector type. + VectorSplit VS; // The alignment of the vector. Align VecAlign; - // The size of each element. - uint64_t ElemSize = 0; + // The size of each (non-remainder) fragment in bytes. + uint64_t SplitSize = 0; }; +/// Concatenate the given fragments to a single vector value of the type +/// described in @p VS. +static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments, + const VectorSplit &VS, Twine Name) { + unsigned NumElements = VS.VecTy->getNumElements(); + SmallVector<int> ExtendMask; + SmallVector<int> InsertMask; + + if (VS.NumPacked > 1) { + // Prepare the shufflevector masks once and re-use them for all + // fragments. + ExtendMask.resize(NumElements, -1); + for (unsigned I = 0; I < VS.NumPacked; ++I) + ExtendMask[I] = I; + + InsertMask.resize(NumElements); + for (unsigned I = 0; I < NumElements; ++I) + InsertMask[I] = I; + } + + Value *Res = PoisonValue::get(VS.VecTy); + for (unsigned I = 0; I < VS.NumFragments; ++I) { + Value *Fragment = Fragments[I]; + + unsigned NumPacked = VS.NumPacked; + if (I == VS.NumFragments - 1 && VS.RemainderTy) { + if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy)) + NumPacked = RemVecTy->getNumElements(); + else + NumPacked = 1; + } + + if (NumPacked == 1) { + Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked, + Name + ".upto" + Twine(I)); + } else { + Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask); + if (I == 0) { + Res = Fragment; + } else { + for (unsigned J = 0; J < NumPacked; ++J) + InsertMask[I * VS.NumPacked + J] = NumElements + J; + Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask, + Name + ".upto" + Twine(I)); + for (unsigned J = 0; J < NumPacked; ++J) + InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J; + } + } + } + + return Res; +} + template <typename T> T getWithDefaultOverride(const cl::opt<T> &ClOption, const std::optional<T> &DefaultOverride) { @@ -205,8 +289,9 @@ public: getWithDefaultOverride(ClScalarizeVariableInsertExtract, Options.ScalarizeVariableInsertExtract)), ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore, - Options.ScalarizeLoadStore)) { - } + Options.ScalarizeLoadStore)), + ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits, + Options.ScalarizeMinBits)) {} bool visit(Function &F); @@ -228,13 +313,15 @@ public: bool visitLoadInst(LoadInst &LI); bool visitStoreInst(StoreInst &SI); bool visitCallInst(CallInst &ICI); + bool visitFreezeInst(FreezeInst &FI); private: - Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr); - void gather(Instruction *Op, const ValueVector &CV); + Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS); + void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS); void replaceUses(Instruction *Op, Value *CV); bool canTransferMetadata(unsigned Kind); void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); + std::optional<VectorSplit> getVectorSplit(Type *Ty); std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, const DataLayout &DL); bool finish(); @@ -256,6 +343,7 @@ private: const bool ScalarizeVariableInsertExtract; const bool ScalarizeLoadStore; + const unsigned ScalarizeMinBits; }; class ScalarizerLegacyPass : public FunctionPass { @@ -284,42 +372,47 @@ INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer", "Scalarize vector operations", false, false) Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, - Type *PtrElemTy, ValueVector *cachePtr) - : BB(bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) { - Type *Ty = V->getType(); - if (Ty->isPointerTy()) { - assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) && - "Pointer element type mismatch"); - Ty = PtrElemTy; + const VectorSplit &VS, ValueVector *cachePtr) + : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) { + IsPointer = V->getType()->isPointerTy(); + if (!CachePtr) { + Tmp.resize(VS.NumFragments, nullptr); + } else { + assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() || + IsPointer) && + "Inconsistent vector sizes"); + if (VS.NumFragments > CachePtr->size()) + CachePtr->resize(VS.NumFragments, nullptr); } - Size = cast<FixedVectorType>(Ty)->getNumElements(); - if (!CachePtr) - Tmp.resize(Size, nullptr); - else if (CachePtr->empty()) - CachePtr->resize(Size, nullptr); - else - assert(Size == CachePtr->size() && "Inconsistent vector sizes"); } -// Return component I, creating a new Value for it if necessary. -Value *Scatterer::operator[](unsigned I) { - ValueVector &CV = (CachePtr ? *CachePtr : Tmp); +// Return fragment Frag, creating a new Value for it if necessary. +Value *Scatterer::operator[](unsigned Frag) { + ValueVector &CV = CachePtr ? *CachePtr : Tmp; // Try to reuse a previous value. - if (CV[I]) - return CV[I]; + if (CV[Frag]) + return CV[Frag]; IRBuilder<> Builder(BB, BBI); - if (PtrElemTy) { - Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType(); - if (!CV[0]) { - Type *NewPtrTy = PointerType::get( - VectorElemTy, V->getType()->getPointerAddressSpace()); - CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0"); - } - if (I != 0) - CV[I] = Builder.CreateConstGEP1_32(VectorElemTy, CV[0], I, - V->getName() + ".i" + Twine(I)); + if (IsPointer) { + if (Frag == 0) + CV[Frag] = V; + else + CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag, + V->getName() + ".i" + Twine(Frag)); + return CV[Frag]; + } + + Type *FragmentTy = VS.getFragmentType(Frag); + + if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) { + SmallVector<int> Mask; + for (unsigned J = 0; J < VecTy->getNumElements(); ++J) + Mask.push_back(Frag * VS.NumPacked + J); + CV[Frag] = + Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask, + V->getName() + ".i" + Twine(Frag)); } else { - // Search through a chain of InsertElementInsts looking for element I. + // Search through a chain of InsertElementInsts looking for element Frag. // Record other elements in the cache. The new V is still suitable // for all uncached indices. while (true) { @@ -331,20 +424,23 @@ Value *Scatterer::operator[](unsigned I) { break; unsigned J = Idx->getZExtValue(); V = Insert->getOperand(0); - if (I == J) { - CV[J] = Insert->getOperand(1); - return CV[J]; - } else if (!CV[J]) { + if (Frag * VS.NumPacked == J) { + CV[Frag] = Insert->getOperand(1); + return CV[Frag]; + } + + if (VS.NumPacked == 1 && !CV[J]) { // Only cache the first entry we find for each index we're not actively // searching for. This prevents us from going too far up the chain and // caching incorrect entries. CV[J] = Insert->getOperand(1); } } - CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I), - V->getName() + ".i" + Twine(I)); + CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked, + V->getName() + ".i" + Twine(Frag)); } - return CV[I]; + + return CV[Frag]; } bool ScalarizerLegacyPass::runOnFunction(Function &F) { @@ -386,13 +482,13 @@ bool ScalarizerVisitor::visit(Function &F) { // Return a scattered form of V that can be accessed by Point. V must be a // vector or a pointer to a vector. Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, - Type *PtrElemTy) { + const VectorSplit &VS) { if (Argument *VArg = dyn_cast<Argument>(V)) { // Put the scattered form of arguments in the entry block, // so that it can be used everywhere. Function *F = VArg->getParent(); BasicBlock *BB = &F->getEntryBlock(); - return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[{V, PtrElemTy}]); + return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]); } if (Instruction *VOp = dyn_cast<Instruction>(V)) { // When scalarizing PHI nodes we might try to examine/rewrite InsertElement @@ -403,29 +499,30 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, // need to analyse them further. if (!DT->isReachableFromEntry(VOp->getParent())) return Scatterer(Point->getParent(), Point->getIterator(), - PoisonValue::get(V->getType()), PtrElemTy); + PoisonValue::get(V->getType()), VS); // Put the scattered form of an instruction directly after the // instruction, skipping over PHI nodes and debug intrinsics. BasicBlock *BB = VOp->getParent(); return Scatterer( - BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, - PtrElemTy, &Scattered[{V, PtrElemTy}]); + BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS, + &Scattered[{V, VS.SplitTy}]); } // In the fallback case, just put the scattered before Point and // keep the result local to Point. - return Scatterer(Point->getParent(), Point->getIterator(), V, PtrElemTy); + return Scatterer(Point->getParent(), Point->getIterator(), V, VS); } // Replace Op with the gathered form of the components in CV. Defer the // deletion of Op and creation of the gathered form to the end of the pass, // so that we can avoid creating the gathered form if all uses of Op are // replaced with uses of CV. -void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { +void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV, + const VectorSplit &VS) { transferMetadataAndIRFlags(Op, CV); // If we already have a scattered form of Op (created from ExtractElements // of Op itself), replace them with the new form. - ValueVector &SV = Scattered[{Op, nullptr}]; + ValueVector &SV = Scattered[{Op, VS.SplitTy}]; if (!SV.empty()) { for (unsigned I = 0, E = SV.size(); I != E; ++I) { Value *V = SV[I]; @@ -483,23 +580,57 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op, } } +// Determine how Ty is split, if at all. +std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) { + VectorSplit Split; + Split.VecTy = dyn_cast<FixedVectorType>(Ty); + if (!Split.VecTy) + return {}; + + unsigned NumElems = Split.VecTy->getNumElements(); + Type *ElemTy = Split.VecTy->getElementType(); + + if (NumElems == 1 || ElemTy->isPointerTy() || + 2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) { + Split.NumPacked = 1; + Split.NumFragments = NumElems; + Split.SplitTy = ElemTy; + } else { + Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits(); + if (Split.NumPacked >= NumElems) + return {}; + + Split.NumFragments = divideCeil(NumElems, Split.NumPacked); + Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked); + + unsigned RemainderElems = NumElems % Split.NumPacked; + if (RemainderElems > 1) + Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems); + else if (RemainderElems == 1) + Split.RemainderTy = ElemTy; + } + + return Split; +} + // Try to fill in Layout from Ty, returning true on success. Alignment is // the alignment of the vector, or std::nullopt if the ABI default should be // used. std::optional<VectorLayout> ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, const DataLayout &DL) { + std::optional<VectorSplit> VS = getVectorSplit(Ty); + if (!VS) + return {}; + VectorLayout Layout; - // Make sure we're dealing with a vector. - Layout.VecTy = dyn_cast<FixedVectorType>(Ty); - if (!Layout.VecTy) - return std::nullopt; - // Check that we're dealing with full-byte elements. - Layout.ElemTy = Layout.VecTy->getElementType(); - if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy)) - return std::nullopt; + Layout.VS = *VS; + // Check that we're dealing with full-byte fragments. + if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) || + (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy))) + return {}; Layout.VecAlign = Alignment; - Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy); + Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy); return Layout; } @@ -507,19 +638,27 @@ ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, // to create an instruction like I with operand X and name Name. template<typename Splitter> bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { - auto *VT = dyn_cast<FixedVectorType>(I.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(I.getType()); + if (!VS) return false; - unsigned NumElems = VT->getNumElements(); + std::optional<VectorSplit> OpVS; + if (I.getOperand(0)->getType() == I.getType()) { + OpVS = VS; + } else { + OpVS = getVectorSplit(I.getOperand(0)->getType()); + if (!OpVS || VS->NumPacked != OpVS->NumPacked) + return false; + } + IRBuilder<> Builder(&I); - Scatterer Op = scatter(&I, I.getOperand(0)); - assert(Op.size() == NumElems && "Mismatched unary operation"); + Scatterer Op = scatter(&I, I.getOperand(0), *OpVS); + assert(Op.size() == VS->NumFragments && "Mismatched unary operation"); ValueVector Res; - Res.resize(NumElems); - for (unsigned Elem = 0; Elem < NumElems; ++Elem) - Res[Elem] = Split(Builder, Op[Elem], I.getName() + ".i" + Twine(Elem)); - gather(&I, Res); + Res.resize(VS->NumFragments); + for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) + Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag)); + gather(&I, Res, *VS); return true; } @@ -527,24 +666,32 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { // to create an instruction like I with operands X and Y and name Name. template<typename Splitter> bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { - auto *VT = dyn_cast<FixedVectorType>(I.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(I.getType()); + if (!VS) return false; - unsigned NumElems = VT->getNumElements(); + std::optional<VectorSplit> OpVS; + if (I.getOperand(0)->getType() == I.getType()) { + OpVS = VS; + } else { + OpVS = getVectorSplit(I.getOperand(0)->getType()); + if (!OpVS || VS->NumPacked != OpVS->NumPacked) + return false; + } + IRBuilder<> Builder(&I); - Scatterer VOp0 = scatter(&I, I.getOperand(0)); - Scatterer VOp1 = scatter(&I, I.getOperand(1)); - assert(VOp0.size() == NumElems && "Mismatched binary operation"); - assert(VOp1.size() == NumElems && "Mismatched binary operation"); + Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS); + Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS); + assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation"); + assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation"); ValueVector Res; - Res.resize(NumElems); - for (unsigned Elem = 0; Elem < NumElems; ++Elem) { - Value *Op0 = VOp0[Elem]; - Value *Op1 = VOp1[Elem]; - Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem)); + Res.resize(VS->NumFragments); + for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) { + Value *Op0 = VOp0[Frag]; + Value *Op1 = VOp1[Frag]; + Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag)); } - gather(&I, Res); + gather(&I, Res, *VS); return true; } @@ -552,18 +699,11 @@ static bool isTriviallyScalariable(Intrinsic::ID ID) { return isTriviallyVectorizable(ID); } -// All of the current scalarizable intrinsics only have one mangled type. -static Function *getScalarIntrinsicDeclaration(Module *M, - Intrinsic::ID ID, - ArrayRef<Type*> Tys) { - return Intrinsic::getDeclaration(M, ID, Tys); -} - /// If a call to a vector typed intrinsic function, split into a scalar call per /// element if possible for the intrinsic. bool ScalarizerVisitor::splitCall(CallInst &CI) { - auto *VT = dyn_cast<FixedVectorType>(CI.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(CI.getType()); + if (!VS) return false; Function *F = CI.getCalledFunction(); @@ -574,26 +714,41 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) return false; - unsigned NumElems = VT->getNumElements(); + // unsigned NumElems = VT->getNumElements(); unsigned NumArgs = CI.arg_size(); ValueVector ScalarOperands(NumArgs); SmallVector<Scatterer, 8> Scattered(NumArgs); - - Scattered.resize(NumArgs); + SmallVector<int> OverloadIdx(NumArgs, -1); SmallVector<llvm::Type *, 3> Tys; - Tys.push_back(VT->getScalarType()); + // Add return type if intrinsic is overloaded on it. + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) + Tys.push_back(VS->SplitTy); // Assumes that any vector type has the same number of elements as the return // vector type, which is true for all current intrinsics. for (unsigned I = 0; I != NumArgs; ++I) { Value *OpI = CI.getOperand(I); - if (OpI->getType()->isVectorTy()) { - Scattered[I] = scatter(&CI, OpI); - assert(Scattered[I].size() == NumElems && "mismatched call operands"); - if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) - Tys.push_back(OpI->getType()->getScalarType()); + if (auto *OpVecTy = dyn_cast<FixedVectorType>(OpI->getType())) { + assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements()); + std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType()); + if (!OpVS || OpVS->NumPacked != VS->NumPacked) { + // The natural split of the operand doesn't match the result. This could + // happen if the vector elements are different and the ScalarizeMinBits + // option is used. + // + // We could in principle handle this case as well, at the cost of + // complicating the scattering machinery to support multiple scattering + // granularities for a single value. + return false; + } + + Scattered[I] = scatter(&CI, OpI, *OpVS); + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) { + OverloadIdx[I] = Tys.size(); + Tys.push_back(OpVS->SplitTy); + } } else { ScalarOperands[I] = OpI; if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) @@ -601,49 +756,67 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { } } - ValueVector Res(NumElems); + ValueVector Res(VS->NumFragments); ValueVector ScalarCallOps(NumArgs); - Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, Tys); + Function *NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys); IRBuilder<> Builder(&CI); // Perform actual scalarization, taking care to preserve any scalar operands. - for (unsigned Elem = 0; Elem < NumElems; ++Elem) { + for (unsigned I = 0; I < VS->NumFragments; ++I) { + bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy; ScalarCallOps.clear(); + if (IsRemainder) + Tys[0] = VS->RemainderTy; + for (unsigned J = 0; J != NumArgs; ++J) { - if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) { ScalarCallOps.push_back(ScalarOperands[J]); - else - ScalarCallOps.push_back(Scattered[J][Elem]); + } else { + ScalarCallOps.push_back(Scattered[J][I]); + if (IsRemainder && OverloadIdx[J] >= 0) + Tys[OverloadIdx[J]] = Scattered[J][I]->getType(); + } } - Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps, - CI.getName() + ".i" + Twine(Elem)); + if (IsRemainder) + NewIntrin = Intrinsic::getDeclaration(F->getParent(), ID, Tys); + + Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps, + CI.getName() + ".i" + Twine(I)); } - gather(&CI, Res); + gather(&CI, Res, *VS); return true; } bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) { - auto *VT = dyn_cast<FixedVectorType>(SI.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(SI.getType()); + if (!VS) return false; - unsigned NumElems = VT->getNumElements(); + std::optional<VectorSplit> CondVS; + if (isa<FixedVectorType>(SI.getCondition()->getType())) { + CondVS = getVectorSplit(SI.getCondition()->getType()); + if (!CondVS || CondVS->NumPacked != VS->NumPacked) { + // This happens when ScalarizeMinBits is used. + return false; + } + } + IRBuilder<> Builder(&SI); - Scatterer VOp1 = scatter(&SI, SI.getOperand(1)); - Scatterer VOp2 = scatter(&SI, SI.getOperand(2)); - assert(VOp1.size() == NumElems && "Mismatched select"); - assert(VOp2.size() == NumElems && "Mismatched select"); + Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS); + Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS); + assert(VOp1.size() == VS->NumFragments && "Mismatched select"); + assert(VOp2.size() == VS->NumFragments && "Mismatched select"); ValueVector Res; - Res.resize(NumElems); + Res.resize(VS->NumFragments); - if (SI.getOperand(0)->getType()->isVectorTy()) { - Scatterer VOp0 = scatter(&SI, SI.getOperand(0)); - assert(VOp0.size() == NumElems && "Mismatched select"); - for (unsigned I = 0; I < NumElems; ++I) { + if (CondVS) { + Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS); + assert(VOp0.size() == CondVS->NumFragments && "Mismatched select"); + for (unsigned I = 0; I < VS->NumFragments; ++I) { Value *Op0 = VOp0[I]; Value *Op1 = VOp1[I]; Value *Op2 = VOp2[I]; @@ -652,14 +825,14 @@ bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) { } } else { Value *Op0 = SI.getOperand(0); - for (unsigned I = 0; I < NumElems; ++I) { + for (unsigned I = 0; I < VS->NumFragments; ++I) { Value *Op1 = VOp1[I]; Value *Op2 = VOp2[I]; Res[I] = Builder.CreateSelect(Op0, Op1, Op2, SI.getName() + ".i" + Twine(I)); } } - gather(&SI, Res); + gather(&SI, Res, *VS); return true; } @@ -680,146 +853,194 @@ bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) { } bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { - auto *VT = dyn_cast<FixedVectorType>(GEPI.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType()); + if (!VS) return false; IRBuilder<> Builder(&GEPI); - unsigned NumElems = VT->getNumElements(); unsigned NumIndices = GEPI.getNumIndices(); - // The base pointer might be scalar even if it's a vector GEP. In those cases, - // splat the pointer into a vector value, and scatter that vector. - Value *Op0 = GEPI.getOperand(0); - if (!Op0->getType()->isVectorTy()) - Op0 = Builder.CreateVectorSplat(NumElems, Op0); - Scatterer Base = scatter(&GEPI, Op0); - - SmallVector<Scatterer, 8> Ops; - Ops.resize(NumIndices); - for (unsigned I = 0; I < NumIndices; ++I) { - Value *Op = GEPI.getOperand(I + 1); - - // The indices might be scalars even if it's a vector GEP. In those cases, - // splat the scalar into a vector value, and scatter that vector. - if (!Op->getType()->isVectorTy()) - Op = Builder.CreateVectorSplat(NumElems, Op); - - Ops[I] = scatter(&GEPI, Op); + // The base pointer and indices might be scalar even if it's a vector GEP. + SmallVector<Value *, 8> ScalarOps{1 + NumIndices}; + SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices}; + + for (unsigned I = 0; I < 1 + NumIndices; ++I) { + if (auto *VecTy = + dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) { + std::optional<VectorSplit> OpVS = getVectorSplit(VecTy); + if (!OpVS || OpVS->NumPacked != VS->NumPacked) { + // This can happen when ScalarizeMinBits is used. + return false; + } + ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS); + } else { + ScalarOps[I] = GEPI.getOperand(I); + } } ValueVector Res; - Res.resize(NumElems); - for (unsigned I = 0; I < NumElems; ++I) { - SmallVector<Value *, 8> Indices; - Indices.resize(NumIndices); - for (unsigned J = 0; J < NumIndices; ++J) - Indices[J] = Ops[J][I]; - Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices, + Res.resize(VS->NumFragments); + for (unsigned I = 0; I < VS->NumFragments; ++I) { + SmallVector<Value *, 8> SplitOps; + SplitOps.resize(1 + NumIndices); + for (unsigned J = 0; J < 1 + NumIndices; ++J) { + if (ScalarOps[J]) + SplitOps[J] = ScalarOps[J]; + else + SplitOps[J] = ScatterOps[J][I]; + } + Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0], + ArrayRef(SplitOps).drop_front(), GEPI.getName() + ".i" + Twine(I)); if (GEPI.isInBounds()) if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I])) NewGEPI->setIsInBounds(); } - gather(&GEPI, Res); + gather(&GEPI, Res, *VS); return true; } bool ScalarizerVisitor::visitCastInst(CastInst &CI) { - auto *VT = dyn_cast<FixedVectorType>(CI.getDestTy()); - if (!VT) + std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy()); + if (!DestVS) + return false; + + std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy()); + if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked) return false; - unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&CI); - Scatterer Op0 = scatter(&CI, CI.getOperand(0)); - assert(Op0.size() == NumElems && "Mismatched cast"); + Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS); + assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast"); ValueVector Res; - Res.resize(NumElems); - for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(), - CI.getName() + ".i" + Twine(I)); - gather(&CI, Res); + Res.resize(DestVS->NumFragments); + for (unsigned I = 0; I < DestVS->NumFragments; ++I) + Res[I] = + Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I), + CI.getName() + ".i" + Twine(I)); + gather(&CI, Res, *DestVS); return true; } bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { - auto *DstVT = dyn_cast<FixedVectorType>(BCI.getDestTy()); - auto *SrcVT = dyn_cast<FixedVectorType>(BCI.getSrcTy()); - if (!DstVT || !SrcVT) + std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy()); + std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy()); + if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy) return false; - unsigned DstNumElems = DstVT->getNumElements(); - unsigned SrcNumElems = SrcVT->getNumElements(); + const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy(); + + // Vectors of pointers are always fully scalarized. + assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1)); + IRBuilder<> Builder(&BCI); - Scatterer Op0 = scatter(&BCI, BCI.getOperand(0)); + Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS); ValueVector Res; - Res.resize(DstNumElems); + Res.resize(DstVS->NumFragments); + + unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits(); + unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits(); - if (DstNumElems == SrcNumElems) { - for (unsigned I = 0; I < DstNumElems; ++I) - Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(), + if (isPointerTy || DstSplitBits == SrcSplitBits) { + assert(DstVS->NumFragments == SrcVS->NumFragments); + for (unsigned I = 0; I < DstVS->NumFragments; ++I) { + Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I), BCI.getName() + ".i" + Twine(I)); - } else if (DstNumElems > SrcNumElems) { - // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the - // individual elements to the destination. - unsigned FanOut = DstNumElems / SrcNumElems; - auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut); + } + } else if (SrcSplitBits % DstSplitBits == 0) { + // Convert each source fragment to the same-sized destination vector and + // then scatter the result to the destination. + VectorSplit MidVS; + MidVS.NumPacked = DstVS->NumPacked; + MidVS.NumFragments = SrcSplitBits / DstSplitBits; + MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(), + MidVS.NumPacked * MidVS.NumFragments); + MidVS.SplitTy = DstVS->SplitTy; + unsigned ResI = 0; - for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) { - Value *V = Op0[Op0I]; - Instruction *VI; + for (unsigned I = 0; I < SrcVS->NumFragments; ++I) { + Value *V = Op0[I]; + // Look through any existing bitcasts before converting to <N x t2>. // In the best case, the resulting conversion might be a no-op. + Instruction *VI; while ((VI = dyn_cast<Instruction>(V)) && VI->getOpcode() == Instruction::BitCast) V = VI->getOperand(0); - V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast"); - Scatterer Mid = scatter(&BCI, V); - for (unsigned MidI = 0; MidI < FanOut; ++MidI) - Res[ResI++] = Mid[MidI]; + + V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast"); + + Scatterer Mid = scatter(&BCI, V, MidVS); + for (unsigned J = 0; J < MidVS.NumFragments; ++J) + Res[ResI++] = Mid[J]; } - } else { - // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2. - unsigned FanIn = SrcNumElems / DstNumElems; - auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn); - unsigned Op0I = 0; - for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) { - Value *V = PoisonValue::get(MidTy); - for (unsigned MidI = 0; MidI < FanIn; ++MidI) - V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI), - BCI.getName() + ".i" + Twine(ResI) - + ".upto" + Twine(MidI)); - Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(), - BCI.getName() + ".i" + Twine(ResI)); + } else if (DstSplitBits % SrcSplitBits == 0) { + // Gather enough source fragments to make up a destination fragment and + // then convert to the destination type. + VectorSplit MidVS; + MidVS.NumFragments = DstSplitBits / SrcSplitBits; + MidVS.NumPacked = SrcVS->NumPacked; + MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(), + MidVS.NumPacked * MidVS.NumFragments); + MidVS.SplitTy = SrcVS->SplitTy; + + unsigned SrcI = 0; + SmallVector<Value *, 8> ConcatOps; + ConcatOps.resize(MidVS.NumFragments); + for (unsigned I = 0; I < DstVS->NumFragments; ++I) { + for (unsigned J = 0; J < MidVS.NumFragments; ++J) + ConcatOps[J] = Op0[SrcI++]; + Value *V = concatenate(Builder, ConcatOps, MidVS, + BCI.getName() + ".i" + Twine(I)); + Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I), + BCI.getName() + ".i" + Twine(I)); } + } else { + return false; } - gather(&BCI, Res); + + gather(&BCI, Res, *DstVS); return true; } bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { - auto *VT = dyn_cast<FixedVectorType>(IEI.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(IEI.getType()); + if (!VS) return false; - unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&IEI); - Scatterer Op0 = scatter(&IEI, IEI.getOperand(0)); + Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS); Value *NewElt = IEI.getOperand(1); Value *InsIdx = IEI.getOperand(2); ValueVector Res; - Res.resize(NumElems); + Res.resize(VS->NumFragments); if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) { - for (unsigned I = 0; I < NumElems; ++I) - Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I]; + unsigned Idx = CI->getZExtValue(); + unsigned Fragment = Idx / VS->NumPacked; + for (unsigned I = 0; I < VS->NumFragments; ++I) { + if (I == Fragment) { + bool IsPacked = VS->NumPacked > 1; + if (Fragment == VS->NumFragments - 1 && VS->RemainderTy && + !VS->RemainderTy->isVectorTy()) + IsPacked = false; + if (IsPacked) { + Res[I] = + Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked); + } else { + Res[I] = NewElt; + } + } else { + Res[I] = Op0[I]; + } + } } else { - if (!ScalarizeVariableInsertExtract) + // Never split a variable insertelement that isn't fully scalarized. + if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1) return false; - for (unsigned I = 0; I < NumElems; ++I) { + for (unsigned I = 0; I < VS->NumFragments; ++I) { Value *ShouldReplace = Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I), InsIdx->getName() + ".is." + Twine(I)); @@ -829,31 +1050,39 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { } } - gather(&IEI, Res); + gather(&IEI, Res, *VS); return true; } bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { - auto *VT = dyn_cast<FixedVectorType>(EEI.getOperand(0)->getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType()); + if (!VS) return false; - unsigned NumSrcElems = VT->getNumElements(); IRBuilder<> Builder(&EEI); - Scatterer Op0 = scatter(&EEI, EEI.getOperand(0)); + Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS); Value *ExtIdx = EEI.getOperand(1); if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) { - Value *Res = Op0[CI->getValue().getZExtValue()]; + unsigned Idx = CI->getZExtValue(); + unsigned Fragment = Idx / VS->NumPacked; + Value *Res = Op0[Fragment]; + bool IsPacked = VS->NumPacked > 1; + if (Fragment == VS->NumFragments - 1 && VS->RemainderTy && + !VS->RemainderTy->isVectorTy()) + IsPacked = false; + if (IsPacked) + Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked); replaceUses(&EEI, Res); return true; } - if (!ScalarizeVariableInsertExtract) + // Never split a variable extractelement that isn't fully scalarized. + if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1) return false; - Value *Res = PoisonValue::get(VT->getElementType()); - for (unsigned I = 0; I < NumSrcElems; ++I) { + Value *Res = PoisonValue::get(VS->VecTy->getElementType()); + for (unsigned I = 0; I < VS->NumFragments; ++I) { Value *ShouldExtract = Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I), ExtIdx->getName() + ".is." + Twine(I)); @@ -866,51 +1095,52 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { } bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { - auto *VT = dyn_cast<FixedVectorType>(SVI.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(SVI.getType()); + std::optional<VectorSplit> VSOp = + getVectorSplit(SVI.getOperand(0)->getType()); + if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1) return false; - unsigned NumElems = VT->getNumElements(); - Scatterer Op0 = scatter(&SVI, SVI.getOperand(0)); - Scatterer Op1 = scatter(&SVI, SVI.getOperand(1)); + Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp); + Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp); ValueVector Res; - Res.resize(NumElems); + Res.resize(VS->NumFragments); - for (unsigned I = 0; I < NumElems; ++I) { + for (unsigned I = 0; I < VS->NumFragments; ++I) { int Selector = SVI.getMaskValue(I); if (Selector < 0) - Res[I] = UndefValue::get(VT->getElementType()); + Res[I] = PoisonValue::get(VS->VecTy->getElementType()); else if (unsigned(Selector) < Op0.size()) Res[I] = Op0[Selector]; else Res[I] = Op1[Selector - Op0.size()]; } - gather(&SVI, Res); + gather(&SVI, Res, *VS); return true; } bool ScalarizerVisitor::visitPHINode(PHINode &PHI) { - auto *VT = dyn_cast<FixedVectorType>(PHI.getType()); - if (!VT) + std::optional<VectorSplit> VS = getVectorSplit(PHI.getType()); + if (!VS) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); IRBuilder<> Builder(&PHI); ValueVector Res; - Res.resize(NumElems); + Res.resize(VS->NumFragments); unsigned NumOps = PHI.getNumOperands(); - for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps, + for (unsigned I = 0; I < VS->NumFragments; ++I) { + Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps, PHI.getName() + ".i" + Twine(I)); + } for (unsigned I = 0; I < NumOps; ++I) { - Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I)); + Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS); BasicBlock *IncomingBlock = PHI.getIncomingBlock(I); - for (unsigned J = 0; J < NumElems; ++J) + for (unsigned J = 0; J < VS->NumFragments; ++J) cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock); } - gather(&PHI, Res); + gather(&PHI, Res, *VS); return true; } @@ -925,17 +1155,17 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { if (!Layout) return false; - unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&LI); - Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), LI.getType()); + Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS); ValueVector Res; - Res.resize(NumElems); + Res.resize(Layout->VS.NumFragments); - for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I], - Align(Layout->getElemAlign(I)), + for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) { + Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I], + Align(Layout->getFragmentAlign(I)), LI.getName() + ".i" + Twine(I)); - gather(&LI, Res); + } + gather(&LI, Res, Layout->VS); return true; } @@ -951,17 +1181,17 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { if (!Layout) return false; - unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&SI); - Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), FullValue->getType()); - Scatterer VVal = scatter(&SI, FullValue); + Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS); + Scatterer VVal = scatter(&SI, FullValue, Layout->VS); ValueVector Stores; - Stores.resize(NumElems); - for (unsigned I = 0; I < NumElems; ++I) { + Stores.resize(Layout->VS.NumFragments); + for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) { Value *Val = VVal[I]; Value *Ptr = VPtr[I]; - Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I)); + Stores[I] = + Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I)); } transferMetadataAndIRFlags(&SI, Stores); return true; @@ -971,6 +1201,12 @@ bool ScalarizerVisitor::visitCallInst(CallInst &CI) { return splitCall(CI); } +bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) { + return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) { + return Builder.CreateFreeze(Op, Name); + }); +} + // Delete the instructions that we scalarized. If a full vector result // is still needed, recreate it using InsertElements. bool ScalarizerVisitor::finish() { @@ -983,17 +1219,19 @@ bool ScalarizerVisitor::finish() { ValueVector &CV = *GMI.second; if (!Op->use_empty()) { // The value is still needed, so recreate it using a series of - // InsertElements. - Value *Res = PoisonValue::get(Op->getType()); + // insertelements and/or shufflevectors. + Value *Res; if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) { BasicBlock *BB = Op->getParent(); - unsigned Count = Ty->getNumElements(); IRBuilder<> Builder(Op); if (isa<PHINode>(Op)) Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); - for (unsigned I = 0; I < Count; ++I) - Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), - Op->getName() + ".upto" + Twine(I)); + + VectorSplit VS = *getVectorSplit(Ty); + assert(VS.NumFragments == CV.size()); + + Res = concatenate(Builder, CV, VS, Op->getName()); + Res->takeName(Op); } else { assert(CV.size() == 1 && Op->getType() == CV[0]->getType()); diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 4fb90bcea4f0..89d0b7c33e0d 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -162,7 +162,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -355,7 +354,6 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); AU.setPreservesCFG(); @@ -374,14 +372,23 @@ private: class SeparateConstOffsetFromGEP { public: SeparateConstOffsetFromGEP( - DominatorTree *DT, ScalarEvolution *SE, LoopInfo *LI, - TargetLibraryInfo *TLI, + DominatorTree *DT, LoopInfo *LI, TargetLibraryInfo *TLI, function_ref<TargetTransformInfo &(Function &)> GetTTI, bool LowerGEP) - : DT(DT), SE(SE), LI(LI), TLI(TLI), GetTTI(GetTTI), LowerGEP(LowerGEP) {} + : DT(DT), LI(LI), TLI(TLI), GetTTI(GetTTI), LowerGEP(LowerGEP) {} bool run(Function &F); private: + /// Track the operands of an add or sub. + using ExprKey = std::pair<Value *, Value *>; + + /// Create a pair for use as a map key for a commutable operation. + static ExprKey createNormalizedCommutablePair(Value *A, Value *B) { + if (A < B) + return {A, B}; + return {B, A}; + } + /// Tries to split the given GEP into a variadic base and a constant offset, /// and returns true if the splitting succeeds. bool splitGEP(GetElementPtrInst *GEP); @@ -428,7 +435,7 @@ private: /// Returns true if the module changes. /// /// Verified in @i32_add in split-gep.ll - bool canonicalizeArrayIndicesToPointerSize(GetElementPtrInst *GEP); + bool canonicalizeArrayIndicesToIndexSize(GetElementPtrInst *GEP); /// Optimize sext(a)+sext(b) to sext(a+b) when a+b can't sign overflow. /// SeparateConstOffsetFromGEP distributes a sext to leaves before extracting @@ -446,8 +453,8 @@ private: /// Find the closest dominator of <Dominatee> that is equivalent to <Key>. Instruction *findClosestMatchingDominator( - const SCEV *Key, Instruction *Dominatee, - DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs); + ExprKey Key, Instruction *Dominatee, + DenseMap<ExprKey, SmallVector<Instruction *, 2>> &DominatingExprs); /// Verify F is free of dead code. void verifyNoDeadCode(Function &F); @@ -463,7 +470,6 @@ private: const DataLayout *DL = nullptr; DominatorTree *DT = nullptr; - ScalarEvolution *SE; LoopInfo *LI; TargetLibraryInfo *TLI; // Retrieved lazily since not always used. @@ -473,8 +479,8 @@ private: /// multiple GEPs with a single index. bool LowerGEP; - DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingAdds; - DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingSubs; + DenseMap<ExprKey, SmallVector<Instruction *, 2>> DominatingAdds; + DenseMap<ExprKey, SmallVector<Instruction *, 2>> DominatingSubs; }; } // end anonymous namespace @@ -521,6 +527,12 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, !haveNoCommonBitsSet(LHS, RHS, DL, nullptr, BO, DT)) return false; + // FIXME: We don't currently support constants from the RHS of subs, + // when we are zero-extended, because we need a way to zero-extended + // them before they are negated. + if (ZeroExtended && !SignExtended && BO->getOpcode() == Instruction::Sub) + return false; + // In addition, tracing into BO requires that its surrounding s/zext (if // any) is distributable to both operands. // @@ -791,17 +803,17 @@ int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP, .getSExtValue(); } -bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToPointerSize( +bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize( GetElementPtrInst *GEP) { bool Changed = false; - Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + Type *PtrIdxTy = DL->getIndexType(GEP->getType()); gep_type_iterator GTI = gep_type_begin(*GEP); for (User::op_iterator I = GEP->op_begin() + 1, E = GEP->op_end(); I != E; ++I, ++GTI) { // Skip struct member indices which must be i32. if (GTI.isSequential()) { - if ((*I)->getType() != IntPtrTy) { - *I = CastInst::CreateIntegerCast(*I, IntPtrTy, true, "idxprom", GEP); + if ((*I)->getType() != PtrIdxTy) { + *I = CastInst::CreateIntegerCast(*I, PtrIdxTy, true, "idxprom", GEP); Changed = true; } } @@ -849,10 +861,8 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset) { IRBuilder<> Builder(Variadic); - Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); + Type *PtrIndexTy = DL->getIndexType(Variadic->getType()); - Type *I8PtrTy = - Builder.getInt8PtrTy(Variadic->getType()->getPointerAddressSpace()); Value *ResultPtr = Variadic->getOperand(0); Loop *L = LI->getLoopFor(Variadic->getParent()); // Check if the base is not loop invariant or used more than once. @@ -861,9 +871,6 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( !hasMoreThanOneUseInLoop(ResultPtr, L); Value *FirstResult = nullptr; - if (ResultPtr->getType() != I8PtrTy) - ResultPtr = Builder.CreateBitCast(ResultPtr, I8PtrTy); - gep_type_iterator GTI = gep_type_begin(*Variadic); // Create an ugly GEP for each sequential index. We don't create GEPs for // structure indices, as they are accumulated in the constant offset index. @@ -875,15 +882,16 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( if (CI->isZero()) continue; - APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), + APInt ElementSize = APInt(PtrIndexTy->getIntegerBitWidth(), DL->getTypeAllocSize(GTI.getIndexedType())); // Scale the index by element size. if (ElementSize != 1) { if (ElementSize.isPowerOf2()) { Idx = Builder.CreateShl( - Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); + Idx, ConstantInt::get(PtrIndexTy, ElementSize.logBase2())); } else { - Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); + Idx = + Builder.CreateMul(Idx, ConstantInt::get(PtrIndexTy, ElementSize)); } } // Create an ugly GEP with a single index for each index. @@ -896,7 +904,7 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( // Create a GEP with the constant offset index. if (AccumulativeByteOffset != 0) { - Value *Offset = ConstantInt::get(IntPtrTy, AccumulativeByteOffset); + Value *Offset = ConstantInt::get(PtrIndexTy, AccumulativeByteOffset); ResultPtr = Builder.CreateGEP(Builder.getInt8Ty(), ResultPtr, Offset, "uglygep"); } else @@ -910,9 +918,6 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( if (isSwapCandidate && isLegalToSwapOperand(FirstGEP, SecondGEP, L)) swapGEPOperand(FirstGEP, SecondGEP); - if (ResultPtr->getType() != Variadic->getType()) - ResultPtr = Builder.CreateBitCast(ResultPtr, Variadic->getType()); - Variadic->replaceAllUsesWith(ResultPtr); Variadic->eraseFromParent(); } @@ -922,6 +927,9 @@ SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset) { IRBuilder<> Builder(Variadic); Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); + assert(IntPtrTy == DL->getIndexType(Variadic->getType()) && + "Pointer type must match index type for arithmetic-based lowering of " + "split GEPs"); Value *ResultPtr = Builder.CreatePtrToInt(Variadic->getOperand(0), IntPtrTy); gep_type_iterator GTI = gep_type_begin(*Variadic); @@ -973,7 +981,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { if (GEP->hasAllConstantIndices()) return false; - bool Changed = canonicalizeArrayIndicesToPointerSize(GEP); + bool Changed = canonicalizeArrayIndicesToIndexSize(GEP); bool NeedsExtraction; int64_t AccumulativeByteOffset = accumulateByteOffset(GEP, NeedsExtraction); @@ -1057,7 +1065,15 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { if (LowerGEP) { // As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to // arithmetic operations if the target uses alias analysis in codegen. - if (TTI.useAA()) + // Additionally, pointers that aren't integral (and so can't be safely + // converted to integers) or those whose offset size is different from their + // pointer size (which means that doing integer arithmetic on them could + // affect that data) can't be lowered in this way. + unsigned AddrSpace = GEP->getPointerAddressSpace(); + bool PointerHasExtraData = DL->getPointerSizeInBits(AddrSpace) != + DL->getIndexSizeInBits(AddrSpace); + if (TTI.useAA() || DL->isNonIntegralAddressSpace(AddrSpace) || + PointerHasExtraData) lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); else lowerToArithmetics(GEP, AccumulativeByteOffset); @@ -1104,13 +1120,13 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // used with unsigned integers later. int64_t ElementTypeSizeOfGEP = static_cast<int64_t>( DL->getTypeAllocSize(GEP->getResultElementType())); - Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + Type *PtrIdxTy = DL->getIndexType(GEP->getType()); if (AccumulativeByteOffset % ElementTypeSizeOfGEP == 0) { // Very likely. As long as %gep is naturally aligned, the byte offset we // extracted should be a multiple of sizeof(*%gep). int64_t Index = AccumulativeByteOffset / ElementTypeSizeOfGEP; NewGEP = GetElementPtrInst::Create(GEP->getResultElementType(), NewGEP, - ConstantInt::get(IntPtrTy, Index, true), + ConstantInt::get(PtrIdxTy, Index, true), GEP->getName(), GEP); NewGEP->copyMetadata(*GEP); // Inherit the inbounds attribute of the original GEP. @@ -1131,16 +1147,11 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // // Emit an uglygep in this case. IRBuilder<> Builder(GEP); - Type *I8PtrTy = - Builder.getInt8Ty()->getPointerTo(GEP->getPointerAddressSpace()); - NewGEP = cast<Instruction>(Builder.CreateGEP( - Builder.getInt8Ty(), Builder.CreateBitCast(NewGEP, I8PtrTy), - {ConstantInt::get(IntPtrTy, AccumulativeByteOffset, true)}, "uglygep", + Builder.getInt8Ty(), NewGEP, + {ConstantInt::get(PtrIdxTy, AccumulativeByteOffset, true)}, "uglygep", GEPWasInBounds)); - NewGEP->copyMetadata(*GEP); - NewGEP = cast<Instruction>(Builder.CreateBitCast(NewGEP, GEP->getType())); } GEP->replaceAllUsesWith(NewGEP); @@ -1153,13 +1164,12 @@ bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto GetTTI = [this](Function &F) -> TargetTransformInfo & { return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); }; - SeparateConstOffsetFromGEP Impl(DT, SE, LI, TLI, GetTTI, LowerGEP); + SeparateConstOffsetFromGEP Impl(DT, LI, TLI, GetTTI, LowerGEP); return Impl.run(F); } @@ -1189,8 +1199,8 @@ bool SeparateConstOffsetFromGEP::run(Function &F) { } Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator( - const SCEV *Key, Instruction *Dominatee, - DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs) { + ExprKey Key, Instruction *Dominatee, + DenseMap<ExprKey, SmallVector<Instruction *, 2>> &DominatingExprs) { auto Pos = DominatingExprs.find(Key); if (Pos == DominatingExprs.end()) return nullptr; @@ -1210,7 +1220,7 @@ Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator( } bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { - if (!SE->isSCEVable(I->getType())) + if (!I->getType()->isIntOrIntVectorTy()) return false; // Dom: LHS+RHS @@ -1220,8 +1230,7 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { Value *LHS = nullptr, *RHS = nullptr; if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { if (LHS->getType() == RHS->getType()) { - const SCEV *Key = - SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); + ExprKey Key = createNormalizedCommutablePair(LHS, RHS); if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingAdds)) { Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); NewSExt->takeName(I); @@ -1232,9 +1241,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { } } else if (match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { if (LHS->getType() == RHS->getType()) { - const SCEV *Key = - SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); - if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingSubs)) { + if (auto *Dom = + findClosestMatchingDominator({LHS, RHS}, I, DominatingSubs)) { Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); NewSExt->takeName(I); I->replaceAllUsesWith(NewSExt); @@ -1247,16 +1255,12 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { // Add I to DominatingExprs if it's an add/sub that can't sign overflow. if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS)))) { if (programUndefinedIfPoison(I)) { - const SCEV *Key = - SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); + ExprKey Key = createNormalizedCommutablePair(LHS, RHS); DominatingAdds[Key].push_back(I); } } else if (match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { - if (programUndefinedIfPoison(I)) { - const SCEV *Key = - SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); - DominatingSubs[Key].push_back(I); - } + if (programUndefinedIfPoison(I)) + DominatingSubs[{LHS, RHS}].push_back(I); } return false; } @@ -1376,16 +1380,25 @@ void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First, First->setIsInBounds(true); } +void SeparateConstOffsetFromGEPPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<SeparateConstOffsetFromGEPPass> *>(this) + ->printPipeline(OS, MapClassName2PassName); + OS << '<'; + if (LowerGEP) + OS << "lower-gep"; + OS << '>'; +} + PreservedAnalyses SeparateConstOffsetFromGEPPass::run(Function &F, FunctionAnalysisManager &AM) { auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); - auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); auto *LI = &AM.getResult<LoopAnalysis>(F); auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F); auto GetTTI = [&AM](Function &F) -> TargetTransformInfo & { return AM.getResult<TargetIRAnalysis>(F); }; - SeparateConstOffsetFromGEP Impl(DT, SE, LI, TLI, GetTTI, LowerGEP); + SeparateConstOffsetFromGEP Impl(DT, LI, TLI, GetTTI, LowerGEP); if (!Impl.run(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 7e08120f923d..ad7d34b61470 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" @@ -42,6 +43,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -73,11 +75,14 @@ using namespace llvm::PatternMatch; STATISTIC(NumBranches, "Number of branches unswitched"); STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumSelects, "Number of selects turned into branches for unswitching"); STATISTIC(NumGuards, "Number of guards turned into branches for unswitching"); STATISTIC(NumTrivial, "Number of unswitches that are trivial"); STATISTIC( NumCostMultiplierSkipped, "Number of unswitch candidates that had their cost multiplier skipped"); +STATISTIC(NumInvariantConditionsInjected, + "Number of invariant conditions injected and unswitched"); static cl::opt<bool> EnableNonTrivialUnswitch( "enable-nontrivial-unswitch", cl::init(false), cl::Hidden, @@ -118,15 +123,53 @@ static cl::opt<bool> FreezeLoopUnswitchCond( cl::desc("If enabled, the freeze instruction will be added to condition " "of loop unswitch to prevent miscompilation.")); +static cl::opt<bool> InjectInvariantConditions( + "simple-loop-unswitch-inject-invariant-conditions", cl::Hidden, + cl::desc("Whether we should inject new invariants and unswitch them to " + "eliminate some existing (non-invariant) conditions."), + cl::init(true)); + +static cl::opt<unsigned> InjectInvariantConditionHotnesThreshold( + "simple-loop-unswitch-inject-invariant-condition-hotness-threshold", + cl::Hidden, cl::desc("Only try to inject loop invariant conditions and " + "unswitch on them to eliminate branches that are " + "not-taken 1/<this option> times or less."), + cl::init(16)); + namespace { +struct CompareDesc { + BranchInst *Term; + Value *Invariant; + BasicBlock *InLoopSucc; + + CompareDesc(BranchInst *Term, Value *Invariant, BasicBlock *InLoopSucc) + : Term(Term), Invariant(Invariant), InLoopSucc(InLoopSucc) {} +}; + +struct InjectedInvariant { + ICmpInst::Predicate Pred; + Value *LHS; + Value *RHS; + BasicBlock *InLoopSucc; + + InjectedInvariant(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, + BasicBlock *InLoopSucc) + : Pred(Pred), LHS(LHS), RHS(RHS), InLoopSucc(InLoopSucc) {} +}; + struct NonTrivialUnswitchCandidate { Instruction *TI = nullptr; TinyPtrVector<Value *> Invariants; std::optional<InstructionCost> Cost; + std::optional<InjectedInvariant> PendingInjection; NonTrivialUnswitchCandidate( Instruction *TI, ArrayRef<Value *> Invariants, - std::optional<InstructionCost> Cost = std::nullopt) - : TI(TI), Invariants(Invariants), Cost(Cost){}; + std::optional<InstructionCost> Cost = std::nullopt, + std::optional<InjectedInvariant> PendingInjection = std::nullopt) + : TI(TI), Invariants(Invariants), Cost(Cost), + PendingInjection(PendingInjection) {}; + + bool hasPendingInjection() const { return PendingInjection.has_value(); } }; } // end anonymous namespace. @@ -434,10 +477,10 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, // Return the top-most loop containing ExitBB and having ExitBB as exiting block // or the loop containing ExitBB, if there is no parent loop containing ExitBB // as exiting block. -static const Loop *getTopMostExitingLoop(const BasicBlock *ExitBB, - const LoopInfo &LI) { - const Loop *TopMost = LI.getLoopFor(ExitBB); - const Loop *Current = TopMost; +static Loop *getTopMostExitingLoop(const BasicBlock *ExitBB, + const LoopInfo &LI) { + Loop *TopMost = LI.getLoopFor(ExitBB); + Loop *Current = TopMost; while (Current) { if (Current->isLoopExiting(ExitBB)) TopMost = Current; @@ -750,15 +793,32 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, Loop *OuterL = &L; if (DefaultExitBB) { - // Clear out the default destination temporarily to allow accurate - // predecessor lists to be examined below. - SI.setDefaultDest(nullptr); // Check the loop containing this exit. - Loop *ExitL = LI.getLoopFor(DefaultExitBB); + Loop *ExitL = getTopMostExitingLoop(DefaultExitBB, LI); + if (!ExitL || ExitL->contains(OuterL)) + OuterL = ExitL; + } + for (unsigned Index : ExitCaseIndices) { + auto CaseI = SI.case_begin() + Index; + // Compute the outer loop from this exit. + Loop *ExitL = getTopMostExitingLoop(CaseI->getCaseSuccessor(), LI); if (!ExitL || ExitL->contains(OuterL)) OuterL = ExitL; } + if (SE) { + if (OuterL) + SE->forgetLoop(OuterL); + else + SE->forgetTopmostLoop(&L); + } + + if (DefaultExitBB) { + // Clear out the default destination temporarily to allow accurate + // predecessor lists to be examined below. + SI.setDefaultDest(nullptr); + } + // Store the exit cases into a separate data structure and remove them from // the switch. SmallVector<std::tuple<ConstantInt *, BasicBlock *, @@ -770,10 +830,6 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { auto CaseI = SI.case_begin() + Index; - // Compute the outer loop from this exit. - Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor()); - if (!ExitL || ExitL->contains(OuterL)) - OuterL = ExitL; // Save the value of this case. auto W = SIW.getSuccessorWeight(CaseI->getSuccessorIndex()); ExitCases.emplace_back(CaseI->getCaseValue(), CaseI->getCaseSuccessor(), W); @@ -781,13 +837,6 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, SIW.removeCase(CaseI); } - if (SE) { - if (OuterL) - SE->forgetLoop(OuterL); - else - SE->forgetTopmostLoop(&L); - } - // Check if after this all of the remaining cases point at the same // successor. BasicBlock *CommonSuccBB = nullptr; @@ -2079,7 +2128,7 @@ static void unswitchNontrivialInvariants( AssumptionCache &AC, function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref<void(Loop &, StringRef)> DestroyLoopCB) { + function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast<BranchInst>(&TI); SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI); @@ -2160,7 +2209,9 @@ static void unswitchNontrivialInvariants( SmallVector<BasicBlock *, 4> ExitBlocks; L.getUniqueExitBlocks(ExitBlocks); for (auto *ExitBB : ExitBlocks) { - Loop *NewOuterExitL = LI.getLoopFor(ExitBB); + // ExitBB can be an exit block for several levels in the loop nest. Make + // sure we find the top most. + Loop *NewOuterExitL = getTopMostExitingLoop(ExitBB, LI); if (!NewOuterExitL) { // We exited the entire nest with this block, so we're done. OuterExitL = nullptr; @@ -2181,25 +2232,6 @@ static void unswitchNontrivialInvariants( SE->forgetBlockAndLoopDispositions(); } - bool InsertFreeze = false; - if (FreezeLoopUnswitchCond) { - ICFLoopSafetyInfo SafetyInfo; - SafetyInfo.computeLoopSafetyInfo(&L); - InsertFreeze = !SafetyInfo.isGuaranteedToExecute(TI, &DT, &L); - } - - // Perform the isGuaranteedNotToBeUndefOrPoison() query before the transform, - // otherwise the branch instruction will have been moved outside the loop - // already, and may imply that a poison condition is always UB. - Value *FullUnswitchCond = nullptr; - if (FullUnswitch) { - FullUnswitchCond = - BI ? skipTrivialSelect(BI->getCondition()) : SI->getCondition(); - if (InsertFreeze) - InsertFreeze = !isGuaranteedNotToBeUndefOrPoison( - FullUnswitchCond, &AC, L.getLoopPreheader()->getTerminator(), &DT); - } - // If the edge from this terminator to a successor dominates that successor, // store a map from each block in its dominator subtree to it. This lets us // tell when cloning for a particular successor if a block is dominated by @@ -2274,10 +2306,11 @@ static void unswitchNontrivialInvariants( BasicBlock *ClonedPH = ClonedPHs.begin()->second; BI->setSuccessor(ClonedSucc, ClonedPH); BI->setSuccessor(1 - ClonedSucc, LoopPH); + Value *Cond = skipTrivialSelect(BI->getCondition()); if (InsertFreeze) - FullUnswitchCond = new FreezeInst( - FullUnswitchCond, FullUnswitchCond->getName() + ".fr", BI); - BI->setCondition(FullUnswitchCond); + Cond = new FreezeInst( + Cond, Cond->getName() + ".fr", BI); + BI->setCondition(Cond); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); } else { assert(SI && "Must either be a branch or switch!"); @@ -2294,7 +2327,7 @@ static void unswitchNontrivialInvariants( if (InsertFreeze) SI->setCondition(new FreezeInst( - FullUnswitchCond, FullUnswitchCond->getName() + ".fr", SI)); + SI->getCondition(), SI->getCondition()->getName() + ".fr", SI)); // We need to use the set to populate domtree updates as even when there // are multiple cases pointing at the same successor we only want to @@ -2593,6 +2626,57 @@ static InstructionCost computeDomSubtreeCost( return Cost; } +/// Turns a select instruction into implicit control flow branch, +/// making the following replacement: +/// +/// head: +/// --code before select-- +/// select %cond, %trueval, %falseval +/// --code after select-- +/// +/// into +/// +/// head: +/// --code before select-- +/// br i1 %cond, label %then, label %tail +/// +/// then: +/// br %tail +/// +/// tail: +/// phi [ %trueval, %then ], [ %falseval, %head] +/// unreachable +/// +/// It also makes all relevant DT and LI updates, so that all structures are in +/// valid state after this transform. +static BranchInst *turnSelectIntoBranch(SelectInst *SI, DominatorTree &DT, + LoopInfo &LI, MemorySSAUpdater *MSSAU, + AssumptionCache *AC) { + LLVM_DEBUG(dbgs() << "Turning " << *SI << " into a branch.\n"); + BasicBlock *HeadBB = SI->getParent(); + + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + SplitBlockAndInsertIfThen(SI->getCondition(), SI, false, + SI->getMetadata(LLVMContext::MD_prof), &DTU, &LI); + auto *CondBr = cast<BranchInst>(HeadBB->getTerminator()); + BasicBlock *ThenBB = CondBr->getSuccessor(0), + *TailBB = CondBr->getSuccessor(1); + if (MSSAU) + MSSAU->moveAllAfterSpliceBlocks(HeadBB, TailBB, SI); + + PHINode *Phi = PHINode::Create(SI->getType(), 2, "unswitched.select", SI); + Phi->addIncoming(SI->getTrueValue(), ThenBB); + Phi->addIncoming(SI->getFalseValue(), HeadBB); + SI->replaceAllUsesWith(Phi); + SI->eraseFromParent(); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + + ++NumSelects; + return CondBr; +} + /// Turns a llvm.experimental.guard intrinsic into implicit control flow branch, /// making the following replacement: /// @@ -2624,15 +2708,10 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); - // Remove all CheckBB's successors from DomTree. A block can be seen among - // successors more than once, but for DomTree it should be added only once. - SmallPtrSet<BasicBlock *, 4> Successors; - for (auto *Succ : successors(CheckBB)) - if (Successors.insert(Succ).second) - DTUpdates.push_back({DominatorTree::Delete, CheckBB, Succ}); - + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); Instruction *DeoptBlockTerm = - SplitBlockAndInsertIfThen(GI->getArgOperand(0), GI, true); + SplitBlockAndInsertIfThen(GI->getArgOperand(0), GI, true, + GI->getMetadata(LLVMContext::MD_prof), &DTU, &LI); BranchInst *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); // SplitBlockAndInsertIfThen inserts control flow that branches to // DeoptBlockTerm if the condition is true. We want the opposite. @@ -2649,20 +2728,6 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, GI->moveBefore(DeoptBlockTerm); GI->setArgOperand(0, ConstantInt::getFalse(GI->getContext())); - // Add new successors of CheckBB into DomTree. - for (auto *Succ : successors(CheckBB)) - DTUpdates.push_back({DominatorTree::Insert, CheckBB, Succ}); - - // Now the blocks that used to be CheckBB's successors are GuardedBlock's - // successors. - for (auto *Succ : Successors) - DTUpdates.push_back({DominatorTree::Insert, GuardedBlock, Succ}); - - // Make proper changes to DT. - DT.applyUpdates(DTUpdates); - // Inform LI of a new loop block. - L.addBasicBlockToLoop(GuardedBlock, LI); - if (MSSAU) { MemoryDef *MD = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(GI)); MSSAU->moveToPlace(MD, DeoptBlock, MemorySSA::BeforeTerminator); @@ -2670,6 +2735,8 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, MSSAU->getMemorySSA()->verifyMemorySSA(); } + if (VerifyLoopInfo) + LI.verify(DT); ++NumGuards; return CheckBI; } @@ -2700,9 +2767,10 @@ static int CalculateUnswitchCostMultiplier( const BasicBlock *CondBlock = TI.getParent(); if (DT.dominates(CondBlock, Latch) && (isGuard(&TI) || - llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) { - return L.contains(SuccBB); - }) <= 1)) { + (TI.isTerminator() && + llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) { + return L.contains(SuccBB); + }) <= 1))) { NumCostMultiplierSkipped++; return 1; } @@ -2711,12 +2779,17 @@ static int CalculateUnswitchCostMultiplier( int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size() : std::distance(LI.begin(), LI.end())); // Count amount of clones that all the candidates might cause during - // unswitching. Branch/guard counts as 1, switch counts as log2 of its cases. + // unswitching. Branch/guard/select counts as 1, switch counts as log2 of its + // cases. int UnswitchedClones = 0; - for (auto Candidate : UnswitchCandidates) { + for (const auto &Candidate : UnswitchCandidates) { const Instruction *CI = Candidate.TI; const BasicBlock *CondBlock = CI->getParent(); bool SkipExitingSuccessors = DT.dominates(CondBlock, Latch); + if (isa<SelectInst>(CI)) { + UnswitchedClones++; + continue; + } if (isGuard(CI)) { if (!SkipExitingSuccessors) UnswitchedClones++; @@ -2766,6 +2839,24 @@ static bool collectUnswitchCandidates( const Loop &L, const LoopInfo &LI, AAResults &AA, const MemorySSAUpdater *MSSAU) { assert(UnswitchCandidates.empty() && "Should be!"); + + auto AddUnswitchCandidatesForInst = [&](Instruction *I, Value *Cond) { + Cond = skipTrivialSelect(Cond); + if (isa<Constant>(Cond)) + return; + if (L.isLoopInvariant(Cond)) { + UnswitchCandidates.push_back({I, {Cond}}); + return; + } + if (match(Cond, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { + TinyPtrVector<Value *> Invariants = + collectHomogenousInstGraphLoopInvariants( + L, *static_cast<Instruction *>(Cond), LI); + if (!Invariants.empty()) + UnswitchCandidates.push_back({I, std::move(Invariants)}); + } + }; + // Whether or not we should also collect guards in the loop. bool CollectGuards = false; if (UnswitchGuards) { @@ -2779,15 +2870,20 @@ static bool collectUnswitchCandidates( if (LI.getLoopFor(BB) != &L) continue; - if (CollectGuards) - for (auto &I : *BB) - if (isGuard(&I)) { - auto *Cond = - skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0)); - // TODO: Support AND, OR conditions and partial unswitching. - if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond)) - UnswitchCandidates.push_back({&I, {Cond}}); - } + for (auto &I : *BB) { + if (auto *SI = dyn_cast<SelectInst>(&I)) { + auto *Cond = SI->getCondition(); + // Do not unswitch vector selects and logical and/or selects + if (Cond->getType()->isIntegerTy(1) && !SI->getType()->isIntegerTy(1)) + AddUnswitchCandidatesForInst(SI, Cond); + } else if (CollectGuards && isGuard(&I)) { + auto *Cond = + skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0)); + // TODO: Support AND, OR conditions and partial unswitching. + if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond)) + UnswitchCandidates.push_back({&I, {Cond}}); + } + } if (auto *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { // We can only consider fully loop-invariant switch conditions as we need @@ -2799,29 +2895,11 @@ static bool collectUnswitchCandidates( } auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); - if (!BI || !BI->isConditional() || isa<Constant>(BI->getCondition()) || + if (!BI || !BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1)) continue; - Value *Cond = skipTrivialSelect(BI->getCondition()); - if (isa<Constant>(Cond)) - continue; - - if (L.isLoopInvariant(Cond)) { - UnswitchCandidates.push_back({BI, {Cond}}); - continue; - } - - Instruction &CondI = *cast<Instruction>(Cond); - if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { - TinyPtrVector<Value *> Invariants = - collectHomogenousInstGraphLoopInvariants(L, CondI, LI); - if (Invariants.empty()) - continue; - - UnswitchCandidates.push_back({BI, std::move(Invariants)}); - continue; - } + AddUnswitchCandidatesForInst(BI, BI->getCondition()); } if (MSSAU && !findOptionMDForLoop(&L, "llvm.loop.unswitch.partial.disable") && @@ -2844,6 +2922,303 @@ static bool collectUnswitchCandidates( return !UnswitchCandidates.empty(); } +/// Tries to canonicalize condition described by: +/// +/// br (LHS pred RHS), label IfTrue, label IfFalse +/// +/// into its equivalent where `Pred` is something that we support for injected +/// invariants (so far it is limited to ult), LHS in canonicalized form is +/// non-invariant and RHS is an invariant. +static void canonicalizeForInvariantConditionInjection( + ICmpInst::Predicate &Pred, Value *&LHS, Value *&RHS, BasicBlock *&IfTrue, + BasicBlock *&IfFalse, const Loop &L) { + if (!L.contains(IfTrue)) { + Pred = ICmpInst::getInversePredicate(Pred); + std::swap(IfTrue, IfFalse); + } + + // Move loop-invariant argument to RHS position. + if (L.isLoopInvariant(LHS)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(LHS, RHS); + } + + if (Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())) { + // Turn "x >=s 0" into "x <u UMIN_INT" + Pred = ICmpInst::ICMP_ULT; + RHS = ConstantInt::get( + RHS->getContext(), + APInt::getSignedMinValue(RHS->getType()->getIntegerBitWidth())); + } +} + +/// Returns true, if predicate described by ( \p Pred, \p LHS, \p RHS ) +/// succeeding into blocks ( \p IfTrue, \p IfFalse) can be optimized by +/// injecting a loop-invariant condition. +static bool shouldTryInjectInvariantCondition( + const ICmpInst::Predicate Pred, const Value *LHS, const Value *RHS, + const BasicBlock *IfTrue, const BasicBlock *IfFalse, const Loop &L) { + if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS)) + return false; + // TODO: Support other predicates. + if (Pred != ICmpInst::ICMP_ULT) + return false; + // TODO: Support non-loop-exiting branches? + if (!L.contains(IfTrue) || L.contains(IfFalse)) + return false; + // FIXME: For some reason this causes problems with MSSA updates, need to + // investigate why. So far, just don't unswitch latch. + if (L.getHeader() == IfTrue) + return false; + return true; +} + +/// Returns true, if metadata on \p BI allows us to optimize branching into \p +/// TakenSucc via injection of invariant conditions. The branch should be not +/// enough and not previously unswitched, the information about this comes from +/// the metadata. +bool shouldTryInjectBasingOnMetadata(const BranchInst *BI, + const BasicBlock *TakenSucc) { + // Skip branches that have already been unswithed this way. After successful + // unswitching of injected condition, we will still have a copy of this loop + // which looks exactly the same as original one. To prevent the 2nd attempt + // of unswitching it in the same pass, mark this branch as "nothing to do + // here". + if (BI->hasMetadata("llvm.invariant.condition.injection.disabled")) + return false; + SmallVector<uint32_t> Weights; + if (!extractBranchWeights(*BI, Weights)) + return false; + unsigned T = InjectInvariantConditionHotnesThreshold; + BranchProbability LikelyTaken(T - 1, T); + + assert(Weights.size() == 2 && "Unexpected profile data!"); + size_t Idx = BI->getSuccessor(0) == TakenSucc ? 0 : 1; + auto Num = Weights[Idx]; + auto Denom = Weights[0] + Weights[1]; + // Degenerate or overflowed metadata. + if (Denom == 0 || Num > Denom) + return false; + BranchProbability ActualTaken(Num, Denom); + if (LikelyTaken > ActualTaken) + return false; + return true; +} + +/// Materialize pending invariant condition of the given candidate into IR. The +/// injected loop-invariant condition implies the original loop-variant branch +/// condition, so the materialization turns +/// +/// loop_block: +/// ... +/// br i1 %variant_cond, label InLoopSucc, label OutOfLoopSucc +/// +/// into +/// +/// preheader: +/// %invariant_cond = LHS pred RHS +/// ... +/// loop_block: +/// br i1 %invariant_cond, label InLoopSucc, label OriginalCheck +/// OriginalCheck: +/// br i1 %variant_cond, label InLoopSucc, label OutOfLoopSucc +/// ... +static NonTrivialUnswitchCandidate +injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L, + DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, MemorySSAUpdater *MSSAU) { + assert(Candidate.hasPendingInjection() && "Nothing to inject!"); + BasicBlock *Preheader = L.getLoopPreheader(); + assert(Preheader && "Loop is not in simplified form?"); + assert(LI.getLoopFor(Candidate.TI->getParent()) == &L && + "Unswitching branch of inner loop!"); + + auto Pred = Candidate.PendingInjection->Pred; + auto *LHS = Candidate.PendingInjection->LHS; + auto *RHS = Candidate.PendingInjection->RHS; + auto *InLoopSucc = Candidate.PendingInjection->InLoopSucc; + auto *TI = cast<BranchInst>(Candidate.TI); + auto *BB = Candidate.TI->getParent(); + auto *OutOfLoopSucc = InLoopSucc == TI->getSuccessor(0) ? TI->getSuccessor(1) + : TI->getSuccessor(0); + // FIXME: Remove this once limitation on successors is lifted. + assert(L.contains(InLoopSucc) && "Not supported yet!"); + assert(!L.contains(OutOfLoopSucc) && "Not supported yet!"); + auto &Ctx = BB->getContext(); + + IRBuilder<> Builder(Preheader->getTerminator()); + assert(ICmpInst::isUnsigned(Pred) && "Not supported yet!"); + if (LHS->getType() != RHS->getType()) { + if (LHS->getType()->getIntegerBitWidth() < + RHS->getType()->getIntegerBitWidth()) + LHS = Builder.CreateZExt(LHS, RHS->getType(), LHS->getName() + ".wide"); + else + RHS = Builder.CreateZExt(RHS, LHS->getType(), RHS->getName() + ".wide"); + } + // Do not use builder here: CreateICmp may simplify this into a constant and + // unswitching will break. Better optimize it away later. + auto *InjectedCond = + ICmpInst::Create(Instruction::ICmp, Pred, LHS, RHS, "injected.cond", + Preheader->getTerminator()); + auto *OldCond = TI->getCondition(); + + BasicBlock *CheckBlock = BasicBlock::Create(Ctx, BB->getName() + ".check", + BB->getParent(), InLoopSucc); + Builder.SetInsertPoint(TI); + auto *InvariantBr = + Builder.CreateCondBr(InjectedCond, InLoopSucc, CheckBlock); + + Builder.SetInsertPoint(CheckBlock); + auto *NewTerm = Builder.CreateCondBr(OldCond, InLoopSucc, OutOfLoopSucc); + + TI->eraseFromParent(); + // Prevent infinite unswitching. + NewTerm->setMetadata("llvm.invariant.condition.injection.disabled", + MDNode::get(BB->getContext(), {})); + + // Fixup phis. + for (auto &I : *InLoopSucc) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + break; + auto *Inc = PN->getIncomingValueForBlock(BB); + PN->addIncoming(Inc, CheckBlock); + } + OutOfLoopSucc->replacePhiUsesWith(BB, CheckBlock); + + SmallVector<DominatorTree::UpdateType, 4> DTUpdates = { + { DominatorTree::Insert, BB, CheckBlock }, + { DominatorTree::Insert, CheckBlock, InLoopSucc }, + { DominatorTree::Insert, CheckBlock, OutOfLoopSucc }, + { DominatorTree::Delete, BB, OutOfLoopSucc } + }; + + DT.applyUpdates(DTUpdates); + if (MSSAU) + MSSAU->applyUpdates(DTUpdates, DT); + L.addBasicBlockToLoop(CheckBlock, LI); + +#ifndef NDEBUG + DT.verify(); + LI.verify(DT); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); +#endif + + // TODO: In fact, cost of unswitching a new invariant candidate is *slightly* + // higher because we have just inserted a new block. Need to think how to + // adjust the cost of injected candidates when it was first computed. + LLVM_DEBUG(dbgs() << "Injected a new loop-invariant branch " << *InvariantBr + << " and considering it for unswitching."); + ++NumInvariantConditionsInjected; + return NonTrivialUnswitchCandidate(InvariantBr, { InjectedCond }, + Candidate.Cost); +} + +/// Given chain of loop branch conditions looking like: +/// br (Variant < Invariant1) +/// br (Variant < Invariant2) +/// br (Variant < Invariant3) +/// ... +/// collect set of invariant conditions on which we want to unswitch, which +/// look like: +/// Invariant1 <= Invariant2 +/// Invariant2 <= Invariant3 +/// ... +/// Though they might not immediately exist in the IR, we can still inject them. +static bool insertCandidatesWithPendingInjections( + SmallVectorImpl<NonTrivialUnswitchCandidate> &UnswitchCandidates, Loop &L, + ICmpInst::Predicate Pred, ArrayRef<CompareDesc> Compares, + const DominatorTree &DT) { + + assert(ICmpInst::isRelational(Pred)); + assert(ICmpInst::isStrictPredicate(Pred)); + if (Compares.size() < 2) + return false; + ICmpInst::Predicate NonStrictPred = ICmpInst::getNonStrictPredicate(Pred); + for (auto Prev = Compares.begin(), Next = Compares.begin() + 1; + Next != Compares.end(); ++Prev, ++Next) { + Value *LHS = Next->Invariant; + Value *RHS = Prev->Invariant; + BasicBlock *InLoopSucc = Prev->InLoopSucc; + InjectedInvariant ToInject(NonStrictPred, LHS, RHS, InLoopSucc); + NonTrivialUnswitchCandidate Candidate(Prev->Term, { LHS, RHS }, + std::nullopt, std::move(ToInject)); + UnswitchCandidates.push_back(std::move(Candidate)); + } + return true; +} + +/// Collect unswitch candidates by invariant conditions that are not immediately +/// present in the loop. However, they can be injected into the code if we +/// decide it's profitable. +/// An example of such conditions is following: +/// +/// for (...) { +/// x = load ... +/// if (! x <u C1) break; +/// if (! x <u C2) break; +/// <do something> +/// } +/// +/// We can unswitch by condition "C1 <=u C2". If that is true, then "x <u C1 <= +/// C2" automatically implies "x <u C2", so we can get rid of one of +/// loop-variant checks in unswitched loop version. +static bool collectUnswitchCandidatesWithInjections( + SmallVectorImpl<NonTrivialUnswitchCandidate> &UnswitchCandidates, + IVConditionInfo &PartialIVInfo, Instruction *&PartialIVCondBranch, Loop &L, + const DominatorTree &DT, const LoopInfo &LI, AAResults &AA, + const MemorySSAUpdater *MSSAU) { + if (!InjectInvariantConditions) + return false; + + if (!DT.isReachableFromEntry(L.getHeader())) + return false; + auto *Latch = L.getLoopLatch(); + // Need to have a single latch and a preheader. + if (!Latch) + return false; + assert(L.getLoopPreheader() && "Must have a preheader!"); + + DenseMap<Value *, SmallVector<CompareDesc, 4> > CandidatesULT; + // Traverse the conditions that dominate latch (and therefore dominate each + // other). + for (auto *DTN = DT.getNode(Latch); L.contains(DTN->getBlock()); + DTN = DTN->getIDom()) { + ICmpInst::Predicate Pred; + Value *LHS = nullptr, *RHS = nullptr; + BasicBlock *IfTrue = nullptr, *IfFalse = nullptr; + auto *BB = DTN->getBlock(); + // Ignore inner loops. + if (LI.getLoopFor(BB) != &L) + continue; + auto *Term = BB->getTerminator(); + if (!match(Term, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), + m_BasicBlock(IfTrue), m_BasicBlock(IfFalse)))) + continue; + if (!LHS->getType()->isIntegerTy()) + continue; + canonicalizeForInvariantConditionInjection(Pred, LHS, RHS, IfTrue, IfFalse, + L); + if (!shouldTryInjectInvariantCondition(Pred, LHS, RHS, IfTrue, IfFalse, L)) + continue; + if (!shouldTryInjectBasingOnMetadata(cast<BranchInst>(Term), IfTrue)) + continue; + // Strip ZEXT for unsigned predicate. + // TODO: once signed predicates are supported, also strip SEXT. + CompareDesc Desc(cast<BranchInst>(Term), RHS, IfTrue); + while (auto *Zext = dyn_cast<ZExtInst>(LHS)) + LHS = Zext->getOperand(0); + CandidatesULT[LHS].push_back(Desc); + } + + bool Found = false; + for (auto &It : CandidatesULT) + Found |= insertCandidatesWithPendingInjections( + UnswitchCandidates, L, ICmpInst::ICMP_ULT, It.second, DT); + return Found; +} + static bool isSafeForNoNTrivialUnswitching(Loop &L, LoopInfo &LI) { if (!L.isSafeToClone()) return false; @@ -2943,6 +3318,10 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate( // cost for that terminator. auto ComputeUnswitchedCost = [&](Instruction &TI, bool FullUnswitch) -> InstructionCost { + // Unswitching selects unswitches the entire loop. + if (isa<SelectInst>(TI)) + return LoopCost; + BasicBlock &BB = *TI.getParent(); SmallPtrSet<BasicBlock *, 4> Visited; @@ -3003,10 +3382,11 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate( Instruction &TI = *Candidate.TI; ArrayRef<Value *> Invariants = Candidate.Invariants; BranchInst *BI = dyn_cast<BranchInst>(&TI); - InstructionCost CandidateCost = ComputeUnswitchedCost( - TI, /*FullUnswitch*/ !BI || - (Invariants.size() == 1 && - Invariants[0] == skipTrivialSelect(BI->getCondition()))); + bool FullUnswitch = + !BI || Candidate.hasPendingInjection() || + (Invariants.size() == 1 && + Invariants[0] == skipTrivialSelect(BI->getCondition())); + InstructionCost CandidateCost = ComputeUnswitchedCost(TI, FullUnswitch); // Calculate cost multiplier which is a tool to limit potentially // exponential behavior of loop-unswitch. if (EnableUnswitchCostMultiplier) { @@ -3033,6 +3413,32 @@ static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate( return *Best; } +// Insert a freeze on an unswitched branch if all is true: +// 1. freeze-loop-unswitch-cond option is true +// 2. The branch may not execute in the loop pre-transformation. If a branch may +// not execute and could cause UB, it would always cause UB if it is hoisted outside +// of the loop. Insert a freeze to prevent this case. +// 3. The branch condition may be poison or undef +static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT, + AssumptionCache &AC) { + assert(isa<BranchInst>(TI) || isa<SwitchInst>(TI)); + if (!FreezeLoopUnswitchCond) + return false; + + ICFLoopSafetyInfo SafetyInfo; + SafetyInfo.computeLoopSafetyInfo(&L); + if (SafetyInfo.isGuaranteedToExecute(TI, &DT, &L)) + return false; + + Value *Cond; + if (BranchInst *BI = dyn_cast<BranchInst>(&TI)) + Cond = skipTrivialSelect(BI->getCondition()); + else + Cond = skipTrivialSelect(cast<SwitchInst>(&TI)->getCondition()); + return !isGuaranteedNotToBeUndefOrPoison( + Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT); +} + static bool unswitchBestCondition( Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, AAResults &AA, TargetTransformInfo &TTI, @@ -3044,9 +3450,13 @@ static bool unswitchBestCondition( SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates; IVConditionInfo PartialIVInfo; Instruction *PartialIVCondBranch = nullptr; + collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo, + PartialIVCondBranch, L, LI, AA, MSSAU); + collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo, + PartialIVCondBranch, L, DT, LI, AA, + MSSAU); // If we didn't find any candidates, we're done. - if (!collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo, - PartialIVCondBranch, L, LI, AA, MSSAU)) + if (UnswitchCandidates.empty()) return false; LLVM_DEBUG( @@ -3065,18 +3475,36 @@ static bool unswitchBestCondition( return false; } + if (Best.hasPendingInjection()) + Best = injectPendingInvariantConditions(Best, L, DT, LI, AC, MSSAU); + assert(!Best.hasPendingInjection() && + "All injections should have been done by now!"); + if (Best.TI != PartialIVCondBranch) PartialIVInfo.InstToDuplicate.clear(); - // If the best candidate is a guard, turn it into a branch. - if (isGuard(Best.TI)) - Best.TI = - turnGuardIntoBranch(cast<IntrinsicInst>(Best.TI), L, DT, LI, MSSAU); + bool InsertFreeze; + if (auto *SI = dyn_cast<SelectInst>(Best.TI)) { + // If the best candidate is a select, turn it into a branch. Select + // instructions with a poison conditional do not propagate poison, but + // branching on poison causes UB. Insert a freeze on the select + // conditional to prevent UB after turning the select into a branch. + InsertFreeze = !isGuaranteedNotToBeUndefOrPoison( + SI->getCondition(), &AC, L.getLoopPreheader()->getTerminator(), &DT); + Best.TI = turnSelectIntoBranch(SI, DT, LI, MSSAU, &AC); + } else { + // If the best candidate is a guard, turn it into a branch. + if (isGuard(Best.TI)) + Best.TI = + turnGuardIntoBranch(cast<IntrinsicInst>(Best.TI), L, DT, LI, MSSAU); + InsertFreeze = shouldInsertFreeze(L, *Best.TI, DT, AC); + } LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost << ") terminator: " << *Best.TI << "\n"); unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT, - LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB); + LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB, + InsertFreeze); return true; } @@ -3124,6 +3552,8 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, return true; } + const Function *F = L.getHeader()->getParent(); + // Check whether we should continue with non-trivial conditions. // EnableNonTrivialUnswitch: Global variable that forces non-trivial // unswitching for testing and debugging. @@ -3136,18 +3566,41 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, // branches even on targets that have divergence. // https://bugs.llvm.org/show_bug.cgi?id=48819 bool ContinueWithNonTrivial = - EnableNonTrivialUnswitch || (NonTrivial && !TTI.hasBranchDivergence()); + EnableNonTrivialUnswitch || (NonTrivial && !TTI.hasBranchDivergence(F)); if (!ContinueWithNonTrivial) return false; // Skip non-trivial unswitching for optsize functions. - if (L.getHeader()->getParent()->hasOptSize()) + if (F->hasOptSize()) return false; - // Skip cold loops, as unswitching them brings little benefit - // but increases the code size - if (PSI && PSI->hasProfileSummary() && BFI && - PSI->isFunctionColdInCallGraph(L.getHeader()->getParent(), *BFI)) { + // Returns true if Loop L's loop nest is cold, i.e. if the headers of L, + // of the loops L is nested in, and of the loops nested in L are all cold. + auto IsLoopNestCold = [&](const Loop *L) { + // Check L and all of its parent loops. + auto *Parent = L; + while (Parent) { + if (!PSI->isColdBlock(Parent->getHeader(), BFI)) + return false; + Parent = Parent->getParentLoop(); + } + // Next check all loops nested within L. + SmallVector<const Loop *, 4> Worklist; + Worklist.insert(Worklist.end(), L->getSubLoops().begin(), + L->getSubLoops().end()); + while (!Worklist.empty()) { + auto *CurLoop = Worklist.pop_back_val(); + if (!PSI->isColdBlock(CurLoop->getHeader(), BFI)) + return false; + Worklist.insert(Worklist.end(), CurLoop->getSubLoops().begin(), + CurLoop->getSubLoops().end()); + } + return true; + }; + + // Skip cold loops in cold loop nests, as unswitching them brings little + // benefit but increases the code size + if (PSI && PSI->hasProfileSummary() && BFI && IsLoopNestCold(&L)) { LLVM_DEBUG(dbgs() << " Skip cold loop: " << L << "\n"); return false; } @@ -3249,10 +3702,10 @@ void SimpleLoopUnswitchPass::printPipeline( static_cast<PassInfoMixin<SimpleLoopUnswitchPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; OS << (NonTrivial ? "" : "no-") << "nontrivial;"; OS << (Trivial ? "" : "no-") << "trivial"; - OS << ">"; + OS << '>'; } namespace { diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index e014f5d1eb04..7017f6adf3a2 100644 --- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -121,7 +121,7 @@ performBlockTailMerging(Function &F, ArrayRef<BasicBlock *> BBs, // Now, go through each block (with the current terminator type) // we've recorded, and rewrite it to branch to the new common block. - const DILocation *CommonDebugLoc = nullptr; + DILocation *CommonDebugLoc = nullptr; for (BasicBlock *BB : BBs) { auto *Term = BB->getTerminator(); assert(Term->getOpcode() == CanonicalTerm->getOpcode() && @@ -228,8 +228,8 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 32> Edges; FindFunctionBackedges(F, Edges); SmallPtrSet<BasicBlock *, 16> UniqueLoopHeaders; - for (unsigned i = 0, e = Edges.size(); i != e; ++i) - UniqueLoopHeaders.insert(const_cast<BasicBlock *>(Edges[i].second)); + for (const auto &Edge : Edges) + UniqueLoopHeaders.insert(const_cast<BasicBlock *>(Edge.second)); SmallVector<WeakVH, 16> LoopHeaders(UniqueLoopHeaders.begin(), UniqueLoopHeaders.end()); @@ -338,8 +338,8 @@ void SimplifyCFGPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<SimplifyCFGPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; - OS << "bonus-inst-threshold=" << Options.BonusInstThreshold << ";"; + OS << '<'; + OS << "bonus-inst-threshold=" << Options.BonusInstThreshold << ';'; OS << (Options.ForwardSwitchCondToPhi ? "" : "no-") << "forward-switch-cond;"; OS << (Options.ConvertSwitchRangeToICmp ? "" : "no-") << "switch-range-to-icmp;"; @@ -347,8 +347,10 @@ void SimplifyCFGPass::printPipeline( << "switch-to-lookup;"; OS << (Options.NeedCanonicalLoop ? "" : "no-") << "keep-loops;"; OS << (Options.HoistCommonInsts ? "" : "no-") << "hoist-common-insts;"; - OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts"; - OS << ">"; + OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts;"; + OS << (Options.SpeculateBlocks ? "" : "no-") << "speculate-blocks;"; + OS << (Options.SimplifyCondBranch ? "" : "no-") << "simplify-cond-branch"; + OS << '>'; } PreservedAnalyses SimplifyCFGPass::run(Function &F, @@ -358,11 +360,6 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F, DominatorTree *DT = nullptr; if (RequireAndPreserveDomTree) DT = &AM.getResult<DominatorTreeAnalysis>(F); - if (F.hasFnAttribute(Attribute::OptForFuzzing)) { - Options.setSimplifyCondBranch(false).setFoldTwoEntryPHINode(false); - } else { - Options.setSimplifyCondBranch(true).setFoldTwoEntryPHINode(true); - } if (!simplifyFunctionCFG(F, TTI, DT, Options)) return PreservedAnalyses::all(); PreservedAnalyses PA; @@ -395,13 +392,6 @@ struct CFGSimplifyPass : public FunctionPass { DominatorTree *DT = nullptr; if (RequireAndPreserveDomTree) DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - if (F.hasFnAttribute(Attribute::OptForFuzzing)) { - Options.setSimplifyCondBranch(false) - .setFoldTwoEntryPHINode(false); - } else { - Options.setSimplifyCondBranch(true) - .setFoldTwoEntryPHINode(true); - } auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); return simplifyFunctionCFG(F, TTI, DT, Options); diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index 65f8d760ede3..e866fe681127 100644 --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -152,7 +152,7 @@ bool SpeculativeExecutionLegacyPass::runOnFunction(Function &F) { namespace llvm { bool SpeculativeExecutionPass::runImpl(Function &F, TargetTransformInfo *TTI) { - if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) { + if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence(&F)) { LLVM_DEBUG(dbgs() << "Not running SpeculativeExecution because " "TTI->hasBranchDivergence() is false.\n"); return false; diff --git a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 70df0cec0dca..fdb41cb415df 100644 --- a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -484,9 +484,9 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( // = B + (sext(Idx) * sext(S)) * ElementSize // = B + (sext(Idx) * ElementSize) * sext(S) // Casting to IntegerType is safe because we skipped vector GEPs. - IntegerType *IntPtrTy = cast<IntegerType>(DL->getIntPtrType(I->getType())); + IntegerType *PtrIdxTy = cast<IntegerType>(DL->getIndexType(I->getType())); ConstantInt *ScaledIdx = ConstantInt::get( - IntPtrTy, Idx->getSExtValue() * (int64_t)ElementSize, true); + PtrIdxTy, Idx->getSExtValue() * (int64_t)ElementSize, true); allocateCandidatesAndFindBasis(Candidate::GEP, B, ScaledIdx, S, I); } @@ -549,18 +549,18 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( Value *ArrayIdx = GEP->getOperand(I); uint64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType()); if (ArrayIdx->getType()->getIntegerBitWidth() <= - DL->getPointerSizeInBits(GEP->getAddressSpace())) { - // Skip factoring if ArrayIdx is wider than the pointer size, because - // ArrayIdx is implicitly truncated to the pointer size. + DL->getIndexSizeInBits(GEP->getAddressSpace())) { + // Skip factoring if ArrayIdx is wider than the index size, because + // ArrayIdx is implicitly truncated to the index size. factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP); } // When ArrayIdx is the sext of a value, we try to factor that value as // well. Handling this case is important because array indices are - // typically sign-extended to the pointer size. + // typically sign-extended to the pointer index size. Value *TruncatedArrayIdx = nullptr; if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) && TruncatedArrayIdx->getType()->getIntegerBitWidth() <= - DL->getPointerSizeInBits(GEP->getAddressSpace())) { + DL->getIndexSizeInBits(GEP->getAddressSpace())) { // Skip factoring if TruncatedArrayIdx is wider than the pointer size, // because TruncatedArrayIdx is implicitly truncated to the pointer size. factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP); @@ -675,24 +675,24 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( } case Candidate::GEP: { - Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType()); - bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds(); - if (BumpWithUglyGEP) { - // C = (char *)Basis + Bump - unsigned AS = Basis.Ins->getType()->getPointerAddressSpace(); - Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS); - Reduced = Builder.CreateBitCast(Basis.Ins, CharTy); - Reduced = - Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds); - Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType()); - } else { - // C = gep Basis, Bump - // Canonicalize bump to pointer size. - Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy); - Reduced = Builder.CreateGEP( - cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), - Basis.Ins, Bump, "", InBounds); - } + Type *OffsetTy = DL->getIndexType(C.Ins->getType()); + bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds(); + if (BumpWithUglyGEP) { + // C = (char *)Basis + Bump + unsigned AS = Basis.Ins->getType()->getPointerAddressSpace(); + Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS); + Reduced = Builder.CreateBitCast(Basis.Ins, CharTy); + Reduced = + Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds); + Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType()); + } else { + // C = gep Basis, Bump + // Canonicalize bump to pointer size. + Bump = Builder.CreateSExtOrTrunc(Bump, OffsetTy); + Reduced = Builder.CreateGEP( + cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), Basis.Ins, + Bump, "", InBounds); + } break; } default: diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 81d151c2904e..fac5695c7bea 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -15,10 +15,10 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LegacyDivergenceAnalysis.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/Analysis/RegionPass.h" +#include "llvm/Analysis/UniformityAnalysis.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -239,12 +239,12 @@ class StructurizeCFG { Type *Boolean; ConstantInt *BoolTrue; ConstantInt *BoolFalse; - UndefValue *BoolUndef; + Value *BoolPoison; Function *Func; Region *ParentRegion; - LegacyDivergenceAnalysis *DA = nullptr; + UniformityInfo *UA = nullptr; DominatorTree *DT; SmallVector<RegionNode *, 8> Order; @@ -319,7 +319,7 @@ class StructurizeCFG { public: void init(Region *R); bool run(Region *R, DominatorTree *DT); - bool makeUniformRegion(Region *R, LegacyDivergenceAnalysis *DA); + bool makeUniformRegion(Region *R, UniformityInfo &UA); }; class StructurizeCFGLegacyPass : public RegionPass { @@ -339,8 +339,9 @@ public: StructurizeCFG SCFG; SCFG.init(R); if (SkipUniformRegions) { - LegacyDivergenceAnalysis *DA = &getAnalysis<LegacyDivergenceAnalysis>(); - if (SCFG.makeUniformRegion(R, DA)) + UniformityInfo &UA = + getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); + if (SCFG.makeUniformRegion(R, UA)) return false; } DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -351,7 +352,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { if (SkipUniformRegions) - AU.addRequired<LegacyDivergenceAnalysis>(); + AU.addRequired<UniformityInfoWrapperPass>(); AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); @@ -366,7 +367,7 @@ char StructurizeCFGLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(StructurizeCFGLegacyPass, "structurizecfg", "Structurize the CFG", false, false) -INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) @@ -798,8 +799,6 @@ void StructurizeCFG::killTerminator(BasicBlock *BB) { for (BasicBlock *Succ : successors(BB)) delPhiValues(BB, Succ); - if (DA) - DA->removeValue(Term); Term->eraseFromParent(); } @@ -957,7 +956,7 @@ void StructurizeCFG::wireFlow(bool ExitUseAllowed, BasicBlock *Next = needPostfix(Flow, ExitUseAllowed); // let it point to entry and next block - BranchInst *Br = BranchInst::Create(Entry, Next, BoolUndef, Flow); + BranchInst *Br = BranchInst::Create(Entry, Next, BoolPoison, Flow); Br->setDebugLoc(TermDL[Flow]); Conditions.push_back(Br); addPhiValues(Flow, Entry); @@ -998,7 +997,7 @@ void StructurizeCFG::handleLoops(bool ExitUseAllowed, // Create an extra loop end node LoopEnd = needPrefix(false); BasicBlock *Next = needPostfix(LoopEnd, ExitUseAllowed); - BranchInst *Br = BranchInst::Create(Next, LoopStart, BoolUndef, LoopEnd); + BranchInst *Br = BranchInst::Create(Next, LoopStart, BoolPoison, LoopEnd); Br->setDebugLoc(TermDL[LoopEnd]); LoopConds.push_back(Br); addPhiValues(LoopEnd, LoopStart); @@ -1064,7 +1063,7 @@ void StructurizeCFG::rebuildSSA() { } static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, - const LegacyDivergenceAnalysis &DA) { + const UniformityInfo &UA) { // Bool for if all sub-regions are uniform. bool SubRegionsAreUniform = true; // Count of how many direct children are conditional. @@ -1076,7 +1075,7 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, if (!Br || !Br->isConditional()) continue; - if (!DA.isUniform(Br)) + if (!UA.isUniform(Br)) return false; // One of our direct children is conditional. @@ -1086,7 +1085,7 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, << " has uniform terminator\n"); } else { // Explicitly refuse to treat regions as uniform if they have non-uniform - // subregions. We cannot rely on DivergenceAnalysis for branches in + // subregions. We cannot rely on UniformityAnalysis for branches in // subregions because those branches may have been removed and re-created, // so we look for our metadata instead. // @@ -1126,17 +1125,17 @@ void StructurizeCFG::init(Region *R) { Boolean = Type::getInt1Ty(Context); BoolTrue = ConstantInt::getTrue(Context); BoolFalse = ConstantInt::getFalse(Context); - BoolUndef = UndefValue::get(Boolean); + BoolPoison = PoisonValue::get(Boolean); - this->DA = nullptr; + this->UA = nullptr; } -bool StructurizeCFG::makeUniformRegion(Region *R, - LegacyDivergenceAnalysis *DA) { +bool StructurizeCFG::makeUniformRegion(Region *R, UniformityInfo &UA) { if (R->isTopLevelRegion()) return false; - this->DA = DA; + this->UA = &UA; + // TODO: We could probably be smarter here with how we handle sub-regions. // We currently rely on the fact that metadata is set by earlier invocations // of the pass on sub-regions, and that this metadata doesn't get lost -- @@ -1144,7 +1143,7 @@ bool StructurizeCFG::makeUniformRegion(Region *R, unsigned UniformMDKindID = R->getEntry()->getContext().getMDKindID("structurizecfg.uniform"); - if (hasOnlyUniformBranches(R, UniformMDKindID, *DA)) { + if (hasOnlyUniformBranches(R, UniformMDKindID, UA)) { LLVM_DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R << '\n'); diff --git a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp index 9e08954ef643..e53019768e88 100644 --- a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp +++ b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -13,7 +13,6 @@ #include "llvm/Transforms/Scalar/WarnMissedTransforms.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -104,47 +103,3 @@ WarnMissedTransformationsPass::run(Function &F, FunctionAnalysisManager &AM) { return PreservedAnalyses::all(); } - -// Legacy pass manager boilerplate -namespace { -class WarnMissedTransformationsLegacy : public FunctionPass { -public: - static char ID; - - explicit WarnMissedTransformationsLegacy() : FunctionPass(ID) { - initializeWarnMissedTransformationsLegacyPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - - warnAboutLeftoverTransformations(&F, &LI, &ORE); - return false; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - - AU.setPreservesAll(); - } -}; -} // end anonymous namespace - -char WarnMissedTransformationsLegacy::ID = 0; - -INITIALIZE_PASS_BEGIN(WarnMissedTransformationsLegacy, "transform-warning", - "Warn about non-applied transformations", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_END(WarnMissedTransformationsLegacy, "transform-warning", - "Warn about non-applied transformations", false, false) - -Pass *llvm::createWarnMissedTransformationsPass() { - return new WarnMissedTransformationsLegacy(); -} diff --git a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp index 24972db404be..2195406c144c 100644 --- a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp +++ b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp @@ -16,7 +16,11 @@ #include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h" #include "llvm/ADT/SparseBitVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/DataExtractor.h" +#include "llvm/Support/MD5.h" +#include "llvm/Support/MathExtras.h" using namespace llvm; @@ -179,11 +183,7 @@ static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg, // Scan the format string to locate all specifiers, and mark the ones that // specify a string, i.e, the "%s" specifier with optional '*' characters. -static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) { - StringRef Str; - if (!getConstantStringInfo(Fmt, Str) || Str.empty()) - return; - +static void locateCStrings(SparseBitVector<8> &BV, StringRef Str) { static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn"; size_t SpecPos = 0; // Skip the first argument, the format string. @@ -207,14 +207,320 @@ static void locateCStrings(SparseBitVector<8> &BV, Value *Fmt) { } } -Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, - ArrayRef<Value *> Args) { +// helper struct to package the string related data +struct StringData { + StringRef Str; + Value *RealSize = nullptr; + Value *AlignedSize = nullptr; + bool IsConst = true; + + StringData(StringRef ST, Value *RS, Value *AS, bool IC) + : Str(ST), RealSize(RS), AlignedSize(AS), IsConst(IC) {} +}; + +// Calculates frame size required for current printf expansion and allocates +// space on printf buffer. Printf frame includes following contents +// [ ControlDWord , format string/Hash , Arguments (each aligned to 8 byte) ] +static Value *callBufferedPrintfStart( + IRBuilder<> &Builder, ArrayRef<Value *> Args, Value *Fmt, + bool isConstFmtStr, SparseBitVector<8> &SpecIsCString, + SmallVectorImpl<StringData> &StringContents, Value *&ArgSize) { + Module *M = Builder.GetInsertBlock()->getModule(); + Value *NonConstStrLen = nullptr; + Value *LenWithNull = nullptr; + Value *LenWithNullAligned = nullptr; + Value *TempAdd = nullptr; + + // First 4 bytes to be reserved for control dword + size_t BufSize = 4; + if (isConstFmtStr) + // First 8 bytes of MD5 hash + BufSize += 8; + else { + LenWithNull = getStrlenWithNull(Builder, Fmt); + + // Align the computed length to next 8 byte boundary + TempAdd = Builder.CreateAdd(LenWithNull, + ConstantInt::get(LenWithNull->getType(), 7U)); + NonConstStrLen = Builder.CreateAnd( + TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U)); + + StringContents.push_back( + StringData(StringRef(), LenWithNull, NonConstStrLen, false)); + } + + for (size_t i = 1; i < Args.size(); i++) { + if (SpecIsCString.test(i)) { + StringRef ArgStr; + if (getConstantStringInfo(Args[i], ArgStr)) { + auto alignedLen = alignTo(ArgStr.size() + 1, 8); + StringContents.push_back(StringData( + ArgStr, + /*RealSize*/ nullptr, /*AlignedSize*/ nullptr, /*IsConst*/ true)); + BufSize += alignedLen; + } else { + LenWithNull = getStrlenWithNull(Builder, Args[i]); + + // Align the computed length to next 8 byte boundary + TempAdd = Builder.CreateAdd( + LenWithNull, ConstantInt::get(LenWithNull->getType(), 7U)); + LenWithNullAligned = Builder.CreateAnd( + TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U)); + + if (NonConstStrLen) { + auto Val = Builder.CreateAdd(LenWithNullAligned, NonConstStrLen, + "cumulativeAdd"); + NonConstStrLen = Val; + } else + NonConstStrLen = LenWithNullAligned; + + StringContents.push_back( + StringData(StringRef(), LenWithNull, LenWithNullAligned, false)); + } + } else { + int AllocSize = M->getDataLayout().getTypeAllocSize(Args[i]->getType()); + // We end up expanding non string arguments to 8 bytes + // (args smaller than 8 bytes) + BufSize += std::max(AllocSize, 8); + } + } + + // calculate final size value to be passed to printf_alloc + Value *SizeToReserve = ConstantInt::get(Builder.getInt64Ty(), BufSize, false); + SmallVector<Value *, 1> Alloc_args; + if (NonConstStrLen) + SizeToReserve = Builder.CreateAdd(NonConstStrLen, SizeToReserve); + + ArgSize = Builder.CreateTrunc(SizeToReserve, Builder.getInt32Ty()); + Alloc_args.push_back(ArgSize); + + // call the printf_alloc function + AttributeList Attr = AttributeList::get( + Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); + + Type *Tys_alloc[1] = {Builder.getInt32Ty()}; + Type *I8Ptr = + Builder.getInt8PtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace()); + FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false); + auto PrintfAllocFn = + M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr); + + return Builder.CreateCall(PrintfAllocFn, Alloc_args, "printf_alloc_fn"); +} + +// Prepare constant string argument to push onto the buffer +static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder, + SmallVectorImpl<Value *> &WhatToStore) { + std::string Str(SD->Str.str() + '\0'); + + DataExtractor Extractor(Str, /*IsLittleEndian=*/true, 8); + DataExtractor::Cursor Offset(0); + while (Offset && Offset.tell() < Str.size()) { + const uint64_t ReadSize = 4; + uint64_t ReadNow = std::min(ReadSize, Str.size() - Offset.tell()); + uint64_t ReadBytes = 0; + switch (ReadNow) { + default: + llvm_unreachable("min(4, X) > 4?"); + case 1: + ReadBytes = Extractor.getU8(Offset); + break; + case 2: + ReadBytes = Extractor.getU16(Offset); + break; + case 3: + ReadBytes = Extractor.getU24(Offset); + break; + case 4: + ReadBytes = Extractor.getU32(Offset); + break; + } + cantFail(Offset.takeError(), "failed to read bytes from constant array"); + + APInt IntVal(8 * ReadSize, ReadBytes); + + // TODO: Should not bother aligning up. + if (ReadNow < ReadSize) + IntVal = IntVal.zext(8 * ReadSize); + + Type *IntTy = Type::getIntNTy(Builder.getContext(), IntVal.getBitWidth()); + WhatToStore.push_back(ConstantInt::get(IntTy, IntVal)); + } + // Additional padding for 8 byte alignment + int Rem = (Str.size() % 8); + if (Rem > 0 && Rem <= 4) + WhatToStore.push_back(ConstantInt::get(Builder.getInt32Ty(), 0)); +} + +static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) { + const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); + auto Ty = Arg->getType(); + + if (auto IntTy = dyn_cast<IntegerType>(Ty)) { + if (IntTy->getBitWidth() < 64) { + return Builder.CreateZExt(Arg, Builder.getInt64Ty()); + } + } + + if (Ty->isFloatingPointTy()) { + if (DL.getTypeAllocSize(Ty) < 8) { + return Builder.CreateFPExt(Arg, Builder.getDoubleTy()); + } + } + + return Arg; +} + +static void +callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args, + Value *PtrToStore, SparseBitVector<8> &SpecIsCString, + SmallVectorImpl<StringData> &StringContents, + bool IsConstFmtStr) { + Module *M = Builder.GetInsertBlock()->getModule(); + const DataLayout &DL = M->getDataLayout(); + auto StrIt = StringContents.begin(); + size_t i = IsConstFmtStr ? 1 : 0; + for (; i < Args.size(); i++) { + SmallVector<Value *, 32> WhatToStore; + if ((i == 0) || SpecIsCString.test(i)) { + if (StrIt->IsConst) { + processConstantStringArg(StrIt, Builder, WhatToStore); + StrIt++; + } else { + // This copies the contents of the string, however the next offset + // is at aligned length, the extra space that might be created due + // to alignment padding is not populated with any specific value + // here. This would be safe as long as runtime is sync with + // the offsets. + Builder.CreateMemCpy(PtrToStore, /*DstAlign*/ Align(1), Args[i], + /*SrcAlign*/ Args[i]->getPointerAlignment(DL), + StrIt->RealSize); + + PtrToStore = + Builder.CreateInBoundsGEP(Builder.getInt8Ty(), PtrToStore, + {StrIt->AlignedSize}, "PrintBuffNextPtr"); + LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" + << *PtrToStore << '\n'); + + // done with current argument, move to next + StrIt++; + continue; + } + } else { + WhatToStore.push_back(processNonStringArg(Args[i], Builder)); + } + + for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) { + Value *toStore = WhatToStore[I]; + + StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore); + LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff + << '\n'); + (void)StBuff; + PtrToStore = Builder.CreateConstInBoundsGEP1_32( + Builder.getInt8Ty(), PtrToStore, + M->getDataLayout().getTypeAllocSize(toStore->getType()), + "PrintBuffNextPtr"); + LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" << *PtrToStore + << '\n'); + } + } +} + +Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, ArrayRef<Value *> Args, + bool IsBuffered) { auto NumOps = Args.size(); assert(NumOps >= 1); auto Fmt = Args[0]; SparseBitVector<8> SpecIsCString; - locateCStrings(SpecIsCString, Fmt); + StringRef FmtStr; + + if (getConstantStringInfo(Fmt, FmtStr)) + locateCStrings(SpecIsCString, FmtStr); + + if (IsBuffered) { + SmallVector<StringData, 8> StringContents; + Module *M = Builder.GetInsertBlock()->getModule(); + LLVMContext &Ctx = Builder.getContext(); + auto Int8Ty = Builder.getInt8Ty(); + auto Int32Ty = Builder.getInt32Ty(); + bool IsConstFmtStr = !FmtStr.empty(); + + Value *ArgSize = nullptr; + Value *Ptr = + callBufferedPrintfStart(Builder, Args, Fmt, IsConstFmtStr, + SpecIsCString, StringContents, ArgSize); + + // The buffered version still follows OpenCL printf standards for + // printf return value, i.e 0 on success, -1 on failure. + ConstantPointerNull *zeroIntPtr = + ConstantPointerNull::get(cast<PointerType>(Ptr->getType())); + + auto *Cmp = cast<ICmpInst>(Builder.CreateICmpNE(Ptr, zeroIntPtr, "")); + + BasicBlock *End = BasicBlock::Create(Ctx, "end.block", + Builder.GetInsertBlock()->getParent()); + BasicBlock *ArgPush = BasicBlock::Create( + Ctx, "argpush.block", Builder.GetInsertBlock()->getParent()); + + BranchInst::Create(ArgPush, End, Cmp, Builder.GetInsertBlock()); + Builder.SetInsertPoint(ArgPush); + + // Create controlDWord and store as the first entry, format as follows + // Bit 0 (LSB) -> stream (1 if stderr, 0 if stdout, printf always outputs to + // stdout) Bit 1 -> constant format string (1 if constant) Bits 2-31 -> size + // of printf data frame + auto ConstantTwo = Builder.getInt32(2); + auto ControlDWord = Builder.CreateShl(ArgSize, ConstantTwo); + if (IsConstFmtStr) + ControlDWord = Builder.CreateOr(ControlDWord, ConstantTwo); + + Builder.CreateStore(ControlDWord, Ptr); + + Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 4); + + // Create MD5 hash for costant format string, push low 64 bits of the + // same onto buffer and metadata. + NamedMDNode *metaD = M->getOrInsertNamedMetadata("llvm.printf.fmts"); + if (IsConstFmtStr) { + MD5 Hasher; + MD5::MD5Result Hash; + Hasher.update(FmtStr); + Hasher.final(Hash); + + // Try sticking to llvm.printf.fmts format, although we are not going to + // use the ID and argument size fields while printing, + std::string MetadataStr = + "0:0:" + llvm::utohexstr(Hash.low(), /*LowerCase=*/true) + "," + + FmtStr.str(); + MDString *fmtStrArray = MDString::get(Ctx, MetadataStr); + MDNode *myMD = MDNode::get(Ctx, fmtStrArray); + metaD->addOperand(myMD); + + Builder.CreateStore(Builder.getInt64(Hash.low()), Ptr); + Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 8); + } else { + // Include a dummy metadata instance in case of only non constant + // format string usage, This might be an absurd usecase but needs to + // be done for completeness + if (metaD->getNumOperands() == 0) { + MDString *fmtStrArray = + MDString::get(Ctx, "0:0:ffffffff,\"Non const format string\""); + MDNode *myMD = MDNode::get(Ctx, fmtStrArray); + metaD->addOperand(myMD); + } + } + + // Push The printf arguments onto buffer + callBufferedPrintfArgPush(Builder, Args, Ptr, SpecIsCString, StringContents, + IsConstFmtStr); + + // End block, returns -1 on failure + BranchInst::Create(End, ArgPush); + Builder.SetInsertPoint(End); + return Builder.CreateSExt(Builder.CreateNot(Cmp), Int32Ty, "printf_result"); + } auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0)); Desc = appendString(Builder, Desc, Fmt, NumOps == 1); diff --git a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp index 56acdcc0bc3c..7d127400651e 100644 --- a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp @@ -85,33 +85,6 @@ static cl::opt<bool> NoDiscriminators( "no-discriminators", cl::init(false), cl::desc("Disable generation of discriminator information.")); -namespace { - -// The legacy pass of AddDiscriminators. -struct AddDiscriminatorsLegacyPass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - - AddDiscriminatorsLegacyPass() : FunctionPass(ID) { - initializeAddDiscriminatorsLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; -}; - -} // end anonymous namespace - -char AddDiscriminatorsLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(AddDiscriminatorsLegacyPass, "add-discriminators", - "Add DWARF path discriminators", false, false) -INITIALIZE_PASS_END(AddDiscriminatorsLegacyPass, "add-discriminators", - "Add DWARF path discriminators", false, false) - -// Create the legacy AddDiscriminatorsPass. -FunctionPass *llvm::createAddDiscriminatorsPass() { - return new AddDiscriminatorsLegacyPass(); -} - static bool shouldHaveDiscriminator(const Instruction *I) { return !isa<IntrinsicInst>(I) || isa<MemIntrinsic>(I); } @@ -269,10 +242,6 @@ static bool addDiscriminators(Function &F) { return Changed; } -bool AddDiscriminatorsLegacyPass::runOnFunction(Function &F) { - return addDiscriminators(F); -} - PreservedAnalyses AddDiscriminatorsPass::run(Function &F, FunctionAnalysisManager &AM) { if (!addDiscriminators(F)) diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index d17c399ba798..45cf98e65a5a 100644 --- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -290,17 +290,20 @@ AssumeInst *llvm::buildAssumeFromInst(Instruction *I) { return Builder.build(); } -void llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC, +bool llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC, DominatorTree *DT) { if (!EnableKnowledgeRetention || I->isTerminator()) - return; + return false; + bool Changed = false; AssumeBuilderState Builder(I->getModule(), I, AC, DT); Builder.addInstruction(I); if (auto *Intr = Builder.build()) { Intr->insertBefore(I); + Changed = true; if (AC) AC->registerAssumption(Intr); } + return Changed; } AssumeInst * @@ -563,57 +566,26 @@ PreservedAnalyses AssumeSimplifyPass::run(Function &F, FunctionAnalysisManager &AM) { if (!EnableKnowledgeRetention) return PreservedAnalyses::all(); - simplifyAssumes(F, &AM.getResult<AssumptionAnalysis>(F), - AM.getCachedResult<DominatorTreeAnalysis>(F)); - return PreservedAnalyses::all(); -} - -namespace { -class AssumeSimplifyPassLegacyPass : public FunctionPass { -public: - static char ID; - - AssumeSimplifyPassLegacyPass() : FunctionPass(ID) { - initializeAssumeSimplifyPassLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - bool runOnFunction(Function &F) override { - if (skipFunction(F) || !EnableKnowledgeRetention) - return false; - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - DominatorTreeWrapperPass *DTWP = - getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - return simplifyAssumes(F, &AC, DTWP ? &DTWP->getDomTree() : nullptr); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - - AU.setPreservesAll(); - } -}; -} // namespace - -char AssumeSimplifyPassLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(AssumeSimplifyPassLegacyPass, "assume-simplify", - "Assume Simplify", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(AssumeSimplifyPassLegacyPass, "assume-simplify", - "Assume Simplify", false, false) - -FunctionPass *llvm::createAssumeSimplifyPass() { - return new AssumeSimplifyPassLegacyPass(); + if (!simplifyAssumes(F, &AM.getResult<AssumptionAnalysis>(F), + AM.getCachedResult<DominatorTreeAnalysis>(F))) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } PreservedAnalyses AssumeBuilderPass::run(Function &F, FunctionAnalysisManager &AM) { AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); DominatorTree* DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + bool Changed = false; for (Instruction &I : instructions(F)) - salvageKnowledge(&I, AC, DT); - return PreservedAnalyses::all(); + Changed |= salvageKnowledge(&I, AC, DT); + if (!Changed) + PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } namespace { diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 58a226fc601c..f06ea89cc61d 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -32,6 +32,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" @@ -379,8 +380,8 @@ bool llvm::MergeBlockSuccessorsIntoGivenBlocks( /// /// Possible improvements: /// - Check fully overlapping fragments and not only identical fragments. -/// - Support dbg.addr, dbg.declare. dbg.label, and possibly other meta -/// instructions being part of the sequence of consecutive instructions. +/// - Support dbg.declare. dbg.label, and possibly other meta instructions being +/// part of the sequence of consecutive instructions. static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { SmallVector<DbgValueInst *, 8> ToBeRemoved; SmallDenseSet<DebugVariable> VariableSet; @@ -599,8 +600,8 @@ bool llvm::IsBlockFollowedByDeoptOrUnreachable(const BasicBlock *BB) { unsigned Depth = 0; while (BB && Depth++ < MaxDeoptOrUnreachableSuccessorCheckDepth && VisitedBlocks.insert(BB).second) { - if (BB->getTerminatingDeoptimizeCall() || - isa<UnreachableInst>(BB->getTerminator())) + if (isa<UnreachableInst>(BB->getTerminator()) || + BB->getTerminatingDeoptimizeCall()) return true; BB = BB->getUniqueSuccessor(); } @@ -1470,133 +1471,198 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, return cast<ReturnInst>(NewRet); } -static Instruction * -SplitBlockAndInsertIfThenImpl(Value *Cond, Instruction *SplitBefore, - bool Unreachable, MDNode *BranchWeights, - DomTreeUpdater *DTU, DominatorTree *DT, - LoopInfo *LI, BasicBlock *ThenBlock) { - SmallVector<DominatorTree::UpdateType, 8> Updates; - BasicBlock *Head = SplitBefore->getParent(); - BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); - if (DTU) { - SmallPtrSet<BasicBlock *, 8> UniqueSuccessorsOfHead; - Updates.push_back({DominatorTree::Insert, Head, Tail}); - Updates.reserve(Updates.size() + 2 * succ_size(Tail)); - for (BasicBlock *SuccessorOfHead : successors(Tail)) - if (UniqueSuccessorsOfHead.insert(SuccessorOfHead).second) { - Updates.push_back({DominatorTree::Insert, Tail, SuccessorOfHead}); - Updates.push_back({DominatorTree::Delete, Head, SuccessorOfHead}); - } - } - Instruction *HeadOldTerm = Head->getTerminator(); - LLVMContext &C = Head->getContext(); - Instruction *CheckTerm; - bool CreateThenBlock = (ThenBlock == nullptr); - if (CreateThenBlock) { - ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - if (Unreachable) - CheckTerm = new UnreachableInst(C, ThenBlock); - else { - CheckTerm = BranchInst::Create(Tail, ThenBlock); - if (DTU) - Updates.push_back({DominatorTree::Insert, ThenBlock, Tail}); - } - CheckTerm->setDebugLoc(SplitBefore->getDebugLoc()); - } else - CheckTerm = ThenBlock->getTerminator(); - BranchInst *HeadNewTerm = - BranchInst::Create(/*ifTrue*/ ThenBlock, /*ifFalse*/ Tail, Cond); - if (DTU) - Updates.push_back({DominatorTree::Insert, Head, ThenBlock}); - HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); - ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); - - if (DTU) - DTU->applyUpdates(Updates); - else if (DT) { - if (DomTreeNode *OldNode = DT->getNode(Head)) { - std::vector<DomTreeNode *> Children(OldNode->begin(), OldNode->end()); - - DomTreeNode *NewNode = DT->addNewBlock(Tail, Head); - for (DomTreeNode *Child : Children) - DT->changeImmediateDominator(Child, NewNode); - - // Head dominates ThenBlock. - if (CreateThenBlock) - DT->addNewBlock(ThenBlock, Head); - else - DT->changeImmediateDominator(ThenBlock, Head); - } - } - - if (LI) { - if (Loop *L = LI->getLoopFor(Head)) { - L->addBasicBlockToLoop(ThenBlock, *LI); - L->addBasicBlockToLoop(Tail, *LI); - } - } - - return CheckTerm; -} - Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, bool Unreachable, MDNode *BranchWeights, - DominatorTree *DT, LoopInfo *LI, + DomTreeUpdater *DTU, LoopInfo *LI, BasicBlock *ThenBlock) { - return SplitBlockAndInsertIfThenImpl(Cond, SplitBefore, Unreachable, - BranchWeights, - /*DTU=*/nullptr, DT, LI, ThenBlock); + SplitBlockAndInsertIfThenElse( + Cond, SplitBefore, &ThenBlock, /* ElseBlock */ nullptr, + /* UnreachableThen */ Unreachable, + /* UnreachableElse */ false, BranchWeights, DTU, LI); + return ThenBlock->getTerminator(); } -Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, + +Instruction *llvm::SplitBlockAndInsertIfElse(Value *Cond, Instruction *SplitBefore, bool Unreachable, MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI, - BasicBlock *ThenBlock) { - return SplitBlockAndInsertIfThenImpl(Cond, SplitBefore, Unreachable, - BranchWeights, DTU, /*DT=*/nullptr, LI, - ThenBlock); + BasicBlock *ElseBlock) { + SplitBlockAndInsertIfThenElse( + Cond, SplitBefore, /* ThenBlock */ nullptr, &ElseBlock, + /* UnreachableThen */ false, + /* UnreachableElse */ Unreachable, BranchWeights, DTU, LI); + return ElseBlock->getTerminator(); } void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, Instruction **ThenTerm, Instruction **ElseTerm, MDNode *BranchWeights, - DomTreeUpdater *DTU) { - BasicBlock *Head = SplitBefore->getParent(); + DomTreeUpdater *DTU, LoopInfo *LI) { + BasicBlock *ThenBlock = nullptr; + BasicBlock *ElseBlock = nullptr; + SplitBlockAndInsertIfThenElse( + Cond, SplitBefore, &ThenBlock, &ElseBlock, /* UnreachableThen */ false, + /* UnreachableElse */ false, BranchWeights, DTU, LI); + + *ThenTerm = ThenBlock->getTerminator(); + *ElseTerm = ElseBlock->getTerminator(); +} + +void llvm::SplitBlockAndInsertIfThenElse( + Value *Cond, Instruction *SplitBefore, BasicBlock **ThenBlock, + BasicBlock **ElseBlock, bool UnreachableThen, bool UnreachableElse, + MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI) { + assert((ThenBlock || ElseBlock) && + "At least one branch block must be created"); + assert((!UnreachableThen || !UnreachableElse) && + "Split block tail must be reachable"); + SmallVector<DominatorTree::UpdateType, 8> Updates; SmallPtrSet<BasicBlock *, 8> UniqueOrigSuccessors; - if (DTU) + BasicBlock *Head = SplitBefore->getParent(); + if (DTU) { UniqueOrigSuccessors.insert(succ_begin(Head), succ_end(Head)); + Updates.reserve(4 + 2 * UniqueOrigSuccessors.size()); + } + LLVMContext &C = Head->getContext(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); + BasicBlock *TrueBlock = Tail; + BasicBlock *FalseBlock = Tail; + bool ThenToTailEdge = false; + bool ElseToTailEdge = false; + + // Encapsulate the logic around creation/insertion/etc of a new block. + auto handleBlock = [&](BasicBlock **PBB, bool Unreachable, BasicBlock *&BB, + bool &ToTailEdge) { + if (PBB == nullptr) + return; // Do not create/insert a block. + + if (*PBB) + BB = *PBB; // Caller supplied block, use it. + else { + // Create a new block. + BB = BasicBlock::Create(C, "", Head->getParent(), Tail); + if (Unreachable) + (void)new UnreachableInst(C, BB); + else { + (void)BranchInst::Create(Tail, BB); + ToTailEdge = true; + } + BB->getTerminator()->setDebugLoc(SplitBefore->getDebugLoc()); + // Pass the new block back to the caller. + *PBB = BB; + } + }; + + handleBlock(ThenBlock, UnreachableThen, TrueBlock, ThenToTailEdge); + handleBlock(ElseBlock, UnreachableElse, FalseBlock, ElseToTailEdge); + Instruction *HeadOldTerm = Head->getTerminator(); - LLVMContext &C = Head->getContext(); - BasicBlock *ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - BasicBlock *ElseBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - *ThenTerm = BranchInst::Create(Tail, ThenBlock); - (*ThenTerm)->setDebugLoc(SplitBefore->getDebugLoc()); - *ElseTerm = BranchInst::Create(Tail, ElseBlock); - (*ElseTerm)->setDebugLoc(SplitBefore->getDebugLoc()); BranchInst *HeadNewTerm = - BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/ElseBlock, Cond); + BranchInst::Create(/*ifTrue*/ TrueBlock, /*ifFalse*/ FalseBlock, Cond); HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); + if (DTU) { - SmallVector<DominatorTree::UpdateType, 8> Updates; - Updates.reserve(4 + 2 * UniqueOrigSuccessors.size()); - for (BasicBlock *Succ : successors(Head)) { - Updates.push_back({DominatorTree::Insert, Head, Succ}); - Updates.push_back({DominatorTree::Insert, Succ, Tail}); - } + Updates.emplace_back(DominatorTree::Insert, Head, TrueBlock); + Updates.emplace_back(DominatorTree::Insert, Head, FalseBlock); + if (ThenToTailEdge) + Updates.emplace_back(DominatorTree::Insert, TrueBlock, Tail); + if (ElseToTailEdge) + Updates.emplace_back(DominatorTree::Insert, FalseBlock, Tail); for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors) - Updates.push_back({DominatorTree::Insert, Tail, UniqueOrigSuccessor}); + Updates.emplace_back(DominatorTree::Insert, Tail, UniqueOrigSuccessor); for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors) - Updates.push_back({DominatorTree::Delete, Head, UniqueOrigSuccessor}); + Updates.emplace_back(DominatorTree::Delete, Head, UniqueOrigSuccessor); DTU->applyUpdates(Updates); } + + if (LI) { + if (Loop *L = LI->getLoopFor(Head); L) { + if (ThenToTailEdge) + L->addBasicBlockToLoop(TrueBlock, *LI); + if (ElseToTailEdge) + L->addBasicBlockToLoop(FalseBlock, *LI); + L->addBasicBlockToLoop(Tail, *LI); + } + } +} + +std::pair<Instruction*, Value*> +llvm::SplitBlockAndInsertSimpleForLoop(Value *End, Instruction *SplitBefore) { + BasicBlock *LoopPred = SplitBefore->getParent(); + BasicBlock *LoopBody = SplitBlock(SplitBefore->getParent(), SplitBefore); + BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore); + + auto *Ty = End->getType(); + auto &DL = SplitBefore->getModule()->getDataLayout(); + const unsigned Bitwidth = DL.getTypeSizeInBits(Ty); + + IRBuilder<> Builder(LoopBody->getTerminator()); + auto *IV = Builder.CreatePHI(Ty, 2, "iv"); + auto *IVNext = + Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next", + /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2); + auto *IVCheck = Builder.CreateICmpEQ(IVNext, End, + IV->getName() + ".check"); + Builder.CreateCondBr(IVCheck, LoopExit, LoopBody); + LoopBody->getTerminator()->eraseFromParent(); + + // Populate the IV PHI. + IV->addIncoming(ConstantInt::get(Ty, 0), LoopPred); + IV->addIncoming(IVNext, LoopBody); + + return std::make_pair(LoopBody->getFirstNonPHI(), IV); +} + +void llvm::SplitBlockAndInsertForEachLane(ElementCount EC, + Type *IndexTy, Instruction *InsertBefore, + std::function<void(IRBuilderBase&, Value*)> Func) { + + IRBuilder<> IRB(InsertBefore); + + if (EC.isScalable()) { + Value *NumElements = IRB.CreateElementCount(IndexTy, EC); + + auto [BodyIP, Index] = + SplitBlockAndInsertSimpleForLoop(NumElements, InsertBefore); + + IRB.SetInsertPoint(BodyIP); + Func(IRB, Index); + return; + } + + unsigned Num = EC.getFixedValue(); + for (unsigned Idx = 0; Idx < Num; ++Idx) { + IRB.SetInsertPoint(InsertBefore); + Func(IRB, ConstantInt::get(IndexTy, Idx)); + } +} + +void llvm::SplitBlockAndInsertForEachLane( + Value *EVL, Instruction *InsertBefore, + std::function<void(IRBuilderBase &, Value *)> Func) { + + IRBuilder<> IRB(InsertBefore); + Type *Ty = EVL->getType(); + + if (!isa<ConstantInt>(EVL)) { + auto [BodyIP, Index] = SplitBlockAndInsertSimpleForLoop(EVL, InsertBefore); + IRB.SetInsertPoint(BodyIP); + Func(IRB, Index); + return; + } + + unsigned Num = cast<ConstantInt>(EVL)->getZExtValue(); + for (unsigned Idx = 0; Idx < Num; ++Idx) { + IRB.SetInsertPoint(InsertBefore); + Func(IRB, ConstantInt::get(Ty, Idx)); + } } BranchInst *llvm::GetIfCondition(BasicBlock *BB, BasicBlock *&IfTrue, @@ -1997,3 +2063,17 @@ BasicBlock *llvm::CreateControlFlowHub( return FirstGuardBlock; } + +void llvm::InvertBranch(BranchInst *PBI, IRBuilderBase &Builder) { + Value *NewCond = PBI->getCondition(); + // If this is a "cmp" instruction, only used for branching (and nowhere + // else), then we can simply invert the predicate. + if (NewCond->hasOneUse() && isa<CmpInst>(NewCond)) { + CmpInst *CI = cast<CmpInst>(NewCond); + CI->setPredicate(CI->getInversePredicate()); + } else + NewCond = Builder.CreateNot(NewCond, NewCond->getName() + ".not"); + + PBI->setCondition(NewCond); + PBI->swapSuccessors(); +} diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 1e21a2f85446..5de8ff84de77 100644 --- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -478,6 +478,8 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_modfl: Changed |= setDoesNotThrow(F); Changed |= setWillReturn(F); + Changed |= setOnlyAccessesArgMemory(F); + Changed |= setOnlyWritesMemory(F); Changed |= setDoesNotCapture(F, 1); break; case LibFunc_memcpy: @@ -725,6 +727,8 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_frexpl: Changed |= setDoesNotThrow(F); Changed |= setWillReturn(F); + Changed |= setOnlyAccessesArgMemory(F); + Changed |= setOnlyWritesMemory(F); Changed |= setDoesNotCapture(F, 1); break; case LibFunc_fstatvfs: @@ -1937,3 +1941,87 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B, return CI; } + +Value *llvm::emitHotColdNew(Value *Num, IRBuilderBase &B, + const TargetLibraryInfo *TLI, LibFunc NewFunc, + uint8_t HotCold) { + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, NewFunc)) + return nullptr; + + StringRef Name = TLI->getName(NewFunc); + FunctionCallee Func = M->getOrInsertFunction(Name, B.getInt8PtrTy(), + Num->getType(), B.getInt8Ty()); + inferNonMandatoryLibFuncAttrs(M, Name, *TLI); + CallInst *CI = B.CreateCall(Func, {Num, B.getInt8(HotCold)}, Name); + + if (const Function *F = + dyn_cast<Function>(Func.getCallee()->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} + +Value *llvm::emitHotColdNewNoThrow(Value *Num, Value *NoThrow, IRBuilderBase &B, + const TargetLibraryInfo *TLI, + LibFunc NewFunc, uint8_t HotCold) { + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, NewFunc)) + return nullptr; + + StringRef Name = TLI->getName(NewFunc); + FunctionCallee Func = + M->getOrInsertFunction(Name, B.getInt8PtrTy(), Num->getType(), + NoThrow->getType(), B.getInt8Ty()); + inferNonMandatoryLibFuncAttrs(M, Name, *TLI); + CallInst *CI = B.CreateCall(Func, {Num, NoThrow, B.getInt8(HotCold)}, Name); + + if (const Function *F = + dyn_cast<Function>(Func.getCallee()->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} + +Value *llvm::emitHotColdNewAligned(Value *Num, Value *Align, IRBuilderBase &B, + const TargetLibraryInfo *TLI, + LibFunc NewFunc, uint8_t HotCold) { + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, NewFunc)) + return nullptr; + + StringRef Name = TLI->getName(NewFunc); + FunctionCallee Func = M->getOrInsertFunction( + Name, B.getInt8PtrTy(), Num->getType(), Align->getType(), B.getInt8Ty()); + inferNonMandatoryLibFuncAttrs(M, Name, *TLI); + CallInst *CI = B.CreateCall(Func, {Num, Align, B.getInt8(HotCold)}, Name); + + if (const Function *F = + dyn_cast<Function>(Func.getCallee()->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} + +Value *llvm::emitHotColdNewAlignedNoThrow(Value *Num, Value *Align, + Value *NoThrow, IRBuilderBase &B, + const TargetLibraryInfo *TLI, + LibFunc NewFunc, uint8_t HotCold) { + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, NewFunc)) + return nullptr; + + StringRef Name = TLI->getName(NewFunc); + FunctionCallee Func = M->getOrInsertFunction( + Name, B.getInt8PtrTy(), Num->getType(), Align->getType(), + NoThrow->getType(), B.getInt8Ty()); + inferNonMandatoryLibFuncAttrs(M, Name, *TLI); + CallInst *CI = + B.CreateCall(Func, {Num, Align, NoThrow, B.getInt8(HotCold)}, Name); + + if (const Function *F = + dyn_cast<Function>(Func.getCallee()->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} diff --git a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index 930a0bcbfac5..73a50b793e6d 100644 --- a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -202,7 +202,7 @@ bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) { ConstantInt *C = dyn_cast<ConstantInt>(Op1); if (!C && isa<BitCastInst>(Op1)) C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0)); - return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth(); + return C && C->getValue().getSignificantBits() > BypassType->getBitWidth(); } case Instruction::PHI: // Stop IR traversal in case of a crazy input code. This limits recursion diff --git a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp index d0b89ba2606e..d0b9884aa909 100644 --- a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp +++ b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp @@ -120,6 +120,8 @@ void CallGraphUpdater::removeFunction(Function &DeadFn) { DeadCGN->removeAllCalledFunctions(); CGSCC->DeleteNode(DeadCGN); } + if (FAM) + FAM->clear(DeadFn, DeadFn.getName()); } void CallGraphUpdater::replaceFunctionWith(Function &OldFn, Function &NewFn) { diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 4a82f9606d3f..b488e3bb0cbd 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" diff --git a/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp b/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp index 4d622679dbdb..c24b6ed70405 100644 --- a/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp +++ b/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -31,8 +31,6 @@ #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/IR/Constants.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" using namespace llvm; diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp index 87822ee85c2b..d55208602b71 100644 --- a/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -470,9 +470,8 @@ void PruningFunctionCloner::CloneBlock( // Nope, clone it now. BasicBlock *NewBB; - BBEntry = NewBB = BasicBlock::Create(BB->getContext()); - if (BB->hasName()) - NewBB->setName(BB->getName() + NameSuffix); + Twine NewName(BB->hasName() ? Twine(BB->getName()) + NameSuffix : ""); + BBEntry = NewBB = BasicBlock::Create(BB->getContext(), NewName, NewFunc); // It is only legal to clone a function if a block address within that // function is never referenced outside of the function. Given that, we @@ -498,6 +497,7 @@ void PruningFunctionCloner::CloneBlock( ++II) { Instruction *NewInst = cloneInstruction(II); + NewInst->insertInto(NewBB, NewBB->end()); if (HostFuncIsStrictFP) { // All function calls in the inlined function must get 'strictfp' @@ -526,7 +526,7 @@ void PruningFunctionCloner::CloneBlock( if (!NewInst->mayHaveSideEffects()) { VMap[&*II] = V; - NewInst->deleteValue(); + NewInst->eraseFromParent(); continue; } } @@ -535,7 +535,6 @@ void PruningFunctionCloner::CloneBlock( if (II->hasName()) NewInst->setName(II->getName() + NameSuffix); VMap[&*II] = NewInst; // Add instruction map to value. - NewInst->insertInto(NewBB, NewBB->end()); if (isa<CallInst>(II) && !II->isDebugOrPseudoInst()) { hasCalls = true; hasMemProfMetadata |= II->hasMetadata(LLVMContext::MD_memprof); @@ -683,8 +682,8 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, if (!NewBB) continue; // Dead block. - // Add the new block to the new function. - NewFunc->insert(NewFunc->end(), NewBB); + // Move the new block to preserve the order in the original function. + NewBB->moveBefore(NewFunc->end()); // Handle PHI nodes specially, as we have to remove references to dead // blocks. @@ -937,8 +936,8 @@ void llvm::CloneAndPruneFunctionInto( } /// Remaps instructions in \p Blocks using the mapping in \p VMap. -void llvm::remapInstructionsInBlocks( - const SmallVectorImpl<BasicBlock *> &Blocks, ValueToValueMapTy &VMap) { +void llvm::remapInstructionsInBlocks(ArrayRef<BasicBlock *> Blocks, + ValueToValueMapTy &VMap) { // Rewrite the code to refer to itself. for (auto *BB : Blocks) for (auto &Inst : *BB) diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index c1fe10504e45..c390af351a69 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -918,6 +918,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::AllocKind: case Attribute::PresplitCoroutine: case Attribute::Memory: + case Attribute::NoFPClass: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: @@ -1091,32 +1092,20 @@ static void insertLifetimeMarkersSurroundingCall( Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd, CallInst *TheCall) { LLVMContext &Ctx = M->getContext(); - auto Int8PtrTy = Type::getInt8PtrTy(Ctx); auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1); Instruction *Term = TheCall->getParent()->getTerminator(); - // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts - // needed to satisfy this requirement so they may be reused. - DenseMap<Value *, Value *> Bitcasts; - // Emit lifetime markers for the pointers given in \p Objects. Insert the // markers before the call if \p InsertBefore, and after the call otherwise. - auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects, + auto insertMarkers = [&](Intrinsic::ID MarkerFunc, ArrayRef<Value *> Objects, bool InsertBefore) { for (Value *Mem : Objects) { assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) && "Input memory not defined in original function"); - Value *&MemAsI8Ptr = Bitcasts[Mem]; - if (!MemAsI8Ptr) { - if (Mem->getType() == Int8PtrTy) - MemAsI8Ptr = Mem; - else - MemAsI8Ptr = - CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall); - } - auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr}); + Function *Func = Intrinsic::getDeclaration(M, MarkerFunc, Mem->getType()); + auto Marker = CallInst::Create(Func, {NegativeOne, Mem}); if (InsertBefore) Marker->insertBefore(TheCall); else @@ -1125,15 +1114,13 @@ static void insertLifetimeMarkersSurroundingCall( }; if (!LifetimesStart.empty()) { - auto StartFn = llvm::Intrinsic::getDeclaration( - M, llvm::Intrinsic::lifetime_start, Int8PtrTy); - insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true); + insertMarkers(Intrinsic::lifetime_start, LifetimesStart, + /*InsertBefore=*/true); } if (!LifetimesEnd.empty()) { - auto EndFn = llvm::Intrinsic::getDeclaration( - M, llvm::Intrinsic::lifetime_end, Int8PtrTy); - insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false); + insertMarkers(Intrinsic::lifetime_end, LifetimesEnd, + /*InsertBefore=*/false); } } @@ -1663,14 +1650,14 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, } } - // Remove CondGuardInsts that will be moved to the new function from the old - // function's assumption cache. + // Remove @llvm.assume calls that will be moved to the new function from the + // old function's assumption cache. for (BasicBlock *Block : Blocks) { for (Instruction &I : llvm::make_early_inc_range(*Block)) { - if (auto *CI = dyn_cast<CondGuardInst>(&I)) { + if (auto *AI = dyn_cast<AssumeInst>(&I)) { if (AC) - AC->unregisterAssumption(CI); - CI->eraseFromParent(); + AC->unregisterAssumption(AI); + AI->eraseFromParent(); } } } @@ -1864,7 +1851,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, const Function &NewFunc, AssumptionCache *AC) { for (auto AssumeVH : AC->assumptions()) { - auto *I = dyn_cast_or_null<CondGuardInst>(AssumeVH); + auto *I = dyn_cast_or_null<CallInst>(AssumeVH); if (!I) continue; @@ -1876,7 +1863,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, // that were previously in the old function, but that have now been moved // to the new function. for (auto AffectedValVH : AC->assumptionsFor(I->getOperand(0))) { - auto *AffectedCI = dyn_cast_or_null<CondGuardInst>(AffectedValVH); + auto *AffectedCI = dyn_cast_or_null<CallInst>(AffectedValVH); if (!AffectedCI) continue; if (AffectedCI->getFunction() != &OldFunc) diff --git a/llvm/lib/Transforms/Utils/CodeLayout.cpp b/llvm/lib/Transforms/Utils/CodeLayout.cpp index 9eb3aff3ffe8..ac74a1c116cc 100644 --- a/llvm/lib/Transforms/Utils/CodeLayout.cpp +++ b/llvm/lib/Transforms/Utils/CodeLayout.cpp @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// ExtTSP - layout of basic blocks with i-cache optimization. +// The file implements "cache-aware" layout algorithms of basic blocks and +// functions in a binary. // // The algorithm tries to find a layout of nodes (basic blocks) of a given CFG // optimizing jump locality and thus processor I-cache utilization. This is @@ -41,12 +42,14 @@ #include "llvm/Transforms/Utils/CodeLayout.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include <cmath> using namespace llvm; #define DEBUG_TYPE "code-layout" +namespace llvm { cl::opt<bool> EnableExtTspBlockPlacement( "enable-ext-tsp-block-placement", cl::Hidden, cl::init(false), cl::desc("Enable machine block placement based on the ext-tsp model, " @@ -56,6 +59,7 @@ cl::opt<bool> ApplyExtTspWithoutProfile( "ext-tsp-apply-without-profile", cl::desc("Whether to apply ext-tsp placement for instances w/o profile"), cl::init(true), cl::Hidden); +} // namespace llvm // Algorithm-specific params. The values are tuned for the best performance // of large-scale front-end bound binaries. @@ -69,11 +73,11 @@ static cl::opt<double> ForwardWeightUncond( static cl::opt<double> BackwardWeightCond( "ext-tsp-backward-weight-cond", cl::ReallyHidden, cl::init(0.1), - cl::desc("The weight of conditonal backward jumps for ExtTSP value")); + cl::desc("The weight of conditional backward jumps for ExtTSP value")); static cl::opt<double> BackwardWeightUncond( "ext-tsp-backward-weight-uncond", cl::ReallyHidden, cl::init(0.1), - cl::desc("The weight of unconditonal backward jumps for ExtTSP value")); + cl::desc("The weight of unconditional backward jumps for ExtTSP value")); static cl::opt<double> FallthroughWeightCond( "ext-tsp-fallthrough-weight-cond", cl::ReallyHidden, cl::init(1.0), @@ -149,29 +153,30 @@ double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr, /// A type of merging two chains, X and Y. The former chain is split into /// X1 and X2 and then concatenated with Y in the order specified by the type. -enum class MergeTypeTy : int { X_Y, X1_Y_X2, Y_X2_X1, X2_X1_Y }; +enum class MergeTypeT : int { X_Y, Y_X, X1_Y_X2, Y_X2_X1, X2_X1_Y }; /// The gain of merging two chains, that is, the Ext-TSP score of the merge -/// together with the corresponfiding merge 'type' and 'offset'. -class MergeGainTy { -public: - explicit MergeGainTy() = default; - explicit MergeGainTy(double Score, size_t MergeOffset, MergeTypeTy MergeType) +/// together with the corresponding merge 'type' and 'offset'. +struct MergeGainT { + explicit MergeGainT() = default; + explicit MergeGainT(double Score, size_t MergeOffset, MergeTypeT MergeType) : Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {} double score() const { return Score; } size_t mergeOffset() const { return MergeOffset; } - MergeTypeTy mergeType() const { return MergeType; } + MergeTypeT mergeType() const { return MergeType; } + + void setMergeType(MergeTypeT Ty) { MergeType = Ty; } // Returns 'true' iff Other is preferred over this. - bool operator<(const MergeGainTy &Other) const { + bool operator<(const MergeGainT &Other) const { return (Other.Score > EPS && Other.Score > Score + EPS); } // Update the current gain if Other is preferred over this. - void updateIfLessThan(const MergeGainTy &Other) { + void updateIfLessThan(const MergeGainT &Other) { if (*this < Other) *this = Other; } @@ -179,106 +184,102 @@ public: private: double Score{-1.0}; size_t MergeOffset{0}; - MergeTypeTy MergeType{MergeTypeTy::X_Y}; + MergeTypeT MergeType{MergeTypeT::X_Y}; }; -class Jump; -class Chain; -class ChainEdge; +struct JumpT; +struct ChainT; +struct ChainEdge; -/// A node in the graph, typically corresponding to a basic block in CFG. -class Block { -public: - Block(const Block &) = delete; - Block(Block &&) = default; - Block &operator=(const Block &) = delete; - Block &operator=(Block &&) = default; +/// A node in the graph, typically corresponding to a basic block in the CFG or +/// a function in the call graph. +struct NodeT { + NodeT(const NodeT &) = delete; + NodeT(NodeT &&) = default; + NodeT &operator=(const NodeT &) = delete; + NodeT &operator=(NodeT &&) = default; + + explicit NodeT(size_t Index, uint64_t Size, uint64_t EC) + : Index(Index), Size(Size), ExecutionCount(EC) {} + + bool isEntry() const { return Index == 0; } + + // The total execution count of outgoing jumps. + uint64_t outCount() const; + + // The total execution count of incoming jumps. + uint64_t inCount() const; - // The original index of the block in CFG. + // The original index of the node in graph. size_t Index{0}; - // The index of the block in the current chain. + // The index of the node in the current chain. size_t CurIndex{0}; - // Size of the block in the binary. + // The size of the node in the binary. uint64_t Size{0}; - // Execution count of the block in the profile data. + // The execution count of the node in the profile data. uint64_t ExecutionCount{0}; - // Current chain of the node. - Chain *CurChain{nullptr}; - // An offset of the block in the current chain. + // The current chain of the node. + ChainT *CurChain{nullptr}; + // The offset of the node in the current chain. mutable uint64_t EstimatedAddr{0}; - // Forced successor of the block in CFG. - Block *ForcedSucc{nullptr}; - // Forced predecessor of the block in CFG. - Block *ForcedPred{nullptr}; - // Outgoing jumps from the block. - std::vector<Jump *> OutJumps; - // Incoming jumps to the block. - std::vector<Jump *> InJumps; - -public: - explicit Block(size_t Index, uint64_t Size, uint64_t EC) - : Index(Index), Size(Size), ExecutionCount(EC) {} - bool isEntry() const { return Index == 0; } + // Forced successor of the node in the graph. + NodeT *ForcedSucc{nullptr}; + // Forced predecessor of the node in the graph. + NodeT *ForcedPred{nullptr}; + // Outgoing jumps from the node. + std::vector<JumpT *> OutJumps; + // Incoming jumps to the node. + std::vector<JumpT *> InJumps; }; -/// An arc in the graph, typically corresponding to a jump between two blocks. -class Jump { -public: - Jump(const Jump &) = delete; - Jump(Jump &&) = default; - Jump &operator=(const Jump &) = delete; - Jump &operator=(Jump &&) = default; - - // Source block of the jump. - Block *Source; - // Target block of the jump. - Block *Target; +/// An arc in the graph, typically corresponding to a jump between two nodes. +struct JumpT { + JumpT(const JumpT &) = delete; + JumpT(JumpT &&) = default; + JumpT &operator=(const JumpT &) = delete; + JumpT &operator=(JumpT &&) = default; + + explicit JumpT(NodeT *Source, NodeT *Target, uint64_t ExecutionCount) + : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {} + + // Source node of the jump. + NodeT *Source; + // Target node of the jump. + NodeT *Target; // Execution count of the arc in the profile data. uint64_t ExecutionCount{0}; // Whether the jump corresponds to a conditional branch. bool IsConditional{false}; - -public: - explicit Jump(Block *Source, Block *Target, uint64_t ExecutionCount) - : Source(Source), Target(Target), ExecutionCount(ExecutionCount) {} + // The offset of the jump from the source node. + uint64_t Offset{0}; }; -/// A chain (ordered sequence) of blocks. -class Chain { -public: - Chain(const Chain &) = delete; - Chain(Chain &&) = default; - Chain &operator=(const Chain &) = delete; - Chain &operator=(Chain &&) = default; +/// A chain (ordered sequence) of nodes in the graph. +struct ChainT { + ChainT(const ChainT &) = delete; + ChainT(ChainT &&) = default; + ChainT &operator=(const ChainT &) = delete; + ChainT &operator=(ChainT &&) = default; + + explicit ChainT(uint64_t Id, NodeT *Node) + : Id(Id), ExecutionCount(Node->ExecutionCount), Size(Node->Size), + Nodes(1, Node) {} - explicit Chain(uint64_t Id, Block *Block) - : Id(Id), Score(0), Blocks(1, Block) {} + size_t numBlocks() const { return Nodes.size(); } - uint64_t id() const { return Id; } + double density() const { return static_cast<double>(ExecutionCount) / Size; } - bool isEntry() const { return Blocks[0]->Index == 0; } + bool isEntry() const { return Nodes[0]->Index == 0; } bool isCold() const { - for (auto *Block : Blocks) { - if (Block->ExecutionCount > 0) + for (NodeT *Node : Nodes) { + if (Node->ExecutionCount > 0) return false; } return true; } - double score() const { return Score; } - - void setScore(double NewScore) { Score = NewScore; } - - const std::vector<Block *> &blocks() const { return Blocks; } - - size_t numBlocks() const { return Blocks.size(); } - - const std::vector<std::pair<Chain *, ChainEdge *>> &edges() const { - return Edges; - } - - ChainEdge *getEdge(Chain *Other) const { + ChainEdge *getEdge(ChainT *Other) const { for (auto It : Edges) { if (It.first == Other) return It.second; @@ -286,7 +287,7 @@ public: return nullptr; } - void removeEdge(Chain *Other) { + void removeEdge(ChainT *Other) { auto It = Edges.begin(); while (It != Edges.end()) { if (It->first == Other) { @@ -297,63 +298,68 @@ public: } } - void addEdge(Chain *Other, ChainEdge *Edge) { + void addEdge(ChainT *Other, ChainEdge *Edge) { Edges.push_back(std::make_pair(Other, Edge)); } - void merge(Chain *Other, const std::vector<Block *> &MergedBlocks) { - Blocks = MergedBlocks; - // Update the block's chains - for (size_t Idx = 0; Idx < Blocks.size(); Idx++) { - Blocks[Idx]->CurChain = this; - Blocks[Idx]->CurIndex = Idx; + void merge(ChainT *Other, const std::vector<NodeT *> &MergedBlocks) { + Nodes = MergedBlocks; + // Update the chain's data + ExecutionCount += Other->ExecutionCount; + Size += Other->Size; + Id = Nodes[0]->Index; + // Update the node's data + for (size_t Idx = 0; Idx < Nodes.size(); Idx++) { + Nodes[Idx]->CurChain = this; + Nodes[Idx]->CurIndex = Idx; } } - void mergeEdges(Chain *Other); + void mergeEdges(ChainT *Other); void clear() { - Blocks.clear(); - Blocks.shrink_to_fit(); + Nodes.clear(); + Nodes.shrink_to_fit(); Edges.clear(); Edges.shrink_to_fit(); } -private: // Unique chain identifier. uint64_t Id; // Cached ext-tsp score for the chain. - double Score; - // Blocks of the chain. - std::vector<Block *> Blocks; + double Score{0}; + // The total execution count of the chain. + uint64_t ExecutionCount{0}; + // The total size of the chain. + uint64_t Size{0}; + // Nodes of the chain. + std::vector<NodeT *> Nodes; // Adjacent chains and corresponding edges (lists of jumps). - std::vector<std::pair<Chain *, ChainEdge *>> Edges; + std::vector<std::pair<ChainT *, ChainEdge *>> Edges; }; -/// An edge in CFG representing jumps between two chains. -/// When blocks are merged into chains, the edges are combined too so that +/// An edge in the graph representing jumps between two chains. +/// When nodes are merged into chains, the edges are combined too so that /// there is always at most one edge between a pair of chains -class ChainEdge { -public: +struct ChainEdge { ChainEdge(const ChainEdge &) = delete; ChainEdge(ChainEdge &&) = default; ChainEdge &operator=(const ChainEdge &) = delete; - ChainEdge &operator=(ChainEdge &&) = default; + ChainEdge &operator=(ChainEdge &&) = delete; - explicit ChainEdge(Jump *Jump) + explicit ChainEdge(JumpT *Jump) : SrcChain(Jump->Source->CurChain), DstChain(Jump->Target->CurChain), Jumps(1, Jump) {} - const std::vector<Jump *> &jumps() const { return Jumps; } + ChainT *srcChain() const { return SrcChain; } - void changeEndpoint(Chain *From, Chain *To) { - if (From == SrcChain) - SrcChain = To; - if (From == DstChain) - DstChain = To; - } + ChainT *dstChain() const { return DstChain; } + + bool isSelfEdge() const { return SrcChain == DstChain; } - void appendJump(Jump *Jump) { Jumps.push_back(Jump); } + const std::vector<JumpT *> &jumps() const { return Jumps; } + + void appendJump(JumpT *Jump) { Jumps.push_back(Jump); } void moveJumps(ChainEdge *Other) { Jumps.insert(Jumps.end(), Other->Jumps.begin(), Other->Jumps.end()); @@ -361,15 +367,22 @@ public: Other->Jumps.shrink_to_fit(); } - bool hasCachedMergeGain(Chain *Src, Chain *Dst) const { + void changeEndpoint(ChainT *From, ChainT *To) { + if (From == SrcChain) + SrcChain = To; + if (From == DstChain) + DstChain = To; + } + + bool hasCachedMergeGain(ChainT *Src, ChainT *Dst) const { return Src == SrcChain ? CacheValidForward : CacheValidBackward; } - MergeGainTy getCachedMergeGain(Chain *Src, Chain *Dst) const { + MergeGainT getCachedMergeGain(ChainT *Src, ChainT *Dst) const { return Src == SrcChain ? CachedGainForward : CachedGainBackward; } - void setCachedMergeGain(Chain *Src, Chain *Dst, MergeGainTy MergeGain) { + void setCachedMergeGain(ChainT *Src, ChainT *Dst, MergeGainT MergeGain) { if (Src == SrcChain) { CachedGainForward = MergeGain; CacheValidForward = true; @@ -384,31 +397,55 @@ public: CacheValidBackward = false; } + void setMergeGain(MergeGainT Gain) { CachedGain = Gain; } + + MergeGainT getMergeGain() const { return CachedGain; } + + double gain() const { return CachedGain.score(); } + private: // Source chain. - Chain *SrcChain{nullptr}; + ChainT *SrcChain{nullptr}; // Destination chain. - Chain *DstChain{nullptr}; - // Original jumps in the binary with correspinding execution counts. - std::vector<Jump *> Jumps; - // Cached ext-tsp value for merging the pair of chains. - // Since the gain of merging (Src, Dst) and (Dst, Src) might be different, - // we store both values here. - MergeGainTy CachedGainForward; - MergeGainTy CachedGainBackward; + ChainT *DstChain{nullptr}; + // Original jumps in the binary with corresponding execution counts. + std::vector<JumpT *> Jumps; + // Cached gain value for merging the pair of chains. + MergeGainT CachedGain; + + // Cached gain values for merging the pair of chains. Since the gain of + // merging (Src, Dst) and (Dst, Src) might be different, we store both values + // here and a flag indicating which of the options results in a higher gain. + // Cached gain values. + MergeGainT CachedGainForward; + MergeGainT CachedGainBackward; // Whether the cached value must be recomputed. bool CacheValidForward{false}; bool CacheValidBackward{false}; }; -void Chain::mergeEdges(Chain *Other) { - assert(this != Other && "cannot merge a chain with itself"); +uint64_t NodeT::outCount() const { + uint64_t Count = 0; + for (JumpT *Jump : OutJumps) { + Count += Jump->ExecutionCount; + } + return Count; +} +uint64_t NodeT::inCount() const { + uint64_t Count = 0; + for (JumpT *Jump : InJumps) { + Count += Jump->ExecutionCount; + } + return Count; +} + +void ChainT::mergeEdges(ChainT *Other) { // Update edges adjacent to chain Other for (auto EdgeIt : Other->Edges) { - Chain *DstChain = EdgeIt.first; + ChainT *DstChain = EdgeIt.first; ChainEdge *DstEdge = EdgeIt.second; - Chain *TargetChain = DstChain == Other ? this : DstChain; + ChainT *TargetChain = DstChain == Other ? this : DstChain; ChainEdge *CurEdge = getEdge(TargetChain); if (CurEdge == nullptr) { DstEdge->changeEndpoint(Other, this); @@ -426,15 +463,14 @@ void Chain::mergeEdges(Chain *Other) { } } -using BlockIter = std::vector<Block *>::const_iterator; +using NodeIter = std::vector<NodeT *>::const_iterator; -/// A wrapper around three chains of blocks; it is used to avoid extra +/// A wrapper around three chains of nodes; it is used to avoid extra /// instantiation of the vectors. -class MergedChain { -public: - MergedChain(BlockIter Begin1, BlockIter End1, BlockIter Begin2 = BlockIter(), - BlockIter End2 = BlockIter(), BlockIter Begin3 = BlockIter(), - BlockIter End3 = BlockIter()) +struct MergedChain { + MergedChain(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(), + NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(), + NodeIter End3 = NodeIter()) : Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3), End3(End3) {} @@ -447,8 +483,8 @@ public: Func(*It); } - std::vector<Block *> getBlocks() const { - std::vector<Block *> Result; + std::vector<NodeT *> getNodes() const { + std::vector<NodeT *> Result; Result.reserve(std::distance(Begin1, End1) + std::distance(Begin2, End2) + std::distance(Begin3, End3)); Result.insert(Result.end(), Begin1, End1); @@ -457,42 +493,71 @@ public: return Result; } - const Block *getFirstBlock() const { return *Begin1; } + const NodeT *getFirstNode() const { return *Begin1; } private: - BlockIter Begin1; - BlockIter End1; - BlockIter Begin2; - BlockIter End2; - BlockIter Begin3; - BlockIter End3; + NodeIter Begin1; + NodeIter End1; + NodeIter Begin2; + NodeIter End2; + NodeIter Begin3; + NodeIter End3; }; +/// Merge two chains of nodes respecting a given 'type' and 'offset'. +/// +/// If MergeType == 0, then the result is a concatenation of two chains. +/// Otherwise, the first chain is cut into two sub-chains at the offset, +/// and merged using all possible ways of concatenating three chains. +MergedChain mergeNodes(const std::vector<NodeT *> &X, + const std::vector<NodeT *> &Y, size_t MergeOffset, + MergeTypeT MergeType) { + // Split the first chain, X, into X1 and X2 + NodeIter BeginX1 = X.begin(); + NodeIter EndX1 = X.begin() + MergeOffset; + NodeIter BeginX2 = X.begin() + MergeOffset; + NodeIter EndX2 = X.end(); + NodeIter BeginY = Y.begin(); + NodeIter EndY = Y.end(); + + // Construct a new chain from the three existing ones + switch (MergeType) { + case MergeTypeT::X_Y: + return MergedChain(BeginX1, EndX2, BeginY, EndY); + case MergeTypeT::Y_X: + return MergedChain(BeginY, EndY, BeginX1, EndX2); + case MergeTypeT::X1_Y_X2: + return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2); + case MergeTypeT::Y_X2_X1: + return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1); + case MergeTypeT::X2_X1_Y: + return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY); + } + llvm_unreachable("unexpected chain merge type"); +} + /// The implementation of the ExtTSP algorithm. class ExtTSPImpl { - using EdgeT = std::pair<uint64_t, uint64_t>; - using EdgeCountMap = std::vector<std::pair<EdgeT, uint64_t>>; - public: - ExtTSPImpl(size_t NumNodes, const std::vector<uint64_t> &NodeSizes, + ExtTSPImpl(const std::vector<uint64_t> &NodeSizes, const std::vector<uint64_t> &NodeCounts, - const EdgeCountMap &EdgeCounts) - : NumNodes(NumNodes) { + const std::vector<EdgeCountT> &EdgeCounts) + : NumNodes(NodeSizes.size()) { initialize(NodeSizes, NodeCounts, EdgeCounts); } - /// Run the algorithm and return an optimized ordering of blocks. + /// Run the algorithm and return an optimized ordering of nodes. void run(std::vector<uint64_t> &Result) { - // Pass 1: Merge blocks with their mutually forced successors + // Pass 1: Merge nodes with their mutually forced successors mergeForcedPairs(); // Pass 2: Merge pairs of chains while improving the ExtTSP objective mergeChainPairs(); - // Pass 3: Merge cold blocks to reduce code size + // Pass 3: Merge cold nodes to reduce code size mergeColdChains(); - // Collect blocks from all chains + // Collect nodes from all chains concatChains(Result); } @@ -500,26 +565,26 @@ private: /// Initialize the algorithm's data structures. void initialize(const std::vector<uint64_t> &NodeSizes, const std::vector<uint64_t> &NodeCounts, - const EdgeCountMap &EdgeCounts) { - // Initialize blocks - AllBlocks.reserve(NumNodes); - for (uint64_t Node = 0; Node < NumNodes; Node++) { - uint64_t Size = std::max<uint64_t>(NodeSizes[Node], 1ULL); - uint64_t ExecutionCount = NodeCounts[Node]; - // The execution count of the entry block is set to at least 1 - if (Node == 0 && ExecutionCount == 0) + const std::vector<EdgeCountT> &EdgeCounts) { + // Initialize nodes + AllNodes.reserve(NumNodes); + for (uint64_t Idx = 0; Idx < NumNodes; Idx++) { + uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL); + uint64_t ExecutionCount = NodeCounts[Idx]; + // The execution count of the entry node is set to at least one + if (Idx == 0 && ExecutionCount == 0) ExecutionCount = 1; - AllBlocks.emplace_back(Node, Size, ExecutionCount); + AllNodes.emplace_back(Idx, Size, ExecutionCount); } - // Initialize jumps between blocks + // Initialize jumps between nodes SuccNodes.resize(NumNodes); PredNodes.resize(NumNodes); std::vector<uint64_t> OutDegree(NumNodes, 0); AllJumps.reserve(EdgeCounts.size()); for (auto It : EdgeCounts) { - auto Pred = It.first.first; - auto Succ = It.first.second; + uint64_t Pred = It.first.first; + uint64_t Succ = It.first.second; OutDegree[Pred]++; // Ignore self-edges if (Pred == Succ) @@ -527,16 +592,16 @@ private: SuccNodes[Pred].push_back(Succ); PredNodes[Succ].push_back(Pred); - auto ExecutionCount = It.second; + uint64_t ExecutionCount = It.second; if (ExecutionCount > 0) { - auto &Block = AllBlocks[Pred]; - auto &SuccBlock = AllBlocks[Succ]; - AllJumps.emplace_back(&Block, &SuccBlock, ExecutionCount); - SuccBlock.InJumps.push_back(&AllJumps.back()); - Block.OutJumps.push_back(&AllJumps.back()); + NodeT &PredNode = AllNodes[Pred]; + NodeT &SuccNode = AllNodes[Succ]; + AllJumps.emplace_back(&PredNode, &SuccNode, ExecutionCount); + SuccNode.InJumps.push_back(&AllJumps.back()); + PredNode.OutJumps.push_back(&AllJumps.back()); } } - for (auto &Jump : AllJumps) { + for (JumpT &Jump : AllJumps) { assert(OutDegree[Jump.Source->Index] > 0); Jump.IsConditional = OutDegree[Jump.Source->Index] > 1; } @@ -544,78 +609,78 @@ private: // Initialize chains AllChains.reserve(NumNodes); HotChains.reserve(NumNodes); - for (Block &Block : AllBlocks) { - AllChains.emplace_back(Block.Index, &Block); - Block.CurChain = &AllChains.back(); - if (Block.ExecutionCount > 0) { + for (NodeT &Node : AllNodes) { + AllChains.emplace_back(Node.Index, &Node); + Node.CurChain = &AllChains.back(); + if (Node.ExecutionCount > 0) { HotChains.push_back(&AllChains.back()); } } // Initialize chain edges AllEdges.reserve(AllJumps.size()); - for (Block &Block : AllBlocks) { - for (auto &Jump : Block.OutJumps) { - auto SuccBlock = Jump->Target; - ChainEdge *CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain); + for (NodeT &PredNode : AllNodes) { + for (JumpT *Jump : PredNode.OutJumps) { + NodeT *SuccNode = Jump->Target; + ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain); // this edge is already present in the graph if (CurEdge != nullptr) { - assert(SuccBlock->CurChain->getEdge(Block.CurChain) != nullptr); + assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr); CurEdge->appendJump(Jump); continue; } // this is a new edge AllEdges.emplace_back(Jump); - Block.CurChain->addEdge(SuccBlock->CurChain, &AllEdges.back()); - SuccBlock->CurChain->addEdge(Block.CurChain, &AllEdges.back()); + PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back()); + SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back()); } } } - /// For a pair of blocks, A and B, block B is the forced successor of A, + /// For a pair of nodes, A and B, node B is the forced successor of A, /// if (i) all jumps (based on profile) from A goes to B and (ii) all jumps - /// to B are from A. Such blocks should be adjacent in the optimal ordering; - /// the method finds and merges such pairs of blocks. + /// to B are from A. Such nodes should be adjacent in the optimal ordering; + /// the method finds and merges such pairs of nodes. void mergeForcedPairs() { // Find fallthroughs based on edge weights - for (auto &Block : AllBlocks) { - if (SuccNodes[Block.Index].size() == 1 && - PredNodes[SuccNodes[Block.Index][0]].size() == 1 && - SuccNodes[Block.Index][0] != 0) { - size_t SuccIndex = SuccNodes[Block.Index][0]; - Block.ForcedSucc = &AllBlocks[SuccIndex]; - AllBlocks[SuccIndex].ForcedPred = &Block; + for (NodeT &Node : AllNodes) { + if (SuccNodes[Node.Index].size() == 1 && + PredNodes[SuccNodes[Node.Index][0]].size() == 1 && + SuccNodes[Node.Index][0] != 0) { + size_t SuccIndex = SuccNodes[Node.Index][0]; + Node.ForcedSucc = &AllNodes[SuccIndex]; + AllNodes[SuccIndex].ForcedPred = &Node; } } // There might be 'cycles' in the forced dependencies, since profile // data isn't 100% accurate. Typically this is observed in loops, when the // loop edges are the hottest successors for the basic blocks of the loop. - // Break the cycles by choosing the block with the smallest index as the + // Break the cycles by choosing the node with the smallest index as the // head. This helps to keep the original order of the loops, which likely // have already been rotated in the optimized manner. - for (auto &Block : AllBlocks) { - if (Block.ForcedSucc == nullptr || Block.ForcedPred == nullptr) + for (NodeT &Node : AllNodes) { + if (Node.ForcedSucc == nullptr || Node.ForcedPred == nullptr) continue; - auto SuccBlock = Block.ForcedSucc; - while (SuccBlock != nullptr && SuccBlock != &Block) { - SuccBlock = SuccBlock->ForcedSucc; + NodeT *SuccNode = Node.ForcedSucc; + while (SuccNode != nullptr && SuccNode != &Node) { + SuccNode = SuccNode->ForcedSucc; } - if (SuccBlock == nullptr) + if (SuccNode == nullptr) continue; // Break the cycle - AllBlocks[Block.ForcedPred->Index].ForcedSucc = nullptr; - Block.ForcedPred = nullptr; + AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr; + Node.ForcedPred = nullptr; } - // Merge blocks with their fallthrough successors - for (auto &Block : AllBlocks) { - if (Block.ForcedPred == nullptr && Block.ForcedSucc != nullptr) { - auto CurBlock = &Block; + // Merge nodes with their fallthrough successors + for (NodeT &Node : AllNodes) { + if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) { + const NodeT *CurBlock = &Node; while (CurBlock->ForcedSucc != nullptr) { - const auto NextBlock = CurBlock->ForcedSucc; - mergeChains(Block.CurChain, NextBlock->CurChain, 0, MergeTypeTy::X_Y); + const NodeT *NextBlock = CurBlock->ForcedSucc; + mergeChains(Node.CurChain, NextBlock->CurChain, 0, MergeTypeT::X_Y); CurBlock = NextBlock; } } @@ -625,23 +690,23 @@ private: /// Merge pairs of chains while improving the ExtTSP objective. void mergeChainPairs() { /// Deterministically compare pairs of chains - auto compareChainPairs = [](const Chain *A1, const Chain *B1, - const Chain *A2, const Chain *B2) { + auto compareChainPairs = [](const ChainT *A1, const ChainT *B1, + const ChainT *A2, const ChainT *B2) { if (A1 != A2) - return A1->id() < A2->id(); - return B1->id() < B2->id(); + return A1->Id < A2->Id; + return B1->Id < B2->Id; }; while (HotChains.size() > 1) { - Chain *BestChainPred = nullptr; - Chain *BestChainSucc = nullptr; - auto BestGain = MergeGainTy(); + ChainT *BestChainPred = nullptr; + ChainT *BestChainSucc = nullptr; + MergeGainT BestGain; // Iterate over all pairs of chains - for (Chain *ChainPred : HotChains) { + for (ChainT *ChainPred : HotChains) { // Get candidates for merging with the current chain - for (auto EdgeIter : ChainPred->edges()) { - Chain *ChainSucc = EdgeIter.first; - class ChainEdge *ChainEdge = EdgeIter.second; + for (auto EdgeIt : ChainPred->Edges) { + ChainT *ChainSucc = EdgeIt.first; + ChainEdge *Edge = EdgeIt.second; // Ignore loop edges if (ChainPred == ChainSucc) continue; @@ -651,8 +716,7 @@ private: continue; // Compute the gain of merging the two chains - MergeGainTy CurGain = - getBestMergeGain(ChainPred, ChainSucc, ChainEdge); + MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge); if (CurGain.score() <= EPS) continue; @@ -677,43 +741,43 @@ private: } } - /// Merge remaining blocks into chains w/o taking jump counts into - /// consideration. This allows to maintain the original block order in the - /// absense of profile data + /// Merge remaining nodes into chains w/o taking jump counts into + /// consideration. This allows to maintain the original node order in the + /// absence of profile data void mergeColdChains() { for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) { // Iterating in reverse order to make sure original fallthrough jumps are // merged first; this might be beneficial for code size. size_t NumSuccs = SuccNodes[SrcBB].size(); for (size_t Idx = 0; Idx < NumSuccs; Idx++) { - auto DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1]; - auto SrcChain = AllBlocks[SrcBB].CurChain; - auto DstChain = AllBlocks[DstBB].CurChain; + size_t DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1]; + ChainT *SrcChain = AllNodes[SrcBB].CurChain; + ChainT *DstChain = AllNodes[DstBB].CurChain; if (SrcChain != DstChain && !DstChain->isEntry() && - SrcChain->blocks().back()->Index == SrcBB && - DstChain->blocks().front()->Index == DstBB && + SrcChain->Nodes.back()->Index == SrcBB && + DstChain->Nodes.front()->Index == DstBB && SrcChain->isCold() == DstChain->isCold()) { - mergeChains(SrcChain, DstChain, 0, MergeTypeTy::X_Y); + mergeChains(SrcChain, DstChain, 0, MergeTypeT::X_Y); } } } } - /// Compute the Ext-TSP score for a given block order and a list of jumps. + /// Compute the Ext-TSP score for a given node order and a list of jumps. double extTSPScore(const MergedChain &MergedBlocks, - const std::vector<Jump *> &Jumps) const { + const std::vector<JumpT *> &Jumps) const { if (Jumps.empty()) return 0.0; uint64_t CurAddr = 0; - MergedBlocks.forEach([&](const Block *BB) { - BB->EstimatedAddr = CurAddr; - CurAddr += BB->Size; + MergedBlocks.forEach([&](const NodeT *Node) { + Node->EstimatedAddr = CurAddr; + CurAddr += Node->Size; }); double Score = 0; - for (auto &Jump : Jumps) { - const Block *SrcBlock = Jump->Source; - const Block *DstBlock = Jump->Target; + for (JumpT *Jump : Jumps) { + const NodeT *SrcBlock = Jump->Source; + const NodeT *DstBlock = Jump->Target; Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size, DstBlock->EstimatedAddr, Jump->ExecutionCount, Jump->IsConditional); @@ -727,8 +791,8 @@ private: /// computes the one having the largest increase in ExtTSP objective. The /// result is a pair with the first element being the gain and the second /// element being the corresponding merging type. - MergeGainTy getBestMergeGain(Chain *ChainPred, Chain *ChainSucc, - ChainEdge *Edge) const { + MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc, + ChainEdge *Edge) const { if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) { return Edge->getCachedMergeGain(ChainPred, ChainSucc); } @@ -742,22 +806,22 @@ private: assert(!Jumps.empty() && "trying to merge chains w/o jumps"); // The object holds the best currently chosen gain of merging the two chains - MergeGainTy Gain = MergeGainTy(); + MergeGainT Gain = MergeGainT(); /// Given a merge offset and a list of merge types, try to merge two chains /// and update Gain with a better alternative auto tryChainMerging = [&](size_t Offset, - const std::vector<MergeTypeTy> &MergeTypes) { + const std::vector<MergeTypeT> &MergeTypes) { // Skip merging corresponding to concatenation w/o splitting - if (Offset == 0 || Offset == ChainPred->blocks().size()) + if (Offset == 0 || Offset == ChainPred->Nodes.size()) return; // Skip merging if it breaks Forced successors - auto BB = ChainPred->blocks()[Offset - 1]; - if (BB->ForcedSucc != nullptr) + NodeT *Node = ChainPred->Nodes[Offset - 1]; + if (Node->ForcedSucc != nullptr) return; // Apply the merge, compute the corresponding gain, and update the best // value, if the merge is beneficial - for (const auto &MergeType : MergeTypes) { + for (const MergeTypeT &MergeType : MergeTypes) { Gain.updateIfLessThan( computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType)); } @@ -765,36 +829,36 @@ private: // Try to concatenate two chains w/o splitting Gain.updateIfLessThan( - computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeTy::X_Y)); + computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y)); if (EnableChainSplitAlongJumps) { - // Attach (a part of) ChainPred before the first block of ChainSucc - for (auto &Jump : ChainSucc->blocks().front()->InJumps) { - const auto SrcBlock = Jump->Source; + // Attach (a part of) ChainPred before the first node of ChainSucc + for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) { + const NodeT *SrcBlock = Jump->Source; if (SrcBlock->CurChain != ChainPred) continue; size_t Offset = SrcBlock->CurIndex + 1; - tryChainMerging(Offset, {MergeTypeTy::X1_Y_X2, MergeTypeTy::X2_X1_Y}); + tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y}); } - // Attach (a part of) ChainPred after the last block of ChainSucc - for (auto &Jump : ChainSucc->blocks().back()->OutJumps) { - const auto DstBlock = Jump->Source; + // Attach (a part of) ChainPred after the last node of ChainSucc + for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) { + const NodeT *DstBlock = Jump->Source; if (DstBlock->CurChain != ChainPred) continue; size_t Offset = DstBlock->CurIndex; - tryChainMerging(Offset, {MergeTypeTy::X1_Y_X2, MergeTypeTy::Y_X2_X1}); + tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1}); } } // Try to break ChainPred in various ways and concatenate with ChainSucc - if (ChainPred->blocks().size() <= ChainSplitThreshold) { - for (size_t Offset = 1; Offset < ChainPred->blocks().size(); Offset++) { + if (ChainPred->Nodes.size() <= ChainSplitThreshold) { + for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) { // Try to split the chain in different ways. In practice, applying // X2_Y_X1 merging is almost never provides benefits; thus, we exclude // it from consideration to reduce the search space - tryChainMerging(Offset, {MergeTypeTy::X1_Y_X2, MergeTypeTy::Y_X2_X1, - MergeTypeTy::X2_X1_Y}); + tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1, + MergeTypeT::X2_X1_Y}); } } Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain); @@ -805,96 +869,66 @@ private: /// merge 'type' and 'offset'. /// /// The two chains are not modified in the method. - MergeGainTy computeMergeGain(const Chain *ChainPred, const Chain *ChainSucc, - const std::vector<Jump *> &Jumps, - size_t MergeOffset, - MergeTypeTy MergeType) const { - auto MergedBlocks = mergeBlocks(ChainPred->blocks(), ChainSucc->blocks(), - MergeOffset, MergeType); - - // Do not allow a merge that does not preserve the original entry block + MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc, + const std::vector<JumpT *> &Jumps, + size_t MergeOffset, MergeTypeT MergeType) const { + auto MergedBlocks = + mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType); + + // Do not allow a merge that does not preserve the original entry point if ((ChainPred->isEntry() || ChainSucc->isEntry()) && - !MergedBlocks.getFirstBlock()->isEntry()) - return MergeGainTy(); + !MergedBlocks.getFirstNode()->isEntry()) + return MergeGainT(); // The gain for the new chain - auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->score(); - return MergeGainTy(NewGainScore, MergeOffset, MergeType); - } - - /// Merge two chains of blocks respecting a given merge 'type' and 'offset'. - /// - /// If MergeType == 0, then the result is a concatenation of two chains. - /// Otherwise, the first chain is cut into two sub-chains at the offset, - /// and merged using all possible ways of concatenating three chains. - MergedChain mergeBlocks(const std::vector<Block *> &X, - const std::vector<Block *> &Y, size_t MergeOffset, - MergeTypeTy MergeType) const { - // Split the first chain, X, into X1 and X2 - BlockIter BeginX1 = X.begin(); - BlockIter EndX1 = X.begin() + MergeOffset; - BlockIter BeginX2 = X.begin() + MergeOffset; - BlockIter EndX2 = X.end(); - BlockIter BeginY = Y.begin(); - BlockIter EndY = Y.end(); - - // Construct a new chain from the three existing ones - switch (MergeType) { - case MergeTypeTy::X_Y: - return MergedChain(BeginX1, EndX2, BeginY, EndY); - case MergeTypeTy::X1_Y_X2: - return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2); - case MergeTypeTy::Y_X2_X1: - return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1); - case MergeTypeTy::X2_X1_Y: - return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY); - } - llvm_unreachable("unexpected chain merge type"); + auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->Score; + return MergeGainT(NewGainScore, MergeOffset, MergeType); } /// Merge chain From into chain Into, update the list of active chains, /// adjacency information, and the corresponding cached values. - void mergeChains(Chain *Into, Chain *From, size_t MergeOffset, - MergeTypeTy MergeType) { + void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset, + MergeTypeT MergeType) { assert(Into != From && "a chain cannot be merged with itself"); - // Merge the blocks - MergedChain MergedBlocks = - mergeBlocks(Into->blocks(), From->blocks(), MergeOffset, MergeType); - Into->merge(From, MergedBlocks.getBlocks()); + // Merge the nodes + MergedChain MergedNodes = + mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType); + Into->merge(From, MergedNodes.getNodes()); + + // Merge the edges Into->mergeEdges(From); From->clear(); // Update cached ext-tsp score for the new chain ChainEdge *SelfEdge = Into->getEdge(Into); if (SelfEdge != nullptr) { - MergedBlocks = MergedChain(Into->blocks().begin(), Into->blocks().end()); - Into->setScore(extTSPScore(MergedBlocks, SelfEdge->jumps())); + MergedNodes = MergedChain(Into->Nodes.begin(), Into->Nodes.end()); + Into->Score = extTSPScore(MergedNodes, SelfEdge->jumps()); } - // Remove chain From from the list of active chains + // Remove the chain from the list of active chains llvm::erase_value(HotChains, From); // Invalidate caches - for (auto EdgeIter : Into->edges()) { - EdgeIter.second->invalidateCache(); - } + for (auto EdgeIt : Into->Edges) + EdgeIt.second->invalidateCache(); } - /// Concatenate all chains into a final order of blocks. + /// Concatenate all chains into the final order. void concatChains(std::vector<uint64_t> &Order) { - // Collect chains and calculate some stats for their sorting - std::vector<Chain *> SortedChains; - DenseMap<const Chain *, double> ChainDensity; - for (auto &Chain : AllChains) { - if (!Chain.blocks().empty()) { + // Collect chains and calculate density stats for their sorting + std::vector<const ChainT *> SortedChains; + DenseMap<const ChainT *, double> ChainDensity; + for (ChainT &Chain : AllChains) { + if (!Chain.Nodes.empty()) { SortedChains.push_back(&Chain); - // Using doubles to avoid overflow of ExecutionCount + // Using doubles to avoid overflow of ExecutionCounts double Size = 0; double ExecutionCount = 0; - for (auto *Block : Chain.blocks()) { - Size += static_cast<double>(Block->Size); - ExecutionCount += static_cast<double>(Block->ExecutionCount); + for (NodeT *Node : Chain.Nodes) { + Size += static_cast<double>(Node->Size); + ExecutionCount += static_cast<double>(Node->ExecutionCount); } assert(Size > 0 && "a chain of zero size"); ChainDensity[&Chain] = ExecutionCount / Size; @@ -903,24 +937,23 @@ private: // Sorting chains by density in the decreasing order std::stable_sort(SortedChains.begin(), SortedChains.end(), - [&](const Chain *C1, const Chain *C2) { - // Make sure the original entry block is at the + [&](const ChainT *L, const ChainT *R) { + // Make sure the original entry point is at the // beginning of the order - if (C1->isEntry() != C2->isEntry()) { - return C1->isEntry(); - } + if (L->isEntry() != R->isEntry()) + return L->isEntry(); - const double D1 = ChainDensity[C1]; - const double D2 = ChainDensity[C2]; + const double DL = ChainDensity[L]; + const double DR = ChainDensity[R]; // Compare by density and break ties by chain identifiers - return (D1 != D2) ? (D1 > D2) : (C1->id() < C2->id()); + return (DL != DR) ? (DL > DR) : (L->Id < R->Id); }); - // Collect the blocks in the order specified by their chains + // Collect the nodes in the order specified by their chains Order.reserve(NumNodes); - for (Chain *Chain : SortedChains) { - for (Block *Block : Chain->blocks()) { - Order.push_back(Block->Index); + for (const ChainT *Chain : SortedChains) { + for (NodeT *Node : Chain->Nodes) { + Order.push_back(Node->Index); } } } @@ -935,49 +968,47 @@ private: /// Predecessors of each node. std::vector<std::vector<uint64_t>> PredNodes; - /// All basic blocks. - std::vector<Block> AllBlocks; + /// All nodes (basic blocks) in the graph. + std::vector<NodeT> AllNodes; - /// All jumps between blocks. - std::vector<Jump> AllJumps; + /// All jumps between the nodes. + std::vector<JumpT> AllJumps; - /// All chains of basic blocks. - std::vector<Chain> AllChains; + /// All chains of nodes. + std::vector<ChainT> AllChains; - /// All edges between chains. + /// All edges between the chains. std::vector<ChainEdge> AllEdges; /// Active chains. The vector gets updated at runtime when chains are merged. - std::vector<Chain *> HotChains; + std::vector<ChainT *> HotChains; }; } // end of anonymous namespace -std::vector<uint64_t> llvm::applyExtTspLayout( - const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) { - size_t NumNodes = NodeSizes.size(); - - // Verify correctness of the input data. +std::vector<uint64_t> +llvm::applyExtTspLayout(const std::vector<uint64_t> &NodeSizes, + const std::vector<uint64_t> &NodeCounts, + const std::vector<EdgeCountT> &EdgeCounts) { + // Verify correctness of the input data assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input"); - assert(NumNodes > 2 && "Incorrect input"); + assert(NodeSizes.size() > 2 && "Incorrect input"); - // Apply the reordering algorithm. - auto Alg = ExtTSPImpl(NumNodes, NodeSizes, NodeCounts, EdgeCounts); + // Apply the reordering algorithm + ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts); std::vector<uint64_t> Result; Alg.run(Result); - // Verify correctness of the output. + // Verify correctness of the output assert(Result.front() == 0 && "Original entry point is not preserved"); - assert(Result.size() == NumNodes && "Incorrect size of reordered layout"); + assert(Result.size() == NodeSizes.size() && "Incorrect size of layout"); return Result; } -double llvm::calcExtTspScore( - const std::vector<uint64_t> &Order, const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) { +double llvm::calcExtTspScore(const std::vector<uint64_t> &Order, + const std::vector<uint64_t> &NodeSizes, + const std::vector<uint64_t> &NodeCounts, + const std::vector<EdgeCountT> &EdgeCounts) { // Estimate addresses of the blocks in memory std::vector<uint64_t> Addr(NodeSizes.size(), 0); for (size_t Idx = 1; Idx < Order.size(); Idx++) { @@ -985,15 +1016,15 @@ double llvm::calcExtTspScore( } std::vector<uint64_t> OutDegree(NodeSizes.size(), 0); for (auto It : EdgeCounts) { - auto Pred = It.first.first; + uint64_t Pred = It.first.first; OutDegree[Pred]++; } // Increase the score for each jump double Score = 0; for (auto It : EdgeCounts) { - auto Pred = It.first.first; - auto Succ = It.first.second; + uint64_t Pred = It.first.first; + uint64_t Succ = It.first.second; uint64_t Count = It.second; bool IsConditional = OutDegree[Pred] > 1; Score += ::extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count, @@ -1002,10 +1033,9 @@ double llvm::calcExtTspScore( return Score; } -double llvm::calcExtTspScore( - const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) { +double llvm::calcExtTspScore(const std::vector<uint64_t> &NodeSizes, + const std::vector<uint64_t> &NodeCounts, + const std::vector<EdgeCountT> &EdgeCounts) { std::vector<uint64_t> Order(NodeSizes.size()); for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) { Order[Idx] = Idx; diff --git a/llvm/lib/Transforms/Utils/CountVisits.cpp b/llvm/lib/Transforms/Utils/CountVisits.cpp new file mode 100644 index 000000000000..4faded8fc656 --- /dev/null +++ b/llvm/lib/Transforms/Utils/CountVisits.cpp @@ -0,0 +1,25 @@ +//===- CountVisits.cpp ----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CountVisits.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/PassManager.h" + +using namespace llvm; + +#define DEBUG_TYPE "count-visits" + +STATISTIC(MaxVisited, "Max number of times we visited a function"); + +PreservedAnalyses CountVisitsPass::run(Function &F, FunctionAnalysisManager &) { + uint32_t Count = Counts[F.getName()] + 1; + Counts[F.getName()] = Count; + if (Count > MaxVisited) + MaxVisited = Count; + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/Utils/CtorUtils.cpp b/llvm/lib/Transforms/Utils/CtorUtils.cpp index c997f39508e3..e07c92df2265 100644 --- a/llvm/lib/Transforms/Utils/CtorUtils.cpp +++ b/llvm/lib/Transforms/Utils/CtorUtils.cpp @@ -48,7 +48,7 @@ static void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemov GlobalVariable *NGV = new GlobalVariable(CA->getType(), GCL->isConstant(), GCL->getLinkage(), CA, "", GCL->getThreadLocalMode()); - GCL->getParent()->getGlobalList().insert(GCL->getIterator(), NGV); + GCL->getParent()->insertGlobalVariable(GCL->getIterator(), NGV); NGV->takeName(GCL); // Nuke the old list, replacing any uses with the new one. diff --git a/llvm/lib/Transforms/Utils/Debugify.cpp b/llvm/lib/Transforms/Utils/Debugify.cpp index 989473693a0b..93cad0888a56 100644 --- a/llvm/lib/Transforms/Utils/Debugify.cpp +++ b/llvm/lib/Transforms/Utils/Debugify.cpp @@ -979,7 +979,9 @@ PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { collectDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass, "ModuleDebugify (original debuginfo)", NameOfWrappedPass); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } ModulePass *createCheckDebugifyModulePass( @@ -1027,45 +1029,58 @@ static bool isIgnoredPass(StringRef PassID) { } void DebugifyEachInstrumentation::registerCallbacks( - PassInstrumentationCallbacks &PIC) { - PIC.registerBeforeNonSkippedPassCallback([this](StringRef P, Any IR) { - if (isIgnoredPass(P)) - return; - if (const auto **F = any_cast<const Function *>(&IR)) - applyDebugify(*const_cast<Function *>(*F), - Mode, DebugInfoBeforePass, P); - else if (const auto **M = any_cast<const Module *>(&IR)) - applyDebugify(*const_cast<Module *>(*M), - Mode, DebugInfoBeforePass, P); - }); - PIC.registerAfterPassCallback([this](StringRef P, Any IR, - const PreservedAnalyses &PassPA) { + PassInstrumentationCallbacks &PIC, ModuleAnalysisManager &MAM) { + PIC.registerBeforeNonSkippedPassCallback([this, &MAM](StringRef P, Any IR) { if (isIgnoredPass(P)) return; + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); if (const auto **CF = any_cast<const Function *>(&IR)) { - auto &F = *const_cast<Function *>(*CF); - Module &M = *F.getParent(); - auto It = F.getIterator(); - if (Mode == DebugifyMode::SyntheticDebugInfo) - checkDebugifyMetadata(M, make_range(It, std::next(It)), P, - "CheckFunctionDebugify", /*Strip=*/true, DIStatsMap); - else - checkDebugInfoMetadata( - M, make_range(It, std::next(It)), *DebugInfoBeforePass, - "CheckModuleDebugify (original debuginfo)", - P, OrigDIVerifyBugsReportFilePath); + Function &F = *const_cast<Function *>(*CF); + applyDebugify(F, Mode, DebugInfoBeforePass, P); + MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent()) + .getManager() + .invalidate(F, PA); } else if (const auto **CM = any_cast<const Module *>(&IR)) { - auto &M = *const_cast<Module *>(*CM); - if (Mode == DebugifyMode::SyntheticDebugInfo) - checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify", - /*Strip=*/true, DIStatsMap); - else - checkDebugInfoMetadata( - M, M.functions(), *DebugInfoBeforePass, - "CheckModuleDebugify (original debuginfo)", - P, OrigDIVerifyBugsReportFilePath); + Module &M = *const_cast<Module *>(*CM); + applyDebugify(M, Mode, DebugInfoBeforePass, P); + MAM.invalidate(M, PA); } }); + PIC.registerAfterPassCallback( + [this, &MAM](StringRef P, Any IR, const PreservedAnalyses &PassPA) { + if (isIgnoredPass(P)) + return; + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + if (const auto **CF = any_cast<const Function *>(&IR)) { + auto &F = *const_cast<Function *>(*CF); + Module &M = *F.getParent(); + auto It = F.getIterator(); + if (Mode == DebugifyMode::SyntheticDebugInfo) + checkDebugifyMetadata(M, make_range(It, std::next(It)), P, + "CheckFunctionDebugify", /*Strip=*/true, + DIStatsMap); + else + checkDebugInfoMetadata(M, make_range(It, std::next(It)), + *DebugInfoBeforePass, + "CheckModuleDebugify (original debuginfo)", + P, OrigDIVerifyBugsReportFilePath); + MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent()) + .getManager() + .invalidate(F, PA); + } else if (const auto **CM = any_cast<const Module *>(&IR)) { + Module &M = *const_cast<Module *>(*CM); + if (Mode == DebugifyMode::SyntheticDebugInfo) + checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify", + /*Strip=*/true, DIStatsMap); + else + checkDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass, + "CheckModuleDebugify (original debuginfo)", + P, OrigDIVerifyBugsReportFilePath); + MAM.invalidate(M, PA); + } + }); } char DebugifyModulePass::ID = 0; diff --git a/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp b/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp index 086ea088dc5e..c894afee68a2 100644 --- a/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -74,6 +74,7 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, V = new LoadInst(I.getType(), Slot, I.getName() + ".reload", VolatileLoads, PN->getIncomingBlock(i)->getTerminator()); + Loads[PN->getIncomingBlock(i)] = V; } PN->setIncomingValue(i, V); } diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 53af1b1969c2..d424ebbef99d 100644 --- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" -#include "llvm/ADT/Triple.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" @@ -16,9 +15,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Utils.h" +#include "llvm/TargetParser/Triple.h" using namespace llvm; @@ -83,6 +80,13 @@ static void insertCall(Function &CurFn, StringRef Func, } static bool runOnFunction(Function &F, bool PostInlining) { + // The asm in a naked function may reasonably expect the argument registers + // and the return address register (if present) to be live. An inserted + // function call will clobber these registers. Simply skip naked functions for + // all targets. + if (F.hasFnAttribute(Attribute::Naked)) + return false; + StringRef EntryAttr = PostInlining ? "instrument-function-entry-inlined" : "instrument-function-entry"; @@ -145,8 +149,8 @@ void llvm::EntryExitInstrumenterPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<llvm::EntryExitInstrumenterPass> *>(this) ->printPipeline(OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (PostInlining) OS << "post-inline"; - OS << ">"; + OS << '>'; } diff --git a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp index 91053338df5f..88c838685bca 100644 --- a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -12,9 +12,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/EscapeEnumerator.h" -#include "llvm/ADT/Triple.h" -#include "llvm/Analysis/EHPersonalities.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Module.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; diff --git a/llvm/lib/Transforms/Utils/Evaluator.cpp b/llvm/lib/Transforms/Utils/Evaluator.cpp index dc58bebd724b..23c1ca366a44 100644 --- a/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -121,7 +121,7 @@ isSimpleEnoughValueToCommit(Constant *C, } void Evaluator::MutableValue::clear() { - if (auto *Agg = Val.dyn_cast<MutableAggregate *>()) + if (auto *Agg = dyn_cast_if_present<MutableAggregate *>(Val)) delete Agg; Val = nullptr; } @@ -130,7 +130,7 @@ Constant *Evaluator::MutableValue::read(Type *Ty, APInt Offset, const DataLayout &DL) const { TypeSize TySize = DL.getTypeStoreSize(Ty); const MutableValue *V = this; - while (const auto *Agg = V->Val.dyn_cast<MutableAggregate *>()) { + while (const auto *Agg = dyn_cast_if_present<MutableAggregate *>(V->Val)) { Type *AggTy = Agg->Ty; std::optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset); if (!Index || Index->uge(Agg->Elements.size()) || @@ -140,11 +140,11 @@ Constant *Evaluator::MutableValue::read(Type *Ty, APInt Offset, V = &Agg->Elements[Index->getZExtValue()]; } - return ConstantFoldLoadFromConst(V->Val.get<Constant *>(), Ty, Offset, DL); + return ConstantFoldLoadFromConst(cast<Constant *>(V->Val), Ty, Offset, DL); } bool Evaluator::MutableValue::makeMutable() { - Constant *C = Val.get<Constant *>(); + Constant *C = cast<Constant *>(Val); Type *Ty = C->getType(); unsigned NumElements; if (auto *VT = dyn_cast<FixedVectorType>(Ty)) { @@ -171,10 +171,10 @@ bool Evaluator::MutableValue::write(Constant *V, APInt Offset, MutableValue *MV = this; while (Offset != 0 || !CastInst::isBitOrNoopPointerCastable(Ty, MV->getType(), DL)) { - if (MV->Val.is<Constant *>() && !MV->makeMutable()) + if (isa<Constant *>(MV->Val) && !MV->makeMutable()) return false; - MutableAggregate *Agg = MV->Val.get<MutableAggregate *>(); + MutableAggregate *Agg = cast<MutableAggregate *>(MV->Val); Type *AggTy = Agg->Ty; std::optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset); if (!Index || Index->uge(Agg->Elements.size()) || @@ -413,16 +413,28 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB, } Constant *Val = getVal(MSI->getValue()); - APInt Len = LenC->getValue(); - while (Len != 0) { - Constant *DestVal = ComputeLoadResult(GV, Val->getType(), Offset); - if (DestVal != Val) { - LLVM_DEBUG(dbgs() << "Memset is not a no-op at offset " - << Offset << " of " << *GV << ".\n"); + // Avoid the byte-per-byte scan if we're memseting a zeroinitializer + // to zero. + if (!Val->isNullValue() || MutatedMemory.contains(GV) || + !GV->hasDefinitiveInitializer() || + !GV->getInitializer()->isNullValue()) { + APInt Len = LenC->getValue(); + if (Len.ugt(64 * 1024)) { + LLVM_DEBUG(dbgs() << "Not evaluating large memset of size " + << Len << "\n"); return false; } - ++Offset; - --Len; + + while (Len != 0) { + Constant *DestVal = ComputeLoadResult(GV, Val->getType(), Offset); + if (DestVal != Val) { + LLVM_DEBUG(dbgs() << "Memset is not a no-op at offset " + << Offset << " of " << *GV << ".\n"); + return false; + } + ++Offset; + --Len; + } } LLVM_DEBUG(dbgs() << "Ignoring no-op memset.\n"); diff --git a/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/llvm/lib/Transforms/Utils/FlattenCFG.cpp index 2fb2ab82e41a..1925b91c4da7 100644 --- a/llvm/lib/Transforms/Utils/FlattenCFG.cpp +++ b/llvm/lib/Transforms/Utils/FlattenCFG.cpp @@ -487,17 +487,10 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { BasicBlock::iterator SaveInsertPt = Builder.GetInsertPoint(); Builder.SetInsertPoint(PBI); if (InvertCond2) { - // If this is a "cmp" instruction, only used for branching (and nowhere - // else), then we can simply invert the predicate. - auto Cmp2 = dyn_cast<CmpInst>(CInst2); - if (Cmp2 && Cmp2->hasOneUse()) - Cmp2->setPredicate(Cmp2->getInversePredicate()); - else - CInst2 = cast<Instruction>(Builder.CreateNot(CInst2)); - PBI->swapSuccessors(); + InvertBranch(PBI, Builder); } - Value *NC = Builder.CreateBinOp(CombineOp, CInst1, CInst2); - PBI->replaceUsesOfWith(CInst2, NC); + Value *NC = Builder.CreateBinOp(CombineOp, CInst1, PBI->getCondition()); + PBI->replaceUsesOfWith(PBI->getCondition(), NC); Builder.SetInsertPoint(SaveInsertBB, SaveInsertPt); // Handle PHI node to replace its predecessors to FirstEntryBlock. diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 3fa61ec68cd3..8daeb92130ba 100644 --- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -157,16 +157,31 @@ int FunctionComparator::cmpAttrs(const AttributeList L, return 0; } -int FunctionComparator::cmpRangeMetadata(const MDNode *L, - const MDNode *R) const { +int FunctionComparator::cmpMetadata(const Metadata *L, + const Metadata *R) const { + // TODO: the following routine coerce the metadata contents into constants + // before comparison. + // It ignores any other cases, so that the metadata nodes are considered + // equal even though this is not correct. + // We should structurally compare the metadata nodes to be perfect here. + auto *CL = dyn_cast<ConstantAsMetadata>(L); + auto *CR = dyn_cast<ConstantAsMetadata>(R); + if (CL == CR) + return 0; + if (!CL) + return -1; + if (!CR) + return 1; + return cmpConstants(CL->getValue(), CR->getValue()); +} + +int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const { if (L == R) return 0; if (!L) return -1; if (!R) return 1; - // Range metadata is a sequence of numbers. Make sure they are the same - // sequence. // TODO: Note that as this is metadata, it is possible to drop and/or merge // this data when considering functions to merge. Thus this comparison would // return 0 (i.e. equivalent), but merging would become more complicated @@ -175,10 +190,30 @@ int FunctionComparator::cmpRangeMetadata(const MDNode *L, // function semantically. if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) return Res; - for (size_t I = 0; I < L->getNumOperands(); ++I) { - ConstantInt *LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); - ConstantInt *RLow = mdconst::extract<ConstantInt>(R->getOperand(I)); - if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue())) + for (size_t I = 0; I < L->getNumOperands(); ++I) + if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I))) + return Res; + return 0; +} + +int FunctionComparator::cmpInstMetadata(Instruction const *L, + Instruction const *R) const { + /// These metadata affects the other optimization passes by making assertions + /// or constraints. + /// Values that carry different expectations should be considered different. + SmallVector<std::pair<unsigned, MDNode *>> MDL, MDR; + L->getAllMetadataOtherThanDebugLoc(MDL); + R->getAllMetadataOtherThanDebugLoc(MDR); + if (MDL.size() > MDR.size()) + return 1; + else if (MDL.size() < MDR.size()) + return -1; + for (size_t I = 0, N = MDL.size(); I < N; ++I) { + auto const [KeyL, ML] = MDL[I]; + auto const [KeyR, MR] = MDR[I]; + if (int Res = cmpNumbers(KeyL, KeyR)) + return Res; + if (int Res = cmpMDNode(ML, MR)) return Res; } return 0; @@ -586,9 +621,7 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(LI->getSyncScopeID(), cast<LoadInst>(R)->getSyncScopeID())) return Res; - return cmpRangeMetadata( - LI->getMetadata(LLVMContext::MD_range), - cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); + return cmpInstMetadata(L, R); } if (const StoreInst *SI = dyn_cast<StoreInst>(L)) { if (int Res = @@ -616,8 +649,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(CI->getTailCallKind(), cast<CallInst>(R)->getTailCallKind())) return Res; - return cmpRangeMetadata(L->getMetadata(LLVMContext::MD_range), - R->getMetadata(LLVMContext::MD_range)); + return cmpMDNode(L->getMetadata(LLVMContext::MD_range), + R->getMetadata(LLVMContext::MD_range)); } if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) { ArrayRef<unsigned> LIndices = IVI->getIndices(); @@ -715,8 +748,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, // When we have target data, we can reduce the GEP down to the value in bytes // added to the address. const DataLayout &DL = FnL->getParent()->getDataLayout(); - unsigned BitWidth = DL.getPointerSizeInBits(ASL); - APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0); + unsigned OffsetBitWidth = DL.getIndexSizeInBits(ASL); + APInt OffsetL(OffsetBitWidth, 0), OffsetR(OffsetBitWidth, 0); if (GEPL->accumulateConstantOffset(DL, OffsetL) && GEPR->accumulateConstantOffset(DL, OffsetR)) return cmpAPInts(OffsetL, OffsetR); diff --git a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp index 55bcb6f3b121..dab0be3a9fde 100644 --- a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp +++ b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp @@ -19,7 +19,6 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/InstIterator.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -40,7 +39,7 @@ STATISTIC(NumCompUsedAdded, /// CI (other than void) need to be widened to a VectorType of VF /// lanes. static void addVariantDeclaration(CallInst &CI, const ElementCount &VF, - const StringRef VFName) { + bool Predicate, const StringRef VFName) { Module *M = CI.getModule(); // Add function declaration. @@ -50,6 +49,8 @@ static void addVariantDeclaration(CallInst &CI, const ElementCount &VF, Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); assert(!CI.getFunctionType()->isVarArg() && "VarArg functions are not supported."); + if (Predicate) + Tys.push_back(ToVectorTy(Type::getInt1Ty(RetTy->getContext()), VF)); FunctionType *FTy = FunctionType::get(RetTy, Tys, /*isVarArg=*/false); Function *VectorF = Function::Create(FTy, Function::ExternalLinkage, VFName, M); @@ -89,19 +90,19 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) { const SetVector<StringRef> OriginalSetOfMappings(Mappings.begin(), Mappings.end()); - auto AddVariantDecl = [&](const ElementCount &VF) { + auto AddVariantDecl = [&](const ElementCount &VF, bool Predicate) { const std::string TLIName = - std::string(TLI.getVectorizedFunction(ScalarName, VF)); + std::string(TLI.getVectorizedFunction(ScalarName, VF, Predicate)); if (!TLIName.empty()) { - std::string MangledName = - VFABI::mangleTLIVectorName(TLIName, ScalarName, CI.arg_size(), VF); + std::string MangledName = VFABI::mangleTLIVectorName( + TLIName, ScalarName, CI.arg_size(), VF, Predicate); if (!OriginalSetOfMappings.count(MangledName)) { Mappings.push_back(MangledName); ++NumCallInjected; } Function *VariantF = M->getFunction(TLIName); if (!VariantF) - addVariantDeclaration(CI, VF, TLIName); + addVariantDeclaration(CI, VF, Predicate, TLIName); } }; @@ -109,13 +110,15 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) { ElementCount WidestFixedVF, WidestScalableVF; TLI.getWidestVF(ScalarName, WidestFixedVF, WidestScalableVF); - for (ElementCount VF = ElementCount::getFixed(2); - ElementCount::isKnownLE(VF, WidestFixedVF); VF *= 2) - AddVariantDecl(VF); + for (bool Predicated : {false, true}) { + for (ElementCount VF = ElementCount::getFixed(2); + ElementCount::isKnownLE(VF, WidestFixedVF); VF *= 2) + AddVariantDecl(VF, Predicated); - // TODO: Add scalable variants once we're able to test them. - assert(WidestScalableVF.isZero() && - "Scalable vector mappings not yet supported"); + for (ElementCount VF = ElementCount::getScalable(2); + ElementCount::isKnownLE(VF, WidestScalableVF); VF *= 2) + AddVariantDecl(VF, Predicated); + } VFABI::setVectorVariantNames(&CI, Mappings); } @@ -138,39 +141,3 @@ PreservedAnalyses InjectTLIMappings::run(Function &F, // Even if the pass adds IR attributes, the analyses are preserved. return PreservedAnalyses::all(); } - -//////////////////////////////////////////////////////////////////////////////// -// Legacy PM Implementation. -//////////////////////////////////////////////////////////////////////////////// -bool InjectTLIMappingsLegacy::runOnFunction(Function &F) { - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - return runImpl(TLI, F); -} - -void InjectTLIMappingsLegacy::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesCFG(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<LoopAccessLegacyAnalysis>(); - AU.addPreserved<DemandedBitsWrapperPass>(); - AU.addPreserved<OptimizationRemarkEmitterWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); -} - -//////////////////////////////////////////////////////////////////////////////// -// Legacy Pass manager initialization -//////////////////////////////////////////////////////////////////////////////// -char InjectTLIMappingsLegacy::ID = 0; - -INITIALIZE_PASS_BEGIN(InjectTLIMappingsLegacy, DEBUG_TYPE, - "Inject TLI Mappings", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(InjectTLIMappingsLegacy, DEBUG_TYPE, "Inject TLI Mappings", - false, false) - -FunctionPass *llvm::createInjectTLIMappingsLegacyPass() { - return new InjectTLIMappingsLegacy(); -} diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 399c9a43793f..f7b93fc8fd06 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -23,7 +23,6 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryProfileInfo.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" @@ -42,6 +41,7 @@ #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" @@ -99,10 +99,6 @@ PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining", cl::init(false), cl::Hidden, cl::desc("Convert align attributes to assumptions during inlining.")); -static cl::opt<bool> UpdateReturnAttributes( - "update-return-attrs", cl::init(true), cl::Hidden, - cl::desc("Update return attributes on calls within inlined body")); - static cl::opt<unsigned> InlinerAttributeWindow( "max-inst-checked-for-throw-during-inlining", cl::Hidden, cl::desc("the maximum number of instructions analyzed for may throw during " @@ -879,9 +875,6 @@ static void propagateMemProfHelper(const CallBase *OrigCall, // inlined callee's callsite metadata with that of the inlined call, // and moving the subset of any memprof contexts to the inlined callee // allocations if they match the new inlined call stack. -// FIXME: Replace memprof metadata with function attribute if all MIB end up -// having the same behavior. Do other context trimming/merging optimizations -// too. static void propagateMemProfMetadata(Function *Callee, CallBase &CB, bool ContainsMemProfMetadata, @@ -1368,9 +1361,6 @@ static AttrBuilder IdentifyValidAttributes(CallBase &CB) { } static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { - if (!UpdateReturnAttributes) - return; - AttrBuilder Valid = IdentifyValidAttributes(CB); if (!Valid.hasAttributes()) return; @@ -1460,84 +1450,10 @@ static void AddAlignmentAssumptions(CallBase &CB, InlineFunctionInfo &IFI) { } } -/// Once we have cloned code over from a callee into the caller, -/// update the specified callgraph to reflect the changes we made. -/// Note that it's possible that not all code was copied over, so only -/// some edges of the callgraph may remain. -static void UpdateCallGraphAfterInlining(CallBase &CB, - Function::iterator FirstNewBlock, - ValueToValueMapTy &VMap, - InlineFunctionInfo &IFI) { - CallGraph &CG = *IFI.CG; - const Function *Caller = CB.getCaller(); - const Function *Callee = CB.getCalledFunction(); - CallGraphNode *CalleeNode = CG[Callee]; - CallGraphNode *CallerNode = CG[Caller]; - - // Since we inlined some uninlined call sites in the callee into the caller, - // add edges from the caller to all of the callees of the callee. - CallGraphNode::iterator I = CalleeNode->begin(), E = CalleeNode->end(); - - // Consider the case where CalleeNode == CallerNode. - CallGraphNode::CalledFunctionsVector CallCache; - if (CalleeNode == CallerNode) { - CallCache.assign(I, E); - I = CallCache.begin(); - E = CallCache.end(); - } - - for (; I != E; ++I) { - // Skip 'refererence' call records. - if (!I->first) - continue; - - const Value *OrigCall = *I->first; - - ValueToValueMapTy::iterator VMI = VMap.find(OrigCall); - // Only copy the edge if the call was inlined! - if (VMI == VMap.end() || VMI->second == nullptr) - continue; - - // If the call was inlined, but then constant folded, there is no edge to - // add. Check for this case. - auto *NewCall = dyn_cast<CallBase>(VMI->second); - if (!NewCall) - continue; - - // We do not treat intrinsic calls like real function calls because we - // expect them to become inline code; do not add an edge for an intrinsic. - if (NewCall->getCalledFunction() && - NewCall->getCalledFunction()->isIntrinsic()) - continue; - - // Remember that this call site got inlined for the client of - // InlineFunction. - IFI.InlinedCalls.push_back(NewCall); - - // It's possible that inlining the callsite will cause it to go from an - // indirect to a direct call by resolving a function pointer. If this - // happens, set the callee of the new call site to a more precise - // destination. This can also happen if the call graph node of the caller - // was just unnecessarily imprecise. - if (!I->second->getFunction()) - if (Function *F = NewCall->getCalledFunction()) { - // Indirect call site resolved to direct call. - CallerNode->addCalledFunction(NewCall, CG[F]); - - continue; - } - - CallerNode->addCalledFunction(NewCall, I->second); - } - - // Update the call graph by deleting the edge from Callee to Caller. We must - // do this after the loop above in case Caller and Callee are the same. - CallerNode->removeCallEdgeFor(*cast<CallBase>(&CB)); -} - static void HandleByValArgumentInit(Type *ByValType, Value *Dst, Value *Src, Module *M, BasicBlock *InsertBlock, - InlineFunctionInfo &IFI) { + InlineFunctionInfo &IFI, + Function *CalledFunc) { IRBuilder<> Builder(InsertBlock, InsertBlock->begin()); Value *Size = @@ -1546,8 +1462,15 @@ static void HandleByValArgumentInit(Type *ByValType, Value *Dst, Value *Src, // Always generate a memcpy of alignment 1 here because we don't know // the alignment of the src pointer. Other optimizations can infer // better alignment. - Builder.CreateMemCpy(Dst, /*DstAlign*/ Align(1), Src, - /*SrcAlign*/ Align(1), Size); + CallInst *CI = Builder.CreateMemCpy(Dst, /*DstAlign*/ Align(1), Src, + /*SrcAlign*/ Align(1), Size); + + // The verifier requires that all calls of debug-info-bearing functions + // from debug-info-bearing functions have a debug location (for inlining + // purposes). Assign a dummy location to satisfy the constraint. + if (!CI->getDebugLoc() && InsertBlock->getParent()->getSubprogram()) + if (DISubprogram *SP = CalledFunc->getSubprogram()) + CI->setDebugLoc(DILocation::get(SP->getContext(), 0, 0, SP)); } /// When inlining a call site that has a byval argument, @@ -1557,8 +1480,6 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, const Function *CalledFunc, InlineFunctionInfo &IFI, MaybeAlign ByValAlignment) { - assert(cast<PointerType>(Arg->getType()) - ->isOpaqueOrPointeeTypeMatches(ByValType)); Function *Caller = TheCall->getFunction(); const DataLayout &DL = Caller->getParent()->getDataLayout(); @@ -1710,6 +1631,12 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, if (allocaWouldBeStaticInEntry(AI)) continue; + // Do not force a debug loc for pseudo probes, since they do not need to + // be debuggable, and also they are expected to have a zero/null dwarf + // discriminator at this point which could be violated otherwise. + if (isa<PseudoProbeInst>(BI)) + continue; + BI->setDebugLoc(TheCallDL); } @@ -2242,7 +2169,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Inject byval arguments initialization. for (ByValInit &Init : ByValInits) HandleByValArgumentInit(Init.Ty, Init.Dst, Init.Src, Caller->getParent(), - &*FirstNewBlock, IFI); + &*FirstNewBlock, IFI, CalledFunc); std::optional<OperandBundleUse> ParentDeopt = CB.getOperandBundle(LLVMContext::OB_deopt); @@ -2292,10 +2219,6 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, } } - // Update the callgraph if requested. - if (IFI.CG) - UpdateCallGraphAfterInlining(CB, FirstNewBlock, VMap, IFI); - // For 'nodebug' functions, the associated DISubprogram is always null. // Conservatively avoid propagating the callsite debug location to // instructions inlined from a function whose DISubprogram is not null. @@ -2333,7 +2256,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, for (BasicBlock &NewBlock : make_range(FirstNewBlock->getIterator(), Caller->end())) for (Instruction &I : NewBlock) - if (auto *II = dyn_cast<CondGuardInst>(&I)) + if (auto *II = dyn_cast<AssumeInst>(&I)) IFI.GetAssumptionCache(*Caller).registerAssumption(II); } @@ -2701,7 +2624,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // call graph updates weren't requested, as those provide value handle based // tracking of inlined call sites instead. Calls to intrinsics are not // collected because they are not inlineable. - if (InlinedFunctionInfo.ContainsCalls && !IFI.CG) { + if (InlinedFunctionInfo.ContainsCalls) { // Otherwise just collect the raw call sites that were inlined. for (BasicBlock &NewBB : make_range(FirstNewBlock->getIterator(), Caller->end())) @@ -2734,7 +2657,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (!CB.use_empty()) { ReturnInst *R = Returns[0]; if (&CB == R->getReturnValue()) - CB.replaceAllUsesWith(UndefValue::get(CB.getType())); + CB.replaceAllUsesWith(PoisonValue::get(CB.getType())); else CB.replaceAllUsesWith(R->getReturnValue()); } @@ -2846,7 +2769,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // using the return value of the call with the computed value. if (!CB.use_empty()) { if (&CB == Returns[0]->getReturnValue()) - CB.replaceAllUsesWith(UndefValue::get(CB.getType())); + CB.replaceAllUsesWith(PoisonValue::get(CB.getType())); else CB.replaceAllUsesWith(Returns[0]->getReturnValue()); } diff --git a/llvm/lib/Transforms/Utils/InstructionNamer.cpp b/llvm/lib/Transforms/Utils/InstructionNamer.cpp index f3499c9c8aed..3ae570cfeb77 100644 --- a/llvm/lib/Transforms/Utils/InstructionNamer.cpp +++ b/llvm/lib/Transforms/Utils/InstructionNamer.cpp @@ -17,9 +17,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Utils.h" using namespace llvm; @@ -41,35 +38,7 @@ void nameInstructions(Function &F) { } } -struct InstNamer : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - InstNamer() : FunctionPass(ID) { - initializeInstNamerPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &Info) const override { - Info.setPreservesAll(); - } - - bool runOnFunction(Function &F) override { - nameInstructions(F); - return true; - } -}; - - char InstNamer::ID = 0; - } // namespace - -INITIALIZE_PASS(InstNamer, "instnamer", - "Assign names to anonymous instructions", false, false) -char &llvm::InstructionNamerID = InstNamer::ID; -//===----------------------------------------------------------------------===// -// -// InstructionNamer - Give any unnamed non-void instructions "tmp" names. -// -FunctionPass *llvm::createInstructionNamerPass() { - return new InstNamer(); -} +} // namespace PreservedAnalyses InstructionNamerPass::run(Function &F, FunctionAnalysisManager &FAM) { diff --git a/llvm/lib/Transforms/Utils/LCSSA.cpp b/llvm/lib/Transforms/Utils/LCSSA.cpp index af79dc456ea6..c36b0533580b 100644 --- a/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -40,7 +40,6 @@ #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PredIteratorCache.h" @@ -77,15 +76,14 @@ static bool isExitBlock(BasicBlock *BB, /// rewrite the uses. bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, const DominatorTree &DT, const LoopInfo &LI, - ScalarEvolution *SE, IRBuilderBase &Builder, - SmallVectorImpl<PHINode *> *PHIsToRemove) { + ScalarEvolution *SE, + SmallVectorImpl<PHINode *> *PHIsToRemove, + SmallVectorImpl<PHINode *> *InsertedPHIs) { SmallVector<Use *, 16> UsesToRewrite; SmallSetVector<PHINode *, 16> LocalPHIsToRemove; PredIteratorCache PredCache; bool Changed = false; - IRBuilderBase::InsertPointGuard InsertPtGuard(Builder); - // Cache the Loop ExitBlocks across this loop. We expect to get a lot of // instructions within the same loops, computing the exit blocks is // expensive, and we're not mutating the loop structure. @@ -146,17 +144,14 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, SmallVector<PHINode *, 16> AddedPHIs; SmallVector<PHINode *, 8> PostProcessPHIs; - SmallVector<PHINode *, 4> InsertedPHIs; - SSAUpdater SSAUpdate(&InsertedPHIs); + SmallVector<PHINode *, 4> LocalInsertedPHIs; + SSAUpdater SSAUpdate(&LocalInsertedPHIs); SSAUpdate.Initialize(I->getType(), I->getName()); - // Force re-computation of I, as some users now need to use the new PHI - // node. - if (SE) - SE->forgetValue(I); - // Insert the LCSSA phi's into all of the exit blocks dominated by the // value, and add them to the Phi's map. + bool HasSCEV = SE && SE->isSCEVable(I->getType()) && + SE->getExistingSCEV(I) != nullptr; for (BasicBlock *ExitBB : ExitBlocks) { if (!DT.dominates(DomNode, DT.getNode(ExitBB))) continue; @@ -164,9 +159,10 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // If we already inserted something for this BB, don't reprocess it. if (SSAUpdate.HasValueForBlock(ExitBB)) continue; - Builder.SetInsertPoint(&ExitBB->front()); - PHINode *PN = Builder.CreatePHI(I->getType(), PredCache.size(ExitBB), - I->getName() + ".lcssa"); + PHINode *PN = PHINode::Create(I->getType(), PredCache.size(ExitBB), + I->getName() + ".lcssa", &ExitBB->front()); + if (InsertedPHIs) + InsertedPHIs->push_back(PN); // Get the debug location from the original instruction. PN->setDebugLoc(I->getDebugLoc()); @@ -203,6 +199,13 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, if (auto *OtherLoop = LI.getLoopFor(ExitBB)) if (!L->contains(OtherLoop)) PostProcessPHIs.push_back(PN); + + // If we have a cached SCEV for the original instruction, make sure the + // new LCSSA phi node is also cached. This makes sures that BECounts + // based on it will be invalidated when the LCSSA phi node is invalidated, + // which some passes rely on. + if (HasSCEV) + SE->getSCEV(PN); } // Rewrite all uses outside the loop in terms of the new PHIs we just @@ -256,10 +259,12 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // SSAUpdater might have inserted phi-nodes inside other loops. We'll need // to post-process them to keep LCSSA form. - for (PHINode *InsertedPN : InsertedPHIs) { + for (PHINode *InsertedPN : LocalInsertedPHIs) { if (auto *OtherLoop = LI.getLoopFor(InsertedPN->getParent())) if (!L->contains(OtherLoop)) PostProcessPHIs.push_back(InsertedPN); + if (InsertedPHIs) + InsertedPHIs->push_back(InsertedPN); } // Post process PHI instructions that were inserted into another disjoint @@ -392,14 +397,7 @@ bool llvm::formLCSSA(Loop &L, const DominatorTree &DT, const LoopInfo *LI, } } - IRBuilder<> Builder(L.getHeader()->getContext()); - Changed = formLCSSAForInstructions(Worklist, DT, *LI, SE, Builder); - - // If we modified the code, remove any caches about the loop from SCEV to - // avoid dangling entries. - // FIXME: This is a big hammer, can we clear the cache more selectively? - if (SE && Changed) - SE->forgetLoop(&L); + Changed = formLCSSAForInstructions(Worklist, DT, *LI, SE); assert(L.isLCSSAForm(DT)); diff --git a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 5dd469c7af4b..cdcfb5050bff 100644 --- a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -28,6 +28,7 @@ #include "llvm/Transforms/Utils/LibCallsShrinkWrap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Constants.h" @@ -37,8 +38,6 @@ #include "llvm/IR/InstVisitor.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cmath> @@ -51,31 +50,10 @@ STATISTIC(NumWrappedOneCond, "Number of One-Condition Wrappers Inserted"); STATISTIC(NumWrappedTwoCond, "Number of Two-Condition Wrappers Inserted"); namespace { -class LibCallsShrinkWrapLegacyPass : public FunctionPass { -public: - static char ID; // Pass identification, replacement for typeid - explicit LibCallsShrinkWrapLegacyPass() : FunctionPass(ID) { - initializeLibCallsShrinkWrapLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; -}; -} - -char LibCallsShrinkWrapLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap", - "Conditionally eliminate dead library calls", false, - false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap", - "Conditionally eliminate dead library calls", false, false) - -namespace { class LibCallsShrinkWrap : public InstVisitor<LibCallsShrinkWrap> { public: - LibCallsShrinkWrap(const TargetLibraryInfo &TLI, DominatorTree *DT) - : TLI(TLI), DT(DT){}; + LibCallsShrinkWrap(const TargetLibraryInfo &TLI, DomTreeUpdater &DTU) + : TLI(TLI), DTU(DTU){}; void visitCallInst(CallInst &CI) { checkCandidate(CI); } bool perform() { bool Changed = false; @@ -101,14 +79,21 @@ private: Value *generateTwoRangeCond(CallInst *CI, const LibFunc &Func); Value *generateCondForPow(CallInst *CI, const LibFunc &Func); + // Create an OR of two conditions with given Arg and Arg2. + Value *createOrCond(CallInst *CI, Value *Arg, CmpInst::Predicate Cmp, + float Val, Value *Arg2, CmpInst::Predicate Cmp2, + float Val2) { + IRBuilder<> BBBuilder(CI); + auto Cond2 = createCond(BBBuilder, Arg2, Cmp2, Val2); + auto Cond1 = createCond(BBBuilder, Arg, Cmp, Val); + return BBBuilder.CreateOr(Cond1, Cond2); + } + // Create an OR of two conditions. Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val, CmpInst::Predicate Cmp2, float Val2) { - IRBuilder<> BBBuilder(CI); Value *Arg = CI->getArgOperand(0); - auto Cond2 = createCond(BBBuilder, Arg, Cmp2, Val2); - auto Cond1 = createCond(BBBuilder, Arg, Cmp, Val); - return BBBuilder.CreateOr(Cond1, Cond2); + return createOrCond(CI, Arg, Cmp, Val, Arg, Cmp2, Val2); } // Create a single condition using IRBuilder. @@ -117,18 +102,26 @@ private: Constant *V = ConstantFP::get(BBBuilder.getContext(), APFloat(Val)); if (!Arg->getType()->isFloatTy()) V = ConstantExpr::getFPExtend(V, Arg->getType()); + if (BBBuilder.GetInsertBlock()->getParent()->hasFnAttribute(Attribute::StrictFP)) + BBBuilder.setIsFPConstrained(true); return BBBuilder.CreateFCmp(Cmp, Arg, V); } + // Create a single condition with given Arg. + Value *createCond(CallInst *CI, Value *Arg, CmpInst::Predicate Cmp, + float Val) { + IRBuilder<> BBBuilder(CI); + return createCond(BBBuilder, Arg, Cmp, Val); + } + // Create a single condition. Value *createCond(CallInst *CI, CmpInst::Predicate Cmp, float Val) { - IRBuilder<> BBBuilder(CI); Value *Arg = CI->getArgOperand(0); - return createCond(BBBuilder, Arg, Cmp, Val); + return createCond(CI, Arg, Cmp, Val); } const TargetLibraryInfo &TLI; - DominatorTree *DT; + DomTreeUpdater &DTU; SmallVector<CallInst *, 16> WorkList; }; } // end anonymous namespace @@ -428,7 +421,6 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, Value *Base = CI->getArgOperand(0); Value *Exp = CI->getArgOperand(1); - IRBuilder<> BBBuilder(CI); // Constant Base case. if (ConstantFP *CF = dyn_cast<ConstantFP>(Base)) { @@ -439,10 +431,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, } ++NumWrappedOneCond; - Constant *V = ConstantFP::get(CI->getContext(), APFloat(127.0f)); - if (!Exp->getType()->isFloatTy()) - V = ConstantExpr::getFPExtend(V, Exp->getType()); - return BBBuilder.CreateFCmp(CmpInst::FCMP_OGT, Exp, V); + return createCond(CI, Exp, CmpInst::FCMP_OGT, 127.0f); } // If the Base value coming from an integer type. @@ -467,16 +456,8 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, } ++NumWrappedTwoCond; - Constant *V = ConstantFP::get(CI->getContext(), APFloat(UpperV)); - Constant *V0 = ConstantFP::get(CI->getContext(), APFloat(0.0f)); - if (!Exp->getType()->isFloatTy()) - V = ConstantExpr::getFPExtend(V, Exp->getType()); - if (!Base->getType()->isFloatTy()) - V0 = ConstantExpr::getFPExtend(V0, Exp->getType()); - - Value *Cond = BBBuilder.CreateFCmp(CmpInst::FCMP_OGT, Exp, V); - Value *Cond0 = BBBuilder.CreateFCmp(CmpInst::FCMP_OLE, Base, V0); - return BBBuilder.CreateOr(Cond0, Cond); + return createOrCond(CI, Base, CmpInst::FCMP_OLE, 0.0f, Exp, + CmpInst::FCMP_OGT, UpperV); } LLVM_DEBUG(dbgs() << "Not handled pow(): base not from integer convert\n"); return nullptr; @@ -489,7 +470,7 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { MDBuilder(CI->getContext()).createBranchWeights(1, 2000); Instruction *NewInst = - SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, DT); + SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, &DTU); BasicBlock *CallBB = NewInst->getParent(); CallBB->setName("cdce.call"); BasicBlock *SuccBB = CallBB->getSingleSuccessor(); @@ -515,40 +496,21 @@ bool LibCallsShrinkWrap::perform(CallInst *CI) { return performCallErrors(CI, Func); } -void LibCallsShrinkWrapLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); -} - static bool runImpl(Function &F, const TargetLibraryInfo &TLI, DominatorTree *DT) { if (F.hasFnAttribute(Attribute::OptimizeForSize)) return false; - LibCallsShrinkWrap CCDCE(TLI, DT); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + LibCallsShrinkWrap CCDCE(TLI, DTU); CCDCE.visit(F); bool Changed = CCDCE.perform(); -// Verify the dominator after we've updated it locally. - assert(!DT || DT->verify(DominatorTree::VerificationLevel::Fast)); + // Verify the dominator after we've updated it locally. + assert(!DT || + DTU.getDomTree().verify(DominatorTree::VerificationLevel::Fast)); return Changed; } -bool LibCallsShrinkWrapLegacyPass::runOnFunction(Function &F) { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; - return runImpl(F, TLI, DT); -} - -namespace llvm { -char &LibCallsShrinkWrapPassID = LibCallsShrinkWrapLegacyPass::ID; - -// Public interface to LibCallsShrinkWrap pass. -FunctionPass *createLibCallsShrinkWrapPass() { - return new LibCallsShrinkWrapLegacyPass(); -} - PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F, FunctionAnalysisManager &FAM) { auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); @@ -559,4 +521,3 @@ PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F, PA.preserve<DominatorTreeAnalysis>(); return PA; } -} diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 31cdd2ee56b9..f153ace5d3fc 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -25,7 +25,6 @@ #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/DomTreeUpdater.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemorySSAUpdater.h" @@ -47,6 +46,7 @@ #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalObject.h" @@ -201,16 +201,16 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, bool Changed = false; // Figure out which case it goes to. - for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { + for (auto It = SI->case_begin(), End = SI->case_end(); It != End;) { // Found case matching a constant operand? - if (i->getCaseValue() == CI) { - TheOnlyDest = i->getCaseSuccessor(); + if (It->getCaseValue() == CI) { + TheOnlyDest = It->getCaseSuccessor(); break; } // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. - if (i->getCaseSuccessor() == DefaultDest) { + if (It->getCaseSuccessor() == DefaultDest) { MDNode *MD = getValidBranchWeightMDNode(*SI); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches @@ -221,11 +221,11 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, extractBranchWeights(MD, Weights); // Merge weight of this case to the default weight. - unsigned idx = i->getCaseIndex(); + unsigned Idx = It->getCaseIndex(); // TODO: Add overflow check. - Weights[0] += Weights[idx+1]; + Weights[0] += Weights[Idx + 1]; // Remove weight for this case. - std::swap(Weights[idx+1], Weights.back()); + std::swap(Weights[Idx + 1], Weights.back()); Weights.pop_back(); SI->setMetadata(LLVMContext::MD_prof, MDBuilder(BB->getContext()). @@ -234,14 +234,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Remove this entry. BasicBlock *ParentBB = SI->getParent(); DefaultDest->removePredecessor(ParentBB); - i = SI->removeCase(i); - e = SI->case_end(); + It = SI->removeCase(It); + End = SI->case_end(); // Removing this case may have made the condition constant. In that // case, update CI and restart iteration through the cases. if (auto *NewCI = dyn_cast<ConstantInt>(SI->getCondition())) { CI = NewCI; - i = SI->case_begin(); + It = SI->case_begin(); } Changed = true; @@ -251,11 +251,11 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Otherwise, check to see if the switch only branches to one destination. // We do this by reseting "TheOnlyDest" to null when we find two non-equal // destinations. - if (i->getCaseSuccessor() != TheOnlyDest) + if (It->getCaseSuccessor() != TheOnlyDest) TheOnlyDest = nullptr; // Increment this iterator as we haven't removed the case. - ++i; + ++It; } if (CI && !TheOnlyDest) { @@ -424,18 +424,10 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (I->isEHPad()) return false; - // We don't want debug info removed by anything this general, unless - // debug info is empty. - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(I)) { - if (DDI->getAddress()) - return false; - return true; - } - if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(I)) { - if (DVI->hasArgList() || DVI->getValue(0)) - return false; - return true; - } + // We don't want debug info removed by anything this general. + if (isa<DbgVariableIntrinsic>(I)) + return false; + if (DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) { if (DLI->getLabel()) return false; @@ -555,7 +547,7 @@ bool llvm::RecursivelyDeleteTriviallyDeadInstructionsPermissive( std::function<void(Value *)> AboutToDeleteCallback) { unsigned S = 0, E = DeadInsts.size(), Alive = 0; for (; S != E; ++S) { - auto *I = dyn_cast<Instruction>(DeadInsts[S]); + auto *I = dyn_cast_or_null<Instruction>(DeadInsts[S]); if (!I || !isInstructionTriviallyDead(I)) { DeadInsts[S] = nullptr; ++Alive; @@ -1231,12 +1223,10 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, // If the unconditional branch we replaced contains llvm.loop metadata, we // add the metadata to the branch instructions in the predecessors. - unsigned LoopMDKind = BB->getContext().getMDKindID("llvm.loop"); - Instruction *TI = BB->getTerminator(); - if (TI) - if (MDNode *LoopMD = TI->getMetadata(LoopMDKind)) + if (Instruction *TI = BB->getTerminator()) + if (MDNode *LoopMD = TI->getMetadata(LLVMContext::MD_loop)) for (BasicBlock *Pred : predecessors(BB)) - Pred->getTerminator()->setMetadata(LoopMDKind, LoopMD); + Pred->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopMD); // Everything that jumped to BB now goes to Succ. BB->replaceAllUsesWith(Succ); @@ -1423,6 +1413,12 @@ static Align tryEnforceAlignment(Value *V, Align PrefAlign, if (!GO->canIncreaseAlignment()) return CurrentAlign; + if (GO->isThreadLocal()) { + unsigned MaxTLSAlign = GO->getParent()->getMaxTLSAlignment() / CHAR_BIT; + if (MaxTLSAlign && PrefAlign > Align(MaxTLSAlign)) + PrefAlign = Align(MaxTLSAlign); + } + GO->setAlignment(PrefAlign); return PrefAlign; } @@ -1480,19 +1476,16 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, /// (or fragment of the variable) described by \p DII. /// /// This is primarily intended as a helper for the different -/// ConvertDebugDeclareToDebugValue functions. The dbg.declare/dbg.addr that is -/// converted describes an alloca'd variable, so we need to use the -/// alloc size of the value when doing the comparison. E.g. an i1 value will be -/// identified as covering an n-bit fragment, if the store size of i1 is at -/// least n bits. +/// ConvertDebugDeclareToDebugValue functions. The dbg.declare that is converted +/// describes an alloca'd variable, so we need to use the alloc size of the +/// value when doing the comparison. E.g. an i1 value will be identified as +/// covering an n-bit fragment, if the store size of i1 is at least n bits. static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { const DataLayout &DL = DII->getModule()->getDataLayout(); TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy); - if (std::optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits()) { - assert(!ValueSize.isScalable() && - "Fragments don't work on scalable types."); - return ValueSize.getFixedValue() >= *FragmentSize; - } + if (std::optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits()) + return TypeSize::isKnownGE(ValueSize, TypeSize::getFixed(*FragmentSize)); + // We can't always calculate the size of the DI variable (e.g. if it is a // VLA). Try to use the size of the alloca that the dbg intrinsic describes // intead. @@ -1513,7 +1506,7 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { } /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value -/// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. +/// that has an associated llvm.dbg.declare intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, StoreInst *SI, DIBuilder &Builder) { assert(DII->isAddressOfVariable() || isa<DbgAssignIntrinsic>(DII)); @@ -1524,24 +1517,39 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, DebugLoc NewLoc = getDebugValueLoc(DII); - if (!valueCoversEntireFragment(DV->getType(), DII)) { - // FIXME: If storing to a part of the variable described by the dbg.declare, - // then we want to insert a dbg.value for the corresponding fragment. - LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " - << *DII << '\n'); - // For now, when there is a store to parts of the variable (but we do not - // know which part) we insert an dbg.value intrinsic to indicate that we - // know nothing about the variable's content. - DV = UndefValue::get(DV->getType()); + // If the alloca describes the variable itself, i.e. the expression in the + // dbg.declare doesn't start with a dereference, we can perform the + // conversion if the value covers the entire fragment of DII. + // If the alloca describes the *address* of DIVar, i.e. DIExpr is + // *just* a DW_OP_deref, we use DV as is for the dbg.value. + // We conservatively ignore other dereferences, because the following two are + // not equivalent: + // dbg.declare(alloca, ..., !Expr(deref, plus_uconstant, 2)) + // dbg.value(DV, ..., !Expr(deref, plus_uconstant, 2)) + // The former is adding 2 to the address of the variable, whereas the latter + // is adding 2 to the value of the variable. As such, we insist on just a + // deref expression. + bool CanConvert = + DIExpr->isDeref() || (!DIExpr->startsWithDeref() && + valueCoversEntireFragment(DV->getType(), DII)); + if (CanConvert) { Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); return; } + // FIXME: If storing to a part of the variable described by the dbg.declare, + // then we want to insert a dbg.value for the corresponding fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " << *DII + << '\n'); + // For now, when there is a store to parts of the variable (but we do not + // know which part) we insert an dbg.value intrinsic to indicate that we + // know nothing about the variable's content. + DV = UndefValue::get(DV->getType()); Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value -/// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. +/// that has an associated llvm.dbg.declare intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, LoadInst *LI, DIBuilder &Builder) { auto *DIVar = DII->getVariable(); @@ -1569,7 +1577,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, } /// Inserts a llvm.dbg.value intrinsic after a phi that has an associated -/// llvm.dbg.declare or llvm.dbg.addr intrinsic. +/// llvm.dbg.declare intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, PHINode *APN, DIBuilder &Builder) { auto *DIVar = DII->getVariable(); @@ -1752,8 +1760,8 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, DIBuilder &Builder, uint8_t DIExprFlags, int Offset) { - auto DbgAddrs = FindDbgAddrUses(Address); - for (DbgVariableIntrinsic *DII : DbgAddrs) { + auto DbgDeclares = FindDbgDeclareUses(Address); + for (DbgVariableIntrinsic *DII : DbgDeclares) { const DebugLoc &Loc = DII->getDebugLoc(); auto *DIVar = DII->getVariable(); auto *DIExpr = DII->getExpression(); @@ -1764,7 +1772,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, DII); DII->eraseFromParent(); } - return !DbgAddrs.empty(); + return !DbgDeclares.empty(); } static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, @@ -1860,9 +1868,8 @@ void llvm::salvageDebugInfoForDbgValues( continue; } - // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they - // are implicitly pointing out the value as a DWARF memory location - // description. + // Do not add DW_OP_stack_value for DbgDeclare, because they are implicitly + // pointing out the value as a DWARF memory location description. bool StackValue = isa<DbgValueInst>(DII); auto DIILocation = DII->location_ops(); assert( @@ -1896,17 +1903,14 @@ void llvm::salvageDebugInfoForDbgValues( bool IsValidSalvageExpr = SalvagedExpr->getNumElements() <= MaxExpressionSize; if (AdditionalValues.empty() && IsValidSalvageExpr) { DII->setExpression(SalvagedExpr); - } else if (isa<DbgValueInst>(DII) && !isa<DbgAssignIntrinsic>(DII) && - IsValidSalvageExpr && + } else if (isa<DbgValueInst>(DII) && IsValidSalvageExpr && DII->getNumVariableLocationOps() + AdditionalValues.size() <= MaxDebugArgs) { DII->addVariableLocationOps(AdditionalValues, SalvagedExpr); } else { - // Do not salvage using DIArgList for dbg.addr/dbg.declare, as it is - // not currently supported in those instructions. Do not salvage using - // DIArgList for dbg.assign yet. FIXME: support this. - // Also do not salvage if the resulting DIArgList would contain an - // unreasonably large number of values. + // Do not salvage using DIArgList for dbg.declare, as it is not currently + // supported in those instructions. Also do not salvage if the resulting + // DIArgList would contain an unreasonably large number of values. DII->setKillLocation(); } LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); @@ -1934,7 +1938,7 @@ Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, Opcodes.insert(Opcodes.begin(), {dwarf::DW_OP_LLVM_arg, 0}); CurrentLocOps = 1; } - for (auto Offset : VariableOffsets) { + for (const auto &Offset : VariableOffsets) { AdditionalValues.push_back(Offset.first); assert(Offset.second.isStrictlyPositive() && "Expected strictly positive multiplier for offset."); @@ -1976,6 +1980,18 @@ uint64_t getDwarfOpForBinOp(Instruction::BinaryOps Opcode) { } } +static void handleSSAValueOperands(uint64_t CurrentLocOps, + SmallVectorImpl<uint64_t> &Opcodes, + SmallVectorImpl<Value *> &AdditionalValues, + Instruction *I) { + if (!CurrentLocOps) { + Opcodes.append({dwarf::DW_OP_LLVM_arg, 0}); + CurrentLocOps = 1; + } + Opcodes.append({dwarf::DW_OP_LLVM_arg, CurrentLocOps}); + AdditionalValues.push_back(I->getOperand(1)); +} + Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, SmallVectorImpl<uint64_t> &Opcodes, SmallVectorImpl<Value *> &AdditionalValues) { @@ -1998,12 +2014,7 @@ Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, } Opcodes.append({dwarf::DW_OP_constu, Val}); } else { - if (!CurrentLocOps) { - Opcodes.append({dwarf::DW_OP_LLVM_arg, 0}); - CurrentLocOps = 1; - } - Opcodes.append({dwarf::DW_OP_LLVM_arg, CurrentLocOps}); - AdditionalValues.push_back(BI->getOperand(1)); + handleSSAValueOperands(CurrentLocOps, Opcodes, AdditionalValues, BI); } // Add salvaged binary operator to expression stack, if it has a valid @@ -2015,6 +2026,60 @@ Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, return BI->getOperand(0); } +uint64_t getDwarfOpForIcmpPred(CmpInst::Predicate Pred) { + // The signedness of the operation is implicit in the typed stack, signed and + // unsigned instructions map to the same DWARF opcode. + switch (Pred) { + case CmpInst::ICMP_EQ: + return dwarf::DW_OP_eq; + case CmpInst::ICMP_NE: + return dwarf::DW_OP_ne; + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_SGT: + return dwarf::DW_OP_gt; + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGE: + return dwarf::DW_OP_ge; + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_SLT: + return dwarf::DW_OP_lt; + case CmpInst::ICMP_ULE: + case CmpInst::ICMP_SLE: + return dwarf::DW_OP_le; + default: + return 0; + } +} + +Value *getSalvageOpsForIcmpOp(ICmpInst *Icmp, uint64_t CurrentLocOps, + SmallVectorImpl<uint64_t> &Opcodes, + SmallVectorImpl<Value *> &AdditionalValues) { + // Handle icmp operations with constant integer operands as a special case. + auto *ConstInt = dyn_cast<ConstantInt>(Icmp->getOperand(1)); + // Values wider than 64 bits cannot be represented within a DIExpression. + if (ConstInt && ConstInt->getBitWidth() > 64) + return nullptr; + // Push any Constant Int operand onto the expression stack. + if (ConstInt) { + if (Icmp->isSigned()) + Opcodes.push_back(dwarf::DW_OP_consts); + else + Opcodes.push_back(dwarf::DW_OP_constu); + uint64_t Val = ConstInt->getSExtValue(); + Opcodes.push_back(Val); + } else { + handleSSAValueOperands(CurrentLocOps, Opcodes, AdditionalValues, Icmp); + } + + // Add salvaged binary operator to expression stack, if it has a valid + // representation in a DIExpression. + uint64_t DwarfIcmpOp = getDwarfOpForIcmpPred(Icmp->getPredicate()); + if (!DwarfIcmpOp) + return nullptr; + Opcodes.push_back(DwarfIcmpOp); + return Icmp->getOperand(0); +} + Value *llvm::salvageDebugInfoImpl(Instruction &I, uint64_t CurrentLocOps, SmallVectorImpl<uint64_t> &Ops, SmallVectorImpl<Value *> &AdditionalValues) { @@ -2054,6 +2119,8 @@ Value *llvm::salvageDebugInfoImpl(Instruction &I, uint64_t CurrentLocOps, return getSalvageOpsForGEP(GEP, DL, CurrentLocOps, Ops, AdditionalValues); if (auto *BI = dyn_cast<BinaryOperator>(&I)) return getSalvageOpsForBinOp(BI, CurrentLocOps, Ops, AdditionalValues); + if (auto *IC = dyn_cast<ICmpInst>(&I)) + return getSalvageOpsForIcmpOp(IC, CurrentLocOps, Ops, AdditionalValues); // *Not* to do: we should not attempt to salvage load instructions, // because the validity and lifetime of a dbg.value containing @@ -2661,43 +2728,52 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, intersectAccessGroups(K, J)); break; case LLVMContext::MD_range: - - // If K does move, use most generic range. Otherwise keep the range of - // K. - if (DoesKMove) - // FIXME: If K does move, we should drop the range info and nonnull. - // Currently this function is used with DoesKMove in passes - // doing hoisting/sinking and the current behavior of using the - // most generic range is correct in those cases. + if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)) K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD)); break; case LLVMContext::MD_fpmath: K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD)); break; case LLVMContext::MD_invariant_load: - // Only set the !invariant.load if it is present in both instructions. - K->setMetadata(Kind, JMD); + // If K moves, only set the !invariant.load if it is present in both + // instructions. + if (DoesKMove) + K->setMetadata(Kind, JMD); break; case LLVMContext::MD_nonnull: - // If K does move, keep nonull if it is present in both instructions. - if (DoesKMove) + if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)) K->setMetadata(Kind, JMD); break; case LLVMContext::MD_invariant_group: // Preserve !invariant.group in K. break; case LLVMContext::MD_align: - K->setMetadata(Kind, - MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); + if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)) + K->setMetadata( + Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; case LLVMContext::MD_dereferenceable: case LLVMContext::MD_dereferenceable_or_null: - K->setMetadata(Kind, - MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); + if (DoesKMove) + K->setMetadata(Kind, + MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; case LLVMContext::MD_preserve_access_index: // Preserve !preserve.access.index in K. break; + case LLVMContext::MD_noundef: + // If K does move, keep noundef if it is present in both instructions. + if (DoesKMove) + K->setMetadata(Kind, JMD); + break; + case LLVMContext::MD_nontemporal: + // Preserve !nontemporal if it is present on both instructions. + K->setMetadata(Kind, JMD); + break; + case LLVMContext::MD_prof: + if (DoesKMove) + K->setMetadata(Kind, MDNode::getMergedProfMetadata(KMD, JMD, K, J)); + break; } } // Set !invariant.group from J if J has it. If both instructions have it @@ -2713,14 +2789,22 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, bool KDominatesJ) { - unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_invariant_load, LLVMContext::MD_nonnull, - LLVMContext::MD_invariant_group, LLVMContext::MD_align, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index}; + unsigned KnownIDs[] = {LLVMContext::MD_tbaa, + LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull, + LLVMContext::MD_invariant_group, + LLVMContext::MD_align, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_access_group, + LLVMContext::MD_preserve_access_index, + LLVMContext::MD_prof, + LLVMContext::MD_nontemporal, + LLVMContext::MD_noundef}; combineMetadata(K, J, KnownIDs, KDominatesJ); } @@ -2799,13 +2883,7 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) { // In general, GVN unifies expressions over different control-flow // regions, and so we need a conservative combination of the noalias // scopes. - static const unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group, LLVMContext::MD_nonnull, - LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index}; - combineMetadata(ReplInst, I, KnownIDs, false); + combineMetadataForCSE(ReplInst, I, false); } template <typename RootType, typename DominatesFn> @@ -2930,7 +3008,8 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, return; unsigned BitWidth = DL.getPointerTypeSizeInBits(NewTy); - if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { + if (BitWidth == OldLI.getType()->getScalarSizeInBits() && + !getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { MDNode *NN = MDNode::get(OldLI.getContext(), std::nullopt); NewLI.setMetadata(LLVMContext::MD_nonnull, NN); } @@ -2969,7 +3048,7 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { Instruction *I = &*II; - I->dropUndefImplyingAttrsAndUnknownMetadata(); + I->dropUBImplyingAttrsAndMetadata(); if (I->isUsedByMetadata()) dropDebugUsers(*I); if (I->isDebugOrPseudoInst()) { @@ -3125,7 +3204,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // Check that the mask allows a multiple of 8 bits for a bswap, for an // early exit. - unsigned NumMaskedBits = AndMask.countPopulation(); + unsigned NumMaskedBits = AndMask.popcount(); if (!MatchBitReversals && (NumMaskedBits % 8) != 0) return Result; diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp index 2acbe9002309..d701cf110154 100644 --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -345,20 +345,20 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, assert(L.isLoopSimplifyForm() && "Loop needs to be in loop simplify form"); unsigned DesiredPeelCount = 0; - for (auto *BB : L.blocks()) { - auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); - if (!BI || BI->isUnconditional()) - continue; - - // Ignore loop exit condition. - if (L.getLoopLatch() == BB) - continue; + // Do not peel the entire loop. + const SCEV *BE = SE.getConstantMaxBackedgeTakenCount(&L); + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(BE)) + MaxPeelCount = + std::min((unsigned)SC->getAPInt().getLimitedValue() - 1, MaxPeelCount); + + auto ComputePeelCount = [&](Value *Condition) -> void { + if (!Condition->getType()->isIntegerTy()) + return; - Value *Condition = BI->getCondition(); Value *LeftVal, *RightVal; CmpInst::Predicate Pred; if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal)))) - continue; + return; const SCEV *LeftSCEV = SE.getSCEV(LeftVal); const SCEV *RightSCEV = SE.getSCEV(RightVal); @@ -366,7 +366,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, // Do not consider predicates that are known to be true or false // independently of the loop iteration. if (SE.evaluatePredicate(Pred, LeftSCEV, RightSCEV)) - continue; + return; // Check if we have a condition with one AddRec and one non AddRec // expression. Normalize LeftSCEV to be the AddRec. @@ -375,7 +375,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, std::swap(LeftSCEV, RightSCEV); Pred = ICmpInst::getSwappedPredicate(Pred); } else - continue; + return; } const SCEVAddRecExpr *LeftAR = cast<SCEVAddRecExpr>(LeftSCEV); @@ -383,10 +383,10 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, // Avoid huge SCEV computations in the loop below, make sure we only // consider AddRecs of the loop we are trying to peel. if (!LeftAR->isAffine() || LeftAR->getLoop() != &L) - continue; + return; if (!(ICmpInst::isEquality(Pred) && LeftAR->hasNoSelfWrap()) && !SE.getMonotonicPredicateType(LeftAR, Pred)) - continue; + return; // Check if extending the current DesiredPeelCount lets us evaluate Pred // or !Pred in the loop body statically. @@ -422,7 +422,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, // first iteration of the loop body after peeling? if (!SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), IterVal, RightSCEV)) - continue; // If not, give up. + return; // If not, give up. // However, for equality comparisons, that isn't always sufficient to // eliminate the comparsion in loop body, we may need to peel one more @@ -433,11 +433,28 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, !SE.isKnownPredicate(Pred, IterVal, RightSCEV) && SE.isKnownPredicate(Pred, NextIterVal, RightSCEV)) { if (!CanPeelOneMoreIteration()) - continue; // Need to peel one more iteration, but can't. Give up. + return; // Need to peel one more iteration, but can't. Give up. PeelOneMoreIteration(); // Great! } DesiredPeelCount = std::max(DesiredPeelCount, NewPeelCount); + }; + + for (BasicBlock *BB : L.blocks()) { + for (Instruction &I : *BB) { + if (SelectInst *SI = dyn_cast<SelectInst>(&I)) + ComputePeelCount(SI->getCondition()); + } + + auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || BI->isUnconditional()) + continue; + + // Ignore loop exit condition. + if (L.getLoopLatch() == BB) + continue; + + ComputePeelCount(BI->getCondition()); } return DesiredPeelCount; @@ -1025,6 +1042,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // We modified the loop, update SE. SE->forgetTopmostLoop(L); + SE->forgetBlockAndLoopDispositions(); #ifdef EXPENSIVE_CHECKS // Finally DomtTree must be correct. diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index 1a9eaf242190..d81db5647c60 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -435,6 +435,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Otherwise, create a duplicate of the instruction. Instruction *C = Inst->clone(); + C->insertBefore(LoopEntryBranch); + ++NumInstrsDuplicated; // Eagerly remap the operands of the instruction. @@ -444,7 +446,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Avoid inserting the same intrinsic twice. if (auto *DII = dyn_cast<DbgVariableIntrinsic>(C)) if (DbgIntrinsics.count(makeHash(DII))) { - C->deleteValue(); + C->eraseFromParent(); continue; } @@ -457,7 +459,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // in the map. InsertNewValueIntoMap(ValueMap, Inst, V); if (!C->mayHaveSideEffects()) { - C->deleteValue(); + C->eraseFromParent(); C = nullptr; } } else { @@ -466,7 +468,6 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { if (C) { // Otherwise, stick the new instruction into the new block! C->setName(Inst->getName()); - C->insertBefore(LoopEntryBranch); if (auto *II = dyn_cast<AssumeInst>(C)) AC->registerAssumption(II); diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 87a0e54e2704..3e604fdf2e11 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -448,16 +448,15 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, // backedge blocks to jump to the BEBlock instead of the header. // If one of the backedges has llvm.loop metadata attached, we remove // it from the backedge and add it to BEBlock. - unsigned LoopMDKind = BEBlock->getContext().getMDKindID("llvm.loop"); MDNode *LoopMD = nullptr; for (BasicBlock *BB : BackedgeBlocks) { Instruction *TI = BB->getTerminator(); if (!LoopMD) - LoopMD = TI->getMetadata(LoopMDKind); - TI->setMetadata(LoopMDKind, nullptr); + LoopMD = TI->getMetadata(LLVMContext::MD_loop); + TI->setMetadata(LLVMContext::MD_loop, nullptr); TI->replaceSuccessorWith(Header, BEBlock); } - BEBlock->getTerminator()->setMetadata(LoopMDKind, LoopMD); + BEBlock->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopMD); //===--- Update all analyses which we must preserve now -----------------===// @@ -693,12 +692,6 @@ ReprocessLoop: } } - // Changing exit conditions for blocks may affect exit counts of this loop and - // any of its paretns, so we must invalidate the entire subtree if we've made - // any changes. - if (Changed && SE) - SE->forgetTopmostLoop(L); - if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -737,6 +730,13 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, DT, LI, SE, AC, MSSAU, PreserveLCSSA); + // Changing exit conditions for blocks may affect exit counts of this loop and + // any of its parents, so we must invalidate the entire subtree if we've made + // any changes. Do this here rather than in simplifyOneLoop() as the top-most + // loop is going to be the same for all child loops. + if (Changed && SE) + SE->forgetTopmostLoop(L); + return Changed; } diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index e8f585b4a94d..511dd61308f9 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -45,6 +45,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/ValueHandle.h" @@ -216,6 +217,8 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, const TargetTransformInfo *TTI) { + using namespace llvm::PatternMatch; + // Simplify any new induction variables in the partially unrolled loop. if (SE && SimplifyIVs) { SmallVector<WeakTrackingVH, 16> DeadInsts; @@ -241,6 +244,30 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, Inst.replaceAllUsesWith(V); if (isInstructionTriviallyDead(&Inst)) DeadInsts.emplace_back(&Inst); + + // Fold ((add X, C1), C2) to (add X, C1+C2). This is very common in + // unrolled loops, and handling this early allows following code to + // identify the IV as a "simple recurrence" without first folding away + // a long chain of adds. + { + Value *X; + const APInt *C1, *C2; + if (match(&Inst, m_Add(m_Add(m_Value(X), m_APInt(C1)), m_APInt(C2)))) { + auto *InnerI = dyn_cast<Instruction>(Inst.getOperand(0)); + auto *InnerOBO = cast<OverflowingBinaryOperator>(Inst.getOperand(0)); + bool SignedOverflow; + APInt NewC = C1->sadd_ov(*C2, SignedOverflow); + Inst.setOperand(0, X); + Inst.setOperand(1, ConstantInt::get(Inst.getType(), NewC)); + Inst.setHasNoUnsignedWrap(Inst.hasNoUnsignedWrap() && + InnerOBO->hasNoUnsignedWrap()); + Inst.setHasNoSignedWrap(Inst.hasNoSignedWrap() && + InnerOBO->hasNoSignedWrap() && + !SignedOverflow); + if (InnerI && isInstructionTriviallyDead(InnerI)) + DeadInsts.emplace_back(InnerI); + } + } } // We can't do recursive deletion until we're done iterating, as we might // have a phi which (potentially indirectly) uses instructions later in @@ -310,6 +337,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, const unsigned MaxTripCount = SE->getSmallConstantMaxTripCount(L); const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L); + unsigned EstimatedLoopInvocationWeight = 0; + std::optional<unsigned> OriginalTripCount = + llvm::getLoopEstimatedTripCount(L, &EstimatedLoopInvocationWeight); // Effectively "DCE" unrolled iterations that are beyond the max tripcount // and will never be executed. @@ -513,7 +543,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, !EnableFSDiscriminator) for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) - if (!isa<DbgInfoIntrinsic>(&I)) + if (!I.isDebugOrPseudoInst()) if (const DILocation *DIL = I.getDebugLoc()) { auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(ULO.Count); if (NewDIL) @@ -830,8 +860,16 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, Loop *OuterL = L->getParentLoop(); // Update LoopInfo if the loop is completely removed. - if (CompletelyUnroll) + if (CompletelyUnroll) { LI->erase(L); + // We shouldn't try to use `L` anymore. + L = nullptr; + } else if (OriginalTripCount) { + // Update the trip count. Note that the remainder has already logic + // computing it in `UnrollRuntimeLoopRemainder`. + setLoopEstimatedTripCount(L, *OriginalTripCount / ULO.Count, + EstimatedLoopInvocationWeight); + } // LoopInfo should not be valid, confirm that. if (UnrollVerifyLoopInfo) diff --git a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp index b125e952ec94..31b8cd34eb24 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -347,7 +347,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, !EnableFSDiscriminator) for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) - if (!isa<DbgInfoIntrinsic>(&I)) + if (!I.isDebugOrPseudoInst()) if (const DILocation *DIL = I.getDebugLoc()) { auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count); if (NewDIL) @@ -757,11 +757,11 @@ checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks, DependenceInfo &DI, LoopInfo &LI) { SmallVector<BasicBlockSet, 8> AllBlocks; for (Loop *L : Root.getLoopsInPreorder()) - if (ForeBlocksMap.find(L) != ForeBlocksMap.end()) + if (ForeBlocksMap.contains(L)) AllBlocks.push_back(ForeBlocksMap.lookup(L)); AllBlocks.push_back(SubLoopBlocks); for (Loop *L : Root.getLoopsInPreorder()) - if (AftBlocksMap.find(L) != AftBlocksMap.end()) + if (AftBlocksMap.contains(L)) AllBlocks.push_back(AftBlocksMap.lookup(L)); unsigned LoopDepth = Root.getLoopDepth(); diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index b19156bcb420..1e22eca30d2d 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -457,7 +457,7 @@ static bool canProfitablyUnrollMultiExitLoop( // call. return (OtherExits.size() == 1 && (UnrollRuntimeOtherExitPredictable || - OtherExits[0]->getTerminatingDeoptimizeCall())); + OtherExits[0]->getPostdominatingDeoptimizeCall())); // TODO: These can be fine-tuned further to consider code size or deopt states // that are captured by the deoptimize exit block. // Also, we can extend this to support more cases, if we actually diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 7df8651ede15..7d6662c44f07 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -466,6 +466,19 @@ llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) { return Worklist; } +bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) { + int LatchIdx = PN->getBasicBlockIndex(LatchBlock); + Value *IncV = PN->getIncomingValue(LatchIdx); + + for (User *U : PN->users()) + if (U != Cond && U != IncV) return false; + + for (User *U : IncV->users()) + if (U != Cond && U != PN) return false; + return true; +} + + void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, LoopInfo *LI, MemorySSA *MSSA) { assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!"); @@ -628,18 +641,17 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, } // After the loop has been deleted all the values defined and modified - // inside the loop are going to be unavailable. - // Since debug values in the loop have been deleted, inserting an undef - // dbg.value truncates the range of any dbg.value before the loop where the - // loop used to be. This is particularly important for constant values. + // inside the loop are going to be unavailable. Values computed in the + // loop will have been deleted, automatically causing their debug uses + // be be replaced with undef. Loop invariant values will still be available. + // Move dbg.values out the loop so that earlier location ranges are still + // terminated and loop invariant assignments are preserved. Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI(); assert(InsertDbgValueBefore && "There should be a non-PHI instruction in exit block, else these " "instructions will have no parent."); - for (auto *DVI : DeadDebugInst) { - DVI->setKillLocation(); + for (auto *DVI : DeadDebugInst) DVI->moveBefore(InsertDbgValueBefore); - } } // Remove the block from the reference counting scheme, so that we can @@ -880,6 +892,29 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } +Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) { + switch (RK) { + default: + llvm_unreachable("Unknown min/max recurrence kind"); + case RecurKind::UMin: + return Intrinsic::umin; + case RecurKind::UMax: + return Intrinsic::umax; + case RecurKind::SMin: + return Intrinsic::smin; + case RecurKind::SMax: + return Intrinsic::smax; + case RecurKind::FMin: + return Intrinsic::minnum; + case RecurKind::FMax: + return Intrinsic::maxnum; + case RecurKind::FMinimum: + return Intrinsic::minimum; + case RecurKind::FMaximum: + return Intrinsic::maximum; + } +} + CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { switch (RK) { default: @@ -896,6 +931,9 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { return CmpInst::FCMP_OLT; case RecurKind::FMax: return CmpInst::FCMP_OGT; + // We do not add FMinimum/FMaximum recurrence kind here since there is no + // equivalent predicate which compares signed zeroes according to the + // semantics of the intrinsics (llvm.minimum/maximum). } } @@ -910,6 +948,14 @@ Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, Value *Right) { + Type *Ty = Left->getType(); + if (Ty->isIntOrIntVectorTy() || + (RK == RecurKind::FMinimum || RK == RecurKind::FMaximum)) { + // TODO: Add float minnum/maxnum support when FMF nnan is set. + Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RK); + return Builder.CreateIntrinsic(Ty, Id, {Left, Right}, nullptr, + "rdx.minmax"); + } CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK); Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp"); Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); @@ -1055,6 +1101,10 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, return Builder.CreateFPMaxReduce(Src); case RecurKind::FMin: return Builder.CreateFPMinReduce(Src); + case RecurKind::FMinimum: + return Builder.CreateFPMinimumReduce(Src); + case RecurKind::FMaximum: + return Builder.CreateFPMaximumReduce(Src); default: llvm_unreachable("Unhandled opcode"); } @@ -1123,6 +1173,20 @@ bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L, SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGE, S, Zero); } +bool llvm::isKnownPositiveInLoop(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(S->getType()); + return SE.isAvailableAtLoopEntry(S, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGT, S, Zero); +} + +bool llvm::isKnownNonPositiveInLoop(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(S->getType()); + return SE.isAvailableAtLoopEntry(S, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLE, S, Zero); +} + bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE, bool Signed) { unsigned BitWidth = cast<IntegerType>(S->getType())->getBitWidth(); diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 17e71cf5a6c4..78ebe75c121b 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -23,7 +23,6 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -31,6 +30,8 @@ using namespace llvm; +#define DEBUG_TYPE "loop-versioning" + static cl::opt<bool> AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true), cl::Hidden, @@ -208,7 +209,7 @@ void LoopVersioning::prepareNoAliasMetadata() { // Finally, transform the above to actually map to scope list which is what // the metadata uses. - for (auto Pair : GroupToNonAliasingScopes) + for (const auto &Pair : GroupToNonAliasingScopes) GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second); } @@ -290,56 +291,6 @@ bool runImpl(LoopInfo *LI, LoopAccessInfoManager &LAIs, DominatorTree *DT, return Changed; } - -/// Also expose this is a pass. Currently this is only used for -/// unit-testing. It adds all memchecks necessary to remove all may-aliasing -/// array accesses from the loop. -class LoopVersioningLegacyPass : public FunctionPass { -public: - LoopVersioningLegacyPass() : FunctionPass(ID) { - initializeLoopVersioningLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - - return runImpl(LI, LAIs, DT, SE); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<LoopAccessLegacyAnalysis>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - } - - static char ID; -}; -} - -#define LVER_OPTION "loop-versioning" -#define DEBUG_TYPE LVER_OPTION - -char LoopVersioningLegacyPass::ID; -static const char LVer_name[] = "Loop Versioning"; - -INITIALIZE_PASS_BEGIN(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false, - false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(LoopVersioningLegacyPass, LVER_OPTION, LVer_name, false, - false) - -namespace llvm { -FunctionPass *createLoopVersioningLegacyPass() { - return new LoopVersioningLegacyPass(); } PreservedAnalyses LoopVersioningPass::run(Function &F, @@ -353,4 +304,3 @@ PreservedAnalyses LoopVersioningPass::run(Function &F, return PreservedAnalyses::none(); return PreservedAnalyses::all(); } -} // namespace llvm diff --git a/llvm/lib/Transforms/Utils/LowerAtomic.cpp b/llvm/lib/Transforms/Utils/LowerAtomic.cpp index b6f40de0daa6..b203970ef9c5 100644 --- a/llvm/lib/Transforms/Utils/LowerAtomic.cpp +++ b/llvm/lib/Transforms/Utils/LowerAtomic.cpp @@ -14,8 +14,7 @@ #include "llvm/Transforms/Utils/LowerAtomic.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" + using namespace llvm; #define DEBUG_TYPE "loweratomic" @@ -102,6 +101,9 @@ Value *llvm::buildAtomicRMWValue(AtomicRMWInst::BinOp Op, bool llvm::lowerAtomicRMWInst(AtomicRMWInst *RMWI) { IRBuilder<> Builder(RMWI); + Builder.setIsFPConstrained( + RMWI->getFunction()->hasFnAttribute(Attribute::StrictFP)); + Value *Ptr = RMWI->getPointerOperand(); Value *Val = RMWI->getValOperand(); diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 165740b55298..906eb71fc2d9 100644 --- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -12,9 +12,12 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <optional> +#define DEBUG_TYPE "lower-mem-intrinsics" + using namespace llvm; void llvm::createMemCpyLoopKnownSize( @@ -376,19 +379,14 @@ void llvm::createMemCpyLoopUnknownSize( static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, Value *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, - bool DstIsVolatile) { + bool DstIsVolatile, + const TargetTransformInfo &TTI) { Type *TypeOfCopyLen = CopyLen->getType(); BasicBlock *OrigBB = InsertBefore->getParent(); Function *F = OrigBB->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); - // TODO: Use different element type if possible? - IRBuilder<> CastBuilder(InsertBefore); - Type *EltTy = CastBuilder.getInt8Ty(); - Type *PtrTy = - CastBuilder.getInt8PtrTy(SrcAddr->getType()->getPointerAddressSpace()); - SrcAddr = CastBuilder.CreateBitCast(SrcAddr, PtrTy); - DstAddr = CastBuilder.CreateBitCast(DstAddr, PtrTy); + Type *EltTy = Type::getInt8Ty(F->getContext()); // Create the a comparison of src and dst, based on which we jump to either // the forward-copy part of the function (if src >= dst) or the backwards-copy @@ -428,6 +426,7 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, BasicBlock *LoopBB = BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB); IRBuilder<> LoopBuilder(LoopBB); + PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); Value *IndexPtr = LoopBuilder.CreateSub( LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr"); @@ -552,15 +551,57 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, } } -void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) { - createMemMoveLoop(/* InsertBefore */ Memmove, - /* SrcAddr */ Memmove->getRawSource(), - /* DstAddr */ Memmove->getRawDest(), - /* CopyLen */ Memmove->getLength(), - /* SrcAlign */ Memmove->getSourceAlign().valueOrOne(), - /* DestAlign */ Memmove->getDestAlign().valueOrOne(), - /* SrcIsVolatile */ Memmove->isVolatile(), - /* DstIsVolatile */ Memmove->isVolatile()); +bool llvm::expandMemMoveAsLoop(MemMoveInst *Memmove, + const TargetTransformInfo &TTI) { + Value *CopyLen = Memmove->getLength(); + Value *SrcAddr = Memmove->getRawSource(); + Value *DstAddr = Memmove->getRawDest(); + Align SrcAlign = Memmove->getSourceAlign().valueOrOne(); + Align DstAlign = Memmove->getDestAlign().valueOrOne(); + bool SrcIsVolatile = Memmove->isVolatile(); + bool DstIsVolatile = SrcIsVolatile; + IRBuilder<> CastBuilder(Memmove); + + unsigned SrcAS = SrcAddr->getType()->getPointerAddressSpace(); + unsigned DstAS = DstAddr->getType()->getPointerAddressSpace(); + if (SrcAS != DstAS) { + if (!TTI.addrspacesMayAlias(SrcAS, DstAS)) { + // We may not be able to emit a pointer comparison, but we don't have + // to. Expand as memcpy. + if (ConstantInt *CI = dyn_cast<ConstantInt>(CopyLen)) { + createMemCpyLoopKnownSize(/*InsertBefore=*/Memmove, SrcAddr, DstAddr, + CI, SrcAlign, DstAlign, SrcIsVolatile, + DstIsVolatile, + /*CanOverlap=*/false, TTI); + } else { + createMemCpyLoopUnknownSize(/*InsertBefore=*/Memmove, SrcAddr, DstAddr, + CopyLen, SrcAlign, DstAlign, SrcIsVolatile, + DstIsVolatile, + /*CanOverlap=*/false, TTI); + } + + return true; + } + + if (TTI.isValidAddrSpaceCast(DstAS, SrcAS)) + DstAddr = CastBuilder.CreateAddrSpaceCast(DstAddr, SrcAddr->getType()); + else if (TTI.isValidAddrSpaceCast(SrcAS, DstAS)) + SrcAddr = CastBuilder.CreateAddrSpaceCast(SrcAddr, DstAddr->getType()); + else { + // We don't know generically if it's legal to introduce an + // addrspacecast. We need to know either if it's legal to insert an + // addrspacecast, or if the address spaces cannot alias. + LLVM_DEBUG( + dbgs() << "Do not know how to expand memmove between different " + "address spaces\n"); + return false; + } + } + + createMemMoveLoop( + /*InsertBefore=*/Memmove, SrcAddr, DstAddr, CopyLen, SrcAlign, DstAlign, + SrcIsVolatile, DstIsVolatile, TTI); + return true; } void llvm::expandMemSetAsLoop(MemSetInst *Memset) { diff --git a/llvm/lib/Transforms/Utils/Mem2Reg.cpp b/llvm/lib/Transforms/Utils/Mem2Reg.cpp index 5ad7aeb463ec..fbc6dd7613de 100644 --- a/llvm/lib/Transforms/Utils/Mem2Reg.cpp +++ b/llvm/lib/Transforms/Utils/Mem2Reg.cpp @@ -74,15 +74,19 @@ namespace { struct PromoteLegacyPass : public FunctionPass { // Pass identification, replacement for typeid static char ID; + bool ForcePass; /// If true, forces pass to execute, instead of skipping. - PromoteLegacyPass() : FunctionPass(ID) { + PromoteLegacyPass() : FunctionPass(ID), ForcePass(false) { + initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry()); + } + PromoteLegacyPass(bool IsForced) : FunctionPass(ID), ForcePass(IsForced) { initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry()); } // runOnFunction - To run this pass, first we calculate the alloca // instructions that are safe for promotion, then we promote each one. bool runOnFunction(Function &F) override { - if (skipFunction(F)) + if (!ForcePass && skipFunction(F)) return false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -111,6 +115,6 @@ INITIALIZE_PASS_END(PromoteLegacyPass, "mem2reg", "Promote Memory to Register", false, false) // createPromoteMemoryToRegister - Provide an entry point to create this pass. -FunctionPass *llvm::createPromoteMemoryToRegisterPass() { - return new PromoteLegacyPass(); +FunctionPass *llvm::createPromoteMemoryToRegisterPass(bool IsForced) { + return new PromoteLegacyPass(IsForced); } diff --git a/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp b/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp index 899928c085c6..531b0a624daf 100644 --- a/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp +++ b/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/MemoryOpRemark.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DebugInfo.h" @@ -321,7 +322,7 @@ void MemoryOpRemark::visitVariable(const Value *V, // Try to get an llvm.dbg.declare, which has a DILocalVariable giving us the // real debug info name and size of the variable. for (const DbgVariableIntrinsic *DVI : - FindDbgAddrUses(const_cast<Value *>(V))) { + FindDbgDeclareUses(const_cast<Value *>(V))) { if (DILocalVariable *DILV = DVI->getVariable()) { std::optional<uint64_t> DISize = getSizeInBytes(DILV->getSizeInBits()); VariableInfo Var{DILV->getName(), DISize}; @@ -387,7 +388,8 @@ bool AutoInitRemark::canHandle(const Instruction *I) { return false; return any_of(I->getMetadata(LLVMContext::MD_annotation)->operands(), [](const MDOperand &Op) { - return cast<MDString>(Op.get())->getString() == "auto-init"; + return isa<MDString>(Op.get()) && + cast<MDString>(Op.get())->getString() == "auto-init"; }); } diff --git a/llvm/lib/Transforms/Utils/MetaRenamer.cpp b/llvm/lib/Transforms/Utils/MetaRenamer.cpp index 0ea210671b93..44ac65f265f0 100644 --- a/llvm/lib/Transforms/Utils/MetaRenamer.cpp +++ b/llvm/lib/Transforms/Utils/MetaRenamer.cpp @@ -26,14 +26,12 @@ #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/TypeFinder.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Transforms/Utils.h" using namespace llvm; @@ -62,6 +60,11 @@ static cl::opt<std::string> RenameExcludeStructPrefixes( "by a comma"), cl::Hidden); +static cl::opt<bool> + RenameOnlyInst("rename-only-inst", cl::init(false), + cl::desc("only rename the instructions in the function"), + cl::Hidden); + static const char *const metaNames[] = { // See http://en.wikipedia.org/wiki/Metasyntactic_variable "foo", "bar", "baz", "quux", "barney", "snork", "zot", "blam", "hoge", @@ -105,6 +108,12 @@ parseExcludedPrefixes(StringRef PrefixesStr, } } +void MetaRenameOnlyInstructions(Function &F) { + for (auto &I : instructions(F)) + if (!I.getType()->isVoidTy() && I.getName().empty()) + I.setName(I.getOpcodeName()); +} + void MetaRename(Function &F) { for (Argument &Arg : F.args()) if (!Arg.getType()->isVoidTy()) @@ -115,7 +124,7 @@ void MetaRename(Function &F) { for (auto &I : BB) if (!I.getType()->isVoidTy()) - I.setName("tmp"); + I.setName(I.getOpcodeName()); } } @@ -145,6 +154,26 @@ void MetaRename(Module &M, [&Name](auto &Prefix) { return Name.startswith(Prefix); }); }; + // Leave library functions alone because their presence or absence could + // affect the behavior of other passes. + auto ExcludeLibFuncs = [&](Function &F) { + LibFunc Tmp; + StringRef Name = F.getName(); + return Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || + GetTLI(F).getLibFunc(F, Tmp) || + IsNameExcluded(Name, ExcludedFuncPrefixes); + }; + + if (RenameOnlyInst) { + // Rename all functions + for (auto &F : M) { + if (ExcludeLibFuncs(F)) + continue; + MetaRenameOnlyInstructions(F); + } + return; + } + // Rename all aliases for (GlobalAlias &GA : M.aliases()) { StringRef Name = GA.getName(); @@ -181,64 +210,20 @@ void MetaRename(Module &M, // Rename all functions for (auto &F : M) { - StringRef Name = F.getName(); - LibFunc Tmp; - // Leave library functions alone because their presence or absence could - // affect the behavior of other passes. - if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || - GetTLI(F).getLibFunc(F, Tmp) || - IsNameExcluded(Name, ExcludedFuncPrefixes)) + if (ExcludeLibFuncs(F)) continue; // Leave @main alone. The output of -metarenamer might be passed to // lli for execution and the latter needs a main entry point. - if (Name != "main") + if (F.getName() != "main") F.setName(renamer.newName()); MetaRename(F); } } -struct MetaRenamer : public ModulePass { - // Pass identification, replacement for typeid - static char ID; - - MetaRenamer() : ModulePass(ID) { - initializeMetaRenamerPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.setPreservesAll(); - } - - bool runOnModule(Module &M) override { - auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - MetaRename(M, GetTLI); - return true; - } -}; - } // end anonymous namespace -char MetaRenamer::ID = 0; - -INITIALIZE_PASS_BEGIN(MetaRenamer, "metarenamer", - "Assign new names to everything", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(MetaRenamer, "metarenamer", - "Assign new names to everything", false, false) - -//===----------------------------------------------------------------------===// -// -// MetaRenamer - Rename everything with metasyntactic names. -// -ModulePass *llvm::createMetaRenamerPass() { - return new MetaRenamer(); -} - PreservedAnalyses MetaRenamerPass::run(Module &M, ModuleAnalysisManager &AM) { FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 6d17a466957e..1e243ef74df7 100644 --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -12,6 +12,7 @@ #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/ADT/SmallString.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -19,6 +20,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/xxhash.h" + using namespace llvm; #define DEBUG_TYPE "moduleutils" @@ -31,11 +33,9 @@ static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F, // Get the current set of static global constructors and add the new ctor // to the list. SmallVector<Constant *, 16> CurrentCtors; - StructType *EltTy = StructType::get( - IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()), - IRB.getInt8PtrTy()); - + StructType *EltTy; if (GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName)) { + EltTy = cast<StructType>(GVCtor->getValueType()->getArrayElementType()); if (Constant *Init = GVCtor->getInitializer()) { unsigned n = Init->getNumOperands(); CurrentCtors.reserve(n + 1); @@ -43,6 +43,10 @@ static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F, CurrentCtors.push_back(cast<Constant>(Init->getOperand(i))); } GVCtor->eraseFromParent(); + } else { + EltTy = StructType::get( + IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()), + IRB.getInt8PtrTy()); } // Build a 3 field global_ctor entry. We don't take a comdat key. @@ -390,9 +394,7 @@ bool llvm::lowerGlobalIFuncUsersAsGlobalCtor( const DataLayout &DL = M.getDataLayout(); PointerType *TableEntryTy = - Ctx.supportsTypedPointers() - ? PointerType::get(Type::getInt8Ty(Ctx), DL.getProgramAddressSpace()) - : PointerType::get(Ctx, DL.getProgramAddressSpace()); + PointerType::get(Ctx, DL.getProgramAddressSpace()); ArrayType *FuncPtrTableTy = ArrayType::get(TableEntryTy, IFuncsToLower.size()); @@ -462,9 +464,7 @@ bool llvm::lowerGlobalIFuncUsersAsGlobalCtor( InitBuilder.CreateRetVoid(); - PointerType *ConstantDataTy = Ctx.supportsTypedPointers() - ? PointerType::get(Type::getInt8Ty(Ctx), 0) - : PointerType::get(Ctx, 0); + PointerType *ConstantDataTy = PointerType::get(Ctx, 0); // TODO: Is this the right priority? Probably should be before any other // constructors? diff --git a/llvm/lib/Transforms/Utils/MoveAutoInit.cpp b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp new file mode 100644 index 000000000000..b0ca0b15c08e --- /dev/null +++ b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp @@ -0,0 +1,231 @@ +//===-- MoveAutoInit.cpp - move auto-init inst closer to their use site----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass moves instruction maked as auto-init closer to the basic block that +// use it, eventually removing it from some control path of the function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/MoveAutoInit.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "move-auto-init" + +STATISTIC(NumMoved, "Number of instructions moved"); + +static cl::opt<unsigned> MoveAutoInitThreshold( + "move-auto-init-threshold", cl::Hidden, cl::init(128), + cl::desc("Maximum instructions to analyze per moved initialization")); + +static bool hasAutoInitMetadata(const Instruction &I) { + return I.hasMetadata(LLVMContext::MD_annotation) && + any_of(I.getMetadata(LLVMContext::MD_annotation)->operands(), + [](const MDOperand &Op) { return Op.equalsStr("auto-init"); }); +} + +static std::optional<MemoryLocation> writeToAlloca(const Instruction &I) { + MemoryLocation ML; + if (auto *MI = dyn_cast<MemIntrinsic>(&I)) + ML = MemoryLocation::getForDest(MI); + else if (auto *SI = dyn_cast<StoreInst>(&I)) + ML = MemoryLocation::get(SI); + else + assert(false && "memory location set"); + + if (isa<AllocaInst>(getUnderlyingObject(ML.Ptr))) + return ML; + else + return {}; +} + +/// Finds a BasicBlock in the CFG where instruction `I` can be moved to while +/// not changing the Memory SSA ordering and being guarded by at least one +/// condition. +static BasicBlock *usersDominator(const MemoryLocation &ML, Instruction *I, + DominatorTree &DT, MemorySSA &MSSA) { + BasicBlock *CurrentDominator = nullptr; + MemoryUseOrDef &IMA = *MSSA.getMemoryAccess(I); + BatchAAResults AA(MSSA.getAA()); + + SmallPtrSet<MemoryAccess *, 8> Visited; + + auto AsMemoryAccess = [](User *U) { return cast<MemoryAccess>(U); }; + SmallVector<MemoryAccess *> WorkList(map_range(IMA.users(), AsMemoryAccess)); + + while (!WorkList.empty()) { + MemoryAccess *MA = WorkList.pop_back_val(); + if (!Visited.insert(MA).second) + continue; + + if (Visited.size() > MoveAutoInitThreshold) + return nullptr; + + bool FoundClobberingUser = false; + if (auto *M = dyn_cast<MemoryUseOrDef>(MA)) { + Instruction *MI = M->getMemoryInst(); + + // If this memory instruction may not clobber `I`, we can skip it. + // LifetimeEnd is a valid user, but we do not want it in the user + // dominator. + if (AA.getModRefInfo(MI, ML) != ModRefInfo::NoModRef && + !MI->isLifetimeStartOrEnd() && MI != I) { + FoundClobberingUser = true; + CurrentDominator = CurrentDominator + ? DT.findNearestCommonDominator(CurrentDominator, + MI->getParent()) + : MI->getParent(); + } + } + if (!FoundClobberingUser) { + auto UsersAsMemoryAccesses = map_range(MA->users(), AsMemoryAccess); + append_range(WorkList, UsersAsMemoryAccesses); + } + } + return CurrentDominator; +} + +static bool runMoveAutoInit(Function &F, DominatorTree &DT, MemorySSA &MSSA) { + BasicBlock &EntryBB = F.getEntryBlock(); + SmallVector<std::pair<Instruction *, BasicBlock *>> JobList; + + // + // Compute movable instructions. + // + for (Instruction &I : EntryBB) { + if (!hasAutoInitMetadata(I)) + continue; + + std::optional<MemoryLocation> ML = writeToAlloca(I); + if (!ML) + continue; + + if (I.isVolatile()) + continue; + + BasicBlock *UsersDominator = usersDominator(ML.value(), &I, DT, MSSA); + if (!UsersDominator) + continue; + + if (UsersDominator == &EntryBB) + continue; + + // Traverse the CFG to detect cycles `UsersDominator` would be part of. + SmallPtrSet<BasicBlock *, 8> TransitiveSuccessors; + SmallVector<BasicBlock *> WorkList(successors(UsersDominator)); + bool HasCycle = false; + while (!WorkList.empty()) { + BasicBlock *CurrBB = WorkList.pop_back_val(); + if (CurrBB == UsersDominator) + // No early exit because we want to compute the full set of transitive + // successors. + HasCycle = true; + for (BasicBlock *Successor : successors(CurrBB)) { + if (!TransitiveSuccessors.insert(Successor).second) + continue; + WorkList.push_back(Successor); + } + } + + // Don't insert if that could create multiple execution of I, + // but we can insert it in the non back-edge predecessors, if it exists. + if (HasCycle) { + BasicBlock *UsersDominatorHead = UsersDominator; + while (BasicBlock *UniquePredecessor = + UsersDominatorHead->getUniquePredecessor()) + UsersDominatorHead = UniquePredecessor; + + if (UsersDominatorHead == &EntryBB) + continue; + + BasicBlock *DominatingPredecessor = nullptr; + for (BasicBlock *Pred : predecessors(UsersDominatorHead)) { + // If one of the predecessor of the dominator also transitively is a + // successor, moving to the dominator would do the inverse of loop + // hoisting, and we don't want that. + if (TransitiveSuccessors.count(Pred)) + continue; + + DominatingPredecessor = + DominatingPredecessor + ? DT.findNearestCommonDominator(DominatingPredecessor, Pred) + : Pred; + } + + if (!DominatingPredecessor || DominatingPredecessor == &EntryBB) + continue; + + UsersDominator = DominatingPredecessor; + } + + // CatchSwitchInst blocks can only have one instruction, so they are not + // good candidates for insertion. + while (isa<CatchSwitchInst>(UsersDominator->getFirstInsertionPt())) { + for (BasicBlock *Pred : predecessors(UsersDominator)) + UsersDominator = DT.findNearestCommonDominator(UsersDominator, Pred); + } + + // We finally found a place where I can be moved while not introducing extra + // execution, and guarded by at least one condition. + if (UsersDominator != &EntryBB) + JobList.emplace_back(&I, UsersDominator); + } + + // + // Perform the actual substitution. + // + if (JobList.empty()) + return false; + + MemorySSAUpdater MSSAU(&MSSA); + + // Reverse insertion to respect relative order between instructions: + // if two instructions are moved from the same BB to the same BB, we insert + // the second one in the front, then the first on top of it. + for (auto &Job : reverse(JobList)) { + Job.first->moveBefore(&*Job.second->getFirstInsertionPt()); + MSSAU.moveToPlace(MSSA.getMemoryAccess(Job.first), Job.first->getParent(), + MemorySSA::InsertionPlace::Beginning); + } + + if (VerifyMemorySSA) + MSSA.verifyMemorySSA(); + + NumMoved += JobList.size(); + + return true; +} + +PreservedAnalyses MoveAutoInitPass::run(Function &F, + FunctionAnalysisManager &AM) { + + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + if (!runMoveAutoInit(F, DT, MSSA)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); + PA.preserveSet<CFGAnalyses>(); + return PA; +} diff --git a/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp b/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp index d4ab4504064f..f41a14cdfbec 100644 --- a/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp +++ b/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp @@ -14,8 +14,6 @@ #include "llvm/Transforms/Utils/NameAnonGlobals.h" #include "llvm/ADT/SmallString.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/MD5.h" #include "llvm/Transforms/Utils/ModuleUtils.h" diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 75ea9dc5dfc0..2e5f40d39912 100644 --- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -118,19 +118,28 @@ public: /// Update assignment tracking debug info given for the to-be-deleted store /// \p ToDelete that stores to this alloca. - void updateForDeletedStore(StoreInst *ToDelete, DIBuilder &DIB) const { + void updateForDeletedStore( + StoreInst *ToDelete, DIBuilder &DIB, + SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete) const { // There's nothing to do if the alloca doesn't have any variables using // assignment tracking. - if (DbgAssigns.empty()) { - assert(at::getAssignmentMarkers(ToDelete).empty()); + if (DbgAssigns.empty()) return; - } - // Just leave dbg.assign intrinsics in place and remember that we've seen - // one for each variable fragment. - SmallSet<DebugVariable, 2> VarHasDbgAssignForStore; - for (DbgAssignIntrinsic *DAI : at::getAssignmentMarkers(ToDelete)) - VarHasDbgAssignForStore.insert(DebugVariable(DAI)); + // Insert a dbg.value where the linked dbg.assign is and remember to delete + // the dbg.assign later. Demoting to dbg.value isn't necessary for + // correctness but does reduce compile time and memory usage by reducing + // unnecessary function-local metadata. Remember that we've seen a + // dbg.assign for each variable fragment for the untracked store handling + // (after this loop). + SmallSet<DebugVariableAggregate, 2> VarHasDbgAssignForStore; + for (DbgAssignIntrinsic *DAI : at::getAssignmentMarkers(ToDelete)) { + VarHasDbgAssignForStore.insert(DebugVariableAggregate(DAI)); + DbgAssignsToDelete->insert(DAI); + DIB.insertDbgValueIntrinsic(DAI->getValue(), DAI->getVariable(), + DAI->getExpression(), DAI->getDebugLoc(), + DAI); + } // It's possible for variables using assignment tracking to have no // dbg.assign linked to this store. These are variables in DbgAssigns that @@ -141,7 +150,7 @@ public: // size) or one that is trackable but has had its DIAssignID attachment // dropped accidentally. for (auto *DAI : DbgAssigns) { - if (VarHasDbgAssignForStore.contains(DebugVariable(DAI))) + if (VarHasDbgAssignForStore.contains(DebugVariableAggregate(DAI))) continue; ConvertDebugDeclareToDebugValue(DAI, ToDelete, DIB); } @@ -324,6 +333,9 @@ struct PromoteMem2Reg { /// For each alloca, keep an instance of a helper class that gives us an easy /// way to update assignment tracking debug info if the alloca is promoted. SmallVector<AssignmentTrackingInfo, 8> AllocaATInfo; + /// A set of dbg.assigns to delete because they've been demoted to + /// dbg.values. Call cleanUpDbgAssigns to delete them. + SmallSet<DbgAssignIntrinsic *, 8> DbgAssignsToDelete; /// The set of basic blocks the renamer has already visited. SmallPtrSet<BasicBlock *, 16> Visited; @@ -367,6 +379,13 @@ private: RenamePassData::LocationVector &IncLocs, std::vector<RenamePassData> &Worklist); bool QueuePhiNode(BasicBlock *BB, unsigned AllocaIdx, unsigned &Version); + + /// Delete dbg.assigns that have been demoted to dbg.values. + void cleanUpDbgAssigns() { + for (auto *DAI : DbgAssignsToDelete) + DAI->eraseFromParent(); + DbgAssignsToDelete.clear(); + } }; } // end anonymous namespace @@ -438,9 +457,10 @@ static void removeIntrinsicUsers(AllocaInst *AI) { /// false there were some loads which were not dominated by the single store /// and thus must be phi-ed with undef. We fall back to the standard alloca /// promotion algorithm in that case. -static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, - LargeBlockInfo &LBI, const DataLayout &DL, - DominatorTree &DT, AssumptionCache *AC) { +static bool rewriteSingleStoreAlloca( + AllocaInst *AI, AllocaInfo &Info, LargeBlockInfo &LBI, const DataLayout &DL, + DominatorTree &DT, AssumptionCache *AC, + SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete) { StoreInst *OnlyStore = Info.OnlyStore; bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0)); BasicBlock *StoreBB = OnlyStore->getParent(); @@ -500,7 +520,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); // Update assignment tracking info for the store we're going to delete. - Info.AssignmentTracking.updateForDeletedStore(Info.OnlyStore, DIB); + Info.AssignmentTracking.updateForDeletedStore(Info.OnlyStore, DIB, + DbgAssignsToDelete); // Record debuginfo for the store and remove the declaration's // debuginfo. @@ -540,11 +561,10 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, /// use(t); /// *A = 42; /// } -static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, - LargeBlockInfo &LBI, - const DataLayout &DL, - DominatorTree &DT, - AssumptionCache *AC) { +static bool promoteSingleBlockAlloca( + AllocaInst *AI, const AllocaInfo &Info, LargeBlockInfo &LBI, + const DataLayout &DL, DominatorTree &DT, AssumptionCache *AC, + SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete) { // The trickiest case to handle is when we have large blocks. Because of this, // this code is optimized assuming that large blocks happen. This does not // significantly pessimize the small block case. This uses LargeBlockInfo to @@ -608,7 +628,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, while (!AI->use_empty()) { StoreInst *SI = cast<StoreInst>(AI->user_back()); // Update assignment tracking info for the store we're going to delete. - Info.AssignmentTracking.updateForDeletedStore(SI, DIB); + Info.AssignmentTracking.updateForDeletedStore(SI, DIB, DbgAssignsToDelete); // Record debuginfo for the store before removing it. for (DbgVariableIntrinsic *DII : Info.DbgUsers) { if (DII->isAddressOfVariable()) { @@ -668,7 +688,8 @@ void PromoteMem2Reg::run() { // If there is only a single store to this value, replace any loads of // it that are directly dominated by the definition with the value stored. if (Info.DefiningBlocks.size() == 1) { - if (rewriteSingleStoreAlloca(AI, Info, LBI, SQ.DL, DT, AC)) { + if (rewriteSingleStoreAlloca(AI, Info, LBI, SQ.DL, DT, AC, + &DbgAssignsToDelete)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); ++NumSingleStore; @@ -679,7 +700,8 @@ void PromoteMem2Reg::run() { // If the alloca is only read and written in one basic block, just perform a // linear sweep over the block to eliminate it. if (Info.OnlyUsedInOneBlock && - promoteSingleBlockAlloca(AI, Info, LBI, SQ.DL, DT, AC)) { + promoteSingleBlockAlloca(AI, Info, LBI, SQ.DL, DT, AC, + &DbgAssignsToDelete)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); continue; @@ -728,9 +750,10 @@ void PromoteMem2Reg::run() { QueuePhiNode(BB, AllocaNum, CurrentVersion); } - if (Allocas.empty()) + if (Allocas.empty()) { + cleanUpDbgAssigns(); return; // All of the allocas must have been trivial! - + } LBI.clear(); // Set the incoming values for the basic block to be null values for all of @@ -812,7 +835,7 @@ void PromoteMem2Reg::run() { // code. Unfortunately, there may be unreachable blocks which the renamer // hasn't traversed. If this is the case, the PHI nodes may not // have incoming values for all predecessors. Loop over all PHI nodes we have - // created, inserting undef values if they are missing any incoming values. + // created, inserting poison values if they are missing any incoming values. for (DenseMap<std::pair<unsigned, unsigned>, PHINode *>::iterator I = NewPhiNodes.begin(), E = NewPhiNodes.end(); @@ -862,13 +885,14 @@ void PromoteMem2Reg::run() { BasicBlock::iterator BBI = BB->begin(); while ((SomePHI = dyn_cast<PHINode>(BBI++)) && SomePHI->getNumIncomingValues() == NumBadPreds) { - Value *UndefVal = UndefValue::get(SomePHI->getType()); + Value *PoisonVal = PoisonValue::get(SomePHI->getType()); for (BasicBlock *Pred : Preds) - SomePHI->addIncoming(UndefVal, Pred); + SomePHI->addIncoming(PoisonVal, Pred); } } NewPhiNodes.clear(); + cleanUpDbgAssigns(); } /// Determine which blocks the value is live in. @@ -1072,7 +1096,8 @@ NextIteration: // Record debuginfo for the store before removing it. IncomingLocs[AllocaNo] = SI->getDebugLoc(); - AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB); + AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB, + &DbgAssignsToDelete); for (DbgVariableIntrinsic *DII : AllocaDbgUsers[ai->second]) if (DII->isAddressOfVariable()) ConvertDebugDeclareToDebugValue(DII, SI, DIB); diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 8d03a0d8a2c4..de3626a24212 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -41,6 +42,14 @@ static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() { MaxNumRangeExtensions); } +static ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty, + bool UndefAllowed = true) { + assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector"); + if (LV.isConstantRange(UndefAllowed)) + return LV.getConstantRange(); + return ConstantRange::getFull(Ty->getScalarSizeInBits()); +} + namespace llvm { bool SCCPSolver::isConstant(const ValueLatticeElement &LV) { @@ -65,30 +74,9 @@ static bool canRemoveInstruction(Instruction *I) { } bool SCCPSolver::tryToReplaceWithConstant(Value *V) { - Constant *Const = nullptr; - if (V->getType()->isStructTy()) { - std::vector<ValueLatticeElement> IVs = getStructLatticeValueFor(V); - if (llvm::any_of(IVs, isOverdefined)) - return false; - std::vector<Constant *> ConstVals; - auto *ST = cast<StructType>(V->getType()); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - ValueLatticeElement V = IVs[i]; - ConstVals.push_back(SCCPSolver::isConstant(V) - ? getConstant(V) - : UndefValue::get(ST->getElementType(i))); - } - Const = ConstantStruct::get(ST, ConstVals); - } else { - const ValueLatticeElement &IV = getLatticeValueFor(V); - if (isOverdefined(IV)) - return false; - - Const = SCCPSolver::isConstant(IV) ? getConstant(IV) - : UndefValue::get(V->getType()); - } - assert(Const && "Constant is nullptr here!"); - + Constant *Const = getConstantOrNull(V); + if (!Const) + return false; // Replacing `musttail` instructions with constant breaks `musttail` invariant // unless the call itself can be removed. // Calls with "clang.arc.attachedcall" implicitly use the return value and @@ -115,6 +103,47 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) { return true; } +/// Try to use \p Inst's value range from \p Solver to infer the NUW flag. +static bool refineInstruction(SCCPSolver &Solver, + const SmallPtrSetImpl<Value *> &InsertedValues, + Instruction &Inst) { + if (!isa<OverflowingBinaryOperator>(Inst)) + return false; + + auto GetRange = [&Solver, &InsertedValues](Value *Op) { + if (auto *Const = dyn_cast<ConstantInt>(Op)) + return ConstantRange(Const->getValue()); + if (isa<Constant>(Op) || InsertedValues.contains(Op)) { + unsigned Bitwidth = Op->getType()->getScalarSizeInBits(); + return ConstantRange::getFull(Bitwidth); + } + return getConstantRange(Solver.getLatticeValueFor(Op), Op->getType(), + /*UndefAllowed=*/false); + }; + auto RangeA = GetRange(Inst.getOperand(0)); + auto RangeB = GetRange(Inst.getOperand(1)); + bool Changed = false; + if (!Inst.hasNoUnsignedWrap()) { + auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::BinaryOps(Inst.getOpcode()), RangeB, + OverflowingBinaryOperator::NoUnsignedWrap); + if (NUWRange.contains(RangeA)) { + Inst.setHasNoUnsignedWrap(); + Changed = true; + } + } + if (!Inst.hasNoSignedWrap()) { + auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::BinaryOps(Inst.getOpcode()), RangeB, OverflowingBinaryOperator::NoSignedWrap); + if (NSWRange.contains(RangeA)) { + Inst.setHasNoSignedWrap(); + Changed = true; + } + } + + return Changed; +} + /// Try to replace signed instructions with their unsigned equivalent. static bool replaceSignedInst(SCCPSolver &Solver, SmallPtrSetImpl<Value *> &InsertedValues, @@ -195,6 +224,8 @@ bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB, } else if (replaceSignedInst(*this, InsertedValues, Inst)) { MadeChanges = true; ++InstReplacedStat; + } else if (refineInstruction(*this, InsertedValues, Inst)) { + MadeChanges = true; } } return MadeChanges; @@ -322,6 +353,10 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> { MapVector<std::pair<Function *, unsigned>, ValueLatticeElement> TrackedMultipleRetVals; + /// The set of values whose lattice has been invalidated. + /// Populated by resetLatticeValueFor(), cleared after resolving undefs. + DenseSet<Value *> Invalidated; + /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. SmallPtrSet<Function *, 16> MRVFunctionsTracked; @@ -352,14 +387,15 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> { using Edge = std::pair<BasicBlock *, BasicBlock *>; DenseSet<Edge> KnownFeasibleEdges; - DenseMap<Function *, AnalysisResultsForFn> AnalysisResults; + DenseMap<Function *, std::unique_ptr<PredicateInfo>> FnPredicateInfo; + DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers; LLVMContext &Ctx; private: - ConstantInt *getConstantInt(const ValueLatticeElement &IV) const { - return dyn_cast_or_null<ConstantInt>(getConstant(IV)); + ConstantInt *getConstantInt(const ValueLatticeElement &IV, Type *Ty) const { + return dyn_cast_or_null<ConstantInt>(getConstant(IV, Ty)); } // pushToWorkList - Helper for markConstant/markOverdefined @@ -447,6 +483,64 @@ private: return LV; } + /// Traverse the use-def chain of \p Call, marking itself and its users as + /// "unknown" on the way. + void invalidate(CallBase *Call) { + SmallVector<Instruction *, 64> ToInvalidate; + ToInvalidate.push_back(Call); + + while (!ToInvalidate.empty()) { + Instruction *Inst = ToInvalidate.pop_back_val(); + + if (!Invalidated.insert(Inst).second) + continue; + + if (!BBExecutable.count(Inst->getParent())) + continue; + + Value *V = nullptr; + // For return instructions we need to invalidate the tracked returns map. + // Anything else has its lattice in the value map. + if (auto *RetInst = dyn_cast<ReturnInst>(Inst)) { + Function *F = RetInst->getParent()->getParent(); + if (auto It = TrackedRetVals.find(F); It != TrackedRetVals.end()) { + It->second = ValueLatticeElement(); + V = F; + } else if (MRVFunctionsTracked.count(F)) { + auto *STy = cast<StructType>(F->getReturnType()); + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) + TrackedMultipleRetVals[{F, I}] = ValueLatticeElement(); + V = F; + } + } else if (auto *STy = dyn_cast<StructType>(Inst->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + if (auto It = StructValueState.find({Inst, I}); + It != StructValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + } + } else if (auto It = ValueState.find(Inst); It != ValueState.end()) { + It->second = ValueLatticeElement(); + V = Inst; + } + + if (V) { + LLVM_DEBUG(dbgs() << "Invalidated lattice for " << *V << "\n"); + + for (User *U : V->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + ToInvalidate.push_back(UI); + + auto It = AdditionalUsers.find(V); + if (It != AdditionalUsers.end()) + for (User *U : It->second) + if (auto *UI = dyn_cast<Instruction>(U)) + ToInvalidate.push_back(UI); + } + } + } + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); @@ -520,6 +614,7 @@ private: void visitCastInst(CastInst &I); void visitSelectInst(SelectInst &I); void visitUnaryOperator(Instruction &I); + void visitFreezeInst(FreezeInst &I); void visitBinaryOperator(Instruction &I); void visitCmpInst(CmpInst &I); void visitExtractValueInst(ExtractValueInst &EVI); @@ -557,8 +652,8 @@ private: void visitInstruction(Instruction &I); public: - void addAnalysis(Function &F, AnalysisResultsForFn A) { - AnalysisResults.insert({&F, std::move(A)}); + void addPredicateInfo(Function &F, DominatorTree &DT, AssumptionCache &AC) { + FnPredicateInfo.insert({&F, std::make_unique<PredicateInfo>(F, DT, AC)}); } void visitCallInst(CallInst &I) { visitCallBase(I); } @@ -566,23 +661,10 @@ public: bool markBlockExecutable(BasicBlock *BB); const PredicateBase *getPredicateInfoFor(Instruction *I) { - auto A = AnalysisResults.find(I->getParent()->getParent()); - if (A == AnalysisResults.end()) + auto It = FnPredicateInfo.find(I->getParent()->getParent()); + if (It == FnPredicateInfo.end()) return nullptr; - return A->second.PredInfo->getPredicateInfoFor(I); - } - - const LoopInfo &getLoopInfo(Function &F) { - auto A = AnalysisResults.find(&F); - assert(A != AnalysisResults.end() && A->second.LI && - "Need LoopInfo analysis results for function."); - return *A->second.LI; - } - - DomTreeUpdater getDTU(Function &F) { - auto A = AnalysisResults.find(&F); - assert(A != AnalysisResults.end() && "Need analysis results for function."); - return {A->second.DT, A->second.PDT, DomTreeUpdater::UpdateStrategy::Lazy}; + return It->second->getPredicateInfoFor(I); } SCCPInstVisitor(const DataLayout &DL, @@ -627,6 +709,8 @@ public: void solve(); + bool resolvedUndef(Instruction &I); + bool resolvedUndefsIn(Function &F); bool isBlockExecutable(BasicBlock *BB) const { @@ -649,6 +733,19 @@ public: void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + /// Invalidate the Lattice Value of \p Call and its users after specializing + /// the call. Then recompute it. + void resetLatticeValueFor(CallBase *Call) { + // Calls to void returning functions do not need invalidation. + Function *F = Call->getCalledFunction(); + (void)F; + assert(!F->getReturnType()->isVoidTy() && + (TrackedRetVals.count(F) || MRVFunctionsTracked.count(F)) && + "All non void specializations should be tracked"); + invalidate(Call); + handleCallResult(*Call); + } + const ValueLatticeElement &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); @@ -681,15 +778,16 @@ public: bool isStructLatticeConstant(Function *F, StructType *STy); - Constant *getConstant(const ValueLatticeElement &LV) const; - ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty) const; + Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const; + + Constant *getConstantOrNull(Value *V) const; SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions() { return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl<ArgInfo> &Args); + void setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl<ArgInfo> &Args); void markFunctionUnreachable(Function *F) { for (auto &BB : *F) @@ -715,6 +813,18 @@ public: ResolvedUndefs |= resolvedUndefsIn(*F); } } + + void solveWhileResolvedUndefs() { + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + solve(); + ResolvedUndefs = false; + for (Value *V : Invalidated) + if (auto *I = dyn_cast<Instruction>(V)) + ResolvedUndefs |= resolvedUndef(*I); + } + Invalidated.clear(); + } }; } // namespace llvm @@ -728,9 +838,13 @@ bool SCCPInstVisitor::markBlockExecutable(BasicBlock *BB) { } void SCCPInstVisitor::pushToWorkList(ValueLatticeElement &IV, Value *V) { - if (IV.isOverdefined()) - return OverdefinedInstWorkList.push_back(V); - InstWorkList.push_back(V); + if (IV.isOverdefined()) { + if (OverdefinedInstWorkList.empty() || OverdefinedInstWorkList.back() != V) + OverdefinedInstWorkList.push_back(V); + return; + } + if (InstWorkList.empty() || InstWorkList.back() != V) + InstWorkList.push_back(V); } void SCCPInstVisitor::pushToWorkListMsg(ValueLatticeElement &IV, Value *V) { @@ -771,57 +885,84 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) { return true; } -Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const { - if (LV.isConstant()) - return LV.getConstant(); +Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV, + Type *Ty) const { + if (LV.isConstant()) { + Constant *C = LV.getConstant(); + assert(C->getType() == Ty && "Type mismatch"); + return C; + } if (LV.isConstantRange()) { const auto &CR = LV.getConstantRange(); if (CR.getSingleElement()) - return ConstantInt::get(Ctx, *CR.getSingleElement()); + return ConstantInt::get(Ty, *CR.getSingleElement()); } return nullptr; } -ConstantRange -SCCPInstVisitor::getConstantRange(const ValueLatticeElement &LV, - Type *Ty) const { - assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector"); - if (LV.isConstantRange()) - return LV.getConstantRange(); - return ConstantRange::getFull(Ty->getScalarSizeInBits()); +Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { + Constant *Const = nullptr; + if (V->getType()->isStructTy()) { + std::vector<ValueLatticeElement> LVs = getStructLatticeValueFor(V); + if (any_of(LVs, SCCPSolver::isOverdefined)) + return nullptr; + std::vector<Constant *> ConstVals; + auto *ST = cast<StructType>(V->getType()); + for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { + ValueLatticeElement LV = LVs[I]; + ConstVals.push_back(SCCPSolver::isConstant(LV) + ? getConstant(LV, ST->getElementType(I)) + : UndefValue::get(ST->getElementType(I))); + } + Const = ConstantStruct::get(ST, ConstVals); + } else { + const ValueLatticeElement &LV = getLatticeValueFor(V); + if (SCCPSolver::isOverdefined(LV)) + return nullptr; + Const = SCCPSolver::isConstant(LV) ? getConstant(LV, V->getType()) + : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + return Const; } -void SCCPInstVisitor::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl<ArgInfo> &Args) { +void SCCPInstVisitor::setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl<ArgInfo> &Args) { assert(!Args.empty() && "Specialization without arguments"); assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() && "Functions should have the same number of arguments"); auto Iter = Args.begin(); - Argument *NewArg = F->arg_begin(); - Argument *OldArg = Args[0].Formal->getParent()->arg_begin(); + Function::arg_iterator NewArg = F->arg_begin(); + Function::arg_iterator OldArg = Args[0].Formal->getParent()->arg_begin(); for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { LLVM_DEBUG(dbgs() << "SCCP: Marking argument " << NewArg->getNameOrAsOperand() << "\n"); - if (Iter != Args.end() && OldArg == Iter->Formal) { - // Mark the argument constants in the new function. - markConstant(NewArg, Iter->Actual); + // Mark the argument constants in the new function + // or copy the lattice state over from the old function. + if (Iter != Args.end() && Iter->Formal == &*OldArg) { + if (auto *STy = dyn_cast<StructType>(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}]; + NewValue.markConstant(Iter->Actual->getAggregateElement(I)); + } + } else { + ValueState[&*NewArg].markConstant(Iter->Actual); + } ++Iter; - } else if (ValueState.count(OldArg)) { - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - // - // Note: This previously looked like this: - // ValueState[NewArg] = ValueState[OldArg]; - // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `NewArg`, which will invalidate the reference to - // `OldArg`. Instead, we make sure `NewArg` exists before setting it. - auto &NewValue = ValueState[NewArg]; - NewValue = ValueState[OldArg]; - pushToWorkList(NewValue, NewArg); + } else { + if (auto *STy = dyn_cast<StructType>(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}]; + NewValue = StructValueState[{&*OldArg, I}]; + } + } else { + ValueLatticeElement &NewValue = ValueState[&*NewArg]; + NewValue = ValueState[&*OldArg]; + } } } } @@ -874,7 +1015,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, } ValueLatticeElement BCValue = getValueState(BI->getCondition()); - ConstantInt *CI = getConstantInt(BCValue); + ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType()); if (!CI) { // Overdefined condition variables, and branches on unfoldable constant // conditions, mean the branch could go either way. @@ -900,7 +1041,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } const ValueLatticeElement &SCValue = getValueState(SI->getCondition()); - if (ConstantInt *CI = getConstantInt(SCValue)) { + if (ConstantInt *CI = + getConstantInt(SCValue, SI->getCondition()->getType())) { Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true; return; } @@ -931,7 +1073,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { // Casts are folded by visitCastInst. ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); - BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue)); + BlockAddress *Addr = dyn_cast_or_null<BlockAddress>( + getConstant(IBRValue, IBR->getAddress()->getType())); if (!Addr) { // Overdefined or unknown condition? // All destinations are executable! if (!IBRValue.isUnknownOrUndef()) @@ -1086,7 +1229,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { if (OpSt.isUnknownOrUndef()) return; - if (Constant *OpC = getConstant(OpSt)) { + if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) { // Fold the constant as we build. Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); markConstant(&I, C); @@ -1221,7 +1364,8 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) { if (CondValue.isUnknownOrUndef()) return; - if (ConstantInt *CondCB = getConstantInt(CondValue)) { + if (ConstantInt *CondCB = + getConstantInt(CondValue, I.getCondition()->getType())) { Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); mergeInValue(&I, getValueState(OpVal)); return; @@ -1254,13 +1398,37 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) { return; if (SCCPSolver::isConstant(V0State)) - if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(), - getConstant(V0State), DL)) + if (Constant *C = ConstantFoldUnaryOpOperand( + I.getOpcode(), getConstant(V0State, I.getType()), DL)) return (void)markConstant(IV, &I, C); markOverdefined(&I); } +void SCCPInstVisitor::visitFreezeInst(FreezeInst &I) { + // If this freeze returns a struct, just mark the result overdefined. + // TODO: We could do a lot better than this. + if (I.getType()->isStructTy()) + return (void)markOverdefined(&I); + + ValueLatticeElement V0State = getValueState(I.getOperand(0)); + ValueLatticeElement &IV = ValueState[&I]; + // resolvedUndefsIn might mark I as overdefined. Bail out, even if we would + // discover a concrete value later. + if (SCCPSolver::isOverdefined(IV)) + return (void)markOverdefined(&I); + + // If something is unknown/undef, wait for it to resolve. + if (V0State.isUnknownOrUndef()) + return; + + if (SCCPSolver::isConstant(V0State) && + isGuaranteedNotToBeUndefOrPoison(getConstant(V0State, I.getType()))) + return (void)markConstant(IV, &I, getConstant(V0State, I.getType())); + + markOverdefined(&I); +} + // Handle Binary Operators. void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { ValueLatticeElement V1State = getValueState(I.getOperand(0)); @@ -1280,10 +1448,12 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { // If either of the operands is a constant, try to fold it to a constant. // TODO: Use information from notconstant better. if ((V1State.isConstant() || V2State.isConstant())) { - Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State) - : I.getOperand(0); - Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State) - : I.getOperand(1); + Value *V1 = SCCPSolver::isConstant(V1State) + ? getConstant(V1State, I.getOperand(0)->getType()) + : I.getOperand(0); + Value *V2 = SCCPSolver::isConstant(V2State) + ? getConstant(V2State, I.getOperand(1)->getType()) + : I.getOperand(1); Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); auto *C = dyn_cast_or_null<Constant>(R); if (C) { @@ -1361,7 +1531,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&I); - if (Constant *C = getConstant(State)) { + if (Constant *C = getConstant(State, I.getOperand(i)->getType())) { Operands.push_back(C); continue; } @@ -1427,7 +1597,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { ValueLatticeElement &IV = ValueState[&I]; if (SCCPSolver::isConstant(PtrVal)) { - Constant *Ptr = getConstant(PtrVal); + Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType()); // load null is undefined. if (isa<ConstantPointerNull>(Ptr)) { @@ -1490,7 +1660,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&CB); assert(SCCPSolver::isConstant(State) && "Unknown state!"); - Operands.push_back(getConstant(State)); + Operands.push_back(getConstant(State, A->getType())); } if (SCCPSolver::isOverdefined(getValueState(&CB))) @@ -1622,6 +1792,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { SmallVector<ConstantRange, 2> OpRanges; for (Value *Op : II->args()) { const ValueLatticeElement &State = getValueState(Op); + if (State.isUnknownOrUndef()) + return; OpRanges.push_back(getConstantRange(State, Op->getType())); } @@ -1666,6 +1838,7 @@ void SCCPInstVisitor::solve() { // things to overdefined more quickly. while (!OverdefinedInstWorkList.empty()) { Value *I = OverdefinedInstWorkList.pop_back_val(); + Invalidated.erase(I); LLVM_DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n'); @@ -1682,6 +1855,7 @@ void SCCPInstVisitor::solve() { // Process the instruction work list. while (!InstWorkList.empty()) { Value *I = InstWorkList.pop_back_val(); + Invalidated.erase(I); LLVM_DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n'); @@ -1709,6 +1883,61 @@ void SCCPInstVisitor::solve() { } } +bool SCCPInstVisitor::resolvedUndef(Instruction &I) { + // Look for instructions which produce undef values. + if (I.getType()->isVoidTy()) + return false; + + if (auto *STy = dyn_cast<StructType>(I.getType())) { + // Only a few things that can be structs matter for undef. + + // Tracked calls must never be marked overdefined in resolvedUndefsIn. + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *F = CB->getCalledFunction()) + if (MRVFunctionsTracked.count(F)) + return false; + + // extractvalue and insertvalue don't need to be marked; they are + // tracked as precisely as their operands. + if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)) + return false; + // Send the results of everything else to overdefined. We could be + // more precise than this but it isn't worth bothering. + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + ValueLatticeElement &LV = getStructValueState(&I, i); + if (LV.isUnknown()) { + markOverdefined(LV, &I); + return true; + } + } + return false; + } + + ValueLatticeElement &LV = getValueState(&I); + if (!LV.isUnknown()) + return false; + + // There are two reasons a call can have an undef result + // 1. It could be tracked. + // 2. It could be constant-foldable. + // Because of the way we solve return values, tracked calls must + // never be marked overdefined in resolvedUndefsIn. + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *F = CB->getCalledFunction()) + if (TrackedRetVals.count(F)) + return false; + + if (isa<LoadInst>(I)) { + // A load here means one of two things: a load of undef from a global, + // a load from an unknown pointer. Either way, having it return undef + // is okay. + return false; + } + + markOverdefined(&I); + return true; +} + /// While solving the dataflow for a function, we don't compute a result for /// operations with an undef operand, to allow undef to be lowered to a /// constant later. For example, constant folding of "zext i8 undef to i16" @@ -1728,60 +1957,8 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { if (!BBExecutable.count(&BB)) continue; - for (Instruction &I : BB) { - // Look for instructions which produce undef values. - if (I.getType()->isVoidTy()) - continue; - - if (auto *STy = dyn_cast<StructType>(I.getType())) { - // Only a few things that can be structs matter for undef. - - // Tracked calls must never be marked overdefined in resolvedUndefsIn. - if (auto *CB = dyn_cast<CallBase>(&I)) - if (Function *F = CB->getCalledFunction()) - if (MRVFunctionsTracked.count(F)) - continue; - - // extractvalue and insertvalue don't need to be marked; they are - // tracked as precisely as their operands. - if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)) - continue; - // Send the results of everything else to overdefined. We could be - // more precise than this but it isn't worth bothering. - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - ValueLatticeElement &LV = getStructValueState(&I, i); - if (LV.isUnknown()) { - markOverdefined(LV, &I); - MadeChange = true; - } - } - continue; - } - - ValueLatticeElement &LV = getValueState(&I); - if (!LV.isUnknown()) - continue; - - // There are two reasons a call can have an undef result - // 1. It could be tracked. - // 2. It could be constant-foldable. - // Because of the way we solve return values, tracked calls must - // never be marked overdefined in resolvedUndefsIn. - if (auto *CB = dyn_cast<CallBase>(&I)) - if (Function *F = CB->getCalledFunction()) - if (TrackedRetVals.count(F)) - continue; - - if (isa<LoadInst>(I)) { - // A load here means one of two things: a load of undef from a global, - // a load from an unknown pointer. Either way, having it return undef - // is okay. - continue; - } - - markOverdefined(&I); - MadeChange = true; - } + for (Instruction &I : BB) + MadeChange |= resolvedUndef(I); } LLVM_DEBUG(if (MadeChange) dbgs() @@ -1802,8 +1979,9 @@ SCCPSolver::SCCPSolver( SCCPSolver::~SCCPSolver() = default; -void SCCPSolver::addAnalysis(Function &F, AnalysisResultsForFn A) { - return Visitor->addAnalysis(F, std::move(A)); +void SCCPSolver::addPredicateInfo(Function &F, DominatorTree &DT, + AssumptionCache &AC) { + Visitor->addPredicateInfo(F, DT, AC); } bool SCCPSolver::markBlockExecutable(BasicBlock *BB) { @@ -1814,12 +1992,6 @@ const PredicateBase *SCCPSolver::getPredicateInfoFor(Instruction *I) { return Visitor->getPredicateInfoFor(I); } -const LoopInfo &SCCPSolver::getLoopInfo(Function &F) { - return Visitor->getLoopInfo(F); -} - -DomTreeUpdater SCCPSolver::getDTU(Function &F) { return Visitor->getDTU(F); } - void SCCPSolver::trackValueOfGlobalVariable(GlobalVariable *GV) { Visitor->trackValueOfGlobalVariable(GV); } @@ -1859,6 +2031,10 @@ SCCPSolver::solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList) { Visitor->solveWhileResolvedUndefsIn(WorkList); } +void SCCPSolver::solveWhileResolvedUndefs() { + Visitor->solveWhileResolvedUndefs(); +} + bool SCCPSolver::isBlockExecutable(BasicBlock *BB) const { return Visitor->isBlockExecutable(BB); } @@ -1876,6 +2052,10 @@ void SCCPSolver::removeLatticeValueFor(Value *V) { return Visitor->removeLatticeValueFor(V); } +void SCCPSolver::resetLatticeValueFor(CallBase *Call) { + Visitor->resetLatticeValueFor(Call); +} + const ValueLatticeElement &SCCPSolver::getLatticeValueFor(Value *V) const { return Visitor->getLatticeValueFor(V); } @@ -1900,17 +2080,22 @@ bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) { return Visitor->isStructLatticeConstant(F, STy); } -Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const { - return Visitor->getConstant(LV); +Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV, + Type *Ty) const { + return Visitor->getConstant(LV, Ty); +} + +Constant *SCCPSolver::getConstantOrNull(Value *V) const { + return Visitor->getConstantOrNull(V); } SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() { return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl<ArgInfo> &Args) { - Visitor->markArgInFuncSpecialization(F, Args); +void SCCPSolver::setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl<ArgInfo> &Args) { + Visitor->setLatticeValueForSpecializationArguments(F, Args); } void SCCPSolver::markFunctionUnreachable(Function *F) { diff --git a/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/llvm/lib/Transforms/Utils/SSAUpdater.cpp index 2520aa5d9db0..ebe9cb27f5ab 100644 --- a/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -19,6 +19,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -195,6 +196,33 @@ void SSAUpdater::RewriteUse(Use &U) { U.set(V); } +void SSAUpdater::UpdateDebugValues(Instruction *I) { + SmallVector<DbgValueInst *, 4> DbgValues; + llvm::findDbgValues(DbgValues, I); + for (auto &DbgValue : DbgValues) { + if (DbgValue->getParent() == I->getParent()) + continue; + UpdateDebugValue(I, DbgValue); + } +} + +void SSAUpdater::UpdateDebugValues(Instruction *I, + SmallVectorImpl<DbgValueInst *> &DbgValues) { + for (auto &DbgValue : DbgValues) { + UpdateDebugValue(I, DbgValue); + } +} + +void SSAUpdater::UpdateDebugValue(Instruction *I, DbgValueInst *DbgValue) { + BasicBlock *UserBB = DbgValue->getParent(); + if (HasValueForBlock(UserBB)) { + Value *NewVal = GetValueAtEndOfBlock(UserBB); + DbgValue->replaceVariableLocationOp(I, NewVal); + } + else + DbgValue->setKillLocation(); +} + void SSAUpdater::RewriteUseAfterInsertions(Use &U) { Instruction *User = cast<Instruction>(U.getUser()); diff --git a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp index 691ee00bd831..31d62fbf0618 100644 --- a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp +++ b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp @@ -20,6 +20,7 @@ #include <queue> #include <set> #include <stack> +#include <unordered_set> using namespace llvm; #define DEBUG_TYPE "sample-profile-inference" @@ -1218,10 +1219,23 @@ void extractWeights(const ProfiParams &Params, MinCostMaxFlow &Network, #ifndef NDEBUG /// Verify that the provided block/jump weights are as expected. void verifyInput(const FlowFunction &Func) { - // Verify the entry block + // Verify entry and exit blocks assert(Func.Entry == 0 && Func.Blocks[0].isEntry()); + size_t NumExitBlocks = 0; for (size_t I = 1; I < Func.Blocks.size(); I++) { assert(!Func.Blocks[I].isEntry() && "multiple entry blocks"); + if (Func.Blocks[I].isExit()) + NumExitBlocks++; + } + assert(NumExitBlocks > 0 && "cannot find exit blocks"); + + // Verify that there are no parallel edges + for (auto &Block : Func.Blocks) { + std::unordered_set<uint64_t> UniqueSuccs; + for (auto &Jump : Block.SuccJumps) { + auto It = UniqueSuccs.insert(Jump->Target); + assert(It.second && "input CFG contains parallel edges"); + } } // Verify CFG jumps for (auto &Block : Func.Blocks) { @@ -1304,8 +1318,26 @@ void verifyOutput(const FlowFunction &Func) { } // end of anonymous namespace -/// Apply the profile inference algorithm for a given function +/// Apply the profile inference algorithm for a given function and provided +/// profi options void llvm::applyFlowInference(const ProfiParams &Params, FlowFunction &Func) { + // Check if the function has samples and assign initial flow values + bool HasSamples = false; + for (FlowBlock &Block : Func.Blocks) { + if (Block.Weight > 0) + HasSamples = true; + Block.Flow = Block.Weight; + } + for (FlowJump &Jump : Func.Jumps) { + if (Jump.Weight > 0) + HasSamples = true; + Jump.Flow = Jump.Weight; + } + + // Quit early for functions with a single block or ones w/o samples + if (Func.Blocks.size() <= 1 || !HasSamples) + return; + #ifndef NDEBUG // Verify the input data verifyInput(Func); diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 24f1966edd37..20844271b943 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -163,7 +163,7 @@ Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) { "InsertNoopCastOfTo cannot change sizes!"); // inttoptr only works for integral pointers. For non-integral pointers, we - // can create a GEP on i8* null with the integral value as index. Note that + // can create a GEP on null with the integral value as index. Note that // it is safe to use GEP of null instead of inttoptr here, because only // expressions already based on a GEP of null should be converted to pointers // during expansion. @@ -173,9 +173,8 @@ Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) { auto *Int8PtrTy = Builder.getInt8PtrTy(PtrTy->getAddressSpace()); assert(DL.getTypeAllocSize(Builder.getInt8Ty()) == 1 && "alloc size of i8 must by 1 byte for the GEP to be correct"); - auto *GEP = Builder.CreateGEP( - Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "uglygep"); - return Builder.CreateBitCast(GEP, Ty); + return Builder.CreateGEP( + Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "scevgep"); } } // Short-circuit unnecessary bitcasts. @@ -287,142 +286,6 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, return BO; } -/// FactorOutConstant - Test if S is divisible by Factor, using signed -/// division. If so, update S with Factor divided out and return true. -/// S need not be evenly divisible if a reasonable remainder can be -/// computed. -static bool FactorOutConstant(const SCEV *&S, const SCEV *&Remainder, - const SCEV *Factor, ScalarEvolution &SE, - const DataLayout &DL) { - // Everything is divisible by one. - if (Factor->isOne()) - return true; - - // x/x == 1. - if (S == Factor) { - S = SE.getConstant(S->getType(), 1); - return true; - } - - // For a Constant, check for a multiple of the given factor. - if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { - // 0/x == 0. - if (C->isZero()) - return true; - // Check for divisibility. - if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor)) { - ConstantInt *CI = - ConstantInt::get(SE.getContext(), C->getAPInt().sdiv(FC->getAPInt())); - // If the quotient is zero and the remainder is non-zero, reject - // the value at this scale. It will be considered for subsequent - // smaller scales. - if (!CI->isZero()) { - const SCEV *Div = SE.getConstant(CI); - S = Div; - Remainder = SE.getAddExpr( - Remainder, SE.getConstant(C->getAPInt().srem(FC->getAPInt()))); - return true; - } - } - } - - // In a Mul, check if there is a constant operand which is a multiple - // of the given factor. - if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { - // Size is known, check if there is a constant operand which is a multiple - // of the given factor. If so, we can factor it. - if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor)) - if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0))) - if (!C->getAPInt().srem(FC->getAPInt())) { - SmallVector<const SCEV *, 4> NewMulOps(M->operands()); - NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt())); - S = SE.getMulExpr(NewMulOps); - return true; - } - } - - // In an AddRec, check if both start and step are divisible. - if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { - const SCEV *Step = A->getStepRecurrence(SE); - const SCEV *StepRem = SE.getConstant(Step->getType(), 0); - if (!FactorOutConstant(Step, StepRem, Factor, SE, DL)) - return false; - if (!StepRem->isZero()) - return false; - const SCEV *Start = A->getStart(); - if (!FactorOutConstant(Start, Remainder, Factor, SE, DL)) - return false; - S = SE.getAddRecExpr(Start, Step, A->getLoop(), - A->getNoWrapFlags(SCEV::FlagNW)); - return true; - } - - return false; -} - -/// SimplifyAddOperands - Sort and simplify a list of add operands. NumAddRecs -/// is the number of SCEVAddRecExprs present, which are kept at the end of -/// the list. -/// -static void SimplifyAddOperands(SmallVectorImpl<const SCEV *> &Ops, - Type *Ty, - ScalarEvolution &SE) { - unsigned NumAddRecs = 0; - for (unsigned i = Ops.size(); i > 0 && isa<SCEVAddRecExpr>(Ops[i-1]); --i) - ++NumAddRecs; - // Group Ops into non-addrecs and addrecs. - SmallVector<const SCEV *, 8> NoAddRecs(Ops.begin(), Ops.end() - NumAddRecs); - SmallVector<const SCEV *, 8> AddRecs(Ops.end() - NumAddRecs, Ops.end()); - // Let ScalarEvolution sort and simplify the non-addrecs list. - const SCEV *Sum = NoAddRecs.empty() ? - SE.getConstant(Ty, 0) : - SE.getAddExpr(NoAddRecs); - // If it returned an add, use the operands. Otherwise it simplified - // the sum into a single value, so just use that. - Ops.clear(); - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Sum)) - append_range(Ops, Add->operands()); - else if (!Sum->isZero()) - Ops.push_back(Sum); - // Then append the addrecs. - Ops.append(AddRecs.begin(), AddRecs.end()); -} - -/// SplitAddRecs - Flatten a list of add operands, moving addrec start values -/// out to the top level. For example, convert {a + b,+,c} to a, b, {0,+,d}. -/// This helps expose more opportunities for folding parts of the expressions -/// into GEP indices. -/// -static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops, - Type *Ty, - ScalarEvolution &SE) { - // Find the addrecs. - SmallVector<const SCEV *, 8> AddRecs; - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Ops[i])) { - const SCEV *Start = A->getStart(); - if (Start->isZero()) break; - const SCEV *Zero = SE.getConstant(Ty, 0); - AddRecs.push_back(SE.getAddRecExpr(Zero, - A->getStepRecurrence(SE), - A->getLoop(), - A->getNoWrapFlags(SCEV::FlagNW))); - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Start)) { - Ops[i] = Zero; - append_range(Ops, Add->operands()); - e += Add->getNumOperands(); - } else { - Ops[i] = Start; - } - } - if (!AddRecs.empty()) { - // Add the addrecs onto the end of the list. - Ops.append(AddRecs.begin(), AddRecs.end()); - // Resort the operand list, moving any constants to the front. - SimplifyAddOperands(Ops, Ty, SE); - } -} - /// expandAddToGEP - Expand an addition expression with a pointer type into /// a GEP instead of using ptrtoint+arithmetic+inttoptr. This helps /// BasicAliasAnalysis and other passes analyze the result. See the rules @@ -450,210 +313,53 @@ static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops, /// loop-invariant portions of expressions, after considering what /// can be folded using target addressing modes. /// -Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, - const SCEV *const *op_end, - PointerType *PTy, - Type *Ty, - Value *V) { - SmallVector<Value *, 4> GepIndices; - SmallVector<const SCEV *, 8> Ops(op_begin, op_end); - bool AnyNonZeroIndices = false; - - // Split AddRecs up into parts as either of the parts may be usable - // without the other. - SplitAddRecs(Ops, Ty, SE); - - Type *IntIdxTy = DL.getIndexType(PTy); - - // For opaque pointers, always generate i8 GEP. - if (!PTy->isOpaque()) { - // Descend down the pointer's type and attempt to convert the other - // operands into GEP indices, at each level. The first index in a GEP - // indexes into the array implied by the pointer operand; the rest of - // the indices index into the element or field type selected by the - // preceding index. - Type *ElTy = PTy->getNonOpaquePointerElementType(); - for (;;) { - // If the scale size is not 0, attempt to factor out a scale for - // array indexing. - SmallVector<const SCEV *, 8> ScaledOps; - if (ElTy->isSized()) { - const SCEV *ElSize = SE.getSizeOfExpr(IntIdxTy, ElTy); - if (!ElSize->isZero()) { - SmallVector<const SCEV *, 8> NewOps; - for (const SCEV *Op : Ops) { - const SCEV *Remainder = SE.getConstant(Ty, 0); - if (FactorOutConstant(Op, Remainder, ElSize, SE, DL)) { - // Op now has ElSize factored out. - ScaledOps.push_back(Op); - if (!Remainder->isZero()) - NewOps.push_back(Remainder); - AnyNonZeroIndices = true; - } else { - // The operand was not divisible, so add it to the list of - // operands we'll scan next iteration. - NewOps.push_back(Op); - } - } - // If we made any changes, update Ops. - if (!ScaledOps.empty()) { - Ops = NewOps; - SimplifyAddOperands(Ops, Ty, SE); - } - } - } +Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) { + assert(!isa<Instruction>(V) || + SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint())); - // Record the scaled array index for this level of the type. If - // we didn't find any operands that could be factored, tentatively - // assume that element zero was selected (since the zero offset - // would obviously be folded away). - Value *Scaled = - ScaledOps.empty() - ? Constant::getNullValue(Ty) - : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty); - GepIndices.push_back(Scaled); - - // Collect struct field index operands. - while (StructType *STy = dyn_cast<StructType>(ElTy)) { - bool FoundFieldNo = false; - // An empty struct has no fields. - if (STy->getNumElements() == 0) break; - // Field offsets are known. See if a constant offset falls within any of - // the struct fields. - if (Ops.empty()) - break; - if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0])) - if (SE.getTypeSizeInBits(C->getType()) <= 64) { - const StructLayout &SL = *DL.getStructLayout(STy); - uint64_t FullOffset = C->getValue()->getZExtValue(); - if (FullOffset < SL.getSizeInBytes()) { - unsigned ElIdx = SL.getElementContainingOffset(FullOffset); - GepIndices.push_back( - ConstantInt::get(Type::getInt32Ty(Ty->getContext()), ElIdx)); - ElTy = STy->getTypeAtIndex(ElIdx); - Ops[0] = - SE.getConstant(Ty, FullOffset - SL.getElementOffset(ElIdx)); - AnyNonZeroIndices = true; - FoundFieldNo = true; - } - } - // If no struct field offsets were found, tentatively assume that - // field zero was selected (since the zero offset would obviously - // be folded away). - if (!FoundFieldNo) { - ElTy = STy->getTypeAtIndex(0u); - GepIndices.push_back( - Constant::getNullValue(Type::getInt32Ty(Ty->getContext()))); - } - } + Value *Idx = expandCodeForImpl(Offset, Ty); - if (ArrayType *ATy = dyn_cast<ArrayType>(ElTy)) - ElTy = ATy->getElementType(); - else - // FIXME: Handle VectorType. - // E.g., If ElTy is scalable vector, then ElSize is not a compile-time - // constant, therefore can not be factored out. The generated IR is less - // ideal with base 'V' cast to i8* and do ugly getelementptr over that. - break; - } - } - - // If none of the operands were convertible to proper GEP indices, cast - // the base to i8* and do an ugly getelementptr with that. It's still - // better than ptrtoint+arithmetic+inttoptr at least. - if (!AnyNonZeroIndices) { - // Cast the base to i8*. - if (!PTy->isOpaque()) - V = InsertNoopCastOfTo(V, - Type::getInt8PtrTy(Ty->getContext(), PTy->getAddressSpace())); - - assert(!isa<Instruction>(V) || - SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint())); - - // Expand the operands for a plain byte offset. - Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty); - - // Fold a GEP with constant operands. - if (Constant *CLHS = dyn_cast<Constant>(V)) - if (Constant *CRHS = dyn_cast<Constant>(Idx)) - return Builder.CreateGEP(Builder.getInt8Ty(), CLHS, CRHS); - - // Do a quick scan to see if we have this GEP nearby. If so, reuse it. - unsigned ScanLimit = 6; - BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); - // Scanning starts from the last instruction before the insertion point. - BasicBlock::iterator IP = Builder.GetInsertPoint(); - if (IP != BlockBegin) { - --IP; - for (; ScanLimit; --IP, --ScanLimit) { - // Don't count dbg.value against the ScanLimit, to avoid perturbing the - // generated code. - if (isa<DbgInfoIntrinsic>(IP)) - ScanLimit++; - if (IP->getOpcode() == Instruction::GetElementPtr && - IP->getOperand(0) == V && IP->getOperand(1) == Idx && - cast<GEPOperator>(&*IP)->getSourceElementType() == - Type::getInt8Ty(Ty->getContext())) - return &*IP; - if (IP == BlockBegin) break; - } - } + // Fold a GEP with constant operands. + if (Constant *CLHS = dyn_cast<Constant>(V)) + if (Constant *CRHS = dyn_cast<Constant>(Idx)) + return Builder.CreateGEP(Builder.getInt8Ty(), CLHS, CRHS); - // Save the original insertion point so we can restore it when we're done. - SCEVInsertPointGuard Guard(Builder, this); - - // Move the insertion point out of as many loops as we can. - while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { - if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break; - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) break; - - // Ok, move up a level. - Builder.SetInsertPoint(Preheader->getTerminator()); + // Do a quick scan to see if we have this GEP nearby. If so, reuse it. + unsigned ScanLimit = 6; + BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); + // Scanning starts from the last instruction before the insertion point. + BasicBlock::iterator IP = Builder.GetInsertPoint(); + if (IP != BlockBegin) { + --IP; + for (; ScanLimit; --IP, --ScanLimit) { + // Don't count dbg.value against the ScanLimit, to avoid perturbing the + // generated code. + if (isa<DbgInfoIntrinsic>(IP)) + ScanLimit++; + if (IP->getOpcode() == Instruction::GetElementPtr && + IP->getOperand(0) == V && IP->getOperand(1) == Idx && + cast<GEPOperator>(&*IP)->getSourceElementType() == + Type::getInt8Ty(Ty->getContext())) + return &*IP; + if (IP == BlockBegin) break; } - - // Emit a GEP. - return Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "uglygep"); } - { - SCEVInsertPointGuard Guard(Builder, this); - - // Move the insertion point out of as many loops as we can. - while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { - if (!L->isLoopInvariant(V)) break; - - bool AnyIndexNotLoopInvariant = any_of( - GepIndices, [L](Value *Op) { return !L->isLoopInvariant(Op); }); - - if (AnyIndexNotLoopInvariant) - break; + // Save the original insertion point so we can restore it when we're done. + SCEVInsertPointGuard Guard(Builder, this); - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) break; + // Move the insertion point out of as many loops as we can. + while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { + if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break; + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) break; - // Ok, move up a level. - Builder.SetInsertPoint(Preheader->getTerminator()); - } - - // Insert a pretty getelementptr. Note that this GEP is not marked inbounds, - // because ScalarEvolution may have changed the address arithmetic to - // compute a value which is beyond the end of the allocated object. - Value *Casted = V; - if (V->getType() != PTy) - Casted = InsertNoopCastOfTo(Casted, PTy); - Value *GEP = Builder.CreateGEP(PTy->getNonOpaquePointerElementType(), - Casted, GepIndices, "scevgep"); - Ops.push_back(SE.getUnknown(GEP)); + // Ok, move up a level. + Builder.SetInsertPoint(Preheader->getTerminator()); } - return expand(SE.getAddExpr(Ops)); -} - -Value *SCEVExpander::expandAddToGEP(const SCEV *Op, PointerType *PTy, Type *Ty, - Value *V) { - const SCEV *const Ops[1] = {Op}; - return expandAddToGEP(Ops, Ops + 1, PTy, Ty, V); + // Emit a GEP. + return Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "scevgep"); } /// PickMostRelevantLoop - Given two loops pick the one that's most relevant for @@ -680,6 +386,7 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return nullptr; // A constant has no relevant loops. case scTruncate: case scZeroExtend: @@ -778,7 +485,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { } assert(!Op->getType()->isPointerTy() && "Only first op can be pointer"); - if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) { + if (isa<PointerType>(Sum->getType())) { // The running sum expression is a pointer. Try to form a getelementptr // at this level with that as the base. SmallVector<const SCEV *, 4> NewOps; @@ -791,7 +498,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { X = SE.getSCEV(U->getValue()); NewOps.push_back(X); } - Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum); + Sum = expandAddToGEP(SE.getAddExpr(NewOps), Ty, Sum); } else if (Op->isNonConstantNegative()) { // Instead of doing a negate and add, just do a subtract. Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty); @@ -995,15 +702,8 @@ Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV, // allow any kind of GEP as long as it can be hoisted. continue; } - // This must be a pointer addition of constants (pretty), which is already - // handled, or some number of address-size elements (ugly). Ugly geps - // have 2 operands. i1* is used by the expander to represent an - // address-size element. - if (IncV->getNumOperands() != 2) - return nullptr; - unsigned AS = cast<PointerType>(IncV->getType())->getAddressSpace(); - if (IncV->getType() != Type::getInt1PtrTy(SE.getContext(), AS) - && IncV->getType() != Type::getInt8PtrTy(SE.getContext(), AS)) + // GEPs produced by SCEVExpander use i8 element type. + if (!cast<GEPOperator>(IncV)->getSourceElementType()->isIntegerTy(8)) return nullptr; break; } @@ -1108,15 +808,7 @@ Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, Value *IncV; // If the PHI is a pointer, use a GEP, otherwise use an add or sub. if (ExpandTy->isPointerTy()) { - PointerType *GEPPtrTy = cast<PointerType>(ExpandTy); - // If the step isn't constant, don't use an implicitly scaled GEP, because - // that would require a multiply inside the loop. - if (!isa<ConstantInt>(StepV)) - GEPPtrTy = PointerType::get(Type::getInt1Ty(SE.getContext()), - GEPPtrTy->getAddressSpace()); - IncV = expandAddToGEP(SE.getSCEV(StepV), GEPPtrTy, IntTy, PN); - if (IncV->getType() != PN->getType()) - IncV = Builder.CreateBitCast(IncV, PN->getType()); + IncV = expandAddToGEP(SE.getSCEV(StepV), IntTy, PN); } else { IncV = useSubtract ? Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") : @@ -1388,7 +1080,8 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { if (PostIncLoops.count(L)) { PostIncLoopSet Loops; Loops.insert(L); - Normalized = cast<SCEVAddRecExpr>(normalizeForPostIncUse(S, Loops, SE)); + Normalized = cast<SCEVAddRecExpr>( + normalizeForPostIncUse(S, Loops, SE, /*CheckInvertible=*/false)); } // Strip off any non-loop-dominating component from the addrec start. @@ -1515,12 +1208,12 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { // Re-apply any non-loop-dominating offset. if (PostLoopOffset) { - if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) { + if (isa<PointerType>(ExpandTy)) { if (Result->getType()->isIntegerTy()) { Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy); - Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base); + Result = expandAddToGEP(SE.getUnknown(Result), IntTy, Base); } else { - Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result); + Result = expandAddToGEP(PostLoopOffset, IntTy, Result); } } else { Result = InsertNoopCastOfTo(Result, IntTy); @@ -1574,10 +1267,9 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { // {X,+,F} --> X + {0,+,F} if (!S->getStart()->isZero()) { - if (PointerType *PTy = dyn_cast<PointerType>(S->getType())) { + if (isa<PointerType>(S->getType())) { Value *StartV = expand(SE.getPointerBase(S)); - assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); - return expandAddToGEP(SE.removePointerBase(S), PTy, Ty, StartV); + return expandAddToGEP(SE.removePointerBase(S), Ty, StartV); } SmallVector<const SCEV *, 4> NewOps(S->operands()); @@ -1744,6 +1436,10 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true); } +Value *SCEVExpander::visitVScale(const SCEVVScale *S) { + return Builder.CreateVScale(ConstantInt::get(S->getType(), 1)); +} + Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP) { setInsertPoint(IP); @@ -1956,11 +1652,17 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, OrigPhiRef = Phi; if (Phi->getType()->isIntegerTy() && TTI && TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) { - // This phi can be freely truncated to the narrowest phi type. Map the - // truncated expression to it so it will be reused for narrow types. - const SCEV *TruncExpr = - SE.getTruncateExpr(SE.getSCEV(Phi), Phis.back()->getType()); - ExprToIVMap[TruncExpr] = Phi; + // Make sure we only rewrite using simple induction variables; + // otherwise, we can make the trip count of a loop unanalyzable + // to SCEV. + const SCEV *PhiExpr = SE.getSCEV(Phi); + if (isa<SCEVAddRecExpr>(PhiExpr)) { + // This phi can be freely truncated to the narrowest phi type. Map the + // truncated expression to it so it will be reused for narrow types. + const SCEV *TruncExpr = + SE.getTruncateExpr(PhiExpr, Phis.back()->getType()); + ExprToIVMap[TruncExpr] = Phi; + } } continue; } @@ -2124,6 +1826,7 @@ template<typename T> static InstructionCost costAndCollectOperands( llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); case scUnknown: case scConstant: + case scVScale: return 0; case scPtrToInt: Cost = CastCost(Instruction::PtrToInt); @@ -2260,6 +1963,7 @@ bool SCEVExpander::isHighCostExpansionHelper( case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); case scUnknown: + case scVScale: // Assume to be zero-cost. return false; case scConstant: { @@ -2551,7 +2255,11 @@ Value *SCEVExpander::fixupLCSSAFormFor(Value *V) { SmallVector<Instruction *, 1> ToUpdate; ToUpdate.push_back(DefI); SmallVector<PHINode *, 16> PHIsToRemove; - formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE, Builder, &PHIsToRemove); + SmallVector<PHINode *, 16> InsertedPHIs; + formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE, &PHIsToRemove, + &InsertedPHIs); + for (PHINode *PN : InsertedPHIs) + rememberInstruction(PN); for (PHINode *PN : PHIsToRemove) { if (!PN->use_empty()) continue; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 9e0483966d3e..d3a9a41aef15 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -271,10 +271,8 @@ class SimplifyCFGOpt { bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI, IRBuilder<> &Builder); - bool HoistThenElseCodeToIf(BranchInst *BI, const TargetTransformInfo &TTI, - bool EqTermsOnly); - bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, - const TargetTransformInfo &TTI); + bool HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly); + bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB); bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond, BasicBlock *TrueBB, BasicBlock *FalseBB, uint32_t TrueWeight, uint32_t FalseWeight); @@ -1086,7 +1084,7 @@ static void GetBranchWeights(Instruction *TI, static void FitWeights(MutableArrayRef<uint64_t> Weights) { uint64_t Max = *std::max_element(Weights.begin(), Weights.end()); if (Max > UINT_MAX) { - unsigned Offset = 32 - countLeadingZeros(Max); + unsigned Offset = 32 - llvm::countl_zero(Max); for (uint64_t &I : Weights) I >>= Offset; } @@ -1117,16 +1115,12 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); VMap[&BonusInst] = NewBonusInst; - // If we moved a load, we cannot any longer claim any knowledge about - // its potential value. The previous information might have been valid + // If we speculated an instruction, we need to drop any metadata that may + // result in undefined behavior, as the metadata might have been valid // only given the branch precondition. - // For an analogous reason, we must also drop all the metadata whose - // semantics we don't understand. We *can* preserve !annotation, because - // it is tied to the instruction itself, not the value or position. // Similarly strip attributes on call parameters that may cause UB in // location the call is moved to. - NewBonusInst->dropUndefImplyingAttrsAndUnknownMetadata( - LLVMContext::MD_annotation); + NewBonusInst->dropUBImplyingAttrsAndMetadata(); NewBonusInst->insertInto(PredBlock, PTI->getIterator()); NewBonusInst->takeName(&BonusInst); @@ -1462,7 +1456,7 @@ static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) { // If we have seen an instruction with side effects, it's unsafe to reorder an // instruction which reads memory or itself has side effects. if ((Flags & SkipSideEffect) && - (I->mayReadFromMemory() || I->mayHaveSideEffects())) + (I->mayReadFromMemory() || I->mayHaveSideEffects() || isa<AllocaInst>(I))) return false; // Reordering across an instruction which does not necessarily transfer @@ -1490,14 +1484,43 @@ static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) { static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValueMayBeModified = false); +/// Helper function for HoistThenElseCodeToIf. Return true if identical +/// instructions \p I1 and \p I2 can and should be hoisted. +static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2, + const TargetTransformInfo &TTI) { + // If we're going to hoist a call, make sure that the two instructions + // we're commoning/hoisting are both marked with musttail, or neither of + // them is marked as such. Otherwise, we might end up in a situation where + // we hoist from a block where the terminator is a `ret` to a block where + // the terminator is a `br`, and `musttail` calls expect to be followed by + // a return. + auto *C1 = dyn_cast<CallInst>(I1); + auto *C2 = dyn_cast<CallInst>(I2); + if (C1 && C2) + if (C1->isMustTailCall() != C2->isMustTailCall()) + return false; + + if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) + return false; + + // If any of the two call sites has nomerge or convergent attribute, stop + // hoisting. + if (const auto *CB1 = dyn_cast<CallBase>(I1)) + if (CB1->cannotMerge() || CB1->isConvergent()) + return false; + if (const auto *CB2 = dyn_cast<CallBase>(I2)) + if (CB2->cannotMerge() || CB2->isConvergent()) + return false; + + return true; +} + /// Given a conditional branch that goes to BB1 and BB2, hoist any common code /// in the two blocks up into the branch block. The caller of this function /// guarantees that BI's block dominates BB1 and BB2. If EqTermsOnly is given, /// only perform hoisting in case both blocks only contain a terminator. In that /// case, only the original BI will be replaced and selects for PHIs are added. -bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, - const TargetTransformInfo &TTI, - bool EqTermsOnly) { +bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly) { // This does very trivial matching, with limited scanning, to find identical // instructions in the two blocks. In particular, we don't want to get into // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As @@ -1572,37 +1595,13 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, goto HoistTerminator; } - if (I1->isIdenticalToWhenDefined(I2)) { - // Even if the instructions are identical, it may not be safe to hoist - // them if we have skipped over instructions with side effects or their - // operands weren't hoisted. - if (!isSafeToHoistInstr(I1, SkipFlagsBB1) || - !isSafeToHoistInstr(I2, SkipFlagsBB2)) - return Changed; - - // If we're going to hoist a call, make sure that the two instructions - // we're commoning/hoisting are both marked with musttail, or neither of - // them is marked as such. Otherwise, we might end up in a situation where - // we hoist from a block where the terminator is a `ret` to a block where - // the terminator is a `br`, and `musttail` calls expect to be followed by - // a return. - auto *C1 = dyn_cast<CallInst>(I1); - auto *C2 = dyn_cast<CallInst>(I2); - if (C1 && C2) - if (C1->isMustTailCall() != C2->isMustTailCall()) - return Changed; - - if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) - return Changed; - - // If any of the two call sites has nomerge attribute, stop hoisting. - if (const auto *CB1 = dyn_cast<CallBase>(I1)) - if (CB1->cannotMerge()) - return Changed; - if (const auto *CB2 = dyn_cast<CallBase>(I2)) - if (CB2->cannotMerge()) - return Changed; - + if (I1->isIdenticalToWhenDefined(I2) && + // Even if the instructions are identical, it may not be safe to hoist + // them if we have skipped over instructions with side effects or their + // operands weren't hoisted. + isSafeToHoistInstr(I1, SkipFlagsBB1) && + isSafeToHoistInstr(I2, SkipFlagsBB2) && + shouldHoistCommonInstructions(I1, I2, TTI)) { if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) { assert(isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2)); // The debug location is an integral part of a debug info intrinsic @@ -1618,19 +1617,7 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->andIRFlags(I2); - unsigned KnownIDs[] = {LLVMContext::MD_tbaa, - LLVMContext::MD_range, - LLVMContext::MD_fpmath, - LLVMContext::MD_invariant_load, - LLVMContext::MD_nonnull, - LLVMContext::MD_invariant_group, - LLVMContext::MD_align, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_mem_parallel_loop_access, - LLVMContext::MD_access_group, - LLVMContext::MD_preserve_access_index}; - combineMetadata(I1, I2, KnownIDs, true); + combineMetadataForCSE(I1, I2, true); // I1 and I2 are being combined into a single instruction. Its debug // location is the merged locations of the original instructions. @@ -1808,9 +1795,9 @@ static bool canSinkInstructions( // Conservatively return false if I is an inline-asm instruction. Sinking // and merging inline-asm instructions can potentially create arguments // that cannot satisfy the inline-asm constraints. - // If the instruction has nomerge attribute, return false. + // If the instruction has nomerge or convergent attribute, return false. if (const auto *C = dyn_cast<CallBase>(I)) - if (C->isInlineAsm() || C->cannotMerge()) + if (C->isInlineAsm() || C->cannotMerge() || C->isConvergent()) return false; // Each instruction must have zero or one use. @@ -2455,9 +2442,13 @@ bool CompatibleSets::shouldBelongToSameSet(ArrayRef<InvokeInst *> Invokes) { // Can we theoretically form the data operands for the merged `invoke`? auto IsIllegalToMergeArguments = [](auto Ops) { - Type *Ty = std::get<0>(Ops)->getType(); - assert(Ty == std::get<1>(Ops)->getType() && "Incompatible types?"); - return Ty->isTokenTy() && std::get<0>(Ops) != std::get<1>(Ops); + Use &U0 = std::get<0>(Ops); + Use &U1 = std::get<1>(Ops); + if (U0 == U1) + return false; + return U0->getType()->isTokenTy() || + !canReplaceOperandWithVariable(cast<Instruction>(U0.getUser()), + U0.getOperandNo()); }; assert(Invokes.size() == 2 && "Always called with exactly two candidates."); if (any_of(zip(Invokes[0]->data_ops(), Invokes[1]->data_ops()), @@ -2571,7 +2562,7 @@ static void MergeCompatibleInvokesImpl(ArrayRef<InvokeInst *> Invokes, // And finally, replace the original `invoke`s with an unconditional branch // to the block with the merged `invoke`. Also, give that merged `invoke` // the merged debugloc of all the original `invoke`s. - const DILocation *MergedDebugLoc = nullptr; + DILocation *MergedDebugLoc = nullptr; for (InvokeInst *II : Invokes) { // Compute the debug location common to all the original `invoke`s. if (!MergedDebugLoc) @@ -2849,8 +2840,11 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB, /// \endcode /// /// \returns true if the conditional block is removed. -bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, - const TargetTransformInfo &TTI) { +bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, + BasicBlock *ThenBB) { + if (!Options.SpeculateBlocks) + return false; + // Be conservative for now. FP select instruction can often be expensive. Value *BrCond = BI->getCondition(); if (isa<FCmpInst>(BrCond)) @@ -3021,7 +3015,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, } // Metadata can be dependent on the condition we are hoisting above. - // Conservatively strip all metadata on the instruction. Drop the debug loc + // Strip all UB-implying metadata on the instruction. Drop the debug loc // to avoid making it appear as if the condition is a constant, which would // be misleading while debugging. // Similarly strip attributes that maybe dependent on condition we are @@ -3032,7 +3026,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, if (!isa<DbgAssignIntrinsic>(&I)) I.setDebugLoc(DebugLoc()); } - I.dropUndefImplyingAttrsAndUnknownMetadata(); + I.dropUBImplyingAttrsAndMetadata(); // Drop ephemeral values. if (EphTracker.contains(&I)) { @@ -3220,6 +3214,9 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, } // Clone the instruction. Instruction *N = BBI->clone(); + // Insert the new instruction into its new home. + N->insertInto(EdgeBB, InsertPt); + if (BBI->hasName()) N->setName(BBI->getName() + ".c"); @@ -3235,7 +3232,8 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, if (!BBI->use_empty()) TranslateMap[&*BBI] = V; if (!N->mayHaveSideEffects()) { - N->deleteValue(); // Instruction folded away, don't need actual inst + N->eraseFromParent(); // Instruction folded away, don't need actual + // inst N = nullptr; } } else { @@ -3243,9 +3241,6 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, TranslateMap[&*BBI] = N; } if (N) { - // Insert the new instruction into its new home. - N->insertInto(EdgeBB, InsertPt); - // Register the new instruction with the assumption cache if necessary. if (auto *Assume = dyn_cast<AssumeInst>(N)) if (AC) @@ -3591,17 +3586,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, // If we need to invert the condition in the pred block to match, do so now. if (InvertPredCond) { - Value *NewCond = PBI->getCondition(); - if (NewCond->hasOneUse() && isa<CmpInst>(NewCond)) { - CmpInst *CI = cast<CmpInst>(NewCond); - CI->setPredicate(CI->getInversePredicate()); - } else { - NewCond = - Builder.CreateNot(NewCond, PBI->getCondition()->getName() + ".not"); - } - - PBI->setCondition(NewCond); - PBI->swapSuccessors(); + InvertBranch(PBI, Builder); } BasicBlock *UniqueSucc = @@ -3887,7 +3872,7 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, for (BasicBlock *PredBB : predecessors(Succ)) if (PredBB != BB) PHI->addIncoming( - AlternativeV ? AlternativeV : UndefValue::get(V->getType()), PredBB); + AlternativeV ? AlternativeV : PoisonValue::get(V->getType()), PredBB); return PHI; } @@ -5150,14 +5135,18 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { Value* Cond = BI->getCondition(); assert(BI->getSuccessor(0) != BI->getSuccessor(1) && "The destinations are guaranteed to be different here."); + CallInst *Assumption; if (BI->getSuccessor(0) == BB) { - Builder.CreateAssumption(Builder.CreateNot(Cond)); + Assumption = Builder.CreateAssumption(Builder.CreateNot(Cond)); Builder.CreateBr(BI->getSuccessor(1)); } else { assert(BI->getSuccessor(1) == BB && "Incorrect CFG"); - Builder.CreateAssumption(Cond); + Assumption = Builder.CreateAssumption(Cond); Builder.CreateBr(BI->getSuccessor(0)); } + if (Options.AC) + Options.AC->registerAssumption(cast<AssumeInst>(Assumption)); + EraseTerminatorAndDCECond(BI); Changed = true; } @@ -5453,7 +5442,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, } const APInt &CaseVal = Case.getCaseValue()->getValue(); if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || - (CaseVal.getMinSignedBits() > MaxSignificantBitsInCond)) { + (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) { DeadCases.push_back(Case.getCaseValue()); if (DTU) --NumPerSuccessorCases[Successor]; @@ -5469,7 +5458,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, bool HasDefault = !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); const unsigned NumUnknownBits = - Known.getBitWidth() - (Known.Zero | Known.One).countPopulation(); + Known.getBitWidth() - (Known.Zero | Known.One).popcount(); assert(NumUnknownBits <= Known.getBitWidth()); if (HasDefault && DeadCases.empty() && NumUnknownBits < 64 /* avoid overflow */ && @@ -5860,7 +5849,7 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, // Check if cases with the same result can cover all number // in touched bits. - if (BitMask.countPopulation() == Log2_32(CaseCount)) { + if (BitMask.popcount() == Log2_32(CaseCount)) { if (!MinCaseVal->isNullValue()) Condition = Builder.CreateSub(Condition, MinCaseVal); Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and"); @@ -6001,6 +5990,7 @@ private: // For LinearMapKind, these are the constants used to derive the value. ConstantInt *LinearOffset = nullptr; ConstantInt *LinearMultiplier = nullptr; + bool LinearMapValWrapped = false; // For ArrayKind, this is the array. GlobalVariable *Array = nullptr; @@ -6061,6 +6051,8 @@ SwitchLookupTable::SwitchLookupTable( bool LinearMappingPossible = true; APInt PrevVal; APInt DistToPrev; + // When linear map is monotonic, we can attach nsw. + bool Wrapped = false; assert(TableSize >= 2 && "Should be a SingleValue table."); // Check if there is the same distance between two consecutive values. for (uint64_t I = 0; I < TableSize; ++I) { @@ -6080,12 +6072,15 @@ SwitchLookupTable::SwitchLookupTable( LinearMappingPossible = false; break; } + Wrapped |= + Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal); } PrevVal = Val; } if (LinearMappingPossible) { LinearOffset = cast<ConstantInt>(TableContents[0]); LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev); + LinearMapValWrapped = Wrapped; Kind = LinearMapKind; ++NumLinearMaps; return; @@ -6134,9 +6129,14 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(), false, "switch.idx.cast"); if (!LinearMultiplier->isOne()) - Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult"); + Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult", + /*HasNUW = */ false, + /*HasNSW = */ !LinearMapValWrapped); + if (!LinearOffset->isZero()) - Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset"); + Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset", + /*HasNUW = */ false, + /*HasNSW = */ !LinearMapValWrapped); return Result; } case BitMapKind: { @@ -6148,10 +6148,12 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { // truncating it to the width of the bitmask is safe. Value *ShiftAmt = Builder.CreateZExtOrTrunc(Index, MapTy, "switch.cast"); - // Multiply the shift amount by the element width. + // Multiply the shift amount by the element width. NUW/NSW can always be + // set, because WouldFitInRegister guarantees Index * ShiftAmt is in + // BitMap's bit width. ShiftAmt = Builder.CreateMul( ShiftAmt, ConstantInt::get(MapTy, BitMapElementTy->getBitWidth()), - "switch.shiftamt"); + "switch.shiftamt",/*HasNUW =*/true,/*HasNSW =*/true); // Shift down. Value *DownShifted = @@ -6490,6 +6492,21 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, std::vector<DominatorTree::UpdateType> Updates; + // Compute the maximum table size representable by the integer type we are + // switching upon. + unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits(); + uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize; + assert(MaxTableSize >= TableSize && + "It is impossible for a switch to have more entries than the max " + "representable value of its input integer type's size."); + + // If the default destination is unreachable, or if the lookup table covers + // all values of the conditional variable, branch directly to the lookup table + // BB. Otherwise, check that the condition is within the case range. + const bool DefaultIsReachable = + !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); + const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize); + // Create the BB that does the lookups. Module &Mod = *CommonDest->getParent()->getParent(); BasicBlock *LookupBB = BasicBlock::Create( @@ -6504,24 +6521,19 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, TableIndex = SI->getCondition(); } else { TableIndexOffset = MinCaseVal; - TableIndex = - Builder.CreateSub(SI->getCondition(), TableIndexOffset, "switch.tableidx"); - } + // If the default is unreachable, all case values are s>= MinCaseVal. Then + // we can try to attach nsw. + bool MayWrap = true; + if (!DefaultIsReachable) { + APInt Res = MaxCaseVal->getValue().ssub_ov(MinCaseVal->getValue(), MayWrap); + (void)Res; + } - // Compute the maximum table size representable by the integer type we are - // switching upon. - unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits(); - uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize; - assert(MaxTableSize >= TableSize && - "It is impossible for a switch to have more entries than the max " - "representable value of its input integer type's size."); + TableIndex = Builder.CreateSub(SI->getCondition(), TableIndexOffset, + "switch.tableidx", /*HasNUW =*/false, + /*HasNSW =*/!MayWrap); + } - // If the default destination is unreachable, or if the lookup table covers - // all values of the conditional variable, branch directly to the lookup table - // BB. Otherwise, check that the condition is within the case range. - const bool DefaultIsReachable = - !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); - const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize); BranchInst *RangeCheckBranch = nullptr; if (!DefaultIsReachable || GeneratingCoveredLookupTable) { @@ -6694,7 +6706,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, // less than 64. unsigned Shift = 64; for (auto &V : Values) - Shift = std::min(Shift, countTrailingZeros((uint64_t)V)); + Shift = std::min(Shift, (unsigned)llvm::countr_zero((uint64_t)V)); assert(Shift < 64); if (Shift > 0) for (auto &V : Values) @@ -6990,7 +7002,8 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { "Tautological conditional branch should have been eliminated already."); BasicBlock *BB = BI->getParent(); - if (!Options.SimplifyCondBranch) + if (!Options.SimplifyCondBranch || + BI->getFunction()->hasFnAttribute(Attribute::OptForFuzzing)) return false; // Conditional branch @@ -7045,8 +7058,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // can hoist it up to the branching block. if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { - if (HoistCommon && - HoistThenElseCodeToIf(BI, TTI, !Options.HoistCommonInsts)) + if (HoistCommon && HoistThenElseCodeToIf(BI, !Options.HoistCommonInsts)) return requestResimplify(); } else { // If Successor #1 has multiple preds, we may be able to conditionally @@ -7054,7 +7066,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { Instruction *Succ0TI = BI->getSuccessor(0)->getTerminator(); if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) - if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), TTI)) + if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0))) return requestResimplify(); } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { @@ -7063,7 +7075,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { Instruction *Succ1TI = BI->getSuccessor(1)->getTerminator(); if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) - if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), TTI)) + if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1))) return requestResimplify(); } @@ -7179,7 +7191,8 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu /// If BB has an incoming value that will always trigger undefined behavior /// (eg. null pointer dereference), remove the branch leading here. static bool removeUndefIntroducingPredecessor(BasicBlock *BB, - DomTreeUpdater *DTU) { + DomTreeUpdater *DTU, + AssumptionCache *AC) { for (PHINode &PHI : BB->phis()) for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) if (passingValueIsAlwaysUndefined(PHI.getIncomingValue(i), &PHI)) { @@ -7196,10 +7209,13 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB, // Preserve guarding condition in assume, because it might not be // inferrable from any dominating condition. Value *Cond = BI->getCondition(); + CallInst *Assumption; if (BI->getSuccessor(0) == BB) - Builder.CreateAssumption(Builder.CreateNot(Cond)); + Assumption = Builder.CreateAssumption(Builder.CreateNot(Cond)); else - Builder.CreateAssumption(Cond); + Assumption = Builder.CreateAssumption(Cond); + if (AC) + AC->registerAssumption(cast<AssumeInst>(Assumption)); Builder.CreateBr(BI->getSuccessor(0) == BB ? BI->getSuccessor(1) : BI->getSuccessor(0)); } @@ -7260,7 +7276,7 @@ bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { Changed |= EliminateDuplicatePHINodes(BB); // Check for and remove branches that will always cause undefined behavior. - if (removeUndefIntroducingPredecessor(BB, DTU)) + if (removeUndefIntroducingPredecessor(BB, DTU, Options.AC)) return requestResimplify(); // Merge basic blocks into their predecessor if there is only one distinct @@ -7282,7 +7298,8 @@ bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { IRBuilder<> Builder(BB); - if (Options.FoldTwoEntryPHINode) { + if (Options.SpeculateBlocks && + !BB->getParent()->hasFnAttribute(Attribute::OptForFuzzing)) { // If there is a trivial two-entry PHI node in this basic block, and we can // eliminate it, do so now. if (auto *PN = dyn_cast<PHINode>(BB->begin())) diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index 4e83d2f6e3c6..a28916bc9baf 100644 --- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -93,6 +93,7 @@ namespace { void replaceRemWithNumeratorOrZero(BinaryOperator *Rem); void replaceSRemWithURem(BinaryOperator *Rem); bool eliminateSDiv(BinaryOperator *SDiv); + bool strengthenBinaryOp(BinaryOperator *BO, Instruction *IVOperand); bool strengthenOverflowingOperation(BinaryOperator *OBO, Instruction *IVOperand); bool strengthenRightShift(BinaryOperator *BO, Instruction *IVOperand); @@ -216,8 +217,10 @@ bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp, // Do not generate something ridiculous. auto *PHTerm = Preheader->getTerminator(); - if (Rewriter.isHighCostExpansion({ InvariantLHS, InvariantRHS }, L, - 2 * SCEVCheapExpansionBudget, TTI, PHTerm)) + if (Rewriter.isHighCostExpansion({InvariantLHS, InvariantRHS}, L, + 2 * SCEVCheapExpansionBudget, TTI, PHTerm) || + !Rewriter.isSafeToExpandAt(InvariantLHS, PHTerm) || + !Rewriter.isSafeToExpandAt(InvariantRHS, PHTerm)) return false; auto *NewLHS = Rewriter.expandCodeFor(InvariantLHS, IVOperand->getType(), PHTerm); @@ -747,6 +750,13 @@ bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst, return true; } +bool SimplifyIndvar::strengthenBinaryOp(BinaryOperator *BO, + Instruction *IVOperand) { + return (isa<OverflowingBinaryOperator>(BO) && + strengthenOverflowingOperation(BO, IVOperand)) || + (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand)); +} + /// Annotate BO with nsw / nuw if it provably does not signed-overflow / /// unsigned-overflow. Returns true if anything changed, false otherwise. bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, @@ -898,6 +908,14 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { if (replaceIVUserWithLoopInvariant(UseInst)) continue; + // Go further for the bitcast ''prtoint ptr to i64' + if (isa<PtrToIntInst>(UseInst)) + for (Use &U : UseInst->uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + if (replaceIVUserWithLoopInvariant(User)) + break; // done replacing + } + Instruction *IVOperand = UseOper.second; for (unsigned N = 0; IVOperand; ++N) { assert(N <= Simplified.size() && "runaway iteration"); @@ -917,9 +935,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { } if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseInst)) { - if ((isa<OverflowingBinaryOperator>(BO) && - strengthenOverflowingOperation(BO, IVOperand)) || - (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand))) { + if (strengthenBinaryOp(BO, IVOperand)) { // re-queue uses of the now modified binary operator and fall // through to the checks that remain. pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 20f18322d43c..5b0951252c07 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -14,11 +14,12 @@ #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" -#include "llvm/ADT/Triple.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -29,6 +30,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SizeOpts.h" @@ -44,6 +46,45 @@ static cl::opt<bool> cl::desc("Enable unsafe double to float " "shrinking for math lib calls")); +// Enable conversion of operator new calls with a MemProf hot or cold hint +// to an operator new call that takes a hot/cold hint. Off by default since +// not all allocators currently support this extension. +static cl::opt<bool> + OptimizeHotColdNew("optimize-hot-cold-new", cl::Hidden, cl::init(false), + cl::desc("Enable hot/cold operator new library calls")); + +namespace { + +// Specialized parser to ensure the hint is an 8 bit value (we can't specify +// uint8_t to opt<> as that is interpreted to mean that we are passing a char +// option with a specific set of values. +struct HotColdHintParser : public cl::parser<unsigned> { + HotColdHintParser(cl::Option &O) : cl::parser<unsigned>(O) {} + + bool parse(cl::Option &O, StringRef ArgName, StringRef Arg, unsigned &Value) { + if (Arg.getAsInteger(0, Value)) + return O.error("'" + Arg + "' value invalid for uint argument!"); + + if (Value > 255) + return O.error("'" + Arg + "' value must be in the range [0, 255]!"); + + return false; + } +}; + +} // end anonymous namespace + +// Hot/cold operator new takes an 8 bit hotness hint, where 0 is the coldest +// and 255 is the hottest. Default to 1 value away from the coldest and hottest +// hints, so that the compiler hinted allocations are slightly less strong than +// manually inserted hints at the two extremes. +static cl::opt<unsigned, false, HotColdHintParser> ColdNewHintValue( + "cold-new-hint-value", cl::Hidden, cl::init(1), + cl::desc("Value to pass to hot/cold operator new for cold allocation")); +static cl::opt<unsigned, false, HotColdHintParser> HotNewHintValue( + "hot-new-hint-value", cl::Hidden, cl::init(254), + cl::desc("Value to pass to hot/cold operator new for hot allocation")); + //===----------------------------------------------------------------------===// // Helper Functions //===----------------------------------------------------------------------===// @@ -186,21 +227,9 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr, return ConstantInt::get(RetTy, Result); } -static bool isOnlyUsedInComparisonWithZero(Value *V) { - for (User *U : V->users()) { - if (ICmpInst *IC = dyn_cast<ICmpInst>(U)) - if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) - if (C->isNullValue()) - continue; - // Unknown instruction. - return false; - } - return true; -} - static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len, const DataLayout &DL) { - if (!isOnlyUsedInComparisonWithZero(CI)) + if (!isOnlyUsedInZeroComparison(CI)) return false; if (!isDereferenceableAndAlignedPointer(Str, Align(1), APInt(64, Len), DL)) @@ -1358,6 +1387,10 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { return nullptr; } + bool OptForSize = CI->getFunction()->hasOptSize() || + llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI, + PGSOQueryType::IRPass); + // If the char is variable but the input str and length are not we can turn // this memchr call into a simple bit field test. Of course this only works // when the return value is only checked against null. @@ -1368,7 +1401,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n'))) // != 0 // after bounds check. - if (Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) + if (OptForSize || Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) return nullptr; unsigned char Max = @@ -1380,8 +1413,34 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { // FIXME: On a 64 bit architecture this prevents us from using the // interesting range of alpha ascii chars. We could do better by emitting // two bitfields or shifting the range by 64 if no lower chars are used. - if (!DL.fitsInLegalInteger(Max + 1)) - return nullptr; + if (!DL.fitsInLegalInteger(Max + 1)) { + // Build chain of ORs + // Transform: + // memchr("abcd", C, 4) != nullptr + // to: + // (C == 'a' || C == 'b' || C == 'c' || C == 'd') != 0 + std::string SortedStr = Str.str(); + llvm::sort(SortedStr); + // Compute the number of of non-contiguous ranges. + unsigned NonContRanges = 1; + for (size_t i = 1; i < SortedStr.size(); ++i) { + if (SortedStr[i] > SortedStr[i - 1] + 1) { + NonContRanges++; + } + } + + // Restrict this optimization to profitable cases with one or two range + // checks. + if (NonContRanges > 2) + return nullptr; + + SmallVector<Value *> CharCompares; + for (unsigned char C : SortedStr) + CharCompares.push_back( + B.CreateICmpEQ(CharVal, ConstantInt::get(CharVal->getType(), C))); + + return B.CreateIntToPtr(B.CreateOr(CharCompares), CI->getType()); + } // For the bit field use a power-of-2 type with at least 8 bits to avoid // creating unnecessary illegal types. @@ -1481,30 +1540,21 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, // First, see if we can fold either argument to a constant. Value *LHSV = nullptr; - if (auto *LHSC = dyn_cast<Constant>(LHS)) { - LHSC = ConstantExpr::getBitCast(LHSC, IntType->getPointerTo()); + if (auto *LHSC = dyn_cast<Constant>(LHS)) LHSV = ConstantFoldLoadFromConstPtr(LHSC, IntType, DL); - } + Value *RHSV = nullptr; - if (auto *RHSC = dyn_cast<Constant>(RHS)) { - RHSC = ConstantExpr::getBitCast(RHSC, IntType->getPointerTo()); + if (auto *RHSC = dyn_cast<Constant>(RHS)) RHSV = ConstantFoldLoadFromConstPtr(RHSC, IntType, DL); - } // Don't generate unaligned loads. If either source is constant data, // alignment doesn't matter for that source because there is no load. if ((LHSV || getKnownAlignment(LHS, DL, CI) >= PrefAlignment) && (RHSV || getKnownAlignment(RHS, DL, CI) >= PrefAlignment)) { - if (!LHSV) { - Type *LHSPtrTy = - IntType->getPointerTo(LHS->getType()->getPointerAddressSpace()); - LHSV = B.CreateLoad(IntType, B.CreateBitCast(LHS, LHSPtrTy), "lhsv"); - } - if (!RHSV) { - Type *RHSPtrTy = - IntType->getPointerTo(RHS->getType()->getPointerAddressSpace()); - RHSV = B.CreateLoad(IntType, B.CreateBitCast(RHS, RHSPtrTy), "rhsv"); - } + if (!LHSV) + LHSV = B.CreateLoad(IntType, LHS, "lhsv"); + if (!RHSV) + RHSV = B.CreateLoad(IntType, RHS, "rhsv"); return B.CreateZExt(B.CreateICmpNE(LHSV, RHSV), CI->getType(), "memcmp"); } } @@ -1653,6 +1703,59 @@ Value *LibCallSimplifier::optimizeRealloc(CallInst *CI, IRBuilderBase &B) { return nullptr; } +// When enabled, replace operator new() calls marked with a hot or cold memprof +// attribute with an operator new() call that takes a __hot_cold_t parameter. +// Currently this is supported by the open source version of tcmalloc, see: +// https://github.com/google/tcmalloc/blob/master/tcmalloc/new_extension.h +Value *LibCallSimplifier::optimizeNew(CallInst *CI, IRBuilderBase &B, + LibFunc &Func) { + if (!OptimizeHotColdNew) + return nullptr; + + uint8_t HotCold; + if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "cold") + HotCold = ColdNewHintValue; + else if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "hot") + HotCold = HotNewHintValue; + else + return nullptr; + + switch (Func) { + case LibFunc_Znwm: + return emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znwm12__hot_cold_t, HotCold); + case LibFunc_Znam: + return emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znam12__hot_cold_t, HotCold); + case LibFunc_ZnwmRKSt9nothrow_t: + return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B, + TLI, LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, + HotCold); + case LibFunc_ZnamRKSt9nothrow_t: + return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B, + TLI, LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, + HotCold); + case LibFunc_ZnwmSt11align_val_t: + return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B, + TLI, LibFunc_ZnwmSt11align_val_t12__hot_cold_t, + HotCold); + case LibFunc_ZnamSt11align_val_t: + return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B, + TLI, LibFunc_ZnamSt11align_val_t12__hot_cold_t, + HotCold); + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: + return emitHotColdNewAlignedNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, + TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: + return emitHotColdNewAlignedNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, + TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); + default: + return nullptr; + } +} + //===----------------------------------------------------------------------===// // Math Library Optimizations //===----------------------------------------------------------------------===// @@ -1939,7 +2042,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { AttributeList NoAttrs; // Attributes are only meaningful on the original call // pow(2.0, itofp(x)) -> ldexp(1.0, x) - if (match(Base, m_SpecificFP(2.0)) && + // TODO: This does not work for vectors because there is no ldexp intrinsic. + if (!Ty->isVectorTy() && match(Base, m_SpecificFP(2.0)) && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) && hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize())) @@ -2056,7 +2160,7 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { // pow(-Inf, 0.5) is optionally required to have a result of +Inf (not setting // errno), but sqrt(-Inf) is required by various standards to set errno. if (!Pow->doesNotAccessMemory() && !Pow->hasNoInfs() && - !isKnownNeverInfinity(Base, TLI)) + !isKnownNeverInfinity(Base, DL, TLI, 0, AC, Pow)) return nullptr; Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), Mod, B, @@ -2217,17 +2321,25 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { hasFloatVersion(M, Name)) Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); + // Bail out for vectors because the code below only expects scalars. + // TODO: This could be allowed if we had a ldexp intrinsic (D14327). Type *Ty = CI->getType(); - Value *Op = CI->getArgOperand(0); + if (Ty->isVectorTy()) + return Ret; // exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize // exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize + Value *Op = CI->getArgOperand(0); if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) && hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { - if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) - return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, - LibFunc_ldexp, LibFunc_ldexpf, - LibFunc_ldexpl, B, AttributeList()); + if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) { + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + return copyFlags( + *CI, emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, + LibFunc_ldexp, LibFunc_ldexpf, + LibFunc_ldexpl, B, AttributeList())); + } } return Ret; @@ -2579,7 +2691,7 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, return true; } -Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) { +Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) { // Make sure the prototype is as expected, otherwise the rest of the // function is probably invalid and likely to abort. if (!isTrigLibCall(CI)) @@ -2618,7 +2730,7 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) { replaceTrigInsts(CosCalls, Cos); replaceTrigInsts(SinCosCalls, SinCos); - return nullptr; + return IsSin ? Sin : Cos; } void LibCallSimplifier::classifyArgUse( @@ -3439,6 +3551,15 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeWcslen(CI, Builder); case LibFunc_bcopy: return optimizeBCopy(CI, Builder); + case LibFunc_Znwm: + case LibFunc_ZnwmRKSt9nothrow_t: + case LibFunc_ZnwmSt11align_val_t: + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: + case LibFunc_Znam: + case LibFunc_ZnamRKSt9nothrow_t: + case LibFunc_ZnamSt11align_val_t: + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: + return optimizeNew(CI, Builder, Func); default: break; } @@ -3461,9 +3582,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, switch (Func) { case LibFunc_sinpif: case LibFunc_sinpi: + return optimizeSinCosPi(CI, /*IsSin*/true, Builder); case LibFunc_cospif: case LibFunc_cospi: - return optimizeSinCosPi(CI, Builder); + return optimizeSinCosPi(CI, /*IsSin*/false, Builder); case LibFunc_powf: case LibFunc_pow: case LibFunc_powl: @@ -3696,13 +3818,13 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) { } LibCallSimplifier::LibCallSimplifier( - const DataLayout &DL, const TargetLibraryInfo *TLI, - OptimizationRemarkEmitter &ORE, - BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, + const DataLayout &DL, const TargetLibraryInfo *TLI, AssumptionCache *AC, + OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, function_ref<void(Instruction *, Value *)> Replacer, function_ref<void(Instruction *)> Eraser) - : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), BFI(BFI), PSI(PSI), - Replacer(Replacer), Eraser(Eraser) {} + : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), AC(AC), ORE(ORE), BFI(BFI), + PSI(PSI), Replacer(Replacer), Eraser(Eraser) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // Indirect through the replacer used in this instance. diff --git a/llvm/lib/Transforms/Utils/SizeOpts.cpp b/llvm/lib/Transforms/Utils/SizeOpts.cpp index 1242380f73c1..1ca2e0e6ebb9 100644 --- a/llvm/lib/Transforms/Utils/SizeOpts.cpp +++ b/llvm/lib/Transforms/Utils/SizeOpts.cpp @@ -98,14 +98,12 @@ struct BasicBlockBFIAdapter { bool llvm::shouldOptimizeForSize(const Function *F, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, PGSOQueryType QueryType) { - return shouldFuncOptimizeForSizeImpl<BasicBlockBFIAdapter>(F, PSI, BFI, - QueryType); + return shouldFuncOptimizeForSizeImpl(F, PSI, BFI, QueryType); } bool llvm::shouldOptimizeForSize(const BasicBlock *BB, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, PGSOQueryType QueryType) { assert(BB); - return shouldOptimizeForSizeImpl<BasicBlockBFIAdapter>(BB, PSI, BFI, - QueryType); + return shouldOptimizeForSizeImpl(BB, PSI, BFI, QueryType); } diff --git a/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp b/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp index 10fda4df51ba..618c6bab3a8f 100644 --- a/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp +++ b/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp @@ -8,44 +8,13 @@ #include "llvm/Transforms/Utils/StripNonLineTableDebugInfo.h" #include "llvm/IR/DebugInfo.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Utils.h" -using namespace llvm; - -namespace { - -/// This pass strips all debug info that is not related line tables. -/// The result will be the same as if the program where compiled with -/// -gline-tables-only. -struct StripNonLineTableDebugLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - StripNonLineTableDebugLegacyPass() : ModulePass(ID) { - initializeStripNonLineTableDebugLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - - bool runOnModule(Module &M) override { - return llvm::stripNonLineTableDebugInfo(M); - } -}; -} - -char StripNonLineTableDebugLegacyPass::ID = 0; -INITIALIZE_PASS(StripNonLineTableDebugLegacyPass, - "strip-nonlinetable-debuginfo", - "Strip all debug info except linetables", false, false) - -ModulePass *llvm::createStripNonLineTableDebugLegacyPass() { - return new StripNonLineTableDebugLegacyPass(); -} +using namespace llvm; PreservedAnalyses StripNonLineTableDebugInfoPass::run(Module &M, ModuleAnalysisManager &AM) { llvm::stripNonLineTableDebugInfo(M); - return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } diff --git a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index 4ad16d622e8d..c3ae43e567b0 100644 --- a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -517,37 +517,6 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, return true; } -namespace { - -class RewriteSymbolsLegacyPass : public ModulePass { -public: - static char ID; // Pass identification, replacement for typeid - - RewriteSymbolsLegacyPass(); - RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList &DL); - - bool runOnModule(Module &M) override; - -private: - RewriteSymbolPass Impl; -}; - -} // end anonymous namespace - -char RewriteSymbolsLegacyPass::ID = 0; - -RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID) { - initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry()); -} - -RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass( - SymbolRewriter::RewriteDescriptorList &DL) - : ModulePass(ID), Impl(DL) {} - -bool RewriteSymbolsLegacyPass::runOnModule(Module &M) { - return Impl.runImpl(M); -} - PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) { if (!runImpl(M)) return PreservedAnalyses::all(); @@ -572,15 +541,3 @@ void RewriteSymbolPass::loadAndParseMapFiles() { for (const auto &MapFile : MapFiles) Parser.parse(MapFile, &Descriptors); } - -INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols", - false, false) - -ModulePass *llvm::createRewriteSymbolsPass() { - return new RewriteSymbolsLegacyPass(); -} - -ModulePass * -llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) { - return new RewriteSymbolsLegacyPass(DL); -} diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 3be96ebc93a2..8c781f59ff5a 100644 --- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -113,7 +113,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L, } } - for (auto II : ExternalUsers) { + for (const auto &II : ExternalUsers) { // For each Def used outside the loop, create NewPhi in // LoopExitBlock. NewPhi receives Def only along exiting blocks that // dominate it, while the remaining values are undefined since those paths @@ -130,7 +130,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L, NewPhi->addIncoming(Def, In); } else { LLVM_DEBUG(dbgs() << "not dominated\n"); - NewPhi->addIncoming(UndefValue::get(Def->getType()), In); + NewPhi->addIncoming(PoisonValue::get(Def->getType()), In); } } diff --git a/llvm/lib/Transforms/Utils/Utils.cpp b/llvm/lib/Transforms/Utils/Utils.cpp index d002922cfd30..91c743f17764 100644 --- a/llvm/lib/Transforms/Utils/Utils.cpp +++ b/llvm/lib/Transforms/Utils/Utils.cpp @@ -12,9 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils.h" -#include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/Utils.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/PassRegistry.h" @@ -24,42 +21,18 @@ using namespace llvm; /// initializeTransformUtils - Initialize all passes in the TransformUtils /// library. void llvm::initializeTransformUtils(PassRegistry &Registry) { - initializeAddDiscriminatorsLegacyPassPass(Registry); - initializeAssumeSimplifyPassLegacyPassPass(Registry); initializeAssumeBuilderPassLegacyPassPass(Registry); initializeBreakCriticalEdgesPass(Registry); initializeCanonicalizeFreezeInLoopsPass(Registry); - initializeInstNamerPass(Registry); initializeLCSSAWrapperPassPass(Registry); - initializeLibCallsShrinkWrapLegacyPassPass(Registry); initializeLoopSimplifyPass(Registry); initializeLowerGlobalDtorsLegacyPassPass(Registry); initializeLowerInvokeLegacyPassPass(Registry); initializeLowerSwitchLegacyPassPass(Registry); initializePromoteLegacyPassPass(Registry); - initializeStripNonLineTableDebugLegacyPassPass(Registry); initializeUnifyFunctionExitNodesLegacyPassPass(Registry); - initializeMetaRenamerPass(Registry); initializeStripGCRelocatesLegacyPass(Registry); initializePredicateInfoPrinterLegacyPassPass(Registry); - initializeInjectTLIMappingsLegacyPass(Registry); initializeFixIrreduciblePass(Registry); initializeUnifyLoopExitsLegacyPassPass(Registry); } - -/// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses. -void LLVMInitializeTransformUtils(LLVMPassRegistryRef R) { - initializeTransformUtils(*unwrap(R)); -} - -void LLVMAddLowerSwitchPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLowerSwitchPass()); -} - -void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createPromoteMemoryToRegisterPass()); -} - -void LLVMAddAddDiscriminatorsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createAddDiscriminatorsPass()); -} diff --git a/llvm/lib/Transforms/Utils/VNCoercion.cpp b/llvm/lib/Transforms/Utils/VNCoercion.cpp index f295a7e312b6..7a597da2bc51 100644 --- a/llvm/lib/Transforms/Utils/VNCoercion.cpp +++ b/llvm/lib/Transforms/Utils/VNCoercion.cpp @@ -226,91 +226,6 @@ int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, DL); } -/// Looks at a memory location for a load (specified by MemLocBase, Offs, and -/// Size) and compares it against a load. -/// -/// If the specified load could be safely widened to a larger integer load -/// that is 1) still efficient, 2) safe for the target, and 3) would provide -/// the specified memory location value, then this function returns the size -/// in bytes of the load width to use. If not, this returns zero. -static unsigned getLoadLoadClobberFullWidthSize(const Value *MemLocBase, - int64_t MemLocOffs, - unsigned MemLocSize, - const LoadInst *LI) { - // We can only extend simple integer loads. - if (!isa<IntegerType>(LI->getType()) || !LI->isSimple()) - return 0; - - // Load widening is hostile to ThreadSanitizer: it may cause false positives - // or make the reports more cryptic (access sizes are wrong). - if (LI->getParent()->getParent()->hasFnAttribute(Attribute::SanitizeThread)) - return 0; - - const DataLayout &DL = LI->getModule()->getDataLayout(); - - // Get the base of this load. - int64_t LIOffs = 0; - const Value *LIBase = - GetPointerBaseWithConstantOffset(LI->getPointerOperand(), LIOffs, DL); - - // If the two pointers are not based on the same pointer, we can't tell that - // they are related. - if (LIBase != MemLocBase) - return 0; - - // Okay, the two values are based on the same pointer, but returned as - // no-alias. This happens when we have things like two byte loads at "P+1" - // and "P+3". Check to see if increasing the size of the "LI" load up to its - // alignment (or the largest native integer type) will allow us to load all - // the bits required by MemLoc. - - // If MemLoc is before LI, then no widening of LI will help us out. - if (MemLocOffs < LIOffs) - return 0; - - // Get the alignment of the load in bytes. We assume that it is safe to load - // any legal integer up to this size without a problem. For example, if we're - // looking at an i8 load on x86-32 that is known 1024 byte aligned, we can - // widen it up to an i32 load. If it is known 2-byte aligned, we can widen it - // to i16. - unsigned LoadAlign = LI->getAlign().value(); - - int64_t MemLocEnd = MemLocOffs + MemLocSize; - - // If no amount of rounding up will let MemLoc fit into LI, then bail out. - if (LIOffs + LoadAlign < MemLocEnd) - return 0; - - // This is the size of the load to try. Start with the next larger power of - // two. - unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits() / 8U; - NewLoadByteSize = NextPowerOf2(NewLoadByteSize); - - while (true) { - // If this load size is bigger than our known alignment or would not fit - // into a native integer register, then we fail. - if (NewLoadByteSize > LoadAlign || - !DL.fitsInLegalInteger(NewLoadByteSize * 8)) - return 0; - - if (LIOffs + NewLoadByteSize > MemLocEnd && - (LI->getParent()->getParent()->hasFnAttribute( - Attribute::SanitizeAddress) || - LI->getParent()->getParent()->hasFnAttribute( - Attribute::SanitizeHWAddress))) - // We will be reading past the location accessed by the original program. - // While this is safe in a regular build, Address Safety analysis tools - // may start reporting false warnings. So, don't do widening. - return 0; - - // If a load of this width would include all of MemLoc, then we succeed. - if (LIOffs + NewLoadByteSize >= MemLocEnd) - return NewLoadByteSize; - - NewLoadByteSize <<= 1; - } -} - /// This function is called when we have a /// memdep query of a load that ends up being clobbered by another load. See if /// the other load can feed into the second load. @@ -325,28 +240,7 @@ int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI, Value *DepPtr = DepLI->getPointerOperand(); uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()).getFixedValue(); - int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); - if (R != -1) - return R; - - // If we have a load/load clobber an DepLI can be widened to cover this load, - // then we should widen it! - int64_t LoadOffs = 0; - const Value *LoadBase = - GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); - - unsigned Size = - getLoadLoadClobberFullWidthSize(LoadBase, LoadOffs, LoadSize, DepLI); - if (Size == 0) - return -1; - - // Check non-obvious conditions enforced by MDA which we rely on for being - // able to materialize this potentially available value - assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!"); - assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load"); - - return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size * 8, DL); + return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); } int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, @@ -438,83 +332,27 @@ static Value *getStoreValueForLoadHelper(Value *SrcVal, unsigned Offset, return SrcVal; } -/// This function is called when we have a memdep query of a load that ends up -/// being a clobbering store. This means that the store provides bits used by -/// the load but the pointers don't must-alias. Check this case to see if -/// there is anything more we can do before we give up. -Value *getStoreValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy, - Instruction *InsertPt, const DataLayout &DL) { +Value *getValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy, + Instruction *InsertPt, const DataLayout &DL) { +#ifndef NDEBUG + unsigned SrcValSize = DL.getTypeStoreSize(SrcVal->getType()).getFixedValue(); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); + assert(Offset + LoadSize <= SrcValSize); +#endif IRBuilder<> Builder(InsertPt); SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, Builder, DL); return coerceAvailableValueToLoadType(SrcVal, LoadTy, Builder, DL); } -Constant *getConstantStoreValueForLoad(Constant *SrcVal, unsigned Offset, - Type *LoadTy, const DataLayout &DL) { - return ConstantFoldLoadFromConst(SrcVal, LoadTy, APInt(32, Offset), DL); -} - -/// This function is called when we have a memdep query of a load that ends up -/// being a clobbering load. This means that the load *may* provide bits used -/// by the load but we can't be sure because the pointers don't must-alias. -/// Check this case to see if there is anything more we can do before we give -/// up. -Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, - Instruction *InsertPt, const DataLayout &DL) { - // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to - // widen SrcVal out to a larger load. - unsigned SrcValStoreSize = - DL.getTypeStoreSize(SrcVal->getType()).getFixedValue(); +Constant *getConstantValueForLoad(Constant *SrcVal, unsigned Offset, + Type *LoadTy, const DataLayout &DL) { +#ifndef NDEBUG + unsigned SrcValSize = DL.getTypeStoreSize(SrcVal->getType()).getFixedValue(); unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); - if (Offset + LoadSize > SrcValStoreSize) { - assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); - assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); - // If we have a load/load clobber an DepLI can be widened to cover this - // load, then we should widen it to the next power of 2 size big enough! - unsigned NewLoadSize = Offset + LoadSize; - if (!isPowerOf2_32(NewLoadSize)) - NewLoadSize = NextPowerOf2(NewLoadSize); - - Value *PtrVal = SrcVal->getPointerOperand(); - // Insert the new load after the old load. This ensures that subsequent - // memdep queries will find the new load. We can't easily remove the old - // load completely because it is already in the value numbering table. - IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal)); - Type *DestTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8); - Type *DestPTy = - PointerType::get(DestTy, PtrVal->getType()->getPointerAddressSpace()); - Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc()); - PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); - LoadInst *NewLoad = Builder.CreateLoad(DestTy, PtrVal); - NewLoad->takeName(SrcVal); - NewLoad->setAlignment(SrcVal->getAlign()); - - LLVM_DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); - LLVM_DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); - - // Replace uses of the original load with the wider load. On a big endian - // system, we need to shift down to get the relevant bits. - Value *RV = NewLoad; - if (DL.isBigEndian()) - RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8); - RV = Builder.CreateTrunc(RV, SrcVal->getType()); - SrcVal->replaceAllUsesWith(RV); - - SrcVal = NewLoad; - } - - return getStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL); -} - -Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset, - Type *LoadTy, const DataLayout &DL) { - unsigned SrcValStoreSize = - DL.getTypeStoreSize(SrcVal->getType()).getFixedValue(); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); - if (Offset + LoadSize > SrcValStoreSize) - return nullptr; - return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL); + assert(Offset + LoadSize <= SrcValSize); +#endif + return ConstantFoldLoadFromConst(SrcVal, LoadTy, APInt(32, Offset), DL); } /// This function is called when we have a diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index a5edbb2acc6d..3446e31cc2ef 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -523,10 +523,14 @@ Value *Mapper::mapValue(const Value *V) { if (isa<ConstantVector>(C)) return getVM()[V] = ConstantVector::get(Ops); // If this is a no-operand constant, it must be because the type was remapped. + if (isa<PoisonValue>(C)) + return getVM()[V] = PoisonValue::get(NewTy); if (isa<UndefValue>(C)) return getVM()[V] = UndefValue::get(NewTy); if (isa<ConstantAggregateZero>(C)) return getVM()[V] = ConstantAggregateZero::get(NewTy); + if (isa<ConstantTargetNone>(C)) + return getVM()[V] = Constant::getNullValue(NewTy); assert(isa<ConstantPointerNull>(C)); return getVM()[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); } @@ -1030,7 +1034,7 @@ void Mapper::mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, if (IsOldCtorDtor) { // FIXME: This upgrade is done during linking to support the C API. See // also IRLinker::linkAppendingVarProto() in IRMover.cpp. - VoidPtrTy = Type::getInt8Ty(GV.getContext())->getPointerTo(); + VoidPtrTy = PointerType::getUnqual(GV.getContext()); auto &ST = *cast<StructType>(NewMembers.front()->getType()); Type *Tys[3] = {ST.getElementType(0), ST.getElementType(1), VoidPtrTy}; EltTy = StructType::get(GV.getContext(), Tys, false); @@ -1179,6 +1183,10 @@ void ValueMapper::remapFunction(Function &F) { FlushingMapper(pImpl)->remapFunction(F); } +void ValueMapper::remapGlobalObjectMetadata(GlobalObject &GO) { + FlushingMapper(pImpl)->remapGlobalObjectMetadata(GO); +} + void ValueMapper::scheduleMapGlobalInitializer(GlobalVariable &GV, Constant &Init, unsigned MCID) { diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 0b7fc853dc1b..260d7889906b 100644 --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -37,13 +37,34 @@ // multiple scalar registers, similar to a GPU vectorized load. In theory ARM // could use this pass (with some modifications), but currently it implements // its own pass to do something similar to what we do here. +// +// Overview of the algorithm and terminology in this pass: +// +// - Break up each basic block into pseudo-BBs, composed of instructions which +// are guaranteed to transfer control to their successors. +// - Within a single pseudo-BB, find all loads, and group them into +// "equivalence classes" according to getUnderlyingObject() and loaded +// element size. Do the same for stores. +// - For each equivalence class, greedily build "chains". Each chain has a +// leader instruction, and every other member of the chain has a known +// constant offset from the first instr in the chain. +// - Break up chains so that they contain only contiguous accesses of legal +// size with no intervening may-alias instrs. +// - Convert each chain to vector instructions. +// +// The O(n^2) behavior of this pass comes from initially building the chains. +// In the worst case we have to compare each new instruction to all of those +// that came before. To limit this, we only calculate the offset to the leaders +// of the N most recently-used chains. #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -57,6 +78,7 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -67,23 +89,33 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ModRef.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" #include <algorithm> #include <cassert> +#include <cstdint> #include <cstdlib> +#include <iterator> +#include <limits> +#include <numeric> +#include <optional> #include <tuple> +#include <type_traits> #include <utility> +#include <vector> using namespace llvm; @@ -92,21 +124,115 @@ using namespace llvm; STATISTIC(NumVectorInstructions, "Number of vector accesses generated"); STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized"); +namespace { + +// Equivalence class key, the initial tuple by which we group loads/stores. +// Loads/stores with different EqClassKeys are never merged. +// +// (We could in theory remove element-size from the this tuple. We'd just need +// to fix up the vector packing/unpacking code.) +using EqClassKey = + std::tuple<const Value * /* result of getUnderlyingObject() */, + unsigned /* AddrSpace */, + unsigned /* Load/Store element size bits */, + char /* IsLoad; char b/c bool can't be a DenseMap key */ + >; +[[maybe_unused]] llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + const EqClassKey &K) { + const auto &[UnderlyingObject, AddrSpace, ElementSize, IsLoad] = K; + return OS << (IsLoad ? "load" : "store") << " of " << *UnderlyingObject + << " of element size " << ElementSize << " bits in addrspace " + << AddrSpace; +} + +// A Chain is a set of instructions such that: +// - All instructions have the same equivalence class, so in particular all are +// loads, or all are stores. +// - We know the address accessed by the i'th chain elem relative to the +// chain's leader instruction, which is the first instr of the chain in BB +// order. +// +// Chains have two canonical orderings: +// - BB order, sorted by Instr->comesBefore. +// - Offset order, sorted by OffsetFromLeader. +// This pass switches back and forth between these orders. +struct ChainElem { + Instruction *Inst; + APInt OffsetFromLeader; +}; +using Chain = SmallVector<ChainElem, 1>; + +void sortChainInBBOrder(Chain &C) { + sort(C, [](auto &A, auto &B) { return A.Inst->comesBefore(B.Inst); }); +} + +void sortChainInOffsetOrder(Chain &C) { + sort(C, [](const auto &A, const auto &B) { + if (A.OffsetFromLeader != B.OffsetFromLeader) + return A.OffsetFromLeader.slt(B.OffsetFromLeader); + return A.Inst->comesBefore(B.Inst); // stable tiebreaker + }); +} + +[[maybe_unused]] void dumpChain(ArrayRef<ChainElem> C) { + for (const auto &E : C) { + dbgs() << " " << *E.Inst << " (offset " << E.OffsetFromLeader << ")\n"; + } +} + +using EquivalenceClassMap = + MapVector<EqClassKey, SmallVector<Instruction *, 8>>; + // FIXME: Assuming stack alignment of 4 is always good enough -static const unsigned StackAdjustedAlignment = 4; +constexpr unsigned StackAdjustedAlignment = 4; -namespace { +Instruction *propagateMetadata(Instruction *I, const Chain &C) { + SmallVector<Value *, 8> Values; + for (const ChainElem &E : C) + Values.push_back(E.Inst); + return propagateMetadata(I, Values); +} -/// ChainID is an arbitrary token that is allowed to be different only for the -/// accesses that are guaranteed to be considered non-consecutive by -/// Vectorizer::isConsecutiveAccess. It's used for grouping instructions -/// together and reducing the number of instructions the main search operates on -/// at a time, i.e. this is to reduce compile time and nothing else as the main -/// search has O(n^2) time complexity. The underlying type of ChainID should not -/// be relied upon. -using ChainID = const Value *; -using InstrList = SmallVector<Instruction *, 8>; -using InstrListMap = MapVector<ChainID, InstrList>; +bool isInvariantLoad(const Instruction *I) { + const LoadInst *LI = dyn_cast<LoadInst>(I); + return LI != nullptr && LI->hasMetadata(LLVMContext::MD_invariant_load); +} + +/// Reorders the instructions that I depends on (the instructions defining its +/// operands), to ensure they dominate I. +void reorder(Instruction *I) { + SmallPtrSet<Instruction *, 16> InstructionsToMove; + SmallVector<Instruction *, 16> Worklist; + + Worklist.push_back(I); + while (!Worklist.empty()) { + Instruction *IW = Worklist.pop_back_val(); + int NumOperands = IW->getNumOperands(); + for (int i = 0; i < NumOperands; i++) { + Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i)); + if (!IM || IM->getOpcode() == Instruction::PHI) + continue; + + // If IM is in another BB, no need to move it, because this pass only + // vectorizes instructions within one BB. + if (IM->getParent() != I->getParent()) + continue; + + if (!IM->comesBefore(I)) { + InstructionsToMove.insert(IM); + Worklist.push_back(IM); + } + } + } + + // All instructions to move should follow I. Start from I, not from begin(). + for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;) { + Instruction *IM = &*(BBI++); + if (!InstructionsToMove.count(IM)) + continue; + IM->moveBefore(I); + } +} class Vectorizer { Function &F; @@ -118,6 +244,12 @@ class Vectorizer { const DataLayout &DL; IRBuilder<> Builder; + // We could erase instrs right after vectorizing them, but that can mess up + // our BB iterators, and also can make the equivalence class keys point to + // freed memory. This is fixable, but it's simpler just to wait until we're + // done with the BB and erase all at once. + SmallVector<Instruction *, 128> ToErase; + public: Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC, DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI) @@ -127,70 +259,83 @@ public: bool run(); private: - unsigned getPointerAddressSpace(Value *I); - static const unsigned MaxDepth = 3; - bool isConsecutiveAccess(Value *A, Value *B); - bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta, - unsigned Depth = 0) const; - bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, - unsigned Depth) const; - bool lookThroughSelects(Value *PtrA, Value *PtrB, const APInt &PtrDelta, - unsigned Depth) const; - - /// After vectorization, reorder the instructions that I depends on - /// (the instructions defining its operands), to ensure they dominate I. - void reorder(Instruction *I); - - /// Returns the first and the last instructions in Chain. - std::pair<BasicBlock::iterator, BasicBlock::iterator> - getBoundaryInstrs(ArrayRef<Instruction *> Chain); - - /// Erases the original instructions after vectorizing. - void eraseInstructions(ArrayRef<Instruction *> Chain); - - /// "Legalize" the vector type that would be produced by combining \p - /// ElementSizeBits elements in \p Chain. Break into two pieces such that the - /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is - /// expected to have more than 4 elements. - std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>> - splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits); - - /// Finds the largest prefix of Chain that's vectorizable, checking for - /// intervening instructions which may affect the memory accessed by the - /// instructions within Chain. + /// Runs the vectorizer on a "pseudo basic block", which is a range of + /// instructions [Begin, End) within one BB all of which have + /// isGuaranteedToTransferExecutionToSuccessor(I) == true. + bool runOnPseudoBB(BasicBlock::iterator Begin, BasicBlock::iterator End); + + /// Runs the vectorizer on one equivalence class, i.e. one set of loads/stores + /// in the same BB with the same value for getUnderlyingObject() etc. + bool runOnEquivalenceClass(const EqClassKey &EqClassKey, + ArrayRef<Instruction *> EqClass); + + /// Runs the vectorizer on one chain, i.e. a subset of an equivalence class + /// where all instructions access a known, constant offset from the first + /// instruction. + bool runOnChain(Chain &C); + + /// Splits the chain into subchains of instructions which read/write a + /// contiguous block of memory. Discards any length-1 subchains (because + /// there's nothing to vectorize in there). + std::vector<Chain> splitChainByContiguity(Chain &C); + + /// Splits the chain into subchains where it's safe to hoist loads up to the + /// beginning of the sub-chain and it's safe to sink loads up to the end of + /// the sub-chain. Discards any length-1 subchains. + std::vector<Chain> splitChainByMayAliasInstrs(Chain &C); + + /// Splits the chain into subchains that make legal, aligned accesses. + /// Discards any length-1 subchains. + std::vector<Chain> splitChainByAlignment(Chain &C); + + /// Converts the instrs in the chain into a single vectorized load or store. + /// Adds the old scalar loads/stores to ToErase. + bool vectorizeChain(Chain &C); + + /// Tries to compute the offset in bytes PtrB - PtrA. + std::optional<APInt> getConstantOffset(Value *PtrA, Value *PtrB, + Instruction *ContextInst, + unsigned Depth = 0); + std::optional<APInt> getConstantOffsetComplexAddrs(Value *PtrA, Value *PtrB, + Instruction *ContextInst, + unsigned Depth); + std::optional<APInt> getConstantOffsetSelects(Value *PtrA, Value *PtrB, + Instruction *ContextInst, + unsigned Depth); + + /// Gets the element type of the vector that the chain will load or store. + /// This is nontrivial because the chain may contain elements of different + /// types; e.g. it's legal to have a chain that contains both i32 and float. + Type *getChainElemTy(const Chain &C); + + /// Determines whether ChainElem can be moved up (if IsLoad) or down (if + /// !IsLoad) to ChainBegin -- i.e. there are no intervening may-alias + /// instructions. + /// + /// The map ChainElemOffsets must contain all of the elements in + /// [ChainBegin, ChainElem] and their offsets from some arbitrary base + /// address. It's ok if it contains additional entries. + template <bool IsLoadChain> + bool isSafeToMove( + Instruction *ChainElem, Instruction *ChainBegin, + const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets); + + /// Collects loads and stores grouped by "equivalence class", where: + /// - all elements in an eq class are a load or all are a store, + /// - they all load/store the same element size (it's OK to have e.g. i8 and + /// <4 x i8> in the same class, but not i32 and <4 x i8>), and + /// - they all have the same value for getUnderlyingObject(). + EquivalenceClassMap collectEquivalenceClasses(BasicBlock::iterator Begin, + BasicBlock::iterator End); + + /// Partitions Instrs into "chains" where every instruction has a known + /// constant offset from the first instr in the chain. /// - /// The elements of \p Chain must be all loads or all stores and must be in - /// address order. - ArrayRef<Instruction *> getVectorizablePrefix(ArrayRef<Instruction *> Chain); - - /// Collects load and store instructions to vectorize. - std::pair<InstrListMap, InstrListMap> collectInstructions(BasicBlock *BB); - - /// Processes the collected instructions, the \p Map. The values of \p Map - /// should be all loads or all stores. - bool vectorizeChains(InstrListMap &Map); - - /// Finds the load/stores to consecutive memory addresses and vectorizes them. - bool vectorizeInstructions(ArrayRef<Instruction *> Instrs); - - /// Vectorizes the load instructions in Chain. - bool - vectorizeLoadChain(ArrayRef<Instruction *> Chain, - SmallPtrSet<Instruction *, 16> *InstructionsProcessed); - - /// Vectorizes the store instructions in Chain. - bool - vectorizeStoreChain(ArrayRef<Instruction *> Chain, - SmallPtrSet<Instruction *, 16> *InstructionsProcessed); - - /// Check if this load/store access is misaligned accesses. - /// Returns a \p RelativeSpeed of an operation if allowed suitable to - /// compare to another result for the same \p AddressSpace and potentially - /// different \p Alignment and \p SzInBytes. - bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, - Align Alignment, unsigned &RelativeSpeed); + /// Postcondition: For all i, ret[i][0].second == 0, because the first instr + /// in the chain is the leader, and an instr touches distance 0 from itself. + std::vector<Chain> gatherChains(ArrayRef<Instruction *> Instrs); }; class LoadStoreVectorizerLegacyPass : public FunctionPass { @@ -198,7 +343,8 @@ public: static char ID; LoadStoreVectorizerLegacyPass() : FunctionPass(ID) { - initializeLoadStoreVectorizerLegacyPassPass(*PassRegistry::getPassRegistry()); + initializeLoadStoreVectorizerLegacyPassPass( + *PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; @@ -250,11 +396,11 @@ bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) { AssumptionCache &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - Vectorizer V(F, AA, AC, DT, SE, TTI); - return V.run(); + return Vectorizer(F, AA, AC, DT, SE, TTI).run(); } -PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { +PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, + FunctionAnalysisManager &AM) { // Don't vectorize when the attribute NoImplicitFloat is used. if (F.hasFnAttribute(Attribute::NoImplicitFloat)) return PreservedAnalyses::all(); @@ -265,125 +411,681 @@ PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisMana TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); - Vectorizer V(F, AA, AC, DT, SE, TTI); - bool Changed = V.run(); + bool Changed = Vectorizer(F, AA, AC, DT, SE, TTI).run(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); return Changed ? PA : PreservedAnalyses::all(); } -// The real propagateMetadata expects a SmallVector<Value*>, but we deal in -// vectors of Instructions. -static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) { - SmallVector<Value *, 8> VL(IL.begin(), IL.end()); - propagateMetadata(I, VL); -} - -// Vectorizer Implementation bool Vectorizer::run() { bool Changed = false; - - // Scan the blocks in the function in post order. + // Break up the BB if there are any instrs which aren't guaranteed to transfer + // execution to their successor. + // + // Consider, for example: + // + // def assert_arr_len(int n) { if (n < 2) exit(); } + // + // load arr[0] + // call assert_array_len(arr.length) + // load arr[1] + // + // Even though assert_arr_len does not read or write any memory, we can't + // speculate the second load before the call. More info at + // https://github.com/llvm/llvm-project/issues/52950. for (BasicBlock *BB : post_order(&F)) { - InstrListMap LoadRefs, StoreRefs; - std::tie(LoadRefs, StoreRefs) = collectInstructions(BB); - Changed |= vectorizeChains(LoadRefs); - Changed |= vectorizeChains(StoreRefs); + // BB must at least have a terminator. + assert(!BB->empty()); + + SmallVector<BasicBlock::iterator, 8> Barriers; + Barriers.push_back(BB->begin()); + for (Instruction &I : *BB) + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + Barriers.push_back(I.getIterator()); + Barriers.push_back(BB->end()); + + for (auto It = Barriers.begin(), End = std::prev(Barriers.end()); It != End; + ++It) + Changed |= runOnPseudoBB(*It, *std::next(It)); + + for (Instruction *I : ToErase) { + auto *PtrOperand = getLoadStorePointerOperand(I); + if (I->use_empty()) + I->eraseFromParent(); + RecursivelyDeleteTriviallyDeadInstructions(PtrOperand); + } + ToErase.clear(); } return Changed; } -unsigned Vectorizer::getPointerAddressSpace(Value *I) { - if (LoadInst *L = dyn_cast<LoadInst>(I)) - return L->getPointerAddressSpace(); - if (StoreInst *S = dyn_cast<StoreInst>(I)) - return S->getPointerAddressSpace(); - return -1; +bool Vectorizer::runOnPseudoBB(BasicBlock::iterator Begin, + BasicBlock::iterator End) { + LLVM_DEBUG({ + dbgs() << "LSV: Running on pseudo-BB [" << *Begin << " ... "; + if (End != Begin->getParent()->end()) + dbgs() << *End; + else + dbgs() << "<BB end>"; + dbgs() << ")\n"; + }); + + bool Changed = false; + for (const auto &[EqClassKey, EqClass] : + collectEquivalenceClasses(Begin, End)) + Changed |= runOnEquivalenceClass(EqClassKey, EqClass); + + return Changed; } -// FIXME: Merge with llvm::isConsecutiveAccess -bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { - Value *PtrA = getLoadStorePointerOperand(A); - Value *PtrB = getLoadStorePointerOperand(B); - unsigned ASA = getPointerAddressSpace(A); - unsigned ASB = getPointerAddressSpace(B); +bool Vectorizer::runOnEquivalenceClass(const EqClassKey &EqClassKey, + ArrayRef<Instruction *> EqClass) { + bool Changed = false; - // Check that the address spaces match and that the pointers are valid. - if (!PtrA || !PtrB || (ASA != ASB)) - return false; + LLVM_DEBUG({ + dbgs() << "LSV: Running on equivalence class of size " << EqClass.size() + << " keyed on " << EqClassKey << ":\n"; + for (Instruction *I : EqClass) + dbgs() << " " << *I << "\n"; + }); - // Make sure that A and B are different pointers of the same size type. - Type *PtrATy = getLoadStoreType(A); - Type *PtrBTy = getLoadStoreType(B); - if (PtrA == PtrB || - PtrATy->isVectorTy() != PtrBTy->isVectorTy() || - DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) || - DL.getTypeStoreSize(PtrATy->getScalarType()) != - DL.getTypeStoreSize(PtrBTy->getScalarType())) - return false; + std::vector<Chain> Chains = gatherChains(EqClass); + LLVM_DEBUG(dbgs() << "LSV: Got " << Chains.size() + << " nontrivial chains.\n";); + for (Chain &C : Chains) + Changed |= runOnChain(C); + return Changed; +} - unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA); - APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy)); +bool Vectorizer::runOnChain(Chain &C) { + LLVM_DEBUG({ + dbgs() << "LSV: Running on chain with " << C.size() << " instructions:\n"; + dumpChain(C); + }); - return areConsecutivePointers(PtrA, PtrB, Size); + // Split up the chain into increasingly smaller chains, until we can finally + // vectorize the chains. + // + // (Don't be scared by the depth of the loop nest here. These operations are + // all at worst O(n lg n) in the number of instructions, and splitting chains + // doesn't change the number of instrs. So the whole loop nest is O(n lg n).) + bool Changed = false; + for (auto &C : splitChainByMayAliasInstrs(C)) + for (auto &C : splitChainByContiguity(C)) + for (auto &C : splitChainByAlignment(C)) + Changed |= vectorizeChain(C); + return Changed; } -bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, - APInt PtrDelta, unsigned Depth) const { - unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType()); - APInt OffsetA(PtrBitWidth, 0); - APInt OffsetB(PtrBitWidth, 0); - PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); - PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); +std::vector<Chain> Vectorizer::splitChainByMayAliasInstrs(Chain &C) { + if (C.empty()) + return {}; - unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType()); + sortChainInBBOrder(C); - if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType())) + LLVM_DEBUG({ + dbgs() << "LSV: splitChainByMayAliasInstrs considering chain:\n"; + dumpChain(C); + }); + + // We know that elements in the chain with nonverlapping offsets can't + // alias, but AA may not be smart enough to figure this out. Use a + // hashtable. + DenseMap<Instruction *, APInt /*OffsetFromLeader*/> ChainOffsets; + for (const auto &E : C) + ChainOffsets.insert({&*E.Inst, E.OffsetFromLeader}); + + // Loads get hoisted up to the first load in the chain. Stores get sunk + // down to the last store in the chain. Our algorithm for loads is: + // + // - Take the first element of the chain. This is the start of a new chain. + // - Take the next element of `Chain` and check for may-alias instructions + // up to the start of NewChain. If no may-alias instrs, add it to + // NewChain. Otherwise, start a new NewChain. + // + // For stores it's the same except in the reverse direction. + // + // We expect IsLoad to be an std::bool_constant. + auto Impl = [&](auto IsLoad) { + // MSVC is unhappy if IsLoad is a capture, so pass it as an arg. + auto [ChainBegin, ChainEnd] = [&](auto IsLoad) { + if constexpr (IsLoad()) + return std::make_pair(C.begin(), C.end()); + else + return std::make_pair(C.rbegin(), C.rend()); + }(IsLoad); + assert(ChainBegin != ChainEnd); + + std::vector<Chain> Chains; + SmallVector<ChainElem, 1> NewChain; + NewChain.push_back(*ChainBegin); + for (auto ChainIt = std::next(ChainBegin); ChainIt != ChainEnd; ++ChainIt) { + if (isSafeToMove<IsLoad>(ChainIt->Inst, NewChain.front().Inst, + ChainOffsets)) { + LLVM_DEBUG(dbgs() << "LSV: No intervening may-alias instrs; can merge " + << *ChainIt->Inst << " into " << *ChainBegin->Inst + << "\n"); + NewChain.push_back(*ChainIt); + } else { + LLVM_DEBUG( + dbgs() << "LSV: Found intervening may-alias instrs; cannot merge " + << *ChainIt->Inst << " into " << *ChainBegin->Inst << "\n"); + if (NewChain.size() > 1) { + LLVM_DEBUG({ + dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n"; + dumpChain(NewChain); + }); + Chains.push_back(std::move(NewChain)); + } + + // Start a new chain. + NewChain = SmallVector<ChainElem, 1>({*ChainIt}); + } + } + if (NewChain.size() > 1) { + LLVM_DEBUG({ + dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n"; + dumpChain(NewChain); + }); + Chains.push_back(std::move(NewChain)); + } + return Chains; + }; + + if (isa<LoadInst>(C[0].Inst)) + return Impl(/*IsLoad=*/std::bool_constant<true>()); + + assert(isa<StoreInst>(C[0].Inst)); + return Impl(/*IsLoad=*/std::bool_constant<false>()); +} + +std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) { + if (C.empty()) + return {}; + + sortChainInOffsetOrder(C); + + LLVM_DEBUG({ + dbgs() << "LSV: splitChainByContiguity considering chain:\n"; + dumpChain(C); + }); + + std::vector<Chain> Ret; + Ret.push_back({C.front()}); + + for (auto It = std::next(C.begin()), End = C.end(); It != End; ++It) { + // `prev` accesses offsets [PrevDistFromBase, PrevReadEnd). + auto &CurChain = Ret.back(); + const ChainElem &Prev = CurChain.back(); + unsigned SzBits = DL.getTypeSizeInBits(getLoadStoreType(&*Prev.Inst)); + assert(SzBits % 8 == 0 && "Non-byte sizes should have been filtered out by " + "collectEquivalenceClass"); + APInt PrevReadEnd = Prev.OffsetFromLeader + SzBits / 8; + + // Add this instruction to the end of the current chain, or start a new one. + bool AreContiguous = It->OffsetFromLeader == PrevReadEnd; + LLVM_DEBUG(dbgs() << "LSV: Instructions are " + << (AreContiguous ? "" : "not ") << "contiguous: " + << *Prev.Inst << " (ends at offset " << PrevReadEnd + << ") -> " << *It->Inst << " (starts at offset " + << It->OffsetFromLeader << ")\n"); + if (AreContiguous) + CurChain.push_back(*It); + else + Ret.push_back({*It}); + } + + // Filter out length-1 chains, these are uninteresting. + llvm::erase_if(Ret, [](const auto &Chain) { return Chain.size() <= 1; }); + return Ret; +} + +Type *Vectorizer::getChainElemTy(const Chain &C) { + assert(!C.empty()); + // The rules are: + // - If there are any pointer types in the chain, use an integer type. + // - Prefer an integer type if it appears in the chain. + // - Otherwise, use the first type in the chain. + // + // The rule about pointer types is a simplification when we merge e.g. a load + // of a ptr and a double. There's no direct conversion from a ptr to a + // double; it requires a ptrtoint followed by a bitcast. + // + // It's unclear to me if the other rules have any practical effect, but we do + // it to match this pass's previous behavior. + if (any_of(C, [](const ChainElem &E) { + return getLoadStoreType(E.Inst)->getScalarType()->isPointerTy(); + })) { + return Type::getIntNTy( + F.getContext(), + DL.getTypeSizeInBits(getLoadStoreType(C[0].Inst)->getScalarType())); + } + + for (const ChainElem &E : C) + if (Type *T = getLoadStoreType(E.Inst)->getScalarType(); T->isIntegerTy()) + return T; + return getLoadStoreType(C[0].Inst)->getScalarType(); +} + +std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) { + // We use a simple greedy algorithm. + // - Given a chain of length N, find all prefixes that + // (a) are not longer than the max register length, and + // (b) are a power of 2. + // - Starting from the longest prefix, try to create a vector of that length. + // - If one of them works, great. Repeat the algorithm on any remaining + // elements in the chain. + // - If none of them work, discard the first element and repeat on a chain + // of length N-1. + if (C.empty()) + return {}; + + sortChainInOffsetOrder(C); + + LLVM_DEBUG({ + dbgs() << "LSV: splitChainByAlignment considering chain:\n"; + dumpChain(C); + }); + + bool IsLoadChain = isa<LoadInst>(C[0].Inst); + auto getVectorFactor = [&](unsigned VF, unsigned LoadStoreSize, + unsigned ChainSizeBytes, VectorType *VecTy) { + return IsLoadChain ? TTI.getLoadVectorFactor(VF, LoadStoreSize, + ChainSizeBytes, VecTy) + : TTI.getStoreVectorFactor(VF, LoadStoreSize, + ChainSizeBytes, VecTy); + }; + +#ifndef NDEBUG + for (const auto &E : C) { + Type *Ty = getLoadStoreType(E.Inst)->getScalarType(); + assert(isPowerOf2_32(DL.getTypeSizeInBits(Ty)) && + "Should have filtered out non-power-of-two elements in " + "collectEquivalenceClasses."); + } +#endif + + unsigned AS = getLoadStoreAddressSpace(C[0].Inst); + unsigned VecRegBytes = TTI.getLoadStoreVecRegBitWidth(AS) / 8; + + std::vector<Chain> Ret; + for (unsigned CBegin = 0; CBegin < C.size(); ++CBegin) { + // Find candidate chains of size not greater than the largest vector reg. + // These chains are over the closed interval [CBegin, CEnd]. + SmallVector<std::pair<unsigned /*CEnd*/, unsigned /*SizeBytes*/>, 8> + CandidateChains; + for (unsigned CEnd = CBegin + 1, Size = C.size(); CEnd < Size; ++CEnd) { + APInt Sz = C[CEnd].OffsetFromLeader + + DL.getTypeStoreSize(getLoadStoreType(C[CEnd].Inst)) - + C[CBegin].OffsetFromLeader; + if (Sz.sgt(VecRegBytes)) + break; + CandidateChains.push_back( + {CEnd, static_cast<unsigned>(Sz.getLimitedValue())}); + } + + // Consider the longest chain first. + for (auto It = CandidateChains.rbegin(), End = CandidateChains.rend(); + It != End; ++It) { + auto [CEnd, SizeBytes] = *It; + LLVM_DEBUG( + dbgs() << "LSV: splitChainByAlignment considering candidate chain [" + << *C[CBegin].Inst << " ... " << *C[CEnd].Inst << "]\n"); + + Type *VecElemTy = getChainElemTy(C); + // Note, VecElemTy is a power of 2, but might be less than one byte. For + // example, we can vectorize 2 x <2 x i4> to <4 x i4>, and in this case + // VecElemTy would be i4. + unsigned VecElemBits = DL.getTypeSizeInBits(VecElemTy); + + // SizeBytes and VecElemBits are powers of 2, so they divide evenly. + assert((8 * SizeBytes) % VecElemBits == 0); + unsigned NumVecElems = 8 * SizeBytes / VecElemBits; + FixedVectorType *VecTy = FixedVectorType::get(VecElemTy, NumVecElems); + unsigned VF = 8 * VecRegBytes / VecElemBits; + + // Check that TTI is happy with this vectorization factor. + unsigned TargetVF = getVectorFactor(VF, VecElemBits, + VecElemBits * NumVecElems / 8, VecTy); + if (TargetVF != VF && TargetVF < NumVecElems) { + LLVM_DEBUG( + dbgs() << "LSV: splitChainByAlignment discarding candidate chain " + "because TargetVF=" + << TargetVF << " != VF=" << VF + << " and TargetVF < NumVecElems=" << NumVecElems << "\n"); + continue; + } + + // Is a load/store with this alignment allowed by TTI and at least as fast + // as an unvectorized load/store? + // + // TTI and F are passed as explicit captures to WAR an MSVC misparse (??). + auto IsAllowedAndFast = [&, SizeBytes = SizeBytes, &TTI = TTI, + &F = F](Align Alignment) { + if (Alignment.value() % SizeBytes == 0) + return true; + unsigned VectorizedSpeed = 0; + bool AllowsMisaligned = TTI.allowsMisalignedMemoryAccesses( + F.getContext(), SizeBytes * 8, AS, Alignment, &VectorizedSpeed); + if (!AllowsMisaligned) { + LLVM_DEBUG(dbgs() + << "LSV: Access of " << SizeBytes << "B in addrspace " + << AS << " with alignment " << Alignment.value() + << " is misaligned, and therefore can't be vectorized.\n"); + return false; + } + + unsigned ElementwiseSpeed = 0; + (TTI).allowsMisalignedMemoryAccesses((F).getContext(), VecElemBits, AS, + Alignment, &ElementwiseSpeed); + if (VectorizedSpeed < ElementwiseSpeed) { + LLVM_DEBUG(dbgs() + << "LSV: Access of " << SizeBytes << "B in addrspace " + << AS << " with alignment " << Alignment.value() + << " has relative speed " << VectorizedSpeed + << ", which is lower than the elementwise speed of " + << ElementwiseSpeed + << ". Therefore this access won't be vectorized.\n"); + return false; + } + return true; + }; + + // If we're loading/storing from an alloca, align it if possible. + // + // FIXME: We eagerly upgrade the alignment, regardless of whether TTI + // tells us this is beneficial. This feels a bit odd, but it matches + // existing tests. This isn't *so* bad, because at most we align to 4 + // bytes (current value of StackAdjustedAlignment). + // + // FIXME: We will upgrade the alignment of the alloca even if it turns out + // we can't vectorize for some other reason. + Value *PtrOperand = getLoadStorePointerOperand(C[CBegin].Inst); + bool IsAllocaAccess = AS == DL.getAllocaAddrSpace() && + isa<AllocaInst>(PtrOperand->stripPointerCasts()); + Align Alignment = getLoadStoreAlignment(C[CBegin].Inst); + Align PrefAlign = Align(StackAdjustedAlignment); + if (IsAllocaAccess && Alignment.value() % SizeBytes != 0 && + IsAllowedAndFast(PrefAlign)) { + Align NewAlign = getOrEnforceKnownAlignment( + PtrOperand, PrefAlign, DL, C[CBegin].Inst, nullptr, &DT); + if (NewAlign >= Alignment) { + LLVM_DEBUG(dbgs() + << "LSV: splitByChain upgrading alloca alignment from " + << Alignment.value() << " to " << NewAlign.value() + << "\n"); + Alignment = NewAlign; + } + } + + if (!IsAllowedAndFast(Alignment)) { + LLVM_DEBUG( + dbgs() << "LSV: splitChainByAlignment discarding candidate chain " + "because its alignment is not AllowedAndFast: " + << Alignment.value() << "\n"); + continue; + } + + if ((IsLoadChain && + !TTI.isLegalToVectorizeLoadChain(SizeBytes, Alignment, AS)) || + (!IsLoadChain && + !TTI.isLegalToVectorizeStoreChain(SizeBytes, Alignment, AS))) { + LLVM_DEBUG( + dbgs() << "LSV: splitChainByAlignment discarding candidate chain " + "because !isLegalToVectorizeLoad/StoreChain."); + continue; + } + + // Hooray, we can vectorize this chain! + Chain &NewChain = Ret.emplace_back(); + for (unsigned I = CBegin; I <= CEnd; ++I) + NewChain.push_back(C[I]); + CBegin = CEnd; // Skip over the instructions we've added to the chain. + break; + } + } + return Ret; +} + +bool Vectorizer::vectorizeChain(Chain &C) { + if (C.size() < 2) return false; - // In case if we have to shrink the pointer - // stripAndAccumulateInBoundsConstantOffsets should properly handle a - // possible overflow and the value should fit into a smallest data type - // used in the cast/gep chain. - assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth && - OffsetB.getMinSignedBits() <= NewPtrBitWidth); + sortChainInOffsetOrder(C); - OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth); - OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth); - PtrDelta = PtrDelta.sextOrTrunc(NewPtrBitWidth); + LLVM_DEBUG({ + dbgs() << "LSV: Vectorizing chain of " << C.size() << " instructions:\n"; + dumpChain(C); + }); - APInt OffsetDelta = OffsetB - OffsetA; + Type *VecElemTy = getChainElemTy(C); + bool IsLoadChain = isa<LoadInst>(C[0].Inst); + unsigned AS = getLoadStoreAddressSpace(C[0].Inst); + unsigned ChainBytes = std::accumulate( + C.begin(), C.end(), 0u, [&](unsigned Bytes, const ChainElem &E) { + return Bytes + DL.getTypeStoreSize(getLoadStoreType(E.Inst)); + }); + assert(ChainBytes % DL.getTypeStoreSize(VecElemTy) == 0); + // VecTy is a power of 2 and 1 byte at smallest, but VecElemTy may be smaller + // than 1 byte (e.g. VecTy == <32 x i1>). + Type *VecTy = FixedVectorType::get( + VecElemTy, 8 * ChainBytes / DL.getTypeSizeInBits(VecElemTy)); + + Align Alignment = getLoadStoreAlignment(C[0].Inst); + // If this is a load/store of an alloca, we might have upgraded the alloca's + // alignment earlier. Get the new alignment. + if (AS == DL.getAllocaAddrSpace()) { + Alignment = std::max( + Alignment, + getOrEnforceKnownAlignment(getLoadStorePointerOperand(C[0].Inst), + MaybeAlign(), DL, C[0].Inst, nullptr, &DT)); + } - // Check if they are based on the same pointer. That makes the offsets - // sufficient. - if (PtrA == PtrB) - return OffsetDelta == PtrDelta; - - // Compute the necessary base pointer delta to have the necessary final delta - // equal to the pointer delta requested. - APInt BaseDelta = PtrDelta - OffsetDelta; - - // Compute the distance with SCEV between the base pointers. - const SCEV *PtrSCEVA = SE.getSCEV(PtrA); - const SCEV *PtrSCEVB = SE.getSCEV(PtrB); - const SCEV *C = SE.getConstant(BaseDelta); - const SCEV *X = SE.getAddExpr(PtrSCEVA, C); - if (X == PtrSCEVB) + // All elements of the chain must have the same scalar-type size. +#ifndef NDEBUG + for (const ChainElem &E : C) + assert(DL.getTypeStoreSize(getLoadStoreType(E.Inst)->getScalarType()) == + DL.getTypeStoreSize(VecElemTy)); +#endif + + Instruction *VecInst; + if (IsLoadChain) { + // Loads get hoisted to the location of the first load in the chain. We may + // also need to hoist the (transitive) operands of the loads. + Builder.SetInsertPoint( + std::min_element(C.begin(), C.end(), [](const auto &A, const auto &B) { + return A.Inst->comesBefore(B.Inst); + })->Inst); + + // Chain is in offset order, so C[0] is the instr with the lowest offset, + // i.e. the root of the vector. + Value *Bitcast = Builder.CreateBitCast( + getLoadStorePointerOperand(C[0].Inst), VecTy->getPointerTo(AS)); + VecInst = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment); + + unsigned VecIdx = 0; + for (const ChainElem &E : C) { + Instruction *I = E.Inst; + Value *V; + Type *T = getLoadStoreType(I); + if (auto *VT = dyn_cast<FixedVectorType>(T)) { + auto Mask = llvm::to_vector<8>( + llvm::seq<int>(VecIdx, VecIdx + VT->getNumElements())); + V = Builder.CreateShuffleVector(VecInst, Mask, I->getName()); + VecIdx += VT->getNumElements(); + } else { + V = Builder.CreateExtractElement(VecInst, Builder.getInt32(VecIdx), + I->getName()); + ++VecIdx; + } + if (V->getType() != I->getType()) + V = Builder.CreateBitOrPointerCast(V, I->getType()); + I->replaceAllUsesWith(V); + } + + // Finally, we need to reorder the instrs in the BB so that the (transitive) + // operands of VecInst appear before it. To see why, suppose we have + // vectorized the following code: + // + // ptr1 = gep a, 1 + // load1 = load i32 ptr1 + // ptr0 = gep a, 0 + // load0 = load i32 ptr0 + // + // We will put the vectorized load at the location of the earliest load in + // the BB, i.e. load1. We get: + // + // ptr1 = gep a, 1 + // loadv = load <2 x i32> ptr0 + // load0 = extractelement loadv, 0 + // load1 = extractelement loadv, 1 + // ptr0 = gep a, 0 + // + // Notice that loadv uses ptr0, which is defined *after* it! + reorder(VecInst); + } else { + // Stores get sunk to the location of the last store in the chain. + Builder.SetInsertPoint( + std::max_element(C.begin(), C.end(), [](auto &A, auto &B) { + return A.Inst->comesBefore(B.Inst); + })->Inst); + + // Build the vector to store. + Value *Vec = PoisonValue::get(VecTy); + unsigned VecIdx = 0; + auto InsertElem = [&](Value *V) { + if (V->getType() != VecElemTy) + V = Builder.CreateBitOrPointerCast(V, VecElemTy); + Vec = Builder.CreateInsertElement(Vec, V, Builder.getInt32(VecIdx++)); + }; + for (const ChainElem &E : C) { + auto I = cast<StoreInst>(E.Inst); + if (FixedVectorType *VT = + dyn_cast<FixedVectorType>(getLoadStoreType(I))) { + for (int J = 0, JE = VT->getNumElements(); J < JE; ++J) { + InsertElem(Builder.CreateExtractElement(I->getValueOperand(), + Builder.getInt32(J))); + } + } else { + InsertElem(I->getValueOperand()); + } + } + + // Chain is in offset order, so C[0] is the instr with the lowest offset, + // i.e. the root of the vector. + VecInst = Builder.CreateAlignedStore( + Vec, + Builder.CreateBitCast(getLoadStorePointerOperand(C[0].Inst), + VecTy->getPointerTo(AS)), + Alignment); + } + + propagateMetadata(VecInst, C); + + for (const ChainElem &E : C) + ToErase.push_back(E.Inst); + + ++NumVectorInstructions; + NumScalarsVectorized += C.size(); + return true; +} + +template <bool IsLoadChain> +bool Vectorizer::isSafeToMove( + Instruction *ChainElem, Instruction *ChainBegin, + const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets) { + LLVM_DEBUG(dbgs() << "LSV: isSafeToMove(" << *ChainElem << " -> " + << *ChainBegin << ")\n"); + + assert(isa<LoadInst>(ChainElem) == IsLoadChain); + if (ChainElem == ChainBegin) return true; - // The above check will not catch the cases where one of the pointers is - // factorized but the other one is not, such as (C + (S * (A + B))) vs - // (AS + BS). Get the minus scev. That will allow re-combining the expresions - // and getting the simplified difference. - const SCEV *Dist = SE.getMinusSCEV(PtrSCEVB, PtrSCEVA); - if (C == Dist) + // Invariant loads can always be reordered; by definition they are not + // clobbered by stores. + if (isInvariantLoad(ChainElem)) return true; - // Sometimes even this doesn't work, because SCEV can't always see through - // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking - // things the hard way. - return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth); + auto BBIt = std::next([&] { + if constexpr (IsLoadChain) + return BasicBlock::reverse_iterator(ChainElem); + else + return BasicBlock::iterator(ChainElem); + }()); + auto BBItEnd = std::next([&] { + if constexpr (IsLoadChain) + return BasicBlock::reverse_iterator(ChainBegin); + else + return BasicBlock::iterator(ChainBegin); + }()); + + const APInt &ChainElemOffset = ChainOffsets.at(ChainElem); + const unsigned ChainElemSize = + DL.getTypeStoreSize(getLoadStoreType(ChainElem)); + + for (; BBIt != BBItEnd; ++BBIt) { + Instruction *I = &*BBIt; + + if (!I->mayReadOrWriteMemory()) + continue; + + // Loads can be reordered with other loads. + if (IsLoadChain && isa<LoadInst>(I)) + continue; + + // Stores can be sunk below invariant loads. + if (!IsLoadChain && isInvariantLoad(I)) + continue; + + // If I is in the chain, we can tell whether it aliases ChainIt by checking + // what offset ChainIt accesses. This may be better than AA is able to do. + // + // We should really only have duplicate offsets for stores (the duplicate + // loads should be CSE'ed), but in case we have a duplicate load, we'll + // split the chain so we don't have to handle this case specially. + if (auto OffsetIt = ChainOffsets.find(I); OffsetIt != ChainOffsets.end()) { + // I and ChainElem overlap if: + // - I and ChainElem have the same offset, OR + // - I's offset is less than ChainElem's, but I touches past the + // beginning of ChainElem, OR + // - ChainElem's offset is less than I's, but ChainElem touches past the + // beginning of I. + const APInt &IOffset = OffsetIt->second; + unsigned IElemSize = DL.getTypeStoreSize(getLoadStoreType(I)); + if (IOffset == ChainElemOffset || + (IOffset.sle(ChainElemOffset) && + (IOffset + IElemSize).sgt(ChainElemOffset)) || + (ChainElemOffset.sle(IOffset) && + (ChainElemOffset + ChainElemSize).sgt(OffsetIt->second))) { + LLVM_DEBUG({ + // Double check that AA also sees this alias. If not, we probably + // have a bug. + ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem)); + assert(IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)); + dbgs() << "LSV: Found alias in chain: " << *I << "\n"; + }); + return false; // We found an aliasing instruction; bail. + } + + continue; // We're confident there's no alias. + } + + LLVM_DEBUG(dbgs() << "LSV: Querying AA for " << *I << "\n"); + ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem)); + if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) { + LLVM_DEBUG(dbgs() << "LSV: Found alias in chain:\n" + << " Aliasing instruction:\n" + << " " << *I << '\n' + << " Aliased instruction and pointer:\n" + << " " << *ChainElem << '\n' + << " " << *getLoadStorePointerOperand(ChainElem) + << '\n'); + + return false; + } + } + return true; } static bool checkNoWrapFlags(Instruction *I, bool Signed) { @@ -395,10 +1097,14 @@ static bool checkNoWrapFlags(Instruction *I, bool Signed) { static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA, unsigned MatchingOpIdxA, Instruction *AddOpB, unsigned MatchingOpIdxB, bool Signed) { - // If both OpA and OpB is an add with NSW/NUW and with - // one of the operands being the same, we can guarantee that the - // transformation is safe if we can prove that OpA won't overflow when - // IdxDiff added to the other operand of OpA. + LLVM_DEBUG(dbgs() << "LSV: checkIfSafeAddSequence IdxDiff=" << IdxDiff + << ", AddOpA=" << *AddOpA << ", MatchingOpIdxA=" + << MatchingOpIdxA << ", AddOpB=" << *AddOpB + << ", MatchingOpIdxB=" << MatchingOpIdxB + << ", Signed=" << Signed << "\n"); + // If both OpA and OpB are adds with NSW/NUW and with one of the operands + // being the same, we can guarantee that the transformation is safe if we can + // prove that OpA won't overflow when Ret added to the other operand of OpA. // For example: // %tmp7 = add nsw i32 %tmp2, %v0 // %tmp8 = sext i32 %tmp7 to i64 @@ -407,10 +1113,9 @@ static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA, // %tmp12 = add nsw i32 %tmp2, %tmp11 // %tmp13 = sext i32 %tmp12 to i64 // - // Both %tmp7 and %tmp2 has the nsw flag and the first operand - // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow - // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the - // nsw flag. + // Both %tmp7 and %tmp12 have the nsw flag and the first operand is %tmp2. + // It's guaranteed that adding 1 to %tmp7 won't overflow because %tmp11 adds + // 1 to %v0 and both %tmp11 and %tmp12 have the nsw flag. assert(AddOpA->getOpcode() == Instruction::Add && AddOpB->getOpcode() == Instruction::Add && checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed)); @@ -461,24 +1166,26 @@ static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA, return false; } -bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, - APInt PtrDelta, - unsigned Depth) const { +std::optional<APInt> Vectorizer::getConstantOffsetComplexAddrs( + Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) { + LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs PtrA=" << *PtrA + << " PtrB=" << *PtrB << " ContextInst=" << *ContextInst + << " Depth=" << Depth << "\n"); auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA); auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB); if (!GEPA || !GEPB) - return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth); + return getConstantOffsetSelects(PtrA, PtrB, ContextInst, Depth); // Look through GEPs after checking they're the same except for the last // index. if (GEPA->getNumOperands() != GEPB->getNumOperands() || GEPA->getPointerOperand() != GEPB->getPointerOperand()) - return false; + return std::nullopt; gep_type_iterator GTIA = gep_type_begin(GEPA); gep_type_iterator GTIB = gep_type_begin(GEPB); for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) { if (GTIA.getOperand() != GTIB.getOperand()) - return false; + return std::nullopt; ++GTIA; ++GTIB; } @@ -487,23 +1194,13 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand()); if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() || OpA->getType() != OpB->getType()) - return false; + return std::nullopt; - if (PtrDelta.isNegative()) { - if (PtrDelta.isMinSignedValue()) - return false; - PtrDelta.negate(); - std::swap(OpA, OpB); - } uint64_t Stride = DL.getTypeAllocSize(GTIA.getIndexedType()); - if (PtrDelta.urem(Stride) != 0) - return false; - unsigned IdxBitWidth = OpA->getType()->getScalarSizeInBits(); - APInt IdxDiff = PtrDelta.udiv(Stride).zext(IdxBitWidth); // Only look through a ZExt/SExt. if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA)) - return false; + return std::nullopt; bool Signed = isa<SExtInst>(OpA); @@ -511,7 +1208,21 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, Value *ValA = OpA->getOperand(0); OpB = dyn_cast<Instruction>(OpB->getOperand(0)); if (!OpB || ValA->getType() != OpB->getType()) - return false; + return std::nullopt; + + const SCEV *OffsetSCEVA = SE.getSCEV(ValA); + const SCEV *OffsetSCEVB = SE.getSCEV(OpB); + const SCEV *IdxDiffSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA); + if (IdxDiffSCEV == SE.getCouldNotCompute()) + return std::nullopt; + + ConstantRange IdxDiffRange = SE.getSignedRange(IdxDiffSCEV); + if (!IdxDiffRange.isSingleElement()) + return std::nullopt; + APInt IdxDiff = *IdxDiffRange.getSingleElement(); + + LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs IdxDiff=" << IdxDiff + << "\n"); // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; @@ -530,10 +1241,9 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, if (!Safe && OpA && OpA->getOpcode() == Instruction::Add && OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) && checkNoWrapFlags(OpB, Signed)) { - // In the checks below a matching operand in OpA and OpB is - // an operand which is the same in those two instructions. - // Below we account for possible orders of the operands of - // these add instructions. + // In the checks below a matching operand in OpA and OpB is an operand which + // is the same in those two instructions. Below we account for possible + // orders of the operands of these add instructions. for (unsigned MatchingOpIdxA : {0, 1}) for (unsigned MatchingOpIdxB : {0, 1}) if (!Safe) @@ -544,802 +1254,267 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); // Third attempt: - // If all set bits of IdxDiff or any higher order bit other than the sign bit - // are known to be zero in ValA, we can add Diff to it while guaranteeing no - // overflow of any sort. + // + // Assuming IdxDiff is positive: If all set bits of IdxDiff or any higher + // order bit other than the sign bit are known to be zero in ValA, we can add + // Diff to it while guaranteeing no overflow of any sort. + // + // If IdxDiff is negative, do the same, but swap ValA and ValB. if (!Safe) { + // When computing known bits, use the GEPs as context instructions, since + // they likely are in the same BB as the load/store. KnownBits Known(BitWidth); - computeKnownBits(ValA, Known, DL, 0, &AC, OpB, &DT); + computeKnownBits((IdxDiff.sge(0) ? ValA : OpB), Known, DL, 0, &AC, + ContextInst, &DT); APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth()); if (Signed) BitsAllowedToBeSet.clearBit(BitWidth - 1); - if (BitsAllowedToBeSet.ult(IdxDiff)) - return false; + if (BitsAllowedToBeSet.ult(IdxDiff.abs())) + return std::nullopt; + Safe = true; } - const SCEV *OffsetSCEVA = SE.getSCEV(ValA); - const SCEV *OffsetSCEVB = SE.getSCEV(OpB); - const SCEV *C = SE.getConstant(IdxDiff.trunc(BitWidth)); - const SCEV *X = SE.getAddExpr(OffsetSCEVA, C); - return X == OffsetSCEVB; + if (Safe) + return IdxDiff * Stride; + return std::nullopt; } -bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB, - const APInt &PtrDelta, - unsigned Depth) const { +std::optional<APInt> Vectorizer::getConstantOffsetSelects( + Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) { if (Depth++ == MaxDepth) - return false; + return std::nullopt; if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) { if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) { - return SelectA->getCondition() == SelectB->getCondition() && - areConsecutivePointers(SelectA->getTrueValue(), - SelectB->getTrueValue(), PtrDelta, Depth) && - areConsecutivePointers(SelectA->getFalseValue(), - SelectB->getFalseValue(), PtrDelta, Depth); + if (SelectA->getCondition() != SelectB->getCondition()) + return std::nullopt; + LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetSelects, PtrA=" << *PtrA + << ", PtrB=" << *PtrB << ", ContextInst=" + << *ContextInst << ", Depth=" << Depth << "\n"); + std::optional<APInt> TrueDiff = getConstantOffset( + SelectA->getTrueValue(), SelectB->getTrueValue(), ContextInst, Depth); + if (!TrueDiff.has_value()) + return std::nullopt; + std::optional<APInt> FalseDiff = + getConstantOffset(SelectA->getFalseValue(), SelectB->getFalseValue(), + ContextInst, Depth); + if (TrueDiff == FalseDiff) + return TrueDiff; } } - return false; + return std::nullopt; } -void Vectorizer::reorder(Instruction *I) { - SmallPtrSet<Instruction *, 16> InstructionsToMove; - SmallVector<Instruction *, 16> Worklist; - - Worklist.push_back(I); - while (!Worklist.empty()) { - Instruction *IW = Worklist.pop_back_val(); - int NumOperands = IW->getNumOperands(); - for (int i = 0; i < NumOperands; i++) { - Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i)); - if (!IM || IM->getOpcode() == Instruction::PHI) - continue; - - // If IM is in another BB, no need to move it, because this pass only - // vectorizes instructions within one BB. - if (IM->getParent() != I->getParent()) - continue; - - if (!IM->comesBefore(I)) { - InstructionsToMove.insert(IM); - Worklist.push_back(IM); - } +EquivalenceClassMap +Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin, + BasicBlock::iterator End) { + EquivalenceClassMap Ret; + + auto getUnderlyingObject = [](const Value *Ptr) -> const Value * { + const Value *ObjPtr = llvm::getUnderlyingObject(Ptr); + if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) { + // The select's themselves are distinct instructions even if they share + // the same condition and evaluate to consecutive pointers for true and + // false values of the condition. Therefore using the select's themselves + // for grouping instructions would put consecutive accesses into different + // lists and they won't be even checked for being consecutive, and won't + // be vectorized. + return Sel->getCondition(); } - } + return ObjPtr; + }; - // All instructions to move should follow I. Start from I, not from begin(). - for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E; - ++BBI) { - if (!InstructionsToMove.count(&*BBI)) + for (Instruction &I : make_range(Begin, End)) { + auto *LI = dyn_cast<LoadInst>(&I); + auto *SI = dyn_cast<StoreInst>(&I); + if (!LI && !SI) continue; - Instruction *IM = &*BBI; - --BBI; - IM->removeFromParent(); - IM->insertBefore(I); - } -} - -std::pair<BasicBlock::iterator, BasicBlock::iterator> -Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) { - Instruction *C0 = Chain[0]; - BasicBlock::iterator FirstInstr = C0->getIterator(); - BasicBlock::iterator LastInstr = C0->getIterator(); - BasicBlock *BB = C0->getParent(); - unsigned NumFound = 0; - for (Instruction &I : *BB) { - if (!is_contained(Chain, &I)) + if ((LI && !LI->isSimple()) || (SI && !SI->isSimple())) continue; - ++NumFound; - if (NumFound == 1) { - FirstInstr = I.getIterator(); - } - if (NumFound == Chain.size()) { - LastInstr = I.getIterator(); - break; - } - } - - // Range is [first, last). - return std::make_pair(FirstInstr, ++LastInstr); -} - -void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) { - SmallVector<Instruction *, 16> Instrs; - for (Instruction *I : Chain) { - Value *PtrOperand = getLoadStorePointerOperand(I); - assert(PtrOperand && "Instruction must have a pointer operand."); - Instrs.push_back(I); - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand)) - Instrs.push_back(GEP); - } - - // Erase instructions. - for (Instruction *I : Instrs) - if (I->use_empty()) - I->eraseFromParent(); -} - -std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>> -Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain, - unsigned ElementSizeBits) { - unsigned ElementSizeBytes = ElementSizeBits / 8; - unsigned SizeBytes = ElementSizeBytes * Chain.size(); - unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes; - if (NumLeft == Chain.size()) { - if ((NumLeft & 1) == 0) - NumLeft /= 2; // Split even in half - else - --NumLeft; // Split off last element - } else if (NumLeft == 0) - NumLeft = 1; - return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft)); -} - -ArrayRef<Instruction *> -Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { - // These are in BB order, unlike Chain, which is in address order. - SmallVector<Instruction *, 16> MemoryInstrs; - SmallVector<Instruction *, 16> ChainInstrs; - - bool IsLoadChain = isa<LoadInst>(Chain[0]); - LLVM_DEBUG({ - for (Instruction *I : Chain) { - if (IsLoadChain) - assert(isa<LoadInst>(I) && - "All elements of Chain must be loads, or all must be stores."); - else - assert(isa<StoreInst>(I) && - "All elements of Chain must be loads, or all must be stores."); - } - }); - - for (Instruction &I : make_range(getBoundaryInstrs(Chain))) { - if ((isa<LoadInst>(I) || isa<StoreInst>(I)) && is_contained(Chain, &I)) { - ChainInstrs.push_back(&I); + if ((LI && !TTI.isLegalToVectorizeLoad(LI)) || + (SI && !TTI.isLegalToVectorizeStore(SI))) continue; - } - if (!isGuaranteedToTransferExecutionToSuccessor(&I)) { - LLVM_DEBUG(dbgs() << "LSV: Found instruction may not transfer execution: " - << I << '\n'); - break; - } - if (I.mayReadOrWriteMemory()) - MemoryInstrs.push_back(&I); - } - - // Loop until we find an instruction in ChainInstrs that we can't vectorize. - unsigned ChainInstrIdx = 0; - Instruction *BarrierMemoryInstr = nullptr; - - for (unsigned E = ChainInstrs.size(); ChainInstrIdx < E; ++ChainInstrIdx) { - Instruction *ChainInstr = ChainInstrs[ChainInstrIdx]; - - // If a barrier memory instruction was found, chain instructions that follow - // will not be added to the valid prefix. - if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(ChainInstr)) - break; - // Check (in BB order) if any instruction prevents ChainInstr from being - // vectorized. Find and store the first such "conflicting" instruction. - for (Instruction *MemInstr : MemoryInstrs) { - // If a barrier memory instruction was found, do not check past it. - if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(MemInstr)) - break; - - auto *MemLoad = dyn_cast<LoadInst>(MemInstr); - auto *ChainLoad = dyn_cast<LoadInst>(ChainInstr); - if (MemLoad && ChainLoad) - continue; - - // We can ignore the alias if the we have a load store pair and the load - // is known to be invariant. The load cannot be clobbered by the store. - auto IsInvariantLoad = [](const LoadInst *LI) -> bool { - return LI->hasMetadata(LLVMContext::MD_invariant_load); - }; - - if (IsLoadChain) { - // We can ignore the alias as long as the load comes before the store, - // because that means we won't be moving the load past the store to - // vectorize it (the vectorized load is inserted at the location of the - // first load in the chain). - if (ChainInstr->comesBefore(MemInstr) || - (ChainLoad && IsInvariantLoad(ChainLoad))) - continue; - } else { - // Same case, but in reverse. - if (MemInstr->comesBefore(ChainInstr) || - (MemLoad && IsInvariantLoad(MemLoad))) - continue; - } - - ModRefInfo MR = - AA.getModRefInfo(MemInstr, MemoryLocation::get(ChainInstr)); - if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) { - LLVM_DEBUG({ - dbgs() << "LSV: Found alias:\n" - " Aliasing instruction:\n" - << " " << *MemInstr << '\n' - << " Aliased instruction and pointer:\n" - << " " << *ChainInstr << '\n' - << " " << *getLoadStorePointerOperand(ChainInstr) << '\n'; - }); - // Save this aliasing memory instruction as a barrier, but allow other - // instructions that precede the barrier to be vectorized with this one. - BarrierMemoryInstr = MemInstr; - break; - } - } - // Continue the search only for store chains, since vectorizing stores that - // precede an aliasing load is valid. Conversely, vectorizing loads is valid - // up to an aliasing store, but should not pull loads from further down in - // the basic block. - if (IsLoadChain && BarrierMemoryInstr) { - // The BarrierMemoryInstr is a store that precedes ChainInstr. - assert(BarrierMemoryInstr->comesBefore(ChainInstr)); - break; - } - } - - // Find the largest prefix of Chain whose elements are all in - // ChainInstrs[0, ChainInstrIdx). This is the largest vectorizable prefix of - // Chain. (Recall that Chain is in address order, but ChainInstrs is in BB - // order.) - SmallPtrSet<Instruction *, 8> VectorizableChainInstrs( - ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx); - unsigned ChainIdx = 0; - for (unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) { - if (!VectorizableChainInstrs.count(Chain[ChainIdx])) - break; - } - return Chain.slice(0, ChainIdx); -} - -static ChainID getChainID(const Value *Ptr) { - const Value *ObjPtr = getUnderlyingObject(Ptr); - if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) { - // The select's themselves are distinct instructions even if they share the - // same condition and evaluate to consecutive pointers for true and false - // values of the condition. Therefore using the select's themselves for - // grouping instructions would put consecutive accesses into different lists - // and they won't be even checked for being consecutive, and won't be - // vectorized. - return Sel->getCondition(); - } - return ObjPtr; -} - -std::pair<InstrListMap, InstrListMap> -Vectorizer::collectInstructions(BasicBlock *BB) { - InstrListMap LoadRefs; - InstrListMap StoreRefs; - - for (Instruction &I : *BB) { - if (!I.mayReadOrWriteMemory()) + Type *Ty = getLoadStoreType(&I); + if (!VectorType::isValidElementType(Ty->getScalarType())) continue; - if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { - if (!LI->isSimple()) - continue; - - // Skip if it's not legal. - if (!TTI.isLegalToVectorizeLoad(LI)) - continue; - - Type *Ty = LI->getType(); - if (!VectorType::isValidElementType(Ty->getScalarType())) - continue; - - // Skip weird non-byte sizes. They probably aren't worth the effort of - // handling correctly. - unsigned TySize = DL.getTypeSizeInBits(Ty); - if ((TySize % 8) != 0) - continue; - - // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain - // functions are currently using an integer type for the vectorized - // load/store, and does not support casting between the integer type and a - // vector of pointers (e.g. i64 to <2 x i16*>) - if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy()) - continue; - - Value *Ptr = LI->getPointerOperand(); - unsigned AS = Ptr->getType()->getPointerAddressSpace(); - unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); - - unsigned VF = VecRegSize / TySize; - VectorType *VecTy = dyn_cast<VectorType>(Ty); - - // No point in looking at these if they're too big to vectorize. - if (TySize > VecRegSize / 2 || - (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0)) - continue; - - // Save the load locations. - const ChainID ID = getChainID(Ptr); - LoadRefs[ID].push_back(LI); - } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { - if (!SI->isSimple()) - continue; - - // Skip if it's not legal. - if (!TTI.isLegalToVectorizeStore(SI)) - continue; - - Type *Ty = SI->getValueOperand()->getType(); - if (!VectorType::isValidElementType(Ty->getScalarType())) - continue; - - // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain - // functions are currently using an integer type for the vectorized - // load/store, and does not support casting between the integer type and a - // vector of pointers (e.g. i64 to <2 x i16*>) - if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy()) - continue; - - // Skip weird non-byte sizes. They probably aren't worth the effort of - // handling correctly. - unsigned TySize = DL.getTypeSizeInBits(Ty); - if ((TySize % 8) != 0) - continue; - - Value *Ptr = SI->getPointerOperand(); - unsigned AS = Ptr->getType()->getPointerAddressSpace(); - unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); - - unsigned VF = VecRegSize / TySize; - VectorType *VecTy = dyn_cast<VectorType>(Ty); - - // No point in looking at these if they're too big to vectorize. - if (TySize > VecRegSize / 2 || - (VecTy && TTI.getStoreVectorFactor(VF, TySize, TySize / 8, VecTy) == 0)) - continue; - - // Save store location. - const ChainID ID = getChainID(Ptr); - StoreRefs[ID].push_back(SI); - } - } - - return {LoadRefs, StoreRefs}; -} - -bool Vectorizer::vectorizeChains(InstrListMap &Map) { - bool Changed = false; - - for (const std::pair<ChainID, InstrList> &Chain : Map) { - unsigned Size = Chain.second.size(); - if (Size < 2) + // Skip weird non-byte sizes. They probably aren't worth the effort of + // handling correctly. + unsigned TySize = DL.getTypeSizeInBits(Ty); + if ((TySize % 8) != 0) continue; - LLVM_DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n"); - - // Process the stores in chunks of 64. - for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) { - unsigned Len = std::min<unsigned>(CE - CI, 64); - ArrayRef<Instruction *> Chunk(&Chain.second[CI], Len); - Changed |= vectorizeInstructions(Chunk); - } - } - - return Changed; -} - -bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) { - LLVM_DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() - << " instructions.\n"); - SmallVector<int, 16> Heads, Tails; - int ConsecutiveChain[64]; - - // Do a quadratic search on all of the given loads/stores and find all of the - // pairs of loads/stores that follow each other. - for (int i = 0, e = Instrs.size(); i < e; ++i) { - ConsecutiveChain[i] = -1; - for (int j = e - 1; j >= 0; --j) { - if (i == j) - continue; - - if (isConsecutiveAccess(Instrs[i], Instrs[j])) { - if (ConsecutiveChain[i] != -1) { - int CurDistance = std::abs(ConsecutiveChain[i] - i); - int NewDistance = std::abs(ConsecutiveChain[i] - j); - if (j < i || NewDistance > CurDistance) - continue; // Should not insert. - } + // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain + // functions are currently using an integer type for the vectorized + // load/store, and does not support casting between the integer type and a + // vector of pointers (e.g. i64 to <2 x i16*>) + if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy()) + continue; - Tails.push_back(j); - Heads.push_back(i); - ConsecutiveChain[i] = j; - } - } - } + Value *Ptr = getLoadStorePointerOperand(&I); + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); - bool Changed = false; - SmallPtrSet<Instruction *, 16> InstructionsProcessed; + unsigned VF = VecRegSize / TySize; + VectorType *VecTy = dyn_cast<VectorType>(Ty); - for (int Head : Heads) { - if (InstructionsProcessed.count(Instrs[Head])) + // Only handle power-of-two sized elements. + if ((!VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(Ty))) || + (VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(VecTy->getScalarType())))) continue; - bool LongerChainExists = false; - for (unsigned TIt = 0; TIt < Tails.size(); TIt++) - if (Head == Tails[TIt] && - !InstructionsProcessed.count(Instrs[Heads[TIt]])) { - LongerChainExists = true; - break; - } - if (LongerChainExists) - continue; - - // We found an instr that starts a chain. Now follow the chain and try to - // vectorize it. - SmallVector<Instruction *, 16> Operands; - int I = Head; - while (I != -1 && (is_contained(Tails, I) || is_contained(Heads, I))) { - if (InstructionsProcessed.count(Instrs[I])) - break; - - Operands.push_back(Instrs[I]); - I = ConsecutiveChain[I]; - } - bool Vectorized = false; - if (isa<LoadInst>(*Operands.begin())) - Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed); - else - Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed); + // No point in looking at these if they're too big to vectorize. + if (TySize > VecRegSize / 2 || + (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0)) + continue; - Changed |= Vectorized; + Ret[{getUnderlyingObject(Ptr), AS, + DL.getTypeSizeInBits(getLoadStoreType(&I)->getScalarType()), + /*IsLoad=*/LI != nullptr}] + .push_back(&I); } - return Changed; + return Ret; } -bool Vectorizer::vectorizeStoreChain( - ArrayRef<Instruction *> Chain, - SmallPtrSet<Instruction *, 16> *InstructionsProcessed) { - StoreInst *S0 = cast<StoreInst>(Chain[0]); - - // If the vector has an int element, default to int for the whole store. - Type *StoreTy = nullptr; - for (Instruction *I : Chain) { - StoreTy = cast<StoreInst>(I)->getValueOperand()->getType(); - if (StoreTy->isIntOrIntVectorTy()) - break; - - if (StoreTy->isPtrOrPtrVectorTy()) { - StoreTy = Type::getIntNTy(F.getParent()->getContext(), - DL.getTypeSizeInBits(StoreTy)); - break; - } - } - assert(StoreTy && "Failed to find store type"); +std::vector<Chain> Vectorizer::gatherChains(ArrayRef<Instruction *> Instrs) { + if (Instrs.empty()) + return {}; - unsigned Sz = DL.getTypeSizeInBits(StoreTy); - unsigned AS = S0->getPointerAddressSpace(); - unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); - unsigned VF = VecRegSize / Sz; - unsigned ChainSize = Chain.size(); - Align Alignment = S0->getAlign(); + unsigned AS = getLoadStoreAddressSpace(Instrs[0]); + unsigned ASPtrBits = DL.getIndexSizeInBits(AS); - if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { - InstructionsProcessed->insert(Chain.begin(), Chain.end()); - return false; +#ifndef NDEBUG + // Check that Instrs is in BB order and all have the same addr space. + for (size_t I = 1; I < Instrs.size(); ++I) { + assert(Instrs[I - 1]->comesBefore(Instrs[I])); + assert(getLoadStoreAddressSpace(Instrs[I]) == AS); } +#endif - ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain); - if (NewChain.empty()) { - // No vectorization possible. - InstructionsProcessed->insert(Chain.begin(), Chain.end()); - return false; - } - if (NewChain.size() == 1) { - // Failed after the first instruction. Discard it and try the smaller chain. - InstructionsProcessed->insert(NewChain.front()); - return false; - } - - // Update Chain to the valid vectorizable subchain. - Chain = NewChain; - ChainSize = Chain.size(); - - // Check if it's legal to vectorize this chain. If not, split the chain and - // try again. - unsigned EltSzInBytes = Sz / 8; - unsigned SzInBytes = EltSzInBytes * ChainSize; - - FixedVectorType *VecTy; - auto *VecStoreTy = dyn_cast<FixedVectorType>(StoreTy); - if (VecStoreTy) - VecTy = FixedVectorType::get(StoreTy->getScalarType(), - Chain.size() * VecStoreTy->getNumElements()); - else - VecTy = FixedVectorType::get(StoreTy, Chain.size()); - - // If it's more than the max vector size or the target has a better - // vector factor, break it into two pieces. - unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy); - if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { - LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." - " Creating two separate arrays.\n"); - bool Vectorized = false; - Vectorized |= - vectorizeStoreChain(Chain.slice(0, TargetVF), InstructionsProcessed); - Vectorized |= - vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed); - return Vectorized; - } - - LLVM_DEBUG({ - dbgs() << "LSV: Stores to vectorize:\n"; - for (Instruction *I : Chain) - dbgs() << " " << *I << "\n"; - }); - - // We won't try again to vectorize the elements of the chain, regardless of - // whether we succeed below. - InstructionsProcessed->insert(Chain.begin(), Chain.end()); - - // If the store is going to be misaligned, don't vectorize it. - unsigned RelativeSpeed; - if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) { - if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { - unsigned SpeedBefore; - accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore); - if (SpeedBefore > RelativeSpeed) - return false; - - auto Chains = splitOddVectorElts(Chain, Sz); - bool Vectorized = false; - Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed); - Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed); - return Vectorized; + // Machinery to build an MRU-hashtable of Chains. + // + // (Ideally this could be done with MapVector, but as currently implemented, + // moving an element to the front of a MapVector is O(n).) + struct InstrListElem : ilist_node<InstrListElem>, + std::pair<Instruction *, Chain> { + explicit InstrListElem(Instruction *I) + : std::pair<Instruction *, Chain>(I, {}) {} + }; + struct InstrListElemDenseMapInfo { + using PtrInfo = DenseMapInfo<InstrListElem *>; + using IInfo = DenseMapInfo<Instruction *>; + static InstrListElem *getEmptyKey() { return PtrInfo::getEmptyKey(); } + static InstrListElem *getTombstoneKey() { + return PtrInfo::getTombstoneKey(); } - - Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(), - Align(StackAdjustedAlignment), - DL, S0, nullptr, &DT); - if (NewAlign >= Alignment) - Alignment = NewAlign; - else - return false; - } - - if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { - auto Chains = splitOddVectorElts(Chain, Sz); - bool Vectorized = false; - Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed); - Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed); - return Vectorized; - } - - BasicBlock::iterator First, Last; - std::tie(First, Last) = getBoundaryInstrs(Chain); - Builder.SetInsertPoint(&*Last); - - Value *Vec = PoisonValue::get(VecTy); - - if (VecStoreTy) { - unsigned VecWidth = VecStoreTy->getNumElements(); - for (unsigned I = 0, E = Chain.size(); I != E; ++I) { - StoreInst *Store = cast<StoreInst>(Chain[I]); - for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) { - unsigned NewIdx = J + I * VecWidth; - Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(), - Builder.getInt32(J)); - if (Extract->getType() != StoreTy->getScalarType()) - Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType()); - - Value *Insert = - Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx)); - Vec = Insert; - } + static unsigned getHashValue(const InstrListElem *E) { + return IInfo::getHashValue(E->first); } - } else { - for (unsigned I = 0, E = Chain.size(); I != E; ++I) { - StoreInst *Store = cast<StoreInst>(Chain[I]); - Value *Extract = Store->getValueOperand(); - if (Extract->getType() != StoreTy->getScalarType()) - Extract = - Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType()); - - Value *Insert = - Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I)); - Vec = Insert; + static bool isEqual(const InstrListElem *A, const InstrListElem *B) { + if (A == getEmptyKey() || B == getEmptyKey()) + return A == getEmptyKey() && B == getEmptyKey(); + if (A == getTombstoneKey() || B == getTombstoneKey()) + return A == getTombstoneKey() && B == getTombstoneKey(); + return IInfo::isEqual(A->first, B->first); } - } - - StoreInst *SI = Builder.CreateAlignedStore( - Vec, - Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)), - Alignment); - propagateMetadata(SI, Chain); - - eraseInstructions(Chain); - ++NumVectorInstructions; - NumScalarsVectorized += Chain.size(); - return true; -} - -bool Vectorizer::vectorizeLoadChain( - ArrayRef<Instruction *> Chain, - SmallPtrSet<Instruction *, 16> *InstructionsProcessed) { - LoadInst *L0 = cast<LoadInst>(Chain[0]); - - // If the vector has an int element, default to int for the whole load. - Type *LoadTy = nullptr; - for (const auto &V : Chain) { - LoadTy = cast<LoadInst>(V)->getType(); - if (LoadTy->isIntOrIntVectorTy()) - break; - - if (LoadTy->isPtrOrPtrVectorTy()) { - LoadTy = Type::getIntNTy(F.getParent()->getContext(), - DL.getTypeSizeInBits(LoadTy)); - break; + }; + SpecificBumpPtrAllocator<InstrListElem> Allocator; + simple_ilist<InstrListElem> MRU; + DenseSet<InstrListElem *, InstrListElemDenseMapInfo> Chains; + + // Compare each instruction in `instrs` to leader of the N most recently-used + // chains. This limits the O(n^2) behavior of this pass while also allowing + // us to build arbitrarily long chains. + for (Instruction *I : Instrs) { + constexpr int MaxChainsToTry = 64; + + bool MatchFound = false; + auto ChainIter = MRU.begin(); + for (size_t J = 0; J < MaxChainsToTry && ChainIter != MRU.end(); + ++J, ++ChainIter) { + std::optional<APInt> Offset = getConstantOffset( + getLoadStorePointerOperand(ChainIter->first), + getLoadStorePointerOperand(I), + /*ContextInst=*/ + (ChainIter->first->comesBefore(I) ? I : ChainIter->first)); + if (Offset.has_value()) { + // `Offset` might not have the expected number of bits, if e.g. AS has a + // different number of bits than opaque pointers. + ChainIter->second.push_back(ChainElem{I, Offset.value()}); + // Move ChainIter to the front of the MRU list. + MRU.remove(*ChainIter); + MRU.push_front(*ChainIter); + MatchFound = true; + break; + } } - } - assert(LoadTy && "Can't determine LoadInst type from chain"); - - unsigned Sz = DL.getTypeSizeInBits(LoadTy); - unsigned AS = L0->getPointerAddressSpace(); - unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); - unsigned VF = VecRegSize / Sz; - unsigned ChainSize = Chain.size(); - Align Alignment = L0->getAlign(); - - if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { - InstructionsProcessed->insert(Chain.begin(), Chain.end()); - return false; - } - - ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain); - if (NewChain.empty()) { - // No vectorization possible. - InstructionsProcessed->insert(Chain.begin(), Chain.end()); - return false; - } - if (NewChain.size() == 1) { - // Failed after the first instruction. Discard it and try the smaller chain. - InstructionsProcessed->insert(NewChain.front()); - return false; - } - // Update Chain to the valid vectorizable subchain. - Chain = NewChain; - ChainSize = Chain.size(); - - // Check if it's legal to vectorize this chain. If not, split the chain and - // try again. - unsigned EltSzInBytes = Sz / 8; - unsigned SzInBytes = EltSzInBytes * ChainSize; - VectorType *VecTy; - auto *VecLoadTy = dyn_cast<FixedVectorType>(LoadTy); - if (VecLoadTy) - VecTy = FixedVectorType::get(LoadTy->getScalarType(), - Chain.size() * VecLoadTy->getNumElements()); - else - VecTy = FixedVectorType::get(LoadTy, Chain.size()); - - // If it's more than the max vector size or the target has a better - // vector factor, break it into two pieces. - unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy); - if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { - LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." - " Creating two separate arrays.\n"); - bool Vectorized = false; - Vectorized |= - vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed); - Vectorized |= - vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed); - return Vectorized; - } - - // We won't try again to vectorize the elements of the chain, regardless of - // whether we succeed below. - InstructionsProcessed->insert(Chain.begin(), Chain.end()); - - // If the load is going to be misaligned, don't vectorize it. - unsigned RelativeSpeed; - if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) { - if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { - unsigned SpeedBefore; - accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore); - if (SpeedBefore > RelativeSpeed) - return false; - - auto Chains = splitOddVectorElts(Chain, Sz); - bool Vectorized = false; - Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed); - Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed); - return Vectorized; + if (!MatchFound) { + APInt ZeroOffset(ASPtrBits, 0); + InstrListElem *E = new (Allocator.Allocate()) InstrListElem(I); + E->second.push_back(ChainElem{I, ZeroOffset}); + MRU.push_front(*E); + Chains.insert(E); } - - Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(), - Align(StackAdjustedAlignment), - DL, L0, nullptr, &DT); - if (NewAlign >= Alignment) - Alignment = NewAlign; - else - return false; } - if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { - auto Chains = splitOddVectorElts(Chain, Sz); - bool Vectorized = false; - Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed); - Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed); - return Vectorized; - } + std::vector<Chain> Ret; + Ret.reserve(Chains.size()); + // Iterate over MRU rather than Chains so the order is deterministic. + for (auto &E : MRU) + if (E.second.size() > 1) + Ret.push_back(std::move(E.second)); + return Ret; +} - LLVM_DEBUG({ - dbgs() << "LSV: Loads to vectorize:\n"; - for (Instruction *I : Chain) - I->dump(); - }); +std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB, + Instruction *ContextInst, + unsigned Depth) { + LLVM_DEBUG(dbgs() << "LSV: getConstantOffset, PtrA=" << *PtrA + << ", PtrB=" << *PtrB << ", ContextInst= " << *ContextInst + << ", Depth=" << Depth << "\n"); + // We'll ultimately return a value of this bit width, even if computations + // happen in a different width. + unsigned OrigBitWidth = DL.getIndexTypeSizeInBits(PtrA->getType()); + APInt OffsetA(OrigBitWidth, 0); + APInt OffsetB(OrigBitWidth, 0); + PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); + PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); + unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType()); + if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType())) + return std::nullopt; - // getVectorizablePrefix already computed getBoundaryInstrs. The value of - // Last may have changed since then, but the value of First won't have. If it - // matters, we could compute getBoundaryInstrs only once and reuse it here. - BasicBlock::iterator First, Last; - std::tie(First, Last) = getBoundaryInstrs(Chain); - Builder.SetInsertPoint(&*First); - - Value *Bitcast = - Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); - LoadInst *LI = - Builder.CreateAlignedLoad(VecTy, Bitcast, MaybeAlign(Alignment)); - propagateMetadata(LI, Chain); - - for (unsigned I = 0, E = Chain.size(); I != E; ++I) { - Value *CV = Chain[I]; - Value *V; - if (VecLoadTy) { - // Extract a subvector using shufflevector. - unsigned VecWidth = VecLoadTy->getNumElements(); - auto Mask = - llvm::to_vector<8>(llvm::seq<int>(I * VecWidth, (I + 1) * VecWidth)); - V = Builder.CreateShuffleVector(LI, Mask, CV->getName()); - } else { - V = Builder.CreateExtractElement(LI, Builder.getInt32(I), CV->getName()); - } + // If we have to shrink the pointer, stripAndAccumulateInBoundsConstantOffsets + // should properly handle a possible overflow and the value should fit into + // the smallest data type used in the cast/gep chain. + assert(OffsetA.getSignificantBits() <= NewPtrBitWidth && + OffsetB.getSignificantBits() <= NewPtrBitWidth); - if (V->getType() != CV->getType()) { - V = Builder.CreateBitOrPointerCast(V, CV->getType()); + OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth); + OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth); + if (PtrA == PtrB) + return (OffsetB - OffsetA).sextOrTrunc(OrigBitWidth); + + // Try to compute B - A. + const SCEV *DistScev = SE.getMinusSCEV(SE.getSCEV(PtrB), SE.getSCEV(PtrA)); + if (DistScev != SE.getCouldNotCompute()) { + LLVM_DEBUG(dbgs() << "LSV: SCEV PtrB - PtrA =" << *DistScev << "\n"); + ConstantRange DistRange = SE.getSignedRange(DistScev); + if (DistRange.isSingleElement()) { + // Handle index width (the width of Dist) != pointer width (the width of + // the Offset*s at this point). + APInt Dist = DistRange.getSingleElement()->sextOrTrunc(NewPtrBitWidth); + return (OffsetB - OffsetA + Dist).sextOrTrunc(OrigBitWidth); } - - // Replace the old instruction. - CV->replaceAllUsesWith(V); } - - // Since we might have opaque pointers we might end up using the pointer - // operand of the first load (wrt. memory loaded) for the vector load. Since - // this first load might not be the first in the block we potentially need to - // reorder the pointer operand (and its operands). If we have a bitcast though - // it might be before the load and should be the reorder start instruction. - // "Might" because for opaque pointers the "bitcast" is just the first loads - // pointer operand, as oppposed to something we inserted at the right position - // ourselves. - Instruction *BCInst = dyn_cast<Instruction>(Bitcast); - reorder((BCInst && BCInst != L0->getPointerOperand()) ? BCInst : LI); - - eraseInstructions(Chain); - - ++NumVectorInstructions; - NumScalarsVectorized += Chain.size(); - return true; -} - -bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, - Align Alignment, unsigned &RelativeSpeed) { - RelativeSpeed = 0; - if (Alignment.value() % SzInBytes == 0) - return false; - - bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(), - SzInBytes * 8, AddressSpace, - Alignment, &RelativeSpeed); - LLVM_DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows - << " with relative speed = " << RelativeSpeed << '\n';); - return !Allows || !RelativeSpeed; + std::optional<APInt> Diff = + getConstantOffsetComplexAddrs(PtrA, PtrB, ContextInst, Depth); + if (Diff.has_value()) + return (OffsetB - OffsetA + Diff->sext(OffsetB.getBitWidth())) + .sextOrTrunc(OrigBitWidth); + return std::nullopt; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index cd48c0d57eb3..f923f0be6621 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -37,6 +37,11 @@ static cl::opt<bool> EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, cl::desc("Enable if-conversion during vectorization.")); +static cl::opt<bool> +AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden, + cl::desc("Enable recognition of non-constant strided " + "pointer induction variables.")); + namespace llvm { cl::opt<bool> HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden, @@ -447,8 +452,12 @@ static bool storeToSameAddress(ScalarEvolution *SE, StoreInst *A, int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy, Value *Ptr) const { - const ValueToValueMap &Strides = - getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap(); + // FIXME: Currently, the set of symbolic strides is sometimes queried before + // it's collected. This happens from canVectorizeWithIfConvert, when the + // pointer is checked to reference consecutive elements suitable for a + // masked access. + const auto &Strides = + LAI ? LAI->getSymbolicStrides() : DenseMap<Value *, const SCEV *>(); Function *F = TheLoop->getHeader()->getParent(); bool OptForSize = F->hasOptSize() || @@ -462,11 +471,135 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy, return 0; } -bool LoopVectorizationLegality::isUniform(Value *V) const { - return LAI->isUniform(V); +bool LoopVectorizationLegality::isInvariant(Value *V) const { + return LAI->isInvariant(V); +} + +namespace { +/// A rewriter to build the SCEVs for each of the VF lanes in the expected +/// vectorized loop, which can then be compared to detect their uniformity. This +/// is done by replacing the AddRec SCEVs of the original scalar loop (TheLoop) +/// with new AddRecs where the step is multiplied by StepMultiplier and Offset * +/// Step is added. Also checks if all sub-expressions are analyzable w.r.t. +/// uniformity. +class SCEVAddRecForUniformityRewriter + : public SCEVRewriteVisitor<SCEVAddRecForUniformityRewriter> { + /// Multiplier to be applied to the step of AddRecs in TheLoop. + unsigned StepMultiplier; + + /// Offset to be added to the AddRecs in TheLoop. + unsigned Offset; + + /// Loop for which to rewrite AddRecsFor. + Loop *TheLoop; + + /// Is any sub-expressions not analyzable w.r.t. uniformity? + bool CannotAnalyze = false; + + bool canAnalyze() const { return !CannotAnalyze; } + +public: + SCEVAddRecForUniformityRewriter(ScalarEvolution &SE, unsigned StepMultiplier, + unsigned Offset, Loop *TheLoop) + : SCEVRewriteVisitor(SE), StepMultiplier(StepMultiplier), Offset(Offset), + TheLoop(TheLoop) {} + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + assert(Expr->getLoop() == TheLoop && + "addrec outside of TheLoop must be invariant and should have been " + "handled earlier"); + // Build a new AddRec by multiplying the step by StepMultiplier and + // incrementing the start by Offset * step. + Type *Ty = Expr->getType(); + auto *Step = Expr->getStepRecurrence(SE); + if (!SE.isLoopInvariant(Step, TheLoop)) { + CannotAnalyze = true; + return Expr; + } + auto *NewStep = SE.getMulExpr(Step, SE.getConstant(Ty, StepMultiplier)); + auto *ScaledOffset = SE.getMulExpr(Step, SE.getConstant(Ty, Offset)); + auto *NewStart = SE.getAddExpr(Expr->getStart(), ScaledOffset); + return SE.getAddRecExpr(NewStart, NewStep, TheLoop, SCEV::FlagAnyWrap); + } + + const SCEV *visit(const SCEV *S) { + if (CannotAnalyze || SE.isLoopInvariant(S, TheLoop)) + return S; + return SCEVRewriteVisitor<SCEVAddRecForUniformityRewriter>::visit(S); + } + + const SCEV *visitUnknown(const SCEVUnknown *S) { + if (SE.isLoopInvariant(S, TheLoop)) + return S; + // The value could vary across iterations. + CannotAnalyze = true; + return S; + } + + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *S) { + // Could not analyze the expression. + CannotAnalyze = true; + return S; + } + + static const SCEV *rewrite(const SCEV *S, ScalarEvolution &SE, + unsigned StepMultiplier, unsigned Offset, + Loop *TheLoop) { + /// Bail out if the expression does not contain an UDiv expression. + /// Uniform values which are not loop invariant require operations to strip + /// out the lowest bits. For now just look for UDivs and use it to avoid + /// re-writing UDIV-free expressions for other lanes to limit compile time. + if (!SCEVExprContains(S, + [](const SCEV *S) { return isa<SCEVUDivExpr>(S); })) + return SE.getCouldNotCompute(); + + SCEVAddRecForUniformityRewriter Rewriter(SE, StepMultiplier, Offset, + TheLoop); + const SCEV *Result = Rewriter.visit(S); + + if (Rewriter.canAnalyze()) + return Result; + return SE.getCouldNotCompute(); + } +}; + +} // namespace + +bool LoopVectorizationLegality::isUniform(Value *V, ElementCount VF) const { + if (isInvariant(V)) + return true; + if (VF.isScalable()) + return false; + if (VF.isScalar()) + return true; + + // Since we rely on SCEV for uniformity, if the type is not SCEVable, it is + // never considered uniform. + auto *SE = PSE.getSE(); + if (!SE->isSCEVable(V->getType())) + return false; + const SCEV *S = SE->getSCEV(V); + + // Rewrite AddRecs in TheLoop to step by VF and check if the expression for + // lane 0 matches the expressions for all other lanes. + unsigned FixedVF = VF.getKnownMinValue(); + const SCEV *FirstLaneExpr = + SCEVAddRecForUniformityRewriter::rewrite(S, *SE, FixedVF, 0, TheLoop); + if (isa<SCEVCouldNotCompute>(FirstLaneExpr)) + return false; + + // Make sure the expressions for lanes FixedVF-1..1 match the expression for + // lane 0. We check lanes in reverse order for compile-time, as frequently + // checking the last lane is sufficient to rule out uniformity. + return all_of(reverse(seq<unsigned>(1, FixedVF)), [&](unsigned I) { + const SCEV *IthLaneExpr = + SCEVAddRecForUniformityRewriter::rewrite(S, *SE, FixedVF, I, TheLoop); + return FirstLaneExpr == IthLaneExpr; + }); } -bool LoopVectorizationLegality::isUniformMemOp(Instruction &I) const { +bool LoopVectorizationLegality::isUniformMemOp(Instruction &I, + ElementCount VF) const { Value *Ptr = getLoadStorePointerOperand(&I); if (!Ptr) return false; @@ -474,7 +607,7 @@ bool LoopVectorizationLegality::isUniformMemOp(Instruction &I) const { // stores from being uniform. The current lowering simply doesn't handle // it; in particular, the cost model distinguishes scatter/gather from // scalar w/predication, and we currently rely on the scalar path. - return isUniform(Ptr) && !blockNeedsPredication(I.getParent()); + return isUniform(Ptr, VF) && !blockNeedsPredication(I.getParent()); } bool LoopVectorizationLegality::canVectorizeOuterLoop() { @@ -700,6 +833,18 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } + // We prevent matching non-constant strided pointer IVS to preserve + // historical vectorizer behavior after a generalization of the + // IVDescriptor code. The intent is to remove this check, but we + // have to fix issues around code quality for such loops first. + auto isDisallowedStridedPointerInduction = + [](const InductionDescriptor &ID) { + if (AllowStridedPointerIVs) + return false; + return ID.getKind() == InductionDescriptor::IK_PtrInduction && + ID.getConstIntStepValue() == nullptr; + }; + // TODO: Instead of recording the AllowedExit, it would be good to // record the complementary set: NotAllowedExit. These include (but may // not be limited to): @@ -715,14 +860,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // By recording these, we can then reason about ways to vectorize each // of these NotAllowedExit. InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID) && + !isDisallowedStridedPointerInduction(ID)) { addInductionPhi(Phi, ID, AllowedExit); Requirements->addExactFPMathInst(ID.getExactFPMathInst()); continue; } - if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop, - SinkAfter, DT)) { + if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop, DT)) { AllowedExit.insert(Phi); FixedOrderRecurrences.insert(Phi); continue; @@ -730,7 +875,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // As a last resort, coerce the PHI to a AddRec expression // and re-try classifying it a an induction PHI. - if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true)) { + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) && + !isDisallowedStridedPointerInduction(ID)) { addInductionPhi(Phi, ID, AllowedExit); continue; } @@ -894,18 +1040,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } } - // For fixed order recurrences, we use the previous value (incoming value from - // the latch) to check if it dominates all users of the recurrence. Bail out - // if we have to sink such an instruction for another recurrence, as the - // dominance requirement may not hold after sinking. - BasicBlock *LoopLatch = TheLoop->getLoopLatch(); - if (any_of(FixedOrderRecurrences, [LoopLatch, this](const PHINode *Phi) { - Instruction *V = - cast<Instruction>(Phi->getIncomingValueForBlock(LoopLatch)); - return SinkAfter.find(V) != SinkAfter.end(); - })) - return false; - // Now we know the widest induction type, check if our found induction // is the same size. If it's not, unset it here and InnerLoopVectorizer // will create another. @@ -1124,6 +1258,16 @@ bool LoopVectorizationLegality::blockCanBePredicated( if (isa<NoAliasScopeDeclInst>(&I)) continue; + // We can allow masked calls if there's at least one vector variant, even + // if we end up scalarizing due to the cost model calculations. + // TODO: Allow other calls if they have appropriate attributes... readonly + // and argmemonly? + if (CallInst *CI = dyn_cast<CallInst>(&I)) + if (VFDatabase::hasMaskedVariant(*CI)) { + MaskedOp.insert(CI); + continue; + } + // Loads are handled via masking (or speculated if safe to do so.) if (auto *LI = dyn_cast<LoadInst>(&I)) { if (!SafePtrs.count(LI->getPointerOperand())) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 8990a65afdb4..13357cb06c55 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -25,6 +25,7 @@ #define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONPLANNER_H #include "VPlan.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Support/InstructionCost.h" namespace llvm { @@ -217,6 +218,16 @@ struct VectorizationFactor { } }; +/// ElementCountComparator creates a total ordering for ElementCount +/// for the purposes of using it in a set structure. +struct ElementCountComparator { + bool operator()(const ElementCount &LHS, const ElementCount &RHS) const { + return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) < + std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue()); + } +}; +using ElementCountSet = SmallSet<ElementCount, 16, ElementCountComparator>; + /// A class that represents two vectorization factors (initialized with 0 by /// default). One for fixed-width vectorization and one for scalable /// vectorization. This can be used by the vectorizer to choose from a range of @@ -261,7 +272,7 @@ class LoopVectorizationPlanner { const TargetLibraryInfo *TLI; /// Target Transform Info. - const TargetTransformInfo *TTI; + const TargetTransformInfo &TTI; /// The legality analysis. LoopVectorizationLegality *Legal; @@ -280,12 +291,15 @@ class LoopVectorizationPlanner { SmallVector<VPlanPtr, 4> VPlans; + /// Profitable vector factors. + SmallVector<VectorizationFactor, 8> ProfitableVFs; + /// A builder used to construct the current plan. VPBuilder Builder; public: LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, + const TargetTransformInfo &TTI, LoopVectorizationLegality *Legal, LoopVectorizationCostModel &CM, InterleavedAccessInfo &IAI, @@ -311,16 +325,22 @@ public: /// TODO: \p IsEpilogueVectorization is needed to avoid issues due to epilogue /// vectorization re-using plans for both the main and epilogue vector loops. /// It should be removed once the re-use issue has been fixed. - void executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan, - InnerLoopVectorizer &LB, DominatorTree *DT, - bool IsEpilogueVectorization); + /// \p ExpandedSCEVs is passed during execution of the plan for epilogue loop + /// to re-use expansion results generated during main plan execution. Returns + /// a mapping of SCEVs to their expanded IR values. Note that this is a + /// temporary workaround needed due to the current epilogue handling. + DenseMap<const SCEV *, Value *> + executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan, + InnerLoopVectorizer &LB, DominatorTree *DT, + bool IsEpilogueVectorization, + DenseMap<const SCEV *, Value *> *ExpandedSCEVs = nullptr); #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printPlans(raw_ostream &O); #endif - /// Look through the existing plans and return true if we have one with all - /// the vectorization factors in question. + /// Look through the existing plans and return true if we have one with + /// vectorization factor \p VF. bool hasPlanWithVF(ElementCount VF) const { return any_of(VPlans, [&](const VPlanPtr &Plan) { return Plan->hasVF(VF); }); @@ -333,8 +353,11 @@ public: getDecisionAndClampRange(const std::function<bool(ElementCount)> &Predicate, VFRange &Range); - /// Check if the number of runtime checks exceeds the threshold. - bool requiresTooManyRuntimeChecks() const; + /// \return The most profitable vectorization factor and the cost of that VF + /// for vectorizing the epilogue. Returns VectorizationFactor::Disabled if + /// epilogue vectorization is not supported for the loop. + VectorizationFactor + selectEpilogueVectorizationFactor(const ElementCount MaxVF, unsigned IC); protected: /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, @@ -350,9 +373,12 @@ private: /// Build a VPlan using VPRecipes according to the information gather by /// Legal. This method is only used for the legacy inner loop vectorizer. - VPlanPtr buildVPlanWithVPRecipes( - VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions, - const MapVector<Instruction *, Instruction *> &SinkAfter); + /// \p Range's largest included VF is restricted to the maximum VF the + /// returned VPlan is valid for. If no VPlan can be built for the input range, + /// set the largest included VF to the maximum VF for which no plan could be + /// built. + std::optional<VPlanPtr> tryToBuildVPlanWithVPRecipes( + VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions); /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, /// according to the information gathered by Legal when it checked if it is @@ -367,6 +393,20 @@ private: void adjustRecipesForReductions(VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF); + + /// \return The most profitable vectorization factor and the cost of that VF. + /// This method checks every VF in \p CandidateVFs. + VectorizationFactor + selectVectorizationFactor(const ElementCountSet &CandidateVFs); + + /// Returns true if the per-lane cost of VectorizationFactor A is lower than + /// that of B. + bool isMoreProfitable(const VectorizationFactor &A, + const VectorizationFactor &B) const; + + /// Determines if we have the infrastructure to vectorize the loop and its + /// epilogue, assuming the main loop is vectorized by \p VF. + bool isCandidateForEpilogueVectorization(const ElementCount VF) const; }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index a28099d8ba7d..d7e40e8ef978 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -98,6 +98,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" @@ -120,8 +121,6 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -231,6 +230,25 @@ static cl::opt<PreferPredicateTy::Option> PreferPredicateOverEpilogue( "prefers tail-folding, don't attempt vectorization if " "tail-folding fails."))); +static cl::opt<TailFoldingStyle> ForceTailFoldingStyle( + "force-tail-folding-style", cl::desc("Force the tail folding style"), + cl::init(TailFoldingStyle::None), + cl::values( + clEnumValN(TailFoldingStyle::None, "none", "Disable tail folding"), + clEnumValN( + TailFoldingStyle::Data, "data", + "Create lane mask for data only, using active.lane.mask intrinsic"), + clEnumValN(TailFoldingStyle::DataWithoutLaneMask, + "data-without-lane-mask", + "Create lane mask with compare/stepvector"), + clEnumValN(TailFoldingStyle::DataAndControlFlow, "data-and-control", + "Create lane mask using active.lane.mask intrinsic, and use " + "it for both data and control flow"), + clEnumValN( + TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck, + "data-and-control-without-rt-check", + "Similar to data-and-control, but remove the runtime check"))); + static cl::opt<bool> MaximizeBandwidth( "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, cl::desc("Maximize bandwidth when selecting vectorization factor which " @@ -338,10 +356,12 @@ static cl::opt<bool> PreferPredicatedReductionSelect( cl::desc( "Prefer predicating a reduction operation over an after loop select.")); +namespace llvm { cl::opt<bool> EnableVPlanNativePath( - "enable-vplan-native-path", cl::init(false), cl::Hidden, + "enable-vplan-native-path", cl::Hidden, cl::desc("Enable VPlan-native vectorization path with " "support for outer loop vectorization.")); +} // This flag enables the stress testing of the VPlan H-CFG construction in the // VPlan-native vectorization path. It must be used in conjuction with @@ -419,9 +439,42 @@ static std::optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, return std::nullopt; } +/// Return a vector containing interleaved elements from multiple +/// smaller input vectors. +static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals, + const Twine &Name) { + unsigned Factor = Vals.size(); + assert(Factor > 1 && "Tried to interleave invalid number of vectors"); + + VectorType *VecTy = cast<VectorType>(Vals[0]->getType()); +#ifndef NDEBUG + for (Value *Val : Vals) + assert(Val->getType() == VecTy && "Tried to interleave mismatched types"); +#endif + + // Scalable vectors cannot use arbitrary shufflevectors (only splats), so + // must use intrinsics to interleave. + if (VecTy->isScalableTy()) { + VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy); + return Builder.CreateIntrinsic( + WideVecTy, Intrinsic::experimental_vector_interleave2, Vals, + /*FMFSource=*/nullptr, Name); + } + + // Fixed length. Start by concatenating all vectors into a wide vector. + Value *WideVec = concatenateVectors(Builder, Vals); + + // Interleave the elements into the wide vector. + const unsigned NumElts = VecTy->getElementCount().getFixedValue(); + return Builder.CreateShuffleVector( + WideVec, createInterleaveMask(NumElts, Factor), Name); +} + namespace { // Forward declare GeneratedRTChecks. class GeneratedRTChecks; + +using SCEV2ValueTy = DenseMap<const SCEV *, Value *>; } // namespace namespace llvm { @@ -477,8 +530,10 @@ public: /// loop and the start value for the canonical induction, if it is != 0. The /// latter is the case when vectorizing the epilogue loop. In the case of /// epilogue vectorization, this function is overriden to handle the more - /// complex control flow around the loops. - virtual std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton(); + /// complex control flow around the loops. \p ExpandedSCEVs is used to + /// look up SCEV expansions for expressions needed during skeleton creation. + virtual std::pair<BasicBlock *, Value *> + createVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs); /// Fix the vectorized code, taking care of header phi's, live-outs, and more. void fixVectorizedLoop(VPTransformState &State, VPlan &Plan); @@ -498,7 +553,7 @@ public: /// Instr's operands. void scalarizeInstruction(const Instruction *Instr, VPReplicateRecipe *RepRecipe, - const VPIteration &Instance, bool IfPredicateInstr, + const VPIteration &Instance, VPTransformState &State); /// Construct the vector value of a scalarized value \p V one lane at a time. @@ -513,7 +568,7 @@ public: ArrayRef<VPValue *> VPDefs, VPTransformState &State, VPValue *Addr, ArrayRef<VPValue *> StoredValues, - VPValue *BlockInMask = nullptr); + VPValue *BlockInMask, bool NeedsMaskForGaps); /// Fix the non-induction PHIs in \p Plan. void fixNonInductionPHIs(VPlan &Plan, VPTransformState &State); @@ -522,28 +577,30 @@ public: /// able to vectorize with strict in-order reductions for the given RdxDesc. bool useOrderedReductions(const RecurrenceDescriptor &RdxDesc); - /// Create a broadcast instruction. This method generates a broadcast - /// instruction (shuffle) for loop invariant values and for the induction - /// value. If this is the induction variable then we extend it to N, N+1, ... - /// this is needed because each iteration in the loop corresponds to a SIMD - /// element. - virtual Value *getBroadcastInstrs(Value *V); - // Returns the resume value (bc.merge.rdx) for a reduction as // generated by fixReduction. PHINode *getReductionResumeValue(const RecurrenceDescriptor &RdxDesc); /// Create a new phi node for the induction variable \p OrigPhi to resume /// iteration count in the scalar epilogue, from where the vectorized loop - /// left off. In cases where the loop skeleton is more complicated (eg. - /// epilogue vectorization) and the resume values can come from an additional - /// bypass block, the \p AdditionalBypass pair provides information about the - /// bypass block and the end value on the edge from bypass to this loop. + /// left off. \p Step is the SCEV-expanded induction step to use. In cases + /// where the loop skeleton is more complicated (i.e., epilogue vectorization) + /// and the resume values can come from an additional bypass block, the \p + /// AdditionalBypass pair provides information about the bypass block and the + /// end value on the edge from bypass to this loop. PHINode *createInductionResumeValue( - PHINode *OrigPhi, const InductionDescriptor &ID, + PHINode *OrigPhi, const InductionDescriptor &ID, Value *Step, ArrayRef<BasicBlock *> BypassBlocks, std::pair<BasicBlock *, Value *> AdditionalBypass = {nullptr, nullptr}); + /// Returns the original loop trip count. + Value *getTripCount() const { return TripCount; } + + /// Used to set the trip count after ILV's construction and after the + /// preheader block has been executed. Note that this always holds the trip + /// count of the original loop for both main loop and epilogue vectorization. + void setTripCount(Value *TC) { TripCount = TC; } + protected: friend class LoopVectorizationPlanner; @@ -560,7 +617,7 @@ protected: void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, Value *VectorTripCount, Value *EndValue, BasicBlock *MiddleBlock, BasicBlock *VectorHeader, - VPlan &Plan); + VPlan &Plan, VPTransformState &State); /// Handle all cross-iteration phis in the header. void fixCrossIterationPHIs(VPTransformState &State); @@ -573,10 +630,6 @@ protected: /// Create code for the loop exit value of the reduction. void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); - /// Clear NSW/NUW flags from reduction instructions if necessary. - void clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, - VPTransformState &State); - /// Iteratively sink the scalarized operands of a predicated instruction into /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); @@ -585,9 +638,6 @@ protected: /// represented as. void truncateToMinimalBitwidths(VPTransformState &State); - /// Returns (and creates if needed) the original loop trip count. - Value *getOrCreateTripCount(BasicBlock *InsertBlock); - /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock); @@ -621,6 +671,7 @@ protected: /// block, the \p AdditionalBypass pair provides information about the bypass /// block and the end value on the edge from bypass to this loop. void createInductionResumeValues( + const SCEV2ValueTy &ExpandedSCEVs, std::pair<BasicBlock *, Value *> AdditionalBypass = {nullptr, nullptr}); /// Complete the loop skeleton by adding debug MDs, creating appropriate @@ -758,9 +809,6 @@ public: ElementCount::getFixed(1), ElementCount::getFixed(1), UnrollFactor, LVL, CM, BFI, PSI, Check) {} - -private: - Value *getBroadcastInstrs(Value *V) override; }; /// Encapsulate information regarding vectorization of a loop and its epilogue. @@ -810,15 +858,16 @@ public: // Override this function to handle the more complex control flow around the // three loops. - std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton() final { - return createEpilogueVectorizedLoopSkeleton(); + std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton( + const SCEV2ValueTy &ExpandedSCEVs) final { + return createEpilogueVectorizedLoopSkeleton(ExpandedSCEVs); } /// The interface for creating a vectorized skeleton using one of two /// different strategies, each corresponding to one execution of the vplan /// as described above. virtual std::pair<BasicBlock *, Value *> - createEpilogueVectorizedLoopSkeleton() = 0; + createEpilogueVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs) = 0; /// Holds and updates state information required to vectorize the main loop /// and its epilogue in two separate passes. This setup helps us avoid @@ -846,7 +895,8 @@ public: EPI, LVL, CM, BFI, PSI, Check) {} /// Implements the interface for creating a vectorized skeleton using the /// *main loop* strategy (ie the first pass of vplan execution). - std::pair<BasicBlock *, Value *> createEpilogueVectorizedLoopSkeleton() final; + std::pair<BasicBlock *, Value *> + createEpilogueVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs) final; protected: /// Emits an iteration count bypass check once for the main loop (when \p @@ -876,7 +926,8 @@ public: } /// Implements the interface for creating a vectorized skeleton using the /// *epilogue loop* strategy (ie the second pass of vplan execution). - std::pair<BasicBlock *, Value *> createEpilogueVectorizedLoopSkeleton() final; + std::pair<BasicBlock *, Value *> + createEpilogueVectorizedLoopSkeleton(const SCEV2ValueTy &ExpandedSCEVs) final; protected: /// Emits an iteration count bypass check after the main vector loop has @@ -953,35 +1004,21 @@ namespace llvm { Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, int64_t Step) { assert(Ty->isIntegerTy() && "Expected an integer step"); - Constant *StepVal = ConstantInt::get(Ty, Step * VF.getKnownMinValue()); - return VF.isScalable() ? B.CreateVScale(StepVal) : StepVal; + return B.CreateElementCount(Ty, VF.multiplyCoefficientBy(Step)); } /// Return the runtime value for VF. Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) { - Constant *EC = ConstantInt::get(Ty, VF.getKnownMinValue()); - return VF.isScalable() ? B.CreateVScale(EC) : EC; + return B.CreateElementCount(Ty, VF); } -const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE) { +const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE, + Loop *OrigLoop) { const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count"); ScalarEvolution &SE = *PSE.getSE(); - - // The exit count might have the type of i64 while the phi is i32. This can - // happen if we have an induction variable that is sign extended before the - // compare. The only way that we get a backedge taken count is that the - // induction variable was signed and as such will not overflow. In such a case - // truncation is legal. - if (SE.getTypeSizeInBits(BackedgeTakenCount->getType()) > - IdxTy->getPrimitiveSizeInBits()) - BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy); - BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); - - // Get the total trip count from the count by adding 1. - return SE.getAddExpr(BackedgeTakenCount, - SE.getOne(BackedgeTakenCount->getType())); + return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop); } static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy, @@ -1062,11 +1099,17 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( continue; // This recipe contributes to the address computation of a widen - // load/store. Collect recipe if its underlying instruction has - // poison-generating flags. - Instruction *Instr = CurRec->getUnderlyingInstr(); - if (Instr && Instr->hasPoisonGeneratingFlags()) - State.MayGeneratePoisonRecipes.insert(CurRec); + // load/store. If the underlying instruction has poison-generating flags, + // drop them directly. + if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) { + RecWithFlags->dropPoisonGeneratingFlags(); + } else { + Instruction *Instr = CurRec->getUnderlyingInstr(); + (void)Instr; + assert((!Instr || !Instr->hasPoisonGeneratingFlags()) && + "found instruction with poison generating flags not covered by " + "VPRecipeWithIRFlags"); + } // Add new definitions to the worklist. for (VPValue *operand : CurRec->operands()) @@ -1143,15 +1186,7 @@ enum ScalarEpilogueLowering { CM_ScalarEpilogueNotAllowedUsePredicate }; -/// ElementCountComparator creates a total ordering for ElementCount -/// for the purposes of using it in a set structure. -struct ElementCountComparator { - bool operator()(const ElementCount &LHS, const ElementCount &RHS) const { - return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) < - std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue()); - } -}; -using ElementCountSet = SmallSet<ElementCount, 16, ElementCountComparator>; +using InstructionVFPair = std::pair<Instruction *, ElementCount>; /// LoopVectorizationCostModel - estimates the expected speedups due to /// vectorization. @@ -1184,17 +1219,6 @@ public: /// otherwise. bool runtimeChecksRequired(); - /// \return The most profitable vectorization factor and the cost of that VF. - /// This method checks every VF in \p CandidateVFs. If UserVF is not ZERO - /// then this vectorization factor will be selected if vectorization is - /// possible. - VectorizationFactor - selectVectorizationFactor(const ElementCountSet &CandidateVFs); - - VectorizationFactor - selectEpilogueVectorizationFactor(const ElementCount MaxVF, - const LoopVectorizationPlanner &LVP); - /// Setup cost-based decisions for user vectorization factor. /// \return true if the UserVF is a feasible VF to be chosen. bool selectUserVectorizationFactor(ElementCount UserVF) { @@ -1278,11 +1302,17 @@ public: auto Scalars = InstsToScalarize.find(VF); assert(Scalars != InstsToScalarize.end() && "VF not yet analyzed for scalarization profitability"); - return Scalars->second.find(I) != Scalars->second.end(); + return Scalars->second.contains(I); } /// Returns true if \p I is known to be uniform after vectorization. bool isUniformAfterVectorization(Instruction *I, ElementCount VF) const { + // Pseudo probe needs to be duplicated for each unrolled iteration and + // vector lane so that profiled loop trip count can be accurately + // accumulated instead of being under counted. + if (isa<PseudoProbeInst>(I)) + return false; + if (VF.isScalar()) return true; @@ -1316,7 +1346,7 @@ public: /// \returns True if instruction \p I can be truncated to a smaller bitwidth /// for vectorization factor \p VF. bool canTruncateToMinimalBitwidth(Instruction *I, ElementCount VF) const { - return VF.isVector() && MinBWs.find(I) != MinBWs.end() && + return VF.isVector() && MinBWs.contains(I) && !isProfitableToScalarize(I, VF) && !isScalarAfterVectorization(I, VF); } @@ -1379,7 +1409,7 @@ public: InstructionCost getWideningCost(Instruction *I, ElementCount VF) { assert(VF.isVector() && "Expected VF >=2"); std::pair<Instruction *, ElementCount> InstOnVF = std::make_pair(I, VF); - assert(WideningDecisions.find(InstOnVF) != WideningDecisions.end() && + assert(WideningDecisions.contains(InstOnVF) && "The cost is not calculated"); return WideningDecisions[InstOnVF].second; } @@ -1419,7 +1449,7 @@ public: /// that may be vectorized as interleave, gather-scatter or scalarized. void collectUniformsAndScalars(ElementCount VF) { // Do the analysis once. - if (VF.isScalar() || Uniforms.find(VF) != Uniforms.end()) + if (VF.isScalar() || Uniforms.contains(VF)) return; setCostBasedWideningDecision(VF); collectLoopUniforms(VF); @@ -1442,8 +1472,7 @@ public: /// Returns true if the target machine can represent \p V as a masked gather /// or scatter operation. - bool isLegalGatherOrScatter(Value *V, - ElementCount VF = ElementCount::getFixed(1)) { + bool isLegalGatherOrScatter(Value *V, ElementCount VF) { bool LI = isa<LoadInst>(V); bool SI = isa<StoreInst>(V); if (!LI && !SI) @@ -1522,14 +1551,29 @@ public: /// Returns true if we're required to use a scalar epilogue for at least /// the final iteration of the original loop. - bool requiresScalarEpilogue(ElementCount VF) const { + bool requiresScalarEpilogue(bool IsVectorizing) const { if (!isScalarEpilogueAllowed()) return false; // If we might exit from anywhere but the latch, must run the exiting // iteration in scalar form. if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) return true; - return VF.isVector() && InterleaveInfo.requiresScalarEpilogue(); + return IsVectorizing && InterleaveInfo.requiresScalarEpilogue(); + } + + /// Returns true if we're required to use a scalar epilogue for at least + /// the final iteration of the original loop for all VFs in \p Range. + /// A scalar epilogue must either be required for all VFs in \p Range or for + /// none. + bool requiresScalarEpilogue(VFRange Range) const { + auto RequiresScalarEpilogue = [this](ElementCount VF) { + return requiresScalarEpilogue(VF.isVector()); + }; + bool IsRequired = all_of(Range, RequiresScalarEpilogue); + assert( + (IsRequired || none_of(Range, RequiresScalarEpilogue)) && + "all VFs in range must agree on whether a scalar epilogue is required"); + return IsRequired; } /// Returns true if a scalar epilogue is not allowed due to optsize or a @@ -1538,14 +1582,21 @@ public: return ScalarEpilogueStatus == CM_ScalarEpilogueAllowed; } - /// Returns true if all loop blocks should be masked to fold tail loop. - bool foldTailByMasking() const { return FoldTailByMasking; } + /// Returns the TailFoldingStyle that is best for the current loop. + TailFoldingStyle + getTailFoldingStyle(bool IVUpdateMayOverflow = true) const { + if (!CanFoldTailByMasking) + return TailFoldingStyle::None; + + if (ForceTailFoldingStyle.getNumOccurrences()) + return ForceTailFoldingStyle; + + return TTI.getPreferredTailFoldingStyle(IVUpdateMayOverflow); + } - /// Returns true if were tail-folding and want to use the active lane mask - /// for vector loop control flow. - bool useActiveLaneMaskForControlFlow() const { - return FoldTailByMasking && - TTI.emitGetActiveLaneMask() == PredicationStyle::DataAndControlFlow; + /// Returns true if all loop blocks should be masked to fold tail loop. + bool foldTailByMasking() const { + return getTailFoldingStyle() != TailFoldingStyle::None; } /// Returns true if the instructions in this block requires predication @@ -1582,12 +1633,8 @@ public: /// scalarized - /// i.e. either vector version isn't available, or is too expensive. InstructionCost getVectorCallCost(CallInst *CI, ElementCount VF, - bool &NeedToScalarize) const; - - /// Returns true if the per-lane cost of VectorizationFactor A is lower than - /// that of B. - bool isMoreProfitable(const VectorizationFactor &A, - const VectorizationFactor &B) const; + Function **Variant, + bool *NeedsMask = nullptr) const; /// Invalidates decisions already taken by the cost model. void invalidateCostModelingDecisions() { @@ -1596,10 +1643,29 @@ public: Scalars.clear(); } - /// Convenience function that returns the value of vscale_range iff - /// vscale_range.min == vscale_range.max or otherwise returns the value - /// returned by the corresponding TLI method. - std::optional<unsigned> getVScaleForTuning() const; + /// The vectorization cost is a combination of the cost itself and a boolean + /// indicating whether any of the contributing operations will actually + /// operate on vector values after type legalization in the backend. If this + /// latter value is false, then all operations will be scalarized (i.e. no + /// vectorization has actually taken place). + using VectorizationCostTy = std::pair<InstructionCost, bool>; + + /// Returns the expected execution cost. The unit of the cost does + /// not matter because we use the 'cost' units to compare different + /// vector widths. The cost that is returned is *not* normalized by + /// the factor width. If \p Invalid is not nullptr, this function + /// will add a pair(Instruction*, ElementCount) to \p Invalid for + /// each instruction that has an Invalid cost for the given VF. + VectorizationCostTy + expectedCost(ElementCount VF, + SmallVectorImpl<InstructionVFPair> *Invalid = nullptr); + + bool hasPredStores() const { return NumPredStores > 0; } + + /// Returns true if epilogue vectorization is considered profitable, and + /// false otherwise. + /// \p VF is the vectorization factor chosen for the original loop. + bool isEpilogueVectorizationProfitable(const ElementCount VF) const; private: unsigned NumPredStores = 0; @@ -1626,24 +1692,6 @@ private: /// of elements. ElementCount getMaxLegalScalableVF(unsigned MaxSafeElements); - /// The vectorization cost is a combination of the cost itself and a boolean - /// indicating whether any of the contributing operations will actually - /// operate on vector values after type legalization in the backend. If this - /// latter value is false, then all operations will be scalarized (i.e. no - /// vectorization has actually taken place). - using VectorizationCostTy = std::pair<InstructionCost, bool>; - - /// Returns the expected execution cost. The unit of the cost does - /// not matter because we use the 'cost' units to compare different - /// vector widths. The cost that is returned is *not* normalized by - /// the factor width. If \p Invalid is not nullptr, this function - /// will add a pair(Instruction*, ElementCount) to \p Invalid for - /// each instruction that has an Invalid cost for the given VF. - using InstructionVFPair = std::pair<Instruction *, ElementCount>; - VectorizationCostTy - expectedCost(ElementCount VF, - SmallVectorImpl<InstructionVFPair> *Invalid = nullptr); - /// Returns the execution time cost of an instruction for a given vector /// width. Vector width of one means scalar. VectorizationCostTy getInstructionCost(Instruction *I, ElementCount VF); @@ -1715,7 +1763,7 @@ private: ScalarEpilogueLowering ScalarEpilogueStatus = CM_ScalarEpilogueAllowed; /// All blocks of loop are to be masked to fold tail of scalar iterations. - bool FoldTailByMasking = false; + bool CanFoldTailByMasking = false; /// A map holding scalar costs for different vectorization factors. The /// presence of a cost for an instruction in the mapping indicates that the @@ -1796,8 +1844,7 @@ private: // the scalars are collected. That should be a safe assumption in most // cases, because we check if the operands have vectorizable types // beforehand in LoopVectorizationLegality. - return Scalars.find(VF) == Scalars.end() || - !isScalarAfterVectorization(I, VF); + return !Scalars.contains(VF) || !isScalarAfterVectorization(I, VF); }; /// Returns a range containing only operands needing to be extracted. @@ -1807,16 +1854,6 @@ private: Ops, [this, VF](Value *V) { return this->needsExtract(V, VF); })); } - /// Determines if we have the infrastructure to vectorize loop \p L and its - /// epilogue, assuming the main loop is vectorized by \p VF. - bool isCandidateForEpilogueVectorization(const Loop &L, - const ElementCount VF) const; - - /// Returns true if epilogue vectorization is considered profitable, and - /// false otherwise. - /// \p VF is the vectorization factor chosen for the original loop. - bool isEpilogueVectorizationProfitable(const ElementCount VF) const; - public: /// The loop that we evaluate. Loop *TheLoop; @@ -1862,9 +1899,6 @@ public: /// All element types found in the loop. SmallPtrSet<Type *, 16> ElementTypesInLoop; - - /// Profitable vector factors. - SmallVector<VectorizationFactor, 8> ProfitableVFs; }; } // end namespace llvm @@ -2135,6 +2169,17 @@ public: }; } // namespace +static bool useActiveLaneMask(TailFoldingStyle Style) { + return Style == TailFoldingStyle::Data || + Style == TailFoldingStyle::DataAndControlFlow || + Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck; +} + +static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) { + return Style == TailFoldingStyle::DataAndControlFlow || + Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck; +} + // Return true if \p OuterLp is an outer loop annotated with hints for explicit // vectorization. The loop needs to be annotated with #pragma omp simd // simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the @@ -2202,97 +2247,11 @@ static void collectSupportedLoops(Loop &L, LoopInfo *LI, collectSupportedLoops(*InnerL, LI, ORE, V); } -namespace { - -/// The LoopVectorize Pass. -struct LoopVectorize : public FunctionPass { - /// Pass identification, replacement for typeid - static char ID; - - LoopVectorizePass Impl; - - explicit LoopVectorize(bool InterleaveOnlyWhenForced = false, - bool VectorizeOnlyWhenForced = false) - : FunctionPass(ID), - Impl({InterleaveOnlyWhenForced, VectorizeOnlyWhenForced}) { - initializeLoopVectorizePass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); - auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - - return Impl - .runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AC, LAIs, *ORE, PSI) - .MadeAnyChange; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<LoopAccessLegacyAnalysis>(); - AU.addRequired<DemandedBitsWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<InjectTLIMappingsLegacy>(); - - // We currently do not preserve loopinfo/dominator analyses with outer loop - // vectorization. Until this is addressed, mark these analyses as preserved - // only for non-VPlan-native path. - // TODO: Preserve Loop and Dominator analyses for VPlan-native path. - if (!EnableVPlanNativePath) { - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - } - - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - //===----------------------------------------------------------------------===// // Implementation of LoopVectorizationLegality, InnerLoopVectorizer and // LoopVectorizationCostModel and LoopVectorizationPlanner. //===----------------------------------------------------------------------===// -Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { - // We need to place the broadcast of invariant variables outside the loop, - // but only if it's proven safe to do so. Else, broadcast will be inside - // vector loop body. - Instruction *Instr = dyn_cast<Instruction>(V); - bool SafeToHoist = OrigLoop->isLoopInvariant(V) && - (!Instr || - DT->dominates(Instr->getParent(), LoopVectorPreHeader)); - // Place the code for broadcasting invariant variables in the new preheader. - IRBuilder<>::InsertPointGuard Guard(Builder); - if (SafeToHoist) - Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); - - // Broadcast the scalar into all locations in the vector. - Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast"); - - return Shuf; -} - /// This function adds /// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...) /// to each vector element of Val. The sequence starts at StartIndex. @@ -2435,21 +2394,6 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step, } } -// Generate code for the induction step. Note that induction steps are -// required to be loop-invariant -static Value *CreateStepValue(const SCEV *Step, ScalarEvolution &SE, - Instruction *InsertBefore, - Loop *OrigLoop = nullptr) { - const DataLayout &DL = SE.getDataLayout(); - assert((!OrigLoop || SE.isLoopInvariant(Step, OrigLoop)) && - "Induction step should be loop invariant"); - if (auto *E = dyn_cast<SCEVUnknown>(Step)) - return E->getValue(); - - SCEVExpander Exp(SE, DL, "induction"); - return Exp.expandCodeFor(Step, Step->getType(), InsertBefore); -} - /// Compute the transformed value of Index at offset StartValue using step /// StepValue. /// For integer induction, returns StartValue + Index * StepValue. @@ -2514,9 +2458,7 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, return CreateAdd(StartValue, Offset); } case InductionDescriptor::IK_PtrInduction: { - assert(isa<Constant>(Step) && - "Expected constant step for pointer induction"); - return B.CreateGEP(ID.getElementType(), StartValue, CreateMul(Index, Step)); + return B.CreateGEP(B.getInt8Ty(), StartValue, CreateMul(Index, Step)); } case InductionDescriptor::IK_FpInduction: { assert(!isa<VectorType>(Index->getType()) && @@ -2538,6 +2480,50 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, llvm_unreachable("invalid enum"); } +std::optional<unsigned> getMaxVScale(const Function &F, + const TargetTransformInfo &TTI) { + if (std::optional<unsigned> MaxVScale = TTI.getMaxVScale()) + return MaxVScale; + + if (F.hasFnAttribute(Attribute::VScaleRange)) + return F.getFnAttribute(Attribute::VScaleRange).getVScaleRangeMax(); + + return std::nullopt; +} + +/// For the given VF and UF and maximum trip count computed for the loop, return +/// whether the induction variable might overflow in the vectorized loop. If not, +/// then we know a runtime overflow check always evaluates to false and can be +/// removed. +static bool isIndvarOverflowCheckKnownFalse( + const LoopVectorizationCostModel *Cost, + ElementCount VF, std::optional<unsigned> UF = std::nullopt) { + // Always be conservative if we don't know the exact unroll factor. + unsigned MaxUF = UF ? *UF : Cost->TTI.getMaxInterleaveFactor(VF); + + Type *IdxTy = Cost->Legal->getWidestInductionType(); + APInt MaxUIntTripCount = cast<IntegerType>(IdxTy)->getMask(); + + // We know the runtime overflow check is known false iff the (max) trip-count + // is known and (max) trip-count + (VF * UF) does not overflow in the type of + // the vector loop induction variable. + if (unsigned TC = + Cost->PSE.getSE()->getSmallConstantMaxTripCount(Cost->TheLoop)) { + uint64_t MaxVF = VF.getKnownMinValue(); + if (VF.isScalable()) { + std::optional<unsigned> MaxVScale = + getMaxVScale(*Cost->TheFunction, Cost->TTI); + if (!MaxVScale) + return false; + MaxVF *= *MaxVScale; + } + + return (MaxUIntTripCount - TC).ugt(MaxVF * MaxUF); + } + + return false; +} + void InnerLoopVectorizer::packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance, VPTransformState &State) { @@ -2591,14 +2577,13 @@ static bool useMaskedInterleavedAccesses(const TargetTransformInfo &TTI) { void InnerLoopVectorizer::vectorizeInterleaveGroup( const InterleaveGroup<Instruction> *Group, ArrayRef<VPValue *> VPDefs, VPTransformState &State, VPValue *Addr, ArrayRef<VPValue *> StoredValues, - VPValue *BlockInMask) { + VPValue *BlockInMask, bool NeedsMaskForGaps) { Instruction *Instr = Group->getInsertPos(); const DataLayout &DL = Instr->getModule()->getDataLayout(); // Prepare for the vector type of the interleaved load/store. Type *ScalarTy = getLoadStoreType(Instr); unsigned InterleaveFactor = Group->getFactor(); - assert(!VF.isScalable() && "scalable vectors not yet supported."); auto *VecTy = VectorType::get(ScalarTy, VF * InterleaveFactor); // Prepare for the new pointers. @@ -2609,14 +2594,21 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( assert((!BlockInMask || !Group->isReverse()) && "Reversed masked interleave-group not supported."); + Value *Idx; // If the group is reverse, adjust the index to refer to the last vector lane // instead of the first. We adjust the index from the first vector lane, // rather than directly getting the pointer for lane VF - 1, because the // pointer operand of the interleaved access is supposed to be uniform. For // uniform instructions, we're only required to generate a value for the // first vector lane in each unroll iteration. - if (Group->isReverse()) - Index += (VF.getKnownMinValue() - 1) * Group->getFactor(); + if (Group->isReverse()) { + Value *RuntimeVF = getRuntimeVF(Builder, Builder.getInt32Ty(), VF); + Idx = Builder.CreateSub(RuntimeVF, Builder.getInt32(1)); + Idx = Builder.CreateMul(Idx, Builder.getInt32(Group->getFactor())); + Idx = Builder.CreateAdd(Idx, Builder.getInt32(Index)); + Idx = Builder.CreateNeg(Idx); + } else + Idx = Builder.getInt32(-Index); for (unsigned Part = 0; Part < UF; Part++) { Value *AddrPart = State.get(Addr, VPIteration(Part, 0)); @@ -2637,8 +2629,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( bool InBounds = false; if (auto *gep = dyn_cast<GetElementPtrInst>(AddrPart->stripPointerCasts())) InBounds = gep->isInBounds(); - AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Builder.getInt32(-Index)); - cast<GetElementPtrInst>(AddrPart)->setIsInBounds(InBounds); + AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Idx, "", InBounds); // Cast to the vector pointer type. unsigned AddressSpace = AddrPart->getType()->getPointerAddressSpace(); @@ -2649,14 +2640,43 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( State.setDebugLocFromInst(Instr); Value *PoisonVec = PoisonValue::get(VecTy); - Value *MaskForGaps = nullptr; - if (Group->requiresScalarEpilogue() && !Cost->isScalarEpilogueAllowed()) { - MaskForGaps = createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); - assert(MaskForGaps && "Mask for Gaps is required but it is null"); - } + auto CreateGroupMask = [this, &BlockInMask, &State, &InterleaveFactor]( + unsigned Part, Value *MaskForGaps) -> Value * { + if (VF.isScalable()) { + assert(!MaskForGaps && "Interleaved groups with gaps are not supported."); + assert(InterleaveFactor == 2 && + "Unsupported deinterleave factor for scalable vectors"); + auto *BlockInMaskPart = State.get(BlockInMask, Part); + SmallVector<Value *, 2> Ops = {BlockInMaskPart, BlockInMaskPart}; + auto *MaskTy = + VectorType::get(Builder.getInt1Ty(), VF.getKnownMinValue() * 2, true); + return Builder.CreateIntrinsic( + MaskTy, Intrinsic::experimental_vector_interleave2, Ops, + /*FMFSource=*/nullptr, "interleaved.mask"); + } + + if (!BlockInMask) + return MaskForGaps; + + Value *BlockInMaskPart = State.get(BlockInMask, Part); + Value *ShuffledMask = Builder.CreateShuffleVector( + BlockInMaskPart, + createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()), + "interleaved.mask"); + return MaskForGaps ? Builder.CreateBinOp(Instruction::And, ShuffledMask, + MaskForGaps) + : ShuffledMask; + }; // Vectorize the interleaved load group. if (isa<LoadInst>(Instr)) { + Value *MaskForGaps = nullptr; + if (NeedsMaskForGaps) { + MaskForGaps = + createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); + assert(MaskForGaps && "Mask for Gaps is required but it is null"); + } + // For each unroll part, create a wide load for the group. SmallVector<Value *, 2> NewLoads; for (unsigned Part = 0; Part < UF; Part++) { @@ -2664,18 +2684,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( if (BlockInMask || MaskForGaps) { assert(useMaskedInterleavedAccesses(*TTI) && "masked interleaved groups are not allowed."); - Value *GroupMask = MaskForGaps; - if (BlockInMask) { - Value *BlockInMaskPart = State.get(BlockInMask, Part); - Value *ShuffledMask = Builder.CreateShuffleVector( - BlockInMaskPart, - createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()), - "interleaved.mask"); - GroupMask = MaskForGaps - ? Builder.CreateBinOp(Instruction::And, ShuffledMask, - MaskForGaps) - : ShuffledMask; - } + Value *GroupMask = CreateGroupMask(Part, MaskForGaps); NewLoad = Builder.CreateMaskedLoad(VecTy, AddrParts[Part], Group->getAlign(), GroupMask, PoisonVec, "wide.masked.vec"); @@ -2687,6 +2696,41 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( NewLoads.push_back(NewLoad); } + if (VecTy->isScalableTy()) { + assert(InterleaveFactor == 2 && + "Unsupported deinterleave factor for scalable vectors"); + + for (unsigned Part = 0; Part < UF; ++Part) { + // Scalable vectors cannot use arbitrary shufflevectors (only splats), + // so must use intrinsics to deinterleave. + Value *DI = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_deinterleave2, VecTy, NewLoads[Part], + /*FMFSource=*/nullptr, "strided.vec"); + unsigned J = 0; + for (unsigned I = 0; I < InterleaveFactor; ++I) { + Instruction *Member = Group->getMember(I); + + if (!Member) + continue; + + Value *StridedVec = Builder.CreateExtractValue(DI, I); + // If this member has different type, cast the result type. + if (Member->getType() != ScalarTy) { + VectorType *OtherVTy = VectorType::get(Member->getType(), VF); + StridedVec = createBitOrPointerCast(StridedVec, OtherVTy, DL); + } + + if (Group->isReverse()) + StridedVec = Builder.CreateVectorReverse(StridedVec, "reverse"); + + State.set(VPDefs[J], StridedVec, Part); + ++J; + } + } + + return; + } + // For each member in the group, shuffle out the appropriate data from the // wide loads. unsigned J = 0; @@ -2724,7 +2768,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( auto *SubVT = VectorType::get(ScalarTy, VF); // Vectorize the interleaved store group. - MaskForGaps = createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); + Value *MaskForGaps = + createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); assert((!MaskForGaps || useMaskedInterleavedAccesses(*TTI)) && "masked interleaved groups are not allowed."); assert((!MaskForGaps || !VF.isScalable()) && @@ -2759,27 +2804,11 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( StoredVecs.push_back(StoredVec); } - // Concatenate all vectors into a wide vector. - Value *WideVec = concatenateVectors(Builder, StoredVecs); - - // Interleave the elements in the wide vector. - Value *IVec = Builder.CreateShuffleVector( - WideVec, createInterleaveMask(VF.getKnownMinValue(), InterleaveFactor), - "interleaved.vec"); - + // Interleave all the smaller vectors into one wider vector. + Value *IVec = interleaveVectors(Builder, StoredVecs, "interleaved.vec"); Instruction *NewStoreInstr; if (BlockInMask || MaskForGaps) { - Value *GroupMask = MaskForGaps; - if (BlockInMask) { - Value *BlockInMaskPart = State.get(BlockInMask, Part); - Value *ShuffledMask = Builder.CreateShuffleVector( - BlockInMaskPart, - createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()), - "interleaved.mask"); - GroupMask = MaskForGaps ? Builder.CreateBinOp(Instruction::And, - ShuffledMask, MaskForGaps) - : ShuffledMask; - } + Value *GroupMask = CreateGroupMask(Part, MaskForGaps); NewStoreInstr = Builder.CreateMaskedStore(IVec, AddrParts[Part], Group->getAlign(), GroupMask); } else @@ -2793,7 +2822,6 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, VPReplicateRecipe *RepRecipe, const VPIteration &Instance, - bool IfPredicateInstr, VPTransformState &State) { assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); @@ -2810,14 +2838,7 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, if (!IsVoidRetTy) Cloned->setName(Instr->getName() + ".cloned"); - // If the scalarized instruction contributes to the address computation of a - // widen masked load/store which was in a basic block that needed predication - // and is not predicated after vectorization, we can't propagate - // poison-generating flags (nuw/nsw, exact, inbounds, etc.). The scalarized - // instruction could feed a poison value to the base address of the widen - // load/store. - if (State.MayGeneratePoisonRecipes.contains(RepRecipe)) - Cloned->dropPoisonGeneratingFlags(); + RepRecipe->setFlags(Cloned); if (Instr->getDebugLoc()) State.setDebugLocFromInst(Instr); @@ -2843,45 +2864,17 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, AC->registerAssumption(II); // End if-block. + bool IfPredicateInstr = RepRecipe->getParent()->getParent()->isReplicator(); if (IfPredicateInstr) PredicatedInstructions.push_back(Cloned); } -Value *InnerLoopVectorizer::getOrCreateTripCount(BasicBlock *InsertBlock) { - if (TripCount) - return TripCount; - - assert(InsertBlock); - IRBuilder<> Builder(InsertBlock->getTerminator()); - // Find the loop boundaries. - Type *IdxTy = Legal->getWidestInductionType(); - assert(IdxTy && "No type for induction"); - const SCEV *ExitCount = createTripCountSCEV(IdxTy, PSE); - - const DataLayout &DL = InsertBlock->getModule()->getDataLayout(); - - // Expand the trip count and place the new instructions in the preheader. - // Notice that the pre-header does not change, only the loop body. - SCEVExpander Exp(*PSE.getSE(), DL, "induction"); - - // Count holds the overall loop count (N). - TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), - InsertBlock->getTerminator()); - - if (TripCount->getType()->isPointerTy()) - TripCount = - CastInst::CreatePointerCast(TripCount, IdxTy, "exitcount.ptrcnt.to.int", - InsertBlock->getTerminator()); - - return TripCount; -} - Value * InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { if (VectorTripCount) return VectorTripCount; - Value *TC = getOrCreateTripCount(InsertBlock); + Value *TC = getTripCount(); IRBuilder<> Builder(InsertBlock->getTerminator()); Type *Ty = TC->getType(); @@ -2917,7 +2910,7 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { // the step does not evenly divide the trip count, no adjustment is necessary // since there will already be scalar iterations. Note that the minimum // iterations check ensures that N >= Step. - if (Cost->requiresScalarEpilogue(VF)) { + if (Cost->requiresScalarEpilogue(VF.isVector())) { auto *IsZero = Builder.CreateICmpEQ(R, ConstantInt::get(R->getType(), 0)); R = Builder.CreateSelect(IsZero, Step, R); } @@ -2930,10 +2923,10 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy, const DataLayout &DL) { // Verify that V is a vector type with same number of elements as DstVTy. - auto *DstFVTy = cast<FixedVectorType>(DstVTy); - unsigned VF = DstFVTy->getNumElements(); - auto *SrcVecTy = cast<FixedVectorType>(V->getType()); - assert((VF == SrcVecTy->getNumElements()) && "Vector dimensions do not match"); + auto *DstFVTy = cast<VectorType>(DstVTy); + auto VF = DstFVTy->getElementCount(); + auto *SrcVecTy = cast<VectorType>(V->getType()); + assert(VF == SrcVecTy->getElementCount() && "Vector dimensions do not match"); Type *SrcElemTy = SrcVecTy->getElementType(); Type *DstElemTy = DstFVTy->getElementType(); assert((DL.getTypeSizeInBits(SrcElemTy) == DL.getTypeSizeInBits(DstElemTy)) && @@ -2953,13 +2946,13 @@ Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy, "Only one type should be a floating point type"); Type *IntTy = IntegerType::getIntNTy(V->getContext(), DL.getTypeSizeInBits(SrcElemTy)); - auto *VecIntTy = FixedVectorType::get(IntTy, VF); + auto *VecIntTy = VectorType::get(IntTy, VF); Value *CastVal = Builder.CreateBitOrPointerCast(V, VecIntTy); return Builder.CreateBitOrPointerCast(CastVal, DstFVTy); } void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { - Value *Count = getOrCreateTripCount(LoopVectorPreHeader); + Value *Count = getTripCount(); // Reuse existing vector loop preheader for TC checks. // Note that new preheader block is generated for vector loop. BasicBlock *const TCCheckBlock = LoopVectorPreHeader; @@ -2970,8 +2963,8 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { // vector trip count is zero. This check also covers the case where adding one // to the backedge-taken count overflowed leading to an incorrect trip count // of zero. In this case we will also jump to the scalar loop. - auto P = Cost->requiresScalarEpilogue(VF) ? ICmpInst::ICMP_ULE - : ICmpInst::ICMP_ULT; + auto P = Cost->requiresScalarEpilogue(VF.isVector()) ? ICmpInst::ICMP_ULE + : ICmpInst::ICMP_ULT; // If tail is to be folded, vector loop takes care of all iterations. Type *CountTy = Count->getType(); @@ -2989,10 +2982,13 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { Intrinsic::umax, MinProfTC, createStepForVF(Builder, CountTy, VF, UF)); }; - if (!Cost->foldTailByMasking()) + TailFoldingStyle Style = Cost->getTailFoldingStyle(); + if (Style == TailFoldingStyle::None) CheckMinIters = Builder.CreateICmp(P, Count, CreateStep(), "min.iters.check"); - else if (VF.isScalable()) { + else if (VF.isScalable() && + !isIndvarOverflowCheckKnownFalse(Cost, VF, UF) && + Style != TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) { // vscale is not necessarily a power-of-2, which means we cannot guarantee // an overflow to zero when updating induction variables and so an // additional overflow check is required before entering the vector loop. @@ -3017,7 +3013,7 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { // Update dominator for Bypass & LoopExit (if needed). DT->changeImmediateDominator(Bypass, TCCheckBlock); - if (!Cost->requiresScalarEpilogue(VF)) + if (!Cost->requiresScalarEpilogue(VF.isVector())) // If there is an epilogue which must run, there's no edge from the // middle block to exit blocks and thus no need to update the immediate // dominator of the exit blocks. @@ -3044,7 +3040,7 @@ BasicBlock *InnerLoopVectorizer::emitSCEVChecks(BasicBlock *Bypass) { // Update dominator only if this is first RT check. if (LoopBypassBlocks.empty()) { DT->changeImmediateDominator(Bypass, SCEVCheckBlock); - if (!Cost->requiresScalarEpilogue(VF)) + if (!Cost->requiresScalarEpilogue(VF.isVector())) // If there is an epilogue which must run, there's no edge from the // middle block to exit blocks and thus no need to update the immediate // dominator of the exit blocks. @@ -3097,7 +3093,7 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { LoopVectorPreHeader = OrigLoop->getLoopPreheader(); assert(LoopVectorPreHeader && "Invalid loop structure"); LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr - assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF)) && + assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) && "multiple exit loop without required epilogue?"); LoopMiddleBlock = @@ -3117,17 +3113,18 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { // branch from the middle block to the loop scalar preheader, and the // exit block. completeLoopSkeleton will update the condition to use an // iteration check, if required to decide whether to execute the remainder. - BranchInst *BrInst = Cost->requiresScalarEpilogue(VF) ? - BranchInst::Create(LoopScalarPreHeader) : - BranchInst::Create(LoopExitBlock, LoopScalarPreHeader, - Builder.getTrue()); + BranchInst *BrInst = + Cost->requiresScalarEpilogue(VF.isVector()) + ? BranchInst::Create(LoopScalarPreHeader) + : BranchInst::Create(LoopExitBlock, LoopScalarPreHeader, + Builder.getTrue()); BrInst->setDebugLoc(ScalarLatchTerm->getDebugLoc()); ReplaceInstWithInst(LoopMiddleBlock->getTerminator(), BrInst); // Update dominator for loop exit. During skeleton creation, only the vector // pre-header and the middle block are created. The vector loop is entirely // created during VPlan exection. - if (!Cost->requiresScalarEpilogue(VF)) + if (!Cost->requiresScalarEpilogue(VF.isVector())) // If there is an epilogue which must run, there's no edge from the // middle block to exit blocks and thus no need to update the immediate // dominator of the exit blocks. @@ -3135,7 +3132,7 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { } PHINode *InnerLoopVectorizer::createInductionResumeValue( - PHINode *OrigPhi, const InductionDescriptor &II, + PHINode *OrigPhi, const InductionDescriptor &II, Value *Step, ArrayRef<BasicBlock *> BypassBlocks, std::pair<BasicBlock *, Value *> AdditionalBypass) { Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); @@ -3154,8 +3151,6 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue( if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp())) B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags()); - Value *Step = - CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); EndValue = emitTransformedIndex(B, VectorTripCount, II.getStartValue(), Step, II); EndValue->setName("ind.end"); @@ -3163,8 +3158,6 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue( // Compute the end value for the additional bypass (if applicable). if (AdditionalBypass.first) { B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt())); - Value *Step = - CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); EndValueFromAdditionalBypass = emitTransformedIndex( B, AdditionalBypass.second, II.getStartValue(), Step, II); EndValueFromAdditionalBypass->setName("ind.end"); @@ -3193,7 +3186,22 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue( return BCResumeVal; } +/// Return the expanded step for \p ID using \p ExpandedSCEVs to look up SCEV +/// expansion results. +static Value *getExpandedStep(const InductionDescriptor &ID, + const SCEV2ValueTy &ExpandedSCEVs) { + const SCEV *Step = ID.getStep(); + if (auto *C = dyn_cast<SCEVConstant>(Step)) + return C->getValue(); + if (auto *U = dyn_cast<SCEVUnknown>(Step)) + return U->getValue(); + auto I = ExpandedSCEVs.find(Step); + assert(I != ExpandedSCEVs.end() && "SCEV must be expanded at this point"); + return I->second; +} + void InnerLoopVectorizer::createInductionResumeValues( + const SCEV2ValueTy &ExpandedSCEVs, std::pair<BasicBlock *, Value *> AdditionalBypass) { assert(((AdditionalBypass.first && AdditionalBypass.second) || (!AdditionalBypass.first && !AdditionalBypass.second)) && @@ -3209,14 +3217,15 @@ void InnerLoopVectorizer::createInductionResumeValues( PHINode *OrigPhi = InductionEntry.first; const InductionDescriptor &II = InductionEntry.second; PHINode *BCResumeVal = createInductionResumeValue( - OrigPhi, II, LoopBypassBlocks, AdditionalBypass); + OrigPhi, II, getExpandedStep(II, ExpandedSCEVs), LoopBypassBlocks, + AdditionalBypass); OrigPhi->setIncomingValueForBlock(LoopScalarPreHeader, BCResumeVal); } } BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() { // The trip counts should be cached by now. - Value *Count = getOrCreateTripCount(LoopVectorPreHeader); + Value *Count = getTripCount(); Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); auto *ScalarLatchTerm = OrigLoop->getLoopLatch()->getTerminator(); @@ -3229,7 +3238,8 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() { // Thus if tail is to be folded, we know we don't need to run the // remainder and we can use the previous value for the condition (true). // 3) Otherwise, construct a runtime check. - if (!Cost->requiresScalarEpilogue(VF) && !Cost->foldTailByMasking()) { + if (!Cost->requiresScalarEpilogue(VF.isVector()) && + !Cost->foldTailByMasking()) { Instruction *CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, VectorTripCount, "cmp.n", LoopMiddleBlock->getTerminator()); @@ -3250,14 +3260,16 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() { } std::pair<BasicBlock *, Value *> -InnerLoopVectorizer::createVectorizedLoopSkeleton() { +InnerLoopVectorizer::createVectorizedLoopSkeleton( + const SCEV2ValueTy &ExpandedSCEVs) { /* In this function we generate a new loop. The new loop will contain the vectorized instructions while the old loop will continue to run the scalar remainder. - [ ] <-- loop iteration number check. - / | + [ ] <-- old preheader - loop iteration number check and SCEVs in Plan's + / | preheader are expanded here. Eventually all required SCEV + / | expansion should happen here. / v | [ ] <-- vector loop bypass (may consist of multiple blocks). | / | @@ -3304,7 +3316,7 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() { emitMemRuntimeChecks(LoopScalarPreHeader); // Emit phis for the new starting index of the scalar loop. - createInductionResumeValues(); + createInductionResumeValues(ExpandedSCEVs); return {completeLoopSkeleton(), nullptr}; } @@ -3317,7 +3329,8 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, Value *VectorTripCount, Value *EndValue, BasicBlock *MiddleBlock, - BasicBlock *VectorHeader, VPlan &Plan) { + BasicBlock *VectorHeader, VPlan &Plan, + VPTransformState &State) { // There are two kinds of external IV usages - those that use the value // computed in the last iteration (the PHI) and those that use the penultimate // value (the value that feeds into the phi from the loop latch). @@ -3345,7 +3358,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, auto *UI = cast<Instruction>(U); if (!OrigLoop->contains(UI)) { assert(isa<PHINode>(UI) && "Expected LCSSA form"); - IRBuilder<> B(MiddleBlock->getTerminator()); // Fast-math-flags propagate from the original induction instruction. @@ -3355,8 +3367,11 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, Value *CountMinusOne = B.CreateSub( VectorTripCount, ConstantInt::get(VectorTripCount->getType(), 1)); CountMinusOne->setName("cmo"); - Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(), - VectorHeader->getTerminator()); + + VPValue *StepVPV = Plan.getSCEVExpansion(II.getStep()); + assert(StepVPV && "step must have been expanded during VPlan execution"); + Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue() + : State.get(StepVPV, {0, 0}); Value *Escape = emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step, II); Escape->setName("ind.escape"); @@ -3430,12 +3445,12 @@ static void cse(BasicBlock *BB) { } } -InstructionCost -LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF, - bool &NeedToScalarize) const { +InstructionCost LoopVectorizationCostModel::getVectorCallCost( + CallInst *CI, ElementCount VF, Function **Variant, bool *NeedsMask) const { Function *F = CI->getCalledFunction(); Type *ScalarRetTy = CI->getType(); SmallVector<Type *, 4> Tys, ScalarTys; + bool MaskRequired = Legal->isMaskRequired(CI); for (auto &ArgOp : CI->args()) ScalarTys.push_back(ArgOp->getType()); @@ -3464,18 +3479,39 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF, // If we can't emit a vector call for this function, then the currently found // cost is the cost we need to return. - NeedToScalarize = true; - VFShape Shape = VFShape::get(*CI, VF, false /*HasGlobalPred*/); + InstructionCost MaskCost = 0; + VFShape Shape = VFShape::get(*CI, VF, MaskRequired); + if (NeedsMask) + *NeedsMask = MaskRequired; Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); + // If we want an unmasked vector function but can't find one matching the VF, + // maybe we can find vector function that does use a mask and synthesize + // an all-true mask. + if (!VecFunc && !MaskRequired) { + Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true); + VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); + // If we found one, add in the cost of creating a mask + if (VecFunc) { + if (NeedsMask) + *NeedsMask = true; + MaskCost = TTI.getShuffleCost( + TargetTransformInfo::SK_Broadcast, + VectorType::get( + IntegerType::getInt1Ty(VecFunc->getFunctionType()->getContext()), + VF)); + } + } + // We don't support masked function calls yet, but we can scalarize a + // masked call with branches (unless VF is scalable). if (!TLI || CI->isNoBuiltin() || !VecFunc) - return Cost; + return VF.isScalable() ? InstructionCost::getInvalid() : Cost; // If the corresponding vector cost is cheaper, return its cost. InstructionCost VectorCallCost = - TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind); + TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost; if (VectorCallCost < Cost) { - NeedToScalarize = false; + *Variant = VecFunc; Cost = VectorCallCost; } return Cost; @@ -3675,14 +3711,25 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, // Forget the original basic block. PSE.getSE()->forgetLoop(OrigLoop); + // After vectorization, the exit blocks of the original loop will have + // additional predecessors. Invalidate SCEVs for the exit phis in case SE + // looked through single-entry phis. + SmallVector<BasicBlock *> ExitBlocks; + OrigLoop->getExitBlocks(ExitBlocks); + for (BasicBlock *Exit : ExitBlocks) + for (PHINode &PN : Exit->phis()) + PSE.getSE()->forgetValue(&PN); + VPBasicBlock *LatchVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock(); Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]); - if (Cost->requiresScalarEpilogue(VF)) { + if (Cost->requiresScalarEpilogue(VF.isVector())) { // No edge from the middle block to the unique exit block has been inserted // and there is nothing to fix from vector loop; phis should have incoming // from scalar loop only. - Plan.clearLiveOuts(); } else { + // TODO: Check VPLiveOuts to see if IV users need fixing instead of checking + // the cost model. + // If we inserted an edge from the middle block to the unique exit block, // update uses outside the loop (phis) to account for the newly inserted // edge. @@ -3692,7 +3739,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, fixupIVUsers(Entry.first, Entry.second, getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()), IVEndValues[Entry.first], LoopMiddleBlock, - VectorLoop->getHeader(), Plan); + VectorLoop->getHeader(), Plan, State); } // Fix LCSSA phis not already fixed earlier. Extracts may need to be generated @@ -3799,31 +3846,53 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence( Value *Incoming = State.get(PreviousDef, UF - 1); auto *ExtractForScalar = Incoming; auto *IdxTy = Builder.getInt32Ty(); + Value *RuntimeVF = nullptr; if (VF.isVector()) { auto *One = ConstantInt::get(IdxTy, 1); Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); - auto *RuntimeVF = getRuntimeVF(Builder, IdxTy, VF); + RuntimeVF = getRuntimeVF(Builder, IdxTy, VF); auto *LastIdx = Builder.CreateSub(RuntimeVF, One); - ExtractForScalar = Builder.CreateExtractElement(ExtractForScalar, LastIdx, - "vector.recur.extract"); - } - // Extract the second last element in the middle block if the - // Phi is used outside the loop. We need to extract the phi itself - // and not the last element (the phi update in the current iteration). This - // will be the value when jumping to the exit block from the LoopMiddleBlock, - // when the scalar loop is not run at all. - Value *ExtractForPhiUsedOutsideLoop = nullptr; - if (VF.isVector()) { - auto *RuntimeVF = getRuntimeVF(Builder, IdxTy, VF); - auto *Idx = Builder.CreateSub(RuntimeVF, ConstantInt::get(IdxTy, 2)); - ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement( - Incoming, Idx, "vector.recur.extract.for.phi"); - } else if (UF > 1) - // When loop is unrolled without vectorizing, initialize - // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled value - // of `Incoming`. This is analogous to the vectorized case above: extracting - // the second last element when VF > 1. - ExtractForPhiUsedOutsideLoop = State.get(PreviousDef, UF - 2); + ExtractForScalar = + Builder.CreateExtractElement(Incoming, LastIdx, "vector.recur.extract"); + } + + auto RecurSplice = cast<VPInstruction>(*PhiR->user_begin()); + assert(PhiR->getNumUsers() == 1 && + RecurSplice->getOpcode() == + VPInstruction::FirstOrderRecurrenceSplice && + "recurrence phi must have a single user: FirstOrderRecurrenceSplice"); + SmallVector<VPLiveOut *> LiveOuts; + for (VPUser *U : RecurSplice->users()) + if (auto *LiveOut = dyn_cast<VPLiveOut>(U)) + LiveOuts.push_back(LiveOut); + + if (!LiveOuts.empty()) { + // Extract the second last element in the middle block if the + // Phi is used outside the loop. We need to extract the phi itself + // and not the last element (the phi update in the current iteration). This + // will be the value when jumping to the exit block from the + // LoopMiddleBlock, when the scalar loop is not run at all. + Value *ExtractForPhiUsedOutsideLoop = nullptr; + if (VF.isVector()) { + auto *Idx = Builder.CreateSub(RuntimeVF, ConstantInt::get(IdxTy, 2)); + ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement( + Incoming, Idx, "vector.recur.extract.for.phi"); + } else { + assert(UF > 1 && "VF and UF cannot both be 1"); + // When loop is unrolled without vectorizing, initialize + // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled + // value of `Incoming`. This is analogous to the vectorized case above: + // extracting the second last element when VF > 1. + ExtractForPhiUsedOutsideLoop = State.get(PreviousDef, UF - 2); + } + + for (VPLiveOut *LiveOut : LiveOuts) { + assert(!Cost->requiresScalarEpilogue(VF.isVector())); + PHINode *LCSSAPhi = LiveOut->getPhi(); + LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); + State.Plan->removeLiveOut(LCSSAPhi); + } + } // Fix the initial value of the original recurrence in the scalar loop. Builder.SetInsertPoint(&*LoopScalarPreHeader->begin()); @@ -3837,22 +3906,6 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence( Phi->setIncomingValueForBlock(LoopScalarPreHeader, Start); Phi->setName("scalar.recur"); - - // Finally, fix users of the recurrence outside the loop. The users will need - // either the last value of the scalar recurrence or the last value of the - // vector recurrence we extracted in the middle block. Since the loop is in - // LCSSA form, we just need to find all the phi nodes for the original scalar - // recurrence in the exit block, and then add an edge for the middle block. - // Note that LCSSA does not imply single entry when the original scalar loop - // had multiple exiting edges (as we always run the last iteration in the - // scalar epilogue); in that case, there is no edge from middle to exit and - // and thus no phis which needed updated. - if (!Cost->requiresScalarEpilogue(VF)) - for (PHINode &LCSSAPhi : LoopExitBlock->phis()) - if (llvm::is_contained(LCSSAPhi.incoming_values(), Phi)) { - LCSSAPhi.addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); - State.Plan->removeLiveOut(&LCSSAPhi); - } } void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, @@ -3872,9 +3925,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // This is the vector-clone of the value that leaves the loop. Type *VecTy = State.get(LoopExitInstDef, 0)->getType(); - // Wrap flags are in general invalid after vectorization, clear them. - clearReductionWrapFlags(PhiR, State); - // Before each round, move the insertion point right between // the PHIs and the values we are going to write. // This allows us to write both PHINodes and the extractelement @@ -4036,7 +4086,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // We know that the loop is in LCSSA form. We need to update the PHI nodes // in the exit blocks. See comment on analogous loop in // fixFixedOrderRecurrence for a more complete explaination of the logic. - if (!Cost->requiresScalarEpilogue(VF)) + if (!Cost->requiresScalarEpilogue(VF.isVector())) for (PHINode &LCSSAPhi : LoopExitBlock->phis()) if (llvm::is_contained(LCSSAPhi.incoming_values(), LoopExitInst)) { LCSSAPhi.addIncoming(ReducedPartRdx, LoopMiddleBlock); @@ -4054,38 +4104,6 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, OrigPhi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); } -void InnerLoopVectorizer::clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, - VPTransformState &State) { - const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); - RecurKind RK = RdxDesc.getRecurrenceKind(); - if (RK != RecurKind::Add && RK != RecurKind::Mul) - return; - - SmallVector<VPValue *, 8> Worklist; - SmallPtrSet<VPValue *, 8> Visited; - Worklist.push_back(PhiR); - Visited.insert(PhiR); - - while (!Worklist.empty()) { - VPValue *Cur = Worklist.pop_back_val(); - for (unsigned Part = 0; Part < UF; ++Part) { - Value *V = State.get(Cur, Part); - if (!isa<OverflowingBinaryOperator>(V)) - break; - cast<Instruction>(V)->dropPoisonGeneratingFlags(); - } - - for (VPUser *U : Cur->users()) { - auto *UserRecipe = dyn_cast<VPRecipeBase>(U); - if (!UserRecipe) - continue; - for (VPValue *V : UserRecipe->definedValues()) - if (Visited.insert(V).second) - Worklist.push_back(V); - } - } -} - void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { // The basic block and loop containing the predicated instruction. auto *PredBB = PredInst->getParent(); @@ -4125,10 +4143,11 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { auto *I = dyn_cast<Instruction>(Worklist.pop_back_val()); // We can't sink an instruction if it is a phi node, is not in the loop, - // or may have side effects. + // may have side effects or may read from memory. + // TODO Could dor more granular checking to allow sinking a load past non-store instructions. if (!I || isa<PHINode>(I) || !VectorLoop->contains(I) || - I->mayHaveSideEffects()) - continue; + I->mayHaveSideEffects() || I->mayReadFromMemory()) + continue; // If the instruction is already in PredBB, check if we can sink its // operands. In that case, VPlan's sinkScalarOperands() succeeded in @@ -4189,7 +4208,7 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { // We should not collect Scalars more than once per VF. Right now, this // function is called from collectUniformsAndScalars(), which already does // this check. Collecting Scalars for VF=1 does not make any sense. - assert(VF.isVector() && Scalars.find(VF) == Scalars.end() && + assert(VF.isVector() && !Scalars.contains(VF) && "This function should not be visited twice for the same VF"); // This avoids any chances of creating a REPLICATE recipe during planning @@ -4382,6 +4401,8 @@ bool LoopVectorizationCostModel::isScalarWithPredication( switch(I->getOpcode()) { default: return true; + case Instruction::Call: + return !VFDatabase::hasMaskedVariant(*(cast<CallInst>(I)), VF); case Instruction::Load: case Instruction::Store: { auto *Ptr = getLoadStorePointerOperand(I); @@ -4430,10 +4451,10 @@ bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const { // both speculation safety (which follows from the same argument as loads), // but also must prove the value being stored is correct. The easiest // form of the later is to require that all values stored are the same. - if (Legal->isUniformMemOp(*I) && - (isa<LoadInst>(I) || - (isa<StoreInst>(I) && - TheLoop->isLoopInvariant(cast<StoreInst>(I)->getValueOperand()))) && + if (Legal->isInvariant(getLoadStorePointerOperand(I)) && + (isa<LoadInst>(I) || + (isa<StoreInst>(I) && + TheLoop->isLoopInvariant(cast<StoreInst>(I)->getValueOperand()))) && !Legal->blockNeedsPredication(I->getParent())) return false; return true; @@ -4445,6 +4466,8 @@ bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const { // TODO: We can use the loop-preheader as context point here and get // context sensitive reasoning return !isSafeToSpeculativelyExecute(I); + case Instruction::Call: + return Legal->isMaskRequired(I); } } @@ -4502,7 +4525,8 @@ LoopVectorizationCostModel::getDivRemSpeculationCost(Instruction *I, // second vector operand. One example of this are shifts on x86. Value *Op2 = I->getOperand(1); auto Op2Info = TTI.getOperandInfo(Op2); - if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2)) + if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && + Legal->isInvariant(Op2)) Op2Info.Kind = TargetTransformInfo::OK_UniformValue; SmallVector<const Value *, 4> Operands(I->operand_values()); @@ -4614,7 +4638,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { // already does this check. Collecting Uniforms for VF=1 does not make any // sense. - assert(VF.isVector() && Uniforms.find(VF) == Uniforms.end() && + assert(VF.isVector() && !Uniforms.contains(VF) && "This function should not be visited twice for the same VF"); // Visit the list of Uniforms. If we'll not find any uniform value, we'll @@ -4663,10 +4687,18 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse()) addToWorklistIfAllowed(Cmp); + auto PrevVF = VF.divideCoefficientBy(2); // Return true if all lanes perform the same memory operation, and we can // thus chose to execute only one. auto isUniformMemOpUse = [&](Instruction *I) { - if (!Legal->isUniformMemOp(*I)) + // If the value was already known to not be uniform for the previous + // (smaller VF), it cannot be uniform for the larger VF. + if (PrevVF.isVector()) { + auto Iter = Uniforms.find(PrevVF); + if (Iter != Uniforms.end() && !Iter->second.contains(I)) + return false; + } + if (!Legal->isUniformMemOp(*I, VF)) return false; if (isa<LoadInst>(I)) // Loading the same address always produces the same result - at least @@ -4689,11 +4721,14 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { WideningDecision == CM_Interleave); }; - // Returns true if Ptr is the pointer operand of a memory access instruction - // I, and I is known to not require scalarization. + // I, I is known to not require scalarization, and the pointer is not also + // stored. auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool { - return getLoadStorePointerOperand(I) == Ptr && isUniformDecision(I, VF); + if (isa<StoreInst>(I) && I->getOperand(0) == Ptr) + return false; + return getLoadStorePointerOperand(I) == Ptr && + (isUniformDecision(I, VF) || Legal->isInvariant(Ptr)); }; // Holds a list of values which are known to have at least one uniform use. @@ -4739,10 +4774,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { if (isUniformMemOpUse(&I)) addToWorklistIfAllowed(&I); - if (isUniformDecision(&I, VF)) { - assert(isVectorizedMemAccessUse(&I, Ptr) && "consistency check"); + if (isVectorizedMemAccessUse(&I, Ptr)) HasUniformUse.insert(Ptr); - } } // Add to the worklist any operands which have *only* uniform (e.g. lane 0 @@ -4906,12 +4939,11 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { return MaxScalableVF; // Limit MaxScalableVF by the maximum safe dependence distance. - std::optional<unsigned> MaxVScale = TTI.getMaxVScale(); - if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange)) - MaxVScale = - TheFunction->getFnAttribute(Attribute::VScaleRange).getVScaleRangeMax(); - MaxScalableVF = - ElementCount::getScalable(MaxVScale ? (MaxSafeElements / *MaxVScale) : 0); + if (std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI)) + MaxScalableVF = ElementCount::getScalable(MaxSafeElements / *MaxVScale); + else + MaxScalableVF = ElementCount::getScalable(0); + if (!MaxScalableVF) reportVectorizationInfo( "Max legal vector width too small, scalable vectorization " @@ -4932,7 +4964,7 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF( // the memory accesses that is most restrictive (involved in the smallest // dependence distance). unsigned MaxSafeElements = - PowerOf2Floor(Legal->getMaxSafeVectorWidthInBits() / WidestType); + llvm::bit_floor(Legal->getMaxSafeVectorWidthInBits() / WidestType); auto MaxSafeFixedVF = ElementCount::getFixed(MaxSafeElements); auto MaxSafeScalableVF = getMaxLegalScalableVF(MaxSafeElements); @@ -5105,16 +5137,26 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { } FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(TC, UserVF, true); + // Avoid tail folding if the trip count is known to be a multiple of any VF - // we chose. - // FIXME: The condition below pessimises the case for fixed-width vectors, - // when scalable VFs are also candidates for vectorization. - if (MaxFactors.FixedVF.isVector() && !MaxFactors.ScalableVF) { - ElementCount MaxFixedVF = MaxFactors.FixedVF; - assert((UserVF.isNonZero() || isPowerOf2_32(MaxFixedVF.getFixedValue())) && + // we choose. + std::optional<unsigned> MaxPowerOf2RuntimeVF = + MaxFactors.FixedVF.getFixedValue(); + if (MaxFactors.ScalableVF) { + std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI); + if (MaxVScale && TTI.isVScaleKnownToBeAPowerOfTwo()) { + MaxPowerOf2RuntimeVF = std::max<unsigned>( + *MaxPowerOf2RuntimeVF, + *MaxVScale * MaxFactors.ScalableVF.getKnownMinValue()); + } else + MaxPowerOf2RuntimeVF = std::nullopt; // Stick with tail-folding for now. + } + + if (MaxPowerOf2RuntimeVF && *MaxPowerOf2RuntimeVF > 0) { + assert((UserVF.isNonZero() || isPowerOf2_32(*MaxPowerOf2RuntimeVF)) && "MaxFixedVF must be a power of 2"); - unsigned MaxVFtimesIC = UserIC ? MaxFixedVF.getFixedValue() * UserIC - : MaxFixedVF.getFixedValue(); + unsigned MaxVFtimesIC = + UserIC ? *MaxPowerOf2RuntimeVF * UserIC : *MaxPowerOf2RuntimeVF; ScalarEvolution *SE = PSE.getSE(); const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); const SCEV *ExitCount = SE->getAddExpr( @@ -5134,7 +5176,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { // by masking. // FIXME: look for a smaller MaxVF that does divide TC rather than masking. if (Legal->prepareToFoldTailByMasking()) { - FoldTailByMasking = true; + CanFoldTailByMasking = true; return MaxFactors; } @@ -5187,7 +5229,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( // Ensure MaxVF is a power of 2; the dependence distance bound may not be. // Note that both WidestRegister and WidestType may not be a powers of 2. auto MaxVectorElementCount = ElementCount::get( - PowerOf2Floor(WidestRegister.getKnownMinValue() / WidestType), + llvm::bit_floor(WidestRegister.getKnownMinValue() / WidestType), ComputeScalableMaxVF); MaxVectorElementCount = MinVF(MaxVectorElementCount, MaxSafeVF); LLVM_DEBUG(dbgs() << "LV: The Widest register safe to use is: " @@ -5207,6 +5249,13 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( auto Min = Attr.getVScaleRangeMin(); WidestRegisterMinEC *= Min; } + + // When a scalar epilogue is required, at least one iteration of the scalar + // loop has to execute. Adjust ConstTripCount accordingly to avoid picking a + // max VF that results in a dead vector loop. + if (ConstTripCount > 0 && requiresScalarEpilogue(true)) + ConstTripCount -= 1; + if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC && (!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) { // If loop trip count (TC) is known at compile time there is no point in @@ -5214,7 +5263,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( // power of two which doesn't exceed TC. // If MaxVectorElementCount is scalable, we only fall back on a fixed VF // when the TC is less than or equal to the known number of lanes. - auto ClampedConstTripCount = PowerOf2Floor(ConstTripCount); + auto ClampedConstTripCount = llvm::bit_floor(ConstTripCount); LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not " "exceeding the constant trip count: " << ClampedConstTripCount << "\n"); @@ -5228,7 +5277,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( if (MaximizeBandwidth || (MaximizeBandwidth.getNumOccurrences() == 0 && TTI.shouldMaximizeVectorBandwidth(RegKind))) { auto MaxVectorElementCountMaxBW = ElementCount::get( - PowerOf2Floor(WidestRegister.getKnownMinValue() / SmallestType), + llvm::bit_floor(WidestRegister.getKnownMinValue() / SmallestType), ComputeScalableMaxVF); MaxVectorElementCountMaxBW = MinVF(MaxVectorElementCountMaxBW, MaxSafeVF); @@ -5273,9 +5322,14 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( return MaxVF; } -std::optional<unsigned> LoopVectorizationCostModel::getVScaleForTuning() const { - if (TheFunction->hasFnAttribute(Attribute::VScaleRange)) { - auto Attr = TheFunction->getFnAttribute(Attribute::VScaleRange); +/// Convenience function that returns the value of vscale_range iff +/// vscale_range.min == vscale_range.max or otherwise returns the value +/// returned by the corresponding TTI method. +static std::optional<unsigned> +getVScaleForTuning(const Loop *L, const TargetTransformInfo &TTI) { + const Function *Fn = L->getHeader()->getParent(); + if (Fn->hasFnAttribute(Attribute::VScaleRange)) { + auto Attr = Fn->getFnAttribute(Attribute::VScaleRange); auto Min = Attr.getVScaleRangeMin(); auto Max = Attr.getVScaleRangeMax(); if (Max && Min == Max) @@ -5285,31 +5339,39 @@ std::optional<unsigned> LoopVectorizationCostModel::getVScaleForTuning() const { return TTI.getVScaleForTuning(); } -bool LoopVectorizationCostModel::isMoreProfitable( +bool LoopVectorizationPlanner::isMoreProfitable( const VectorizationFactor &A, const VectorizationFactor &B) const { InstructionCost CostA = A.Cost; InstructionCost CostB = B.Cost; - unsigned MaxTripCount = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop); - - if (!A.Width.isScalable() && !B.Width.isScalable() && FoldTailByMasking && - MaxTripCount) { - // If we are folding the tail and the trip count is a known (possibly small) - // constant, the trip count will be rounded up to an integer number of - // iterations. The total cost will be PerIterationCost*ceil(TripCount/VF), - // which we compare directly. When not folding the tail, the total cost will - // be PerIterationCost*floor(TC/VF) + Scalar remainder cost, and so is - // approximated with the per-lane cost below instead of using the tripcount - // as here. - auto RTCostA = CostA * divideCeil(MaxTripCount, A.Width.getFixedValue()); - auto RTCostB = CostB * divideCeil(MaxTripCount, B.Width.getFixedValue()); + unsigned MaxTripCount = PSE.getSE()->getSmallConstantMaxTripCount(OrigLoop); + + if (!A.Width.isScalable() && !B.Width.isScalable() && MaxTripCount) { + // If the trip count is a known (possibly small) constant, the trip count + // will be rounded up to an integer number of iterations under + // FoldTailByMasking. The total cost in that case will be + // VecCost*ceil(TripCount/VF). When not folding the tail, the total + // cost will be VecCost*floor(TC/VF) + ScalarCost*(TC%VF). There will be + // some extra overheads, but for the purpose of comparing the costs of + // different VFs we can use this to compare the total loop-body cost + // expected after vectorization. + auto GetCostForTC = [MaxTripCount, this](unsigned VF, + InstructionCost VectorCost, + InstructionCost ScalarCost) { + return CM.foldTailByMasking() ? VectorCost * divideCeil(MaxTripCount, VF) + : VectorCost * (MaxTripCount / VF) + + ScalarCost * (MaxTripCount % VF); + }; + auto RTCostA = GetCostForTC(A.Width.getFixedValue(), CostA, A.ScalarCost); + auto RTCostB = GetCostForTC(B.Width.getFixedValue(), CostB, B.ScalarCost); + return RTCostA < RTCostB; } // Improve estimate for the vector width if it is scalable. unsigned EstimatedWidthA = A.Width.getKnownMinValue(); unsigned EstimatedWidthB = B.Width.getKnownMinValue(); - if (std::optional<unsigned> VScale = getVScaleForTuning()) { + if (std::optional<unsigned> VScale = getVScaleForTuning(OrigLoop, TTI)) { if (A.Width.isScalable()) EstimatedWidthA *= *VScale; if (B.Width.isScalable()) @@ -5328,9 +5390,74 @@ bool LoopVectorizationCostModel::isMoreProfitable( return (CostA * EstimatedWidthB) < (CostB * EstimatedWidthA); } -VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( +static void emitInvalidCostRemarks(SmallVector<InstructionVFPair> InvalidCosts, + OptimizationRemarkEmitter *ORE, + Loop *TheLoop) { + if (InvalidCosts.empty()) + return; + + // Emit a report of VFs with invalid costs in the loop. + + // Group the remarks per instruction, keeping the instruction order from + // InvalidCosts. + std::map<Instruction *, unsigned> Numbering; + unsigned I = 0; + for (auto &Pair : InvalidCosts) + if (!Numbering.count(Pair.first)) + Numbering[Pair.first] = I++; + + // Sort the list, first on instruction(number) then on VF. + sort(InvalidCosts, [&Numbering](InstructionVFPair &A, InstructionVFPair &B) { + if (Numbering[A.first] != Numbering[B.first]) + return Numbering[A.first] < Numbering[B.first]; + ElementCountComparator ECC; + return ECC(A.second, B.second); + }); + + // For a list of ordered instruction-vf pairs: + // [(load, vf1), (load, vf2), (store, vf1)] + // Group the instructions together to emit separate remarks for: + // load (vf1, vf2) + // store (vf1) + auto Tail = ArrayRef<InstructionVFPair>(InvalidCosts); + auto Subset = ArrayRef<InstructionVFPair>(); + do { + if (Subset.empty()) + Subset = Tail.take_front(1); + + Instruction *I = Subset.front().first; + + // If the next instruction is different, or if there are no other pairs, + // emit a remark for the collated subset. e.g. + // [(load, vf1), (load, vf2))] + // to emit: + // remark: invalid costs for 'load' at VF=(vf, vf2) + if (Subset == Tail || Tail[Subset.size()].first != I) { + std::string OutString; + raw_string_ostream OS(OutString); + assert(!Subset.empty() && "Unexpected empty range"); + OS << "Instruction with invalid costs prevented vectorization at VF=("; + for (const auto &Pair : Subset) + OS << (Pair.second == Subset.front().second ? "" : ", ") << Pair.second; + OS << "):"; + if (auto *CI = dyn_cast<CallInst>(I)) + OS << " call to " << CI->getCalledFunction()->getName(); + else + OS << " " << I->getOpcodeName(); + OS.flush(); + reportVectorizationInfo(OutString, "InvalidCost", ORE, TheLoop, I); + Tail = Tail.drop_front(Subset.size()); + Subset = {}; + } else + // Grow the subset by one element + Subset = Tail.take_front(Subset.size() + 1); + } while (!Tail.empty()); +} + +VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor( const ElementCountSet &VFCandidates) { - InstructionCost ExpectedCost = expectedCost(ElementCount::getFixed(1)).first; + InstructionCost ExpectedCost = + CM.expectedCost(ElementCount::getFixed(1)).first; LLVM_DEBUG(dbgs() << "LV: Scalar loop costs: " << ExpectedCost << ".\n"); assert(ExpectedCost.isValid() && "Unexpected invalid cost for scalar loop"); assert(VFCandidates.count(ElementCount::getFixed(1)) && @@ -5340,7 +5467,7 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( ExpectedCost); VectorizationFactor ChosenFactor = ScalarCost; - bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled; + bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; if (ForceVectorization && VFCandidates.size() > 1) { // Ignore scalar width, because the user explicitly wants vectorization. // Initialize cost to max so that VF = 2 is, at least, chosen during cost @@ -5354,12 +5481,13 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( if (i.isScalar()) continue; - VectorizationCostTy C = expectedCost(i, &InvalidCosts); + LoopVectorizationCostModel::VectorizationCostTy C = + CM.expectedCost(i, &InvalidCosts); VectorizationFactor Candidate(i, C.first, ScalarCost.ScalarCost); #ifndef NDEBUG unsigned AssumedMinimumVscale = 1; - if (std::optional<unsigned> VScale = getVScaleForTuning()) + if (std::optional<unsigned> VScale = getVScaleForTuning(OrigLoop, TTI)) AssumedMinimumVscale = *VScale; unsigned Width = Candidate.Width.isScalable() @@ -5388,70 +5516,13 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( ChosenFactor = Candidate; } - // Emit a report of VFs with invalid costs in the loop. - if (!InvalidCosts.empty()) { - // Group the remarks per instruction, keeping the instruction order from - // InvalidCosts. - std::map<Instruction *, unsigned> Numbering; - unsigned I = 0; - for (auto &Pair : InvalidCosts) - if (!Numbering.count(Pair.first)) - Numbering[Pair.first] = I++; - - // Sort the list, first on instruction(number) then on VF. - llvm::sort(InvalidCosts, - [&Numbering](InstructionVFPair &A, InstructionVFPair &B) { - if (Numbering[A.first] != Numbering[B.first]) - return Numbering[A.first] < Numbering[B.first]; - ElementCountComparator ECC; - return ECC(A.second, B.second); - }); - - // For a list of ordered instruction-vf pairs: - // [(load, vf1), (load, vf2), (store, vf1)] - // Group the instructions together to emit separate remarks for: - // load (vf1, vf2) - // store (vf1) - auto Tail = ArrayRef<InstructionVFPair>(InvalidCosts); - auto Subset = ArrayRef<InstructionVFPair>(); - do { - if (Subset.empty()) - Subset = Tail.take_front(1); - - Instruction *I = Subset.front().first; - - // If the next instruction is different, or if there are no other pairs, - // emit a remark for the collated subset. e.g. - // [(load, vf1), (load, vf2))] - // to emit: - // remark: invalid costs for 'load' at VF=(vf, vf2) - if (Subset == Tail || Tail[Subset.size()].first != I) { - std::string OutString; - raw_string_ostream OS(OutString); - assert(!Subset.empty() && "Unexpected empty range"); - OS << "Instruction with invalid costs prevented vectorization at VF=("; - for (const auto &Pair : Subset) - OS << (Pair.second == Subset.front().second ? "" : ", ") - << Pair.second; - OS << "):"; - if (auto *CI = dyn_cast<CallInst>(I)) - OS << " call to " << CI->getCalledFunction()->getName(); - else - OS << " " << I->getOpcodeName(); - OS.flush(); - reportVectorizationInfo(OutString, "InvalidCost", ORE, TheLoop, I); - Tail = Tail.drop_front(Subset.size()); - Subset = {}; - } else - // Grow the subset by one element - Subset = Tail.take_front(Subset.size() + 1); - } while (!Tail.empty()); - } + emitInvalidCostRemarks(InvalidCosts, ORE, OrigLoop); - if (!EnableCondStoresVectorization && NumPredStores) { - reportVectorizationFailure("There are conditional stores.", + if (!EnableCondStoresVectorization && CM.hasPredStores()) { + reportVectorizationFailure( + "There are conditional stores.", "store that is conditionally executed prevents vectorization", - "ConditionalStore", ORE, TheLoop); + "ConditionalStore", ORE, OrigLoop); ChosenFactor = ScalarCost; } @@ -5463,11 +5534,11 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( return ChosenFactor; } -bool LoopVectorizationCostModel::isCandidateForEpilogueVectorization( - const Loop &L, ElementCount VF) const { +bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization( + ElementCount VF) const { // Cross iteration phis such as reductions need special handling and are // currently unsupported. - if (any_of(L.getHeader()->phis(), + if (any_of(OrigLoop->getHeader()->phis(), [&](PHINode &Phi) { return Legal->isFixedOrderRecurrence(&Phi); })) return false; @@ -5475,20 +5546,21 @@ bool LoopVectorizationCostModel::isCandidateForEpilogueVectorization( // currently unsupported. for (const auto &Entry : Legal->getInductionVars()) { // Look for uses of the value of the induction at the last iteration. - Value *PostInc = Entry.first->getIncomingValueForBlock(L.getLoopLatch()); + Value *PostInc = + Entry.first->getIncomingValueForBlock(OrigLoop->getLoopLatch()); for (User *U : PostInc->users()) - if (!L.contains(cast<Instruction>(U))) + if (!OrigLoop->contains(cast<Instruction>(U))) return false; // Look for uses of penultimate value of the induction. for (User *U : Entry.first->users()) - if (!L.contains(cast<Instruction>(U))) + if (!OrigLoop->contains(cast<Instruction>(U))) return false; } // Epilogue vectorization code has not been auditted to ensure it handles // non-latch exits properly. It may be fine, but it needs auditted and // tested. - if (L.getExitingBlock() != L.getLoopLatch()) + if (OrigLoop->getExitingBlock() != OrigLoop->getLoopLatch()) return false; return true; @@ -5507,62 +5579,59 @@ bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable( // We also consider epilogue vectorization unprofitable for targets that don't // consider interleaving beneficial (eg. MVE). - if (TTI.getMaxInterleaveFactor(VF.getKnownMinValue()) <= 1) + if (TTI.getMaxInterleaveFactor(VF) <= 1) return false; - // FIXME: We should consider changing the threshold for scalable - // vectors to take VScaleForTuning into account. - if (VF.getKnownMinValue() >= EpilogueVectorizationMinVF) + + unsigned Multiplier = 1; + if (VF.isScalable()) + Multiplier = getVScaleForTuning(TheLoop, TTI).value_or(1); + if ((Multiplier * VF.getKnownMinValue()) >= EpilogueVectorizationMinVF) return true; return false; } -VectorizationFactor -LoopVectorizationCostModel::selectEpilogueVectorizationFactor( - const ElementCount MainLoopVF, const LoopVectorizationPlanner &LVP) { +VectorizationFactor LoopVectorizationPlanner::selectEpilogueVectorizationFactor( + const ElementCount MainLoopVF, unsigned IC) { VectorizationFactor Result = VectorizationFactor::Disabled(); if (!EnableEpilogueVectorization) { - LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is disabled.\n";); + LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is disabled.\n"); return Result; } - if (!isScalarEpilogueAllowed()) { - LLVM_DEBUG( - dbgs() << "LEV: Unable to vectorize epilogue because no epilogue is " - "allowed.\n";); + if (!CM.isScalarEpilogueAllowed()) { + LLVM_DEBUG(dbgs() << "LEV: Unable to vectorize epilogue because no " + "epilogue is allowed.\n"); return Result; } // Not really a cost consideration, but check for unsupported cases here to // simplify the logic. - if (!isCandidateForEpilogueVectorization(*TheLoop, MainLoopVF)) { - LLVM_DEBUG( - dbgs() << "LEV: Unable to vectorize epilogue because the loop is " - "not a supported candidate.\n";); + if (!isCandidateForEpilogueVectorization(MainLoopVF)) { + LLVM_DEBUG(dbgs() << "LEV: Unable to vectorize epilogue because the loop " + "is not a supported candidate.\n"); return Result; } if (EpilogueVectorizationForceVF > 1) { - LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n";); + LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n"); ElementCount ForcedEC = ElementCount::getFixed(EpilogueVectorizationForceVF); - if (LVP.hasPlanWithVF(ForcedEC)) + if (hasPlanWithVF(ForcedEC)) return {ForcedEC, 0, 0}; else { - LLVM_DEBUG( - dbgs() - << "LEV: Epilogue vectorization forced factor is not viable.\n";); + LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization forced factor is not " + "viable.\n"); return Result; } } - if (TheLoop->getHeader()->getParent()->hasOptSize() || - TheLoop->getHeader()->getParent()->hasMinSize()) { + if (OrigLoop->getHeader()->getParent()->hasOptSize() || + OrigLoop->getHeader()->getParent()->hasMinSize()) { LLVM_DEBUG( - dbgs() - << "LEV: Epilogue vectorization skipped due to opt for size.\n";); + dbgs() << "LEV: Epilogue vectorization skipped due to opt for size.\n"); return Result; } - if (!isEpilogueVectorizationProfitable(MainLoopVF)) { + if (!CM.isEpilogueVectorizationProfitable(MainLoopVF)) { LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is not profitable for " "this loop\n"); return Result; @@ -5574,21 +5643,48 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( ElementCount EstimatedRuntimeVF = MainLoopVF; if (MainLoopVF.isScalable()) { EstimatedRuntimeVF = ElementCount::getFixed(MainLoopVF.getKnownMinValue()); - if (std::optional<unsigned> VScale = getVScaleForTuning()) + if (std::optional<unsigned> VScale = getVScaleForTuning(OrigLoop, TTI)) EstimatedRuntimeVF *= *VScale; } - for (auto &NextVF : ProfitableVFs) - if (((!NextVF.Width.isScalable() && MainLoopVF.isScalable() && - ElementCount::isKnownLT(NextVF.Width, EstimatedRuntimeVF)) || - ElementCount::isKnownLT(NextVF.Width, MainLoopVF)) && - (Result.Width.isScalar() || isMoreProfitable(NextVF, Result)) && - LVP.hasPlanWithVF(NextVF.Width)) + ScalarEvolution &SE = *PSE.getSE(); + Type *TCType = Legal->getWidestInductionType(); + const SCEV *RemainingIterations = nullptr; + for (auto &NextVF : ProfitableVFs) { + // Skip candidate VFs without a corresponding VPlan. + if (!hasPlanWithVF(NextVF.Width)) + continue; + + // Skip candidate VFs with widths >= the estimate runtime VF (scalable + // vectors) or the VF of the main loop (fixed vectors). + if ((!NextVF.Width.isScalable() && MainLoopVF.isScalable() && + ElementCount::isKnownGE(NextVF.Width, EstimatedRuntimeVF)) || + ElementCount::isKnownGE(NextVF.Width, MainLoopVF)) + continue; + + // If NextVF is greater than the number of remaining iterations, the + // epilogue loop would be dead. Skip such factors. + if (!MainLoopVF.isScalable() && !NextVF.Width.isScalable()) { + // TODO: extend to support scalable VFs. + if (!RemainingIterations) { + const SCEV *TC = createTripCountSCEV(TCType, PSE, OrigLoop); + RemainingIterations = SE.getURemExpr( + TC, SE.getConstant(TCType, MainLoopVF.getKnownMinValue() * IC)); + } + if (SE.isKnownPredicate( + CmpInst::ICMP_UGT, + SE.getConstant(TCType, NextVF.Width.getKnownMinValue()), + RemainingIterations)) + continue; + } + + if (Result.Width.isScalar() || isMoreProfitable(NextVF, Result)) Result = NextVF; + } if (Result != VectorizationFactor::Disabled()) LLVM_DEBUG(dbgs() << "LEV: Vectorizing epilogue loop with VF = " - << Result.Width << "\n";); + << Result.Width << "\n"); return Result; } @@ -5688,7 +5784,7 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, return 1; // We used the distance for the interleave count. - if (Legal->getMaxSafeDepDistBytes() != -1U) + if (!Legal->isSafeForAnyVectorWidth()) return 1; auto BestKnownTC = getSmallBestKnownTC(*PSE.getSE(), TheLoop); @@ -5750,20 +5846,19 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, if (R.LoopInvariantRegs.find(pair.first) != R.LoopInvariantRegs.end()) LoopInvariantRegs = R.LoopInvariantRegs[pair.first]; - unsigned TmpIC = PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs) / MaxLocalUsers); + unsigned TmpIC = llvm::bit_floor((TargetNumRegisters - LoopInvariantRegs) / + MaxLocalUsers); // Don't count the induction variable as interleaved. if (EnableIndVarRegisterHeur) { - TmpIC = - PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs - 1) / - std::max(1U, (MaxLocalUsers - 1))); + TmpIC = llvm::bit_floor((TargetNumRegisters - LoopInvariantRegs - 1) / + std::max(1U, (MaxLocalUsers - 1))); } IC = std::min(IC, TmpIC); } // Clamp the interleave ranges to reasonable counts. - unsigned MaxInterleaveCount = - TTI.getMaxInterleaveFactor(VF.getKnownMinValue()); + unsigned MaxInterleaveCount = TTI.getMaxInterleaveFactor(VF); // Check if the user has overridden the max. if (VF.isScalar()) { @@ -5834,8 +5929,8 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, // We assume that the cost overhead is 1 and we use the cost model // to estimate the cost of the loop and interleave until the cost of the // loop overhead is about 5% of the cost of the loop. - unsigned SmallIC = std::min( - IC, (unsigned)PowerOf2Floor(SmallLoopCost / *LoopCost.getValue())); + unsigned SmallIC = std::min(IC, (unsigned)llvm::bit_floor<uint64_t>( + SmallLoopCost / *LoopCost.getValue())); // Interleave until store/load ports (estimated by max interleave count) are // saturated. @@ -5953,7 +6048,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { // Saves the list of values that are used in the loop but are defined outside // the loop (not including non-instruction values such as arguments and // constants). - SmallPtrSet<Value *, 8> LoopInvariants; + SmallSetVector<Instruction *, 8> LoopInvariants; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { for (Instruction &I : BB->instructionsWithoutDebug()) { @@ -6079,11 +6174,16 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { for (auto *Inst : LoopInvariants) { // FIXME: The target might use more than one register for the type // even in the scalar case. - unsigned Usage = - VFs[i].isScalar() ? 1 : GetRegUsage(Inst->getType(), VFs[i]); + bool IsScalar = all_of(Inst->users(), [&](User *U) { + auto *I = cast<Instruction>(U); + return TheLoop != LI->getLoopFor(I->getParent()) || + isScalarAfterVectorization(I, VFs[i]); + }); + + ElementCount VF = IsScalar ? ElementCount::getFixed(1) : VFs[i]; unsigned ClassID = - TTI.getRegisterClassForType(VFs[i].isVector(), Inst->getType()); - Invariant[ClassID] += Usage; + TTI.getRegisterClassForType(VF.isVector(), Inst->getType()); + Invariant[ClassID] += GetRegUsage(Inst->getType(), VF); } LLVM_DEBUG({ @@ -6134,8 +6234,7 @@ void LoopVectorizationCostModel::collectInstsToScalarize(ElementCount VF) { // instructions to scalarize, there's nothing to do. Collection may already // have occurred if we have a user-selected VF and are now computing the // expected cost for interleaving. - if (VF.isScalar() || VF.isZero() || - InstsToScalarize.find(VF) != InstsToScalarize.end()) + if (VF.isScalar() || VF.isZero() || InstsToScalarize.contains(VF)) return; // Initialize a mapping for VF in InstsToScalalarize. If we find that it's @@ -6224,7 +6323,7 @@ InstructionCost LoopVectorizationCostModel::computePredInstDiscount( Instruction *I = Worklist.pop_back_val(); // If we've already analyzed the instruction, there's nothing to do. - if (ScalarCosts.find(I) != ScalarCosts.end()) + if (ScalarCosts.contains(I)) continue; // Compute the cost of the vector instruction. Note that this cost already @@ -6362,11 +6461,6 @@ static const SCEV *getAddressAccessSCEV( return PSE.getSCEV(Ptr); } -static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { - return Legal->hasStride(I->getOperand(0)) || - Legal->hasStride(I->getOperand(1)); -} - InstructionCost LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, ElementCount VF) { @@ -6460,7 +6554,7 @@ LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, InstructionCost LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, ElementCount VF) { - assert(Legal->isUniformMemOp(*I)); + assert(Legal->isUniformMemOp(*I, VF)); Type *ValTy = getLoadStoreType(I); auto *VectorTy = cast<VectorType>(ToVectorTy(ValTy, VF)); @@ -6475,7 +6569,7 @@ LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, } StoreInst *SI = cast<StoreInst>(I); - bool isLoopInvariantStoreValue = Legal->isUniform(SI->getValueOperand()); + bool isLoopInvariantStoreValue = Legal->isInvariant(SI->getValueOperand()); return TTI.getAddressComputationCost(ValTy) + TTI.getMemoryOpCost(Instruction::Store, ValTy, Alignment, AS, CostKind) + @@ -6502,11 +6596,6 @@ LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, InstructionCost LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, ElementCount VF) { - // TODO: Once we have support for interleaving with scalable vectors - // we can calculate the cost properly here. - if (VF.isScalable()) - return InstructionCost::getInvalid(); - Type *ValTy = getLoadStoreType(I); auto *VectorTy = cast<VectorType>(ToVectorTy(ValTy, VF)); unsigned AS = getLoadStoreAddressSpace(I); @@ -6836,7 +6925,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { if (isa<StoreInst>(&I) && isScalarWithPredication(&I, VF)) NumPredStores++; - if (Legal->isUniformMemOp(I)) { + if (Legal->isUniformMemOp(I, VF)) { auto isLegalToScalarize = [&]() { if (!VF.isScalable()) // Scalarization of fixed length vectors "just works". @@ -7134,8 +7223,12 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, case Instruction::And: case Instruction::Or: case Instruction::Xor: { - // Since we will replace the stride by 1 the multiplication should go away. - if (I->getOpcode() == Instruction::Mul && isStrideMul(I, Legal)) + // If we're speculating on the stride being 1, the multiplication may + // fold away. We can generalize this for all operations using the notion + // of neutral elements. (TODO) + if (I->getOpcode() == Instruction::Mul && + (PSE.getSCEV(I->getOperand(0))->isOne() || + PSE.getSCEV(I->getOperand(1))->isOne())) return 0; // Detect reduction patterns @@ -7146,7 +7239,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, // second vector operand. One example of this are shifts on x86. Value *Op2 = I->getOperand(1); auto Op2Info = TTI.getOperandInfo(Op2); - if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2)) + if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && + Legal->isInvariant(Op2)) Op2Info.Kind = TargetTransformInfo::OK_UniformValue; SmallVector<const Value *, 4> Operands(I->operand_values()); @@ -7304,7 +7398,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { - SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); + // Leave SrcVecTy unchanged - we only shrink the destination element + // type. VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); } @@ -7316,9 +7411,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, if (RecurrenceDescriptor::isFMulAddIntrinsic(I)) if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind)) return *RedCost; - bool NeedToScalarize; + Function *Variant; CallInst *CI = cast<CallInst>(I); - InstructionCost CallCost = getVectorCallCost(CI, VF, NeedToScalarize); + InstructionCost CallCost = getVectorCallCost(CI, VF, &Variant); if (getVectorIntrinsicIDForCall(CI, TLI)) { InstructionCost IntrinsicCost = getVectorIntrinsicCost(CI, VF); return std::min(CallCost, IntrinsicCost); @@ -7339,37 +7434,6 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, } // end of switch. } -char LoopVectorize::ID = 0; - -static const char lv_name[] = "Loop Vectorization"; - -INITIALIZE_PASS_BEGIN(LoopVectorize, LV_NAME, lv_name, false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) -INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(InjectTLIMappingsLegacy) -INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) - -namespace llvm { - -Pass *createLoopVectorizePass() { return new LoopVectorize(); } - -Pass *createLoopVectorizePass(bool InterleaveOnlyWhenForced, - bool VectorizeOnlyWhenForced) { - return new LoopVectorize(InterleaveOnlyWhenForced, VectorizeOnlyWhenForced); -} - -} // end namespace llvm - void LoopVectorizationCostModel::collectValuesToIgnore() { // Ignore ephemeral values. CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); @@ -7462,7 +7526,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { // reasonable one. if (UserVF.isZero()) { VF = ElementCount::getFixed(determineVPlanVF( - TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) + TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) .getFixedValue(), CM)); LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n"); @@ -7497,13 +7561,16 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { std::optional<VectorizationFactor> LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { assert(OrigLoop->isInnermost() && "Inner loop expected."); + CM.collectValuesToIgnore(); + CM.collectElementTypesForWidening(); + FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC); if (!MaxFactors) // Cases that should not to be vectorized nor interleaved. return std::nullopt; // Invalidate interleave groups if all blocks of loop will be predicated. if (CM.blockNeedsPredicationForAnyReason(OrigLoop->getHeader()) && - !useMaskedInterleavedAccesses(*TTI)) { + !useMaskedInterleavedAccesses(TTI)) { LLVM_DEBUG( dbgs() << "LV: Invalidate all interleaved groups due to fold-tail by masking " @@ -7527,6 +7594,12 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); CM.collectInLoopReductions(); buildVPlansWithVPRecipes(UserVF, UserVF); + if (!hasPlanWithVF(UserVF)) { + LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << UserVF + << ".\n"); + return std::nullopt; + } + LLVM_DEBUG(printPlans(dbgs())); return {{UserVF, 0, 0}}; } else @@ -7562,8 +7635,13 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { return VectorizationFactor::Disabled(); // Select the optimal vectorization factor. - VectorizationFactor VF = CM.selectVectorizationFactor(VFCandidates); + VectorizationFactor VF = selectVectorizationFactor(VFCandidates); assert((VF.Width.isScalar() || VF.ScalarCost > 0) && "when vectorizing, the scalar cost must be non-zero."); + if (!hasPlanWithVF(VF.Width)) { + LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << VF.Width + << ".\n"); + return std::nullopt; + } return VF; } @@ -7614,43 +7692,51 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) { } } -void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, - VPlan &BestVPlan, - InnerLoopVectorizer &ILV, - DominatorTree *DT, - bool IsEpilogueVectorization) { +SCEV2ValueTy LoopVectorizationPlanner::executePlan( + ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan, + InnerLoopVectorizer &ILV, DominatorTree *DT, bool IsEpilogueVectorization, + DenseMap<const SCEV *, Value *> *ExpandedSCEVs) { assert(BestVPlan.hasVF(BestVF) && "Trying to execute plan with unsupported VF"); assert(BestVPlan.hasUF(BestUF) && "Trying to execute plan with unsupported UF"); + assert( + (IsEpilogueVectorization || !ExpandedSCEVs) && + "expanded SCEVs to reuse can only be used during epilogue vectorization"); LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF << ", UF=" << BestUF << '\n'); - // Workaround! Compute the trip count of the original loop and cache it - // before we start modifying the CFG. This code has a systemic problem - // wherein it tries to run analysis over partially constructed IR; this is - // wrong, and not simply for SCEV. The trip count of the original loop - // simply happens to be prone to hitting this in practice. In theory, we - // can hit the same issue for any SCEV, or ValueTracking query done during - // mutation. See PR49900. - ILV.getOrCreateTripCount(OrigLoop->getLoopPreheader()); - if (!IsEpilogueVectorization) VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE); // Perform the actual loop transformation. + VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan}; + + // 0. Generate SCEV-dependent code into the preheader, including TripCount, + // before making any changes to the CFG. + if (!BestVPlan.getPreheader()->empty()) { + State.CFG.PrevBB = OrigLoop->getLoopPreheader(); + State.Builder.SetInsertPoint(OrigLoop->getLoopPreheader()->getTerminator()); + BestVPlan.getPreheader()->execute(&State); + } + if (!ILV.getTripCount()) + ILV.setTripCount(State.get(BestVPlan.getTripCount(), {0, 0})); + else + assert(IsEpilogueVectorization && "should only re-use the existing trip " + "count during epilogue vectorization"); // 1. Set up the skeleton for vectorization, including vector pre-header and // middle block. The vector loop is created during VPlan execution. - VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan}; Value *CanonicalIVStartValue; std::tie(State.CFG.PrevBB, CanonicalIVStartValue) = - ILV.createVectorizedLoopSkeleton(); + ILV.createVectorizedLoopSkeleton(ExpandedSCEVs ? *ExpandedSCEVs + : State.ExpandedSCEVs); // Only use noalias metadata when using memory checks guaranteeing no overlap // across all iterations. const LoopAccessInfo *LAI = ILV.Legal->getLAI(); + std::unique_ptr<LoopVersioning> LVer = nullptr; if (LAI && !LAI->getRuntimePointerChecking()->getChecks().empty() && !LAI->getRuntimePointerChecking()->getDiffChecks()) { @@ -7658,9 +7744,10 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, // still use it to add the noalias metadata. // TODO: Find a better way to re-use LoopVersioning functionality to add // metadata. - State.LVer = std::make_unique<LoopVersioning>( + LVer = std::make_unique<LoopVersioning>( *LAI, LAI->getRuntimePointerChecking()->getChecks(), OrigLoop, LI, DT, PSE.getSE()); + State.LVer = &*LVer; State.LVer->prepareNoAliasMetadata(); } @@ -7677,10 +7764,9 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, //===------------------------------------------------===// // 2. Copy and widen instructions from the old loop into the new loop. - BestVPlan.prepareToExecute(ILV.getOrCreateTripCount(nullptr), - ILV.getOrCreateVectorTripCount(nullptr), - CanonicalIVStartValue, State, - IsEpilogueVectorization); + BestVPlan.prepareToExecute( + ILV.getTripCount(), ILV.getOrCreateVectorTripCount(nullptr), + CanonicalIVStartValue, State, IsEpilogueVectorization); BestVPlan.execute(&State); @@ -7706,13 +7792,18 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, LoopVectorizeHints Hints(L, true, *ORE); Hints.setAlreadyVectorized(); } - AddRuntimeUnrollDisableMetaData(L); + TargetTransformInfo::UnrollingPreferences UP; + TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE); + if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue) + AddRuntimeUnrollDisableMetaData(L); // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. ILV.fixVectorizedLoop(State, BestVPlan); ILV.printDebugTracesAtEnd(); + + return State.ExpandedSCEVs; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -7725,8 +7816,6 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) { } #endif -Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } - //===--------------------------------------------------------------------===// // EpilogueVectorizerMainLoop //===--------------------------------------------------------------------===// @@ -7734,7 +7823,8 @@ Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } /// This function is partially responsible for generating the control flow /// depicted in https://llvm.org/docs/Vectorizers.html#epilogue-vectorization. std::pair<BasicBlock *, Value *> -EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { +EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton( + const SCEV2ValueTy &ExpandedSCEVs) { createVectorLoopSkeleton(""); // Generate the code to check the minimum iteration count of the vector @@ -7795,7 +7885,7 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, assert(Bypass && "Expected valid bypass basic block."); ElementCount VFactor = ForEpilogue ? EPI.EpilogueVF : VF; unsigned UFactor = ForEpilogue ? EPI.EpilogueUF : UF; - Value *Count = getOrCreateTripCount(LoopVectorPreHeader); + Value *Count = getTripCount(); // Reuse existing vector loop preheader for TC checks. // Note that new preheader block is generated for vector loop. BasicBlock *const TCCheckBlock = LoopVectorPreHeader; @@ -7803,8 +7893,10 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, // Generate code to check if the loop's trip count is less than VF * UF of the // main vector loop. - auto P = Cost->requiresScalarEpilogue(ForEpilogue ? EPI.EpilogueVF : VF) ? - ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT; + auto P = Cost->requiresScalarEpilogue(ForEpilogue ? EPI.EpilogueVF.isVector() + : VF.isVector()) + ? ICmpInst::ICMP_ULE + : ICmpInst::ICMP_ULT; Value *CheckMinIters = Builder.CreateICmp( P, Count, createStepForVF(Builder, Count->getType(), VFactor, UFactor), @@ -7824,7 +7916,7 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, // Update dominator for Bypass & LoopExit. DT->changeImmediateDominator(Bypass, TCCheckBlock); - if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF)) + if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector())) // For loops with multiple exits, there's no edge from the middle block // to exit blocks (as the epilogue must run) and thus no need to update // the immediate dominator of the exit blocks. @@ -7852,7 +7944,8 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, /// This function is partially responsible for generating the control flow /// depicted in https://llvm.org/docs/Vectorizers.html#epilogue-vectorization. std::pair<BasicBlock *, Value *> -EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { +EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton( + const SCEV2ValueTy &ExpandedSCEVs) { createVectorLoopSkeleton("vec.epilog."); // Now, compare the remaining count and if there aren't enough iterations to @@ -7891,7 +7984,7 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { DT->changeImmediateDominator(LoopScalarPreHeader, EPI.EpilogueIterationCountCheck); - if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF)) + if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector())) // If there is an epilogue which must run, there's no edge from the // middle block to exit blocks and thus no need to update the immediate // dominator of the exit blocks. @@ -7950,7 +8043,8 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { // check, then the resume value for the induction variable comes from // the trip count of the main vector loop, hence passing the AdditionalBypass // argument. - createInductionResumeValues({VecEpilogueIterationCountCheck, + createInductionResumeValues(ExpandedSCEVs, + {VecEpilogueIterationCountCheck, EPI.VectorTripCount} /* AdditionalBypass */); return {completeLoopSkeleton(), EPResumeVal}; @@ -7972,8 +8066,9 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( // Generate code to check if the loop's trip count is less than VF * UF of the // vector epilogue loop. - auto P = Cost->requiresScalarEpilogue(EPI.EpilogueVF) ? - ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT; + auto P = Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector()) + ? ICmpInst::ICMP_ULE + : ICmpInst::ICMP_ULT; Value *CheckMinIters = Builder.CreateICmp(P, Count, @@ -8008,8 +8103,7 @@ bool LoopVectorizationPlanner::getDecisionAndClampRange( assert(!Range.isEmpty() && "Trying to test an empty VF range."); bool PredicateAtRangeStart = Predicate(Range.Start); - for (ElementCount TmpVF = Range.Start * 2; - ElementCount::isKnownLT(TmpVF, Range.End); TmpVF *= 2) + for (ElementCount TmpVF : VFRange(Range.Start * 2, Range.End)) if (Predicate(TmpVF) != PredicateAtRangeStart) { Range.End = TmpVF; break; @@ -8025,16 +8119,16 @@ bool LoopVectorizationPlanner::getDecisionAndClampRange( /// buildVPlan(). void LoopVectorizationPlanner::buildVPlans(ElementCount MinVF, ElementCount MaxVF) { - auto MaxVFPlusOne = MaxVF.getWithIncrement(1); - for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFPlusOne);) { - VFRange SubRange = {VF, MaxVFPlusOne}; + auto MaxVFTimes2 = MaxVF * 2; + for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) { + VFRange SubRange = {VF, MaxVFTimes2}; VPlans.push_back(buildVPlan(SubRange)); VF = SubRange.End; } } VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, - VPlanPtr &Plan) { + VPlan &Plan) { assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); // Look for cached value. @@ -8058,7 +8152,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, if (OrigLoop->isLoopExiting(Src)) return EdgeMaskCache[Edge] = SrcMask; - VPValue *EdgeMask = Plan->getOrAddVPValue(BI->getCondition()); + VPValue *EdgeMask = Plan.getVPValueOrAddLiveIn(BI->getCondition()); assert(EdgeMask && "No Edge Mask found for condition"); if (BI->getSuccessor(0) != Dst) @@ -8069,7 +8163,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, // 'select i1 SrcMask, i1 EdgeMask, i1 false'. // The select version does not introduce new UB if SrcMask is false and // EdgeMask is poison. Using 'and' here introduces undefined behavior. - VPValue *False = Plan->getOrAddVPValue( + VPValue *False = Plan.getVPValueOrAddLiveIn( ConstantInt::getFalse(BI->getCondition()->getType())); EdgeMask = Builder.createSelect(SrcMask, EdgeMask, False, BI->getDebugLoc()); @@ -8078,7 +8172,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, return EdgeMaskCache[Edge] = EdgeMask; } -VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { +VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) { assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); // Look for cached value. @@ -8098,29 +8192,28 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { // If we're using the active lane mask for control flow, then we get the // mask from the active lane mask PHI that is cached in the VPlan. - PredicationStyle EmitGetActiveLaneMask = CM.TTI.emitGetActiveLaneMask(); - if (EmitGetActiveLaneMask == PredicationStyle::DataAndControlFlow) - return BlockMaskCache[BB] = Plan->getActiveLaneMaskPhi(); + TailFoldingStyle TFStyle = CM.getTailFoldingStyle(); + if (useActiveLaneMaskForControlFlow(TFStyle)) + return BlockMaskCache[BB] = Plan.getActiveLaneMaskPhi(); // Introduce the early-exit compare IV <= BTC to form header block mask. // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by // constructing the desired canonical IV in the header block as its first // non-phi instructions. - VPBasicBlock *HeaderVPBB = - Plan->getVectorLoopRegion()->getEntryBasicBlock(); + VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi(); - auto *IV = new VPWidenCanonicalIVRecipe(Plan->getCanonicalIV()); + auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV()); HeaderVPBB->insert(IV, HeaderVPBB->getFirstNonPhi()); VPBuilder::InsertPointGuard Guard(Builder); Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint); - if (EmitGetActiveLaneMask != PredicationStyle::None) { - VPValue *TC = Plan->getOrCreateTripCount(); + if (useActiveLaneMask(TFStyle)) { + VPValue *TC = Plan.getTripCount(); BlockMask = Builder.createNaryOp(VPInstruction::ActiveLaneMask, {IV, TC}, nullptr, "active.lane.mask"); } else { - VPValue *BTC = Plan->getOrCreateBackedgeTakenCount(); + VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC}); } return BlockMaskCache[BB] = BlockMask; @@ -8168,7 +8261,7 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, VPValue *Mask = nullptr; if (Legal->isMaskRequired(I)) - Mask = createBlockInMask(I->getParent(), Plan); + Mask = createBlockInMask(I->getParent(), *Plan); // Determine if the pointer operand of the access is either consecutive or // reverse consecutive. @@ -8189,22 +8282,11 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, /// Creates a VPWidenIntOrFpInductionRecpipe for \p Phi. If needed, it will also /// insert a recipe to expand the step for the induction recipe. -static VPWidenIntOrFpInductionRecipe *createWidenInductionRecipes( - PHINode *Phi, Instruction *PhiOrTrunc, VPValue *Start, - const InductionDescriptor &IndDesc, LoopVectorizationCostModel &CM, - VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop, VFRange &Range) { - // Returns true if an instruction \p I should be scalarized instead of - // vectorized for the chosen vectorization factor. - auto ShouldScalarizeInstruction = [&CM](Instruction *I, ElementCount VF) { - return CM.isScalarAfterVectorization(I, VF) || - CM.isProfitableToScalarize(I, VF); - }; - - bool NeedsScalarIVOnly = LoopVectorizationPlanner::getDecisionAndClampRange( - [&](ElementCount VF) { - return ShouldScalarizeInstruction(PhiOrTrunc, VF); - }, - Range); +static VPWidenIntOrFpInductionRecipe * +createWidenInductionRecipes(PHINode *Phi, Instruction *PhiOrTrunc, + VPValue *Start, const InductionDescriptor &IndDesc, + VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop, + VFRange &Range) { assert(IndDesc.getStartValue() == Phi->getIncomingValueForBlock(OrigLoop.getLoopPreheader())); assert(SE.isLoopInvariant(IndDesc.getStep(), &OrigLoop) && @@ -8213,12 +8295,10 @@ static VPWidenIntOrFpInductionRecipe *createWidenInductionRecipes( VPValue *Step = vputils::getOrCreateVPValueForSCEVExpr(Plan, IndDesc.getStep(), SE); if (auto *TruncI = dyn_cast<TruncInst>(PhiOrTrunc)) { - return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, TruncI, - !NeedsScalarIVOnly); + return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, TruncI); } assert(isa<PHINode>(PhiOrTrunc) && "must be a phi node here"); - return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, - !NeedsScalarIVOnly); + return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc); } VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI( @@ -8227,14 +8307,13 @@ VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI( // Check if this is an integer or fp induction. If so, build the recipe that // produces its scalar and vector values. if (auto *II = Legal->getIntOrFpInductionDescriptor(Phi)) - return createWidenInductionRecipes(Phi, Phi, Operands[0], *II, CM, Plan, + return createWidenInductionRecipes(Phi, Phi, Operands[0], *II, Plan, *PSE.getSE(), *OrigLoop, Range); // Check if this is pointer induction. If so, build the recipe for it. if (auto *II = Legal->getPointerInductionDescriptor(Phi)) { VPValue *Step = vputils::getOrCreateVPValueForSCEVExpr(Plan, II->getStep(), *PSE.getSE()); - assert(isa<SCEVConstant>(II->getStep())); return new VPWidenPointerInductionRecipe( Phi, Operands[0], Step, *II, LoopVectorizationPlanner::getDecisionAndClampRange( @@ -8267,9 +8346,9 @@ VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate( auto *Phi = cast<PHINode>(I->getOperand(0)); const InductionDescriptor &II = *Legal->getIntOrFpInductionDescriptor(Phi); - VPValue *Start = Plan.getOrAddVPValue(II.getStartValue()); - return createWidenInductionRecipes(Phi, I, Start, II, CM, Plan, - *PSE.getSE(), *OrigLoop, Range); + VPValue *Start = Plan.getVPValueOrAddLiveIn(II.getStartValue()); + return createWidenInductionRecipes(Phi, I, Start, II, Plan, *PSE.getSE(), + *OrigLoop, Range); } return nullptr; } @@ -8309,7 +8388,7 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi, for (unsigned In = 0; In < NumIncoming; In++) { VPValue *EdgeMask = - createEdgeMask(Phi->getIncomingBlock(In), Phi->getParent(), Plan); + createEdgeMask(Phi->getIncomingBlock(In), Phi->getParent(), *Plan); assert((EdgeMask || NumIncoming == 1) && "Multiple predecessors with one having a full mask"); OperandsWithMask.push_back(Operands[In]); @@ -8321,8 +8400,8 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi, VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands, - VFRange &Range) const { - + VFRange &Range, + VPlanPtr &Plan) { bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange( [this, CI](ElementCount VF) { return CM.isScalarWithPredication(CI, VF); @@ -8339,17 +8418,17 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, ID == Intrinsic::experimental_noalias_scope_decl)) return nullptr; - ArrayRef<VPValue *> Ops = Operands.take_front(CI->arg_size()); + SmallVector<VPValue *, 4> Ops(Operands.take_front(CI->arg_size())); // Is it beneficial to perform intrinsic call compared to lib call? bool ShouldUseVectorIntrinsic = ID && LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) -> bool { - bool NeedToScalarize = false; + Function *Variant; // Is it beneficial to perform intrinsic call compared to lib // call? InstructionCost CallCost = - CM.getVectorCallCost(CI, VF, NeedToScalarize); + CM.getVectorCallCost(CI, VF, &Variant); InstructionCost IntrinsicCost = CM.getVectorIntrinsicCost(CI, VF); return IntrinsicCost <= CallCost; @@ -8358,6 +8437,9 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, if (ShouldUseVectorIntrinsic) return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), ID); + Function *Variant = nullptr; + ElementCount VariantVF; + bool NeedsMask = false; // Is better to call a vectorized version of the function than to to scalarize // the call? auto ShouldUseVectorCall = LoopVectorizationPlanner::getDecisionAndClampRange( @@ -8365,14 +8447,57 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, // The following case may be scalarized depending on the VF. // The flag shows whether we can use a usual Call for vectorized // version of the instruction. - bool NeedToScalarize = false; - CM.getVectorCallCost(CI, VF, NeedToScalarize); - return !NeedToScalarize; + + // If we've found a variant at a previous VF, then stop looking. A + // vectorized variant of a function expects input in a certain shape + // -- basically the number of input registers, the number of lanes + // per register, and whether there's a mask required. + // We store a pointer to the variant in the VPWidenCallRecipe, so + // once we have an appropriate variant it's only valid for that VF. + // This will force a different vplan to be generated for each VF that + // finds a valid variant. + if (Variant) + return false; + CM.getVectorCallCost(CI, VF, &Variant, &NeedsMask); + // If we found a valid vector variant at this VF, then store the VF + // in case we need to generate a mask. + if (Variant) + VariantVF = VF; + return Variant != nullptr; }, Range); - if (ShouldUseVectorCall) + if (ShouldUseVectorCall) { + if (NeedsMask) { + // We have 2 cases that would require a mask: + // 1) The block needs to be predicated, either due to a conditional + // in the scalar loop or use of an active lane mask with + // tail-folding, and we use the appropriate mask for the block. + // 2) No mask is required for the block, but the only available + // vector variant at this VF requires a mask, so we synthesize an + // all-true mask. + VPValue *Mask = nullptr; + if (Legal->isMaskRequired(CI)) + Mask = createBlockInMask(CI->getParent(), *Plan); + else + Mask = Plan->getVPValueOrAddLiveIn(ConstantInt::getTrue( + IntegerType::getInt1Ty(Variant->getFunctionType()->getContext()))); + + VFShape Shape = VFShape::get(*CI, VariantVF, /*HasGlobalPred=*/true); + unsigned MaskPos = 0; + + for (const VFInfo &Info : VFDatabase::getMappings(*CI)) + if (Info.Shape == Shape) { + assert(Info.isMasked() && "Vector function info shape mismatch"); + MaskPos = Info.getParamIndexForOptionalMask().value(); + break; + } + + Ops.insert(Ops.begin() + MaskPos, Mask); + } + return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), - Intrinsic::not_intrinsic); + Intrinsic::not_intrinsic, Variant); + } return nullptr; } @@ -8405,9 +8530,9 @@ VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I, // div/rem operation itself. Otherwise fall through to general handling below. if (CM.isPredicatedInst(I)) { SmallVector<VPValue *> Ops(Operands.begin(), Operands.end()); - VPValue *Mask = createBlockInMask(I->getParent(), Plan); - VPValue *One = - Plan->getOrAddExternalDef(ConstantInt::get(I->getType(), 1u, false)); + VPValue *Mask = createBlockInMask(I->getParent(), *Plan); + VPValue *One = Plan->getVPValueOrAddLiveIn( + ConstantInt::get(I->getType(), 1u, false)); auto *SafeRHS = new VPInstruction(Instruction::Select, {Mask, Ops[1], One}, I->getDebugLoc()); @@ -8415,38 +8540,26 @@ VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I, Ops[1] = SafeRHS; return new VPWidenRecipe(*I, make_range(Ops.begin(), Ops.end())); } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Instruction::Add: case Instruction::And: case Instruction::AShr: - case Instruction::BitCast: case Instruction::FAdd: case Instruction::FCmp: case Instruction::FDiv: case Instruction::FMul: case Instruction::FNeg: - case Instruction::FPExt: - case Instruction::FPToSI: - case Instruction::FPToUI: - case Instruction::FPTrunc: case Instruction::FRem: case Instruction::FSub: case Instruction::ICmp: - case Instruction::IntToPtr: case Instruction::LShr: case Instruction::Mul: case Instruction::Or: - case Instruction::PtrToInt: case Instruction::Select: - case Instruction::SExt: case Instruction::Shl: - case Instruction::SIToFP: case Instruction::Sub: - case Instruction::Trunc: - case Instruction::UIToFP: case Instruction::Xor: - case Instruction::ZExt: case Instruction::Freeze: return new VPWidenRecipe(*I, make_range(Operands.begin(), Operands.end())); }; @@ -8462,9 +8575,9 @@ void VPRecipeBuilder::fixHeaderPhis() { } } -VPBasicBlock *VPRecipeBuilder::handleReplication( - Instruction *I, VFRange &Range, VPBasicBlock *VPBB, - VPlanPtr &Plan) { +VPRecipeOrVPValueTy VPRecipeBuilder::handleReplication(Instruction *I, + VFRange &Range, + VPlan &Plan) { bool IsUniform = LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { return CM.isUniformAfterVectorization(I, VF); }, Range); @@ -8501,83 +8614,22 @@ VPBasicBlock *VPRecipeBuilder::handleReplication( break; } } - - auto *Recipe = new VPReplicateRecipe(I, Plan->mapToVPValues(I->operands()), - IsUniform, IsPredicated); - - // Find if I uses a predicated instruction. If so, it will use its scalar - // value. Avoid hoisting the insert-element which packs the scalar value into - // a vector value, as that happens iff all users use the vector value. - for (VPValue *Op : Recipe->operands()) { - auto *PredR = - dyn_cast_or_null<VPPredInstPHIRecipe>(Op->getDefiningRecipe()); - if (!PredR) - continue; - auto *RepR = cast<VPReplicateRecipe>( - PredR->getOperand(0)->getDefiningRecipe()); - assert(RepR->isPredicated() && - "expected Replicate recipe to be predicated"); - RepR->setAlsoPack(false); - } - - // Finalize the recipe for Instr, first if it is not predicated. + VPValue *BlockInMask = nullptr; if (!IsPredicated) { + // Finalize the recipe for Instr, first if it is not predicated. LLVM_DEBUG(dbgs() << "LV: Scalarizing:" << *I << "\n"); - setRecipe(I, Recipe); - Plan->addVPValue(I, Recipe); - VPBB->appendRecipe(Recipe); - return VPBB; - } - LLVM_DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n"); - - VPBlockBase *SingleSucc = VPBB->getSingleSuccessor(); - assert(SingleSucc && "VPBB must have a single successor when handling " - "predicated replication."); - VPBlockUtils::disconnectBlocks(VPBB, SingleSucc); - // Record predicated instructions for above packing optimizations. - VPBlockBase *Region = createReplicateRegion(Recipe, Plan); - VPBlockUtils::insertBlockAfter(Region, VPBB); - auto *RegSucc = new VPBasicBlock(); - VPBlockUtils::insertBlockAfter(RegSucc, Region); - VPBlockUtils::connectBlocks(RegSucc, SingleSucc); - return RegSucc; -} - -VPRegionBlock * -VPRecipeBuilder::createReplicateRegion(VPReplicateRecipe *PredRecipe, - VPlanPtr &Plan) { - Instruction *Instr = PredRecipe->getUnderlyingInstr(); - // Instructions marked for predication are replicated and placed under an - // if-then construct to prevent side-effects. - // Generate recipes to compute the block mask for this region. - VPValue *BlockInMask = createBlockInMask(Instr->getParent(), Plan); - - // Build the triangular if-then region. - std::string RegionName = (Twine("pred.") + Instr->getOpcodeName()).str(); - assert(Instr->getParent() && "Predicated instruction not in any basic block"); - auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask); - auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe); - auto *PHIRecipe = Instr->getType()->isVoidTy() - ? nullptr - : new VPPredInstPHIRecipe(PredRecipe); - if (PHIRecipe) { - setRecipe(Instr, PHIRecipe); - Plan->addVPValue(Instr, PHIRecipe); } else { - setRecipe(Instr, PredRecipe); - Plan->addVPValue(Instr, PredRecipe); + LLVM_DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n"); + // Instructions marked for predication are replicated and a mask operand is + // added initially. Masked replicate recipes will later be placed under an + // if-then construct to prevent side-effects. Generate recipes to compute + // the block mask for this region. + BlockInMask = createBlockInMask(I->getParent(), Plan); } - auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); - auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", PredRecipe); - VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true); - - // Note: first set Entry as region entry and then connect successors starting - // from it in order, to propagate the "parent" of each VPBasicBlock. - VPBlockUtils::insertTwoBlocksAfter(Pred, Exiting, Entry); - VPBlockUtils::connectBlocks(Pred, Exiting); - - return Region; + auto *Recipe = new VPReplicateRecipe(I, Plan.mapToVPValues(I->operands()), + IsUniform, BlockInMask); + return toVPRecipeResult(Recipe); } VPRecipeOrVPValueTy @@ -8643,7 +8695,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, return nullptr; if (auto *CI = dyn_cast<CallInst>(Instr)) - return toVPRecipeResult(tryToWidenCall(CI, Operands, Range)); + return toVPRecipeResult(tryToWidenCall(CI, Operands, Range, Plan)); if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) return toVPRecipeResult(tryToWidenMemory(Instr, Operands, Range, Plan)); @@ -8653,13 +8705,16 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, if (auto GEP = dyn_cast<GetElementPtrInst>(Instr)) return toVPRecipeResult(new VPWidenGEPRecipe( - GEP, make_range(Operands.begin(), Operands.end()), OrigLoop)); + GEP, make_range(Operands.begin(), Operands.end()))); if (auto *SI = dyn_cast<SelectInst>(Instr)) { - bool InvariantCond = - PSE.getSE()->isLoopInvariant(PSE.getSCEV(SI->getOperand(0)), OrigLoop); return toVPRecipeResult(new VPWidenSelectRecipe( - *SI, make_range(Operands.begin(), Operands.end()), InvariantCond)); + *SI, make_range(Operands.begin(), Operands.end()))); + } + + if (auto *CI = dyn_cast<CastInst>(Instr)) { + return toVPRecipeResult( + new VPWidenCastRecipe(CI->getOpcode(), Operands[0], CI->getType(), CI)); } return toVPRecipeResult(tryToWiden(Instr, Operands, VPBB, Plan)); @@ -8677,34 +8732,11 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, auto &ConditionalAssumes = Legal->getConditionalAssumes(); DeadInstructions.insert(ConditionalAssumes.begin(), ConditionalAssumes.end()); - MapVector<Instruction *, Instruction *> &SinkAfter = Legal->getSinkAfter(); - // Dead instructions do not need sinking. Remove them from SinkAfter. - for (Instruction *I : DeadInstructions) - SinkAfter.erase(I); - - // Cannot sink instructions after dead instructions (there won't be any - // recipes for them). Instead, find the first non-dead previous instruction. - for (auto &P : Legal->getSinkAfter()) { - Instruction *SinkTarget = P.second; - Instruction *FirstInst = &*SinkTarget->getParent()->begin(); - (void)FirstInst; - while (DeadInstructions.contains(SinkTarget)) { - assert( - SinkTarget != FirstInst && - "Must find a live instruction (at least the one feeding the " - "fixed-order recurrence PHI) before reaching beginning of the block"); - SinkTarget = SinkTarget->getPrevNode(); - assert(SinkTarget != P.first && - "sink source equals target, no sinking required"); - } - P.second = SinkTarget; - } - - auto MaxVFPlusOne = MaxVF.getWithIncrement(1); - for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFPlusOne);) { - VFRange SubRange = {VF, MaxVFPlusOne}; - VPlans.push_back( - buildVPlanWithVPRecipes(SubRange, DeadInstructions, SinkAfter)); + auto MaxVFTimes2 = MaxVF * 2; + for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) { + VFRange SubRange = {VF, MaxVFTimes2}; + if (auto Plan = tryToBuildVPlanWithVPRecipes(SubRange, DeadInstructions)) + VPlans.push_back(std::move(*Plan)); VF = SubRange.End; } } @@ -8712,10 +8744,9 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, // Add the necessary canonical IV and branch recipes required to control the // loop. static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, - bool HasNUW, - bool UseLaneMaskForLoopControlFlow) { + TailFoldingStyle Style) { Value *StartIdx = ConstantInt::get(IdxTy, 0); - auto *StartV = Plan.getOrAddVPValue(StartIdx); + auto *StartV = Plan.getVPValueOrAddLiveIn(StartIdx); // Add a VPCanonicalIVPHIRecipe starting at 0 to the header. auto *CanonicalIVPHI = new VPCanonicalIVPHIRecipe(StartV, DL); @@ -8725,6 +8756,7 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, // Add a CanonicalIVIncrement{NUW} VPInstruction to increment the scalar // IV by VF * UF. + bool HasNUW = Style == TailFoldingStyle::None; auto *CanonicalIVIncrement = new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementNUW : VPInstruction::CanonicalIVIncrement, @@ -8732,11 +8764,10 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, CanonicalIVPHI->addOperand(CanonicalIVIncrement); VPBasicBlock *EB = TopRegion->getExitingBasicBlock(); - EB->appendRecipe(CanonicalIVIncrement); - - if (UseLaneMaskForLoopControlFlow) { + if (useActiveLaneMaskForControlFlow(Style)) { // Create the active lane mask instruction in the vplan preheader. - VPBasicBlock *Preheader = Plan.getEntry()->getEntryBasicBlock(); + VPBasicBlock *VecPreheader = + cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSinglePredecessor()); // We can't use StartV directly in the ActiveLaneMask VPInstruction, since // we have to take unrolling into account. Each part needs to start at @@ -8745,14 +8776,34 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW : VPInstruction::CanonicalIVIncrementForPart, {StartV}, DL, "index.part.next"); - Preheader->appendRecipe(CanonicalIVIncrementParts); + VecPreheader->appendRecipe(CanonicalIVIncrementParts); // Create the ActiveLaneMask instruction using the correct start values. - VPValue *TC = Plan.getOrCreateTripCount(); + VPValue *TC = Plan.getTripCount(); + + VPValue *TripCount, *IncrementValue; + if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) { + // When avoiding a runtime check, the active.lane.mask inside the loop + // uses a modified trip count and the induction variable increment is + // done after the active.lane.mask intrinsic is called. + auto *TCMinusVF = + new VPInstruction(VPInstruction::CalculateTripCountMinusVF, {TC}, DL); + VecPreheader->appendRecipe(TCMinusVF); + IncrementValue = CanonicalIVPHI; + TripCount = TCMinusVF; + } else { + // When the loop is guarded by a runtime overflow check for the loop + // induction variable increment by VF, we can increment the value before + // the get.active.lane mask and use the unmodified tripcount. + EB->appendRecipe(CanonicalIVIncrement); + IncrementValue = CanonicalIVIncrement; + TripCount = TC; + } + auto *EntryALM = new VPInstruction(VPInstruction::ActiveLaneMask, {CanonicalIVIncrementParts, TC}, DL, "active.lane.mask.entry"); - Preheader->appendRecipe(EntryALM); + VecPreheader->appendRecipe(EntryALM); // Now create the ActiveLaneMaskPhi recipe in the main loop using the // preheader ActiveLaneMask instruction. @@ -8763,15 +8814,21 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, CanonicalIVIncrementParts = new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW : VPInstruction::CanonicalIVIncrementForPart, - {CanonicalIVIncrement}, DL); + {IncrementValue}, DL); EB->appendRecipe(CanonicalIVIncrementParts); auto *ALM = new VPInstruction(VPInstruction::ActiveLaneMask, - {CanonicalIVIncrementParts, TC}, DL, + {CanonicalIVIncrementParts, TripCount}, DL, "active.lane.mask.next"); EB->appendRecipe(ALM); LaneMaskPhi->addOperand(ALM); + if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) { + // Do the increment of the canonical IV after the active.lane.mask, because + // that value is still based off %CanonicalIVPHI + EB->appendRecipe(CanonicalIVIncrement); + } + // We have to invert the mask here because a true condition means jumping // to the exit block. auto *NotMask = new VPInstruction(VPInstruction::Not, ALM, DL); @@ -8781,6 +8838,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, new VPInstruction(VPInstruction::BranchOnCond, {NotMask}, DL); EB->appendRecipe(BranchBack); } else { + EB->appendRecipe(CanonicalIVIncrement); + // Add the BranchOnCount VPInstruction to the latch. VPInstruction *BranchBack = new VPInstruction( VPInstruction::BranchOnCount, @@ -8804,14 +8863,13 @@ static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, for (PHINode &ExitPhi : ExitBB->phis()) { Value *IncomingValue = ExitPhi.getIncomingValueForBlock(ExitingBB); - VPValue *V = Plan.getOrAddVPValue(IncomingValue, true); + VPValue *V = Plan.getVPValueOrAddLiveIn(IncomingValue); Plan.addLiveOut(&ExitPhi, V); } } -VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( - VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions, - const MapVector<Instruction *, Instruction *> &SinkAfter) { +std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( + VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions) { SmallPtrSet<const InterleaveGroup<Instruction> *, 1> InterleaveGroups; @@ -8822,12 +8880,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // process after constructing the initial VPlan. // --------------------------------------------------------------------------- - // Mark instructions we'll need to sink later and their targets as - // ingredients whose recipe we'll need to record. - for (const auto &Entry : SinkAfter) { - RecipeBuilder.recordRecipeOf(Entry.first); - RecipeBuilder.recordRecipeOf(Entry.second); - } for (const auto &Reduction : CM.getInLoopReductionChains()) { PHINode *Phi = Reduction.first; RecurKind Kind = @@ -8852,9 +8904,15 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // single VPInterleaveRecipe. for (InterleaveGroup<Instruction> *IG : IAI.getInterleaveGroups()) { auto applyIG = [IG, this](ElementCount VF) -> bool { - return (VF.isVector() && // Query is illegal for VF == 1 - CM.getWideningDecision(IG->getInsertPos(), VF) == - LoopVectorizationCostModel::CM_Interleave); + bool Result = (VF.isVector() && // Query is illegal for VF == 1 + CM.getWideningDecision(IG->getInsertPos(), VF) == + LoopVectorizationCostModel::CM_Interleave); + // For scalable vectors, the only interleave factor currently supported + // is 2 since we require the (de)interleave2 intrinsics instead of + // shufflevectors. + assert((!Result || !VF.isScalable() || IG->getFactor() == 2) && + "Unsupported interleave factor for scalable vectors"); + return Result; }; if (!getDecisionAndClampRange(applyIG, Range)) continue; @@ -8869,26 +8927,34 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // visit each basic block after having visited its predecessor basic blocks. // --------------------------------------------------------------------------- - // Create initial VPlan skeleton, starting with a block for the pre-header, - // followed by a region for the vector loop, followed by the middle block. The - // skeleton vector loop region contains a header and latch block. - VPBasicBlock *Preheader = new VPBasicBlock("vector.ph"); - auto Plan = std::make_unique<VPlan>(Preheader); - + // Create initial VPlan skeleton, having a basic block for the pre-header + // which contains SCEV expansions that need to happen before the CFG is + // modified; a basic block for the vector pre-header, followed by a region for + // the vector loop, followed by the middle basic block. The skeleton vector + // loop region contains a header and latch basic blocks. + VPlanPtr Plan = VPlan::createInitialVPlan( + createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop), + *PSE.getSE()); VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body"); VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch"); VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB); auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop"); - VPBlockUtils::insertBlockAfter(TopRegion, Preheader); + VPBlockUtils::insertBlockAfter(TopRegion, Plan->getEntry()); VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block"); VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion); + // Don't use getDecisionAndClampRange here, because we don't know the UF + // so this function is better to be conservative, rather than to split + // it up into different VPlans. + bool IVUpdateMayOverflow = false; + for (ElementCount VF : Range) + IVUpdateMayOverflow |= !isIndvarOverflowCheckKnownFalse(&CM, VF); + Instruction *DLInst = getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DLInst ? DLInst->getDebugLoc() : DebugLoc(), - !CM.foldTailByMasking(), - CM.useActiveLaneMaskForControlFlow()); + CM.getTailFoldingStyle(IVUpdateMayOverflow)); // Scan the body of the loop in a topological order to visit each basic block // after having visited its predecessor basic blocks. @@ -8896,18 +8962,16 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( DFS.perform(LI); VPBasicBlock *VPBB = HeaderVPBB; - SmallVector<VPWidenIntOrFpInductionRecipe *> InductionsToMove; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { // Relevant instructions from basic block BB will be grouped into VPRecipe // ingredients and fill a new VPBasicBlock. - unsigned VPBBsForBB = 0; if (VPBB != HeaderVPBB) VPBB->setName(BB->getName()); Builder.setInsertPoint(VPBB); // Introduce each ingredient into VPlan. // TODO: Model and preserve debug intrinsics in VPlan. - for (Instruction &I : BB->instructionsWithoutDebug()) { + for (Instruction &I : BB->instructionsWithoutDebug(false)) { Instruction *Instr = &I; // First filter out irrelevant instructions, to ensure no recipes are @@ -8918,7 +8982,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( SmallVector<VPValue *, 4> Operands; auto *Phi = dyn_cast<PHINode>(Instr); if (Phi && Phi->getParent() == OrigLoop->getHeader()) { - Operands.push_back(Plan->getOrAddVPValue( + Operands.push_back(Plan->getVPValueOrAddLiveIn( Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()))); } else { auto OpRange = Plan->mapToVPValues(Instr->operands()); @@ -8932,50 +8996,36 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( Legal->isInvariantAddressOfReduction(SI->getPointerOperand())) continue; - if (auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe( - Instr, Operands, Range, VPBB, Plan)) { - // If Instr can be simplified to an existing VPValue, use it. - if (RecipeOrValue.is<VPValue *>()) { - auto *VPV = RecipeOrValue.get<VPValue *>(); - Plan->addVPValue(Instr, VPV); - // If the re-used value is a recipe, register the recipe for the - // instruction, in case the recipe for Instr needs to be recorded. - if (VPRecipeBase *R = VPV->getDefiningRecipe()) - RecipeBuilder.setRecipe(Instr, R); - continue; - } - // Otherwise, add the new recipe. - VPRecipeBase *Recipe = RecipeOrValue.get<VPRecipeBase *>(); - for (auto *Def : Recipe->definedValues()) { - auto *UV = Def->getUnderlyingValue(); - Plan->addVPValue(UV, Def); - } - - if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) && - HeaderVPBB->getFirstNonPhi() != VPBB->end()) { - // Keep track of VPWidenIntOrFpInductionRecipes not in the phi section - // of the header block. That can happen for truncates of induction - // variables. Those recipes are moved to the phi section of the header - // block after applying SinkAfter, which relies on the original - // position of the trunc. - assert(isa<TruncInst>(Instr)); - InductionsToMove.push_back( - cast<VPWidenIntOrFpInductionRecipe>(Recipe)); - } - RecipeBuilder.setRecipe(Instr, Recipe); - VPBB->appendRecipe(Recipe); + auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe( + Instr, Operands, Range, VPBB, Plan); + if (!RecipeOrValue) + RecipeOrValue = RecipeBuilder.handleReplication(Instr, Range, *Plan); + // If Instr can be simplified to an existing VPValue, use it. + if (isa<VPValue *>(RecipeOrValue)) { + auto *VPV = cast<VPValue *>(RecipeOrValue); + Plan->addVPValue(Instr, VPV); + // If the re-used value is a recipe, register the recipe for the + // instruction, in case the recipe for Instr needs to be recorded. + if (VPRecipeBase *R = VPV->getDefiningRecipe()) + RecipeBuilder.setRecipe(Instr, R); continue; } - - // Otherwise, if all widening options failed, Instruction is to be - // replicated. This may create a successor for VPBB. - VPBasicBlock *NextVPBB = - RecipeBuilder.handleReplication(Instr, Range, VPBB, Plan); - if (NextVPBB != VPBB) { - VPBB = NextVPBB; - VPBB->setName(BB->hasName() ? BB->getName() + "." + Twine(VPBBsForBB++) - : ""); + // Otherwise, add the new recipe. + VPRecipeBase *Recipe = cast<VPRecipeBase *>(RecipeOrValue); + for (auto *Def : Recipe->definedValues()) { + auto *UV = Def->getUnderlyingValue(); + Plan->addVPValue(UV, Def); } + + RecipeBuilder.setRecipe(Instr, Recipe); + if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) && + HeaderVPBB->getFirstNonPhi() != VPBB->end()) { + // Move VPWidenIntOrFpInductionRecipes for optimized truncates to the + // phi section of HeaderVPBB. + assert(isa<TruncInst>(Instr)); + Recipe->insertBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi()); + } else + VPBB->appendRecipe(Recipe); } VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB); @@ -8985,7 +9035,12 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // After here, VPBB should not be used. VPBB = nullptr; - addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan); + if (CM.requiresScalarEpilogue(Range)) { + // No edge from the middle block to the unique exit block has been inserted + // and there is nothing to fix from vector loop; phis should have incoming + // from scalar loop only. + } else + addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan); assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) && !Plan->getVectorLoopRegion()->getEntryBasicBlock()->empty() && @@ -8998,116 +9053,10 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // bring the VPlan to its final state. // --------------------------------------------------------------------------- - // Apply Sink-After legal constraints. - auto GetReplicateRegion = [](VPRecipeBase *R) -> VPRegionBlock * { - auto *Region = dyn_cast_or_null<VPRegionBlock>(R->getParent()->getParent()); - if (Region && Region->isReplicator()) { - assert(Region->getNumSuccessors() == 1 && - Region->getNumPredecessors() == 1 && "Expected SESE region!"); - assert(R->getParent()->size() == 1 && - "A recipe in an original replicator region must be the only " - "recipe in its block"); - return Region; - } - return nullptr; - }; - for (const auto &Entry : SinkAfter) { - VPRecipeBase *Sink = RecipeBuilder.getRecipe(Entry.first); - VPRecipeBase *Target = RecipeBuilder.getRecipe(Entry.second); - - auto *TargetRegion = GetReplicateRegion(Target); - auto *SinkRegion = GetReplicateRegion(Sink); - if (!SinkRegion) { - // If the sink source is not a replicate region, sink the recipe directly. - if (TargetRegion) { - // The target is in a replication region, make sure to move Sink to - // the block after it, not into the replication region itself. - VPBasicBlock *NextBlock = - cast<VPBasicBlock>(TargetRegion->getSuccessors().front()); - Sink->moveBefore(*NextBlock, NextBlock->getFirstNonPhi()); - } else - Sink->moveAfter(Target); - continue; - } - - // The sink source is in a replicate region. Unhook the region from the CFG. - auto *SinkPred = SinkRegion->getSinglePredecessor(); - auto *SinkSucc = SinkRegion->getSingleSuccessor(); - VPBlockUtils::disconnectBlocks(SinkPred, SinkRegion); - VPBlockUtils::disconnectBlocks(SinkRegion, SinkSucc); - VPBlockUtils::connectBlocks(SinkPred, SinkSucc); - - if (TargetRegion) { - // The target recipe is also in a replicate region, move the sink region - // after the target region. - auto *TargetSucc = TargetRegion->getSingleSuccessor(); - VPBlockUtils::disconnectBlocks(TargetRegion, TargetSucc); - VPBlockUtils::connectBlocks(TargetRegion, SinkRegion); - VPBlockUtils::connectBlocks(SinkRegion, TargetSucc); - } else { - // The sink source is in a replicate region, we need to move the whole - // replicate region, which should only contain a single recipe in the - // main block. - auto *SplitBlock = - Target->getParent()->splitAt(std::next(Target->getIterator())); - - auto *SplitPred = SplitBlock->getSinglePredecessor(); - - VPBlockUtils::disconnectBlocks(SplitPred, SplitBlock); - VPBlockUtils::connectBlocks(SplitPred, SinkRegion); - VPBlockUtils::connectBlocks(SinkRegion, SplitBlock); - } - } - - VPlanTransforms::removeRedundantCanonicalIVs(*Plan); - VPlanTransforms::removeRedundantInductionCasts(*Plan); - - // Now that sink-after is done, move induction recipes for optimized truncates - // to the phi section of the header block. - for (VPWidenIntOrFpInductionRecipe *Ind : InductionsToMove) - Ind->moveBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi()); - // Adjust the recipes for any inloop reductions. adjustRecipesForReductions(cast<VPBasicBlock>(TopRegion->getExiting()), Plan, RecipeBuilder, Range.Start); - // Introduce a recipe to combine the incoming and previous values of a - // fixed-order recurrence. - for (VPRecipeBase &R : - Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { - auto *RecurPhi = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R); - if (!RecurPhi) - continue; - - VPRecipeBase *PrevRecipe = &RecurPhi->getBackedgeRecipe(); - // Fixed-order recurrences do not contain cycles, so this loop is guaranteed - // to terminate. - while (auto *PrevPhi = - dyn_cast<VPFirstOrderRecurrencePHIRecipe>(PrevRecipe)) - PrevRecipe = &PrevPhi->getBackedgeRecipe(); - VPBasicBlock *InsertBlock = PrevRecipe->getParent(); - auto *Region = GetReplicateRegion(PrevRecipe); - if (Region) - InsertBlock = dyn_cast<VPBasicBlock>(Region->getSingleSuccessor()); - if (!InsertBlock) { - InsertBlock = new VPBasicBlock(Region->getName() + ".succ"); - VPBlockUtils::insertBlockAfter(InsertBlock, Region); - } - if (Region || PrevRecipe->isPhi()) - Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi()); - else - Builder.setInsertPoint(InsertBlock, std::next(PrevRecipe->getIterator())); - - auto *RecurSplice = cast<VPInstruction>( - Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice, - {RecurPhi, RecurPhi->getBackedgeValue()})); - - RecurPhi->replaceAllUsesWith(RecurSplice); - // Set the first operand of RecurSplice to RecurPhi again, after replacing - // all users. - RecurSplice->setOperand(0, RecurPhi); - } - // Interleave memory: for each Interleave Group we marked earlier as relevant // for this VPlan, replace the Recipes widening its memory instructions with a // single VPInterleaveRecipe at its insertion point. @@ -9122,48 +9071,66 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( StoredValues.push_back(StoreR->getStoredValue()); } + bool NeedsMaskForGaps = + IG->requiresScalarEpilogue() && !CM.isScalarEpilogueAllowed(); auto *VPIG = new VPInterleaveRecipe(IG, Recipe->getAddr(), StoredValues, - Recipe->getMask()); + Recipe->getMask(), NeedsMaskForGaps); VPIG->insertBefore(Recipe); unsigned J = 0; for (unsigned i = 0; i < IG->getFactor(); ++i) if (Instruction *Member = IG->getMember(i)) { + VPRecipeBase *MemberR = RecipeBuilder.getRecipe(Member); if (!Member->getType()->isVoidTy()) { - VPValue *OriginalV = Plan->getVPValue(Member); - Plan->removeVPValueFor(Member); - Plan->addVPValue(Member, VPIG->getVPValue(J)); + VPValue *OriginalV = MemberR->getVPSingleValue(); OriginalV->replaceAllUsesWith(VPIG->getVPValue(J)); J++; } - RecipeBuilder.getRecipe(Member)->eraseFromParent(); + MemberR->eraseFromParent(); } } - for (ElementCount VF = Range.Start; ElementCount::isKnownLT(VF, Range.End); - VF *= 2) + for (ElementCount VF : Range) Plan->addVF(VF); Plan->setName("Initial VPlan"); + // Replace VPValues for known constant strides guaranteed by predicate scalar + // evolution. + for (auto [_, Stride] : Legal->getLAI()->getSymbolicStrides()) { + auto *StrideV = cast<SCEVUnknown>(Stride)->getValue(); + auto *ScevStride = dyn_cast<SCEVConstant>(PSE.getSCEV(StrideV)); + // Only handle constant strides for now. + if (!ScevStride) + continue; + Constant *CI = ConstantInt::get(Stride->getType(), ScevStride->getAPInt()); + + auto *ConstVPV = Plan->getVPValueOrAddLiveIn(CI); + // The versioned value may not be used in the loop directly, so just add a + // new live-in in those cases. + Plan->getVPValueOrAddLiveIn(StrideV)->replaceAllUsesWith(ConstVPV); + } + // From this point onwards, VPlan-to-VPlan transformations may change the plan // in ways that accessing values using original IR values is incorrect. Plan->disableValue2VPValue(); + // Sink users of fixed-order recurrence past the recipe defining the previous + // value and introduce FirstOrderRecurrenceSplice VPInstructions. + if (!VPlanTransforms::adjustFixedOrderRecurrences(*Plan, Builder)) + return std::nullopt; + + VPlanTransforms::removeRedundantCanonicalIVs(*Plan); + VPlanTransforms::removeRedundantInductionCasts(*Plan); + VPlanTransforms::optimizeInductions(*Plan, *PSE.getSE()); VPlanTransforms::removeDeadRecipes(*Plan); - bool ShouldSimplify = true; - while (ShouldSimplify) { - ShouldSimplify = VPlanTransforms::sinkScalarOperands(*Plan); - ShouldSimplify |= - VPlanTransforms::mergeReplicateRegionsIntoSuccessors(*Plan); - ShouldSimplify |= VPlanTransforms::mergeBlocksIntoPredecessors(*Plan); - } + VPlanTransforms::createAndOptimizeReplicateRegions(*Plan); VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan); VPlanTransforms::mergeBlocksIntoPredecessors(*Plan); assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); - return Plan; + return std::make_optional(std::move(Plan)); } VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { @@ -9175,21 +9142,21 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { assert(EnableVPlanNativePath && "VPlan-native path is not enabled."); // Create new empty VPlan - auto Plan = std::make_unique<VPlan>(); + auto Plan = VPlan::createInitialVPlan( + createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop), + *PSE.getSE()); // Build hierarchical CFG VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan); HCFGBuilder.buildHierarchicalCFG(); - for (ElementCount VF = Range.Start; ElementCount::isKnownLT(VF, Range.End); - VF *= 2) + for (ElementCount VF : Range) Plan->addVF(VF); - SmallPtrSet<Instruction *, 1> DeadInstructions; VPlanTransforms::VPInstructionsToVPRecipes( - OrigLoop, Plan, + Plan, [this](PHINode *P) { return Legal->getIntOrFpInductionDescriptor(P); }, - DeadInstructions, *PSE.getSE(), *TLI); + *PSE.getSE(), *TLI); // Remove the existing terminator of the exiting block of the top-most region. // A BranchOnCount will be added instead when adding the canonical IV recipes. @@ -9198,7 +9165,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { Term->eraseFromParent(); addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DebugLoc(), - true, CM.useActiveLaneMaskForControlFlow()); + CM.getTailFoldingStyle()); return Plan; } @@ -9255,7 +9222,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( VPBuilder::InsertPointGuard Guard(Builder); Builder.setInsertPoint(WidenRecipe->getParent(), WidenRecipe->getIterator()); - CondOp = RecipeBuilder.createBlockInMask(R->getParent(), Plan); + CondOp = RecipeBuilder.createBlockInMask(R->getParent(), *Plan); } if (IsFMulAdd) { @@ -9270,7 +9237,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( VecOp = FMulRecipe; } VPReductionRecipe *RedRecipe = - new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, TTI); + new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, &TTI); WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); Plan->removeVPValueFor(R); Plan->addVPValue(R, RedRecipe); @@ -9304,13 +9271,15 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( if (!PhiR || PhiR->isInLoop()) continue; VPValue *Cond = - RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan); + RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), *Plan); VPValue *Red = PhiR->getBackedgeValue(); assert(Red->getDefiningRecipe()->getParent() != LatchVPBB && "reduction recipe must be defined before latch"); Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR}); } } + + VPlanTransforms::clearReductionWrapFlags(*Plan); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -9475,7 +9444,7 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { PartStart, ConstantInt::get(PtrInd->getType(), Lane)); Value *GlobalIdx = State.Builder.CreateAdd(PtrInd, Idx); - Value *Step = State.get(getOperand(1), VPIteration(0, Part)); + Value *Step = State.get(getOperand(1), VPIteration(Part, Lane)); Value *SclrGep = emitTransformedIndex( State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, IndDesc); SclrGep->setName("next.gep"); @@ -9485,8 +9454,6 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { return; } - assert(isa<SCEVConstant>(IndDesc.getStep()) && - "Induction step not a SCEV constant!"); Type *PhiType = IndDesc.getStep()->getType(); // Build a pointer phi @@ -9506,7 +9473,7 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { Value *NumUnrolledElems = State.Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, State.UF)); Value *InductionGEP = GetElementPtrInst::Create( - IndDesc.getElementType(), NewPointerPhi, + State.Builder.getInt8Ty(), NewPointerPhi, State.Builder.CreateMul(ScalarStepValue, NumUnrolledElems), "ptr.ind", InductionLoc); // Add induction update using an incorrect block temporarily. The phi node @@ -9529,10 +9496,10 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { StartOffset = State.Builder.CreateAdd( StartOffset, State.Builder.CreateStepVector(VecPhiType)); - assert(ScalarStepValue == State.get(getOperand(1), VPIteration(0, Part)) && + assert(ScalarStepValue == State.get(getOperand(1), VPIteration(Part, 0)) && "scalar step must be the same across all parts"); Value *GEP = State.Builder.CreateGEP( - IndDesc.getElementType(), NewPointerPhi, + State.Builder.getInt8Ty(), NewPointerPhi, State.Builder.CreateMul( StartOffset, State.Builder.CreateVectorSplat(State.VF, ScalarStepValue), @@ -9584,7 +9551,8 @@ void VPScalarIVStepsRecipe::execute(VPTransformState &State) { void VPInterleaveRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Interleave group being replicated."); State.ILV->vectorizeInterleaveGroup(IG, definedValues(), State, getAddr(), - getStoredValues(), getMask()); + getStoredValues(), getMask(), + NeedsMaskForGaps); } void VPReductionRecipe::execute(VPTransformState &State) { @@ -9640,10 +9608,9 @@ void VPReplicateRecipe::execute(VPTransformState &State) { Instruction *UI = getUnderlyingInstr(); if (State.Instance) { // Generate a single instance. assert(!State.VF.isScalable() && "Can't scalarize a scalable vector"); - State.ILV->scalarizeInstruction(UI, this, *State.Instance, - IsPredicated, State); + State.ILV->scalarizeInstruction(UI, this, *State.Instance, State); // Insert scalar instance packing it into a vector. - if (AlsoPack && State.VF.isVector()) { + if (State.VF.isVector() && shouldPack()) { // If we're constructing lane 0, initialize to start from poison. if (State.Instance->Lane.isFirstLane()) { assert(!State.VF.isScalable() && "VF is assumed to be non scalable."); @@ -9663,8 +9630,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) { all_of(operands(), [](VPValue *Op) { return Op->isDefinedOutsideVectorRegions(); })) { - State.ILV->scalarizeInstruction(UI, this, VPIteration(0, 0), IsPredicated, - State); + State.ILV->scalarizeInstruction(UI, this, VPIteration(0, 0), State); if (user_begin() != user_end()) { for (unsigned Part = 1; Part < State.UF; ++Part) State.set(this, State.get(this, VPIteration(0, 0)), @@ -9676,16 +9642,16 @@ void VPReplicateRecipe::execute(VPTransformState &State) { // Uniform within VL means we need to generate lane 0 only for each // unrolled copy. for (unsigned Part = 0; Part < State.UF; ++Part) - State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, 0), - IsPredicated, State); + State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, 0), State); return; } - // A store of a loop varying value to a loop invariant address only - // needs only the last copy of the store. - if (isa<StoreInst>(UI) && !getOperand(1)->hasDefiningRecipe()) { + // A store of a loop varying value to a uniform address only needs the last + // copy of the store. + if (isa<StoreInst>(UI) && + vputils::isUniformAfterVectorization(getOperand(1))) { auto Lane = VPLane::getLastLaneForVF(State.VF); - State.ILV->scalarizeInstruction(UI, this, VPIteration(State.UF - 1, Lane), IsPredicated, + State.ILV->scalarizeInstruction(UI, this, VPIteration(State.UF - 1, Lane), State); return; } @@ -9695,8 +9661,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) { const unsigned EndLane = State.VF.getKnownMinValue(); for (unsigned Part = 0; Part < State.UF; ++Part) for (unsigned Lane = 0; Lane < EndLane; ++Lane) - State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane), - IsPredicated, State); + State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane), State); } void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { @@ -9714,7 +9679,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { auto *DataTy = VectorType::get(ScalarDataTy, State.VF); const Align Alignment = getLoadStoreAlignment(&Ingredient); - bool CreateGatherScatter = !Consecutive; + bool CreateGatherScatter = !isConsecutive(); auto &Builder = State.Builder; InnerLoopVectorizer::VectorParts BlockInMaskParts(State.UF); @@ -9725,36 +9690,39 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { const auto CreateVecPtr = [&](unsigned Part, Value *Ptr) -> Value * { // Calculate the pointer for the specific unroll-part. - GetElementPtrInst *PartPtr = nullptr; - + Value *PartPtr = nullptr; + + // Use i32 for the gep index type when the value is constant, + // or query DataLayout for a more suitable index type otherwise. + const DataLayout &DL = + Builder.GetInsertBlock()->getModule()->getDataLayout(); + Type *IndexTy = State.VF.isScalable() && (isReverse() || Part > 0) + ? DL.getIndexType(ScalarDataTy->getPointerTo()) + : Builder.getInt32Ty(); bool InBounds = false; if (auto *gep = dyn_cast<GetElementPtrInst>(Ptr->stripPointerCasts())) InBounds = gep->isInBounds(); - if (Reverse) { + if (isReverse()) { // If the address is consecutive but reversed, then the // wide store needs to start at the last vector element. // RunTimeVF = VScale * VF.getKnownMinValue() // For fixed-width VScale is 1, then RunTimeVF = VF.getKnownMinValue() - Value *RunTimeVF = getRuntimeVF(Builder, Builder.getInt32Ty(), State.VF); + Value *RunTimeVF = getRuntimeVF(Builder, IndexTy, State.VF); // NumElt = -Part * RunTimeVF - Value *NumElt = Builder.CreateMul(Builder.getInt32(-Part), RunTimeVF); + Value *NumElt = + Builder.CreateMul(ConstantInt::get(IndexTy, -(int64_t)Part), RunTimeVF); // LastLane = 1 - RunTimeVF - Value *LastLane = Builder.CreateSub(Builder.getInt32(1), RunTimeVF); + Value *LastLane = + Builder.CreateSub(ConstantInt::get(IndexTy, 1), RunTimeVF); + PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, NumElt, "", InBounds); PartPtr = - cast<GetElementPtrInst>(Builder.CreateGEP(ScalarDataTy, Ptr, NumElt)); - PartPtr->setIsInBounds(InBounds); - PartPtr = cast<GetElementPtrInst>( - Builder.CreateGEP(ScalarDataTy, PartPtr, LastLane)); - PartPtr->setIsInBounds(InBounds); + Builder.CreateGEP(ScalarDataTy, PartPtr, LastLane, "", InBounds); if (isMaskRequired) // Reverse of a null all-one mask is a null mask. BlockInMaskParts[Part] = Builder.CreateVectorReverse(BlockInMaskParts[Part], "reverse"); } else { - Value *Increment = - createStepForVF(Builder, Builder.getInt32Ty(), State.VF, Part); - PartPtr = cast<GetElementPtrInst>( - Builder.CreateGEP(ScalarDataTy, Ptr, Increment)); - PartPtr->setIsInBounds(InBounds); + Value *Increment = createStepForVF(Builder, IndexTy, State.VF, Part); + PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, Increment, "", InBounds); } unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); @@ -9774,7 +9742,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, MaskPart); } else { - if (Reverse) { + if (isReverse()) { // If we store to reverse consecutive memory locations, then we need // to reverse the order of elements in the stored value. StoredVal = Builder.CreateVectorReverse(StoredVal, "reverse"); @@ -9833,7 +9801,6 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { static ScalarEpilogueLowering getScalarEpilogueLowering( Function *F, Loop *L, LoopVectorizeHints &Hints, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, TargetTransformInfo *TTI, TargetLibraryInfo *TLI, - AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, LoopVectorizationLegality &LVL, InterleavedAccessInfo *IAI) { // 1) OptSize takes precedence over all other options, i.e. if this is set, // don't look at hints or options, and don't request a scalar epilogue. @@ -9869,7 +9836,8 @@ static ScalarEpilogueLowering getScalarEpilogueLowering( }; // 4) if the TTI hook indicates this is profitable, request predication. - if (TTI->preferPredicateOverEpilogue(L, LI, *SE, *AC, TLI, DT, &LVL, IAI)) + TailFoldingInfo TFI(TLI, &LVL, IAI); + if (TTI->preferPredicateOverEpilogue(&TFI)) return CM_ScalarEpilogueNotNeededUsePredicate; return CM_ScalarEpilogueAllowed; @@ -9880,9 +9848,29 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) { if (hasVectorValue(Def, Part)) return Data.PerPartOutput[Def][Part]; + auto GetBroadcastInstrs = [this, Def](Value *V) { + bool SafeToHoist = Def->isDefinedOutsideVectorRegions(); + if (VF.isScalar()) + return V; + // Place the code for broadcasting invariant variables in the new preheader. + IRBuilder<>::InsertPointGuard Guard(Builder); + if (SafeToHoist) { + BasicBlock *LoopVectorPreHeader = CFG.VPBB2IRBB[cast<VPBasicBlock>( + Plan->getVectorLoopRegion()->getSinglePredecessor())]; + if (LoopVectorPreHeader) + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + } + + // Place the code for broadcasting invariant variables in the new preheader. + // Broadcast the scalar into all locations in the vector. + Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast"); + + return Shuf; + }; + if (!hasScalarValue(Def, {Part, 0})) { Value *IRV = Def->getLiveInIRValue(); - Value *B = ILV->getBroadcastInstrs(IRV); + Value *B = GetBroadcastInstrs(IRV); set(Def, B, Part); return B; } @@ -9900,9 +9888,11 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) { unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1; // Check if there is a scalar value for the selected lane. if (!hasScalarValue(Def, {Part, LastLane})) { - // At the moment, VPWidenIntOrFpInductionRecipes and VPScalarIVStepsRecipes can also be uniform. + // At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and + // VPExpandSCEVRecipes can also be uniform. assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) || - isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe())) && + isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe()) || + isa<VPExpandSCEVRecipe>(Def->getDefiningRecipe())) && "unexpected recipe found to be invariant"); IsUniform = true; LastLane = 0; @@ -9927,7 +9917,7 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) { // State, we will only generate the insertelements once. Value *VectorValue = nullptr; if (IsUniform) { - VectorValue = ILV->getBroadcastInstrs(ScalarValue); + VectorValue = GetBroadcastInstrs(ScalarValue); set(Def, VectorValue, Part); } else { // Initialize packing with insertelements to start from undef. @@ -9962,15 +9952,15 @@ static bool processLoopInVPlanNativePath( Function *F = L->getHeader()->getParent(); InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL->getLAI()); - ScalarEpilogueLowering SEL = getScalarEpilogueLowering( - F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, *LVL, &IAI); + ScalarEpilogueLowering SEL = + getScalarEpilogueLowering(F, L, Hints, PSI, BFI, TTI, TLI, *LVL, &IAI); LoopVectorizationCostModel CM(SEL, L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F, &Hints, IAI); // Use the planner for outer loop vectorization. // TODO: CM is not used at this point inside the planner. Turn CM into an // optional argument if we don't need it in the future. - LoopVectorizationPlanner LVP(L, LI, TLI, TTI, LVL, CM, IAI, PSE, Hints, ORE); + LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, LVL, CM, IAI, PSE, Hints, ORE); // Get user vectorization factor. ElementCount UserVF = Hints.getWidth(); @@ -10231,8 +10221,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Check the function attributes and profiles to find out if this function // should be optimized for size. - ScalarEpilogueLowering SEL = getScalarEpilogueLowering( - F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, LVL, &IAI); + ScalarEpilogueLowering SEL = + getScalarEpilogueLowering(F, L, Hints, PSI, BFI, TTI, TLI, LVL, &IAI); // Check the loop for a trip count threshold: vectorize loops with a tiny trip // count by optimizing for size, to minimize overheads. @@ -10309,11 +10299,9 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Use the cost model. LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, &Hints, IAI); - CM.collectValuesToIgnore(); - CM.collectElementTypesForWidening(); - // Use the planner for vectorization. - LoopVectorizationPlanner LVP(L, LI, TLI, TTI, &LVL, CM, IAI, PSE, Hints, ORE); + LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, &LVL, CM, IAI, PSE, Hints, + ORE); // Get user vectorization factor and interleave count. ElementCount UserVF = Hints.getWidth(); @@ -10342,7 +10330,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; if (!ForceVectorization && - !areRuntimeChecksProfitable(Checks, VF, CM.getVScaleForTuning(), L, + !areRuntimeChecksProfitable(Checks, VF, getVScaleForTuning(L, *TTI), L, *PSE.getSE())) { ORE->emit([&]() { return OptimizationRemarkAnalysisAliasing( @@ -10464,7 +10452,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Consider vectorizing the epilogue too if it's profitable. VectorizationFactor EpilogueVF = - CM.selectEpilogueVectorizationFactor(VF.Width, LVP); + LVP.selectEpilogueVectorizationFactor(VF.Width, IC); if (EpilogueVF.Width.isVector()) { // The first pass vectorizes the main loop and creates a scalar epilogue @@ -10475,8 +10463,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { EPI, &LVL, &CM, BFI, PSI, Checks); VPlan &BestMainPlan = LVP.getBestPlanFor(EPI.MainLoopVF); - LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF, BestMainPlan, MainILV, - DT, true); + auto ExpandedSCEVs = LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF, + BestMainPlan, MainILV, DT, true); ++LoopsVectorized; // Second pass vectorizes the epilogue and adjusts the control flow @@ -10492,6 +10480,21 @@ bool LoopVectorizePass::processLoop(Loop *L) { VPBasicBlock *Header = VectorLoop->getEntryBasicBlock(); Header->setName("vec.epilog.vector.body"); + // Re-use the trip count and steps expanded for the main loop, as + // skeleton creation needs it as a value that dominates both the scalar + // and vector epilogue loops + // TODO: This is a workaround needed for epilogue vectorization and it + // should be removed once induction resume value creation is done + // directly in VPlan. + EpilogILV.setTripCount(MainILV.getTripCount()); + for (auto &R : make_early_inc_range(*BestEpiPlan.getPreheader())) { + auto *ExpandR = cast<VPExpandSCEVRecipe>(&R); + auto *ExpandedVal = BestEpiPlan.getVPValueOrAddLiveIn( + ExpandedSCEVs.find(ExpandR->getSCEV())->second); + ExpandR->replaceAllUsesWith(ExpandedVal); + ExpandR->eraseFromParent(); + } + // Ensure that the start values for any VPWidenIntOrFpInductionRecipe, // VPWidenPointerInductionRecipe and VPReductionPHIRecipes are updated // before vectorizing the epilogue loop. @@ -10520,15 +10523,16 @@ bool LoopVectorizePass::processLoop(Loop *L) { } ResumeV = MainILV.createInductionResumeValue( - IndPhi, *ID, {EPI.MainLoopIterationCountCheck}); + IndPhi, *ID, getExpandedStep(*ID, ExpandedSCEVs), + {EPI.MainLoopIterationCountCheck}); } assert(ResumeV && "Must have a resume value"); - VPValue *StartVal = BestEpiPlan.getOrAddExternalDef(ResumeV); + VPValue *StartVal = BestEpiPlan.getVPValueOrAddLiveIn(ResumeV); cast<VPHeaderPHIRecipe>(&R)->setStartValue(StartVal); } LVP.executePlan(EPI.EpilogueVF, EPI.EpilogueUF, BestEpiPlan, EpilogILV, - DT, true); + DT, true, &ExpandedSCEVs); ++LoopsEpilogueVectorized; if (!MainILV.areSafetyChecksAdded()) @@ -10581,14 +10585,14 @@ bool LoopVectorizePass::processLoop(Loop *L) { LoopVectorizeResult LoopVectorizePass::runImpl( Function &F, ScalarEvolution &SE_, LoopInfo &LI_, TargetTransformInfo &TTI_, - DominatorTree &DT_, BlockFrequencyInfo &BFI_, TargetLibraryInfo *TLI_, + DominatorTree &DT_, BlockFrequencyInfo *BFI_, TargetLibraryInfo *TLI_, DemandedBits &DB_, AssumptionCache &AC_, LoopAccessInfoManager &LAIs_, OptimizationRemarkEmitter &ORE_, ProfileSummaryInfo *PSI_) { SE = &SE_; LI = &LI_; TTI = &TTI_; DT = &DT_; - BFI = &BFI_; + BFI = BFI_; TLI = TLI_; AC = &AC_; LAIs = &LAIs_; @@ -10604,7 +10608,7 @@ LoopVectorizeResult LoopVectorizePass::runImpl( // vector registers, loop vectorization may still enable scalar // interleaving. if (!TTI->getNumberOfRegisters(TTI->getRegisterClassForType(true)) && - TTI->getMaxInterleaveFactor(1) < 2) + TTI->getMaxInterleaveFactor(ElementCount::getFixed(1)) < 2) return LoopVectorizeResult(false, false); bool Changed = false, CFGChanged = false; @@ -10656,7 +10660,6 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &BFI = AM.getResult<BlockFrequencyAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DB = AM.getResult<DemandedBitsAnalysis>(F); @@ -10666,12 +10669,20 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); ProfileSummaryInfo *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + BlockFrequencyInfo *BFI = nullptr; + if (PSI && PSI->hasProfileSummary()) + BFI = &AM.getResult<BlockFrequencyAnalysis>(F); LoopVectorizeResult Result = runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AC, LAIs, ORE, PSI); if (!Result.MadeAnyChange) return PreservedAnalyses::all(); PreservedAnalyses PA; + if (isAssignmentTrackingEnabled(*F.getParent())) { + for (auto &BB : F) + RemoveRedundantDbgInstrs(&BB); + } + // We currently do not preserve loopinfo/dominator analyses with outer loop // vectorization. Until this is addressed, mark these analyses as preserved // only for non-VPlan-native path. @@ -10679,6 +10690,11 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, if (!EnableVPlanNativePath) { PA.preserve<LoopAnalysis>(); PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<ScalarEvolutionAnalysis>(); + +#ifdef EXPENSIVE_CHECKS + SE.verify(); +#endif } if (Result.MadeCFGChange) { @@ -10699,8 +10715,8 @@ void LoopVectorizePass::printPipeline( static_cast<PassInfoMixin<LoopVectorizePass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; OS << (InterleaveOnlyWhenForced ? "" : "no-") << "interleave-forced-only;"; OS << (VectorizeOnlyWhenForced ? "" : "no-") << "vectorize-forced-only;"; - OS << ">"; + OS << '>'; } diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index e3eb6b1804e7..821a3fa22a85 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -87,7 +87,6 @@ #include "llvm/Transforms/Utils/InjectTLIMappings.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Vectorize.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -126,6 +125,13 @@ static cl::opt<bool> ShouldStartVectorizeHorAtStore( cl::desc( "Attempt to vectorize horizontal reductions feeding into a store")); +// NOTE: If AllowHorRdxIdenityOptimization is true, the optimization will run +// even if we match a reduction but do not vectorize in the end. +static cl::opt<bool> AllowHorRdxIdenityOptimization( + "slp-optimize-identity-hor-reduction-ops", cl::init(true), cl::Hidden, + cl::desc("Allow optimization of original scalar identity operations on " + "matched horizontal reductions.")); + static cl::opt<int> MaxVectorRegSizeOption("slp-max-reg-size", cl::init(128), cl::Hidden, cl::desc("Attempt to vectorize for this register size in bits")); @@ -287,7 +293,7 @@ static bool isCommutative(Instruction *I) { /// \returns inserting index of InsertElement or InsertValue instruction, /// using Offset as base offset for index. static std::optional<unsigned> getInsertIndex(const Value *InsertInst, - unsigned Offset = 0) { + unsigned Offset = 0) { int Index = Offset; if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) { const auto *VT = dyn_cast<FixedVectorType>(IE->getType()); @@ -342,16 +348,16 @@ enum class UseMask { static SmallBitVector buildUseMask(int VF, ArrayRef<int> Mask, UseMask MaskArg) { SmallBitVector UseMask(VF, true); - for (auto P : enumerate(Mask)) { - if (P.value() == UndefMaskElem) { + for (auto [Idx, Value] : enumerate(Mask)) { + if (Value == PoisonMaskElem) { if (MaskArg == UseMask::UndefsAsMask) - UseMask.reset(P.index()); + UseMask.reset(Idx); continue; } - if (MaskArg == UseMask::FirstArg && P.value() < VF) - UseMask.reset(P.value()); - else if (MaskArg == UseMask::SecondArg && P.value() >= VF) - UseMask.reset(P.value() - VF); + if (MaskArg == UseMask::FirstArg && Value < VF) + UseMask.reset(Value); + else if (MaskArg == UseMask::SecondArg && Value >= VF) + UseMask.reset(Value - VF); } return UseMask; } @@ -374,9 +380,9 @@ static SmallBitVector isUndefVector(const Value *V, if (!UseMask.empty()) { const Value *Base = V; while (auto *II = dyn_cast<InsertElementInst>(Base)) { + Base = II->getOperand(0); if (isa<T>(II->getOperand(1))) continue; - Base = II->getOperand(0); std::optional<unsigned> Idx = getInsertIndex(II); if (!Idx) continue; @@ -461,7 +467,7 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { Value *Vec2 = nullptr; enum ShuffleMode { Unknown, Select, Permute }; ShuffleMode CommonShuffleMode = Unknown; - Mask.assign(VL.size(), UndefMaskElem); + Mask.assign(VL.size(), PoisonMaskElem); for (unsigned I = 0, E = VL.size(); I < E; ++I) { // Undef can be represented as an undef element in a vector. if (isa<UndefValue>(VL[I])) @@ -533,6 +539,117 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) { return *EI->idx_begin(); } +/// Tries to find extractelement instructions with constant indices from fixed +/// vector type and gather such instructions into a bunch, which highly likely +/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was +/// successful, the matched scalars are replaced by poison values in \p VL for +/// future analysis. +static std::optional<TTI::ShuffleKind> +tryToGatherExtractElements(SmallVectorImpl<Value *> &VL, + SmallVectorImpl<int> &Mask) { + // Scan list of gathered scalars for extractelements that can be represented + // as shuffles. + MapVector<Value *, SmallVector<int>> VectorOpToIdx; + SmallVector<int> UndefVectorExtracts; + for (int I = 0, E = VL.size(); I < E; ++I) { + auto *EI = dyn_cast<ExtractElementInst>(VL[I]); + if (!EI) { + if (isa<UndefValue>(VL[I])) + UndefVectorExtracts.push_back(I); + continue; + } + auto *VecTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType()); + if (!VecTy || !isa<ConstantInt, UndefValue>(EI->getIndexOperand())) + continue; + std::optional<unsigned> Idx = getExtractIndex(EI); + // Undefined index. + if (!Idx) { + UndefVectorExtracts.push_back(I); + continue; + } + SmallBitVector ExtractMask(VecTy->getNumElements(), true); + ExtractMask.reset(*Idx); + if (isUndefVector(EI->getVectorOperand(), ExtractMask).all()) { + UndefVectorExtracts.push_back(I); + continue; + } + VectorOpToIdx[EI->getVectorOperand()].push_back(I); + } + // Sort the vector operands by the maximum number of uses in extractelements. + MapVector<unsigned, SmallVector<Value *>> VFToVector; + for (const auto &Data : VectorOpToIdx) + VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()] + .push_back(Data.first); + for (auto &Data : VFToVector) { + stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) { + return VectorOpToIdx.find(V1)->second.size() > + VectorOpToIdx.find(V2)->second.size(); + }); + } + // Find the best pair of the vectors with the same number of elements or a + // single vector. + const int UndefSz = UndefVectorExtracts.size(); + unsigned SingleMax = 0; + Value *SingleVec = nullptr; + unsigned PairMax = 0; + std::pair<Value *, Value *> PairVec(nullptr, nullptr); + for (auto &Data : VFToVector) { + Value *V1 = Data.second.front(); + if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) { + SingleMax = VectorOpToIdx[V1].size() + UndefSz; + SingleVec = V1; + } + Value *V2 = nullptr; + if (Data.second.size() > 1) + V2 = *std::next(Data.second.begin()); + if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + + UndefSz) { + PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz; + PairVec = std::make_pair(V1, V2); + } + } + if (SingleMax == 0 && PairMax == 0 && UndefSz == 0) + return std::nullopt; + // Check if better to perform a shuffle of 2 vectors or just of a single + // vector. + SmallVector<Value *> SavedVL(VL.begin(), VL.end()); + SmallVector<Value *> GatheredExtracts( + VL.size(), PoisonValue::get(VL.front()->getType())); + if (SingleMax >= PairMax && SingleMax) { + for (int Idx : VectorOpToIdx[SingleVec]) + std::swap(GatheredExtracts[Idx], VL[Idx]); + } else { + for (Value *V : {PairVec.first, PairVec.second}) + for (int Idx : VectorOpToIdx[V]) + std::swap(GatheredExtracts[Idx], VL[Idx]); + } + // Add extracts from undefs too. + for (int Idx : UndefVectorExtracts) + std::swap(GatheredExtracts[Idx], VL[Idx]); + // Check that gather of extractelements can be represented as just a + // shuffle of a single/two vectors the scalars are extracted from. + std::optional<TTI::ShuffleKind> Res = + isFixedVectorShuffle(GatheredExtracts, Mask); + if (!Res) { + // TODO: try to check other subsets if possible. + // Restore the original VL if attempt was not successful. + VL.swap(SavedVL); + return std::nullopt; + } + // Restore unused scalars from mask, if some of the extractelements were not + // selected for shuffle. + for (int I = 0, E = GatheredExtracts.size(); I < E; ++I) { + auto *EI = dyn_cast<ExtractElementInst>(VL[I]); + if (!EI || !isa<FixedVectorType>(EI->getVectorOperandType()) || + !isa<ConstantInt, UndefValue>(EI->getIndexOperand()) || + is_contained(UndefVectorExtracts, I)) + continue; + if (Mask[I] == PoisonMaskElem && !isa<PoisonValue>(GatheredExtracts[I])) + std::swap(VL[I], GatheredExtracts[I]); + } + return Res; +} + namespace { /// Main data required for vectorization of instructions. @@ -829,18 +946,29 @@ static bool isSimple(Instruction *I) { } /// Shuffles \p Mask in accordance with the given \p SubMask. -static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask) { +/// \param ExtendingManyInputs Supports reshuffling of the mask with not only +/// one but two input vectors. +static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask, + bool ExtendingManyInputs = false) { if (SubMask.empty()) return; + assert( + (!ExtendingManyInputs || SubMask.size() > Mask.size() || + // Check if input scalars were extended to match the size of other node. + (SubMask.size() == Mask.size() && + std::all_of(std::next(Mask.begin(), Mask.size() / 2), Mask.end(), + [](int Idx) { return Idx == PoisonMaskElem; }))) && + "SubMask with many inputs support must be larger than the mask."); if (Mask.empty()) { Mask.append(SubMask.begin(), SubMask.end()); return; } - SmallVector<int> NewMask(SubMask.size(), UndefMaskElem); + SmallVector<int> NewMask(SubMask.size(), PoisonMaskElem); int TermValue = std::min(Mask.size(), SubMask.size()); for (int I = 0, E = SubMask.size(); I < E; ++I) { - if (SubMask[I] >= TermValue || SubMask[I] == UndefMaskElem || - Mask[SubMask[I]] >= TermValue) + if (SubMask[I] == PoisonMaskElem || + (!ExtendingManyInputs && + (SubMask[I] >= TermValue || Mask[SubMask[I]] >= TermValue))) continue; NewMask[I] = Mask[SubMask[I]]; } @@ -887,7 +1015,7 @@ static void inversePermutation(ArrayRef<unsigned> Indices, SmallVectorImpl<int> &Mask) { Mask.clear(); const unsigned E = Indices.size(); - Mask.resize(E, UndefMaskElem); + Mask.resize(E, PoisonMaskElem); for (unsigned I = 0; I < E; ++I) Mask[Indices[I]] = I; } @@ -900,7 +1028,7 @@ static void reorderScalars(SmallVectorImpl<Value *> &Scalars, UndefValue::get(Scalars.front()->getType())); Prev.swap(Scalars); for (unsigned I = 0, E = Prev.size(); I < E; ++I) - if (Mask[I] != UndefMaskElem) + if (Mask[I] != PoisonMaskElem) Scalars[Mask[I]] = Prev[I]; } @@ -962,6 +1090,7 @@ namespace slpvectorizer { class BoUpSLP { struct TreeEntry; struct ScheduleData; + class ShuffleCostEstimator; class ShuffleInstructionBuilder; public: @@ -1006,8 +1135,12 @@ public: /// Vectorize the tree but with the list of externally used values \p /// ExternallyUsedValues. Values in this MapVector can be replaced but the /// generated extractvalue instructions. - Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, - Instruction *ReductionRoot = nullptr); + /// \param ReplacedExternals containd list of replaced external values + /// {scalar, replace} after emitting extractelement for external uses. + Value * + vectorizeTree(const ExtraValueToDebugLocsMap &ExternallyUsedValues, + SmallVectorImpl<std::pair<Value *, Value *>> &ReplacedExternals, + Instruction *ReductionRoot = nullptr); /// \returns the cost incurred by unwanted spills and fills, caused by /// holding live values over call sites. @@ -1025,24 +1158,18 @@ public: /// Construct a vectorizable tree that starts at \p Roots. void buildTree(ArrayRef<Value *> Roots); - /// Checks if the very first tree node is going to be vectorized. - bool isVectorizedFirstNode() const { - return !VectorizableTree.empty() && - VectorizableTree.front()->State == TreeEntry::Vectorize; - } - - /// Returns the main instruction for the very first node. - Instruction *getFirstNodeMainOp() const { - assert(!VectorizableTree.empty() && "No tree to get the first node from"); - return VectorizableTree.front()->getMainOp(); - } - /// Returns whether the root node has in-tree uses. bool doesRootHaveInTreeUses() const { return !VectorizableTree.empty() && !VectorizableTree.front()->UserTreeIndices.empty(); } + /// Return the scalars of the root node. + ArrayRef<Value *> getRootNodeScalars() const { + assert(!VectorizableTree.empty() && "No graph to get the first node from"); + return VectorizableTree.front()->Scalars; + } + /// Builds external uses of the vectorized scalars, i.e. the list of /// vectorized scalars to be extracted, their lanes and their scalar users. \p /// ExternallyUsedValues contains additional list of external uses to handle @@ -1064,6 +1191,8 @@ public: MinBWs.clear(); InstrElementSize.clear(); UserIgnoreList = nullptr; + PostponedGathers.clear(); + ValueToGatherNodes.clear(); } unsigned getTreeSize() const { return VectorizableTree.size(); } @@ -1083,9 +1212,12 @@ public: /// Gets reordering data for the given tree entry. If the entry is vectorized /// - just return ReorderIndices, otherwise check if the scalars can be /// reordered and return the most optimal order. + /// \return std::nullopt if ordering is not important, empty order, if + /// identity order is important, or the actual order. /// \param TopToBottom If true, include the order of vectorized stores and /// insertelement nodes, otherwise skip them. - std::optional<OrdersType> getReorderingData(const TreeEntry &TE, bool TopToBottom); + std::optional<OrdersType> getReorderingData(const TreeEntry &TE, + bool TopToBottom); /// Reorders the current graph to the most profitable order starting from the /// root node to the leaf nodes. The best order is chosen only from the nodes @@ -1328,8 +1460,14 @@ public: ConstantInt *Ex1Idx; if (match(V1, m_ExtractElt(m_Value(EV1), m_ConstantInt(Ex1Idx)))) { // Undefs are always profitable for extractelements. + // Compiler can easily combine poison and extractelement <non-poison> or + // undef and extractelement <poison>. But combining undef + + // extractelement <non-poison-but-may-produce-poison> requires some + // extra operations. if (isa<UndefValue>(V2)) - return LookAheadHeuristics::ScoreConsecutiveExtracts; + return (isa<PoisonValue>(V2) || isUndefVector(EV1).all()) + ? LookAheadHeuristics::ScoreConsecutiveExtracts + : LookAheadHeuristics::ScoreSameOpcode; Value *EV2 = nullptr; ConstantInt *Ex2Idx = nullptr; if (match(V2, @@ -1683,9 +1821,10 @@ public: // Search all operands in Ops[*][Lane] for the one that matches best // Ops[OpIdx][LastLane] and return its opreand index. // If no good match can be found, return std::nullopt. - std::optional<unsigned> getBestOperand(unsigned OpIdx, int Lane, int LastLane, - ArrayRef<ReorderingMode> ReorderingModes, - ArrayRef<Value *> MainAltOps) { + std::optional<unsigned> + getBestOperand(unsigned OpIdx, int Lane, int LastLane, + ArrayRef<ReorderingMode> ReorderingModes, + ArrayRef<Value *> MainAltOps) { unsigned NumOperands = getNumOperands(); // The operand of the previous lane at OpIdx. @@ -2299,7 +2438,8 @@ private: /// \returns the cost of the vectorizable entry. InstructionCost getEntryCost(const TreeEntry *E, - ArrayRef<Value *> VectorizedVals); + ArrayRef<Value *> VectorizedVals, + SmallPtrSetImpl<Value *> &CheckedExtracts); /// This is the recursive part of buildTree. void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, @@ -2323,15 +2463,13 @@ private: /// Create a new vector from a list of scalar values. Produces a sequence /// which exploits values reused across lanes, and arranges the inserts /// for ease of later optimization. - Value *createBuildVector(const TreeEntry *E); + template <typename BVTy, typename ResTy, typename... Args> + ResTy processBuildVector(const TreeEntry *E, Args &...Params); - /// \returns the scalarization cost for this type. Scalarization in this - /// context means the creation of vectors from a group of scalars. If \p - /// NeedToShuffle is true, need to add a cost of reshuffling some of the - /// vector elements. - InstructionCost getGatherCost(FixedVectorType *Ty, - const APInt &ShuffledIndices, - bool NeedToShuffle) const; + /// Create a new vector from a list of scalar values. Produces a sequence + /// which exploits values reused across lanes, and arranges the inserts + /// for ease of later optimization. + Value *createBuildVector(const TreeEntry *E); /// Returns the instruction in the bundle, which can be used as a base point /// for scheduling. Usually it is the last instruction in the bundle, except @@ -2354,14 +2492,16 @@ private: /// \returns the scalarization cost for this list of values. Assuming that /// this subtree gets vectorized, we may need to extract the values from the /// roots. This method calculates the cost of extracting the values. - InstructionCost getGatherCost(ArrayRef<Value *> VL) const; + /// \param ForPoisonSrc true if initial vector is poison, false otherwise. + InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc) const; /// Set the Builder insert point to one after the last instruction in /// the bundle void setInsertPointAfterBundle(const TreeEntry *E); - /// \returns a vector from a collection of scalars in \p VL. - Value *gather(ArrayRef<Value *> VL); + /// \returns a vector from a collection of scalars in \p VL. if \p Root is not + /// specified, the starting vector value is poison. + Value *gather(ArrayRef<Value *> VL, Value *Root); /// \returns whether the VectorizableTree is fully vectorizable and will /// be beneficial even the tree height is tiny. @@ -2400,6 +2540,14 @@ private: using VecTreeTy = SmallVector<std::unique_ptr<TreeEntry>, 8>; TreeEntry(VecTreeTy &Container) : Container(Container) {} + /// \returns Common mask for reorder indices and reused scalars. + SmallVector<int> getCommonMask() const { + SmallVector<int> Mask; + inversePermutation(ReorderIndices, Mask); + ::addMask(Mask, ReuseShuffleIndices); + return Mask; + } + /// \returns true if the scalars in VL are equal to this entry. bool isSame(ArrayRef<Value *> VL) const { auto &&IsSame = [VL](ArrayRef<Value *> Scalars, ArrayRef<int> Mask) { @@ -2409,8 +2557,8 @@ private: std::equal(VL.begin(), VL.end(), Mask.begin(), [Scalars](Value *V, int Idx) { return (isa<UndefValue>(V) && - Idx == UndefMaskElem) || - (Idx != UndefMaskElem && V == Scalars[Idx]); + Idx == PoisonMaskElem) || + (Idx != PoisonMaskElem && V == Scalars[Idx]); }); }; if (!ReorderIndices.empty()) { @@ -2471,7 +2619,7 @@ private: ValueList Scalars; /// The Scalars are vectorized into this value. It is initialized to Null. - Value *VectorizedValue = nullptr; + WeakTrackingVH VectorizedValue = nullptr; /// Do we need to gather this sequence or vectorize it /// (either with vector instruction or with scatter/gather @@ -2684,20 +2832,22 @@ private: #ifndef NDEBUG void dumpTreeCosts(const TreeEntry *E, InstructionCost ReuseShuffleCost, - InstructionCost VecCost, - InstructionCost ScalarCost) const { - dbgs() << "SLP: Calculated costs for Tree:\n"; E->dump(); + InstructionCost VecCost, InstructionCost ScalarCost, + StringRef Banner) const { + dbgs() << "SLP: " << Banner << ":\n"; + E->dump(); dbgs() << "SLP: Costs:\n"; dbgs() << "SLP: ReuseShuffleCost = " << ReuseShuffleCost << "\n"; dbgs() << "SLP: VectorCost = " << VecCost << "\n"; dbgs() << "SLP: ScalarCost = " << ScalarCost << "\n"; - dbgs() << "SLP: ReuseShuffleCost + VecCost - ScalarCost = " << - ReuseShuffleCost + VecCost - ScalarCost << "\n"; + dbgs() << "SLP: ReuseShuffleCost + VecCost - ScalarCost = " + << ReuseShuffleCost + VecCost - ScalarCost << "\n"; } #endif /// Create a new VectorizableTree entry. - TreeEntry *newTreeEntry(ArrayRef<Value *> VL, std::optional<ScheduleData *> Bundle, + TreeEntry *newTreeEntry(ArrayRef<Value *> VL, + std::optional<ScheduleData *> Bundle, const InstructionsState &S, const EdgeInfo &UserTreeIdx, ArrayRef<int> ReuseShuffleIndices = std::nullopt, @@ -2791,8 +2941,14 @@ private: return ScalarToTreeEntry.lookup(V); } + /// Checks if the specified list of the instructions/values can be vectorized + /// and fills required data before actual scheduling of the instructions. + TreeEntry::EntryState getScalarsVectorizationState( + InstructionsState &S, ArrayRef<Value *> VL, bool IsScatterVectorizeUserTE, + OrdersType &CurrentOrder, SmallVectorImpl<Value *> &PointerOps) const; + /// Maps a specific scalar to its tree entry. - SmallDenseMap<Value*, TreeEntry *> ScalarToTreeEntry; + SmallDenseMap<Value *, TreeEntry *> ScalarToTreeEntry; /// Maps a value to the proposed vectorizable size. SmallDenseMap<Value *, unsigned> InstrElementSize; @@ -2808,6 +2964,15 @@ private: /// pre-gather them before. DenseMap<const TreeEntry *, Instruction *> EntryToLastInstruction; + /// List of gather nodes, depending on other gather/vector nodes, which should + /// be emitted after the vector instruction emission process to correctly + /// handle order of the vector instructions and shuffles. + SetVector<const TreeEntry *> PostponedGathers; + + using ValueToGatherNodesMap = + DenseMap<Value *, SmallPtrSet<const TreeEntry *, 4>>; + ValueToGatherNodesMap ValueToGatherNodes; + /// This POD struct describes one external user in the vectorized tree. struct ExternalUser { ExternalUser(Value *S, llvm::User *U, int L) @@ -3235,7 +3400,6 @@ private: << "SLP: gets ready (ctl): " << *DepBundle << "\n"); } } - } } @@ -3579,7 +3743,7 @@ static void reorderReuses(SmallVectorImpl<int> &Reuses, ArrayRef<int> Mask) { SmallVector<int> Prev(Reuses.begin(), Reuses.end()); Prev.swap(Reuses); for (unsigned I = 0, E = Prev.size(); I < E; ++I) - if (Mask[I] != UndefMaskElem) + if (Mask[I] != PoisonMaskElem) Reuses[Mask[I]] = Prev[I]; } @@ -3603,7 +3767,7 @@ static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) { } Order.assign(Mask.size(), Mask.size()); for (unsigned I = 0, E = Mask.size(); I < E; ++I) - if (MaskOrder[I] != UndefMaskElem) + if (MaskOrder[I] != PoisonMaskElem) Order[MaskOrder[I]] = I; fixupOrderingIndices(Order); } @@ -3653,10 +3817,8 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { return false; return true; }; - if (IsIdentityOrder(CurrentOrder)) { - CurrentOrder.clear(); - return CurrentOrder; - } + if (IsIdentityOrder(CurrentOrder)) + return OrdersType(); auto *It = CurrentOrder.begin(); for (unsigned I = 0; I < NumScalars;) { if (UsedPositions.test(I)) { @@ -3669,7 +3831,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { } ++It; } - return CurrentOrder; + return std::move(CurrentOrder); } return std::nullopt; } @@ -3779,9 +3941,9 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, return LoadsState::Gather; } -bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, - const DataLayout &DL, ScalarEvolution &SE, - SmallVectorImpl<unsigned> &SortedIndices) { +static bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, + const DataLayout &DL, ScalarEvolution &SE, + SmallVectorImpl<unsigned> &SortedIndices) { assert(llvm::all_of( VL, [](const Value *V) { return V->getType()->isPointerTy(); }) && "Expected list of pointer operands."); @@ -3825,7 +3987,7 @@ bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, return std::get<1>(X) < std::get<1>(Y); }); int InitialOffset = std::get<1>(Vec[0]); - AnyConsecutive |= all_of(enumerate(Vec), [InitialOffset](auto &P) { + AnyConsecutive |= all_of(enumerate(Vec), [InitialOffset](const auto &P) { return std::get<1>(P.value()) == int(P.index()) + InitialOffset; }); } @@ -3862,7 +4024,7 @@ BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) { BoUpSLP::OrdersType Order; if (clusterSortPtrAccesses(Ptrs, ScalarTy, *DL, *SE, Order)) - return Order; + return std::move(Order); return std::nullopt; } @@ -3888,31 +4050,35 @@ static bool areTwoInsertFromSameBuildVector( // Go through the vector operand of insertelement instructions trying to find // either VU as the original vector for IE2 or V as the original vector for // IE1. + SmallSet<int, 8> ReusedIdx; + bool IsReusedIdx = false; do { - if (IE2 == VU) + if (IE2 == VU && !IE1) return VU->hasOneUse(); - if (IE1 == V) + if (IE1 == V && !IE2) return V->hasOneUse(); - if (IE1) { - if ((IE1 != VU && !IE1->hasOneUse()) || - getInsertIndex(IE1).value_or(*Idx2) == *Idx2) + if (IE1 && IE1 != V) { + IsReusedIdx |= + !ReusedIdx.insert(getInsertIndex(IE1).value_or(*Idx2)).second; + if ((IE1 != VU && !IE1->hasOneUse()) || IsReusedIdx) IE1 = nullptr; else IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1)); } - if (IE2) { - if ((IE2 != V && !IE2->hasOneUse()) || - getInsertIndex(IE2).value_or(*Idx1) == *Idx1) + if (IE2 && IE2 != VU) { + IsReusedIdx |= + !ReusedIdx.insert(getInsertIndex(IE2).value_or(*Idx1)).second; + if ((IE2 != V && !IE2->hasOneUse()) || IsReusedIdx) IE2 = nullptr; else IE2 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE2)); } - } while (IE1 || IE2); + } while (!IsReusedIdx && (IE1 || IE2)); return false; } -std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, - bool TopToBottom) { +std::optional<BoUpSLP::OrdersType> +BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // No need to reorder if need to shuffle reuses, still need to shuffle the // node. if (!TE.ReuseShuffleIndices.empty()) { @@ -3936,14 +4102,14 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T std::optional<unsigned> Idx = getExtractIndex(cast<Instruction>(V)); return Idx && *Idx < Sz; })) { - SmallVector<int> ReorderMask(Sz, UndefMaskElem); + SmallVector<int> ReorderMask(Sz, PoisonMaskElem); if (TE.ReorderIndices.empty()) std::iota(ReorderMask.begin(), ReorderMask.end(), 0); else inversePermutation(TE.ReorderIndices, ReorderMask); for (unsigned I = 0; I < VF; ++I) { int &Idx = ReusedMask[I]; - if (Idx == UndefMaskElem) + if (Idx == PoisonMaskElem) continue; Value *V = TE.Scalars[ReorderMask[Idx]]; std::optional<unsigned> EI = getExtractIndex(cast<Instruction>(V)); @@ -3958,7 +4124,7 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T for (unsigned K = 0; K < VF; K += Sz) { OrdersType CurrentOrder(TE.ReorderIndices); SmallVector<int> SubMask{ArrayRef(ReusedMask).slice(K, Sz)}; - if (SubMask.front() == UndefMaskElem) + if (SubMask.front() == PoisonMaskElem) std::iota(SubMask.begin(), SubMask.end(), 0); reorderOrder(CurrentOrder, SubMask); transform(CurrentOrder, It, [K](unsigned Pos) { return Pos + K; }); @@ -3966,8 +4132,8 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T } if (all_of(enumerate(ResOrder), [](const auto &Data) { return Data.index() == Data.value(); })) - return {}; // Use identity order. - return ResOrder; + return std::nullopt; // No need to reorder. + return std::move(ResOrder); } if (TE.State == TreeEntry::Vectorize && (isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE.getMainOp()) || @@ -3976,6 +4142,8 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T return TE.ReorderIndices; if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::PHI) { auto PHICompare = [](llvm::Value *V1, llvm::Value *V2) { + if (V1 == V2) + return false; if (!V1->hasOneUse() || !V2->hasOneUse()) return false; auto *FirstUserOfPhi1 = cast<Instruction>(*V1->user_begin()); @@ -4023,8 +4191,8 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T for (unsigned Id = 0, Sz = Phis.size(); Id < Sz; ++Id) ResOrder[Id] = PhiToId[Phis[Id]]; if (IsIdentityOrder(ResOrder)) - return {}; - return ResOrder; + return std::nullopt; // No need to reorder. + return std::move(ResOrder); } if (TE.State == TreeEntry::NeedToGather) { // TODO: add analysis of other gather nodes with extractelement @@ -4050,7 +4218,42 @@ std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &T if (Reuse || !CurrentOrder.empty()) { if (!CurrentOrder.empty()) fixupOrderingIndices(CurrentOrder); - return CurrentOrder; + return std::move(CurrentOrder); + } + } + // If the gather node is <undef, v, .., poison> and + // insertelement poison, v, 0 [+ permute] + // is cheaper than + // insertelement poison, v, n - try to reorder. + // If rotating the whole graph, exclude the permute cost, the whole graph + // might be transformed. + int Sz = TE.Scalars.size(); + if (isSplat(TE.Scalars) && !allConstant(TE.Scalars) && + count_if(TE.Scalars, UndefValue::classof) == Sz - 1) { + const auto *It = + find_if(TE.Scalars, [](Value *V) { return !isConstant(V); }); + if (It == TE.Scalars.begin()) + return OrdersType(); + auto *Ty = FixedVectorType::get(TE.Scalars.front()->getType(), Sz); + if (It != TE.Scalars.end()) { + OrdersType Order(Sz, Sz); + unsigned Idx = std::distance(TE.Scalars.begin(), It); + Order[Idx] = 0; + fixupOrderingIndices(Order); + SmallVector<int> Mask; + inversePermutation(Order, Mask); + InstructionCost PermuteCost = + TopToBottom + ? 0 + : TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, Mask); + InstructionCost InsertFirstCost = TTI->getVectorInstrCost( + Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, 0, + PoisonValue::get(Ty), *It); + InstructionCost InsertIdxCost = TTI->getVectorInstrCost( + Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, Idx, + PoisonValue::get(Ty), *It); + if (InsertFirstCost + PermuteCost < InsertIdxCost) + return std::move(Order); } } if (std::optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE)) @@ -4260,7 +4463,7 @@ void BoUpSLP::reorderTopToBottom() { unsigned E = Order.size(); OrdersType CurrentOrder(E, E); transform(Mask, CurrentOrder.begin(), [E](int Idx) { - return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx); + return Idx == PoisonMaskElem ? E : static_cast<unsigned>(Idx); }); fixupOrderingIndices(CurrentOrder); ++OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second; @@ -4285,10 +4488,10 @@ void BoUpSLP::reorderTopToBottom() { continue; SmallVector<int> Mask; inversePermutation(BestOrder, Mask); - SmallVector<int> MaskOrder(BestOrder.size(), UndefMaskElem); + SmallVector<int> MaskOrder(BestOrder.size(), PoisonMaskElem); unsigned E = BestOrder.size(); transform(BestOrder, MaskOrder.begin(), [E](unsigned I) { - return I < E ? static_cast<int>(I) : UndefMaskElem; + return I < E ? static_cast<int>(I) : PoisonMaskElem; }); // Do an actual reordering, if profitable. for (std::unique_ptr<TreeEntry> &TE : VectorizableTree) { @@ -4384,7 +4587,7 @@ bool BoUpSLP::canReorderOperands( } return false; }) > 1 && - !all_of(UserTE->getOperand(I), isConstant)) + !allConstant(UserTE->getOperand(I))) return false; if (Gather) GatherOps.push_back(Gather); @@ -4499,7 +4702,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { unsigned E = Order.size(); OrdersType CurrentOrder(E, E); transform(Mask, CurrentOrder.begin(), [E](int Idx) { - return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx); + return Idx == PoisonMaskElem ? E : static_cast<unsigned>(Idx); }); fixupOrderingIndices(CurrentOrder); OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second += @@ -4578,10 +4781,10 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { VisitedOps.clear(); SmallVector<int> Mask; inversePermutation(BestOrder, Mask); - SmallVector<int> MaskOrder(BestOrder.size(), UndefMaskElem); + SmallVector<int> MaskOrder(BestOrder.size(), PoisonMaskElem); unsigned E = BestOrder.size(); transform(BestOrder, MaskOrder.begin(), [E](unsigned I) { - return I < E ? static_cast<int>(I) : UndefMaskElem; + return I < E ? static_cast<int>(I) : PoisonMaskElem; }); for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) { TreeEntry *TE = Op.second; @@ -4779,7 +4982,7 @@ bool BoUpSLP::canFormVector(const SmallVector<StoreInst *, 4> &StoresVec, // Check if the stores are consecutive by checking if their difference is 1. for (unsigned Idx : seq<unsigned>(1, StoreOffsetVec.size())) - if (StoreOffsetVec[Idx].second != StoreOffsetVec[Idx-1].second + 1) + if (StoreOffsetVec[Idx].second != StoreOffsetVec[Idx - 1].second + 1) return false; // Calculate the shuffle indices according to their offset against the sorted @@ -4976,6 +5179,309 @@ static bool isAlternateInstruction(const Instruction *I, const Instruction *AltOp, const TargetLibraryInfo &TLI); +BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( + InstructionsState &S, ArrayRef<Value *> VL, bool IsScatterVectorizeUserTE, + OrdersType &CurrentOrder, SmallVectorImpl<Value *> &PointerOps) const { + assert(S.MainOp && "Expected instructions with same/alternate opcodes only."); + + unsigned ShuffleOrOp = + S.isAltShuffle() ? (unsigned)Instruction::ShuffleVector : S.getOpcode(); + auto *VL0 = cast<Instruction>(S.OpValue); + switch (ShuffleOrOp) { + case Instruction::PHI: { + // Check for terminator values (e.g. invoke). + for (Value *V : VL) + for (Value *Incoming : cast<PHINode>(V)->incoming_values()) { + Instruction *Term = dyn_cast<Instruction>(Incoming); + if (Term && Term->isTerminator()) { + LLVM_DEBUG(dbgs() + << "SLP: Need to swizzle PHINodes (terminator use).\n"); + return TreeEntry::NeedToGather; + } + } + + return TreeEntry::Vectorize; + } + case Instruction::ExtractValue: + case Instruction::ExtractElement: { + bool Reuse = canReuseExtract(VL, VL0, CurrentOrder); + if (Reuse || !CurrentOrder.empty()) + return TreeEntry::Vectorize; + LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); + return TreeEntry::NeedToGather; + } + case Instruction::InsertElement: { + // Check that we have a buildvector and not a shuffle of 2 or more + // different vectors. + ValueSet SourceVectors; + for (Value *V : VL) { + SourceVectors.insert(cast<Instruction>(V)->getOperand(0)); + assert(getInsertIndex(V) != std::nullopt && + "Non-constant or undef index?"); + } + + if (count_if(VL, [&SourceVectors](Value *V) { + return !SourceVectors.contains(V); + }) >= 2) { + // Found 2nd source vector - cancel. + LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with " + "different source vectors.\n"); + return TreeEntry::NeedToGather; + } + + return TreeEntry::Vectorize; + } + case Instruction::Load: { + // Check that a vectorized load would load the same memory as a scalar + // load. For example, we don't want to vectorize loads that are smaller + // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM + // treats loading/storing it as an i8 struct. If we vectorize loads/stores + // from such a struct, we read/write packed bits disagreeing with the + // unvectorized version. + switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, CurrentOrder, + PointerOps)) { + case LoadsState::Vectorize: + return TreeEntry::Vectorize; + case LoadsState::ScatterVectorize: + return TreeEntry::ScatterVectorize; + case LoadsState::Gather: +#ifndef NDEBUG + Type *ScalarTy = VL0->getType(); + if (DL->getTypeSizeInBits(ScalarTy) != + DL->getTypeAllocSizeInBits(ScalarTy)) + LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); + else if (any_of(VL, + [](Value *V) { return !cast<LoadInst>(V)->isSimple(); })) + LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); + else + LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); +#endif // NDEBUG + return TreeEntry::NeedToGather; + } + llvm_unreachable("Unexpected state of loads"); + } + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + Type *SrcTy = VL0->getOperand(0)->getType(); + for (Value *V : VL) { + Type *Ty = cast<Instruction>(V)->getOperand(0)->getType(); + if (Ty != SrcTy || !isValidElementType(Ty)) { + LLVM_DEBUG( + dbgs() << "SLP: Gathering casts with different src types.\n"); + return TreeEntry::NeedToGather; + } + } + return TreeEntry::Vectorize; + } + case Instruction::ICmp: + case Instruction::FCmp: { + // Check that all of the compares have the same predicate. + CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); + CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0); + Type *ComparedTy = VL0->getOperand(0)->getType(); + for (Value *V : VL) { + CmpInst *Cmp = cast<CmpInst>(V); + if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) || + Cmp->getOperand(0)->getType() != ComparedTy) { + LLVM_DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); + return TreeEntry::NeedToGather; + } + } + return TreeEntry::Vectorize; + } + case Instruction::Select: + case Instruction::FNeg: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + return TreeEntry::Vectorize; + case Instruction::GetElementPtr: { + // We don't combine GEPs with complicated (nested) indexing. + for (Value *V : VL) { + auto *I = dyn_cast<GetElementPtrInst>(V); + if (!I) + continue; + if (I->getNumOperands() != 2) { + LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); + return TreeEntry::NeedToGather; + } + } + + // We can't combine several GEPs into one vector if they operate on + // different types. + Type *Ty0 = cast<GEPOperator>(VL0)->getSourceElementType(); + for (Value *V : VL) { + auto *GEP = dyn_cast<GEPOperator>(V); + if (!GEP) + continue; + Type *CurTy = GEP->getSourceElementType(); + if (Ty0 != CurTy) { + LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); + return TreeEntry::NeedToGather; + } + } + + // We don't combine GEPs with non-constant indexes. + Type *Ty1 = VL0->getOperand(1)->getType(); + for (Value *V : VL) { + auto *I = dyn_cast<GetElementPtrInst>(V); + if (!I) + continue; + auto *Op = I->getOperand(1); + if ((!IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) || + (Op->getType() != Ty1 && + ((IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) || + Op->getType()->getScalarSizeInBits() > + DL->getIndexSizeInBits( + V->getType()->getPointerAddressSpace())))) { + LLVM_DEBUG( + dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); + return TreeEntry::NeedToGather; + } + } + + return TreeEntry::Vectorize; + } + case Instruction::Store: { + // Check if the stores are consecutive or if we need to swizzle them. + llvm::Type *ScalarTy = cast<StoreInst>(VL0)->getValueOperand()->getType(); + // Avoid types that are padded when being allocated as scalars, while + // being packed together in a vector (such as i1). + if (DL->getTypeSizeInBits(ScalarTy) != + DL->getTypeAllocSizeInBits(ScalarTy)) { + LLVM_DEBUG(dbgs() << "SLP: Gathering stores of non-packed type.\n"); + return TreeEntry::NeedToGather; + } + // Make sure all stores in the bundle are simple - we can't vectorize + // atomic or volatile stores. + for (Value *V : VL) { + auto *SI = cast<StoreInst>(V); + if (!SI->isSimple()) { + LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple stores.\n"); + return TreeEntry::NeedToGather; + } + PointerOps.push_back(SI->getPointerOperand()); + } + + // Check the order of pointer operands. + if (llvm::sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, CurrentOrder)) { + Value *Ptr0; + Value *PtrN; + if (CurrentOrder.empty()) { + Ptr0 = PointerOps.front(); + PtrN = PointerOps.back(); + } else { + Ptr0 = PointerOps[CurrentOrder.front()]; + PtrN = PointerOps[CurrentOrder.back()]; + } + std::optional<int> Dist = + getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE); + // Check that the sorted pointer operands are consecutive. + if (static_cast<unsigned>(*Dist) == VL.size() - 1) + return TreeEntry::Vectorize; + } + + LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); + return TreeEntry::NeedToGather; + } + case Instruction::Call: { + // Check if the calls are all to the same vectorizable intrinsic or + // library function. + CallInst *CI = cast<CallInst>(VL0); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + + VFShape Shape = VFShape::get( + *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())), + false /*HasGlobalPred*/); + Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); + + if (!VecFunc && !isTriviallyVectorizable(ID)) { + LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); + return TreeEntry::NeedToGather; + } + Function *F = CI->getCalledFunction(); + unsigned NumArgs = CI->arg_size(); + SmallVector<Value *, 4> ScalarArgs(NumArgs, nullptr); + for (unsigned J = 0; J != NumArgs; ++J) + if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) + ScalarArgs[J] = CI->getArgOperand(J); + for (Value *V : VL) { + CallInst *CI2 = dyn_cast<CallInst>(V); + if (!CI2 || CI2->getCalledFunction() != F || + getVectorIntrinsicIDForCall(CI2, TLI) != ID || + (VecFunc && + VecFunc != VFDatabase(*CI2).getVectorizedFunction(Shape)) || + !CI->hasIdenticalOperandBundleSchema(*CI2)) { + LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V + << "\n"); + return TreeEntry::NeedToGather; + } + // Some intrinsics have scalar arguments and should be same in order for + // them to be vectorized. + for (unsigned J = 0; J != NumArgs; ++J) { + if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) { + Value *A1J = CI2->getArgOperand(J); + if (ScalarArgs[J] != A1J) { + LLVM_DEBUG(dbgs() + << "SLP: mismatched arguments in call:" << *CI + << " argument " << ScalarArgs[J] << "!=" << A1J << "\n"); + return TreeEntry::NeedToGather; + } + } + } + // Verify that the bundle operands are identical between the two calls. + if (CI->hasOperandBundles() && + !std::equal(CI->op_begin() + CI->getBundleOperandsStartIndex(), + CI->op_begin() + CI->getBundleOperandsEndIndex(), + CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { + LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI + << "!=" << *V << '\n'); + return TreeEntry::NeedToGather; + } + } + + return TreeEntry::Vectorize; + } + case Instruction::ShuffleVector: { + // If this is not an alternate sequence of opcode like add-sub + // then do not vectorize this instruction. + if (!S.isAltShuffle()) { + LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); + return TreeEntry::NeedToGather; + } + return TreeEntry::Vectorize; + } + default: + LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); + return TreeEntry::NeedToGather; + } +} + void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, const EdgeInfo &UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); @@ -4990,7 +5496,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *V : VL) { if (isConstant(V)) { ReuseShuffleIndicies.emplace_back( - isa<UndefValue>(V) ? UndefMaskElem : UniqueValues.size()); + isa<UndefValue>(V) ? PoisonMaskElem : UniqueValues.size()); UniqueValues.emplace_back(V); continue; } @@ -5010,7 +5516,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return isa<UndefValue>(V) || !isConstant(V); })) || - !llvm::isPowerOf2_32(NumUniqueScalarValues)) { + !llvm::has_single_bit<uint32_t>(NumUniqueScalarValues)) { LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return false; @@ -5257,6 +5763,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!TryToFindDuplicates(S)) return; + // Perform specific checks for each particular instruction kind. + OrdersType CurrentOrder; + SmallVector<Value *> PointerOps; + TreeEntry::EntryState State = getScalarsVectorizationState( + S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps); + if (State == TreeEntry::NeedToGather) { + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + return; + } + auto &BSRef = BlocksSchedules[BB]; if (!BSRef) BSRef = std::make_unique<BlockScheduling>(BB); @@ -5285,20 +5802,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::PHI: { auto *PH = cast<PHINode>(VL0); - // Check for terminator values (e.g. invoke). - for (Value *V : VL) - for (Value *Incoming : cast<PHINode>(V)->incoming_values()) { - Instruction *Term = dyn_cast<Instruction>(Incoming); - if (Term && Term->isTerminator()) { - LLVM_DEBUG(dbgs() - << "SLP: Need to swizzle PHINodes (terminator use).\n"); - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - return; - } - } - TreeEntry *TE = newTreeEntry(VL, Bundle, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); @@ -5326,9 +5829,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } case Instruction::ExtractValue: case Instruction::ExtractElement: { - OrdersType CurrentOrder; - bool Reuse = canReuseExtract(VL, VL0, CurrentOrder); - if (Reuse) { + if (CurrentOrder.empty()) { LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n"); newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); @@ -5339,55 +5840,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, VectorizableTree.back()->setOperand(0, Op0); return; } - if (!CurrentOrder.empty()) { - LLVM_DEBUG({ - dbgs() << "SLP: Reusing or shuffling of reordered extract sequence " - "with order"; - for (unsigned Idx : CurrentOrder) - dbgs() << " " << Idx; - dbgs() << "\n"; - }); - fixupOrderingIndices(CurrentOrder); - // Insert new order with initial value 0, if it does not exist, - // otherwise return the iterator to the existing one. - newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies, CurrentOrder); - // This is a special case, as it does not gather, but at the same time - // we are not extending buildTree_rec() towards the operands. - ValueList Op0; - Op0.assign(VL.size(), VL0->getOperand(0)); - VectorizableTree.back()->setOperand(0, Op0); - return; - } - LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - BS.cancelScheduling(VL, VL0); + LLVM_DEBUG({ + dbgs() << "SLP: Reusing or shuffling of reordered extract sequence " + "with order"; + for (unsigned Idx : CurrentOrder) + dbgs() << " " << Idx; + dbgs() << "\n"; + }); + fixupOrderingIndices(CurrentOrder); + // Insert new order with initial value 0, if it does not exist, + // otherwise return the iterator to the existing one. + newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies, CurrentOrder); + // This is a special case, as it does not gather, but at the same time + // we are not extending buildTree_rec() towards the operands. + ValueList Op0; + Op0.assign(VL.size(), VL0->getOperand(0)); + VectorizableTree.back()->setOperand(0, Op0); return; } case Instruction::InsertElement: { assert(ReuseShuffleIndicies.empty() && "All inserts should be unique"); - // Check that we have a buildvector and not a shuffle of 2 or more - // different vectors. - ValueSet SourceVectors; - for (Value *V : VL) { - SourceVectors.insert(cast<Instruction>(V)->getOperand(0)); - assert(getInsertIndex(V) != std::nullopt && - "Non-constant or undef index?"); - } - - if (count_if(VL, [&SourceVectors](Value *V) { - return !SourceVectors.contains(V); - }) >= 2) { - // Found 2nd source vector - cancel. - LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with " - "different source vectors.\n"); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); - BS.cancelScheduling(VL, VL0); - return; - } - auto OrdCompare = [](const std::pair<int, int> &P1, const std::pair<int, int> &P2) { return P1.first > P2.first; @@ -5430,12 +5904,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // treats loading/storing it as an i8 struct. If we vectorize loads/stores // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. - SmallVector<Value *> PointerOps; - OrdersType CurrentOrder; TreeEntry *TE = nullptr; - switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, - CurrentOrder, PointerOps)) { - case LoadsState::Vectorize: + switch (State) { + case TreeEntry::Vectorize: if (CurrentOrder.empty()) { // Original loads are consecutive and does not require reordering. TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, @@ -5450,7 +5921,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } TE->setOperandsInOrder(); break; - case LoadsState::ScatterVectorize: + case TreeEntry::ScatterVectorize: // Vectorizing non-consecutive loads with `llvm.masked.gather`. TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, S, UserTreeIdx, ReuseShuffleIndicies); @@ -5458,23 +5929,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, buildTree_rec(PointerOps, Depth + 1, {TE, 0}); LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n"); break; - case LoadsState::Gather: - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); -#ifndef NDEBUG - Type *ScalarTy = VL0->getType(); - if (DL->getTypeSizeInBits(ScalarTy) != - DL->getTypeAllocSizeInBits(ScalarTy)) - LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); - else if (any_of(VL, [](Value *V) { - return !cast<LoadInst>(V)->isSimple(); - })) - LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); - else - LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); -#endif // NDEBUG - break; + case TreeEntry::NeedToGather: + llvm_unreachable("Unexpected loads state."); } return; } @@ -5490,18 +5946,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - Type *SrcTy = VL0->getOperand(0)->getType(); - for (Value *V : VL) { - Type *Ty = cast<Instruction>(V)->getOperand(0)->getType(); - if (Ty != SrcTy || !isValidElementType(Ty)) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() - << "SLP: Gathering casts with different src types.\n"); - return; - } - } TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n"); @@ -5521,21 +5965,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::FCmp: { // Check that all of the compares have the same predicate. CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); - CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0); - Type *ComparedTy = VL0->getOperand(0)->getType(); - for (Value *V : VL) { - CmpInst *Cmp = cast<CmpInst>(V); - if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) || - Cmp->getOperand(0)->getType() != ComparedTy) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() - << "SLP: Gathering cmp with different predicate.\n"); - return; - } - } - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n"); @@ -5544,7 +5973,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (cast<CmpInst>(VL0)->isCommutative()) { // Commutative predicate - collect + sort operands of the instructions // so that each side is more likely to have the same opcode. - assert(P0 == SwapP0 && "Commutative Predicate mismatch"); + assert(P0 == CmpInst::getSwappedPredicate(P0) && + "Commutative Predicate mismatch"); reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this); } else { // Collect operands - commute if it uses the swapped predicate. @@ -5612,60 +6042,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } case Instruction::GetElementPtr: { - // We don't combine GEPs with complicated (nested) indexing. - for (Value *V : VL) { - auto *I = dyn_cast<GetElementPtrInst>(V); - if (!I) - continue; - if (I->getNumOperands() != 2) { - LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - return; - } - } - - // We can't combine several GEPs into one vector if they operate on - // different types. - Type *Ty0 = cast<GEPOperator>(VL0)->getSourceElementType(); - for (Value *V : VL) { - auto *GEP = dyn_cast<GEPOperator>(V); - if (!GEP) - continue; - Type *CurTy = GEP->getSourceElementType(); - if (Ty0 != CurTy) { - LLVM_DEBUG(dbgs() - << "SLP: not-vectorizable GEP (different types).\n"); - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - return; - } - } - - // We don't combine GEPs with non-constant indexes. - Type *Ty1 = VL0->getOperand(1)->getType(); - for (Value *V : VL) { - auto *I = dyn_cast<GetElementPtrInst>(V); - if (!I) - continue; - auto *Op = I->getOperand(1); - if ((!IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) || - (Op->getType() != Ty1 && - ((IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) || - Op->getType()->getScalarSizeInBits() > - DL->getIndexSizeInBits( - V->getType()->getPointerAddressSpace())))) { - LLVM_DEBUG(dbgs() - << "SLP: not-vectorizable GEP (non-constant indexes).\n"); - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - return; - } - } - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); @@ -5722,78 +6098,29 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } case Instruction::Store: { // Check if the stores are consecutive or if we need to swizzle them. - llvm::Type *ScalarTy = cast<StoreInst>(VL0)->getValueOperand()->getType(); - // Avoid types that are padded when being allocated as scalars, while - // being packed together in a vector (such as i1). - if (DL->getTypeSizeInBits(ScalarTy) != - DL->getTypeAllocSizeInBits(ScalarTy)) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Gathering stores of non-packed type.\n"); - return; - } - // Make sure all stores in the bundle are simple - we can't vectorize - // atomic or volatile stores. - SmallVector<Value *, 4> PointerOps(VL.size()); ValueList Operands(VL.size()); - auto POIter = PointerOps.begin(); - auto OIter = Operands.begin(); + auto *OIter = Operands.begin(); for (Value *V : VL) { auto *SI = cast<StoreInst>(V); - if (!SI->isSimple()) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple stores.\n"); - return; - } - *POIter = SI->getPointerOperand(); *OIter = SI->getValueOperand(); - ++POIter; ++OIter; } - - OrdersType CurrentOrder; - // Check the order of pointer operands. - if (llvm::sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, CurrentOrder)) { - Value *Ptr0; - Value *PtrN; - if (CurrentOrder.empty()) { - Ptr0 = PointerOps.front(); - PtrN = PointerOps.back(); - } else { - Ptr0 = PointerOps[CurrentOrder.front()]; - PtrN = PointerOps[CurrentOrder.back()]; - } - std::optional<int> Dist = - getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE); - // Check that the sorted pointer operands are consecutive. - if (static_cast<unsigned>(*Dist) == VL.size() - 1) { - if (CurrentOrder.empty()) { - // Original stores are consecutive and does not require reordering. - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, - UserTreeIdx, ReuseShuffleIndicies); - TE->setOperandsInOrder(); - buildTree_rec(Operands, Depth + 1, {TE, 0}); - LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); - } else { - fixupOrderingIndices(CurrentOrder); - TreeEntry *TE = - newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies, CurrentOrder); - TE->setOperandsInOrder(); - buildTree_rec(Operands, Depth + 1, {TE, 0}); - LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled stores.\n"); - } - return; - } + // Check that the sorted pointer operands are consecutive. + if (CurrentOrder.empty()) { + // Original stores are consecutive and does not require reordering. + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + TE->setOperandsInOrder(); + buildTree_rec(Operands, Depth + 1, {TE, 0}); + LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); + } else { + fixupOrderingIndices(CurrentOrder); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies, CurrentOrder); + TE->setOperandsInOrder(); + buildTree_rec(Operands, Depth + 1, {TE, 0}); + LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled stores.\n"); } - - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; } case Instruction::Call: { @@ -5802,68 +6129,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CallInst *CI = cast<CallInst>(VL0); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - VFShape Shape = VFShape::get( - *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())), - false /*HasGlobalPred*/); - Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); - - if (!VecFunc && !isTriviallyVectorizable(ID)) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); - return; - } - Function *F = CI->getCalledFunction(); - unsigned NumArgs = CI->arg_size(); - SmallVector<Value*, 4> ScalarArgs(NumArgs, nullptr); - for (unsigned j = 0; j != NumArgs; ++j) - if (isVectorIntrinsicWithScalarOpAtArg(ID, j)) - ScalarArgs[j] = CI->getArgOperand(j); - for (Value *V : VL) { - CallInst *CI2 = dyn_cast<CallInst>(V); - if (!CI2 || CI2->getCalledFunction() != F || - getVectorIntrinsicIDForCall(CI2, TLI) != ID || - (VecFunc && - VecFunc != VFDatabase(*CI2).getVectorizedFunction(Shape)) || - !CI->hasIdenticalOperandBundleSchema(*CI2)) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V - << "\n"); - return; - } - // Some intrinsics have scalar arguments and should be same in order for - // them to be vectorized. - for (unsigned j = 0; j != NumArgs; ++j) { - if (isVectorIntrinsicWithScalarOpAtArg(ID, j)) { - Value *A1J = CI2->getArgOperand(j); - if (ScalarArgs[j] != A1J) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI - << " argument " << ScalarArgs[j] << "!=" << A1J - << "\n"); - return; - } - } - } - // Verify that the bundle operands are identical between the two calls. - if (CI->hasOperandBundles() && - !std::equal(CI->op_begin() + CI->getBundleOperandsStartIndex(), - CI->op_begin() + CI->getBundleOperandsEndIndex(), - CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" - << *CI << "!=" << *V << '\n'); - return; - } - } - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); TE->setOperandsInOrder(); @@ -5883,15 +6148,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } case Instruction::ShuffleVector: { - // If this is not an alternate sequence of opcode like add-sub - // then do not vectorize this instruction. - if (!S.isAltShuffle()) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); - return; - } TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); @@ -5949,19 +6205,16 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } default: - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); - return; + break; } + llvm_unreachable("Unexpected vectorization of the instructions."); } unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { unsigned N = 1; Type *EltTy = T; - while (isa<StructType, ArrayType, VectorType>(EltTy)) { + while (isa<StructType, ArrayType, FixedVectorType>(EltTy)) { if (auto *ST = dyn_cast<StructType>(EltTy)) { // Check that struct is homogeneous. for (const auto *Ty : ST->elements()) @@ -5982,7 +6235,8 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { if (!isValidElementType(EltTy)) return 0; uint64_t VTSize = DL.getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N)); - if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || VTSize != DL.getTypeStoreSizeInBits(T)) + if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || + VTSize != DL.getTypeStoreSizeInBits(T)) return 0; return N; } @@ -6111,68 +6365,6 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, return {IntrinsicCost, LibCost}; } -/// Compute the cost of creating a vector of type \p VecTy containing the -/// extracted values from \p VL. -static InstructionCost -computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy, - TargetTransformInfo::ShuffleKind ShuffleKind, - ArrayRef<int> Mask, TargetTransformInfo &TTI) { - unsigned NumOfParts = TTI.getNumberOfParts(VecTy); - - if (ShuffleKind != TargetTransformInfo::SK_PermuteSingleSrc || !NumOfParts || - VecTy->getNumElements() < NumOfParts) - return TTI.getShuffleCost(ShuffleKind, VecTy, Mask); - - bool AllConsecutive = true; - unsigned EltsPerVector = VecTy->getNumElements() / NumOfParts; - unsigned Idx = -1; - InstructionCost Cost = 0; - - // Process extracts in blocks of EltsPerVector to check if the source vector - // operand can be re-used directly. If not, add the cost of creating a shuffle - // to extract the values into a vector register. - SmallVector<int> RegMask(EltsPerVector, UndefMaskElem); - for (auto *V : VL) { - ++Idx; - - // Reached the start of a new vector registers. - if (Idx % EltsPerVector == 0) { - RegMask.assign(EltsPerVector, UndefMaskElem); - AllConsecutive = true; - continue; - } - - // Need to exclude undefs from analysis. - if (isa<UndefValue>(V) || Mask[Idx] == UndefMaskElem) - continue; - - // Check all extracts for a vector register on the target directly - // extract values in order. - unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V)); - if (!isa<UndefValue>(VL[Idx - 1]) && Mask[Idx - 1] != UndefMaskElem) { - unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1])); - AllConsecutive &= PrevIdx + 1 == CurrentIdx && - CurrentIdx % EltsPerVector == Idx % EltsPerVector; - RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector; - } - - if (AllConsecutive) - continue; - - // Skip all indices, except for the last index per vector block. - if ((Idx + 1) % EltsPerVector != 0 && Idx + 1 != VL.size()) - continue; - - // If we have a series of extracts which are not consecutive and hence - // cannot re-use the source vector register directly, compute the shuffle - // cost to extract the vector with EltsPerVector elements. - Cost += TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, - FixedVectorType::get(VecTy->getElementType(), EltsPerVector), RegMask); - } - return Cost; -} - /// Build shuffle mask for shuffle graph entries and lists of main and alternate /// operations operands. static void @@ -6183,7 +6375,7 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, SmallVectorImpl<Value *> *OpScalars = nullptr, SmallVectorImpl<Value *> *AltScalars = nullptr) { unsigned Sz = VL.size(); - Mask.assign(Sz, UndefMaskElem); + Mask.assign(Sz, PoisonMaskElem); SmallVector<int> OrderMask; if (!ReorderIndices.empty()) inversePermutation(ReorderIndices, OrderMask); @@ -6203,9 +6395,9 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, } } if (!ReusesIndices.empty()) { - SmallVector<int> NewMask(ReusesIndices.size(), UndefMaskElem); + SmallVector<int> NewMask(ReusesIndices.size(), PoisonMaskElem); transform(ReusesIndices, NewMask.begin(), [&Mask](int Idx) { - return Idx != UndefMaskElem ? Mask[Idx] : UndefMaskElem; + return Idx != PoisonMaskElem ? Mask[Idx] : PoisonMaskElem; }); Mask.swap(NewMask); } @@ -6325,13 +6517,13 @@ protected: static void combineMasks(unsigned LocalVF, SmallVectorImpl<int> &Mask, ArrayRef<int> ExtMask) { unsigned VF = Mask.size(); - SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem); + SmallVector<int> NewMask(ExtMask.size(), PoisonMaskElem); for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) { - if (ExtMask[I] == UndefMaskElem) + if (ExtMask[I] == PoisonMaskElem) continue; int MaskedIdx = Mask[ExtMask[I] % VF]; NewMask[I] = - MaskedIdx == UndefMaskElem ? UndefMaskElem : MaskedIdx % LocalVF; + MaskedIdx == PoisonMaskElem ? PoisonMaskElem : MaskedIdx % LocalVF; } Mask.swap(NewMask); } @@ -6418,11 +6610,12 @@ protected: if (auto *SVOpTy = dyn_cast<FixedVectorType>(SV->getOperand(0)->getType())) LocalVF = SVOpTy->getNumElements(); - SmallVector<int> ExtMask(Mask.size(), UndefMaskElem); + SmallVector<int> ExtMask(Mask.size(), PoisonMaskElem); for (auto [Idx, I] : enumerate(Mask)) { - if (I == UndefMaskElem) - continue; - ExtMask[Idx] = SV->getMaskValue(I); + if (I == PoisonMaskElem || + static_cast<unsigned>(I) >= SV->getShuffleMask().size()) + continue; + ExtMask[Idx] = SV->getMaskValue(I); } bool IsOp1Undef = isUndefVector(SV->getOperand(0), @@ -6435,11 +6628,11 @@ protected: if (!IsOp1Undef && !IsOp2Undef) { // Update mask and mark undef elems. for (int &I : Mask) { - if (I == UndefMaskElem) + if (I == PoisonMaskElem) continue; if (SV->getMaskValue(I % SV->getShuffleMask().size()) == - UndefMaskElem) - I = UndefMaskElem; + PoisonMaskElem) + I = PoisonMaskElem; } break; } @@ -6453,15 +6646,16 @@ protected: Op = SV->getOperand(1); } if (auto *OpTy = dyn_cast<FixedVectorType>(Op->getType()); - !OpTy || !isIdentityMask(Mask, OpTy, SinglePermute)) { + !OpTy || !isIdentityMask(Mask, OpTy, SinglePermute) || + ShuffleVectorInst::isZeroEltSplatMask(Mask)) { if (IdentityOp) { V = IdentityOp; assert(Mask.size() == IdentityMask.size() && "Expected masks of same sizes."); // Clear known poison elements. for (auto [I, Idx] : enumerate(Mask)) - if (Idx == UndefMaskElem) - IdentityMask[I] = UndefMaskElem; + if (Idx == PoisonMaskElem) + IdentityMask[I] = PoisonMaskElem; Mask.swap(IdentityMask); auto *Shuffle = dyn_cast<ShuffleVectorInst>(V); return SinglePermute && @@ -6481,10 +6675,12 @@ protected: /// Smart shuffle instruction emission, walks through shuffles trees and /// tries to find the best matching vector for the actual shuffle /// instruction. - template <typename ShuffleBuilderTy> - static Value *createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask, - ShuffleBuilderTy &Builder) { + template <typename T, typename ShuffleBuilderTy> + static T createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask, + ShuffleBuilderTy &Builder) { assert(V1 && "Expected at least one vector value."); + if (V2) + Builder.resizeToMatch(V1, V2); int VF = Mask.size(); if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType())) VF = FTy->getNumElements(); @@ -6495,8 +6691,8 @@ protected: Value *Op2 = V2; int VF = cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); - SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem); - SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem); + SmallVector<int> CombinedMask1(Mask.size(), PoisonMaskElem); + SmallVector<int> CombinedMask2(Mask.size(), PoisonMaskElem); for (int I = 0, E = Mask.size(); I < E; ++I) { if (Mask[I] < VF) CombinedMask1[I] = Mask[I]; @@ -6514,9 +6710,9 @@ protected: // again. if (auto *SV1 = dyn_cast<ShuffleVectorInst>(Op1)) if (auto *SV2 = dyn_cast<ShuffleVectorInst>(Op2)) { - SmallVector<int> ExtMask1(Mask.size(), UndefMaskElem); + SmallVector<int> ExtMask1(Mask.size(), PoisonMaskElem); for (auto [Idx, I] : enumerate(CombinedMask1)) { - if (I == UndefMaskElem) + if (I == PoisonMaskElem) continue; ExtMask1[Idx] = SV1->getMaskValue(I); } @@ -6524,9 +6720,9 @@ protected: cast<FixedVectorType>(SV1->getOperand(1)->getType()) ->getNumElements(), ExtMask1, UseMask::SecondArg); - SmallVector<int> ExtMask2(CombinedMask2.size(), UndefMaskElem); + SmallVector<int> ExtMask2(CombinedMask2.size(), PoisonMaskElem); for (auto [Idx, I] : enumerate(CombinedMask2)) { - if (I == UndefMaskElem) + if (I == PoisonMaskElem) continue; ExtMask2[Idx] = SV2->getMaskValue(I); } @@ -6566,64 +6762,360 @@ protected: ->getElementCount() .getKnownMinValue()); for (int I = 0, E = Mask.size(); I < E; ++I) { - if (CombinedMask2[I] != UndefMaskElem) { - assert(CombinedMask1[I] == UndefMaskElem && + if (CombinedMask2[I] != PoisonMaskElem) { + assert(CombinedMask1[I] == PoisonMaskElem && "Expected undefined mask element"); CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF); } } + const int Limit = CombinedMask1.size() * 2; + if (Op1 == Op2 && Limit == 2 * VF && + all_of(CombinedMask1, [=](int Idx) { return Idx < Limit; }) && + (ShuffleVectorInst::isIdentityMask(CombinedMask1) || + (ShuffleVectorInst::isZeroEltSplatMask(CombinedMask1) && + isa<ShuffleVectorInst>(Op1) && + cast<ShuffleVectorInst>(Op1)->getShuffleMask() == + ArrayRef(CombinedMask1)))) + return Builder.createIdentity(Op1); return Builder.createShuffleVector( Op1, Op1 == Op2 ? PoisonValue::get(Op1->getType()) : Op2, CombinedMask1); } if (isa<PoisonValue>(V1)) - return PoisonValue::get(FixedVectorType::get( - cast<VectorType>(V1->getType())->getElementType(), Mask.size())); + return Builder.createPoison( + cast<VectorType>(V1->getType())->getElementType(), Mask.size()); SmallVector<int> NewMask(Mask.begin(), Mask.end()); bool IsIdentity = peekThroughShuffles(V1, NewMask, /*SinglePermute=*/true); assert(V1 && "Expected non-null value after looking through shuffles."); if (!IsIdentity) return Builder.createShuffleVector(V1, NewMask); - return V1; + return Builder.createIdentity(V1); } }; } // namespace -InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, - ArrayRef<Value *> VectorizedVals) { - ArrayRef<Value *> VL = E->Scalars; +/// Merges shuffle masks and emits final shuffle instruction, if required. It +/// supports shuffling of 2 input vectors. It implements lazy shuffles emission, +/// when the actual shuffle instruction is generated only if this is actually +/// required. Otherwise, the shuffle instruction emission is delayed till the +/// end of the process, to reduce the number of emitted instructions and further +/// analysis/transformations. +class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { + bool IsFinalized = false; + SmallVector<int> CommonMask; + SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors; + const TargetTransformInfo &TTI; + InstructionCost Cost = 0; + ArrayRef<Value *> VectorizedVals; + BoUpSLP &R; + SmallPtrSetImpl<Value *> &CheckedExtracts; + constexpr static TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) { + if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof)) + return TTI::TCC_Free; + auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size()); + InstructionCost GatherCost = 0; + SmallVector<Value *> Gathers(VL.begin(), VL.end()); + // Improve gather cost for gather of loads, if we can group some of the + // loads into vector loads. + InstructionsState S = getSameOpcode(VL, *R.TLI); + if (VL.size() > 2 && S.getOpcode() == Instruction::Load && + !S.isAltShuffle() && + !all_of(Gathers, [&](Value *V) { return R.getTreeEntry(V); }) && + !isSplat(Gathers)) { + BoUpSLP::ValueSet VectorizedLoads; + unsigned StartIdx = 0; + unsigned VF = VL.size() / 2; + unsigned VectorizedCnt = 0; + unsigned ScatterVectorizeCnt = 0; + const unsigned Sz = R.DL->getTypeSizeInBits(S.MainOp->getType()); + for (unsigned MinVF = R.getMinVF(2 * Sz); VF >= MinVF; VF /= 2) { + for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End; + Cnt += VF) { + ArrayRef<Value *> Slice = VL.slice(Cnt, VF); + if (!VectorizedLoads.count(Slice.front()) && + !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) { + SmallVector<Value *> PointerOps; + OrdersType CurrentOrder; + LoadsState LS = + canVectorizeLoads(Slice, Slice.front(), TTI, *R.DL, *R.SE, + *R.LI, *R.TLI, CurrentOrder, PointerOps); + switch (LS) { + case LoadsState::Vectorize: + case LoadsState::ScatterVectorize: + // Mark the vectorized loads so that we don't vectorize them + // again. + if (LS == LoadsState::Vectorize) + ++VectorizedCnt; + else + ++ScatterVectorizeCnt; + VectorizedLoads.insert(Slice.begin(), Slice.end()); + // If we vectorized initial block, no need to try to vectorize + // it again. + if (Cnt == StartIdx) + StartIdx += VF; + break; + case LoadsState::Gather: + break; + } + } + } + // Check if the whole array was vectorized already - exit. + if (StartIdx >= VL.size()) + break; + // Found vectorizable parts - exit. + if (!VectorizedLoads.empty()) + break; + } + if (!VectorizedLoads.empty()) { + unsigned NumParts = TTI.getNumberOfParts(VecTy); + bool NeedInsertSubvectorAnalysis = + !NumParts || (VL.size() / VF) > NumParts; + // Get the cost for gathered loads. + for (unsigned I = 0, End = VL.size(); I < End; I += VF) { + if (VectorizedLoads.contains(VL[I])) + continue; + GatherCost += getBuildVectorCost(VL.slice(I, VF), Root); + } + // Exclude potentially vectorized loads from list of gathered + // scalars. + auto *LI = cast<LoadInst>(S.MainOp); + Gathers.assign(Gathers.size(), PoisonValue::get(LI->getType())); + // The cost for vectorized loads. + InstructionCost ScalarsCost = 0; + for (Value *V : VectorizedLoads) { + auto *LI = cast<LoadInst>(V); + ScalarsCost += + TTI.getMemoryOpCost(Instruction::Load, LI->getType(), + LI->getAlign(), LI->getPointerAddressSpace(), + CostKind, TTI::OperandValueInfo(), LI); + } + auto *LoadTy = FixedVectorType::get(LI->getType(), VF); + Align Alignment = LI->getAlign(); + GatherCost += + VectorizedCnt * + TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, + LI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo(), LI); + GatherCost += ScatterVectorizeCnt * + TTI.getGatherScatterOpCost( + Instruction::Load, LoadTy, LI->getPointerOperand(), + /*VariableMask=*/false, Alignment, CostKind, LI); + if (NeedInsertSubvectorAnalysis) { + // Add the cost for the subvectors insert. + for (int I = VF, E = VL.size(); I < E; I += VF) + GatherCost += TTI.getShuffleCost(TTI::SK_InsertSubvector, VecTy, + std::nullopt, CostKind, I, LoadTy); + } + GatherCost -= ScalarsCost; + } + } else if (!Root && isSplat(VL)) { + // Found the broadcasting of the single scalar, calculate the cost as + // the broadcast. + const auto *It = + find_if(VL, [](Value *V) { return !isa<UndefValue>(V); }); + assert(It != VL.end() && "Expected at least one non-undef value."); + // Add broadcast for non-identity shuffle only. + bool NeedShuffle = + count(VL, *It) > 1 && + (VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof)); + InstructionCost InsertCost = TTI.getVectorInstrCost( + Instruction::InsertElement, VecTy, CostKind, + NeedShuffle ? 0 : std::distance(VL.begin(), It), + PoisonValue::get(VecTy), *It); + return InsertCost + + (NeedShuffle ? TTI.getShuffleCost( + TargetTransformInfo::SK_Broadcast, VecTy, + /*Mask=*/std::nullopt, CostKind, /*Index=*/0, + /*SubTp=*/nullptr, /*Args=*/*It) + : TTI::TCC_Free); + } + return GatherCost + + (all_of(Gathers, UndefValue::classof) + ? TTI::TCC_Free + : R.getGatherCost(Gathers, !Root && VL.equals(Gathers))); + }; - Type *ScalarTy = VL[0]->getType(); - if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) - ScalarTy = SI->getValueOperand()->getType(); - else if (CmpInst *CI = dyn_cast<CmpInst>(VL[0])) - ScalarTy = CI->getOperand(0)->getType(); - else if (auto *IE = dyn_cast<InsertElementInst>(VL[0])) - ScalarTy = IE->getOperand(1)->getType(); - auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + /// Compute the cost of creating a vector of type \p VecTy containing the + /// extracted values from \p VL. + InstructionCost computeExtractCost(ArrayRef<Value *> VL, ArrayRef<int> Mask, + TTI::ShuffleKind ShuffleKind) { + auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size()); + unsigned NumOfParts = TTI.getNumberOfParts(VecTy); - // If we have computed a smaller type for the expression, update VecTy so - // that the costs will be accurate. - if (MinBWs.count(VL[0])) - VecTy = FixedVectorType::get( - IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); - unsigned EntryVF = E->getVectorFactor(); - auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF); + if (ShuffleKind != TargetTransformInfo::SK_PermuteSingleSrc || + !NumOfParts || VecTy->getNumElements() < NumOfParts) + return TTI.getShuffleCost(ShuffleKind, VecTy, Mask); - bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); - // FIXME: it tries to fix a problem with MSVC buildbots. - TargetTransformInfo *TTI = this->TTI; - auto AdjustExtractsCost = [=](InstructionCost &Cost) { + bool AllConsecutive = true; + unsigned EltsPerVector = VecTy->getNumElements() / NumOfParts; + unsigned Idx = -1; + InstructionCost Cost = 0; + + // Process extracts in blocks of EltsPerVector to check if the source vector + // operand can be re-used directly. If not, add the cost of creating a + // shuffle to extract the values into a vector register. + SmallVector<int> RegMask(EltsPerVector, PoisonMaskElem); + for (auto *V : VL) { + ++Idx; + + // Reached the start of a new vector registers. + if (Idx % EltsPerVector == 0) { + RegMask.assign(EltsPerVector, PoisonMaskElem); + AllConsecutive = true; + continue; + } + + // Need to exclude undefs from analysis. + if (isa<UndefValue>(V) || Mask[Idx] == PoisonMaskElem) + continue; + + // Check all extracts for a vector register on the target directly + // extract values in order. + unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V)); + if (!isa<UndefValue>(VL[Idx - 1]) && Mask[Idx - 1] != PoisonMaskElem) { + unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1])); + AllConsecutive &= PrevIdx + 1 == CurrentIdx && + CurrentIdx % EltsPerVector == Idx % EltsPerVector; + RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector; + } + + if (AllConsecutive) + continue; + + // Skip all indices, except for the last index per vector block. + if ((Idx + 1) % EltsPerVector != 0 && Idx + 1 != VL.size()) + continue; + + // If we have a series of extracts which are not consecutive and hence + // cannot re-use the source vector register directly, compute the shuffle + // cost to extract the vector with EltsPerVector elements. + Cost += TTI.getShuffleCost( + TargetTransformInfo::SK_PermuteSingleSrc, + FixedVectorType::get(VecTy->getElementType(), EltsPerVector), + RegMask); + } + return Cost; + } + + class ShuffleCostBuilder { + const TargetTransformInfo &TTI; + + static bool isEmptyOrIdentity(ArrayRef<int> Mask, unsigned VF) { + int Limit = 2 * VF; + return Mask.empty() || + (VF == Mask.size() && + all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) && + ShuffleVectorInst::isIdentityMask(Mask)); + } + + public: + ShuffleCostBuilder(const TargetTransformInfo &TTI) : TTI(TTI) {} + ~ShuffleCostBuilder() = default; + InstructionCost createShuffleVector(Value *V1, Value *, + ArrayRef<int> Mask) const { + // Empty mask or identity mask are free. + unsigned VF = + cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); + if (isEmptyOrIdentity(Mask, VF)) + return TTI::TCC_Free; + return TTI.getShuffleCost( + TTI::SK_PermuteTwoSrc, + FixedVectorType::get( + cast<VectorType>(V1->getType())->getElementType(), Mask.size()), + Mask); + } + InstructionCost createShuffleVector(Value *V1, ArrayRef<int> Mask) const { + // Empty mask or identity mask are free. + if (isEmptyOrIdentity(Mask, Mask.size())) + return TTI::TCC_Free; + return TTI.getShuffleCost( + TTI::SK_PermuteSingleSrc, + FixedVectorType::get( + cast<VectorType>(V1->getType())->getElementType(), Mask.size()), + Mask); + } + InstructionCost createIdentity(Value *) const { return TTI::TCC_Free; } + InstructionCost createPoison(Type *Ty, unsigned VF) const { + return TTI::TCC_Free; + } + void resizeToMatch(Value *&, Value *&) const {} + }; + + /// Smart shuffle instruction emission, walks through shuffles trees and + /// tries to find the best matching vector for the actual shuffle + /// instruction. + InstructionCost + createShuffle(const PointerUnion<Value *, const TreeEntry *> &P1, + const PointerUnion<Value *, const TreeEntry *> &P2, + ArrayRef<int> Mask) { + ShuffleCostBuilder Builder(TTI); + Value *V1 = P1.dyn_cast<Value *>(), *V2 = P2.dyn_cast<Value *>(); + unsigned CommonVF = 0; + if (!V1) { + const TreeEntry *E = P1.get<const TreeEntry *>(); + unsigned VF = E->getVectorFactor(); + if (V2) { + unsigned V2VF = cast<FixedVectorType>(V2->getType())->getNumElements(); + if (V2VF != VF && V2VF == E->Scalars.size()) + VF = E->Scalars.size(); + } else if (!P2.isNull()) { + const TreeEntry *E2 = P2.get<const TreeEntry *>(); + if (E->Scalars.size() == E2->Scalars.size()) + CommonVF = VF = E->Scalars.size(); + } else { + // P2 is empty, check that we have same node + reshuffle (if any). + if (E->Scalars.size() == Mask.size() && VF != Mask.size()) { + VF = E->Scalars.size(); + SmallVector<int> CommonMask(Mask.begin(), Mask.end()); + ::addMask(CommonMask, E->getCommonMask()); + V1 = Constant::getNullValue( + FixedVectorType::get(E->Scalars.front()->getType(), VF)); + return BaseShuffleAnalysis::createShuffle<InstructionCost>( + V1, nullptr, CommonMask, Builder); + } + } + V1 = Constant::getNullValue( + FixedVectorType::get(E->Scalars.front()->getType(), VF)); + } + if (!V2 && !P2.isNull()) { + const TreeEntry *E = P2.get<const TreeEntry *>(); + unsigned VF = E->getVectorFactor(); + unsigned V1VF = cast<FixedVectorType>(V1->getType())->getNumElements(); + if (!CommonVF && V1VF == E->Scalars.size()) + CommonVF = E->Scalars.size(); + if (CommonVF) + VF = CommonVF; + V2 = Constant::getNullValue( + FixedVectorType::get(E->Scalars.front()->getType(), VF)); + } + return BaseShuffleAnalysis::createShuffle<InstructionCost>(V1, V2, Mask, + Builder); + } + +public: + ShuffleCostEstimator(TargetTransformInfo &TTI, + ArrayRef<Value *> VectorizedVals, BoUpSLP &R, + SmallPtrSetImpl<Value *> &CheckedExtracts) + : TTI(TTI), VectorizedVals(VectorizedVals), R(R), + CheckedExtracts(CheckedExtracts) {} + Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask, + TTI::ShuffleKind ShuffleKind) { + if (Mask.empty()) + return nullptr; + Value *VecBase = nullptr; + ArrayRef<Value *> VL = E->Scalars; + auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size()); // If the resulting type is scalarized, do not adjust the cost. - unsigned VecNumParts = TTI->getNumberOfParts(VecTy); + unsigned VecNumParts = TTI.getNumberOfParts(VecTy); if (VecNumParts == VecTy->getNumElements()) - return; + return nullptr; DenseMap<Value *, int> ExtractVectorsTys; - SmallPtrSet<Value *, 4> CheckedExtracts; - for (auto *V : VL) { - if (isa<UndefValue>(V)) + for (auto [I, V] : enumerate(VL)) { + // Ignore non-extractelement scalars. + if (isa<UndefValue>(V) || (!Mask.empty() && Mask[I] == PoisonMaskElem)) continue; // If all users of instruction are going to be vectorized and this // instruction itself is not going to be vectorized, consider this @@ -6631,17 +7123,18 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // vectorized tree. // Also, avoid adjusting the cost for extractelements with multiple uses // in different graph entries. - const TreeEntry *VE = getTreeEntry(V); + const TreeEntry *VE = R.getTreeEntry(V); if (!CheckedExtracts.insert(V).second || - !areAllUsersVectorized(cast<Instruction>(V), VectorizedVals) || + !R.areAllUsersVectorized(cast<Instruction>(V), VectorizedVals) || (VE && VE != E)) continue; auto *EE = cast<ExtractElementInst>(V); + VecBase = EE->getVectorOperand(); std::optional<unsigned> EEIdx = getExtractIndex(EE); if (!EEIdx) continue; unsigned Idx = *EEIdx; - if (VecNumParts != TTI->getNumberOfParts(EE->getVectorOperandType())) { + if (VecNumParts != TTI.getNumberOfParts(EE->getVectorOperandType())) { auto It = ExtractVectorsTys.try_emplace(EE->getVectorOperand(), Idx).first; It->getSecond() = std::min<int>(It->second, Idx); @@ -6654,18 +7147,17 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, })) { // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. - Cost -= - TTI->getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), - EE->getVectorOperandType(), Idx); + Cost -= TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), + EE->getVectorOperandType(), Idx); // Add back the cost of s|zext which is subtracted separately. - Cost += TTI->getCastInstrCost( + Cost += TTI.getCastInstrCost( Ext->getOpcode(), Ext->getType(), EE->getType(), TTI::getCastContextHint(Ext), CostKind, Ext); continue; } } - Cost -= TTI->getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind, - Idx); + Cost -= TTI.getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind, + Idx); } // Add a cost for subvector extracts/inserts if required. for (const auto &Data : ExtractVectorsTys) { @@ -6673,34 +7165,148 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, unsigned NumElts = VecTy->getNumElements(); if (Data.second % NumElts == 0) continue; - if (TTI->getNumberOfParts(EEVTy) > VecNumParts) { + if (TTI.getNumberOfParts(EEVTy) > VecNumParts) { unsigned Idx = (Data.second / NumElts) * NumElts; unsigned EENumElts = EEVTy->getNumElements(); + if (Idx % NumElts == 0) + continue; if (Idx + NumElts <= EENumElts) { - Cost += - TTI->getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - EEVTy, std::nullopt, CostKind, Idx, VecTy); + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, + EEVTy, std::nullopt, CostKind, Idx, VecTy); } else { // Need to round up the subvector type vectorization factor to avoid a // crash in cost model functions. Make SubVT so that Idx + VF of SubVT // <= EENumElts. auto *SubVT = FixedVectorType::get(VecTy->getElementType(), EENumElts - Idx); - Cost += - TTI->getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - EEVTy, std::nullopt, CostKind, Idx, SubVT); + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, + EEVTy, std::nullopt, CostKind, Idx, SubVT); } } else { - Cost += TTI->getShuffleCost(TargetTransformInfo::SK_InsertSubvector, - VecTy, std::nullopt, CostKind, 0, EEVTy); + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_InsertSubvector, + VecTy, std::nullopt, CostKind, 0, EEVTy); } } - }; + // Check that gather of extractelements can be represented as just a + // shuffle of a single/two vectors the scalars are extracted from. + // Found the bunch of extractelement instructions that must be gathered + // into a vector and can be represented as a permutation elements in a + // single input vector or of 2 input vectors. + Cost += computeExtractCost(VL, Mask, ShuffleKind); + return VecBase; + } + void add(const TreeEntry *E1, const TreeEntry *E2, ArrayRef<int> Mask) { + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign({E1, E2}); + } + void add(const TreeEntry *E1, ArrayRef<int> Mask) { + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign(1, E1); + } + /// Adds another one input vector and the mask for the shuffling. + void add(Value *V1, ArrayRef<int> Mask) { + assert(CommonMask.empty() && InVectors.empty() && + "Expected empty input mask/vectors."); + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign(1, V1); + } + Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) { + Cost += getBuildVectorCost(VL, Root); + if (!Root) { + assert(InVectors.empty() && "Unexpected input vectors for buildvector."); + // FIXME: Need to find a way to avoid use of getNullValue here. + SmallVector<Constant *> Vals; + for (Value *V : VL) { + if (isa<UndefValue>(V)) { + Vals.push_back(cast<Constant>(V)); + continue; + } + Vals.push_back(Constant::getNullValue(V->getType())); + } + return ConstantVector::get(Vals); + } + return ConstantVector::getSplat( + ElementCount::getFixed(VL.size()), + Constant::getNullValue(VL.front()->getType())); + } + /// Finalize emission of the shuffles. + InstructionCost + finalize(ArrayRef<int> ExtMask, unsigned VF = 0, + function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) { + IsFinalized = true; + if (Action) { + const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front(); + if (InVectors.size() == 2) { + Cost += createShuffle(Vec, InVectors.back(), CommonMask); + InVectors.pop_back(); + } else { + Cost += createShuffle(Vec, nullptr, CommonMask); + } + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (CommonMask[Idx] != PoisonMaskElem) + CommonMask[Idx] = Idx; + assert(VF > 0 && + "Expected vector length for the final value before action."); + Value *V = Vec.dyn_cast<Value *>(); + if (!Vec.isNull() && !V) + V = Constant::getNullValue(FixedVectorType::get( + Vec.get<const TreeEntry *>()->Scalars.front()->getType(), + CommonMask.size())); + Action(V, CommonMask); + } + ::addMask(CommonMask, ExtMask, /*ExtendingManyInputs=*/true); + if (CommonMask.empty()) + return Cost; + int Limit = CommonMask.size() * 2; + if (all_of(CommonMask, [=](int Idx) { return Idx < Limit; }) && + ShuffleVectorInst::isIdentityMask(CommonMask)) + return Cost; + return Cost + + createShuffle(InVectors.front(), + InVectors.size() == 2 ? InVectors.back() : nullptr, + CommonMask); + } + + ~ShuffleCostEstimator() { + assert((IsFinalized || CommonMask.empty()) && + "Shuffle construction must be finalized."); + } +}; + +InstructionCost +BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, + SmallPtrSetImpl<Value *> &CheckedExtracts) { + ArrayRef<Value *> VL = E->Scalars; + + Type *ScalarTy = VL[0]->getType(); + if (auto *SI = dyn_cast<StoreInst>(VL[0])) + ScalarTy = SI->getValueOperand()->getType(); + else if (auto *CI = dyn_cast<CmpInst>(VL[0])) + ScalarTy = CI->getOperand(0)->getType(); + else if (auto *IE = dyn_cast<InsertElementInst>(VL[0])) + ScalarTy = IE->getOperand(1)->getType(); + auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + // If we have computed a smaller type for the expression, update VecTy so + // that the costs will be accurate. + if (MinBWs.count(VL[0])) + VecTy = FixedVectorType::get( + IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); + unsigned EntryVF = E->getVectorFactor(); + auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF); + + bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); if (E->State == TreeEntry::NeedToGather) { if (allConstant(VL)) return 0; if (isa<InsertElementInst>(VL[0])) return InstructionCost::getInvalid(); + ShuffleCostEstimator Estimator(*TTI, VectorizedVals, *this, + CheckedExtracts); + unsigned VF = E->getVectorFactor(); + SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(), + E->ReuseShuffleIndices.end()); SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); // Build a mask out of the reorder indices and reorder scalars per this // mask. @@ -6709,195 +7315,104 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, if (!ReorderMask.empty()) reorderScalars(GatheredScalars, ReorderMask); SmallVector<int> Mask; + SmallVector<int> ExtractMask; + std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle; std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle; SmallVector<const TreeEntry *> Entries; + Type *ScalarTy = GatheredScalars.front()->getType(); + // Check for gathered extracts. + ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask); + SmallVector<Value *> IgnoredVals; + if (UserIgnoreList) + IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end()); + + bool Resized = false; + if (Value *VecBase = Estimator.adjustExtracts( + E, ExtractMask, ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc))) + if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType())) + if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) { + Resized = true; + GatheredScalars.append(VF - GatheredScalars.size(), + PoisonValue::get(ScalarTy)); + } + // Do not try to look for reshuffled loads for gathered loads (they will be // handled later), for vectorized scalars, and cases, which are definitely // not profitable (splats and small gather nodes.) - if (E->getOpcode() != Instruction::Load || E->isAltShuffle() || + if (ExtractShuffle || E->getOpcode() != Instruction::Load || + E->isAltShuffle() || all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) || isSplat(E->Scalars) || (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); if (GatherShuffle) { - // Remove shuffled elements from list of gathers. - for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { - if (Mask[I] != UndefMaskElem) - GatheredScalars[I] = PoisonValue::get(ScalarTy); - } assert((Entries.size() == 1 || Entries.size() == 2) && "Expected shuffle of 1 or 2 entries."); - InstructionCost GatherCost = 0; - int Limit = Mask.size() * 2; - if (all_of(Mask, [=](int Idx) { return Idx < Limit; }) && - ShuffleVectorInst::isIdentityMask(Mask)) { + if (*GatherShuffle == TTI::SK_PermuteSingleSrc && + Entries.front()->isSame(E->Scalars)) { // Perfect match in the graph, will reuse the previously vectorized // node. Cost is 0. LLVM_DEBUG( dbgs() << "SLP: perfect diamond match for gather bundle that starts with " << *VL.front() << ".\n"); - if (NeedToShuffleReuses) - GatherCost = - TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, - FinalVecTy, E->ReuseShuffleIndices); - } else { - LLVM_DEBUG(dbgs() << "SLP: shuffled " << Entries.size() - << " entries for bundle that starts with " - << *VL.front() << ".\n"); - // Detected that instead of gather we can emit a shuffle of single/two - // previously vectorized nodes. Add the cost of the permutation rather - // than gather. - ::addMask(Mask, E->ReuseShuffleIndices); - GatherCost = TTI->getShuffleCost(*GatherShuffle, FinalVecTy, Mask); - } - if (!all_of(GatheredScalars, UndefValue::classof)) - GatherCost += getGatherCost(GatheredScalars); - return GatherCost; - } - if ((E->getOpcode() == Instruction::ExtractElement || - all_of(E->Scalars, - [](Value *V) { - return isa<ExtractElementInst, UndefValue>(V); - })) && - allSameType(VL)) { - // Check that gather of extractelements can be represented as just a - // shuffle of a single/two vectors the scalars are extracted from. - SmallVector<int> Mask; - std::optional<TargetTransformInfo::ShuffleKind> ShuffleKind = - isFixedVectorShuffle(VL, Mask); - if (ShuffleKind) { - // Found the bunch of extractelement instructions that must be gathered - // into a vector and can be represented as a permutation elements in a - // single input vector or of 2 input vectors. - InstructionCost Cost = - computeExtractCost(VL, VecTy, *ShuffleKind, Mask, *TTI); - AdjustExtractsCost(Cost); - if (NeedToShuffleReuses) - Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, - FinalVecTy, E->ReuseShuffleIndices); - return Cost; - } - } - if (isSplat(VL)) { - // Found the broadcasting of the single scalar, calculate the cost as the - // broadcast. - assert(VecTy == FinalVecTy && - "No reused scalars expected for broadcast."); - const auto *It = - find_if(VL, [](Value *V) { return !isa<UndefValue>(V); }); - // If all values are undefs - consider cost free. - if (It == VL.end()) - return TTI::TCC_Free; - // Add broadcast for non-identity shuffle only. - bool NeedShuffle = - VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof); - InstructionCost InsertCost = - TTI->getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, - /*Index=*/0, PoisonValue::get(VecTy), *It); - return InsertCost + (NeedShuffle - ? TTI->getShuffleCost( - TargetTransformInfo::SK_Broadcast, VecTy, - /*Mask=*/std::nullopt, CostKind, - /*Index=*/0, - /*SubTp=*/nullptr, /*Args=*/VL[0]) - : TTI::TCC_Free); - } - InstructionCost ReuseShuffleCost = 0; - if (NeedToShuffleReuses) - ReuseShuffleCost = TTI->getShuffleCost( - TTI::SK_PermuteSingleSrc, FinalVecTy, E->ReuseShuffleIndices); - // Improve gather cost for gather of loads, if we can group some of the - // loads into vector loads. - if (VL.size() > 2 && E->getOpcode() == Instruction::Load && - !E->isAltShuffle()) { - BoUpSLP::ValueSet VectorizedLoads; - unsigned StartIdx = 0; - unsigned VF = VL.size() / 2; - unsigned VectorizedCnt = 0; - unsigned ScatterVectorizeCnt = 0; - const unsigned Sz = DL->getTypeSizeInBits(E->getMainOp()->getType()); - for (unsigned MinVF = getMinVF(2 * Sz); VF >= MinVF; VF /= 2) { - for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End; - Cnt += VF) { - ArrayRef<Value *> Slice = VL.slice(Cnt, VF); - if (!VectorizedLoads.count(Slice.front()) && - !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) { - SmallVector<Value *> PointerOps; - OrdersType CurrentOrder; - LoadsState LS = - canVectorizeLoads(Slice, Slice.front(), *TTI, *DL, *SE, *LI, - *TLI, CurrentOrder, PointerOps); - switch (LS) { - case LoadsState::Vectorize: - case LoadsState::ScatterVectorize: - // Mark the vectorized loads so that we don't vectorize them - // again. - if (LS == LoadsState::Vectorize) - ++VectorizedCnt; - else - ++ScatterVectorizeCnt; - VectorizedLoads.insert(Slice.begin(), Slice.end()); - // If we vectorized initial block, no need to try to vectorize it - // again. - if (Cnt == StartIdx) - StartIdx += VF; - break; - case LoadsState::Gather: - break; - } + // Restore the mask for previous partially matched values. + for (auto [I, V] : enumerate(E->Scalars)) { + if (isa<PoisonValue>(V)) { + Mask[I] = PoisonMaskElem; + continue; } + if (Mask[I] == PoisonMaskElem) + Mask[I] = Entries.front()->findLaneForValue(V); } - // Check if the whole array was vectorized already - exit. - if (StartIdx >= VL.size()) - break; - // Found vectorizable parts - exit. - if (!VectorizedLoads.empty()) - break; + Estimator.add(Entries.front(), Mask); + return Estimator.finalize(E->ReuseShuffleIndices); } - if (!VectorizedLoads.empty()) { - InstructionCost GatherCost = 0; - unsigned NumParts = TTI->getNumberOfParts(VecTy); - bool NeedInsertSubvectorAnalysis = - !NumParts || (VL.size() / VF) > NumParts; - // Get the cost for gathered loads. - for (unsigned I = 0, End = VL.size(); I < End; I += VF) { - if (VectorizedLoads.contains(VL[I])) - continue; - GatherCost += getGatherCost(VL.slice(I, VF)); - } - // The cost for vectorized loads. - InstructionCost ScalarsCost = 0; - for (Value *V : VectorizedLoads) { - auto *LI = cast<LoadInst>(V); - ScalarsCost += - TTI->getMemoryOpCost(Instruction::Load, LI->getType(), - LI->getAlign(), LI->getPointerAddressSpace(), - CostKind, TTI::OperandValueInfo(), LI); - } - auto *LI = cast<LoadInst>(E->getMainOp()); - auto *LoadTy = FixedVectorType::get(LI->getType(), VF); - Align Alignment = LI->getAlign(); - GatherCost += - VectorizedCnt * - TTI->getMemoryOpCost(Instruction::Load, LoadTy, Alignment, - LI->getPointerAddressSpace(), CostKind, - TTI::OperandValueInfo(), LI); - GatherCost += ScatterVectorizeCnt * - TTI->getGatherScatterOpCost( - Instruction::Load, LoadTy, LI->getPointerOperand(), - /*VariableMask=*/false, Alignment, CostKind, LI); - if (NeedInsertSubvectorAnalysis) { - // Add the cost for the subvectors insert. - for (int I = VF, E = VL.size(); I < E; I += VF) - GatherCost += - TTI->getShuffleCost(TTI::SK_InsertSubvector, VecTy, - std::nullopt, CostKind, I, LoadTy); - } - return ReuseShuffleCost + GatherCost - ScalarsCost; + if (!Resized) { + unsigned VF1 = Entries.front()->getVectorFactor(); + unsigned VF2 = Entries.back()->getVectorFactor(); + if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF) + GatheredScalars.append(VF - GatheredScalars.size(), + PoisonValue::get(ScalarTy)); } + // Remove shuffled elements from list of gathers. + for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { + if (Mask[I] != PoisonMaskElem) + GatheredScalars[I] = PoisonValue::get(ScalarTy); + } + LLVM_DEBUG(dbgs() << "SLP: shuffled " << Entries.size() + << " entries for bundle that starts with " + << *VL.front() << ".\n";); + if (Entries.size() == 1) + Estimator.add(Entries.front(), Mask); + else + Estimator.add(Entries.front(), Entries.back(), Mask); + if (all_of(GatheredScalars, PoisonValue ::classof)) + return Estimator.finalize(E->ReuseShuffleIndices); + return Estimator.finalize( + E->ReuseShuffleIndices, E->Scalars.size(), + [&](Value *&Vec, SmallVectorImpl<int> &Mask) { + Vec = Estimator.gather(GatheredScalars, + Constant::getNullValue(FixedVectorType::get( + GatheredScalars.front()->getType(), + GatheredScalars.size()))); + }); } - return ReuseShuffleCost + getGatherCost(VL); + if (!all_of(GatheredScalars, PoisonValue::classof)) { + auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size()); + bool SameGathers = VL.equals(Gathers); + Value *BV = Estimator.gather( + Gathers, SameGathers ? nullptr + : Constant::getNullValue(FixedVectorType::get( + GatheredScalars.front()->getType(), + GatheredScalars.size()))); + SmallVector<int> ReuseMask(Gathers.size(), PoisonMaskElem); + std::iota(ReuseMask.begin(), ReuseMask.end(), 0); + Estimator.add(BV, ReuseMask); + } + if (ExtractShuffle) + Estimator.add(E, std::nullopt); + return Estimator.finalize(E->ReuseShuffleIndices); } InstructionCost CommonCost = 0; SmallVector<int> Mask; @@ -6945,48 +7460,89 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, } InstructionCost VecCost = VectorCost(CommonCost); - LLVM_DEBUG( - dumpTreeCosts(E, CommonCost, VecCost - CommonCost, ScalarCost)); - // Disable warnings for `this` and `E` are unused. Required for - // `dumpTreeCosts`. - (void)this; - (void)E; + LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost - CommonCost, + ScalarCost, "Calculated costs for Tree")); return VecCost - ScalarCost; }; // Calculate cost difference from vectorizing set of GEPs. // Negative value means vectorizing is profitable. auto GetGEPCostDiff = [=](ArrayRef<Value *> Ptrs, Value *BasePtr) { - InstructionCost CostSavings = 0; - for (Value *V : Ptrs) { - if (V == BasePtr) - continue; - auto *Ptr = dyn_cast<GetElementPtrInst>(V); - // GEPs may contain just addresses without instructions, considered free. - // GEPs with all constant indices also considered to have zero cost. - if (!Ptr || Ptr->hasAllConstantIndices()) - continue; - - // Here we differentiate two cases: when GEPs represent a regular - // vectorization tree node (and hence vectorized) and when the set is - // arguments of a set of loads or stores being vectorized. In the former - // case all the scalar GEPs will be removed as a result of vectorization. + InstructionCost ScalarCost = 0; + InstructionCost VecCost = 0; + // Here we differentiate two cases: (1) when Ptrs represent a regular + // vectorization tree node (as they are pointer arguments of scattered + // loads) or (2) when Ptrs are the arguments of loads or stores being + // vectorized as plane wide unit-stride load/store since all the + // loads/stores are known to be from/to adjacent locations. + assert(E->State == TreeEntry::Vectorize && + "Entry state expected to be Vectorize here."); + if (isa<LoadInst, StoreInst>(VL0)) { + // Case 2: estimate costs for pointer related costs when vectorizing to + // a wide load/store. + // Scalar cost is estimated as a set of pointers with known relationship + // between them. + // For vector code we will use BasePtr as argument for the wide load/store + // but we also need to account all the instructions which are going to + // stay in vectorized code due to uses outside of these scalar + // loads/stores. + ScalarCost = TTI->getPointersChainCost( + Ptrs, BasePtr, TTI::PointersChainInfo::getUnitStride(), ScalarTy, + CostKind); + + SmallVector<const Value *> PtrsRetainedInVecCode; + for (Value *V : Ptrs) { + if (V == BasePtr) { + PtrsRetainedInVecCode.push_back(V); + continue; + } + auto *Ptr = dyn_cast<GetElementPtrInst>(V); + // For simplicity assume Ptr to stay in vectorized code if it's not a + // GEP instruction. We don't care since it's cost considered free. + // TODO: We should check for any uses outside of vectorizable tree + // rather than just single use. + if (!Ptr || !Ptr->hasOneUse()) + PtrsRetainedInVecCode.push_back(V); + } + + if (PtrsRetainedInVecCode.size() == Ptrs.size()) { + // If all pointers stay in vectorized code then we don't have + // any savings on that. + LLVM_DEBUG(dumpTreeCosts(E, 0, ScalarCost, ScalarCost, + "Calculated GEPs cost for Tree")); + return InstructionCost{TTI::TCC_Free}; + } + VecCost = TTI->getPointersChainCost( + PtrsRetainedInVecCode, BasePtr, + TTI::PointersChainInfo::getKnownStride(), VecTy, CostKind); + } else { + // Case 1: Ptrs are the arguments of loads that we are going to transform + // into masked gather load intrinsic. + // All the scalar GEPs will be removed as a result of vectorization. // For any external uses of some lanes extract element instructions will - // be generated (which cost is estimated separately). For the latter case - // since the set of GEPs itself is not vectorized those used more than - // once will remain staying in vectorized code as well. So we should not - // count them as savings. - if (!Ptr->hasOneUse() && isa<LoadInst, StoreInst>(VL0)) - continue; - - // TODO: it is target dependent, so need to implement and then use a TTI - // interface. - CostSavings += TTI->getArithmeticInstrCost(Instruction::Add, - Ptr->getType(), CostKind); - } - LLVM_DEBUG(dbgs() << "SLP: Calculated GEPs cost savings or Tree:\n"; - E->dump()); - LLVM_DEBUG(dbgs() << "SLP: GEP cost saving = " << CostSavings << "\n"); - return InstructionCost() - CostSavings; + // be generated (which cost is estimated separately). + TTI::PointersChainInfo PtrsInfo = + all_of(Ptrs, + [](const Value *V) { + auto *Ptr = dyn_cast<GetElementPtrInst>(V); + return Ptr && !Ptr->hasAllConstantIndices(); + }) + ? TTI::PointersChainInfo::getUnknownStride() + : TTI::PointersChainInfo::getKnownStride(); + + ScalarCost = TTI->getPointersChainCost(Ptrs, BasePtr, PtrsInfo, ScalarTy, + CostKind); + if (auto *BaseGEP = dyn_cast<GEPOperator>(BasePtr)) { + SmallVector<const Value *> Indices(BaseGEP->indices()); + VecCost = TTI->getGEPCost(BaseGEP->getSourceElementType(), + BaseGEP->getPointerOperand(), Indices, VecTy, + CostKind); + } + } + + LLVM_DEBUG(dumpTreeCosts(E, 0, VecCost, ScalarCost, + "Calculated GEPs cost for Tree")); + + return VecCost - ScalarCost; }; switch (ShuffleOrOp) { @@ -7062,7 +7618,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy); - SmallVector<int> InsertMask(NumElts, UndefMaskElem); + SmallVector<int> InsertMask(NumElts, PoisonMaskElem); unsigned OffsetBeg = *getInsertIndex(VL.front()); unsigned OffsetEnd = OffsetBeg; InsertMask[OffsetBeg] = 0; @@ -7099,13 +7655,13 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, SmallVector<int> Mask; if (!E->ReorderIndices.empty()) { inversePermutation(E->ReorderIndices, Mask); - Mask.append(InsertVecSz - Mask.size(), UndefMaskElem); + Mask.append(InsertVecSz - Mask.size(), PoisonMaskElem); } else { - Mask.assign(VecSz, UndefMaskElem); + Mask.assign(VecSz, PoisonMaskElem); std::iota(Mask.begin(), std::next(Mask.begin(), InsertVecSz), 0); } bool IsIdentity = true; - SmallVector<int> PrevMask(InsertVecSz, UndefMaskElem); + SmallVector<int> PrevMask(InsertVecSz, PoisonMaskElem); Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]); @@ -7148,14 +7704,14 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, InsertVecTy); } else { for (unsigned I = 0, End = OffsetBeg - Offset; I < End; ++I) - Mask[I] = InMask.test(I) ? UndefMaskElem : I; + Mask[I] = InMask.test(I) ? PoisonMaskElem : I; for (unsigned I = OffsetBeg - Offset, End = OffsetEnd - Offset; I <= End; ++I) - if (Mask[I] != UndefMaskElem) + if (Mask[I] != PoisonMaskElem) Mask[I] = I + VecSz; for (unsigned I = OffsetEnd + 1 - Offset; I < VecSz; ++I) Mask[I] = - ((I >= InMask.size()) || InMask.test(I)) ? UndefMaskElem : I; + ((I >= InMask.size()) || InMask.test(I)) ? PoisonMaskElem : I; Cost += TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, InsertVecTy, Mask); } } @@ -7422,11 +7978,11 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy, CostKind); } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) { - VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, - Builder.getInt1Ty(), + auto *MaskTy = FixedVectorType::get(Builder.getInt1Ty(), VL.size()); + VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, CI0->getPredicate(), CostKind, VL0); VecCost += TTI->getCmpSelInstrCost( - E->getOpcode(), ScalarTy, Builder.getInt1Ty(), + E->getOpcode(), VecTy, MaskTy, cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind, E->getAltOp()); } else { @@ -7615,7 +8171,7 @@ InstructionCost BoUpSLP::getSpillCost() const { unsigned BundleWidth = VectorizableTree.front()->Scalars.size(); InstructionCost Cost = 0; - SmallPtrSet<Instruction*, 4> LiveValues; + SmallPtrSet<Instruction *, 4> LiveValues; Instruction *PrevInst = nullptr; // The entries in VectorizableTree are not necessarily ordered by their @@ -7626,6 +8182,8 @@ InstructionCost BoUpSLP::getSpillCost() const { // are grouped together. Using dominance ensures a deterministic order. SmallVector<Instruction *, 16> OrderedScalars; for (const auto &TEPtr : VectorizableTree) { + if (TEPtr->State != TreeEntry::Vectorize) + continue; Instruction *Inst = dyn_cast<Instruction>(TEPtr->Scalars[0]); if (!Inst) continue; @@ -7639,7 +8197,7 @@ InstructionCost BoUpSLP::getSpillCost() const { assert((NodeA == NodeB) == (NodeA->getDFSNumIn() == NodeB->getDFSNumIn()) && "Different nodes should have different DFS numbers"); if (NodeA != NodeB) - return NodeA->getDFSNumIn() < NodeB->getDFSNumIn(); + return NodeA->getDFSNumIn() > NodeB->getDFSNumIn(); return B->comesBefore(A); }); @@ -7698,7 +8256,7 @@ InstructionCost BoUpSLP::getSpillCost() const { }; // Debug information does not impact spill cost. - if (isa<CallInst>(&*PrevInstIt) && !NoCallIntrinsic(&*PrevInstIt) && + if (isa<CallBase>(&*PrevInstIt) && !NoCallIntrinsic(&*PrevInstIt) && &*PrevInstIt != PrevInst) NumCalls++; @@ -7706,7 +8264,7 @@ InstructionCost BoUpSLP::getSpillCost() const { } if (NumCalls) { - SmallVector<Type*, 4> V; + SmallVector<Type *, 4> V; for (auto *II : LiveValues) { auto *ScalarTy = II->getType(); if (auto *VectorTy = dyn_cast<FixedVectorType>(ScalarTy)) @@ -7797,8 +8355,8 @@ static T *performExtractsShuffleAction( ResizeAction(ShuffleMask.begin()->first, Mask, /*ForSingleMask=*/false); SmallBitVector IsBasePoison = isUndefVector<true>(Base, UseMask); for (unsigned Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) { - if (Mask[Idx] == UndefMaskElem) - Mask[Idx] = IsBasePoison.test(Idx) ? UndefMaskElem : Idx; + if (Mask[Idx] == PoisonMaskElem) + Mask[Idx] = IsBasePoison.test(Idx) ? PoisonMaskElem : Idx; else Mask[Idx] = (Res.second ? Idx : Mask[Idx]) + VF; } @@ -7827,8 +8385,8 @@ static T *performExtractsShuffleAction( // can shuffle them directly. ArrayRef<int> SecMask = VMIt->second; for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { - if (SecMask[I] != UndefMaskElem) { - assert(Mask[I] == UndefMaskElem && "Multiple uses of scalars."); + if (SecMask[I] != PoisonMaskElem) { + assert(Mask[I] == PoisonMaskElem && "Multiple uses of scalars."); Mask[I] = SecMask[I] + Vec1VF; } } @@ -7841,12 +8399,12 @@ static T *performExtractsShuffleAction( ResizeAction(VMIt->first, VMIt->second, /*ForSingleMask=*/false); ArrayRef<int> SecMask = VMIt->second; for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { - if (Mask[I] != UndefMaskElem) { - assert(SecMask[I] == UndefMaskElem && "Multiple uses of scalars."); + if (Mask[I] != PoisonMaskElem) { + assert(SecMask[I] == PoisonMaskElem && "Multiple uses of scalars."); if (Res1.second) Mask[I] = I; - } else if (SecMask[I] != UndefMaskElem) { - assert(Mask[I] == UndefMaskElem && "Multiple uses of scalars."); + } else if (SecMask[I] != PoisonMaskElem) { + assert(Mask[I] == PoisonMaskElem && "Multiple uses of scalars."); Mask[I] = (Res2.second ? I : SecMask[I]) + VF; } } @@ -7863,11 +8421,11 @@ static T *performExtractsShuffleAction( ResizeAction(VMIt->first, VMIt->second, /*ForSingleMask=*/false); ArrayRef<int> SecMask = VMIt->second; for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { - if (SecMask[I] != UndefMaskElem) { - assert((Mask[I] == UndefMaskElem || IsBaseNotUndef) && + if (SecMask[I] != PoisonMaskElem) { + assert((Mask[I] == PoisonMaskElem || IsBaseNotUndef) && "Multiple uses of scalars."); Mask[I] = (Res.second ? I : SecMask[I]) + VF; - } else if (Mask[I] != UndefMaskElem) { + } else if (Mask[I] != PoisonMaskElem) { Mask[I] = I; } } @@ -7877,12 +8435,23 @@ static T *performExtractsShuffleAction( } InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { + // Build a map for gathered scalars to the nodes where they are used. + ValueToGatherNodes.clear(); + for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) { + if (EntryPtr->State != TreeEntry::NeedToGather) + continue; + for (Value *V : EntryPtr->Scalars) + if (!isConstant(V)) + ValueToGatherNodes.try_emplace(V).first->getSecond().insert( + EntryPtr.get()); + } InstructionCost Cost = 0; LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << VectorizableTree.size() << ".\n"); unsigned BundleWidth = VectorizableTree[0]->Scalars.size(); + SmallPtrSet<Value *, 4> CheckedExtracts; for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { TreeEntry &TE = *VectorizableTree[I]; if (TE.State == TreeEntry::NeedToGather) { @@ -7898,7 +8467,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { } } - InstructionCost C = getEntryCost(&TE, VectorizedVals); + InstructionCost C = getEntryCost(&TE, VectorizedVals, CheckedExtracts); Cost += C; LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle that starts with " << *TE.Scalars[0] @@ -7951,7 +8520,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { (void)ShuffleMasks.emplace_back(); SmallVectorImpl<int> &Mask = ShuffleMasks.back()[ScalarTE]; if (Mask.empty()) - Mask.assign(FTy->getNumElements(), UndefMaskElem); + Mask.assign(FTy->getNumElements(), PoisonMaskElem); // Find the insertvector, vectorized in tree, if any. Value *Base = VU; while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) { @@ -7965,7 +8534,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { do { IEBase = cast<InsertElementInst>(Base); int Idx = *getInsertIndex(IEBase); - assert(Mask[Idx] == UndefMaskElem && + assert(Mask[Idx] == PoisonMaskElem && "InsertElementInstruction used already."); Mask[Idx] = Idx; Base = IEBase->getOperand(0); @@ -7985,7 +8554,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { int InIdx = *InsertIdx; SmallVectorImpl<int> &Mask = ShuffleMasks[VecId][ScalarTE]; if (Mask.empty()) - Mask.assign(FTy->getNumElements(), UndefMaskElem); + Mask.assign(FTy->getNumElements(), PoisonMaskElem); Mask[InIdx] = EU.Lane; DemandedElts[VecId].setBit(InIdx); continue; @@ -8024,7 +8593,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { (all_of(Mask, [VF](int Idx) { return Idx < 2 * static_cast<int>(VF); }) && !ShuffleVectorInst::isIdentityMask(Mask)))) { - SmallVector<int> OrigMask(VecVF, UndefMaskElem); + SmallVector<int> OrigMask(VecVF, PoisonMaskElem); std::copy(Mask.begin(), std::next(Mask.begin(), std::min(VF, VecVF)), OrigMask.begin()); C = TTI->getShuffleCost( @@ -8110,17 +8679,23 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, // No need to check for the topmost gather node. if (TE == VectorizableTree.front().get()) return std::nullopt; - Mask.assign(VL.size(), UndefMaskElem); + Mask.assign(VL.size(), PoisonMaskElem); assert(TE->UserTreeIndices.size() == 1 && "Expected only single user of the gather node."); // TODO: currently checking only for Scalars in the tree entry, need to count // reused elements too for better cost estimation. Instruction &UserInst = getLastInstructionInBundle(TE->UserTreeIndices.front().UserTE); - auto *PHI = dyn_cast<PHINode>(&UserInst); - auto *NodeUI = DT->getNode( - PHI ? PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx) - : UserInst.getParent()); + BasicBlock *ParentBB = nullptr; + // Main node of PHI entries keeps the correct order of operands/incoming + // blocks. + if (auto *PHI = + dyn_cast<PHINode>(TE->UserTreeIndices.front().UserTE->getMainOp())) { + ParentBB = PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx); + } else { + ParentBB = UserInst.getParent(); + } + auto *NodeUI = DT->getNode(ParentBB); assert(NodeUI && "Should only process reachable instructions"); SmallPtrSet<Value *, 4> GatheredScalars(VL.begin(), VL.end()); auto CheckOrdering = [&](Instruction *LastEI) { @@ -8147,45 +8722,6 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, return false; return true; }; - // Build a lists of values to tree entries. - DenseMap<Value *, SmallPtrSet<const TreeEntry *, 4>> ValueToTEs; - for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) { - if (EntryPtr.get() == TE) - continue; - if (EntryPtr->State != TreeEntry::NeedToGather) - continue; - if (!any_of(EntryPtr->Scalars, [&GatheredScalars](Value *V) { - return GatheredScalars.contains(V); - })) - continue; - assert(EntryPtr->UserTreeIndices.size() == 1 && - "Expected only single user of the gather node."); - Instruction &EntryUserInst = - getLastInstructionInBundle(EntryPtr->UserTreeIndices.front().UserTE); - if (&UserInst == &EntryUserInst) { - // If 2 gathers are operands of the same entry, compare operands indices, - // use the earlier one as the base. - if (TE->UserTreeIndices.front().UserTE == - EntryPtr->UserTreeIndices.front().UserTE && - TE->UserTreeIndices.front().EdgeIdx < - EntryPtr->UserTreeIndices.front().EdgeIdx) - continue; - } - // Check if the user node of the TE comes after user node of EntryPtr, - // otherwise EntryPtr depends on TE. - auto *EntryPHI = dyn_cast<PHINode>(&EntryUserInst); - auto *EntryI = - EntryPHI - ? EntryPHI - ->getIncomingBlock(EntryPtr->UserTreeIndices.front().EdgeIdx) - ->getTerminator() - : &EntryUserInst; - if (!CheckOrdering(EntryI)) - continue; - for (Value *V : EntryPtr->Scalars) - if (!isConstant(V)) - ValueToTEs.try_emplace(V).first->getSecond().insert(EntryPtr.get()); - } // Find all tree entries used by the gathered values. If no common entries // found - not a shuffle. // Here we build a set of tree nodes for each gathered value and trying to @@ -8195,16 +8731,58 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, // have a permutation of 2 input vectors. SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs; DenseMap<Value *, int> UsedValuesEntry; - for (Value *V : TE->Scalars) { + for (Value *V : VL) { if (isConstant(V)) continue; // Build a list of tree entries where V is used. SmallPtrSet<const TreeEntry *, 4> VToTEs; - auto It = ValueToTEs.find(V); - if (It != ValueToTEs.end()) - VToTEs = It->second; - if (const TreeEntry *VTE = getTreeEntry(V)) + for (const TreeEntry *TEPtr : ValueToGatherNodes.find(V)->second) { + if (TEPtr == TE) + continue; + assert(any_of(TEPtr->Scalars, + [&](Value *V) { return GatheredScalars.contains(V); }) && + "Must contain at least single gathered value."); + assert(TEPtr->UserTreeIndices.size() == 1 && + "Expected only single user of the gather node."); + PHINode *EntryPHI = + dyn_cast<PHINode>(TEPtr->UserTreeIndices.front().UserTE->getMainOp()); + Instruction *EntryUserInst = + EntryPHI ? nullptr + : &getLastInstructionInBundle( + TEPtr->UserTreeIndices.front().UserTE); + if (&UserInst == EntryUserInst) { + assert(!EntryPHI && "Unexpected phi node entry."); + // If 2 gathers are operands of the same entry, compare operands + // indices, use the earlier one as the base. + if (TE->UserTreeIndices.front().UserTE == + TEPtr->UserTreeIndices.front().UserTE && + TE->UserTreeIndices.front().EdgeIdx < + TEPtr->UserTreeIndices.front().EdgeIdx) + continue; + } + // Check if the user node of the TE comes after user node of EntryPtr, + // otherwise EntryPtr depends on TE. + auto *EntryI = + EntryPHI + ? EntryPHI + ->getIncomingBlock(TEPtr->UserTreeIndices.front().EdgeIdx) + ->getTerminator() + : EntryUserInst; + if ((ParentBB != EntryI->getParent() || + TE->UserTreeIndices.front().EdgeIdx < + TEPtr->UserTreeIndices.front().EdgeIdx || + TE->UserTreeIndices.front().UserTE != + TEPtr->UserTreeIndices.front().UserTE) && + !CheckOrdering(EntryI)) + continue; + VToTEs.insert(TEPtr); + } + if (const TreeEntry *VTE = getTreeEntry(V)) { + Instruction &EntryUserInst = getLastInstructionInBundle(VTE); + if (&EntryUserInst == &UserInst || !CheckOrdering(&EntryUserInst)) + continue; VToTEs.insert(VTE); + } if (VToTEs.empty()) continue; if (UsedTEs.empty()) { @@ -8260,13 +8838,13 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, auto *It = find_if(FirstEntries, [=](const TreeEntry *EntryPtr) { return EntryPtr->isSame(VL) || EntryPtr->isSame(TE->Scalars); }); - if (It != FirstEntries.end()) { + if (It != FirstEntries.end() && (*It)->getVectorFactor() == VL.size()) { Entries.push_back(*It); std::iota(Mask.begin(), Mask.end(), 0); // Clear undef scalars. for (int I = 0, Sz = VL.size(); I < Sz; ++I) - if (isa<PoisonValue>(TE->Scalars[I])) - Mask[I] = UndefMaskElem; + if (isa<PoisonValue>(VL[I])) + Mask[I] = PoisonMaskElem; return TargetTransformInfo::SK_PermuteSingleSrc; } // No perfect match, just shuffle, so choose the first tree node from the @@ -8302,10 +8880,18 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, break; } } - // No 2 source vectors with the same vector factor - give up and do regular - // gather. - if (Entries.empty()) - return std::nullopt; + // No 2 source vectors with the same vector factor - just choose 2 with max + // index. + if (Entries.empty()) { + Entries.push_back( + *std::max_element(UsedTEs.front().begin(), UsedTEs.front().end(), + [](const TreeEntry *TE1, const TreeEntry *TE2) { + return TE1->Idx < TE2->Idx; + })); + Entries.push_back(SecondEntries.front()); + VF = std::max(Entries.front()->getVectorFactor(), + Entries.back()->getVectorFactor()); + } } bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, UndefValue::classof); @@ -8427,19 +9013,8 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, return std::nullopt; } -InstructionCost BoUpSLP::getGatherCost(FixedVectorType *Ty, - const APInt &ShuffledIndices, - bool NeedToShuffle) const { - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - InstructionCost Cost = - TTI->getScalarizationOverhead(Ty, ~ShuffledIndices, /*Insert*/ true, - /*Extract*/ false, CostKind); - if (NeedToShuffle) - Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); - return Cost; -} - -InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { +InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, + bool ForPoisonSrc) const { // Find the type of the operands in VL. Type *ScalarTy = VL[0]->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) @@ -8451,20 +9026,36 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { // shuffle candidates. APInt ShuffledElements = APInt::getZero(VL.size()); DenseSet<Value *> UniqueElements; - // Iterate in reverse order to consider insert elements with the high cost. - for (unsigned I = VL.size(); I > 0; --I) { - unsigned Idx = I - 1; + constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost Cost; + auto EstimateInsertCost = [&](unsigned I, Value *V) { + if (!ForPoisonSrc) + Cost += + TTI->getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, + I, Constant::getNullValue(VecTy), V); + }; + for (unsigned I = 0, E = VL.size(); I < E; ++I) { + Value *V = VL[I]; // No need to shuffle duplicates for constants. - if (isConstant(VL[Idx])) { - ShuffledElements.setBit(Idx); + if ((ForPoisonSrc && isConstant(V)) || isa<UndefValue>(V)) { + ShuffledElements.setBit(I); continue; } - if (!UniqueElements.insert(VL[Idx]).second) { + if (!UniqueElements.insert(V).second) { DuplicateNonConst = true; - ShuffledElements.setBit(Idx); + ShuffledElements.setBit(I); + continue; } + EstimateInsertCost(I, V); } - return getGatherCost(VecTy, ShuffledElements, DuplicateNonConst); + if (ForPoisonSrc) + Cost = + TTI->getScalarizationOverhead(VecTy, ~ShuffledElements, /*Insert*/ true, + /*Extract*/ false, CostKind); + if (DuplicateNonConst) + Cost += + TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); + return Cost; } // Perform operand reordering on the instructions in VL and return the reordered @@ -8483,6 +9074,9 @@ void BoUpSLP::reorderInputsAccordingToOpcode( } Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { + auto &Res = EntryToLastInstruction.FindAndConstruct(E); + if (Res.second) + return *Res.second; // Get the basic block this bundle is in. All instructions in the bundle // should be in this block (except for extractelement-like instructions with // constant indeces). @@ -8497,7 +9091,7 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { isVectorLikeInstWithConstOps(I); })); - auto &&FindLastInst = [E, Front, this, &BB]() { + auto FindLastInst = [&]() { Instruction *LastInst = Front; for (Value *V : E->Scalars) { auto *I = dyn_cast<Instruction>(V); @@ -8508,9 +9102,11 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { LastInst = I; continue; } - assert(isVectorLikeInstWithConstOps(LastInst) && - isVectorLikeInstWithConstOps(I) && - "Expected vector-like insts only."); + assert(((E->getOpcode() == Instruction::GetElementPtr && + !isa<GetElementPtrInst>(I)) || + (isVectorLikeInstWithConstOps(LastInst) && + isVectorLikeInstWithConstOps(I))) && + "Expected vector-like or non-GEP in GEP node insts only."); if (!DT->isReachableFromEntry(LastInst->getParent())) { LastInst = I; continue; @@ -8531,7 +9127,7 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { return LastInst; }; - auto &&FindFirstInst = [E, Front, this]() { + auto FindFirstInst = [&]() { Instruction *FirstInst = Front; for (Value *V : E->Scalars) { auto *I = dyn_cast<Instruction>(V); @@ -8542,9 +9138,11 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { FirstInst = I; continue; } - assert(isVectorLikeInstWithConstOps(FirstInst) && - isVectorLikeInstWithConstOps(I) && - "Expected vector-like insts only."); + assert(((E->getOpcode() == Instruction::GetElementPtr && + !isa<GetElementPtrInst>(I)) || + (isVectorLikeInstWithConstOps(FirstInst) && + isVectorLikeInstWithConstOps(I))) && + "Expected vector-like or non-GEP in GEP node insts only."); if (!DT->isReachableFromEntry(FirstInst->getParent())) { FirstInst = I; continue; @@ -8566,22 +9164,23 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { // Set the insert point to the beginning of the basic block if the entry // should not be scheduled. - if (E->State != TreeEntry::NeedToGather && - (doesNotNeedToSchedule(E->Scalars) || + if (doesNotNeedToSchedule(E->Scalars) || + (E->State != TreeEntry::NeedToGather && all_of(E->Scalars, isVectorLikeInstWithConstOps))) { - Instruction *InsertInst; - if (all_of(E->Scalars, [](Value *V) { + if ((E->getOpcode() == Instruction::GetElementPtr && + any_of(E->Scalars, + [](Value *V) { + return !isa<GetElementPtrInst>(V) && isa<Instruction>(V); + })) || + all_of(E->Scalars, [](Value *V) { return !isVectorLikeInstWithConstOps(V) && isUsedOutsideBlock(V); })) - InsertInst = FindLastInst(); + Res.second = FindLastInst(); else - InsertInst = FindFirstInst(); - return *InsertInst; + Res.second = FindFirstInst(); + return *Res.second; } - // The last instruction in the bundle in program order. - Instruction *LastInst = nullptr; - // Find the last instruction. The common case should be that BB has been // scheduled, and the last instruction is VL.back(). So we start with // VL.back() and iterate over schedule data until we reach the end of the @@ -8594,7 +9193,7 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { if (Bundle && Bundle->isPartOfBundle()) for (; Bundle; Bundle = Bundle->NextInBundle) if (Bundle->OpValue == Bundle->Inst) - LastInst = Bundle->Inst; + Res.second = Bundle->Inst; } // LastInst can still be null at this point if there's either not an entry @@ -8615,15 +9214,15 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { // not ideal. However, this should be exceedingly rare since it requires that // we both exit early from buildTree_rec and that the bundle be out-of-order // (causing us to iterate all the way to the end of the block). - if (!LastInst) - LastInst = FindLastInst(); - assert(LastInst && "Failed to find last instruction in bundle"); - return *LastInst; + if (!Res.second) + Res.second = FindLastInst(); + assert(Res.second && "Failed to find last instruction in bundle"); + return *Res.second; } void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { auto *Front = E->getMainOp(); - Instruction *LastInst = EntryToLastInstruction.lookup(E); + Instruction *LastInst = &getLastInstructionInBundle(E); assert(LastInst && "Failed to find last instruction in bundle"); // If the instruction is PHI, set the insert point after all the PHIs. bool IsPHI = isa<PHINode>(LastInst); @@ -8641,7 +9240,7 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { Builder.SetCurrentDebugLocation(Front->getDebugLoc()); } -Value *BoUpSLP::gather(ArrayRef<Value *> VL) { +Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) { // List of instructions/lanes from current block and/or the blocks which are // part of the current loop. These instructions will be inserted at the end to // make it possible to optimize loops and hoist invariant instructions out of @@ -8658,7 +9257,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) { for (int I = 0, E = VL.size(); I < E; ++I) { if (auto *Inst = dyn_cast<Instruction>(VL[I])) if ((CheckPredecessor(Inst->getParent(), Builder.GetInsertBlock()) || - getTreeEntry(Inst) || (L && (L->contains(Inst)))) && + getTreeEntry(Inst) || + (L && (!Root || L->isLoopInvariant(Root)) && L->contains(Inst))) && PostponedIndices.insert(I).second) PostponedInsts.emplace_back(Inst, I); } @@ -8681,7 +9281,7 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) { Value *Val0 = isa<StoreInst>(VL[0]) ? cast<StoreInst>(VL[0])->getValueOperand() : VL[0]; FixedVectorType *VecTy = FixedVectorType::get(Val0->getType(), VL.size()); - Value *Vec = PoisonValue::get(VecTy); + Value *Vec = Root ? Root : PoisonValue::get(VecTy); SmallVector<int> NonConsts; // Insert constant values at first. for (int I = 0, E = VL.size(); I < E; ++I) { @@ -8691,6 +9291,18 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) { NonConsts.push_back(I); continue; } + if (Root) { + if (!isa<UndefValue>(VL[I])) { + NonConsts.push_back(I); + continue; + } + if (isa<PoisonValue>(VL[I])) + continue; + if (auto *SV = dyn_cast<ShuffleVectorInst>(Root)) { + if (SV->getMaskValue(I) == PoisonMaskElem) + continue; + } + } Vec = CreateInsertElement(Vec, VL[I], I); } // Insert non-constant values. @@ -8789,6 +9401,10 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { } return Vec; } + Value *createIdentity(Value *V) { return V; } + Value *createPoison(Type *Ty, unsigned VF) { + return PoisonValue::get(FixedVectorType::get(Ty, VF)); + } /// Resizes 2 input vector to match the sizes, if the they are not equal /// yet. The smallest vector is resized to the size of the larger vector. void resizeToMatch(Value *&V1, Value *&V2) { @@ -8798,7 +9414,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { int V2VF = cast<FixedVectorType>(V2->getType())->getNumElements(); int VF = std::max(V1VF, V2VF); int MinVF = std::min(V1VF, V2VF); - SmallVector<int> IdentityMask(VF, UndefMaskElem); + SmallVector<int> IdentityMask(VF, PoisonMaskElem); std::iota(IdentityMask.begin(), std::next(IdentityMask.begin(), MinVF), 0); Value *&Op = MinVF == V1VF ? V1 : V2; @@ -8821,7 +9437,8 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { assert(V1 && "Expected at least one vector value."); ShuffleIRBuilder ShuffleBuilder(Builder, R.GatherShuffleExtractSeq, R.CSEBlocks); - return BaseShuffleAnalysis::createShuffle(V1, V2, Mask, ShuffleBuilder); + return BaseShuffleAnalysis::createShuffle<Value *>(V1, V2, Mask, + ShuffleBuilder); } /// Transforms mask \p CommonMask per given \p Mask to make proper set after @@ -8829,7 +9446,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { static void transformMaskAfterShuffle(MutableArrayRef<int> CommonMask, ArrayRef<int> Mask) { for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) - if (Mask[Idx] != UndefMaskElem) + if (Mask[Idx] != PoisonMaskElem) CommonMask[Idx] = Idx; } @@ -8837,6 +9454,39 @@ public: ShuffleInstructionBuilder(IRBuilderBase &Builder, BoUpSLP &R) : Builder(Builder), R(R) {} + /// Adjusts extractelements after reusing them. + Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask) { + Value *VecBase = nullptr; + for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { + int Idx = Mask[I]; + if (Idx == PoisonMaskElem) + continue; + auto *EI = cast<ExtractElementInst>(E->Scalars[I]); + VecBase = EI->getVectorOperand(); + // If the only one use is vectorized - can delete the extractelement + // itself. + if (!EI->hasOneUse() || any_of(EI->users(), [&](User *U) { + return !R.ScalarToTreeEntry.count(U); + })) + continue; + R.eraseInstruction(EI); + } + return VecBase; + } + /// Checks if the specified entry \p E needs to be delayed because of its + /// dependency nodes. + Value *needToDelay(const TreeEntry *E, ArrayRef<const TreeEntry *> Deps) { + // No need to delay emission if all deps are ready. + if (all_of(Deps, [](const TreeEntry *TE) { return TE->VectorizedValue; })) + return nullptr; + // Postpone gather emission, will be emitted after the end of the + // process to keep correct order. + auto *VecTy = FixedVectorType::get(E->Scalars.front()->getType(), + E->getVectorFactor()); + return Builder.CreateAlignedLoad( + VecTy, PoisonValue::get(PointerType::getUnqual(VecTy->getContext())), + MaybeAlign()); + } /// Adds 2 input vectors and the mask for their shuffling. void add(Value *V1, Value *V2, ArrayRef<int> Mask) { assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors."); @@ -8849,15 +9499,15 @@ public: Value *Vec = InVectors.front(); if (InVectors.size() == 2) { Vec = createShuffle(Vec, InVectors.back(), CommonMask); - transformMaskAfterShuffle(CommonMask, Mask); + transformMaskAfterShuffle(CommonMask, CommonMask); } else if (cast<FixedVectorType>(Vec->getType())->getNumElements() != Mask.size()) { Vec = createShuffle(Vec, nullptr, CommonMask); - transformMaskAfterShuffle(CommonMask, Mask); + transformMaskAfterShuffle(CommonMask, CommonMask); } V1 = createShuffle(V1, V2, Mask); for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) - if (Mask[Idx] != UndefMaskElem) + if (Mask[Idx] != PoisonMaskElem) CommonMask[Idx] = Idx + Sz; InVectors.front() = Vec; if (InVectors.size() == 2) @@ -8870,7 +9520,7 @@ public: if (InVectors.empty()) { if (!isa<FixedVectorType>(V1->getType())) { V1 = createShuffle(V1, nullptr, CommonMask); - CommonMask.assign(Mask.size(), UndefMaskElem); + CommonMask.assign(Mask.size(), PoisonMaskElem); transformMaskAfterShuffle(CommonMask, Mask); } InVectors.push_back(V1); @@ -8892,7 +9542,7 @@ public: transformMaskAfterShuffle(CommonMask, CommonMask); } for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) - if (CommonMask[Idx] == UndefMaskElem && Mask[Idx] != UndefMaskElem) + if (CommonMask[Idx] == PoisonMaskElem && Mask[Idx] != PoisonMaskElem) CommonMask[Idx] = V->getType() != V1->getType() ? Idx + Sz @@ -8910,7 +9560,7 @@ public: // Check if second vector is required if the used elements are already // used from the first one. for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) - if (Mask[Idx] != UndefMaskElem && CommonMask[Idx] == UndefMaskElem) { + if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem) { InVectors.push_back(V1); break; } @@ -8919,7 +9569,7 @@ public: if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType())) VF = FTy->getNumElements(); for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) - if (Mask[Idx] != UndefMaskElem && CommonMask[Idx] == UndefMaskElem) + if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem) CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF); } /// Adds another one input vector and the mask for the shuffling. @@ -8928,17 +9578,46 @@ public: inversePermutation(Order, NewMask); add(V1, NewMask); } + Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) { + return R.gather(VL, Root); + } + Value *createFreeze(Value *V) { return Builder.CreateFreeze(V); } /// Finalize emission of the shuffles. + /// \param Action the action (if any) to be performed before final applying of + /// the \p ExtMask mask. Value * - finalize(ArrayRef<int> ExtMask = std::nullopt) { + finalize(ArrayRef<int> ExtMask, unsigned VF = 0, + function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) { IsFinalized = true; + if (Action) { + Value *Vec = InVectors.front(); + if (InVectors.size() == 2) { + Vec = createShuffle(Vec, InVectors.back(), CommonMask); + InVectors.pop_back(); + } else { + Vec = createShuffle(Vec, nullptr, CommonMask); + } + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (CommonMask[Idx] != PoisonMaskElem) + CommonMask[Idx] = Idx; + assert(VF > 0 && + "Expected vector length for the final value before action."); + unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements(); + if (VecVF < VF) { + SmallVector<int> ResizeMask(VF, PoisonMaskElem); + std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), VecVF), 0); + Vec = createShuffle(Vec, nullptr, ResizeMask); + } + Action(Vec, CommonMask); + InVectors.front() = Vec; + } if (!ExtMask.empty()) { if (CommonMask.empty()) { CommonMask.assign(ExtMask.begin(), ExtMask.end()); } else { - SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem); + SmallVector<int> NewMask(ExtMask.size(), PoisonMaskElem); for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) { - if (ExtMask[I] == UndefMaskElem) + if (ExtMask[I] == PoisonMaskElem) continue; NewMask[I] = CommonMask[ExtMask[I]]; } @@ -9009,18 +9688,18 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { // ... (use %2) // %shuffle = shuffle <2 x> %2, poison, <2 x> {2, 0} // br %block - SmallVector<int> UniqueIdxs(VF, UndefMaskElem); + SmallVector<int> UniqueIdxs(VF, PoisonMaskElem); SmallSet<int, 4> UsedIdxs; int Pos = 0; for (int Idx : VE->ReuseShuffleIndices) { - if (Idx != static_cast<int>(VF) && Idx != UndefMaskElem && + if (Idx != static_cast<int>(VF) && Idx != PoisonMaskElem && UsedIdxs.insert(Idx).second) UniqueIdxs[Idx] = Pos; ++Pos; } assert(VF >= UsedIdxs.size() && "Expected vectorization factor " "less than original vector size."); - UniqueIdxs.append(VF - UsedIdxs.size(), UndefMaskElem); + UniqueIdxs.append(VF - UsedIdxs.size(), PoisonMaskElem); V = FinalShuffle(V, UniqueIdxs); } else { assert(VF < cast<FixedVectorType>(V->getType())->getNumElements() && @@ -9031,6 +9710,21 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { V = FinalShuffle(V, UniformMask); } } + // Need to update the operand gather node, if actually the operand is not a + // vectorized node, but the buildvector/gather node, which matches one of + // the vectorized nodes. + if (find_if(VE->UserTreeIndices, [&](const EdgeInfo &EI) { + return EI.UserTE == E && EI.EdgeIdx == NodeIdx; + }) == VE->UserTreeIndices.end()) { + auto *It = find_if( + VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { + return TE->State == TreeEntry::NeedToGather && + TE->UserTreeIndices.front().UserTE == E && + TE->UserTreeIndices.front().EdgeIdx == NodeIdx; + }); + assert(It != VectorizableTree.end() && "Expected gather node operand."); + (*It)->VectorizedValue = V; + } return V; } } @@ -9049,108 +9743,370 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { IRBuilder<>::InsertPointGuard Guard(Builder); if (E->getOpcode() != Instruction::InsertElement && E->getOpcode() != Instruction::PHI) { - Instruction *LastInst = EntryToLastInstruction.lookup(E); + Instruction *LastInst = &getLastInstructionInBundle(E); assert(LastInst && "Failed to find last instruction in bundle"); Builder.SetInsertPoint(LastInst); } return vectorizeTree(I->get()); } -Value *BoUpSLP::createBuildVector(const TreeEntry *E) { +template <typename BVTy, typename ResTy, typename... Args> +ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { assert(E->State == TreeEntry::NeedToGather && "Expected gather node."); unsigned VF = E->getVectorFactor(); - ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); - SmallVector<Value *> Gathered( - VF, PoisonValue::get(E->Scalars.front()->getType())); bool NeedFreeze = false; - SmallVector<Value *> VL(E->Scalars.begin(), E->Scalars.end()); - // Build a mask out of the redorder indices and reorder scalars per this mask. + SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(), + E->ReuseShuffleIndices.end()); + SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); + // Build a mask out of the reorder indices and reorder scalars per this + // mask. SmallVector<int> ReorderMask; inversePermutation(E->ReorderIndices, ReorderMask); if (!ReorderMask.empty()) - reorderScalars(VL, ReorderMask); - SmallVector<int> ReuseMask(VF, UndefMaskElem); - if (!allConstant(VL)) { + reorderScalars(GatheredScalars, ReorderMask); + auto FindReusedSplat = [&](SmallVectorImpl<int> &Mask) { + if (!isSplat(E->Scalars) || none_of(E->Scalars, [](Value *V) { + return isa<UndefValue>(V) && !isa<PoisonValue>(V); + })) + return false; + TreeEntry *UserTE = E->UserTreeIndices.back().UserTE; + unsigned EdgeIdx = E->UserTreeIndices.back().EdgeIdx; + if (UserTE->getNumOperands() != 2) + return false; + auto *It = + find_if(VectorizableTree, [=](const std::unique_ptr<TreeEntry> &TE) { + return find_if(TE->UserTreeIndices, [=](const EdgeInfo &EI) { + return EI.UserTE == UserTE && EI.EdgeIdx != EdgeIdx; + }) != TE->UserTreeIndices.end(); + }); + if (It == VectorizableTree.end()) + return false; + unsigned I = + *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; }); + int Sz = Mask.size(); + if (all_of(Mask, [Sz](int Idx) { return Idx < 2 * Sz; }) && + ShuffleVectorInst::isIdentityMask(Mask)) + std::iota(Mask.begin(), Mask.end(), 0); + else + std::fill(Mask.begin(), Mask.end(), I); + return true; + }; + BVTy ShuffleBuilder(Params...); + ResTy Res = ResTy(); + SmallVector<int> Mask; + SmallVector<int> ExtractMask; + std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle; + std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle; + SmallVector<const TreeEntry *> Entries; + Type *ScalarTy = GatheredScalars.front()->getType(); + if (!all_of(GatheredScalars, UndefValue::classof)) { + // Check for gathered extracts. + ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask); + SmallVector<Value *> IgnoredVals; + if (UserIgnoreList) + IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end()); + bool Resized = false; + if (Value *VecBase = ShuffleBuilder.adjustExtracts(E, ExtractMask)) + if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType())) + if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) { + Resized = true; + GatheredScalars.append(VF - GatheredScalars.size(), + PoisonValue::get(ScalarTy)); + } + // Gather extracts after we check for full matched gathers only. + if (ExtractShuffle || E->getOpcode() != Instruction::Load || + E->isAltShuffle() || + all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) || + isSplat(E->Scalars) || + (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) { + GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); + } + if (GatherShuffle) { + if (Value *Delayed = ShuffleBuilder.needToDelay(E, Entries)) { + // Delay emission of gathers which are not ready yet. + PostponedGathers.insert(E); + // Postpone gather emission, will be emitted after the end of the + // process to keep correct order. + return Delayed; + } + assert((Entries.size() == 1 || Entries.size() == 2) && + "Expected shuffle of 1 or 2 entries."); + if (*GatherShuffle == TTI::SK_PermuteSingleSrc && + Entries.front()->isSame(E->Scalars)) { + // Perfect match in the graph, will reuse the previously vectorized + // node. Cost is 0. + LLVM_DEBUG( + dbgs() + << "SLP: perfect diamond match for gather bundle that starts with " + << *E->Scalars.front() << ".\n"); + // Restore the mask for previous partially matched values. + if (Entries.front()->ReorderIndices.empty() && + ((Entries.front()->ReuseShuffleIndices.empty() && + E->Scalars.size() == Entries.front()->Scalars.size()) || + (E->Scalars.size() == + Entries.front()->ReuseShuffleIndices.size()))) { + std::iota(Mask.begin(), Mask.end(), 0); + } else { + for (auto [I, V] : enumerate(E->Scalars)) { + if (isa<PoisonValue>(V)) { + Mask[I] = PoisonMaskElem; + continue; + } + Mask[I] = Entries.front()->findLaneForValue(V); + } + } + ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask); + Res = ShuffleBuilder.finalize(E->getCommonMask()); + return Res; + } + if (!Resized) { + unsigned VF1 = Entries.front()->getVectorFactor(); + unsigned VF2 = Entries.back()->getVectorFactor(); + if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF) + GatheredScalars.append(VF - GatheredScalars.size(), + PoisonValue::get(ScalarTy)); + } + // Remove shuffled elements from list of gathers. + for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { + if (Mask[I] != PoisonMaskElem) + GatheredScalars[I] = PoisonValue::get(ScalarTy); + } + } + } + auto TryPackScalars = [&](SmallVectorImpl<Value *> &Scalars, + SmallVectorImpl<int> &ReuseMask, + bool IsRootPoison) { // For splats with can emit broadcasts instead of gathers, so try to find // such sequences. - bool IsSplat = isSplat(VL) && (VL.size() > 2 || VL.front() == VL.back()); + bool IsSplat = IsRootPoison && isSplat(Scalars) && + (Scalars.size() > 2 || Scalars.front() == Scalars.back()); + Scalars.append(VF - Scalars.size(), PoisonValue::get(ScalarTy)); SmallVector<int> UndefPos; DenseMap<Value *, unsigned> UniquePositions; // Gather unique non-const values and all constant values. // For repeated values, just shuffle them. - for (auto [I, V] : enumerate(VL)) { + int NumNonConsts = 0; + int SinglePos = 0; + for (auto [I, V] : enumerate(Scalars)) { if (isa<UndefValue>(V)) { if (!isa<PoisonValue>(V)) { - Gathered[I] = V; ReuseMask[I] = I; UndefPos.push_back(I); } continue; } if (isConstant(V)) { - Gathered[I] = V; ReuseMask[I] = I; continue; } + ++NumNonConsts; + SinglePos = I; + Value *OrigV = V; + Scalars[I] = PoisonValue::get(ScalarTy); if (IsSplat) { - Gathered.front() = V; + Scalars.front() = OrigV; ReuseMask[I] = 0; } else { - const auto Res = UniquePositions.try_emplace(V, I); - Gathered[Res.first->second] = V; + const auto Res = UniquePositions.try_emplace(OrigV, I); + Scalars[Res.first->second] = OrigV; ReuseMask[I] = Res.first->second; } } - if (!UndefPos.empty() && IsSplat) { + if (NumNonConsts == 1) { + // Restore single insert element. + if (IsSplat) { + ReuseMask.assign(VF, PoisonMaskElem); + std::swap(Scalars.front(), Scalars[SinglePos]); + if (!UndefPos.empty() && UndefPos.front() == 0) + Scalars.front() = UndefValue::get(ScalarTy); + } + ReuseMask[SinglePos] = SinglePos; + } else if (!UndefPos.empty() && IsSplat) { // For undef values, try to replace them with the simple broadcast. // We can do it if the broadcasted value is guaranteed to be // non-poisonous, or by freezing the incoming scalar value first. - auto *It = find_if(Gathered, [this, E](Value *V) { + auto *It = find_if(Scalars, [this, E](Value *V) { return !isa<UndefValue>(V) && (getTreeEntry(V) || isGuaranteedNotToBePoison(V) || - any_of(V->uses(), [E](const Use &U) { - // Check if the value already used in the same operation in - // one of the nodes already. - return E->UserTreeIndices.size() == 1 && - is_contained( - E->UserTreeIndices.front().UserTE->Scalars, - U.getUser()) && - E->UserTreeIndices.front().EdgeIdx != U.getOperandNo(); - })); + (E->UserTreeIndices.size() == 1 && + any_of(V->uses(), [E](const Use &U) { + // Check if the value already used in the same operation in + // one of the nodes already. + return E->UserTreeIndices.front().EdgeIdx != + U.getOperandNo() && + is_contained( + E->UserTreeIndices.front().UserTE->Scalars, + U.getUser()); + }))); }); - if (It != Gathered.end()) { + if (It != Scalars.end()) { // Replace undefs by the non-poisoned scalars and emit broadcast. - int Pos = std::distance(Gathered.begin(), It); + int Pos = std::distance(Scalars.begin(), It); for_each(UndefPos, [&](int I) { // Set the undef position to the non-poisoned scalar. ReuseMask[I] = Pos; - // Replace the undef by the poison, in the mask it is replaced by non-poisoned scalar already. + // Replace the undef by the poison, in the mask it is replaced by + // non-poisoned scalar already. if (I != Pos) - Gathered[I] = PoisonValue::get(Gathered[I]->getType()); + Scalars[I] = PoisonValue::get(ScalarTy); }); } else { // Replace undefs by the poisons, emit broadcast and then emit // freeze. for_each(UndefPos, [&](int I) { - ReuseMask[I] = UndefMaskElem; - if (isa<UndefValue>(Gathered[I])) - Gathered[I] = PoisonValue::get(Gathered[I]->getType()); + ReuseMask[I] = PoisonMaskElem; + if (isa<UndefValue>(Scalars[I])) + Scalars[I] = PoisonValue::get(ScalarTy); }); NeedFreeze = true; } } + }; + if (ExtractShuffle || GatherShuffle) { + bool IsNonPoisoned = true; + bool IsUsedInExpr = false; + Value *Vec1 = nullptr; + if (ExtractShuffle) { + // Gather of extractelements can be represented as just a shuffle of + // a single/two vectors the scalars are extracted from. + // Find input vectors. + Value *Vec2 = nullptr; + for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) { + if (ExtractMask[I] == PoisonMaskElem || + (!Mask.empty() && Mask[I] != PoisonMaskElem)) { + ExtractMask[I] = PoisonMaskElem; + continue; + } + if (isa<UndefValue>(E->Scalars[I])) + continue; + auto *EI = cast<ExtractElementInst>(E->Scalars[I]); + if (!Vec1) { + Vec1 = EI->getVectorOperand(); + } else if (Vec1 != EI->getVectorOperand()) { + assert((!Vec2 || Vec2 == EI->getVectorOperand()) && + "Expected only 1 or 2 vectors shuffle."); + Vec2 = EI->getVectorOperand(); + } + } + if (Vec2) { + IsNonPoisoned &= + isGuaranteedNotToBePoison(Vec1) && isGuaranteedNotToBePoison(Vec2); + ShuffleBuilder.add(Vec1, Vec2, ExtractMask); + } else if (Vec1) { + IsUsedInExpr = FindReusedSplat(ExtractMask); + ShuffleBuilder.add(Vec1, ExtractMask); + IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1); + } else { + ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get( + ScalarTy, GatheredScalars.size())), + ExtractMask); + } + } + if (GatherShuffle) { + if (Entries.size() == 1) { + IsUsedInExpr = FindReusedSplat(Mask); + ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask); + IsNonPoisoned &= + isGuaranteedNotToBePoison(Entries.front()->VectorizedValue); + } else { + ShuffleBuilder.add(Entries.front()->VectorizedValue, + Entries.back()->VectorizedValue, Mask); + IsNonPoisoned &= + isGuaranteedNotToBePoison(Entries.front()->VectorizedValue) && + isGuaranteedNotToBePoison(Entries.back()->VectorizedValue); + } + } + // Try to figure out best way to combine values: build a shuffle and insert + // elements or just build several shuffles. + // Insert non-constant scalars. + SmallVector<Value *> NonConstants(GatheredScalars); + int EMSz = ExtractMask.size(); + int MSz = Mask.size(); + // Try to build constant vector and shuffle with it only if currently we + // have a single permutation and more than 1 scalar constants. + bool IsSingleShuffle = !ExtractShuffle || !GatherShuffle; + bool IsIdentityShuffle = + (ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc) == + TTI::SK_PermuteSingleSrc && + none_of(ExtractMask, [&](int I) { return I >= EMSz; }) && + ShuffleVectorInst::isIdentityMask(ExtractMask)) || + (GatherShuffle.value_or(TTI::SK_PermuteTwoSrc) == + TTI::SK_PermuteSingleSrc && + none_of(Mask, [&](int I) { return I >= MSz; }) && + ShuffleVectorInst::isIdentityMask(Mask)); + bool EnoughConstsForShuffle = + IsSingleShuffle && + (none_of(GatheredScalars, + [](Value *V) { + return isa<UndefValue>(V) && !isa<PoisonValue>(V); + }) || + any_of(GatheredScalars, + [](Value *V) { + return isa<Constant>(V) && !isa<UndefValue>(V); + })) && + (!IsIdentityShuffle || + (GatheredScalars.size() == 2 && + any_of(GatheredScalars, + [](Value *V) { return !isa<UndefValue>(V); })) || + count_if(GatheredScalars, [](Value *V) { + return isa<Constant>(V) && !isa<PoisonValue>(V); + }) > 1); + // NonConstants array contains just non-constant values, GatheredScalars + // contains only constant to build final vector and then shuffle. + for (int I = 0, Sz = GatheredScalars.size(); I < Sz; ++I) { + if (EnoughConstsForShuffle && isa<Constant>(GatheredScalars[I])) + NonConstants[I] = PoisonValue::get(ScalarTy); + else + GatheredScalars[I] = PoisonValue::get(ScalarTy); + } + // Generate constants for final shuffle and build a mask for them. + if (!all_of(GatheredScalars, PoisonValue::classof)) { + SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem); + TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true); + Value *BV = ShuffleBuilder.gather(GatheredScalars); + ShuffleBuilder.add(BV, BVMask); + } + if (all_of(NonConstants, [=](Value *V) { + return isa<PoisonValue>(V) || + (IsSingleShuffle && ((IsIdentityShuffle && + IsNonPoisoned) || IsUsedInExpr) && isa<UndefValue>(V)); + })) + Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices); + else + Res = ShuffleBuilder.finalize( + E->ReuseShuffleIndices, E->Scalars.size(), + [&](Value *&Vec, SmallVectorImpl<int> &Mask) { + TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false); + Vec = ShuffleBuilder.gather(NonConstants, Vec); + }); + } else if (!allConstant(GatheredScalars)) { + // Gather unique scalars and all constants. + SmallVector<int> ReuseMask(GatheredScalars.size(), PoisonMaskElem); + TryPackScalars(GatheredScalars, ReuseMask, /*IsRootPoison=*/true); + Value *BV = ShuffleBuilder.gather(GatheredScalars); + ShuffleBuilder.add(BV, ReuseMask); + Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices); } else { - ReuseMask.clear(); - copy(VL, Gathered.begin()); + // Gather all constants. + SmallVector<int> Mask(E->Scalars.size(), PoisonMaskElem); + for (auto [I, V] : enumerate(E->Scalars)) { + if (!isa<PoisonValue>(V)) + Mask[I] = I; + } + Value *BV = ShuffleBuilder.gather(E->Scalars); + ShuffleBuilder.add(BV, Mask); + Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices); } - // Gather unique scalars and all constants. - Value *Vec = gather(Gathered); - ShuffleBuilder.add(Vec, ReuseMask); - Vec = ShuffleBuilder.finalize(E->ReuseShuffleIndices); + if (NeedFreeze) - Vec = Builder.CreateFreeze(Vec); - return Vec; + Res = ShuffleBuilder.createFreeze(Res); + return Res; +} + +Value *BoUpSLP::createBuildVector(const TreeEntry *E) { + return processBuildVector<ShuffleInstructionBuilder, Value *>(E, Builder, + *this); } Value *BoUpSLP::vectorizeTree(TreeEntry *E) { @@ -9161,10 +10117,17 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return E->VectorizedValue; } + if (E->State == TreeEntry::NeedToGather) { + if (E->getMainOp() && E->Idx == 0) + setInsertPointAfterBundle(E); + Value *Vec = createBuildVector(E); + E->VectorizedValue = Vec; + return Vec; + } + auto FinalShuffle = [&](Value *V, const TreeEntry *E) { ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); - if (E->State != TreeEntry::NeedToGather && - E->getOpcode() == Instruction::Store) { + if (E->getOpcode() == Instruction::Store) { ArrayRef<int> Mask = ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()), E->ReorderIndices.size()); @@ -9175,45 +10138,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return ShuffleBuilder.finalize(E->ReuseShuffleIndices); }; - if (E->State == TreeEntry::NeedToGather) { - if (E->Idx > 0) { - // We are in the middle of a vectorizable chain. We need to gather the - // scalars from the users. - Value *Vec = createBuildVector(E); - E->VectorizedValue = Vec; - return Vec; - } - if (E->getMainOp()) - setInsertPointAfterBundle(E); - SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); - // Build a mask out of the reorder indices and reorder scalars per this - // mask. - SmallVector<int> ReorderMask; - inversePermutation(E->ReorderIndices, ReorderMask); - if (!ReorderMask.empty()) - reorderScalars(GatheredScalars, ReorderMask); - Value *Vec; - SmallVector<int> Mask; - SmallVector<const TreeEntry *> Entries; - std::optional<TargetTransformInfo::ShuffleKind> Shuffle = - isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); - if (Shuffle) { - assert((Entries.size() == 1 || Entries.size() == 2) && - "Expected shuffle of 1 or 2 entries."); - Vec = Builder.CreateShuffleVector(Entries.front()->VectorizedValue, - Entries.back()->VectorizedValue, Mask); - if (auto *I = dyn_cast<Instruction>(Vec)) { - GatherShuffleExtractSeq.insert(I); - CSEBlocks.insert(I->getParent()); - } - } else { - Vec = gather(E->Scalars); - } - Vec = FinalShuffle(Vec, E); - E->VectorizedValue = Vec; - return Vec; - } - assert((E->State == TreeEntry::Vectorize || E->State == TreeEntry::ScatterVectorize) && "Unhandled state"); @@ -9248,7 +10172,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // PHINodes may have multiple entries from the same block. We want to // visit every block once. - SmallPtrSet<BasicBlock*, 4> VisitedBBs; + SmallPtrSet<BasicBlock *, 4> VisitedBBs; for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { ValueList Operands; @@ -9314,14 +10238,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { SmallVector<int> Mask; if (!E->ReorderIndices.empty()) { inversePermutation(E->ReorderIndices, Mask); - Mask.append(NumElts - NumScalars, UndefMaskElem); + Mask.append(NumElts - NumScalars, PoisonMaskElem); } else { - Mask.assign(NumElts, UndefMaskElem); + Mask.assign(NumElts, PoisonMaskElem); std::iota(Mask.begin(), std::next(Mask.begin(), NumScalars), 0); } // Create InsertVector shuffle if necessary bool IsIdentity = true; - SmallVector<int> PrevMask(NumElts, UndefMaskElem); + SmallVector<int> PrevMask(NumElts, PoisonMaskElem); Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { Value *Scalar = E->Scalars[PrevMask[I]]; @@ -9337,9 +10261,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } } - SmallVector<int> InsertMask(NumElts, UndefMaskElem); + SmallVector<int> InsertMask(NumElts, PoisonMaskElem); for (unsigned I = 0; I < NumElts; I++) { - if (Mask[I] != UndefMaskElem) + if (Mask[I] != PoisonMaskElem) InsertMask[Offset + I] = I; } SmallBitVector UseMask = @@ -9354,10 +10278,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { isUndefVector<true>(FirstInsert->getOperand(0), UseMask); if (!IsFirstPoison.all()) { for (unsigned I = 0; I < NumElts; I++) { - if (InsertMask[I] == UndefMaskElem && !IsFirstPoison.test(I)) + if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I)) InsertMask[I] = I + NumElts; } - } + } V = Builder.CreateShuffleVector( V, IsFirstPoison.all() ? PoisonValue::get(V->getType()) @@ -9372,8 +10296,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { SmallBitVector IsFirstPoison = isUndefVector<true>(FirstInsert->getOperand(0), UseMask); for (unsigned I = 0; I < NumElts; I++) { - if (InsertMask[I] == UndefMaskElem) - InsertMask[I] = IsFirstPoison.test(I) ? UndefMaskElem : I; + if (InsertMask[I] == PoisonMaskElem) + InsertMask[I] = IsFirstPoison.test(I) ? PoisonMaskElem : I; else InsertMask[I] += NumElts; } @@ -9544,20 +10468,17 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { LoadInst *LI = cast<LoadInst>(VL0); Instruction *NewLI; - unsigned AS = LI->getPointerAddressSpace(); Value *PO = LI->getPointerOperand(); if (E->State == TreeEntry::Vectorize) { - Value *VecPtr = Builder.CreateBitCast(PO, VecTy->getPointerTo(AS)); - NewLI = Builder.CreateAlignedLoad(VecTy, VecPtr, LI->getAlign()); + NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign()); - // The pointer operand uses an in-tree scalar so we add the new BitCast - // or LoadInst to ExternalUses list to make sure that an extract will + // The pointer operand uses an in-tree scalar so we add the new + // LoadInst to ExternalUses list to make sure that an extract will // be generated in the future. if (TreeEntry *Entry = getTreeEntry(PO)) { // Find which lane we need to extract. unsigned FoundLane = Entry->findLaneForValue(PO); - ExternalUses.emplace_back( - PO, PO != VecPtr ? cast<User>(VecPtr) : NewLI, FoundLane); + ExternalUses.emplace_back(PO, NewLI, FoundLane); } } else { assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state"); @@ -9653,7 +10574,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { CallInst *CI = cast<CallInst>(VL0); setInsertPointAfterBundle(E); - Intrinsic::ID IID = Intrinsic::not_intrinsic; + Intrinsic::ID IID = Intrinsic::not_intrinsic; if (Function *FI = CI->getCalledFunction()) IID = FI->getIntrinsicID(); @@ -9665,8 +10586,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *ScalarArg = nullptr; std::vector<Value *> OpVecs; - SmallVector<Type *, 2> TysForDecl = - {FixedVectorType::get(CI->getType(), E->Scalars.size())}; + SmallVector<Type *, 2> TysForDecl; + // Add return type if intrinsic is overloaded on it. + if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1)) + TysForDecl.push_back( + FixedVectorType::get(CI->getType(), E->Scalars.size())); for (int j = 0, e = CI->arg_size(); j < e; ++j) { ValueList OpVL; // Some intrinsics have scalar arguments. This argument should not be @@ -9808,14 +10732,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } default: - llvm_unreachable("unknown inst"); + llvm_unreachable("unknown inst"); } return nullptr; } Value *BoUpSLP::vectorizeTree() { ExtraValueToDebugLocsMap ExternallyUsedValues; - return vectorizeTree(ExternallyUsedValues); + SmallVector<std::pair<Value *, Value *>> ReplacedExternals; + return vectorizeTree(ExternallyUsedValues, ReplacedExternals); } namespace { @@ -9829,28 +10754,51 @@ struct ShuffledInsertData { }; } // namespace -Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, - Instruction *ReductionRoot) { +Value *BoUpSLP::vectorizeTree( + const ExtraValueToDebugLocsMap &ExternallyUsedValues, + SmallVectorImpl<std::pair<Value *, Value *>> &ReplacedExternals, + Instruction *ReductionRoot) { // All blocks must be scheduled before any instructions are inserted. for (auto &BSIter : BlocksSchedules) { scheduleBlock(BSIter.second.get()); } - - // Pre-gather last instructions. - for (const std::unique_ptr<TreeEntry> &E : VectorizableTree) { - if ((E->State == TreeEntry::NeedToGather && - (!E->getMainOp() || E->Idx > 0)) || - (E->State != TreeEntry::NeedToGather && - E->getOpcode() == Instruction::ExtractValue) || - E->getOpcode() == Instruction::InsertElement) - continue; - Instruction *LastInst = &getLastInstructionInBundle(E.get()); - EntryToLastInstruction.try_emplace(E.get(), LastInst); - } + // Clean Entry-to-LastInstruction table. It can be affected after scheduling, + // need to rebuild it. + EntryToLastInstruction.clear(); Builder.SetInsertPoint(ReductionRoot ? ReductionRoot : &F->getEntryBlock().front()); auto *VectorRoot = vectorizeTree(VectorizableTree[0].get()); + // Run through the list of postponed gathers and emit them, replacing the temp + // emitted allocas with actual vector instructions. + ArrayRef<const TreeEntry *> PostponedNodes = PostponedGathers.getArrayRef(); + DenseMap<Value *, SmallVector<TreeEntry *>> PostponedValues; + for (const TreeEntry *E : PostponedNodes) { + auto *TE = const_cast<TreeEntry *>(E); + if (auto *VecTE = getTreeEntry(TE->Scalars.front())) + if (VecTE->isSame(TE->UserTreeIndices.front().UserTE->getOperand( + TE->UserTreeIndices.front().EdgeIdx))) + // Found gather node which is absolutely the same as one of the + // vectorized nodes. It may happen after reordering. + continue; + auto *PrevVec = cast<Instruction>(TE->VectorizedValue); + TE->VectorizedValue = nullptr; + auto *UserI = + cast<Instruction>(TE->UserTreeIndices.front().UserTE->VectorizedValue); + Builder.SetInsertPoint(PrevVec); + Builder.SetCurrentDebugLocation(UserI->getDebugLoc()); + Value *Vec = vectorizeTree(TE); + PrevVec->replaceAllUsesWith(Vec); + PostponedValues.try_emplace(Vec).first->second.push_back(TE); + // Replace the stub vector node, if it was used before for one of the + // buildvector nodes already. + auto It = PostponedValues.find(PrevVec); + if (It != PostponedValues.end()) { + for (TreeEntry *VTE : It->getSecond()) + VTE->VectorizedValue = Vec; + } + eraseInstruction(PrevVec); + } // If the vectorized tree can be rewritten in a smaller type, we truncate the // vectorized root. InstCombine will then rewrite the entire expression. We @@ -9968,14 +10916,9 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, Builder.SetInsertPoint(&F->getEntryBlock().front()); } Value *NewInst = ExtractAndExtendIfNeeded(Vec); - auto &NewInstLocs = ExternallyUsedValues[NewInst]; - auto It = ExternallyUsedValues.find(Scalar); - assert(It != ExternallyUsedValues.end() && - "Externally used scalar is not found in ExternallyUsedValues"); - NewInstLocs.append(It->second); - ExternallyUsedValues.erase(Scalar); // Required to update internally referenced instructions. Scalar->replaceAllUsesWith(NewInst); + ReplacedExternals.emplace_back(Scalar, NewInst); continue; } @@ -10004,7 +10947,7 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, ShuffledInserts.size() - 1); SmallVectorImpl<int> &Mask = It->ValueMasks[Vec]; if (Mask.empty()) - Mask.assign(FTy->getNumElements(), UndefMaskElem); + Mask.assign(FTy->getNumElements(), PoisonMaskElem); // Find the insertvector, vectorized in tree, if any. Value *Base = VU; while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) { @@ -10017,7 +10960,7 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, do { IEBase = cast<InsertElementInst>(Base); int IEIdx = *getInsertIndex(IEBase); - assert(Mask[Idx] == UndefMaskElem && + assert(Mask[Idx] == PoisonMaskElem && "InsertElementInstruction used already."); Mask[IEIdx] = IEIdx; Base = IEBase->getOperand(0); @@ -10035,7 +10978,7 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, } SmallVectorImpl<int> &Mask = It->ValueMasks[Vec]; if (Mask.empty()) - Mask.assign(FTy->getNumElements(), UndefMaskElem); + Mask.assign(FTy->getNumElements(), PoisonMaskElem); Mask[Idx] = ExternalUse.Lane; It->InsertElements.push_back(cast<InsertElementInst>(User)); continue; @@ -10077,8 +11020,8 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, } auto CreateShuffle = [&](Value *V1, Value *V2, ArrayRef<int> Mask) { - SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem); - SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem); + SmallVector<int> CombinedMask1(Mask.size(), PoisonMaskElem); + SmallVector<int> CombinedMask2(Mask.size(), PoisonMaskElem); int VF = cast<FixedVectorType>(V1->getType())->getNumElements(); for (int I = 0, E = Mask.size(); I < E; ++I) { if (Mask[I] < VF) @@ -10103,9 +11046,9 @@ Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, return std::make_pair(Vec, true); } if (!ForSingleMask) { - SmallVector<int> ResizeMask(VF, UndefMaskElem); + SmallVector<int> ResizeMask(VF, PoisonMaskElem); for (unsigned I = 0; I < VF; ++I) { - if (Mask[I] != UndefMaskElem) + if (Mask[I] != PoisonMaskElem) ResizeMask[Mask[I]] = Mask[I]; } Vec = CreateShuffle(Vec, nullptr, ResizeMask); @@ -10308,14 +11251,14 @@ void BoUpSLP::optimizeGatherSequence() { // registers. unsigned LastUndefsCnt = 0; for (int I = 0, E = NewMask.size(); I < E; ++I) { - if (SM1[I] == UndefMaskElem) + if (SM1[I] == PoisonMaskElem) ++LastUndefsCnt; else LastUndefsCnt = 0; - if (NewMask[I] != UndefMaskElem && SM1[I] != UndefMaskElem && + if (NewMask[I] != PoisonMaskElem && SM1[I] != PoisonMaskElem && NewMask[I] != SM1[I]) return false; - if (NewMask[I] == UndefMaskElem) + if (NewMask[I] == PoisonMaskElem) NewMask[I] = SM1[I]; } // Check if the last undefs actually change the final number of used vector @@ -10590,11 +11533,20 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, } // Search up and down at the same time, because we don't know if the new // instruction is above or below the existing scheduling region. + // Ignore debug info (and other "AssumeLike" intrinsics) so that's not counted + // against the budget. Otherwise debug info could affect codegen. BasicBlock::reverse_iterator UpIter = ++ScheduleStart->getIterator().getReverse(); BasicBlock::reverse_iterator UpperEnd = BB->rend(); BasicBlock::iterator DownIter = ScheduleEnd->getIterator(); BasicBlock::iterator LowerEnd = BB->end(); + auto IsAssumeLikeIntr = [](const Instruction &I) { + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + return II->isAssumeLikeIntrinsic(); + return false; + }; + UpIter = std::find_if_not(UpIter, UpperEnd, IsAssumeLikeIntr); + DownIter = std::find_if_not(DownIter, LowerEnd, IsAssumeLikeIntr); while (UpIter != UpperEnd && DownIter != LowerEnd && &*UpIter != I && &*DownIter != I) { if (++ScheduleRegionSize > ScheduleRegionSizeLimit) { @@ -10604,6 +11556,9 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, ++UpIter; ++DownIter; + + UpIter = std::find_if_not(UpIter, UpperEnd, IsAssumeLikeIntr); + DownIter = std::find_if_not(DownIter, LowerEnd, IsAssumeLikeIntr); } if (DownIter == LowerEnd || (UpIter != UpperEnd && &*UpIter == I)) { assert(I->getParent() == ScheduleStart->getParent() && @@ -10804,7 +11759,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, unsigned numAliased = 0; unsigned DistToSrc = 1; - for ( ; DepDest; DepDest = DepDest->NextLoadStore) { + for (; DepDest; DepDest = DepDest->NextLoadStore) { assert(isInSchedulingRegion(DepDest)); // We have two limits to reduce the complexity: @@ -11163,8 +12118,8 @@ void BoUpSLP::computeMinimumValueSizes() { // we can truncate the roots to this narrower type. for (auto *Root : TreeRoot) { auto Mask = DB->getDemandedBits(cast<Instruction>(Root)); - MaxBitWidth = std::max<unsigned>( - Mask.getBitWidth() - Mask.countLeadingZeros(), MaxBitWidth); + MaxBitWidth = std::max<unsigned>(Mask.getBitWidth() - Mask.countl_zero(), + MaxBitWidth); } // True if the roots can be zero-extended back to their original type, rather @@ -11223,8 +12178,7 @@ void BoUpSLP::computeMinimumValueSizes() { } // Round MaxBitWidth up to the next power-of-two. - if (!isPowerOf2_64(MaxBitWidth)) - MaxBitWidth = NextPowerOf2(MaxBitWidth); + MaxBitWidth = llvm::bit_ceil(MaxBitWidth); // If the maximum bit width we compute is less than the with of the roots' // type, we can proceed with the narrowing. Otherwise, do nothing. @@ -11242,60 +12196,6 @@ void BoUpSLP::computeMinimumValueSizes() { MinBWs[Scalar] = std::make_pair(MaxBitWidth, !IsKnownPositive); } -namespace { - -/// The SLPVectorizer Pass. -struct SLPVectorizer : public FunctionPass { - SLPVectorizerPass Impl; - - /// Pass identification, replacement for typeid - static char ID; - - explicit SLPVectorizer() : FunctionPass(ID) { - initializeSLPVectorizerPass(*PassRegistry::getPassRegistry()); - } - - bool doInitialization(Module &M) override { return false; } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; - auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); - auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - - return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, ORE); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - FunctionPass::getAnalysisUsage(AU); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<DemandedBitsWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<InjectTLIMappingsLegacy>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.setPreservesCFG(); - } -}; - -} // end anonymous namespace - PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); auto *TTI = &AM.getResult<TargetIRAnalysis>(F); @@ -11536,7 +12436,7 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, unsigned MaxVecRegSize = R.getMaxVecRegSize(); unsigned EltSize = R.getVectorElementSize(Operands[0]); - unsigned MaxElts = llvm::PowerOf2Floor(MaxVecRegSize / EltSize); + unsigned MaxElts = llvm::bit_floor(MaxVecRegSize / EltSize); unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts); @@ -11618,17 +12518,8 @@ void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { } } -bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { - if (!A || !B) - return false; - if (isa<InsertElementInst>(A) || isa<InsertElementInst>(B)) - return false; - Value *VL[] = {A, B}; - return tryToVectorizeList(VL, R); -} - bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - bool LimitForRegisterSize) { + bool MaxVFOnly) { if (VL.size() < 2) return false; @@ -11663,7 +12554,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, unsigned Sz = R.getVectorElementSize(I0); unsigned MinVF = R.getMinVF(Sz); - unsigned MaxVF = std::max<unsigned>(PowerOf2Floor(VL.size()), MinVF); + unsigned MaxVF = std::max<unsigned>(llvm::bit_floor(VL.size()), MinVF); MaxVF = std::min(R.getMaximumVF(Sz, S.getOpcode()), MaxVF); if (MaxVF < 2) { R.getORE()->emit([&]() { @@ -11690,21 +12581,17 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (TTI->getNumberOfParts(VecTy) == VF) continue; for (unsigned I = NextInst; I < MaxInst; ++I) { - unsigned OpsWidth = 0; + unsigned ActualVF = std::min(MaxInst - I, VF); - if (I + VF > MaxInst) - OpsWidth = MaxInst - I; - else - OpsWidth = VF; - - if (!isPowerOf2_32(OpsWidth)) + if (!isPowerOf2_32(ActualVF)) continue; - if ((LimitForRegisterSize && OpsWidth < MaxVF) || - (VF > MinVF && OpsWidth <= VF / 2) || (VF == MinVF && OpsWidth < 2)) + if (MaxVFOnly && ActualVF < MaxVF) + break; + if ((VF > MinVF && ActualVF <= VF / 2) || (VF == MinVF && ActualVF < 2)) break; - ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); + ArrayRef<Value *> Ops = VL.slice(I, ActualVF); // Check that a previous iteration of this loop did not delete the Value. if (llvm::any_of(Ops, [&R](Value *V) { auto *I = dyn_cast<Instruction>(V); @@ -11712,7 +12599,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, })) continue; - LLVM_DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations " + LLVM_DEBUG(dbgs() << "SLP: Analyzing " << ActualVF << " operations " << "\n"); R.buildTree(Ops); @@ -11730,7 +12617,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, MinCost = std::min(MinCost, Cost); LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost - << " for VF=" << OpsWidth << "\n"); + << " for VF=" << ActualVF << "\n"); if (Cost < -SLPCostThreshold) { LLVM_DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); R.getORE()->emit(OptimizationRemark(SV_NAME, "VectorizedList", @@ -11806,14 +12693,14 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { } if (Candidates.size() == 1) - return tryToVectorizePair(Op0, Op1, R); + return tryToVectorizeList({Op0, Op1}, R); // We have multiple options. Try to pick the single best. std::optional<int> BestCandidate = R.findBestRootPair(Candidates); if (!BestCandidate) return false; - return tryToVectorizePair(Candidates[*BestCandidate].first, - Candidates[*BestCandidate].second, R); + return tryToVectorizeList( + {Candidates[*BestCandidate].first, Candidates[*BestCandidate].second}, R); } namespace { @@ -11857,6 +12744,9 @@ class HorizontalReduction { WeakTrackingVH ReductionRoot; /// The type of reduction operation. RecurKind RdxKind; + /// Checks if the optimization of original scalar identity operations on + /// matched horizontal reductions is enabled and allowed. + bool IsSupportedHorRdxIdentityOp = false; static bool isCmpSelMinMax(Instruction *I) { return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) && @@ -11888,6 +12778,9 @@ class HorizontalReduction { return I->getFastMathFlags().noNaNs(); } + if (Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimum) + return true; + return I->isAssociative(); } @@ -11905,6 +12798,7 @@ class HorizontalReduction { static Value *createOp(IRBuilder<> &Builder, RecurKind Kind, Value *LHS, Value *RHS, const Twine &Name, bool UseSelect) { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); + bool IsConstant = isConstant(LHS) && isConstant(RHS); switch (Kind) { case RecurKind::Or: if (UseSelect && @@ -11926,29 +12820,49 @@ class HorizontalReduction { return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, Name); case RecurKind::FMax: + if (IsConstant) + return ConstantFP::get(LHS->getType(), + maxnum(cast<ConstantFP>(LHS)->getValueAPF(), + cast<ConstantFP>(RHS)->getValueAPF())); return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS); case RecurKind::FMin: + if (IsConstant) + return ConstantFP::get(LHS->getType(), + minnum(cast<ConstantFP>(LHS)->getValueAPF(), + cast<ConstantFP>(RHS)->getValueAPF())); return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS); + case RecurKind::FMaximum: + if (IsConstant) + return ConstantFP::get(LHS->getType(), + maximum(cast<ConstantFP>(LHS)->getValueAPF(), + cast<ConstantFP>(RHS)->getValueAPF())); + return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS); + case RecurKind::FMinimum: + if (IsConstant) + return ConstantFP::get(LHS->getType(), + minimum(cast<ConstantFP>(LHS)->getValueAPF(), + cast<ConstantFP>(RHS)->getValueAPF())); + return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS); case RecurKind::SMax: - if (UseSelect) { + if (IsConstant || UseSelect) { Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } return Builder.CreateBinaryIntrinsic(Intrinsic::smax, LHS, RHS); case RecurKind::SMin: - if (UseSelect) { + if (IsConstant || UseSelect) { Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } return Builder.CreateBinaryIntrinsic(Intrinsic::smin, LHS, RHS); case RecurKind::UMax: - if (UseSelect) { + if (IsConstant || UseSelect) { Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } return Builder.CreateBinaryIntrinsic(Intrinsic::umax, LHS, RHS); case RecurKind::UMin: - if (UseSelect) { + if (IsConstant || UseSelect) { Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } @@ -11984,6 +12898,7 @@ class HorizontalReduction { return Op; } +public: static RecurKind getRdxKind(Value *V) { auto *I = dyn_cast<Instruction>(V); if (!I) @@ -12010,6 +12925,10 @@ class HorizontalReduction { if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value()))) return RecurKind::FMin; + if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(), m_Value()))) + return RecurKind::FMaximum; + if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(), m_Value()))) + return RecurKind::FMinimum; // This matches either cmp+select or intrinsics. SLP is expected to handle // either form. // TODO: If we are canonicalizing to intrinsics, we can remove several @@ -12086,6 +13005,7 @@ class HorizontalReduction { return isCmpSelMinMax(I) ? 1 : 0; } +private: /// Total number of operands in the reduction operation. static unsigned getNumberOfOperands(Instruction *I) { return isCmpSelMinMax(I) ? 3 : 2; @@ -12134,17 +13054,6 @@ class HorizontalReduction { } } - static Value *getLHS(RecurKind Kind, Instruction *I) { - if (Kind == RecurKind::None) - return nullptr; - return I->getOperand(getFirstOperandIndex(I)); - } - static Value *getRHS(RecurKind Kind, Instruction *I) { - if (Kind == RecurKind::None) - return nullptr; - return I->getOperand(getFirstOperandIndex(I) + 1); - } - static bool isGoodForReduction(ArrayRef<Value *> Data) { int Sz = Data.size(); auto *I = dyn_cast<Instruction>(Data.front()); @@ -12156,65 +13065,39 @@ public: HorizontalReduction() = default; /// Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, Instruction *Inst, + bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root, ScalarEvolution &SE, const DataLayout &DL, const TargetLibraryInfo &TLI) { - assert((!Phi || is_contained(Phi->operands(), Inst)) && - "Phi needs to use the binary operator"); - assert((isa<BinaryOperator>(Inst) || isa<SelectInst>(Inst) || - isa<IntrinsicInst>(Inst)) && - "Expected binop, select, or intrinsic for reduction matching"); - RdxKind = getRdxKind(Inst); - - // We could have a initial reductions that is not an add. - // r *= v1 + v2 + v3 + v4 - // In such a case start looking for a tree rooted in the first '+'. - if (Phi) { - if (getLHS(RdxKind, Inst) == Phi) { - Phi = nullptr; - Inst = dyn_cast<Instruction>(getRHS(RdxKind, Inst)); - if (!Inst) - return false; - RdxKind = getRdxKind(Inst); - } else if (getRHS(RdxKind, Inst) == Phi) { - Phi = nullptr; - Inst = dyn_cast<Instruction>(getLHS(RdxKind, Inst)); - if (!Inst) - return false; - RdxKind = getRdxKind(Inst); - } - } - - if (!isVectorizable(RdxKind, Inst)) + RdxKind = HorizontalReduction::getRdxKind(Root); + if (!isVectorizable(RdxKind, Root)) return false; // Analyze "regular" integer/FP types for reductions - no target-specific // types or pointers. - Type *Ty = Inst->getType(); + Type *Ty = Root->getType(); if (!isValidElementType(Ty) || Ty->isPointerTy()) return false; // Though the ultimate reduction may have multiple uses, its condition must // have only single use. - if (auto *Sel = dyn_cast<SelectInst>(Inst)) + if (auto *Sel = dyn_cast<SelectInst>(Root)) if (!Sel->getCondition()->hasOneUse()) return false; - ReductionRoot = Inst; + ReductionRoot = Root; // Iterate through all the operands of the possible reduction tree and // gather all the reduced values, sorting them by their value id. - BasicBlock *BB = Inst->getParent(); - bool IsCmpSelMinMax = isCmpSelMinMax(Inst); - SmallVector<Instruction *> Worklist(1, Inst); + BasicBlock *BB = Root->getParent(); + bool IsCmpSelMinMax = isCmpSelMinMax(Root); + SmallVector<Instruction *> Worklist(1, Root); // Checks if the operands of the \p TreeN instruction are also reduction // operations or should be treated as reduced values or an extra argument, // which is not part of the reduction. - auto &&CheckOperands = [this, IsCmpSelMinMax, - BB](Instruction *TreeN, - SmallVectorImpl<Value *> &ExtraArgs, - SmallVectorImpl<Value *> &PossibleReducedVals, - SmallVectorImpl<Instruction *> &ReductionOps) { + auto CheckOperands = [&](Instruction *TreeN, + SmallVectorImpl<Value *> &ExtraArgs, + SmallVectorImpl<Value *> &PossibleReducedVals, + SmallVectorImpl<Instruction *> &ReductionOps) { for (int I = getFirstOperandIndex(TreeN), End = getNumberOfOperands(TreeN); I < End; ++I) { @@ -12229,10 +13112,14 @@ public: } // If the edge is not an instruction, or it is different from the main // reduction opcode or has too many uses - possible reduced value. + // Also, do not try to reduce const values, if the operation is not + // foldable. if (!EdgeInst || getRdxKind(EdgeInst) != RdxKind || IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) || !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) || - !isVectorizable(getRdxKind(EdgeInst), EdgeInst)) { + !isVectorizable(RdxKind, EdgeInst) || + (R.isAnalyzedReductionRoot(EdgeInst) && + all_of(EdgeInst->operands(), Constant::classof))) { PossibleReducedVals.push_back(EdgeVal); continue; } @@ -12246,10 +13133,43 @@ public: // instructions (grouping them by the predicate). MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>> PossibleReducedVals; - initReductionOps(Inst); + initReductionOps(Root); DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap; SmallSet<size_t, 2> LoadKeyUsed; SmallPtrSet<Value *, 4> DoNotReverseVals; + + auto GenerateLoadsSubkey = [&](size_t Key, LoadInst *LI) { + Value *Ptr = getUnderlyingObject(LI->getPointerOperand()); + if (LoadKeyUsed.contains(Key)) { + auto LIt = LoadsMap.find(Ptr); + if (LIt != LoadsMap.end()) { + for (LoadInst *RLI : LIt->second) { + if (getPointersDiff(RLI->getType(), RLI->getPointerOperand(), + LI->getType(), LI->getPointerOperand(), DL, SE, + /*StrictCheck=*/true)) + return hash_value(RLI->getPointerOperand()); + } + for (LoadInst *RLI : LIt->second) { + if (arePointersCompatible(RLI->getPointerOperand(), + LI->getPointerOperand(), TLI)) { + hash_code SubKey = hash_value(RLI->getPointerOperand()); + DoNotReverseVals.insert(RLI); + return SubKey; + } + } + if (LIt->second.size() > 2) { + hash_code SubKey = + hash_value(LIt->second.back()->getPointerOperand()); + DoNotReverseVals.insert(LIt->second.back()); + return SubKey; + } + } + } + LoadKeyUsed.insert(Key); + LoadsMap.try_emplace(Ptr).first->second.push_back(LI); + return hash_value(LI->getPointerOperand()); + }; + while (!Worklist.empty()) { Instruction *TreeN = Worklist.pop_back_val(); SmallVector<Value *> Args; @@ -12269,41 +13189,8 @@ public: // results. for (Value *V : PossibleRedVals) { size_t Key, Idx; - std::tie(Key, Idx) = generateKeySubkey( - V, &TLI, - [&](size_t Key, LoadInst *LI) { - Value *Ptr = getUnderlyingObject(LI->getPointerOperand()); - if (LoadKeyUsed.contains(Key)) { - auto LIt = LoadsMap.find(Ptr); - if (LIt != LoadsMap.end()) { - for (LoadInst *RLI: LIt->second) { - if (getPointersDiff( - RLI->getType(), RLI->getPointerOperand(), - LI->getType(), LI->getPointerOperand(), DL, SE, - /*StrictCheck=*/true)) - return hash_value(RLI->getPointerOperand()); - } - for (LoadInst *RLI : LIt->second) { - if (arePointersCompatible(RLI->getPointerOperand(), - LI->getPointerOperand(), TLI)) { - hash_code SubKey = hash_value(RLI->getPointerOperand()); - DoNotReverseVals.insert(RLI); - return SubKey; - } - } - if (LIt->second.size() > 2) { - hash_code SubKey = - hash_value(LIt->second.back()->getPointerOperand()); - DoNotReverseVals.insert(LIt->second.back()); - return SubKey; - } - } - } - LoadKeyUsed.insert(Key); - LoadsMap.try_emplace(Ptr).first->second.push_back(LI); - return hash_value(LI->getPointerOperand()); - }, - /*AllowAlternate=*/false); + std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey, + /*AllowAlternate=*/false); ++PossibleReducedVals[Key][Idx] .insert(std::make_pair(V, 0)) .first->second; @@ -12312,40 +13199,8 @@ public: PossibleReductionOps.rend()); } else { size_t Key, Idx; - std::tie(Key, Idx) = generateKeySubkey( - TreeN, &TLI, - [&](size_t Key, LoadInst *LI) { - Value *Ptr = getUnderlyingObject(LI->getPointerOperand()); - if (LoadKeyUsed.contains(Key)) { - auto LIt = LoadsMap.find(Ptr); - if (LIt != LoadsMap.end()) { - for (LoadInst *RLI: LIt->second) { - if (getPointersDiff(RLI->getType(), - RLI->getPointerOperand(), LI->getType(), - LI->getPointerOperand(), DL, SE, - /*StrictCheck=*/true)) - return hash_value(RLI->getPointerOperand()); - } - for (LoadInst *RLI : LIt->second) { - if (arePointersCompatible(RLI->getPointerOperand(), - LI->getPointerOperand(), TLI)) { - hash_code SubKey = hash_value(RLI->getPointerOperand()); - DoNotReverseVals.insert(RLI); - return SubKey; - } - } - if (LIt->second.size() > 2) { - hash_code SubKey = hash_value(LIt->second.back()->getPointerOperand()); - DoNotReverseVals.insert(LIt->second.back()); - return SubKey; - } - } - } - LoadKeyUsed.insert(Key); - LoadsMap.try_emplace(Ptr).first->second.push_back(LI); - return hash_value(LI->getPointerOperand()); - }, - /*AllowAlternate=*/false); + std::tie(Key, Idx) = generateKeySubkey(TreeN, &TLI, GenerateLoadsSubkey, + /*AllowAlternate=*/false); ++PossibleReducedVals[Key][Idx] .insert(std::make_pair(TreeN, 0)) .first->second; @@ -12407,14 +13262,18 @@ public: // If there are a sufficient number of reduction values, reduce // to a nearby power-of-2. We can safely generate oversized // vectors and rely on the backend to split them to legal sizes. - size_t NumReducedVals = + unsigned NumReducedVals = std::accumulate(ReducedVals.begin(), ReducedVals.end(), 0, - [](size_t Num, ArrayRef<Value *> Vals) { + [](unsigned Num, ArrayRef<Value *> Vals) -> unsigned { if (!isGoodForReduction(Vals)) return Num; return Num + Vals.size(); }); - if (NumReducedVals < ReductionLimit) { + if (NumReducedVals < ReductionLimit && + (!AllowHorRdxIdenityOptimization || + all_of(ReducedVals, [](ArrayRef<Value *> RedV) { + return RedV.size() < 2 || !allConstant(RedV) || !isSplat(RedV); + }))) { for (ReductionOpsType &RdxOps : ReductionOps) for (Value *RdxOp : RdxOps) V.analyzedReductionRoot(cast<Instruction>(RdxOp)); @@ -12428,6 +13287,7 @@ public: DenseMap<Value *, WeakTrackingVH> TrackedVals( ReducedVals.size() * ReducedVals.front().size() + ExtraArgs.size()); BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; + SmallVector<std::pair<Value *, Value *>> ReplacedExternals; ExternallyUsedValues.reserve(ExtraArgs.size() + 1); // The same extra argument may be used several times, so log each attempt // to use it. @@ -12448,6 +13308,18 @@ public: return cast<Instruction>(ScalarCond); }; + // Return new VectorizedTree, based on previous value. + auto GetNewVectorizedTree = [&](Value *VectorizedTree, Value *Res) { + if (VectorizedTree) { + // Update the final value in the reduction. + Builder.SetCurrentDebugLocation( + cast<Instruction>(ReductionOps.front().front())->getDebugLoc()); + return createOp(Builder, RdxKind, VectorizedTree, Res, "op.rdx", + ReductionOps); + } + // Initialize the final value in the reduction. + return Res; + }; // The reduction root is used as the insertion point for new instructions, // so set it as externally used to prevent it from being deleted. ExternallyUsedValues[ReductionRoot]; @@ -12459,6 +13331,12 @@ public: continue; IgnoreList.insert(RdxOp); } + // Intersect the fast-math-flags from all reduction operations. + FastMathFlags RdxFMF; + RdxFMF.set(); + for (Value *U : IgnoreList) + if (auto *FPMO = dyn_cast<FPMathOperator>(U)) + RdxFMF &= FPMO->getFastMathFlags(); bool IsCmpSelMinMax = isCmpSelMinMax(cast<Instruction>(ReductionRoot)); // Need to track reduced vals, they may be changed during vectorization of @@ -12519,16 +13397,82 @@ public: } } } + + // Emit code for constant values. + if (AllowHorRdxIdenityOptimization && Candidates.size() > 1 && + allConstant(Candidates)) { + Value *Res = Candidates.front(); + ++VectorizedVals.try_emplace(Candidates.front(), 0).first->getSecond(); + for (Value *VC : ArrayRef(Candidates).drop_front()) { + Res = createOp(Builder, RdxKind, Res, VC, "const.rdx", ReductionOps); + ++VectorizedVals.try_emplace(VC, 0).first->getSecond(); + if (auto *ResI = dyn_cast<Instruction>(Res)) + V.analyzedReductionRoot(ResI); + } + VectorizedTree = GetNewVectorizedTree(VectorizedTree, Res); + continue; + } + unsigned NumReducedVals = Candidates.size(); - if (NumReducedVals < ReductionLimit) + if (NumReducedVals < ReductionLimit && + (NumReducedVals < 2 || !AllowHorRdxIdenityOptimization || + !isSplat(Candidates))) continue; + // Check if we support repeated scalar values processing (optimization of + // original scalar identity operations on matched horizontal reductions). + IsSupportedHorRdxIdentityOp = + AllowHorRdxIdenityOptimization && RdxKind != RecurKind::Mul && + RdxKind != RecurKind::FMul && RdxKind != RecurKind::FMulAdd; + // Gather same values. + MapVector<Value *, unsigned> SameValuesCounter; + if (IsSupportedHorRdxIdentityOp) + for (Value *V : Candidates) + ++SameValuesCounter.insert(std::make_pair(V, 0)).first->second; + // Used to check if the reduced values used same number of times. In this + // case the compiler may produce better code. E.g. if reduced values are + // aabbccdd (8 x values), then the first node of the tree will have a node + // for 4 x abcd + shuffle <4 x abcd>, <0, 0, 1, 1, 2, 2, 3, 3>. + // Plus, the final reduction will be performed on <8 x aabbccdd>. + // Instead compiler may build <4 x abcd> tree immediately, + reduction (4 + // x abcd) * 2. + // Currently it only handles add/fadd/xor. and/or/min/max do not require + // this analysis, other operations may require an extra estimation of + // the profitability. + bool SameScaleFactor = false; + bool OptReusedScalars = IsSupportedHorRdxIdentityOp && + SameValuesCounter.size() != Candidates.size(); + if (OptReusedScalars) { + SameScaleFactor = + (RdxKind == RecurKind::Add || RdxKind == RecurKind::FAdd || + RdxKind == RecurKind::Xor) && + all_of(drop_begin(SameValuesCounter), + [&SameValuesCounter](const std::pair<Value *, unsigned> &P) { + return P.second == SameValuesCounter.front().second; + }); + Candidates.resize(SameValuesCounter.size()); + transform(SameValuesCounter, Candidates.begin(), + [](const auto &P) { return P.first; }); + NumReducedVals = Candidates.size(); + // Have a reduction of the same element. + if (NumReducedVals == 1) { + Value *OrigV = TrackedToOrig.find(Candidates.front())->second; + unsigned Cnt = SameValuesCounter.lookup(OrigV); + Value *RedVal = + emitScaleForReusedOps(Candidates.front(), Builder, Cnt); + VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal); + VectorizedVals.try_emplace(OrigV, Cnt); + continue; + } + } + unsigned MaxVecRegSize = V.getMaxVecRegSize(); unsigned EltSize = V.getVectorElementSize(Candidates[0]); - unsigned MaxElts = RegMaxNumber * PowerOf2Floor(MaxVecRegSize / EltSize); + unsigned MaxElts = + RegMaxNumber * llvm::bit_floor(MaxVecRegSize / EltSize); unsigned ReduxWidth = std::min<unsigned>( - PowerOf2Floor(NumReducedVals), std::max(RedValsMaxNumber, MaxElts)); + llvm::bit_floor(NumReducedVals), std::max(RedValsMaxNumber, MaxElts)); unsigned Start = 0; unsigned Pos = Start; // Restarts vectorization attempt with lower vector factor. @@ -12551,6 +13495,7 @@ public: ReduxWidth /= 2; return IsAnyRedOpGathered; }; + bool AnyVectorized = false; while (Pos < NumReducedVals - ReduxWidth + 1 && ReduxWidth >= ReductionLimit) { // Dependency in tree of the reduction ops - drop this attempt, try @@ -12603,34 +13548,24 @@ public: LocalExternallyUsedValues[TrackedVals[V]]; }); } - // Number of uses of the candidates in the vector of values. - SmallDenseMap<Value *, unsigned> NumUses(Candidates.size()); - for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) { - Value *V = Candidates[Cnt]; - ++NumUses.try_emplace(V, 0).first->getSecond(); - } - for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) { - Value *V = Candidates[Cnt]; - ++NumUses.try_emplace(V, 0).first->getSecond(); + if (!IsSupportedHorRdxIdentityOp) { + // Number of uses of the candidates in the vector of values. + assert(SameValuesCounter.empty() && + "Reused values counter map is not empty"); + for (unsigned Cnt = 0; Cnt < NumReducedVals; ++Cnt) { + if (Cnt >= Pos && Cnt < Pos + ReduxWidth) + continue; + Value *V = Candidates[Cnt]; + Value *OrigV = TrackedToOrig.find(V)->second; + ++SameValuesCounter[OrigV]; + } } SmallPtrSet<Value *, 4> VLScalars(VL.begin(), VL.end()); // Gather externally used values. SmallPtrSet<Value *, 4> Visited; - for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) { - Value *RdxVal = Candidates[Cnt]; - if (!Visited.insert(RdxVal).second) + for (unsigned Cnt = 0; Cnt < NumReducedVals; ++Cnt) { + if (Cnt >= Pos && Cnt < Pos + ReduxWidth) continue; - // Check if the scalar was vectorized as part of the vectorization - // tree but not the top node. - if (!VLScalars.contains(RdxVal) && V.isVectorized(RdxVal)) { - LocalExternallyUsedValues[RdxVal]; - continue; - } - unsigned NumOps = VectorizedVals.lookup(RdxVal) + NumUses[RdxVal]; - if (NumOps != ReducedValsToOps.find(RdxVal)->second.size()) - LocalExternallyUsedValues[RdxVal]; - } - for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) { Value *RdxVal = Candidates[Cnt]; if (!Visited.insert(RdxVal).second) continue; @@ -12640,42 +13575,34 @@ public: LocalExternallyUsedValues[RdxVal]; continue; } - unsigned NumOps = VectorizedVals.lookup(RdxVal) + NumUses[RdxVal]; - if (NumOps != ReducedValsToOps.find(RdxVal)->second.size()) + Value *OrigV = TrackedToOrig.find(RdxVal)->second; + unsigned NumOps = + VectorizedVals.lookup(RdxVal) + SameValuesCounter[OrigV]; + if (NumOps != ReducedValsToOps.find(OrigV)->second.size()) LocalExternallyUsedValues[RdxVal]; } + // Do not need the list of reused scalars in regular mode anymore. + if (!IsSupportedHorRdxIdentityOp) + SameValuesCounter.clear(); for (Value *RdxVal : VL) if (RequiredExtract.contains(RdxVal)) LocalExternallyUsedValues[RdxVal]; + // Update LocalExternallyUsedValues for the scalar, replaced by + // extractelement instructions. + for (const std::pair<Value *, Value *> &Pair : ReplacedExternals) { + auto It = ExternallyUsedValues.find(Pair.first); + if (It == ExternallyUsedValues.end()) + continue; + LocalExternallyUsedValues[Pair.second].append(It->second); + } V.buildExternalUses(LocalExternallyUsedValues); V.computeMinimumValueSizes(); - // Intersect the fast-math-flags from all reduction operations. - FastMathFlags RdxFMF; - RdxFMF.set(); - for (Value *U : IgnoreList) - if (auto *FPMO = dyn_cast<FPMathOperator>(U)) - RdxFMF &= FPMO->getFastMathFlags(); // Estimate cost. InstructionCost TreeCost = V.getTreeCost(VL); InstructionCost ReductionCost = - getReductionCost(TTI, VL, ReduxWidth, RdxFMF); - if (V.isVectorizedFirstNode() && isa<LoadInst>(VL.front())) { - Instruction *MainOp = V.getFirstNodeMainOp(); - for (Value *V : VL) { - auto *VI = dyn_cast<LoadInst>(V); - // Add the costs of scalar GEP pointers, to be removed from the - // code. - if (!VI || VI == MainOp) - continue; - auto *Ptr = dyn_cast<GetElementPtrInst>(VI->getPointerOperand()); - if (!Ptr || !Ptr->hasOneUse() || Ptr->hasAllConstantIndices()) - continue; - TreeCost -= TTI->getArithmeticInstrCost( - Instruction::Add, Ptr->getType(), TTI::TCK_RecipThroughput); - } - } + getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF); InstructionCost Cost = TreeCost + ReductionCost; LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); if (!Cost.isValid()) @@ -12716,8 +13643,8 @@ public: InsertPt = GetCmpForMinMaxReduction(RdxRootInst); // Vectorize a tree. - Value *VectorizedRoot = - V.vectorizeTree(LocalExternallyUsedValues, InsertPt); + Value *VectorizedRoot = V.vectorizeTree(LocalExternallyUsedValues, + ReplacedExternals, InsertPt); Builder.SetInsertPoint(InsertPt); @@ -12727,29 +13654,48 @@ public: if (isBoolLogicOp(RdxRootInst)) VectorizedRoot = Builder.CreateFreeze(VectorizedRoot); + // Emit code to correctly handle reused reduced values, if required. + if (OptReusedScalars && !SameScaleFactor) { + VectorizedRoot = + emitReusedOps(VectorizedRoot, Builder, V.getRootNodeScalars(), + SameValuesCounter, TrackedToOrig); + } + Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); - if (!VectorizedTree) { - // Initialize the final value in the reduction. - VectorizedTree = ReducedSubTree; - } else { - // Update the final value in the reduction. - Builder.SetCurrentDebugLocation( - cast<Instruction>(ReductionOps.front().front())->getDebugLoc()); - VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, - ReducedSubTree, "op.rdx", ReductionOps); - } + // Improved analysis for add/fadd/xor reductions with same scale factor + // for all operands of reductions. We can emit scalar ops for them + // instead. + if (OptReusedScalars && SameScaleFactor) + ReducedSubTree = emitScaleForReusedOps( + ReducedSubTree, Builder, SameValuesCounter.front().second); + + VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree); // Count vectorized reduced values to exclude them from final reduction. for (Value *RdxVal : VL) { - ++VectorizedVals.try_emplace(TrackedToOrig.find(RdxVal)->second, 0) - .first->getSecond(); + Value *OrigV = TrackedToOrig.find(RdxVal)->second; + if (IsSupportedHorRdxIdentityOp) { + VectorizedVals.try_emplace(OrigV, SameValuesCounter[RdxVal]); + continue; + } + ++VectorizedVals.try_emplace(OrigV, 0).first->getSecond(); if (!V.isVectorized(RdxVal)) RequiredExtract.insert(RdxVal); } Pos += ReduxWidth; Start = Pos; - ReduxWidth = PowerOf2Floor(NumReducedVals - Pos); + ReduxWidth = llvm::bit_floor(NumReducedVals - Pos); + AnyVectorized = true; + } + if (OptReusedScalars && !AnyVectorized) { + for (const std::pair<Value *, unsigned> &P : SameValuesCounter) { + Value *RedVal = emitScaleForReusedOps(P.first, Builder, P.second); + VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal); + Value *OrigV = TrackedToOrig.find(P.first)->second; + VectorizedVals.try_emplace(OrigV, P.second); + } + continue; } } if (VectorizedTree) { @@ -12757,7 +13703,7 @@ public: // possible problem with poison propagation. If not possible to reorder // (both operands are originally RHS), emit an extra freeze instruction // for the LHS operand. - //I.e., if we have original code like this: + // I.e., if we have original code like this: // RedOp1 = select i1 ?, i1 LHS, i1 false // RedOp2 = select i1 RHS, i1 ?, i1 false @@ -12892,7 +13838,8 @@ private: /// Calculate the cost of a reduction. InstructionCost getReductionCost(TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals, - unsigned ReduxWidth, FastMathFlags FMF) { + bool IsCmpSelMinMax, unsigned ReduxWidth, + FastMathFlags FMF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Value *FirstReducedVal = ReducedVals.front(); Type *ScalarTy = FirstReducedVal->getType(); @@ -12900,7 +13847,36 @@ private: InstructionCost VectorCost = 0, ScalarCost; // If all of the reduced values are constant, the vector cost is 0, since // the reduction value can be calculated at the compile time. - bool AllConsts = all_of(ReducedVals, isConstant); + bool AllConsts = allConstant(ReducedVals); + auto EvaluateScalarCost = [&](function_ref<InstructionCost()> GenCostFn) { + InstructionCost Cost = 0; + // Scalar cost is repeated for N-1 elements. + int Cnt = ReducedVals.size(); + for (Value *RdxVal : ReducedVals) { + if (Cnt == 1) + break; + --Cnt; + if (RdxVal->hasNUsesOrMore(IsCmpSelMinMax ? 3 : 2)) { + Cost += GenCostFn(); + continue; + } + InstructionCost ScalarCost = 0; + for (User *U : RdxVal->users()) { + auto *RdxOp = cast<Instruction>(U); + if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) { + ScalarCost += TTI->getInstructionCost(RdxOp, CostKind); + continue; + } + ScalarCost = InstructionCost::getInvalid(); + break; + } + if (ScalarCost.isValid()) + Cost += ScalarCost; + else + Cost += GenCostFn(); + } + return Cost; + }; switch (RdxKind) { case RecurKind::Add: case RecurKind::Mul: @@ -12913,52 +13889,32 @@ private: if (!AllConsts) VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); - ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); + ScalarCost = EvaluateScalarCost([&]() { + return TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); + }); break; } case RecurKind::FMax: - case RecurKind::FMin: { - auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); - if (!AllConsts) { - auto *VecCondTy = - cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); - VectorCost = - TTI->getMinMaxReductionCost(VectorTy, VecCondTy, - /*IsUnsigned=*/false, CostKind); - } - CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); - ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, - SclCondTy, RdxPred, CostKind) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - SclCondTy, RdxPred, CostKind); - break; - } + case RecurKind::FMin: + case RecurKind::FMaximum: + case RecurKind::FMinimum: case RecurKind::SMax: case RecurKind::SMin: case RecurKind::UMax: case RecurKind::UMin: { - auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); - if (!AllConsts) { - auto *VecCondTy = - cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); - bool IsUnsigned = - RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin; - VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, - IsUnsigned, CostKind); - } - CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); - ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, - SclCondTy, RdxPred, CostKind) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - SclCondTy, RdxPred, CostKind); + Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind); + if (!AllConsts) + VectorCost = TTI->getMinMaxReductionCost(Id, VectorTy, FMF, CostKind); + ScalarCost = EvaluateScalarCost([&]() { + IntrinsicCostAttributes ICA(Id, ScalarTy, {ScalarTy, ScalarTy}, FMF); + return TTI->getIntrinsicInstrCost(ICA, CostKind); + }); break; } default: llvm_unreachable("Expected arithmetic or min/max reduction operation"); } - // Scalar cost is repeated for N-1 elements. - ScalarCost *= (ReduxWidth - 1); LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost << " for reduction that starts with " << *FirstReducedVal << " (It is a splitting reduction)\n"); @@ -12977,8 +13933,148 @@ private: ++NumVectorInstructions; return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind); } -}; + /// Emits optimized code for unique scalar value reused \p Cnt times. + Value *emitScaleForReusedOps(Value *VectorizedValue, IRBuilderBase &Builder, + unsigned Cnt) { + assert(IsSupportedHorRdxIdentityOp && + "The optimization of matched scalar identity horizontal reductions " + "must be supported."); + switch (RdxKind) { + case RecurKind::Add: { + // res = mul vv, n + Value *Scale = ConstantInt::get(VectorizedValue->getType(), Cnt); + LLVM_DEBUG(dbgs() << "SLP: Add (to-mul) " << Cnt << "of " + << VectorizedValue << ". (HorRdx)\n"); + return Builder.CreateMul(VectorizedValue, Scale); + } + case RecurKind::Xor: { + // res = n % 2 ? 0 : vv + LLVM_DEBUG(dbgs() << "SLP: Xor " << Cnt << "of " << VectorizedValue + << ". (HorRdx)\n"); + if (Cnt % 2 == 0) + return Constant::getNullValue(VectorizedValue->getType()); + return VectorizedValue; + } + case RecurKind::FAdd: { + // res = fmul v, n + Value *Scale = ConstantFP::get(VectorizedValue->getType(), Cnt); + LLVM_DEBUG(dbgs() << "SLP: FAdd (to-fmul) " << Cnt << "of " + << VectorizedValue << ". (HorRdx)\n"); + return Builder.CreateFMul(VectorizedValue, Scale); + } + case RecurKind::And: + case RecurKind::Or: + case RecurKind::SMax: + case RecurKind::SMin: + case RecurKind::UMax: + case RecurKind::UMin: + case RecurKind::FMax: + case RecurKind::FMin: + case RecurKind::FMaximum: + case RecurKind::FMinimum: + // res = vv + return VectorizedValue; + case RecurKind::Mul: + case RecurKind::FMul: + case RecurKind::FMulAdd: + case RecurKind::SelectICmp: + case RecurKind::SelectFCmp: + case RecurKind::None: + llvm_unreachable("Unexpected reduction kind for repeated scalar."); + } + return nullptr; + } + + /// Emits actual operation for the scalar identity values, found during + /// horizontal reduction analysis. + Value *emitReusedOps(Value *VectorizedValue, IRBuilderBase &Builder, + ArrayRef<Value *> VL, + const MapVector<Value *, unsigned> &SameValuesCounter, + const DenseMap<Value *, Value *> &TrackedToOrig) { + assert(IsSupportedHorRdxIdentityOp && + "The optimization of matched scalar identity horizontal reductions " + "must be supported."); + switch (RdxKind) { + case RecurKind::Add: { + // root = mul prev_root, <1, 1, n, 1> + SmallVector<Constant *> Vals; + for (Value *V : VL) { + unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second); + Vals.push_back(ConstantInt::get(V->getType(), Cnt, /*IsSigned=*/false)); + } + auto *Scale = ConstantVector::get(Vals); + LLVM_DEBUG(dbgs() << "SLP: Add (to-mul) " << Scale << "of " + << VectorizedValue << ". (HorRdx)\n"); + return Builder.CreateMul(VectorizedValue, Scale); + } + case RecurKind::And: + case RecurKind::Or: + // No need for multiple or/and(s). + LLVM_DEBUG(dbgs() << "SLP: And/or of same " << VectorizedValue + << ". (HorRdx)\n"); + return VectorizedValue; + case RecurKind::SMax: + case RecurKind::SMin: + case RecurKind::UMax: + case RecurKind::UMin: + case RecurKind::FMax: + case RecurKind::FMin: + case RecurKind::FMaximum: + case RecurKind::FMinimum: + // No need for multiple min/max(s) of the same value. + LLVM_DEBUG(dbgs() << "SLP: Max/min of same " << VectorizedValue + << ". (HorRdx)\n"); + return VectorizedValue; + case RecurKind::Xor: { + // Replace values with even number of repeats with 0, since + // x xor x = 0. + // root = shuffle prev_root, zeroinitalizer, <0, 1, 2, vf, 4, vf, 5, 6, + // 7>, if elements 4th and 6th elements have even number of repeats. + SmallVector<int> Mask( + cast<FixedVectorType>(VectorizedValue->getType())->getNumElements(), + PoisonMaskElem); + std::iota(Mask.begin(), Mask.end(), 0); + bool NeedShuffle = false; + for (unsigned I = 0, VF = VL.size(); I < VF; ++I) { + Value *V = VL[I]; + unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second); + if (Cnt % 2 == 0) { + Mask[I] = VF; + NeedShuffle = true; + } + } + LLVM_DEBUG(dbgs() << "SLP: Xor <"; for (int I + : Mask) dbgs() + << I << " "; + dbgs() << "> of " << VectorizedValue << ". (HorRdx)\n"); + if (NeedShuffle) + VectorizedValue = Builder.CreateShuffleVector( + VectorizedValue, + ConstantVector::getNullValue(VectorizedValue->getType()), Mask); + return VectorizedValue; + } + case RecurKind::FAdd: { + // root = fmul prev_root, <1.0, 1.0, n.0, 1.0> + SmallVector<Constant *> Vals; + for (Value *V : VL) { + unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second); + Vals.push_back(ConstantFP::get(V->getType(), Cnt)); + } + auto *Scale = ConstantVector::get(Vals); + return Builder.CreateFMul(VectorizedValue, Scale); + } + case RecurKind::Mul: + case RecurKind::FMul: + case RecurKind::FMulAdd: + case RecurKind::SelectICmp: + case RecurKind::SelectFCmp: + case RecurKind::None: + llvm_unreachable("Unexpected reduction kind for reused scalars."); + } + return nullptr; + } +}; } // end anonymous namespace static std::optional<unsigned> getAggregateSize(Instruction *InsertInst) { @@ -13075,15 +14171,15 @@ static bool findBuildAggregate(Instruction *LastInsertInst, return false; } -/// Try and get a reduction value from a phi node. +/// Try and get a reduction instruction from a phi node. /// /// Given a phi node \p P in a block \p ParentBB, consider possible reductions /// if they come from either \p ParentBB or a containing loop latch. /// /// \returns A candidate reduction value if possible, or \code nullptr \endcode /// if not possible. -static Value *getReductionValue(const DominatorTree *DT, PHINode *P, - BasicBlock *ParentBB, LoopInfo *LI) { +static Instruction *getReductionInstr(const DominatorTree *DT, PHINode *P, + BasicBlock *ParentBB, LoopInfo *LI) { // There are situations where the reduction value is not dominated by the // reduction phi. Vectorizing such cases has been reported to cause // miscompiles. See PR25787. @@ -13092,13 +14188,13 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, DT->dominates(P->getParent(), cast<Instruction>(R)->getParent()); }; - Value *Rdx = nullptr; + Instruction *Rdx = nullptr; // Return the incoming value if it comes from the same BB as the phi node. if (P->getIncomingBlock(0) == ParentBB) { - Rdx = P->getIncomingValue(0); + Rdx = dyn_cast<Instruction>(P->getIncomingValue(0)); } else if (P->getIncomingBlock(1) == ParentBB) { - Rdx = P->getIncomingValue(1); + Rdx = dyn_cast<Instruction>(P->getIncomingValue(1)); } if (Rdx && DominatedReduxValue(Rdx)) @@ -13115,9 +14211,9 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, // There is a loop latch, return the incoming value if it comes from // that. This reduction pattern occasionally turns up. if (P->getIncomingBlock(0) == BBLatch) { - Rdx = P->getIncomingValue(0); + Rdx = dyn_cast<Instruction>(P->getIncomingValue(0)); } else if (P->getIncomingBlock(1) == BBLatch) { - Rdx = P->getIncomingValue(1); + Rdx = dyn_cast<Instruction>(P->getIncomingValue(1)); } if (Rdx && DominatedReduxValue(Rdx)) @@ -13133,6 +14229,10 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) { return true; if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(V0), m_Value(V1)))) return true; + if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(V0), m_Value(V1)))) + return true; + if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(V0), m_Value(V1)))) + return true; if (match(I, m_Intrinsic<Intrinsic::smax>(m_Value(V0), m_Value(V1)))) return true; if (match(I, m_Intrinsic<Intrinsic::smin>(m_Value(V0), m_Value(V1)))) @@ -13144,21 +14244,63 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) { return false; } +/// We could have an initial reduction that is not an add. +/// r *= v1 + v2 + v3 + v4 +/// In such a case start looking for a tree rooted in the first '+'. +/// \Returns the new root if found, which may be nullptr if not an instruction. +static Instruction *tryGetSecondaryReductionRoot(PHINode *Phi, + Instruction *Root) { + assert((isa<BinaryOperator>(Root) || isa<SelectInst>(Root) || + isa<IntrinsicInst>(Root)) && + "Expected binop, select, or intrinsic for reduction matching"); + Value *LHS = + Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root)); + Value *RHS = + Root->getOperand(HorizontalReduction::getFirstOperandIndex(Root) + 1); + if (LHS == Phi) + return dyn_cast<Instruction>(RHS); + if (RHS == Phi) + return dyn_cast<Instruction>(LHS); + return nullptr; +} + +/// \p Returns the first operand of \p I that does not match \p Phi. If +/// operand is not an instruction it returns nullptr. +static Instruction *getNonPhiOperand(Instruction *I, PHINode *Phi) { + Value *Op0 = nullptr; + Value *Op1 = nullptr; + if (!matchRdxBop(I, Op0, Op1)) + return nullptr; + return dyn_cast<Instruction>(Op0 == Phi ? Op1 : Op0); +} + +/// \Returns true if \p I is a candidate instruction for reduction vectorization. +static bool isReductionCandidate(Instruction *I) { + bool IsSelect = match(I, m_Select(m_Value(), m_Value(), m_Value())); + Value *B0 = nullptr, *B1 = nullptr; + bool IsBinop = matchRdxBop(I, B0, B1); + return IsBinop || IsSelect; +} + bool SLPVectorizerPass::vectorizeHorReduction( - PHINode *P, Value *V, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI, + PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI, SmallVectorImpl<WeakTrackingVH> &PostponedInsts) { if (!ShouldVectorizeHor) return false; + bool TryOperandsAsNewSeeds = P && isa<BinaryOperator>(Root); - auto *Root = dyn_cast_or_null<Instruction>(V); - if (!Root) + if (Root->getParent() != BB || isa<PHINode>(Root)) return false; - if (!isa<BinaryOperator>(Root)) - P = nullptr; + // If we can find a secondary reduction root, use that instead. + auto SelectRoot = [&]() { + if (TryOperandsAsNewSeeds && isReductionCandidate(Root) && + HorizontalReduction::getRdxKind(Root) != RecurKind::None) + if (Instruction *NewRoot = tryGetSecondaryReductionRoot(P, Root)) + return NewRoot; + return Root; + }; - if (Root->getParent() != BB || isa<PHINode>(Root)) - return false; // Start analysis starting from Root instruction. If horizontal reduction is // found, try to vectorize it. If it is not a horizontal reduction or // vectorization is not possible or not effective, and currently analyzed @@ -13171,22 +14313,32 @@ bool SLPVectorizerPass::vectorizeHorReduction( // If a horizintal reduction was not matched or vectorized we collect // instructions for possible later attempts for vectorization. std::queue<std::pair<Instruction *, unsigned>> Stack; - Stack.emplace(Root, 0); + Stack.emplace(SelectRoot(), 0); SmallPtrSet<Value *, 8> VisitedInstrs; bool Res = false; - auto &&TryToReduce = [this, TTI, &P, &R](Instruction *Inst, Value *&B0, - Value *&B1) -> Value * { + auto &&TryToReduce = [this, TTI, &R](Instruction *Inst) -> Value * { if (R.isAnalyzedReductionRoot(Inst)) return nullptr; - bool IsBinop = matchRdxBop(Inst, B0, B1); - bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value())); - if (IsBinop || IsSelect) { - HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, Inst, *SE, *DL, *TLI)) - return HorRdx.tryToReduce(R, TTI, *TLI); + if (!isReductionCandidate(Inst)) + return nullptr; + HorizontalReduction HorRdx; + if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI)) + return nullptr; + return HorRdx.tryToReduce(R, TTI, *TLI); + }; + auto TryAppendToPostponedInsts = [&](Instruction *FutureSeed) { + if (TryOperandsAsNewSeeds && FutureSeed == Root) { + FutureSeed = getNonPhiOperand(Root, P); + if (!FutureSeed) + return false; } - return nullptr; + // Do not collect CmpInst or InsertElementInst/InsertValueInst as their + // analysis is done separately. + if (!isa<CmpInst, InsertElementInst, InsertValueInst>(FutureSeed)) + PostponedInsts.push_back(FutureSeed); + return true; }; + while (!Stack.empty()) { Instruction *Inst; unsigned Level; @@ -13197,37 +14349,19 @@ bool SLPVectorizerPass::vectorizeHorReduction( // iteration while stack was populated before that happened. if (R.isDeleted(Inst)) continue; - Value *B0 = nullptr, *B1 = nullptr; - if (Value *V = TryToReduce(Inst, B0, B1)) { + if (Value *VectorizedV = TryToReduce(Inst)) { Res = true; - // Set P to nullptr to avoid re-analysis of phi node in - // matchAssociativeReduction function unless this is the root node. - P = nullptr; - if (auto *I = dyn_cast<Instruction>(V)) { + if (auto *I = dyn_cast<Instruction>(VectorizedV)) { // Try to find another reduction. Stack.emplace(I, Level); continue; } } else { - bool IsBinop = B0 && B1; - if (P && IsBinop) { - Inst = dyn_cast<Instruction>(B0); - if (Inst == P) - Inst = dyn_cast<Instruction>(B1); - if (!Inst) { - // Set P to nullptr to avoid re-analysis of phi node in - // matchAssociativeReduction function unless this is the root node. - P = nullptr; - continue; - } + // We could not vectorize `Inst` so try to use it as a future seed. + if (!TryAppendToPostponedInsts(Inst)) { + assert(Stack.empty() && "Expected empty stack"); + break; } - // Set P to nullptr to avoid re-analysis of phi node in - // matchAssociativeReduction function unless this is the root node. - P = nullptr; - // Do not collect CmpInst or InsertElementInst/InsertValueInst as their - // analysis is done separately. - if (!isa<CmpInst, InsertElementInst, InsertValueInst>(Inst)) - PostponedInsts.push_back(Inst); } // Try to vectorize operands. @@ -13246,11 +14380,11 @@ bool SLPVectorizerPass::vectorizeHorReduction( return Res; } -bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, +bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI) { SmallVector<WeakTrackingVH> PostponedInsts; - bool Res = vectorizeHorReduction(P, V, BB, R, TTI, PostponedInsts); + bool Res = vectorizeHorReduction(P, Root, BB, R, TTI, PostponedInsts); Res |= tryToVectorize(PostponedInsts, R); return Res; } @@ -13297,13 +14431,11 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI, } template <typename T> -static bool -tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, - function_ref<unsigned(T *)> Limit, - function_ref<bool(T *, T *)> Comparator, - function_ref<bool(T *, T *)> AreCompatible, - function_ref<bool(ArrayRef<T *>, bool)> TryToVectorizeHelper, - bool LimitForRegisterSize) { +static bool tryToVectorizeSequence( + SmallVectorImpl<T *> &Incoming, function_ref<bool(T *, T *)> Comparator, + function_ref<bool(T *, T *)> AreCompatible, + function_ref<bool(ArrayRef<T *>, bool)> TryToVectorizeHelper, + bool MaxVFOnly, BoUpSLP &R) { bool Changed = false; // Sort by type, parent, operands. stable_sort(Incoming, Comparator); @@ -13331,21 +14463,29 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, // same/alternate ops only, this may result in some extra final // vectorization. if (NumElts > 1 && - TryToVectorizeHelper(ArrayRef(IncIt, NumElts), LimitForRegisterSize)) { + TryToVectorizeHelper(ArrayRef(IncIt, NumElts), MaxVFOnly)) { // Success start over because instructions might have been changed. Changed = true; - } else if (NumElts < Limit(*IncIt) && - (Candidates.empty() || - Candidates.front()->getType() == (*IncIt)->getType())) { - Candidates.append(IncIt, std::next(IncIt, NumElts)); + } else { + /// \Returns the minimum number of elements that we will attempt to + /// vectorize. + auto GetMinNumElements = [&R](Value *V) { + unsigned EltSize = R.getVectorElementSize(V); + return std::max(2U, R.getMaxVecRegSize() / EltSize); + }; + if (NumElts < GetMinNumElements(*IncIt) && + (Candidates.empty() || + Candidates.front()->getType() == (*IncIt)->getType())) { + Candidates.append(IncIt, std::next(IncIt, NumElts)); + } } // Final attempt to vectorize instructions with the same types. if (Candidates.size() > 1 && (SameTypeIt == E || (*SameTypeIt)->getType() != (*IncIt)->getType())) { - if (TryToVectorizeHelper(Candidates, /*LimitForRegisterSize=*/false)) { + if (TryToVectorizeHelper(Candidates, /*MaxVFOnly=*/false)) { // Success start over because instructions might have been changed. Changed = true; - } else if (LimitForRegisterSize) { + } else if (MaxVFOnly) { // Try to vectorize using small vectors. for (auto *It = Candidates.begin(), *End = Candidates.end(); It != End;) { @@ -13353,9 +14493,8 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, while (SameTypeIt != End && AreCompatible(*SameTypeIt, *It)) ++SameTypeIt; unsigned NumElts = (SameTypeIt - It); - if (NumElts > 1 && - TryToVectorizeHelper(ArrayRef(It, NumElts), - /*LimitForRegisterSize=*/false)) + if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(It, NumElts), + /*MaxVFOnly=*/false)) Changed = true; It = SameTypeIt; } @@ -13378,11 +14517,12 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, /// of the second cmp instruction. template <bool IsCompatibility> static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI, - function_ref<bool(Instruction *)> IsDeleted) { + const DominatorTree &DT) { + assert(isValidElementType(V->getType()) && + isValidElementType(V2->getType()) && + "Expected valid element types only."); auto *CI1 = cast<CmpInst>(V); auto *CI2 = cast<CmpInst>(V2); - if (IsDeleted(CI2) || !isValidElementType(CI2->getType())) - return false; if (CI1->getOperand(0)->getType()->getTypeID() < CI2->getOperand(0)->getType()->getTypeID()) return !IsCompatibility; @@ -13411,31 +14551,102 @@ static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI, return false; if (auto *I1 = dyn_cast<Instruction>(Op1)) if (auto *I2 = dyn_cast<Instruction>(Op2)) { - if (I1->getParent() != I2->getParent()) - return false; + if (IsCompatibility) { + if (I1->getParent() != I2->getParent()) + return false; + } else { + // Try to compare nodes with same parent. + DomTreeNodeBase<BasicBlock> *NodeI1 = DT.getNode(I1->getParent()); + DomTreeNodeBase<BasicBlock> *NodeI2 = DT.getNode(I2->getParent()); + if (!NodeI1) + return NodeI2 != nullptr; + if (!NodeI2) + return false; + assert((NodeI1 == NodeI2) == + (NodeI1->getDFSNumIn() == NodeI2->getDFSNumIn()) && + "Different nodes should have different DFS numbers"); + if (NodeI1 != NodeI2) + return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); + } InstructionsState S = getSameOpcode({I1, I2}, TLI); - if (S.getOpcode()) + if (S.getOpcode() && (IsCompatibility || !S.isAltShuffle())) continue; - return false; + return !IsCompatibility && I1->getOpcode() < I2->getOpcode(); } } return IsCompatibility; } -bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions, - BasicBlock *BB, BoUpSLP &R, - bool AtTerminator) { +template <typename ItT> +bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range<ItT> CmpInsts, + BasicBlock *BB, BoUpSLP &R) { + bool Changed = false; + // Try to find reductions first. + for (CmpInst *I : CmpInsts) { + if (R.isDeleted(I)) + continue; + for (Value *Op : I->operands()) + if (auto *RootOp = dyn_cast<Instruction>(Op)) + Changed |= vectorizeRootInstruction(nullptr, RootOp, BB, R, TTI); + } + // Try to vectorize operands as vector bundles. + for (CmpInst *I : CmpInsts) { + if (R.isDeleted(I)) + continue; + Changed |= tryToVectorize(I, R); + } + // Try to vectorize list of compares. + // Sort by type, compare predicate, etc. + auto CompareSorter = [&](Value *V, Value *V2) { + if (V == V2) + return false; + return compareCmp<false>(V, V2, *TLI, *DT); + }; + + auto AreCompatibleCompares = [&](Value *V1, Value *V2) { + if (V1 == V2) + return true; + return compareCmp<true>(V1, V2, *TLI, *DT); + }; + + SmallVector<Value *> Vals; + for (Instruction *V : CmpInsts) + if (!R.isDeleted(V) && isValidElementType(V->getType())) + Vals.push_back(V); + if (Vals.size() <= 1) + return Changed; + Changed |= tryToVectorizeSequence<Value>( + Vals, CompareSorter, AreCompatibleCompares, + [this, &R](ArrayRef<Value *> Candidates, bool MaxVFOnly) { + // Exclude possible reductions from other blocks. + bool ArePossiblyReducedInOtherBlock = any_of(Candidates, [](Value *V) { + return any_of(V->users(), [V](User *U) { + auto *Select = dyn_cast<SelectInst>(U); + return Select && + Select->getParent() != cast<Instruction>(V)->getParent(); + }); + }); + if (ArePossiblyReducedInOtherBlock) + return false; + return tryToVectorizeList(Candidates, R, MaxVFOnly); + }, + /*MaxVFOnly=*/true, R); + return Changed; +} + +bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions, + BasicBlock *BB, BoUpSLP &R) { + assert(all_of(Instructions, + [](auto *I) { + return isa<InsertElementInst, InsertValueInst>(I); + }) && + "This function only accepts Insert instructions"); bool OpsChanged = false; - SmallVector<Instruction *, 4> PostponedCmps; SmallVector<WeakTrackingVH> PostponedInsts; // pass1 - try to vectorize reductions only for (auto *I : reverse(Instructions)) { if (R.isDeleted(I)) continue; - if (isa<CmpInst>(I)) { - PostponedCmps.push_back(I); - continue; - } OpsChanged |= vectorizeHorReduction(nullptr, I, BB, R, TTI, PostponedInsts); } // pass2 - try to match and vectorize a buildvector sequence. @@ -13451,63 +14662,7 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions, // Now try to vectorize postponed instructions. OpsChanged |= tryToVectorize(PostponedInsts, R); - if (AtTerminator) { - // Try to find reductions first. - for (Instruction *I : PostponedCmps) { - if (R.isDeleted(I)) - continue; - for (Value *Op : I->operands()) - OpsChanged |= vectorizeRootInstruction(nullptr, Op, BB, R, TTI); - } - // Try to vectorize operands as vector bundles. - for (Instruction *I : PostponedCmps) { - if (R.isDeleted(I)) - continue; - OpsChanged |= tryToVectorize(I, R); - } - // Try to vectorize list of compares. - // Sort by type, compare predicate, etc. - auto CompareSorter = [&](Value *V, Value *V2) { - return compareCmp<false>(V, V2, *TLI, - [&R](Instruction *I) { return R.isDeleted(I); }); - }; - - auto AreCompatibleCompares = [&](Value *V1, Value *V2) { - if (V1 == V2) - return true; - return compareCmp<true>(V1, V2, *TLI, - [&R](Instruction *I) { return R.isDeleted(I); }); - }; - auto Limit = [&R](Value *V) { - unsigned EltSize = R.getVectorElementSize(V); - return std::max(2U, R.getMaxVecRegSize() / EltSize); - }; - - SmallVector<Value *> Vals(PostponedCmps.begin(), PostponedCmps.end()); - OpsChanged |= tryToVectorizeSequence<Value>( - Vals, Limit, CompareSorter, AreCompatibleCompares, - [this, &R](ArrayRef<Value *> Candidates, bool LimitForRegisterSize) { - // Exclude possible reductions from other blocks. - bool ArePossiblyReducedInOtherBlock = - any_of(Candidates, [](Value *V) { - return any_of(V->users(), [V](User *U) { - return isa<SelectInst>(U) && - cast<SelectInst>(U)->getParent() != - cast<Instruction>(V)->getParent(); - }); - }); - if (ArePossiblyReducedInOtherBlock) - return false; - return tryToVectorizeList(Candidates, R, LimitForRegisterSize); - }, - /*LimitForRegisterSize=*/true); - Instructions.clear(); - } else { - Instructions.clear(); - // Insert in reverse order since the PostponedCmps vector was filled in - // reverse order. - Instructions.insert(PostponedCmps.rbegin(), PostponedCmps.rend()); - } + Instructions.clear(); return OpsChanged; } @@ -13603,10 +14758,6 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } return true; }; - auto Limit = [&R](Value *V) { - unsigned EltSize = R.getVectorElementSize(V); - return std::max(2U, R.getMaxVecRegSize() / EltSize); - }; bool HaveVectorizedPhiNodes = false; do { @@ -13648,19 +14799,44 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } HaveVectorizedPhiNodes = tryToVectorizeSequence<Value>( - Incoming, Limit, PHICompare, AreCompatiblePHIs, - [this, &R](ArrayRef<Value *> Candidates, bool LimitForRegisterSize) { - return tryToVectorizeList(Candidates, R, LimitForRegisterSize); + Incoming, PHICompare, AreCompatiblePHIs, + [this, &R](ArrayRef<Value *> Candidates, bool MaxVFOnly) { + return tryToVectorizeList(Candidates, R, MaxVFOnly); }, - /*LimitForRegisterSize=*/true); + /*MaxVFOnly=*/true, R); Changed |= HaveVectorizedPhiNodes; VisitedInstrs.insert(Incoming.begin(), Incoming.end()); } while (HaveVectorizedPhiNodes); VisitedInstrs.clear(); - InstSetVector PostProcessInstructions; - SmallDenseSet<Instruction *, 4> KeyNodes; + InstSetVector PostProcessInserts; + SmallSetVector<CmpInst *, 8> PostProcessCmps; + // Vectorizes Inserts in `PostProcessInserts` and if `VecctorizeCmps` is true + // also vectorizes `PostProcessCmps`. + auto VectorizeInsertsAndCmps = [&](bool VectorizeCmps) { + bool Changed = vectorizeInserts(PostProcessInserts, BB, R); + if (VectorizeCmps) { + Changed |= vectorizeCmpInsts(reverse(PostProcessCmps), BB, R); + PostProcessCmps.clear(); + } + PostProcessInserts.clear(); + return Changed; + }; + // Returns true if `I` is in `PostProcessInserts` or `PostProcessCmps`. + auto IsInPostProcessInstrs = [&](Instruction *I) { + if (auto *Cmp = dyn_cast<CmpInst>(I)) + return PostProcessCmps.contains(Cmp); + return isa<InsertElementInst, InsertValueInst>(I) && + PostProcessInserts.contains(I); + }; + // Returns true if `I` is an instruction without users, like terminator, or + // function call with ignored return value, store. Ignore unused instructions + // (basing on instruction type, except for CallInst and InvokeInst). + auto HasNoUsers = [](Instruction *I) { + return I->use_empty() && + (I->getType()->isVoidTy() || isa<CallInst, InvokeInst>(I)); + }; for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { // Skip instructions with scalable type. The num of elements is unknown at // compile-time for scalable type. @@ -13672,9 +14848,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; // We may go through BB multiple times so skip the one we have checked. if (!VisitedInstrs.insert(&*it).second) { - if (it->use_empty() && KeyNodes.contains(&*it) && - vectorizeSimpleInstructions(PostProcessInstructions, BB, R, - it->isTerminator())) { + if (HasNoUsers(&*it) && + VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator())) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. Changed = true; @@ -13692,8 +14867,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Check that the PHI is a reduction PHI. if (P->getNumIncomingValues() == 2) { // Try to match and vectorize a horizontal reduction. - if (vectorizeRootInstruction(P, getReductionValue(DT, P, BB, LI), BB, R, - TTI)) { + Instruction *Root = getReductionInstr(DT, P, BB, LI); + if (Root && vectorizeRootInstruction(P, Root, BB, R, TTI)) { Changed = true; it = BB->begin(); e = BB->end(); @@ -13714,19 +14889,14 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *PI = dyn_cast<Instruction>(P->getIncomingValue(I)); - PI && !PostProcessInstructions.contains(PI)) - Changed |= vectorizeRootInstruction(nullptr, P->getIncomingValue(I), + PI && !IsInPostProcessInstrs(PI)) + Changed |= vectorizeRootInstruction(nullptr, PI, P->getIncomingBlock(I), R, TTI); } continue; } - // Ran into an instruction without users, like terminator, or function call - // with ignored return value, store. Ignore unused instructions (basing on - // instruction type, except for CallInst and InvokeInst). - if (it->use_empty() && - (it->getType()->isVoidTy() || isa<CallInst, InvokeInst>(it))) { - KeyNodes.insert(&*it); + if (HasNoUsers(&*it)) { bool OpsChanged = false; auto *SI = dyn_cast<StoreInst>(it); bool TryToVectorizeRoot = ShouldStartVectorizeHorAtStore || !SI; @@ -13746,16 +14916,16 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *VI = dyn_cast<Instruction>(V); - VI && !PostProcessInstructions.contains(VI)) + VI && !IsInPostProcessInstrs(VI)) // Try to match and vectorize a horizontal reduction. - OpsChanged |= vectorizeRootInstruction(nullptr, V, BB, R, TTI); + OpsChanged |= vectorizeRootInstruction(nullptr, VI, BB, R, TTI); } } // Start vectorization of post-process list of instructions from the // top-tree instructions to try to vectorize as many instructions as // possible. - OpsChanged |= vectorizeSimpleInstructions(PostProcessInstructions, BB, R, - it->isTerminator()); + OpsChanged |= + VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator()); if (OpsChanged) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. @@ -13766,8 +14936,10 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } } - if (isa<CmpInst, InsertElementInst, InsertValueInst>(it)) - PostProcessInstructions.insert(&*it); + if (isa<InsertElementInst, InsertValueInst>(it)) + PostProcessInserts.insert(&*it); + else if (isa<CmpInst>(it)) + PostProcessCmps.insert(cast<CmpInst>(&*it)); } return Changed; @@ -13928,10 +15100,6 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { return V1->getValueOperand()->getValueID() == V2->getValueOperand()->getValueID(); }; - auto Limit = [&R, this](StoreInst *SI) { - unsigned EltSize = DL->getTypeSizeInBits(SI->getValueOperand()->getType()); - return R.getMinVF(EltSize); - }; // Attempt to sort and vectorize each of the store-groups. for (auto &Pair : Stores) { @@ -13945,28 +15113,11 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { continue; Changed |= tryToVectorizeSequence<StoreInst>( - Pair.second, Limit, StoreSorter, AreCompatibleStores, + Pair.second, StoreSorter, AreCompatibleStores, [this, &R](ArrayRef<StoreInst *> Candidates, bool) { return vectorizeStores(Candidates, R); }, - /*LimitForRegisterSize=*/false); + /*MaxVFOnly=*/false, R); } return Changed; } - -char SLPVectorizer::ID = 0; - -static const char lv_name[] = "SLP Vectorizer"; - -INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_DEPENDENCY(InjectTLIMappingsLegacy) -INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false) - -Pass *llvm::createSLPVectorizerPass() { return new SLPVectorizer(); } diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h index 733d2e1c667b..1271d1424c03 100644 --- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -95,7 +95,7 @@ class VPRecipeBuilder { /// return a new VPWidenCallRecipe. Range.End may be decreased to ensure same /// decision from \p Range.Start to \p Range.End. VPWidenCallRecipe *tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands, - VFRange &Range) const; + VFRange &Range, VPlanPtr &Plan); /// Check if \p I has an opcode that can be widened and return a VPWidenRecipe /// if it can. The function should only be called if the cost-model indicates @@ -136,11 +136,11 @@ public: /// A helper function that computes the predicate of the block BB, assuming /// that the header block of the loop is set to True. It returns the *entry* /// mask for the block BB. - VPValue *createBlockInMask(BasicBlock *BB, VPlanPtr &Plan); + VPValue *createBlockInMask(BasicBlock *BB, VPlan &Plan); /// A helper function that computes the predicate of the edge between SRC /// and DST. - VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlanPtr &Plan); + VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlan &Plan); /// Mark given ingredient for recording its recipe once one is created for /// it. @@ -159,19 +159,11 @@ public: return Ingredient2Recipe[I]; } - /// Create a replicating region for \p PredRecipe. - VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, - VPlanPtr &Plan); - - /// Build a VPReplicationRecipe for \p I and enclose it within a Region if it - /// is predicated. \return \p VPBB augmented with this new recipe if \p I is - /// not predicated, otherwise \return a new VPBasicBlock that succeeds the new - /// Region. Update the packing decision of predicated instructions if they - /// feed \p I. Range.End may be decreased to ensure same recipe behavior from - /// \p Range.Start to \p Range.End. - VPBasicBlock *handleReplication( - Instruction *I, VFRange &Range, VPBasicBlock *VPBB, - VPlanPtr &Plan); + /// Build a VPReplicationRecipe for \p I. If it is predicated, add the mask as + /// last operand. Range.End may be decreased to ensure same recipe behavior + /// from \p Range.Start to \p Range.End. + VPRecipeOrVPValueTy handleReplication(Instruction *I, VFRange &Range, + VPlan &Plan); /// Add the incoming values from the backedge to reduction & first-order /// recurrence cross-iteration phis. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index d554f438c804..e81b88fd8099 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" @@ -46,7 +47,10 @@ #include <vector> using namespace llvm; + +namespace llvm { extern cl::opt<bool> EnableVPlanNativePath; +} #define DEBUG_TYPE "vplan" @@ -160,8 +164,9 @@ VPBasicBlock *VPBlockBase::getEntryBasicBlock() { } void VPBlockBase::setPlan(VPlan *ParentPlan) { - assert(ParentPlan->getEntry() == this && - "Can only set plan on its entry block."); + assert( + (ParentPlan->getEntry() == this || ParentPlan->getPreheader() == this) && + "Can only set plan on its entry or preheader block."); Plan = ParentPlan; } @@ -209,7 +214,7 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() { } Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { - if (!Def->hasDefiningRecipe()) + if (Def->isLiveIn()) return Def->getLiveInIRValue(); if (hasScalarValue(Def, Instance)) { @@ -243,11 +248,19 @@ void VPTransformState::addNewMetadata(Instruction *To, } void VPTransformState::addMetadata(Instruction *To, Instruction *From) { + // No source instruction to transfer metadata from? + if (!From) + return; + propagateMetadata(To, From); addNewMetadata(To, From); } void VPTransformState::addMetadata(ArrayRef<Value *> To, Instruction *From) { + // No source instruction to transfer metadata from? + if (!From) + return; + for (Value *V : To) { if (Instruction *I = dyn_cast<Instruction>(V)) addMetadata(I, From); @@ -265,7 +278,7 @@ void VPTransformState::setDebugLocFromInst(const Value *V) { // When a FSDiscriminator is enabled, we don't need to add the multiply // factors to the discriminators. if (DIL && Inst->getFunction()->shouldEmitDebugInfoForProfiling() && - !isa<DbgInfoIntrinsic>(Inst) && !EnableFSDiscriminator) { + !Inst->isDebugOrPseudoInst() && !EnableFSDiscriminator) { // FIXME: For scalable vectors, assume vscale=1. auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(UF * VF.getKnownMinValue()); @@ -577,7 +590,9 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, #endif VPlan::~VPlan() { - clearLiveOuts(); + for (auto &KV : LiveOuts) + delete KV.second; + LiveOuts.clear(); if (Entry) { VPValue DummyValue; @@ -585,15 +600,23 @@ VPlan::~VPlan() { Block->dropAllReferences(&DummyValue); VPBlockBase::deleteCFG(Entry); + + Preheader->dropAllReferences(&DummyValue); + delete Preheader; } - for (VPValue *VPV : VPValuesToFree) + for (VPValue *VPV : VPLiveInsToFree) delete VPV; - if (TripCount) - delete TripCount; if (BackedgeTakenCount) delete BackedgeTakenCount; - for (auto &P : VPExternalDefs) - delete P.second; +} + +VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE) { + VPBasicBlock *Preheader = new VPBasicBlock("ph"); + VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph"); + auto Plan = std::make_unique<VPlan>(Preheader, VecPreheader); + Plan->TripCount = + vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE); + return Plan; } VPActiveLaneMaskPHIRecipe *VPlan::getActiveLaneMaskPhi() { @@ -609,13 +632,6 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, Value *CanonicalIVStartValue, VPTransformState &State, bool IsEpilogueVectorization) { - - // Check if the trip count is needed, and if so build it. - if (TripCount && TripCount->getNumUsers()) { - for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) - State.set(TripCount, TripCountV, Part); - } - // Check if the backedge taken count is needed, and if so build it. if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) { IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); @@ -636,7 +652,7 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, // needs to be changed from zero to the value after the main vector loop. // FIXME: Improve modeling for canonical IV start values in the epilogue loop. if (CanonicalIVStartValue) { - VPValue *VPV = getOrAddExternalDef(CanonicalIVStartValue); + VPValue *VPV = getVPValueOrAddLiveIn(CanonicalIVStartValue); auto *IV = getCanonicalIV(); assert(all_of(IV->users(), [](const VPUser *U) { @@ -650,8 +666,7 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, VPInstruction::CanonicalIVIncrementNUW; }) && "the canonical IV should only be used by its increments or " - "ScalarIVSteps when " - "resetting the start value"); + "ScalarIVSteps when resetting the start value"); IV->setOperand(0, VPV); } } @@ -748,13 +763,25 @@ void VPlan::print(raw_ostream &O) const { if (VectorTripCount.getNumUsers() > 0) { O << "\nLive-in "; VectorTripCount.printAsOperand(O, SlotTracker); - O << " = vector-trip-count\n"; + O << " = vector-trip-count"; } if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) { O << "\nLive-in "; BackedgeTakenCount->printAsOperand(O, SlotTracker); - O << " = backedge-taken count\n"; + O << " = backedge-taken count"; + } + + O << "\n"; + if (TripCount->isLiveIn()) + O << "Live-in "; + TripCount->printAsOperand(O, SlotTracker); + O << " = original trip-count"; + O << "\n"; + + if (!getPreheader()->empty()) { + O << "\n"; + getPreheader()->print(O, "", SlotTracker); } for (const VPBlockBase *Block : vp_depth_first_shallow(getEntry())) { @@ -765,11 +792,7 @@ void VPlan::print(raw_ostream &O) const { if (!LiveOuts.empty()) O << "\n"; for (const auto &KV : LiveOuts) { - O << "Live-out "; - KV.second->getPhi()->printAsOperand(O); - O << " = "; - KV.second->getOperand(0)->printAsOperand(O, SlotTracker); - O << "\n"; + KV.second->print(O, SlotTracker); } O << "}\n"; @@ -882,6 +905,8 @@ void VPlanPrinter::dump() { OS << "edge [fontname=Courier, fontsize=30]\n"; OS << "compound=true\n"; + dumpBlock(Plan.getPreheader()); + for (const VPBlockBase *Block : vp_depth_first_shallow(Plan.getEntry())) dumpBlock(Block); @@ -1086,26 +1111,27 @@ VPInterleavedAccessInfo::VPInterleavedAccessInfo(VPlan &Plan, } void VPSlotTracker::assignSlot(const VPValue *V) { - assert(Slots.find(V) == Slots.end() && "VPValue already has a slot!"); + assert(!Slots.contains(V) && "VPValue already has a slot!"); Slots[V] = NextSlot++; } void VPSlotTracker::assignSlots(const VPlan &Plan) { - - for (const auto &P : Plan.VPExternalDefs) - assignSlot(P.second); - assignSlot(&Plan.VectorTripCount); if (Plan.BackedgeTakenCount) assignSlot(Plan.BackedgeTakenCount); + assignSlots(Plan.getPreheader()); ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<const VPBlockBase *>> RPOT(VPBlockDeepTraversalWrapper<const VPBlockBase *>(Plan.getEntry())); for (const VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<const VPBasicBlock>(RPOT)) - for (const VPRecipeBase &Recipe : *VPBB) - for (VPValue *Def : Recipe.definedValues()) - assignSlot(Def); + assignSlots(VPBB); +} + +void VPSlotTracker::assignSlots(const VPBasicBlock *VPBB) { + for (const VPRecipeBase &Recipe : *VPBB) + for (VPValue *Def : Recipe.definedValues()) + assignSlot(Def); } bool vputils::onlyFirstLaneUsed(VPValue *Def) { @@ -1115,13 +1141,17 @@ bool vputils::onlyFirstLaneUsed(VPValue *Def) { VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, ScalarEvolution &SE) { + if (auto *Expanded = Plan.getSCEVExpansion(Expr)) + return Expanded; + VPValue *Expanded = nullptr; if (auto *E = dyn_cast<SCEVConstant>(Expr)) - return Plan.getOrAddExternalDef(E->getValue()); - if (auto *E = dyn_cast<SCEVUnknown>(Expr)) - return Plan.getOrAddExternalDef(E->getValue()); - - VPBasicBlock *Preheader = Plan.getEntry()->getEntryBasicBlock(); - VPExpandSCEVRecipe *Step = new VPExpandSCEVRecipe(Expr, SE); - Preheader->appendRecipe(Step); - return Step; + Expanded = Plan.getVPValueOrAddLiveIn(E->getValue()); + else if (auto *E = dyn_cast<SCEVUnknown>(Expr)) + Expanded = Plan.getVPValueOrAddLiveIn(E->getValue()); + else { + Expanded = new VPExpandSCEVRecipe(Expr, SE); + Plan.getPreheader()->appendRecipe(Expanded->getDefiningRecipe()); + } + Plan.addSCEVExpansion(Expr, Expanded); + return Expanded; } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 986faaf99664..73313465adea 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -25,7 +25,6 @@ #include "VPlanValue.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -33,11 +32,12 @@ #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" +#include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/FMF.h" -#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/IR/Operator.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -47,11 +47,9 @@ namespace llvm { class BasicBlock; class DominatorTree; -class InductionDescriptor; class InnerLoopVectorizer; class IRBuilderBase; class LoopInfo; -class PredicateScalarEvolution; class raw_ostream; class RecurrenceDescriptor; class SCEV; @@ -62,6 +60,7 @@ class VPlan; class VPReplicateRecipe; class VPlanSlp; class Value; +class LoopVersioning; namespace Intrinsic { typedef unsigned ID; @@ -76,16 +75,17 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF); Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, int64_t Step); -const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE); +const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE, + Loop *CurLoop = nullptr); /// A range of powers-of-2 vectorization factors with fixed start and /// adjustable end. The range includes start and excludes end, e.g.,: -/// [1, 9) = {1, 2, 4, 8} +/// [1, 16) = {1, 2, 4, 8} struct VFRange { // A power of 2. const ElementCount Start; - // Need not be a power of 2. If End <= Start range is empty. + // A power of 2. If End <= Start range is empty. ElementCount End; bool isEmpty() const { @@ -98,6 +98,33 @@ struct VFRange { "Both Start and End should have the same scalable flag"); assert(isPowerOf2_32(Start.getKnownMinValue()) && "Expected Start to be a power of 2"); + assert(isPowerOf2_32(End.getKnownMinValue()) && + "Expected End to be a power of 2"); + } + + /// Iterator to iterate over vectorization factors in a VFRange. + class iterator + : public iterator_facade_base<iterator, std::forward_iterator_tag, + ElementCount> { + ElementCount VF; + + public: + iterator(ElementCount VF) : VF(VF) {} + + bool operator==(const iterator &Other) const { return VF == Other.VF; } + + ElementCount operator*() const { return VF; } + + iterator &operator++() { + VF *= 2; + return *this; + } + }; + + iterator begin() { return iterator(Start); } + iterator end() { + assert(isPowerOf2_32(End.getKnownMinValue())); + return iterator(End); } }; @@ -248,7 +275,7 @@ struct VPTransformState { } bool hasAnyVectorValue(VPValue *Def) const { - return Data.PerPartOutput.find(Def) != Data.PerPartOutput.end(); + return Data.PerPartOutput.contains(Def); } bool hasScalarValue(VPValue *Def, VPIteration Instance) { @@ -370,10 +397,6 @@ struct VPTransformState { /// Pointer to the VPlan code is generated for. VPlan *Plan; - /// Holds recipes that may generate a poison value that is used after - /// vectorization, even when their operands are not poison. - SmallPtrSet<VPRecipeBase *, 16> MayGeneratePoisonRecipes; - /// The loop object for the current parent region, or nullptr. Loop *CurrentVectorLoop = nullptr; @@ -382,7 +405,11 @@ struct VPTransformState { /// /// This is currently only used to add no-alias metadata based on the /// memchecks. The actually versioning is performed manually. - std::unique_ptr<LoopVersioning> LVer; + LoopVersioning *LVer = nullptr; + + /// Map SCEVs to their expanded values. Populated when executing + /// VPExpandSCEVRecipes. + DenseMap<const SCEV *, Value *> ExpandedSCEVs; }; /// VPBlockBase is the building block of the Hierarchical Control-Flow Graph. @@ -639,6 +666,10 @@ public: VPLiveOut(PHINode *Phi, VPValue *Op) : VPUser({Op}, VPUser::VPUserID::LiveOut), Phi(Phi) {} + static inline bool classof(const VPUser *U) { + return U->getVPUserID() == VPUser::VPUserID::LiveOut; + } + /// Fixup the wrapped LCSSA phi node in the unique exit block. This simply /// means we need to add the appropriate incoming value from the middle /// block as exiting edges from the scalar epilogue loop (if present) are @@ -654,6 +685,11 @@ public: } PHINode *getPhi() const { return Phi; } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the VPLiveOut to \p O. + void print(raw_ostream &O, VPSlotTracker &SlotTracker) const; +#endif }; /// VPRecipeBase is a base class modeling a sequence of one or more output IR @@ -790,6 +826,7 @@ public: SLPLoad, SLPStore, ActiveLaneMask, + CalculateTripCountMinusVF, CanonicalIVIncrement, CanonicalIVIncrementNUW, // The next two are similar to the above, but instead increment the @@ -810,8 +847,10 @@ private: const std::string Name; /// Utility method serving execute(): generates a single instance of the - /// modeled instruction. - void generateInstruction(VPTransformState &State, unsigned Part); + /// modeled instruction. \returns the generated value for \p Part. + /// In some cases an existing value is returned rather than a generated + /// one. + Value *generateInstruction(VPTransformState &State, unsigned Part); protected: void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } @@ -892,6 +931,7 @@ public: default: return false; case VPInstruction::ActiveLaneMask: + case VPInstruction::CalculateTripCountMinusVF: case VPInstruction::CanonicalIVIncrement: case VPInstruction::CanonicalIVIncrementNUW: case VPInstruction::CanonicalIVIncrementForPart: @@ -903,14 +943,169 @@ public: } }; +/// Class to record LLVM IR flag for a recipe along with it. +class VPRecipeWithIRFlags : public VPRecipeBase { + enum class OperationType : unsigned char { + OverflowingBinOp, + PossiblyExactOp, + GEPOp, + FPMathOp, + Other + }; + struct WrapFlagsTy { + char HasNUW : 1; + char HasNSW : 1; + }; + struct ExactFlagsTy { + char IsExact : 1; + }; + struct GEPFlagsTy { + char IsInBounds : 1; + }; + struct FastMathFlagsTy { + char AllowReassoc : 1; + char NoNaNs : 1; + char NoInfs : 1; + char NoSignedZeros : 1; + char AllowReciprocal : 1; + char AllowContract : 1; + char ApproxFunc : 1; + }; + + OperationType OpType; + + union { + WrapFlagsTy WrapFlags; + ExactFlagsTy ExactFlags; + GEPFlagsTy GEPFlags; + FastMathFlagsTy FMFs; + unsigned char AllFlags; + }; + +public: + template <typename IterT> + VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands) + : VPRecipeBase(SC, Operands) { + OpType = OperationType::Other; + AllFlags = 0; + } + + template <typename IterT> + VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands, + Instruction &I) + : VPRecipeWithIRFlags(SC, Operands) { + if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) { + OpType = OperationType::OverflowingBinOp; + WrapFlags.HasNUW = Op->hasNoUnsignedWrap(); + WrapFlags.HasNSW = Op->hasNoSignedWrap(); + } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) { + OpType = OperationType::PossiblyExactOp; + ExactFlags.IsExact = Op->isExact(); + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + OpType = OperationType::GEPOp; + GEPFlags.IsInBounds = GEP->isInBounds(); + } else if (auto *Op = dyn_cast<FPMathOperator>(&I)) { + OpType = OperationType::FPMathOp; + FastMathFlags FMF = Op->getFastMathFlags(); + FMFs.AllowReassoc = FMF.allowReassoc(); + FMFs.NoNaNs = FMF.noNaNs(); + FMFs.NoInfs = FMF.noInfs(); + FMFs.NoSignedZeros = FMF.noSignedZeros(); + FMFs.AllowReciprocal = FMF.allowReciprocal(); + FMFs.AllowContract = FMF.allowContract(); + FMFs.ApproxFunc = FMF.approxFunc(); + } + } + + static inline bool classof(const VPRecipeBase *R) { + return R->getVPDefID() == VPRecipeBase::VPWidenSC || + R->getVPDefID() == VPRecipeBase::VPWidenGEPSC || + R->getVPDefID() == VPRecipeBase::VPReplicateSC; + } + + /// Drop all poison-generating flags. + void dropPoisonGeneratingFlags() { + // NOTE: This needs to be kept in-sync with + // Instruction::dropPoisonGeneratingFlags. + switch (OpType) { + case OperationType::OverflowingBinOp: + WrapFlags.HasNUW = false; + WrapFlags.HasNSW = false; + break; + case OperationType::PossiblyExactOp: + ExactFlags.IsExact = false; + break; + case OperationType::GEPOp: + GEPFlags.IsInBounds = false; + break; + case OperationType::FPMathOp: + FMFs.NoNaNs = false; + FMFs.NoInfs = false; + break; + case OperationType::Other: + break; + } + } + + /// Set the IR flags for \p I. + void setFlags(Instruction *I) const { + switch (OpType) { + case OperationType::OverflowingBinOp: + I->setHasNoUnsignedWrap(WrapFlags.HasNUW); + I->setHasNoSignedWrap(WrapFlags.HasNSW); + break; + case OperationType::PossiblyExactOp: + I->setIsExact(ExactFlags.IsExact); + break; + case OperationType::GEPOp: + cast<GetElementPtrInst>(I)->setIsInBounds(GEPFlags.IsInBounds); + break; + case OperationType::FPMathOp: + I->setHasAllowReassoc(FMFs.AllowReassoc); + I->setHasNoNaNs(FMFs.NoNaNs); + I->setHasNoInfs(FMFs.NoInfs); + I->setHasNoSignedZeros(FMFs.NoSignedZeros); + I->setHasAllowReciprocal(FMFs.AllowReciprocal); + I->setHasAllowContract(FMFs.AllowContract); + I->setHasApproxFunc(FMFs.ApproxFunc); + break; + case OperationType::Other: + break; + } + } + + bool isInBounds() const { + assert(OpType == OperationType::GEPOp && + "recipe doesn't have inbounds flag"); + return GEPFlags.IsInBounds; + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + FastMathFlags getFastMathFlags() const { + FastMathFlags Res; + Res.setAllowReassoc(FMFs.AllowReassoc); + Res.setNoNaNs(FMFs.NoNaNs); + Res.setNoInfs(FMFs.NoInfs); + Res.setNoSignedZeros(FMFs.NoSignedZeros); + Res.setAllowReciprocal(FMFs.AllowReciprocal); + Res.setAllowContract(FMFs.AllowContract); + Res.setApproxFunc(FMFs.ApproxFunc); + return Res; + } + + void printFlags(raw_ostream &O) const; +#endif +}; + /// VPWidenRecipe is a recipe for producing a copy of vector type its /// ingredient. This recipe covers most of the traditional vectorization cases /// where each ingredient transforms into a vectorized version of itself. -class VPWidenRecipe : public VPRecipeBase, public VPValue { +class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue { + public: template <typename IterT> VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands) - : VPRecipeBase(VPDef::VPWidenSC, Operands), VPValue(this, &I) {} + : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPValue(this, &I) {} ~VPWidenRecipe() override = default; @@ -926,18 +1121,62 @@ public: #endif }; +/// VPWidenCastRecipe is a recipe to create vector cast instructions. +class VPWidenCastRecipe : public VPRecipeBase, public VPValue { + /// Cast instruction opcode. + Instruction::CastOps Opcode; + + /// Result type for the cast. + Type *ResultTy; + +public: + VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy, + CastInst *UI = nullptr) + : VPRecipeBase(VPDef::VPWidenCastSC, Op), VPValue(this, UI), + Opcode(Opcode), ResultTy(ResultTy) { + assert((!UI || UI->getOpcode() == Opcode) && + "opcode of underlying cast doesn't match"); + assert((!UI || UI->getType() == ResultTy) && + "result type of underlying cast doesn't match"); + } + + ~VPWidenCastRecipe() override = default; + + VP_CLASSOF_IMPL(VPDef::VPWidenCastSC) + + /// Produce widened copies of the cast. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + Instruction::CastOps getOpcode() const { return Opcode; } + + /// Returns the result type of the cast. + Type *getResultType() const { return ResultTy; } +}; + /// A recipe for widening Call instructions. class VPWidenCallRecipe : public VPRecipeBase, public VPValue { /// ID of the vector intrinsic to call when widening the call. If set the /// Intrinsic::not_intrinsic, a library call will be used instead. Intrinsic::ID VectorIntrinsicID; + /// If this recipe represents a library call, Variant stores a pointer to + /// the chosen function. There is a 1:1 mapping between a given VF and the + /// chosen vectorized variant, so there will be a different vplan for each + /// VF with a valid variant. + Function *Variant; public: template <typename IterT> VPWidenCallRecipe(CallInst &I, iterator_range<IterT> CallArguments, - Intrinsic::ID VectorIntrinsicID) + Intrinsic::ID VectorIntrinsicID, + Function *Variant = nullptr) : VPRecipeBase(VPDef::VPWidenCallSC, CallArguments), VPValue(this, &I), - VectorIntrinsicID(VectorIntrinsicID) {} + VectorIntrinsicID(VectorIntrinsicID), Variant(Variant) {} ~VPWidenCallRecipe() override = default; @@ -954,17 +1193,10 @@ public: }; /// A recipe for widening select instructions. -class VPWidenSelectRecipe : public VPRecipeBase, public VPValue { - - /// Is the condition of the select loop invariant? - bool InvariantCond; - -public: +struct VPWidenSelectRecipe : public VPRecipeBase, public VPValue { template <typename IterT> - VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands, - bool InvariantCond) - : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I), - InvariantCond(InvariantCond) {} + VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands) + : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I) {} ~VPWidenSelectRecipe() override = default; @@ -978,29 +1210,38 @@ public: void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; #endif + + VPValue *getCond() const { + return getOperand(0); + } + + bool isInvariantCond() const { + return getCond()->isDefinedOutsideVectorRegions(); + } }; /// A recipe for handling GEP instructions. -class VPWidenGEPRecipe : public VPRecipeBase, public VPValue { - bool IsPtrLoopInvariant; - SmallBitVector IsIndexLoopInvariant; +class VPWidenGEPRecipe : public VPRecipeWithIRFlags, public VPValue { + bool isPointerLoopInvariant() const { + return getOperand(0)->isDefinedOutsideVectorRegions(); + } + + bool isIndexLoopInvariant(unsigned I) const { + return getOperand(I + 1)->isDefinedOutsideVectorRegions(); + } + + bool areAllOperandsInvariant() const { + return all_of(operands(), [](VPValue *Op) { + return Op->isDefinedOutsideVectorRegions(); + }); + } public: template <typename IterT> VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands) - : VPRecipeBase(VPDef::VPWidenGEPSC, Operands), VPValue(this, GEP), - IsIndexLoopInvariant(GEP->getNumIndices(), false) {} + : VPRecipeWithIRFlags(VPDef::VPWidenGEPSC, Operands, *GEP), + VPValue(this, GEP) {} - template <typename IterT> - VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands, - Loop *OrigLoop) - : VPRecipeBase(VPDef::VPWidenGEPSC, Operands), VPValue(this, GEP), - IsIndexLoopInvariant(GEP->getNumIndices(), false) { - IsPtrLoopInvariant = OrigLoop->isLoopInvariant(GEP->getPointerOperand()); - for (auto Index : enumerate(GEP->indices())) - IsIndexLoopInvariant[Index.index()] = - OrigLoop->isLoopInvariant(Index.value().get()); - } ~VPWidenGEPRecipe() override = default; VP_CLASSOF_IMPL(VPDef::VPWidenGEPSC) @@ -1015,78 +1256,6 @@ public: #endif }; -/// A recipe for handling phi nodes of integer and floating-point inductions, -/// producing their vector values. -class VPWidenIntOrFpInductionRecipe : public VPRecipeBase, public VPValue { - PHINode *IV; - const InductionDescriptor &IndDesc; - bool NeedsVectorIV; - -public: - VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, - const InductionDescriptor &IndDesc, - bool NeedsVectorIV) - : VPRecipeBase(VPDef::VPWidenIntOrFpInductionSC, {Start, Step}), - VPValue(this, IV), IV(IV), IndDesc(IndDesc), - NeedsVectorIV(NeedsVectorIV) {} - - VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, - const InductionDescriptor &IndDesc, - TruncInst *Trunc, bool NeedsVectorIV) - : VPRecipeBase(VPDef::VPWidenIntOrFpInductionSC, {Start, Step}), - VPValue(this, Trunc), IV(IV), IndDesc(IndDesc), - NeedsVectorIV(NeedsVectorIV) {} - - ~VPWidenIntOrFpInductionRecipe() override = default; - - VP_CLASSOF_IMPL(VPDef::VPWidenIntOrFpInductionSC) - - /// Generate the vectorized and scalarized versions of the phi node as - /// needed by their users. - void execute(VPTransformState &State) override; - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - /// Print the recipe. - void print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const override; -#endif - - /// Returns the start value of the induction. - VPValue *getStartValue() { return getOperand(0); } - const VPValue *getStartValue() const { return getOperand(0); } - - /// Returns the step value of the induction. - VPValue *getStepValue() { return getOperand(1); } - const VPValue *getStepValue() const { return getOperand(1); } - - /// Returns the first defined value as TruncInst, if it is one or nullptr - /// otherwise. - TruncInst *getTruncInst() { - return dyn_cast_or_null<TruncInst>(getVPValue(0)->getUnderlyingValue()); - } - const TruncInst *getTruncInst() const { - return dyn_cast_or_null<TruncInst>(getVPValue(0)->getUnderlyingValue()); - } - - PHINode *getPHINode() { return IV; } - - /// Returns the induction descriptor for the recipe. - const InductionDescriptor &getInductionDescriptor() const { return IndDesc; } - - /// Returns true if the induction is canonical, i.e. starting at 0 and - /// incremented by UF * VF (= the original IV is incremented by 1). - bool isCanonical() const; - - /// Returns the scalar type of the induction. - const Type *getScalarType() const { - const TruncInst *TruncI = getTruncInst(); - return TruncI ? TruncI->getType() : IV->getType(); - } - - /// Returns true if a vector phi needs to be created for the induction. - bool needsVectorIV() const { return NeedsVectorIV; } -}; - /// A pure virtual base class for all recipes modeling header phis, including /// phis for first order recurrences, pointer inductions and reductions. The /// start value is the first operand of the recipe and the incoming value from @@ -1112,9 +1281,9 @@ public: /// per-lane based on the canonical induction. class VPHeaderPHIRecipe : public VPRecipeBase, public VPValue { protected: - VPHeaderPHIRecipe(unsigned char VPDefID, PHINode *Phi, + VPHeaderPHIRecipe(unsigned char VPDefID, Instruction *UnderlyingInstr, VPValue *Start = nullptr) - : VPRecipeBase(VPDefID, {}), VPValue(this, Phi) { + : VPRecipeBase(VPDefID, {}), VPValue(this, UnderlyingInstr) { if (Start) addOperand(Start); } @@ -1125,12 +1294,12 @@ public: /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPRecipeBase *B) { return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC && - B->getVPDefID() <= VPDef::VPLastPHISC; + B->getVPDefID() <= VPDef::VPLastHeaderPHISC; } static inline bool classof(const VPValue *V) { auto *B = V->getDefiningRecipe(); return B && B->getVPDefID() >= VPRecipeBase::VPFirstHeaderPHISC && - B->getVPDefID() <= VPRecipeBase::VPLastPHISC; + B->getVPDefID() <= VPRecipeBase::VPLastHeaderPHISC; } /// Generate the phi nodes. @@ -1154,17 +1323,92 @@ public: void setStartValue(VPValue *V) { setOperand(0, V); } /// Returns the incoming value from the loop backedge. - VPValue *getBackedgeValue() { + virtual VPValue *getBackedgeValue() { return getOperand(1); } /// Returns the backedge value as a recipe. The backedge value is guaranteed /// to be a recipe. - VPRecipeBase &getBackedgeRecipe() { + virtual VPRecipeBase &getBackedgeRecipe() { return *getBackedgeValue()->getDefiningRecipe(); } }; +/// A recipe for handling phi nodes of integer and floating-point inductions, +/// producing their vector values. +class VPWidenIntOrFpInductionRecipe : public VPHeaderPHIRecipe { + PHINode *IV; + TruncInst *Trunc; + const InductionDescriptor &IndDesc; + +public: + VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, + const InductionDescriptor &IndDesc) + : VPHeaderPHIRecipe(VPDef::VPWidenIntOrFpInductionSC, IV, Start), IV(IV), + Trunc(nullptr), IndDesc(IndDesc) { + addOperand(Step); + } + + VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, + const InductionDescriptor &IndDesc, + TruncInst *Trunc) + : VPHeaderPHIRecipe(VPDef::VPWidenIntOrFpInductionSC, Trunc, Start), + IV(IV), Trunc(Trunc), IndDesc(IndDesc) { + addOperand(Step); + } + + ~VPWidenIntOrFpInductionRecipe() override = default; + + VP_CLASSOF_IMPL(VPDef::VPWidenIntOrFpInductionSC) + + /// Generate the vectorized and scalarized versions of the phi node as + /// needed by their users. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + VPValue *getBackedgeValue() override { + // TODO: All operands of base recipe must exist and be at same index in + // derived recipe. + llvm_unreachable( + "VPWidenIntOrFpInductionRecipe generates its own backedge value"); + } + + VPRecipeBase &getBackedgeRecipe() override { + // TODO: All operands of base recipe must exist and be at same index in + // derived recipe. + llvm_unreachable( + "VPWidenIntOrFpInductionRecipe generates its own backedge value"); + } + + /// Returns the step value of the induction. + VPValue *getStepValue() { return getOperand(1); } + const VPValue *getStepValue() const { return getOperand(1); } + + /// Returns the first defined value as TruncInst, if it is one or nullptr + /// otherwise. + TruncInst *getTruncInst() { return Trunc; } + const TruncInst *getTruncInst() const { return Trunc; } + + PHINode *getPHINode() { return IV; } + + /// Returns the induction descriptor for the recipe. + const InductionDescriptor &getInductionDescriptor() const { return IndDesc; } + + /// Returns true if the induction is canonical, i.e. starting at 0 and + /// incremented by UF * VF (= the original IV is incremented by 1). + bool isCanonical() const; + + /// Returns the scalar type of the induction. + const Type *getScalarType() const { + return Trunc ? Trunc->getType() : IV->getType(); + } +}; + class VPWidenPointerInductionRecipe : public VPHeaderPHIRecipe { const InductionDescriptor &IndDesc; @@ -1374,12 +1618,20 @@ public: class VPInterleaveRecipe : public VPRecipeBase { const InterleaveGroup<Instruction> *IG; + /// Indicates if the interleave group is in a conditional block and requires a + /// mask. bool HasMask = false; + /// Indicates if gaps between members of the group need to be masked out or if + /// unusued gaps can be loaded speculatively. + bool NeedsMaskForGaps = false; + public: VPInterleaveRecipe(const InterleaveGroup<Instruction> *IG, VPValue *Addr, - ArrayRef<VPValue *> StoredValues, VPValue *Mask) - : VPRecipeBase(VPDef::VPInterleaveSC, {Addr}), IG(IG) { + ArrayRef<VPValue *> StoredValues, VPValue *Mask, + bool NeedsMaskForGaps) + : VPRecipeBase(VPDef::VPInterleaveSC, {Addr}), IG(IG), + NeedsMaskForGaps(NeedsMaskForGaps) { for (unsigned i = 0; i < IG->getFactor(); ++i) if (Instruction *I = IG->getMember(i)) { if (I->getType()->isVoidTy()) @@ -1490,28 +1742,21 @@ public: /// copies of the original scalar type, one per lane, instead of producing a /// single copy of widened type for all lanes. If the instruction is known to be /// uniform only one copy, per lane zero, will be generated. -class VPReplicateRecipe : public VPRecipeBase, public VPValue { +class VPReplicateRecipe : public VPRecipeWithIRFlags, public VPValue { /// Indicator if only a single replica per lane is needed. bool IsUniform; /// Indicator if the replicas are also predicated. bool IsPredicated; - /// Indicator if the scalar values should also be packed into a vector. - bool AlsoPack; - public: template <typename IterT> VPReplicateRecipe(Instruction *I, iterator_range<IterT> Operands, - bool IsUniform, bool IsPredicated = false) - : VPRecipeBase(VPDef::VPReplicateSC, Operands), VPValue(this, I), - IsUniform(IsUniform), IsPredicated(IsPredicated) { - // Retain the previous behavior of predicateInstructions(), where an - // insert-element of a predicated instruction got hoisted into the - // predicated basic block iff it was its only user. This is achieved by - // having predicated instructions also pack their values into a vector by - // default unless they have a replicated user which uses their scalar value. - AlsoPack = IsPredicated && !I->use_empty(); + bool IsUniform, VPValue *Mask = nullptr) + : VPRecipeWithIRFlags(VPDef::VPReplicateSC, Operands, *I), + VPValue(this, I), IsUniform(IsUniform), IsPredicated(Mask) { + if (Mask) + addOperand(Mask); } ~VPReplicateRecipe() override = default; @@ -1523,8 +1768,6 @@ public: /// the \p State. void execute(VPTransformState &State) override; - void setAlsoPack(bool Pack) { AlsoPack = Pack; } - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, @@ -1533,8 +1776,6 @@ public: bool isUniform() const { return IsUniform; } - bool isPacked() const { return AlsoPack; } - bool isPredicated() const { return IsPredicated; } /// Returns true if the recipe only uses the first lane of operand \p Op. @@ -1550,6 +1791,17 @@ public: "Op must be an operand of the recipe"); return true; } + + /// Returns true if the recipe is used by a widened recipe via an intervening + /// VPPredInstPHIRecipe. In this case, the scalar values should also be packed + /// in a vector. + bool shouldPack() const; + + /// Return the mask of a predicated VPReplicateRecipe. + VPValue *getMask() { + assert(isPredicated() && "Trying to get the mask of a unpredicated recipe"); + return getOperand(getNumOperands() - 1); + } }; /// A recipe for generating conditional branches on the bits of a mask. @@ -1791,9 +2043,11 @@ public: return true; } - /// Check if the induction described by \p ID is canonical, i.e. has the same - /// start, step (of 1), and type as the canonical IV. - bool isCanonical(const InductionDescriptor &ID, Type *Ty) const; + /// Check if the induction described by \p Kind, /p Start and \p Step is + /// canonical, i.e. has the same start, step (of 1), and type as the + /// canonical IV. + bool isCanonical(InductionDescriptor::InductionKind Kind, VPValue *Start, + VPValue *Step, Type *Ty) const; }; /// A recipe for generating the active lane mask for the vector loop that is @@ -2156,13 +2410,19 @@ public: /// to produce efficient output IR, including which branches, basic-blocks and /// output IR instructions to generate, and their cost. VPlan holds a /// Hierarchical-CFG of VPBasicBlocks and VPRegionBlocks rooted at an Entry -/// VPBlock. +/// VPBasicBlock. class VPlan { friend class VPlanPrinter; friend class VPSlotTracker; - /// Hold the single entry to the Hierarchical CFG of the VPlan. - VPBlockBase *Entry; + /// Hold the single entry to the Hierarchical CFG of the VPlan, i.e. the + /// preheader of the vector loop. + VPBasicBlock *Entry; + + /// VPBasicBlock corresponding to the original preheader. Used to place + /// VPExpandSCEV recipes for expressions used during skeleton creation and the + /// rest of VPlan execution. + VPBasicBlock *Preheader; /// Holds the VFs applicable to this VPlan. SmallSetVector<ElementCount, 2> VFs; @@ -2174,10 +2434,6 @@ class VPlan { /// Holds the name of the VPlan, for printing. std::string Name; - /// Holds all the external definitions created for this VPlan. External - /// definitions must be immutable and hold a pointer to their underlying IR. - DenseMap<Value *, VPValue *> VPExternalDefs; - /// Represents the trip count of the original loop, for folding /// the tail. VPValue *TripCount = nullptr; @@ -2193,9 +2449,9 @@ class VPlan { /// VPlan. Value2VPValueTy Value2VPValue; - /// Contains all VPValues that been allocated by addVPValue directly and need - /// to be free when the plan's destructor is called. - SmallVector<VPValue *, 16> VPValuesToFree; + /// Contains all the external definitions created for this VPlan. External + /// definitions are VPValues that hold a pointer to their underlying IR. + SmallVector<VPValue *, 16> VPLiveInsToFree; /// Indicates whether it is safe use the Value2VPValue mapping or if the /// mapping cannot be used any longer, because it is stale. @@ -2204,14 +2460,41 @@ class VPlan { /// Values used outside the plan. MapVector<PHINode *, VPLiveOut *> LiveOuts; + /// Mapping from SCEVs to the VPValues representing their expansions. + /// NOTE: This mapping is temporary and will be removed once all users have + /// been modeled in VPlan directly. + DenseMap<const SCEV *, VPValue *> SCEVToExpansion; + public: - VPlan(VPBlockBase *Entry = nullptr) : Entry(Entry) { - if (Entry) - Entry->setPlan(this); + /// Construct a VPlan with original preheader \p Preheader, trip count \p TC + /// and \p Entry to the plan. At the moment, \p Preheader and \p Entry need to + /// be disconnected, as the bypass blocks between them are not yet modeled in + /// VPlan. + VPlan(VPBasicBlock *Preheader, VPValue *TC, VPBasicBlock *Entry) + : VPlan(Preheader, Entry) { + TripCount = TC; + } + + /// Construct a VPlan with original preheader \p Preheader and \p Entry to + /// the plan. At the moment, \p Preheader and \p Entry need to be + /// disconnected, as the bypass blocks between them are not yet modeled in + /// VPlan. + VPlan(VPBasicBlock *Preheader, VPBasicBlock *Entry) + : Entry(Entry), Preheader(Preheader) { + Entry->setPlan(this); + Preheader->setPlan(this); + assert(Preheader->getNumSuccessors() == 0 && + Preheader->getNumPredecessors() == 0 && + "preheader must be disconnected"); } ~VPlan(); + /// Create an initial VPlan with preheader and entry blocks. Creates a + /// VPExpandSCEVRecipe for \p TripCount and uses it as plan's trip count. + static VPlanPtr createInitialVPlan(const SCEV *TripCount, + ScalarEvolution &PSE); + /// Prepare the plan for execution, setting up the required live-in values. void prepareToExecute(Value *TripCount, Value *VectorTripCount, Value *CanonicalIVStartValue, VPTransformState &State, @@ -2220,19 +2503,12 @@ public: /// Generate the IR code for this VPlan. void execute(VPTransformState *State); - VPBlockBase *getEntry() { return Entry; } - const VPBlockBase *getEntry() const { return Entry; } - - VPBlockBase *setEntry(VPBlockBase *Block) { - Entry = Block; - Block->setPlan(this); - return Entry; - } + VPBasicBlock *getEntry() { return Entry; } + const VPBasicBlock *getEntry() const { return Entry; } /// The trip count of the original loop. - VPValue *getOrCreateTripCount() { - if (!TripCount) - TripCount = new VPValue(); + VPValue *getTripCount() const { + assert(TripCount && "trip count needs to be set before accessing it"); return TripCount; } @@ -2275,50 +2551,35 @@ public: void setName(const Twine &newName) { Name = newName.str(); } - /// Get the existing or add a new external definition for \p V. - VPValue *getOrAddExternalDef(Value *V) { - auto I = VPExternalDefs.insert({V, nullptr}); - if (I.second) - I.first->second = new VPValue(V); - return I.first->second; - } - - void addVPValue(Value *V) { - assert(Value2VPValueEnabled && - "IR value to VPValue mapping may be out of date!"); - assert(V && "Trying to add a null Value to VPlan"); - assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); - VPValue *VPV = new VPValue(V); - Value2VPValue[V] = VPV; - VPValuesToFree.push_back(VPV); - } - void addVPValue(Value *V, VPValue *VPV) { - assert(Value2VPValueEnabled && "Value2VPValue mapping may be out of date!"); + assert((Value2VPValueEnabled || VPV->isLiveIn()) && + "Value2VPValue mapping may be out of date!"); assert(V && "Trying to add a null Value to VPlan"); assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); Value2VPValue[V] = VPV; } /// Returns the VPValue for \p V. \p OverrideAllowed can be used to disable - /// checking whether it is safe to query VPValues using IR Values. + /// /// checking whether it is safe to query VPValues using IR Values. VPValue *getVPValue(Value *V, bool OverrideAllowed = false) { - assert((OverrideAllowed || isa<Constant>(V) || Value2VPValueEnabled) && - "Value2VPValue mapping may be out of date!"); assert(V && "Trying to get the VPValue of a null Value"); assert(Value2VPValue.count(V) && "Value does not exist in VPlan"); + assert((Value2VPValueEnabled || OverrideAllowed || + Value2VPValue[V]->isLiveIn()) && + "Value2VPValue mapping may be out of date!"); return Value2VPValue[V]; } - /// Gets the VPValue or adds a new one (if none exists yet) for \p V. \p - /// OverrideAllowed can be used to disable checking whether it is safe to - /// query VPValues using IR Values. - VPValue *getOrAddVPValue(Value *V, bool OverrideAllowed = false) { - assert((OverrideAllowed || isa<Constant>(V) || Value2VPValueEnabled) && - "Value2VPValue mapping may be out of date!"); + /// Gets the VPValue for \p V or adds a new live-in (if none exists yet) for + /// \p V. + VPValue *getVPValueOrAddLiveIn(Value *V) { assert(V && "Trying to get or add the VPValue of a null Value"); - if (!Value2VPValue.count(V)) - addVPValue(V); + if (!Value2VPValue.count(V)) { + VPValue *VPV = new VPValue(V); + VPLiveInsToFree.push_back(VPV); + addVPValue(V, VPV); + } + return getVPValue(V); } @@ -2344,7 +2605,7 @@ public: iterator_range<mapped_iterator<Use *, std::function<VPValue *(Value *)>>> mapToVPValues(User::op_range Operands) { std::function<VPValue *(Value *)> Fn = [this](Value *Op) { - return getOrAddVPValue(Op); + return getVPValueOrAddLiveIn(Op); }; return map_range(Operands, Fn); } @@ -2373,12 +2634,6 @@ public: void addLiveOut(PHINode *PN, VPValue *V); - void clearLiveOuts() { - for (auto &KV : LiveOuts) - delete KV.second; - LiveOuts.clear(); - } - void removeLiveOut(PHINode *PN) { delete LiveOuts[PN]; LiveOuts.erase(PN); @@ -2388,6 +2643,19 @@ public: return LiveOuts; } + VPValue *getSCEVExpansion(const SCEV *S) const { + return SCEVToExpansion.lookup(S); + } + + void addSCEVExpansion(const SCEV *S, VPValue *V) { + assert(!SCEVToExpansion.contains(S) && "SCEV already expanded"); + SCEVToExpansion[S] = V; + } + + /// \return The block corresponding to the original preheader. + VPBasicBlock *getPreheader() { return Preheader; } + const VPBasicBlock *getPreheader() const { return Preheader; } + private: /// Add to the given dominator tree the header block and every new basic block /// that was created between it and the latch block, inclusive. @@ -2709,6 +2977,8 @@ inline bool isUniformAfterVectorization(VPValue *VPV) { assert(Def && "Must have definition for value defined inside vector region"); if (auto Rep = dyn_cast<VPReplicateRecipe>(Def)) return Rep->isUniform(); + if (auto *GEP = dyn_cast<VPWidenGEPRecipe>(Def)) + return all_of(GEP->operands(), isUniformAfterVectorization); return false; } } // end namespace vputils diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h index f790f7e73e11..89e2e7514dac 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanCFG.h +++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h @@ -13,6 +13,7 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H #include "VPlan.h" +#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/SmallVector.h" diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index 952ce72e36c1..f6e3a2a16db8 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -73,9 +73,8 @@ public: PlainCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P) : TheLoop(Lp), LI(LI), Plan(P) {} - /// Build plain CFG for TheLoop. Return the pre-header VPBasicBlock connected - /// to a new VPRegionBlock (TopRegion) enclosing the plain CFG. - VPBasicBlock *buildPlainCFG(); + /// Build plain CFG for TheLoop and connects it to Plan's entry. + void buildPlainCFG(); }; } // anonymous namespace @@ -196,7 +195,7 @@ VPValue *PlainCFGBuilder::getOrCreateVPOperand(Value *IRVal) { // A and B: Create VPValue and add it to the pool of external definitions and // to the Value->VPValue map. - VPValue *NewVPVal = Plan.getOrAddExternalDef(IRVal); + VPValue *NewVPVal = Plan.getVPValueOrAddLiveIn(IRVal); IRDef2VPValue[IRVal] = NewVPVal; return NewVPVal; } @@ -254,7 +253,7 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, } // Main interface to build the plain CFG. -VPBasicBlock *PlainCFGBuilder::buildPlainCFG() { +void PlainCFGBuilder::buildPlainCFG() { // 1. Scan the body of the loop in a topological order to visit each basic // block after having visited its predecessor basic blocks. Create a VPBB for // each BB and link it to its successor and predecessor VPBBs. Note that @@ -267,12 +266,13 @@ VPBasicBlock *PlainCFGBuilder::buildPlainCFG() { BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader(); assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) && "Unexpected loop preheader"); - VPBasicBlock *ThePreheaderVPBB = getOrCreateVPBB(ThePreheaderBB); + VPBasicBlock *ThePreheaderVPBB = Plan.getEntry(); + BB2VPBB[ThePreheaderBB] = ThePreheaderVPBB; ThePreheaderVPBB->setName("vector.ph"); for (auto &I : *ThePreheaderBB) { if (I.getType()->isVoidTy()) continue; - IRDef2VPValue[&I] = Plan.getOrAddExternalDef(&I); + IRDef2VPValue[&I] = Plan.getVPValueOrAddLiveIn(&I); } // Create empty VPBB for Loop H so that we can link PH->H. VPBlockBase *HeaderVPBB = getOrCreateVPBB(TheLoop->getHeader()); @@ -371,20 +371,17 @@ VPBasicBlock *PlainCFGBuilder::buildPlainCFG() { // have a VPlan couterpart. Fix VPlan phi nodes by adding their corresponding // VPlan operands. fixPhiNodes(); - - return ThePreheaderVPBB; } -VPBasicBlock *VPlanHCFGBuilder::buildPlainCFG() { +void VPlanHCFGBuilder::buildPlainCFG() { PlainCFGBuilder PCFGBuilder(TheLoop, LI, Plan); - return PCFGBuilder.buildPlainCFG(); + PCFGBuilder.buildPlainCFG(); } // Public interface to build a H-CFG. void VPlanHCFGBuilder::buildHierarchicalCFG() { - // Build Top Region enclosing the plain CFG and set it as VPlan entry. - VPBasicBlock *EntryVPBB = buildPlainCFG(); - Plan.setEntry(EntryVPBB); + // Build Top Region enclosing the plain CFG. + buildPlainCFG(); LLVM_DEBUG(Plan.setName("HCFGBuilder: Plain CFG\n"); dbgs() << Plan); VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h index 2d52990af268..299ae36155cb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h @@ -57,9 +57,8 @@ private: // are introduced. VPDominatorTree VPDomTree; - /// Build plain CFG for TheLoop. Return the pre-header VPBasicBlock connected - /// to a new VPRegionBlock (TopRegion) enclosing the plain CFG. - VPBasicBlock *buildPlainCFG(); + /// Build plain CFG for TheLoop and connects it to Plan's entry. + void buildPlainCFG(); public: VPlanHCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 4e9be35001ad..26c309eed800 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -34,7 +34,9 @@ using namespace llvm; using VectorParts = SmallVector<Value *, 2>; +namespace llvm { extern cl::opt<bool> EnableVPlanNativePath; +} #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME @@ -50,14 +52,16 @@ bool VPRecipeBase::mayWriteToMemory() const { ->mayWriteToMemory(); case VPBranchOnMaskSC: case VPScalarIVStepsSC: + case VPPredInstPHISC: return false; - case VPWidenIntOrFpInductionSC: + case VPBlendSC: + case VPReductionSC: case VPWidenCanonicalIVSC: + case VPWidenCastSC: + case VPWidenGEPSC: + case VPWidenIntOrFpInductionSC: case VPWidenPHISC: - case VPBlendSC: case VPWidenSC: - case VPWidenGEPSC: - case VPReductionSC: case VPWidenSelectSC: { const Instruction *I = dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); @@ -82,14 +86,16 @@ bool VPRecipeBase::mayReadFromMemory() const { ->mayReadFromMemory(); case VPBranchOnMaskSC: case VPScalarIVStepsSC: + case VPPredInstPHISC: return false; - case VPWidenIntOrFpInductionSC: + case VPBlendSC: + case VPReductionSC: case VPWidenCanonicalIVSC: + case VPWidenCastSC: + case VPWidenGEPSC: + case VPWidenIntOrFpInductionSC: case VPWidenPHISC: - case VPBlendSC: case VPWidenSC: - case VPWidenGEPSC: - case VPReductionSC: case VPWidenSelectSC: { const Instruction *I = dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); @@ -108,16 +114,20 @@ bool VPRecipeBase::mayHaveSideEffects() const { case VPDerivedIVSC: case VPPredInstPHISC: return false; - case VPWidenIntOrFpInductionSC: - case VPWidenPointerInductionSC: + case VPWidenCallSC: + return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) + ->mayHaveSideEffects(); + case VPBlendSC: + case VPReductionSC: + case VPScalarIVStepsSC: case VPWidenCanonicalIVSC: + case VPWidenCastSC: + case VPWidenGEPSC: + case VPWidenIntOrFpInductionSC: case VPWidenPHISC: - case VPBlendSC: + case VPWidenPointerInductionSC: case VPWidenSC: - case VPWidenGEPSC: - case VPReductionSC: - case VPWidenSelectSC: - case VPScalarIVStepsSC: { + case VPWidenSelectSC: { const Instruction *I = dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); (void)I; @@ -125,6 +135,13 @@ bool VPRecipeBase::mayHaveSideEffects() const { "underlying instruction has side-effects"); return false; } + case VPWidenMemoryInstructionSC: + assert(cast<VPWidenMemoryInstructionRecipe>(this) + ->getIngredient() + .mayHaveSideEffects() == mayWriteToMemory() && + "mayHaveSideffects result for ingredient differs from this " + "implementation"); + return mayWriteToMemory(); case VPReplicateSC: { auto *R = cast<VPReplicateRecipe>(this); return R->getUnderlyingInstr()->mayHaveSideEffects(); @@ -143,6 +160,16 @@ void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) { State.Builder.GetInsertBlock()); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPLiveOut::print(raw_ostream &O, VPSlotTracker &SlotTracker) const { + O << "Live-out "; + getPhi()->printAsOperand(O); + O << " = "; + getOperand(0)->printAsOperand(O, SlotTracker); + O << "\n"; +} +#endif + void VPRecipeBase::insertBefore(VPRecipeBase *InsertPos) { assert(!Parent && "Recipe already in some VPBasicBlock"); assert(InsertPos->getParent() && @@ -189,55 +216,44 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB, insertBefore(BB, I); } -void VPInstruction::generateInstruction(VPTransformState &State, - unsigned Part) { +Value *VPInstruction::generateInstruction(VPTransformState &State, + unsigned Part) { IRBuilderBase &Builder = State.Builder; Builder.SetCurrentDebugLocation(DL); if (Instruction::isBinaryOp(getOpcode())) { Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); - Value *V = - Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); - State.set(this, V, Part); - return; + return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); } switch (getOpcode()) { case VPInstruction::Not: { Value *A = State.get(getOperand(0), Part); - Value *V = Builder.CreateNot(A, Name); - State.set(this, V, Part); - break; + return Builder.CreateNot(A, Name); } case VPInstruction::ICmpULE: { Value *IV = State.get(getOperand(0), Part); Value *TC = State.get(getOperand(1), Part); - Value *V = Builder.CreateICmpULE(IV, TC, Name); - State.set(this, V, Part); - break; + return Builder.CreateICmpULE(IV, TC, Name); } case Instruction::Select: { Value *Cond = State.get(getOperand(0), Part); Value *Op1 = State.get(getOperand(1), Part); Value *Op2 = State.get(getOperand(2), Part); - Value *V = Builder.CreateSelect(Cond, Op1, Op2, Name); - State.set(this, V, Part); - break; + return Builder.CreateSelect(Cond, Op1, Op2, Name); } case VPInstruction::ActiveLaneMask: { // Get first lane of vector induction variable. Value *VIVElem0 = State.get(getOperand(0), VPIteration(Part, 0)); // Get the original loop tripcount. - Value *ScalarTC = State.get(getOperand(1), Part); + Value *ScalarTC = State.get(getOperand(1), VPIteration(Part, 0)); auto *Int1Ty = Type::getInt1Ty(Builder.getContext()); auto *PredTy = VectorType::get(Int1Ty, State.VF); - Instruction *Call = Builder.CreateIntrinsic( - Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()}, - {VIVElem0, ScalarTC}, nullptr, Name); - State.set(this, Call, Part); - break; + return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, + {PredTy, ScalarTC->getType()}, + {VIVElem0, ScalarTC}, nullptr, Name); } case VPInstruction::FirstOrderRecurrenceSplice: { // Generate code to combine the previous and current values in vector v3. @@ -255,18 +271,22 @@ void VPInstruction::generateInstruction(VPTransformState &State, // For the first part, use the recurrence phi (v1), otherwise v2. auto *V1 = State.get(getOperand(0), 0); Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1); - if (!PartMinus1->getType()->isVectorTy()) { - State.set(this, PartMinus1, Part); - } else { - Value *V2 = State.get(getOperand(1), Part); - State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1, Name), - Part); - } - break; + if (!PartMinus1->getType()->isVectorTy()) + return PartMinus1; + Value *V2 = State.get(getOperand(1), Part); + return Builder.CreateVectorSplice(PartMinus1, V2, -1, Name); + } + case VPInstruction::CalculateTripCountMinusVF: { + Value *ScalarTC = State.get(getOperand(0), {0, 0}); + Value *Step = + createStepForVF(Builder, ScalarTC->getType(), State.VF, State.UF); + Value *Sub = Builder.CreateSub(ScalarTC, Step); + Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step); + Value *Zero = ConstantInt::get(ScalarTC->getType(), 0); + return Builder.CreateSelect(Cmp, Sub, Zero); } case VPInstruction::CanonicalIVIncrement: case VPInstruction::CanonicalIVIncrementNUW: { - Value *Next = nullptr; if (Part == 0) { bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW; auto *Phi = State.get(getOperand(0), 0); @@ -274,34 +294,26 @@ void VPInstruction::generateInstruction(VPTransformState &State, // elements) times the unroll factor (num of SIMD instructions). Value *Step = createStepForVF(Builder, Phi->getType(), State.VF, State.UF); - Next = Builder.CreateAdd(Phi, Step, Name, IsNUW, false); - } else { - Next = State.get(this, 0); + return Builder.CreateAdd(Phi, Step, Name, IsNUW, false); } - - State.set(this, Next, Part); - break; + return State.get(this, 0); } case VPInstruction::CanonicalIVIncrementForPart: case VPInstruction::CanonicalIVIncrementForPartNUW: { bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW; auto *IV = State.get(getOperand(0), VPIteration(0, 0)); - if (Part == 0) { - State.set(this, IV, Part); - break; - } + if (Part == 0) + return IV; // The canonical IV is incremented by the vectorization factor (num of SIMD // elements) times the unroll part. Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part); - Value *Next = Builder.CreateAdd(IV, Step, Name, IsNUW, false); - State.set(this, Next, Part); - break; + return Builder.CreateAdd(IV, Step, Name, IsNUW, false); } case VPInstruction::BranchOnCond: { if (Part != 0) - break; + return nullptr; Value *Cond = State.get(getOperand(0), VPIteration(Part, 0)); VPRegionBlock *ParentRegion = getParent()->getParent(); @@ -318,11 +330,11 @@ void VPInstruction::generateInstruction(VPTransformState &State, CondBr->setSuccessor(0, nullptr); Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); - break; + return CondBr; } case VPInstruction::BranchOnCount: { if (Part != 0) - break; + return nullptr; // First create the compare. Value *IV = State.get(getOperand(0), Part); Value *TC = State.get(getOperand(1), Part); @@ -342,7 +354,7 @@ void VPInstruction::generateInstruction(VPTransformState &State, State.CFG.VPBB2IRBB[Header]); CondBr->setSuccessor(0, nullptr); Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); - break; + return CondBr; } default: llvm_unreachable("Unsupported opcode for instruction"); @@ -353,8 +365,13 @@ void VPInstruction::execute(VPTransformState &State) { assert(!State.Instance && "VPInstruction executing an Instance"); IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); State.Builder.setFastMathFlags(FMF); - for (unsigned Part = 0; Part < State.UF; ++Part) - generateInstruction(State, Part); + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *GeneratedValue = generateInstruction(State, Part); + if (!hasResult()) + continue; + assert(GeneratedValue && "generateInstruction must produce a value"); + State.set(this, GeneratedValue, Part); + } } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -400,6 +417,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, case VPInstruction::BranchOnCond: O << "branch-on-cond"; break; + case VPInstruction::CalculateTripCountMinusVF: + O << "TC > VF ? TC - VF : 0"; + break; case VPInstruction::CanonicalIVIncrementForPart: O << "VF * Part + "; break; @@ -438,18 +458,19 @@ void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { } void VPWidenCallRecipe::execute(VPTransformState &State) { + assert(State.VF.isVector() && "not widening"); auto &CI = *cast<CallInst>(getUnderlyingInstr()); assert(!isa<DbgInfoIntrinsic>(CI) && "DbgInfoIntrinsic should have been dropped during VPlan construction"); State.setDebugLocFromInst(&CI); - SmallVector<Type *, 4> Tys; - for (Value *ArgOperand : CI.args()) - Tys.push_back( - ToVectorTy(ArgOperand->getType(), State.VF.getKnownMinValue())); - for (unsigned Part = 0; Part < State.UF; ++Part) { - SmallVector<Type *, 2> TysForDecl = {CI.getType()}; + SmallVector<Type *, 2> TysForDecl; + // Add return type if intrinsic is overloaded on it. + if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1)) { + TysForDecl.push_back( + VectorType::get(CI.getType()->getScalarType(), State.VF)); + } SmallVector<Value *, 4> Args; for (const auto &I : enumerate(operands())) { // Some intrinsics have a scalar argument - don't replace it with a @@ -468,21 +489,16 @@ void VPWidenCallRecipe::execute(VPTransformState &State) { Function *VectorF; if (VectorIntrinsicID != Intrinsic::not_intrinsic) { // Use vector version of the intrinsic. - if (State.VF.isVector()) - TysForDecl[0] = - VectorType::get(CI.getType()->getScalarType(), State.VF); Module *M = State.Builder.GetInsertBlock()->getModule(); VectorF = Intrinsic::getDeclaration(M, VectorIntrinsicID, TysForDecl); assert(VectorF && "Can't retrieve vector intrinsic."); } else { - // Use vector version of the function call. - const VFShape Shape = VFShape::get(CI, State.VF, false /*HasGlobalPred*/); #ifndef NDEBUG - assert(VFDatabase(CI).getVectorizedFunction(Shape) != nullptr && - "Can't create vector function."); + assert(Variant != nullptr && "Can't create vector function."); #endif - VectorF = VFDatabase(CI).getVectorizedFunction(Shape); + VectorF = Variant; } + SmallVector<OperandBundleDef, 1> OpBundles; CI.getOperandBundlesAsDefs(OpBundles); CallInst *V = State.Builder.CreateCall(VectorF, Args, OpBundles); @@ -514,8 +530,12 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent, if (VectorIntrinsicID) O << " (using vector intrinsic)"; - else - O << " (using library function)"; + else { + O << " (using library function"; + if (Variant->hasName()) + O << ": " << Variant->getName(); + O << ")"; + } } void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent, @@ -528,7 +548,7 @@ void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent, getOperand(1)->printAsOperand(O, SlotTracker); O << ", "; getOperand(2)->printAsOperand(O, SlotTracker); - O << (InvariantCond ? " (condition is loop invariant)" : ""); + O << (isInvariantCond() ? " (condition is loop invariant)" : ""); } #endif @@ -541,10 +561,10 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) { // We have to take the 'vectorized' value and pick the first lane. // Instcombine will make this a no-op. auto *InvarCond = - InvariantCond ? State.get(getOperand(0), VPIteration(0, 0)) : nullptr; + isInvariantCond() ? State.get(getCond(), VPIteration(0, 0)) : nullptr; for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *Cond = InvarCond ? InvarCond : State.get(getOperand(0), Part); + Value *Cond = InvarCond ? InvarCond : State.get(getCond(), Part); Value *Op0 = State.get(getOperand(1), Part); Value *Op1 = State.get(getOperand(2), Part); Value *Sel = State.Builder.CreateSelect(Cond, Op0, Op1); @@ -553,6 +573,33 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) { } } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { + switch (OpType) { + case OperationType::PossiblyExactOp: + if (ExactFlags.IsExact) + O << " exact"; + break; + case OperationType::OverflowingBinOp: + if (WrapFlags.HasNUW) + O << " nuw"; + if (WrapFlags.HasNSW) + O << " nsw"; + break; + case OperationType::FPMathOp: + getFastMathFlags().print(O); + break; + case OperationType::GEPOp: + if (GEPFlags.IsInBounds) + O << " inbounds"; + break; + case OperationType::Other: + break; + } + O << " "; +} +#endif + void VPWidenRecipe::execute(VPTransformState &State) { auto &I = *cast<Instruction>(getUnderlyingValue()); auto &Builder = State.Builder; @@ -592,17 +639,8 @@ void VPWidenRecipe::execute(VPTransformState &State) { Value *V = Builder.CreateNAryOp(I.getOpcode(), Ops); - if (auto *VecOp = dyn_cast<Instruction>(V)) { - VecOp->copyIRFlags(&I); - - // If the instruction is vectorized and was in a basic block that needed - // predication, we can't propagate poison-generating flags (nuw/nsw, - // exact, etc.). The control flow has been linearized and the - // instruction is no longer guarded by the predicate, which could make - // the flag properties to no longer hold. - if (State.MayGeneratePoisonRecipes.contains(this)) - VecOp->dropPoisonGeneratingFlags(); - } + if (auto *VecOp = dyn_cast<Instruction>(V)) + setFlags(VecOp); // Use this vector value for all users of the original instruction. State.set(this, V, Part); @@ -646,35 +684,6 @@ void VPWidenRecipe::execute(VPTransformState &State) { break; } - - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::FPExt: - case Instruction::PtrToInt: - case Instruction::IntToPtr: - case Instruction::SIToFP: - case Instruction::UIToFP: - case Instruction::Trunc: - case Instruction::FPTrunc: - case Instruction::BitCast: { - auto *CI = cast<CastInst>(&I); - State.setDebugLocFromInst(CI); - - /// Vectorize casts. - Type *DestTy = (State.VF.isScalar()) - ? CI->getType() - : VectorType::get(CI->getType(), State.VF); - - for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *A = State.get(getOperand(0), Part); - Value *Cast = Builder.CreateCast(CI->getOpcode(), A, DestTy); - State.set(this, Cast, Part); - State.addMetadata(Cast, &I); - } - break; - } default: // This instruction is not vectorized by simple widening. LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I); @@ -687,10 +696,39 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent, O << Indent << "WIDEN "; printAsOperand(O, SlotTracker); const Instruction *UI = getUnderlyingInstr(); - O << " = " << UI->getOpcodeName() << " "; + O << " = " << UI->getOpcodeName(); + printFlags(O); if (auto *Cmp = dyn_cast<CmpInst>(UI)) - O << CmpInst::getPredicateName(Cmp->getPredicate()) << " "; + O << Cmp->getPredicate() << " "; + printOperands(O, SlotTracker); +} +#endif + +void VPWidenCastRecipe::execute(VPTransformState &State) { + auto *I = cast_or_null<Instruction>(getUnderlyingValue()); + if (I) + State.setDebugLocFromInst(I); + auto &Builder = State.Builder; + /// Vectorize casts. + assert(State.VF.isVector() && "Not vectorizing?"); + Type *DestTy = VectorType::get(getResultType(), State.VF); + + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *A = State.get(getOperand(0), Part); + Value *Cast = Builder.CreateCast(Instruction::CastOps(Opcode), A, DestTy); + State.set(this, Cast, Part); + State.addMetadata(Cast, I); + } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-CAST "; + printAsOperand(O, SlotTracker); + O << " = " << Instruction::getOpcodeName(Opcode) << " "; printOperands(O, SlotTracker); + O << " to " << *getResultType(); } void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent, @@ -710,8 +748,13 @@ void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent, #endif bool VPWidenIntOrFpInductionRecipe::isCanonical() const { + // The step may be defined by a recipe in the preheader (e.g. if it requires + // SCEV expansion), but for the canonical induction the step is required to be + // 1, which is represented as live-in. + if (getStepValue()->getDefiningRecipe()) + return false; + auto *StepC = dyn_cast<ConstantInt>(getStepValue()->getLiveInIRValue()); auto *StartC = dyn_cast<ConstantInt>(getStartValue()->getLiveInIRValue()); - auto *StepC = dyn_cast<SCEVConstant>(getInductionDescriptor().getStep()); return StartC && StartC->isZero() && StepC && StepC->isOne(); } @@ -743,6 +786,7 @@ void VPScalarIVStepsRecipe::print(raw_ostream &O, const Twine &Indent, #endif void VPWidenGEPRecipe::execute(VPTransformState &State) { + assert(State.VF.isVector() && "not widening"); auto *GEP = cast<GetElementPtrInst>(getUnderlyingInstr()); // Construct a vector GEP by widening the operands of the scalar GEP as // necessary. We mark the vector GEP 'inbounds' if appropriate. A GEP @@ -750,7 +794,7 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { // is vector-typed. Thus, to keep the representation compact, we only use // vector-typed operands for loop-varying values. - if (State.VF.isVector() && IsPtrLoopInvariant && IsIndexLoopInvariant.all()) { + if (areAllOperandsInvariant()) { // If we are vectorizing, but the GEP has only loop-invariant operands, // the GEP we build (by only using vector-typed operands for // loop-varying values) would be a scalar pointer. Thus, to ensure we @@ -763,9 +807,15 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { // required. We would add the scalarization decision to // collectLoopScalars() and teach getVectorValue() to broadcast // the lane-zero scalar value. - auto *Clone = State.Builder.Insert(GEP->clone()); + SmallVector<Value *> Ops; + for (unsigned I = 0, E = getNumOperands(); I != E; I++) + Ops.push_back(State.get(getOperand(I), VPIteration(0, 0))); + + auto *NewGEP = + State.Builder.CreateGEP(GEP->getSourceElementType(), Ops[0], + ArrayRef(Ops).drop_front(), "", isInBounds()); for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *EntryPart = State.Builder.CreateVectorSplat(State.VF, Clone); + Value *EntryPart = State.Builder.CreateVectorSplat(State.VF, NewGEP); State.set(this, EntryPart, Part); State.addMetadata(EntryPart, GEP); } @@ -780,7 +830,7 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { for (unsigned Part = 0; Part < State.UF; ++Part) { // The pointer operand of the new GEP. If it's loop-invariant, we // won't broadcast it. - auto *Ptr = IsPtrLoopInvariant + auto *Ptr = isPointerLoopInvariant() ? State.get(getOperand(0), VPIteration(0, 0)) : State.get(getOperand(0), Part); @@ -789,24 +839,16 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { SmallVector<Value *, 4> Indices; for (unsigned I = 1, E = getNumOperands(); I < E; I++) { VPValue *Operand = getOperand(I); - if (IsIndexLoopInvariant[I - 1]) + if (isIndexLoopInvariant(I - 1)) Indices.push_back(State.get(Operand, VPIteration(0, 0))); else Indices.push_back(State.get(Operand, Part)); } - // If the GEP instruction is vectorized and was in a basic block that - // needed predication, we can't propagate the poison-generating 'inbounds' - // flag. The control flow has been linearized and the GEP is no longer - // guarded by the predicate, which could make the 'inbounds' properties to - // no longer hold. - bool IsInBounds = - GEP->isInBounds() && State.MayGeneratePoisonRecipes.count(this) == 0; - // Create the new GEP. Note that this GEP may be a scalar if VF == 1, // but it should be a vector, otherwise. auto *NewGEP = State.Builder.CreateGEP(GEP->getSourceElementType(), Ptr, - Indices, "", IsInBounds); + Indices, "", isInBounds()); assert((State.VF.isScalar() || NewGEP->getType()->isVectorTy()) && "NewGEP is not a pointer vector"); State.set(this, NewGEP, Part); @@ -819,14 +861,14 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "WIDEN-GEP "; - O << (IsPtrLoopInvariant ? "Inv" : "Var"); - size_t IndicesNumber = IsIndexLoopInvariant.size(); - for (size_t I = 0; I < IndicesNumber; ++I) - O << "[" << (IsIndexLoopInvariant[I] ? "Inv" : "Var") << "]"; + O << (isPointerLoopInvariant() ? "Inv" : "Var"); + for (size_t I = 0; I < getNumOperands() - 1; ++I) + O << "[" << (isIndexLoopInvariant(I) ? "Inv" : "Var") << "]"; O << " "; printAsOperand(O, SlotTracker); - O << " = getelementptr "; + O << " = getelementptr"; + printFlags(O); printOperands(O, SlotTracker); } #endif @@ -911,7 +953,21 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, O << " (with final reduction value stored in invariant address sank " "outside of loop)"; } +#endif + +bool VPReplicateRecipe::shouldPack() const { + // Find if the recipe is used by a widened recipe via an intervening + // VPPredInstPHIRecipe. In this case, also pack the scalar values in a vector. + return any_of(users(), [](const VPUser *U) { + if (auto *PredR = dyn_cast<VPPredInstPHIRecipe>(U)) + return any_of(PredR->users(), [PredR](const VPUser *U) { + return !U->usesScalars(PredR); + }); + return false; + }); +} +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << (IsUniform ? "CLONE " : "REPLICATE "); @@ -921,18 +977,21 @@ void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent, O << " = "; } if (auto *CB = dyn_cast<CallBase>(getUnderlyingInstr())) { - O << "call @" << CB->getCalledFunction()->getName() << "("; + O << "call"; + printFlags(O); + O << "@" << CB->getCalledFunction()->getName() << "("; interleaveComma(make_range(op_begin(), op_begin() + (getNumOperands() - 1)), O, [&O, &SlotTracker](VPValue *Op) { Op->printAsOperand(O, SlotTracker); }); O << ")"; } else { - O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode()) << " "; + O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode()); + printFlags(O); printOperands(O, SlotTracker); } - if (AlsoPack) + if (shouldPack()) O << " (S->V)"; } #endif @@ -1053,20 +1112,22 @@ void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, } #endif -bool VPCanonicalIVPHIRecipe::isCanonical(const InductionDescriptor &ID, - Type *Ty) const { - if (Ty != getScalarType()) +bool VPCanonicalIVPHIRecipe::isCanonical( + InductionDescriptor::InductionKind Kind, VPValue *Start, VPValue *Step, + Type *Ty) const { + // The types must match and it must be an integer induction. + if (Ty != getScalarType() || Kind != InductionDescriptor::IK_IntInduction) return false; - // The start value of ID must match the start value of this canonical - // induction. - if (getStartValue()->getLiveInIRValue() != ID.getStartValue()) + // Start must match the start value of this canonical induction. + if (Start != getStartValue()) return false; - ConstantInt *Step = ID.getConstIntStepValue(); - // ID must also be incremented by one. IK_IntInduction always increment the - // induction by Step, but the binary op may not be set. - return ID.getKind() == InductionDescriptor::IK_IntInduction && Step && - Step->isOne(); + // If the step is defined by a recipe, it is not a ConstantInt. + if (Step->getDefiningRecipe()) + return false; + + ConstantInt *StepC = dyn_cast<ConstantInt>(Step->getLiveInIRValue()); + return StepC && StepC->isOne(); } bool VPWidenPointerInductionRecipe::onlyScalarsGenerated(ElementCount VF) { @@ -1092,9 +1153,11 @@ void VPExpandSCEVRecipe::execute(VPTransformState &State) { Value *Res = Exp.expandCodeFor(Expr, Expr->getType(), &*State.Builder.GetInsertPoint()); - + assert(!State.ExpandedSCEVs.contains(Expr) && + "Same SCEV expanded multiple times"); + State.ExpandedSCEVs[Expr] = Res; for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) - State.set(this, Res, Part); + State.set(this, Res, {Part, 0}); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index cbf111b00e3d..83bfdfd09d19 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "VPlanTransforms.h" +#include "VPlanDominatorTree.h" +#include "VPRecipeBuilder.h" #include "VPlanCFG.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" @@ -22,11 +24,10 @@ using namespace llvm; void VPlanTransforms::VPInstructionsToVPRecipes( - Loop *OrigLoop, VPlanPtr &Plan, + VPlanPtr &Plan, function_ref<const InductionDescriptor *(PHINode *)> GetIntOrFpInductionDescriptor, - SmallPtrSetImpl<Instruction *> &DeadInstructions, ScalarEvolution &SE, - const TargetLibraryInfo &TLI) { + ScalarEvolution &SE, const TargetLibraryInfo &TLI) { ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( Plan->getEntry()); @@ -39,22 +40,15 @@ void VPlanTransforms::VPInstructionsToVPRecipes( VPValue *VPV = Ingredient.getVPSingleValue(); Instruction *Inst = cast<Instruction>(VPV->getUnderlyingValue()); - if (DeadInstructions.count(Inst)) { - VPValue DummyValue; - VPV->replaceAllUsesWith(&DummyValue); - Ingredient.eraseFromParent(); - continue; - } VPRecipeBase *NewRecipe = nullptr; if (auto *VPPhi = dyn_cast<VPWidenPHIRecipe>(&Ingredient)) { auto *Phi = cast<PHINode>(VPPhi->getUnderlyingValue()); if (const auto *II = GetIntOrFpInductionDescriptor(Phi)) { - VPValue *Start = Plan->getOrAddVPValue(II->getStartValue()); + VPValue *Start = Plan->getVPValueOrAddLiveIn(II->getStartValue()); VPValue *Step = vputils::getOrCreateVPValueForSCEVExpr(*Plan, II->getStep(), SE); - NewRecipe = - new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II, true); + NewRecipe = new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II); } else { Plan->addVPValue(Phi, VPPhi); continue; @@ -66,28 +60,25 @@ void VPlanTransforms::VPInstructionsToVPRecipes( // Create VPWidenMemoryInstructionRecipe for loads and stores. if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) { NewRecipe = new VPWidenMemoryInstructionRecipe( - *Load, Plan->getOrAddVPValue(getLoadStorePointerOperand(Inst)), - nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/); + *Load, Ingredient.getOperand(0), nullptr /*Mask*/, + false /*Consecutive*/, false /*Reverse*/); } else if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) { NewRecipe = new VPWidenMemoryInstructionRecipe( - *Store, Plan->getOrAddVPValue(getLoadStorePointerOperand(Inst)), - Plan->getOrAddVPValue(Store->getValueOperand()), nullptr /*Mask*/, - false /*Consecutive*/, false /*Reverse*/); + *Store, Ingredient.getOperand(1), Ingredient.getOperand(0), + nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { - NewRecipe = new VPWidenGEPRecipe( - GEP, Plan->mapToVPValues(GEP->operands()), OrigLoop); + NewRecipe = new VPWidenGEPRecipe(GEP, Ingredient.operands()); } else if (CallInst *CI = dyn_cast<CallInst>(Inst)) { NewRecipe = - new VPWidenCallRecipe(*CI, Plan->mapToVPValues(CI->args()), + new VPWidenCallRecipe(*CI, drop_end(Ingredient.operands()), getVectorIntrinsicIDForCall(CI, &TLI)); } else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { - bool InvariantCond = - SE.isLoopInvariant(SE.getSCEV(SI->getOperand(0)), OrigLoop); - NewRecipe = new VPWidenSelectRecipe( - *SI, Plan->mapToVPValues(SI->operands()), InvariantCond); + NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands()); + } else if (auto *CI = dyn_cast<CastInst>(Inst)) { + NewRecipe = new VPWidenCastRecipe( + CI->getOpcode(), Ingredient.getOperand(0), CI->getType(), CI); } else { - NewRecipe = - new VPWidenRecipe(*Inst, Plan->mapToVPValues(Inst->operands())); + NewRecipe = new VPWidenRecipe(*Inst, Ingredient.operands()); } } @@ -98,15 +89,11 @@ void VPlanTransforms::VPInstructionsToVPRecipes( assert(NewRecipe->getNumDefinedValues() == 0 && "Only recpies with zero or one defined values expected"); Ingredient.eraseFromParent(); - Plan->removeVPValueFor(Inst); - for (auto *Def : NewRecipe->definedValues()) { - Plan->addVPValue(Inst, Def); - } } } } -bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { +static bool sinkScalarOperands(VPlan &Plan) { auto Iter = vp_depth_first_deep(Plan.getEntry()); bool Changed = false; // First, collect the operands of all recipes in replicate blocks as seeds for @@ -167,8 +154,7 @@ bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { continue; Instruction *I = cast<Instruction>( cast<VPReplicateRecipe>(SinkCandidate)->getUnderlyingValue()); - auto *Clone = - new VPReplicateRecipe(I, SinkCandidate->operands(), true, false); + auto *Clone = new VPReplicateRecipe(I, SinkCandidate->operands(), true); // TODO: add ".cloned" suffix to name of Clone's VPValue. Clone->insertBefore(SinkCandidate); @@ -224,7 +210,10 @@ static VPBasicBlock *getPredicatedThenBlock(VPRegionBlock *R) { return nullptr; } -bool VPlanTransforms::mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { +// Merge replicate regions in their successor region, if a replicate region +// is connected to a successor replicate region with the same predicate by a +// single, empty VPBasicBlock. +static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { SetVector<VPRegionBlock *> DeletedRegions; // Collect replicate regions followed by an empty block, followed by another @@ -312,6 +301,81 @@ bool VPlanTransforms::mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { return !DeletedRegions.empty(); } +static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, + VPlan &Plan) { + Instruction *Instr = PredRecipe->getUnderlyingInstr(); + // Build the triangular if-then region. + std::string RegionName = (Twine("pred.") + Instr->getOpcodeName()).str(); + assert(Instr->getParent() && "Predicated instruction not in any basic block"); + auto *BlockInMask = PredRecipe->getMask(); + auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask); + auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe); + + // Replace predicated replicate recipe with a replicate recipe without a + // mask but in the replicate region. + auto *RecipeWithoutMask = new VPReplicateRecipe( + PredRecipe->getUnderlyingInstr(), + make_range(PredRecipe->op_begin(), std::prev(PredRecipe->op_end())), + PredRecipe->isUniform()); + auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask); + + VPPredInstPHIRecipe *PHIRecipe = nullptr; + if (PredRecipe->getNumUsers() != 0) { + PHIRecipe = new VPPredInstPHIRecipe(RecipeWithoutMask); + PredRecipe->replaceAllUsesWith(PHIRecipe); + PHIRecipe->setOperand(0, RecipeWithoutMask); + } + PredRecipe->eraseFromParent(); + auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); + VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true); + + // Note: first set Entry as region entry and then connect successors starting + // from it in order, to propagate the "parent" of each VPBasicBlock. + VPBlockUtils::insertTwoBlocksAfter(Pred, Exiting, Entry); + VPBlockUtils::connectBlocks(Pred, Exiting); + + return Region; +} + +static void addReplicateRegions(VPlan &Plan) { + SmallVector<VPReplicateRecipe *> WorkList; + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( + vp_depth_first_deep(Plan.getEntry()))) { + for (VPRecipeBase &R : *VPBB) + if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R)) { + if (RepR->isPredicated()) + WorkList.push_back(RepR); + } + } + + unsigned BBNum = 0; + for (VPReplicateRecipe *RepR : WorkList) { + VPBasicBlock *CurrentBlock = RepR->getParent(); + VPBasicBlock *SplitBlock = CurrentBlock->splitAt(RepR->getIterator()); + + BasicBlock *OrigBB = RepR->getUnderlyingInstr()->getParent(); + SplitBlock->setName( + OrigBB->hasName() ? OrigBB->getName() + "." + Twine(BBNum++) : ""); + // Record predicated instructions for above packing optimizations. + VPBlockBase *Region = createReplicateRegion(RepR, Plan); + Region->setParent(CurrentBlock->getParent()); + VPBlockUtils::disconnectBlocks(CurrentBlock, SplitBlock); + VPBlockUtils::connectBlocks(CurrentBlock, Region); + VPBlockUtils::connectBlocks(Region, SplitBlock); + } +} + +void VPlanTransforms::createAndOptimizeReplicateRegions(VPlan &Plan) { + // Convert masked VPReplicateRecipes to if-then region blocks. + addReplicateRegions(Plan); + + bool ShouldSimplify = true; + while (ShouldSimplify) { + ShouldSimplify = sinkScalarOperands(Plan); + ShouldSimplify |= mergeReplicateRegionsIntoSuccessors(Plan); + ShouldSimplify |= VPlanTransforms::mergeBlocksIntoPredecessors(Plan); + } +} bool VPlanTransforms::mergeBlocksIntoPredecessors(VPlan &Plan) { SmallVector<VPBasicBlock *> WorkList; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -395,7 +459,10 @@ void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) { // everything WidenNewIV's users need. That is, WidenOriginalIV will // generate a vector phi or all users of WidenNewIV demand the first lane // only. - if (WidenOriginalIV->needsVectorIV() || + if (any_of(WidenOriginalIV->users(), + [WidenOriginalIV](VPUser *U) { + return !U->usesScalars(WidenOriginalIV); + }) || vputils::onlyFirstLaneUsed(WidenNewIV)) { WidenNewIV->replaceAllUsesWith(WidenOriginalIV); WidenNewIV->eraseFromParent(); @@ -440,10 +507,10 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { if (Instruction *TruncI = WideIV->getTruncInst()) ResultTy = TruncI->getType(); const InductionDescriptor &ID = WideIV->getInductionDescriptor(); - VPValue *Step = - vputils::getOrCreateVPValueForSCEVExpr(Plan, ID.getStep(), SE); + VPValue *Step = WideIV->getStepValue(); VPValue *BaseIV = CanonicalIV; - if (!CanonicalIV->isCanonical(ID, ResultTy)) { + if (!CanonicalIV->isCanonical(ID.getKind(), WideIV->getStartValue(), Step, + ResultTy)) { BaseIV = new VPDerivedIVRecipe(ID, WideIV->getStartValue(), CanonicalIV, Step, ResultTy); HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP); @@ -522,9 +589,9 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, return; LLVMContext &Ctx = SE.getContext(); - auto *BOC = - new VPInstruction(VPInstruction::BranchOnCond, - {Plan.getOrAddExternalDef(ConstantInt::getTrue(Ctx))}); + auto *BOC = new VPInstruction( + VPInstruction::BranchOnCond, + {Plan.getVPValueOrAddLiveIn(ConstantInt::getTrue(Ctx))}); Term->eraseFromParent(); ExitingVPBB->appendRecipe(BOC); Plan.setVF(BestVF); @@ -533,3 +600,181 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, // 1. Replace inductions with constants. // 2. Replace vector loop region with VPBasicBlock. } + +#ifndef NDEBUG +static VPRegionBlock *GetReplicateRegion(VPRecipeBase *R) { + auto *Region = dyn_cast_or_null<VPRegionBlock>(R->getParent()->getParent()); + if (Region && Region->isReplicator()) { + assert(Region->getNumSuccessors() == 1 && + Region->getNumPredecessors() == 1 && "Expected SESE region!"); + assert(R->getParent()->size() == 1 && + "A recipe in an original replicator region must be the only " + "recipe in its block"); + return Region; + } + return nullptr; +} +#endif + +static bool properlyDominates(const VPRecipeBase *A, const VPRecipeBase *B, + VPDominatorTree &VPDT) { + if (A == B) + return false; + + auto LocalComesBefore = [](const VPRecipeBase *A, const VPRecipeBase *B) { + for (auto &R : *A->getParent()) { + if (&R == A) + return true; + if (&R == B) + return false; + } + llvm_unreachable("recipe not found"); + }; + const VPBlockBase *ParentA = A->getParent(); + const VPBlockBase *ParentB = B->getParent(); + if (ParentA == ParentB) + return LocalComesBefore(A, B); + + assert(!GetReplicateRegion(const_cast<VPRecipeBase *>(A)) && + "No replicate regions expected at this point"); + assert(!GetReplicateRegion(const_cast<VPRecipeBase *>(B)) && + "No replicate regions expected at this point"); + return VPDT.properlyDominates(ParentA, ParentB); +} + +/// Sink users of \p FOR after the recipe defining the previous value \p +/// Previous of the recurrence. \returns true if all users of \p FOR could be +/// re-arranged as needed or false if it is not possible. +static bool +sinkRecurrenceUsersAfterPrevious(VPFirstOrderRecurrencePHIRecipe *FOR, + VPRecipeBase *Previous, + VPDominatorTree &VPDT) { + // Collect recipes that need sinking. + SmallVector<VPRecipeBase *> WorkList; + SmallPtrSet<VPRecipeBase *, 8> Seen; + Seen.insert(Previous); + auto TryToPushSinkCandidate = [&](VPRecipeBase *SinkCandidate) { + // The previous value must not depend on the users of the recurrence phi. In + // that case, FOR is not a fixed order recurrence. + if (SinkCandidate == Previous) + return false; + + if (isa<VPHeaderPHIRecipe>(SinkCandidate) || + !Seen.insert(SinkCandidate).second || + properlyDominates(Previous, SinkCandidate, VPDT)) + return true; + + if (SinkCandidate->mayHaveSideEffects()) + return false; + + WorkList.push_back(SinkCandidate); + return true; + }; + + // Recursively sink users of FOR after Previous. + WorkList.push_back(FOR); + for (unsigned I = 0; I != WorkList.size(); ++I) { + VPRecipeBase *Current = WorkList[I]; + assert(Current->getNumDefinedValues() == 1 && + "only recipes with a single defined value expected"); + + for (VPUser *User : Current->getVPSingleValue()->users()) { + if (auto *R = dyn_cast<VPRecipeBase>(User)) + if (!TryToPushSinkCandidate(R)) + return false; + } + } + + // Keep recipes to sink ordered by dominance so earlier instructions are + // processed first. + sort(WorkList, [&VPDT](const VPRecipeBase *A, const VPRecipeBase *B) { + return properlyDominates(A, B, VPDT); + }); + + for (VPRecipeBase *SinkCandidate : WorkList) { + if (SinkCandidate == FOR) + continue; + + SinkCandidate->moveAfter(Previous); + Previous = SinkCandidate; + } + return true; +} + +bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, + VPBuilder &Builder) { + VPDominatorTree VPDT; + VPDT.recalculate(Plan); + + SmallVector<VPFirstOrderRecurrencePHIRecipe *> RecurrencePhis; + for (VPRecipeBase &R : + Plan.getVectorLoopRegion()->getEntry()->getEntryBasicBlock()->phis()) + if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) + RecurrencePhis.push_back(FOR); + + for (VPFirstOrderRecurrencePHIRecipe *FOR : RecurrencePhis) { + SmallPtrSet<VPFirstOrderRecurrencePHIRecipe *, 4> SeenPhis; + VPRecipeBase *Previous = FOR->getBackedgeValue()->getDefiningRecipe(); + // Fixed-order recurrences do not contain cycles, so this loop is guaranteed + // to terminate. + while (auto *PrevPhi = + dyn_cast_or_null<VPFirstOrderRecurrencePHIRecipe>(Previous)) { + assert(PrevPhi->getParent() == FOR->getParent()); + assert(SeenPhis.insert(PrevPhi).second); + Previous = PrevPhi->getBackedgeValue()->getDefiningRecipe(); + } + + if (!sinkRecurrenceUsersAfterPrevious(FOR, Previous, VPDT)) + return false; + + // Introduce a recipe to combine the incoming and previous values of a + // fixed-order recurrence. + VPBasicBlock *InsertBlock = Previous->getParent(); + if (isa<VPHeaderPHIRecipe>(Previous)) + Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi()); + else + Builder.setInsertPoint(InsertBlock, std::next(Previous->getIterator())); + + auto *RecurSplice = cast<VPInstruction>( + Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice, + {FOR, FOR->getBackedgeValue()})); + + FOR->replaceAllUsesWith(RecurSplice); + // Set the first operand of RecurSplice to FOR again, after replacing + // all users. + RecurSplice->setOperand(0, FOR); + } + return true; +} + +void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { + for (VPRecipeBase &R : + Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { + auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); + if (!PhiR) + continue; + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + RecurKind RK = RdxDesc.getRecurrenceKind(); + if (RK != RecurKind::Add && RK != RecurKind::Mul) + continue; + + SmallSetVector<VPValue *, 8> Worklist; + Worklist.insert(PhiR); + + for (unsigned I = 0; I != Worklist.size(); ++I) { + VPValue *Cur = Worklist[I]; + if (auto *RecWithFlags = + dyn_cast<VPRecipeWithIRFlags>(Cur->getDefiningRecipe())) { + RecWithFlags->dropPoisonGeneratingFlags(); + } + + for (VPUser *U : Cur->users()) { + auto *UserRecipe = dyn_cast<VPRecipeBase>(U); + if (!UserRecipe) + continue; + for (VPValue *V : UserRecipe->definedValues()) + Worklist.insert(V); + } + } + } +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index be0d8e76d809..3eccf6e9600d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -25,23 +25,23 @@ class ScalarEvolution; class Loop; class PredicatedScalarEvolution; class TargetLibraryInfo; +class VPBuilder; +class VPRecipeBuilder; struct VPlanTransforms { /// Replaces the VPInstructions in \p Plan with corresponding /// widen recipes. static void - VPInstructionsToVPRecipes(Loop *OrigLoop, VPlanPtr &Plan, + VPInstructionsToVPRecipes(VPlanPtr &Plan, function_ref<const InductionDescriptor *(PHINode *)> GetIntOrFpInductionDescriptor, - SmallPtrSetImpl<Instruction *> &DeadInstructions, ScalarEvolution &SE, const TargetLibraryInfo &TLI); - static bool sinkScalarOperands(VPlan &Plan); - - /// Merge replicate regions in their successor region, if a replicate region - /// is connected to a successor replicate region with the same predicate by a - /// single, empty VPBasicBlock. - static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan); + /// Wrap predicated VPReplicateRecipes with a mask operand in an if-then + /// region block and remove the mask operand. Optimize the created regions by + /// iteratively sinking scalar operands into the region, followed by merging + /// regions until no improvements are remaining. + static void createAndOptimizeReplicateRegions(VPlan &Plan); /// Remove redundant VPBasicBlocks by merging them into their predecessor if /// the predecessor has a single successor. @@ -71,6 +71,19 @@ struct VPlanTransforms { /// them with already existing recipes expanding the same SCEV expression. static void removeRedundantExpandSCEVRecipes(VPlan &Plan); + /// Sink users of fixed-order recurrences after the recipe defining their + /// previous value. Then introduce FirstOrderRecurrenceSplice VPInstructions + /// to combine the value from the recurrence phis and previous values. The + /// current implementation assumes all users can be sunk after the previous + /// value, which is enforced by earlier legality checks. + /// \returns true if all users of fixed-order recurrences could be re-arranged + /// as needed or false if it is not possible. In the latter case, \p Plan is + /// not valid. + static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); + + /// Clear NSW/NUW flags from reduction instructions if necessary. + static void clearReductionWrapFlags(VPlan &Plan); + /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the /// resulting plan to \p BestVF and \p BestUF. static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index 62ec65cbfe5d..ac110bb3b0ef 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -171,16 +171,19 @@ public: /// Returns true if this VPValue is defined by a recipe. bool hasDefiningRecipe() const { return getDefiningRecipe(); } + /// Returns true if this VPValue is a live-in, i.e. defined outside the VPlan. + bool isLiveIn() const { return !hasDefiningRecipe(); } + /// Returns the underlying IR value, if this VPValue is defined outside the /// scope of VPlan. Returns nullptr if the VPValue is defined by a VPDef /// inside a VPlan. Value *getLiveInIRValue() { - assert(!hasDefiningRecipe() && + assert(isLiveIn() && "VPValue is not a live-in; it is defined by a VPDef inside a VPlan"); return getUnderlyingValue(); } const Value *getLiveInIRValue() const { - assert(!hasDefiningRecipe() && + assert(isLiveIn() && "VPValue is not a live-in; it is defined by a VPDef inside a VPlan"); return getUnderlyingValue(); } @@ -342,15 +345,16 @@ public: VPScalarIVStepsSC, VPWidenCallSC, VPWidenCanonicalIVSC, + VPWidenCastSC, VPWidenGEPSC, VPWidenMemoryInstructionSC, VPWidenSC, VPWidenSelectSC, - - // Phi-like recipes. Need to be kept together. + // START: Phi-like recipes. Need to be kept together. VPBlendSC, VPPredInstPHISC, - // Header-phi recipes. Need to be kept together. + // START: SubclassID for recipes that inherit VPHeaderPHIRecipe. + // VPHeaderPHIRecipe need to be kept together. VPCanonicalIVPHISC, VPActiveLaneMaskPHISC, VPFirstOrderRecurrencePHISC, @@ -358,8 +362,11 @@ public: VPWidenIntOrFpInductionSC, VPWidenPointerInductionSC, VPReductionPHISC, + // END: SubclassID for recipes that inherit VPHeaderPHIRecipe + // END: Phi-like recipes VPFirstPHISC = VPBlendSC, VPFirstHeaderPHISC = VPCanonicalIVPHISC, + VPLastHeaderPHISC = VPReductionPHISC, VPLastPHISC = VPReductionPHISC, }; @@ -434,6 +441,7 @@ class VPSlotTracker { void assignSlot(const VPValue *V); void assignSlots(const VPlan &Plan); + void assignSlots(const VPBasicBlock *VPBB); public: VPSlotTracker(const VPlan *Plan = nullptr) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 18125cebed33..d6b81543dbc9 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -15,6 +15,7 @@ #include "VPlanVerifier.h" #include "VPlan.h" #include "VPlanCFG.h" +#include "VPlanDominatorTree.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/Support/CommandLine.h" @@ -189,9 +190,8 @@ static bool verifyPhiRecipes(const VPBasicBlock *VPBB) { return true; } -static bool -verifyVPBasicBlock(const VPBasicBlock *VPBB, - DenseMap<const VPBlockBase *, unsigned> &BlockNumbering) { +static bool verifyVPBasicBlock(const VPBasicBlock *VPBB, + VPDominatorTree &VPDT) { if (!verifyPhiRecipes(VPBB)) return false; @@ -206,7 +206,8 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB, for (const VPValue *V : R.definedValues()) { for (const VPUser *U : V->users()) { auto *UI = dyn_cast<VPRecipeBase>(U); - if (!UI || isa<VPHeaderPHIRecipe>(UI)) + // TODO: check dominance of incoming values for phis properly. + if (!UI || isa<VPHeaderPHIRecipe>(UI) || isa<VPPredInstPHIRecipe>(UI)) continue; // If the user is in the same block, check it comes after R in the @@ -219,27 +220,7 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB, continue; } - // Skip blocks outside any region for now and blocks outside - // replicate-regions. - auto *ParentR = VPBB->getParent(); - if (!ParentR || !ParentR->isReplicator()) - continue; - - // For replicators, verify that VPPRedInstPHIRecipe defs are only used - // in subsequent blocks. - if (isa<VPPredInstPHIRecipe>(&R)) { - auto I = BlockNumbering.find(UI->getParent()); - unsigned BlockNumber = I == BlockNumbering.end() ? std::numeric_limits<unsigned>::max() : I->second; - if (BlockNumber < BlockNumbering[ParentR]) { - errs() << "Use before def!\n"; - return false; - } - continue; - } - - // All non-VPPredInstPHIRecipe recipes in the block must be used in - // the replicate region only. - if (UI->getParent()->getParent() != ParentR) { + if (!VPDT.dominates(VPBB, UI->getParent())) { errs() << "Use before def!\n"; return false; } @@ -250,15 +231,13 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB, } bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { - DenseMap<const VPBlockBase *, unsigned> BlockNumbering; - unsigned Cnt = 0; + VPDominatorTree VPDT; + VPDT.recalculate(const_cast<VPlan &>(Plan)); + auto Iter = vp_depth_first_deep(Plan.getEntry()); - for (const VPBlockBase *VPB : Iter) { - BlockNumbering[VPB] = Cnt++; - auto *VPBB = dyn_cast<VPBasicBlock>(VPB); - if (!VPBB) - continue; - if (!verifyVPBasicBlock(VPBB, BlockNumbering)) + for (const VPBasicBlock *VPBB : + VPBlockUtils::blocksOnly<const VPBasicBlock>(Iter)) { + if (!verifyVPBasicBlock(VPBB, VPDT)) return false; } diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 2e489757ebc1..13464c9d3496 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -25,11 +25,8 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Vectorize.h" #include <numeric> #define DEBUG_TYPE "vector-combine" @@ -247,7 +244,7 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // still need a shuffle to change the vector size. auto *Ty = cast<FixedVectorType>(I.getType()); unsigned OutputNumElts = Ty->getNumElements(); - SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem); + SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem); assert(OffsetEltIndex < MinVecNumElts && "Address offset too big"); Mask[0] = OffsetEltIndex; if (OffsetEltIndex) @@ -460,9 +457,9 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, // If we are extracting from 2 different indexes, then one operand must be // shuffled before performing the vector operation. The shuffle mask is - // undefined except for 1 lane that is being translated to the remaining + // poison except for 1 lane that is being translated to the remaining // extraction lane. Therefore, it is a splat shuffle. Ex: - // ShufMask = { undef, undef, 0, undef } + // ShufMask = { poison, poison, 0, poison } // TODO: The cost model has an option for a "broadcast" shuffle // (splat-from-element-0), but no option for a more general splat. NewCost += @@ -479,11 +476,11 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, /// to a new element location. static Value *createShiftShuffle(Value *Vec, unsigned OldIndex, unsigned NewIndex, IRBuilder<> &Builder) { - // The shuffle mask is undefined except for 1 lane that is being translated + // The shuffle mask is poison except for 1 lane that is being translated // to the new element index. Example for OldIndex == 2 and NewIndex == 0: - // ShufMask = { 2, undef, undef, undef } + // ShufMask = { 2, poison, poison, poison } auto *VecTy = cast<FixedVectorType>(Vec->getType()); - SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); + SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); ShufMask[NewIndex] = OldIndex; return Builder.CreateShuffleVector(Vec, ShufMask, "shift"); } @@ -917,7 +914,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); InstructionCost NewCost = TTI.getCmpSelInstrCost( CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred); - SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); + SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); ShufMask[CheapIndex] = ExpensiveIndex; NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, ShufMask); @@ -932,7 +929,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { // Create a vector constant from the 2 scalar constants. SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(), - UndefValue::get(VecTy->getElementType())); + PoisonValue::get(VecTy->getElementType())); CmpC[Index0] = C0; CmpC[Index1] = C1; Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); @@ -1565,7 +1562,7 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { // Calculate our ReconstructMasks from the OrigReconstructMasks and the // modified order of the input shuffles. SmallVector<SmallVector<int>> ReconstructMasks; - for (auto Mask : OrigReconstructMasks) { + for (const auto &Mask : OrigReconstructMasks) { SmallVector<int> ReconstructMask; for (int M : Mask) { auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) { @@ -1596,12 +1593,12 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first)); } while (V1A.size() < NumElts) { - V1A.push_back(UndefMaskElem); - V1B.push_back(UndefMaskElem); + V1A.push_back(PoisonMaskElem); + V1B.push_back(PoisonMaskElem); } while (V2A.size() < NumElts) { - V2A.push_back(UndefMaskElem); - V2B.push_back(UndefMaskElem); + V2A.push_back(PoisonMaskElem); + V2B.push_back(PoisonMaskElem); } auto AddShuffleCost = [&](InstructionCost C, Instruction *I) { @@ -1660,16 +1657,16 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { return SSV->getOperand(Op); return SV->getOperand(Op); }; - Builder.SetInsertPoint(SVI0A->getNextNode()); + Builder.SetInsertPoint(SVI0A->getInsertionPointAfterDef()); Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0), GetShuffleOperand(SVI0A, 1), V1A); - Builder.SetInsertPoint(SVI0B->getNextNode()); + Builder.SetInsertPoint(SVI0B->getInsertionPointAfterDef()); Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0), GetShuffleOperand(SVI0B, 1), V1B); - Builder.SetInsertPoint(SVI1A->getNextNode()); + Builder.SetInsertPoint(SVI1A->getInsertionPointAfterDef()); Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0), GetShuffleOperand(SVI1A, 1), V2A); - Builder.SetInsertPoint(SVI1B->getNextNode()); + Builder.SetInsertPoint(SVI1B->getInsertionPointAfterDef()); Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0), GetShuffleOperand(SVI1B, 1), V2B); Builder.SetInsertPoint(Op0); @@ -1811,54 +1808,6 @@ bool VectorCombine::run() { return MadeChange; } -// Pass manager boilerplate below here. - -namespace { -class VectorCombineLegacyPass : public FunctionPass { -public: - static char ID; - VectorCombineLegacyPass() : FunctionPass(ID) { - initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.setPreservesCFG(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - FunctionPass::getAnalysisUsage(AU); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA, AC, false); - return Combiner.run(); - } -}; -} // namespace - -char VectorCombineLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", - "Optimize scalar/vector ops", false, - false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", - "Optimize scalar/vector ops", false, false) -Pass *llvm::createVectorCombinePass() { - return new VectorCombineLegacyPass(); -} - PreservedAnalyses VectorCombinePass::run(Function &F, FunctionAnalysisManager &FAM) { auto &AC = FAM.getResult<AssumptionAnalysis>(F); diff --git a/llvm/lib/Transforms/Vectorize/Vectorize.cpp b/llvm/lib/Transforms/Vectorize/Vectorize.cpp index 208e5eeea864..2f5048d2a664 100644 --- a/llvm/lib/Transforms/Vectorize/Vectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/Vectorize.cpp @@ -12,10 +12,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Vectorize.h" -#include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/Vectorize.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/InitializePasses.h" #include "llvm/PassRegistry.h" @@ -23,20 +19,5 @@ using namespace llvm; /// Initialize all passes linked into the Vectorization library. void llvm::initializeVectorization(PassRegistry &Registry) { - initializeLoopVectorizePass(Registry); - initializeSLPVectorizerPass(Registry); initializeLoadStoreVectorizerLegacyPassPass(Registry); - initializeVectorCombineLegacyPassPass(Registry); -} - -void LLVMInitializeVectorization(LLVMPassRegistryRef R) { - initializeVectorization(*unwrap(R)); -} - -void LLVMAddLoopVectorizePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopVectorizePass()); -} - -void LLVMAddSLPVectorizePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createSLPVectorizerPass()); } |