diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
| commit | b1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch) | |
| tree | 7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/Transforms | |
| parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) | |
Diffstat (limited to 'llvm/lib/Transforms')
206 files changed, 20790 insertions, 12882 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 34c8a380448e..d09ac1c099c1 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -19,7 +19,6 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -29,7 +28,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -373,7 +371,7 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { InstructionCost SatCost = TTI.getIntrinsicInstrCost( IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}), TTI::TCK_RecipThroughput); - SatCost += TTI.getCastInstrCost(Instruction::SExt, SatTy, IntTy, + SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy, TTI::CastContextHint::None, TTI::TCK_RecipThroughput); @@ -398,6 +396,54 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { return true; } +/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids +/// pessimistic codegen that has to account for setting errno and can enable +/// vectorization. +static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, AssumptionCache &AC, + DominatorTree &DT) { + // Match a call to sqrt mathlib function. + auto *Call = dyn_cast<CallInst>(&I); + if (!Call) + return false; + + Module *M = Call->getModule(); + LibFunc Func; + if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) + return false; + + if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) + return false; + + // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created + // (because NNAN or the operand arg must not be less than -0.0) and (2) we + // would not end up lowering to a libcall anyway (which could change the value + // of errno), then: + // (1) errno won't be set. + // (2) it is safe to convert this to an intrinsic call. + Type *Ty = Call->getType(); + Value *Arg = Call->getArgOperand(0); + if (TTI.haveFastSqrt(Ty) && + (Call->hasNoNaNs() || + cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, &I, + &DT))) { + IRBuilder<> Builder(&I); + IRBuilderBase::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(Call->getFastMathFlags()); + + Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); + Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); + I.replaceAllUsesWith(NewSqrt); + + // Explicitly erase the old call because a call with side effects is not + // trivially dead. + I.eraseFromParent(); + return true; + } + + return false; +} + // Check if this array of constants represents a cttz table. // Iterate over the elements from \p Table by trying to find/match all // the numbers from 0 to \p InputBits that should represent cttz results. @@ -447,7 +493,8 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // %shr = lshr i32 %mul, 27 // %idxprom = zext i32 %shr to i64 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0, -// i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 +// i64 %idxprom +// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 // // CASE 2: // %sub = sub i32 0, %x @@ -455,8 +502,9 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // %mul = mul i32 %and, 72416175 // %shr = lshr i32 %mul, 26 // %idxprom = zext i32 %shr to i64 -// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64 -// 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 +// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, +// i64 0, i64 %idxprom +// %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 // // CASE 3: // %sub = sub i32 0, %x @@ -464,16 +512,18 @@ static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, // %mul = mul i32 %and, 81224991 // %shr = lshr i32 %mul, 27 // %idxprom = zext i32 %shr to i64 -// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64 -// 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 +// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, +// i64 0, i64 %idxprom +// %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 // // CASE 4: // %sub = sub i64 0, %x // %and = and i64 %sub, %x // %mul = mul i64 %and, 283881067100198605 // %shr = lshr i64 %mul, 58 -// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64 -// %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 +// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, +// i64 %shr +// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 // // All this can be lowered to @llvm.cttz.i32/64 intrinsic. static bool tryToRecognizeTableBasedCttz(Instruction &I) { @@ -656,7 +706,10 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, make_range(Start->getIterator(), End->getIterator())) { if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc))) return false; - if (++NumScanned > MaxInstrsToScan) + + // Ignore debug info so that's not counted against MaxInstrsToScan. + // Otherwise debug info could affect codegen. + if (!isa<DbgInfoIntrinsic>(Inst) && ++NumScanned > MaxInstrsToScan) return false; } @@ -869,159 +922,13 @@ static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { return true; } -/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids -/// pessimistic codegen that has to account for setting errno and can enable -/// vectorization. -static bool foldSqrt(CallInst *Call, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI, AssumptionCache &AC, - DominatorTree &DT) { - Module *M = Call->getModule(); - - // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created - // (because NNAN or the operand arg must not be less than -0.0) and (2) we - // would not end up lowering to a libcall anyway (which could change the value - // of errno), then: - // (1) errno won't be set. - // (2) it is safe to convert this to an intrinsic call. - Type *Ty = Call->getType(); - Value *Arg = Call->getArgOperand(0); - if (TTI.haveFastSqrt(Ty) && - (Call->hasNoNaNs() || - cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, Call, - &DT))) { - IRBuilder<> Builder(Call); - IRBuilderBase::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Call->getFastMathFlags()); - - Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); - Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); - Call->replaceAllUsesWith(NewSqrt); - - // Explicitly erase the old call because a call with side effects is not - // trivially dead. - Call->eraseFromParent(); - return true; - } - - return false; -} - -/// Try to expand strcmp(P, "x") calls. -static bool expandStrcmp(CallInst *CI, DominatorTree &DT, bool &MadeCFGChange) { - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - - // Trivial cases are optimized during inst combine - if (Str1P == Str2P) - return false; - - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); - - Value *NonConstantP = nullptr; - StringRef ConstantStr; - - if (!HasStr1 && HasStr2 && Str2.size() == 1) { - NonConstantP = Str1P; - ConstantStr = Str2; - } else if (!HasStr2 && HasStr1 && Str1.size() == 1) { - NonConstantP = Str2P; - ConstantStr = Str1; - } else { - return false; - } - - // Check if strcmp result is only used in a comparison with zero - if (!isOnlyUsedInZeroComparison(CI)) - return false; - - // For strcmp(P, "x") do the following transformation: - // - // (before) - // dst = strcmp(P, "x") - // - // (after) - // v0 = P[0] - 'x' - // [if v0 == 0] - // v1 = P[1] - // dst = phi(v0, v1) - // - - IRBuilder<> B(CI->getParent()); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - - Type *RetType = CI->getType(); - - B.SetInsertPoint(CI); - BasicBlock *InitialBB = B.GetInsertBlock(); - Value *Str1FirstCharacterValue = - B.CreateZExt(B.CreateLoad(B.getInt8Ty(), NonConstantP), RetType); - Value *Str2FirstCharacterValue = - ConstantInt::get(RetType, static_cast<unsigned char>(ConstantStr[0])); - Value *FirstCharacterSub = - B.CreateNSWSub(Str1FirstCharacterValue, Str2FirstCharacterValue); - Value *IsFirstCharacterSubZero = - B.CreateICmpEQ(FirstCharacterSub, ConstantInt::get(RetType, 0)); - Instruction *IsFirstCharacterSubZeroBBTerminator = SplitBlockAndInsertIfThen( - IsFirstCharacterSubZero, CI, /*Unreachable*/ false, - /*BranchWeights*/ nullptr, &DTU); - - B.SetInsertPoint(IsFirstCharacterSubZeroBBTerminator); - B.GetInsertBlock()->setName("strcmp_expand_sub_is_zero"); - BasicBlock *IsFirstCharacterSubZeroBB = B.GetInsertBlock(); - Value *Str1SecondCharacterValue = B.CreateZExt( - B.CreateLoad(B.getInt8Ty(), B.CreateConstInBoundsGEP1_64( - B.getInt8Ty(), NonConstantP, 1)), - RetType); - - B.SetInsertPoint(CI); - B.GetInsertBlock()->setName("strcmp_expand_sub_join"); - - PHINode *Result = B.CreatePHI(RetType, 2); - Result->addIncoming(FirstCharacterSub, InitialBB); - Result->addIncoming(Str1SecondCharacterValue, IsFirstCharacterSubZeroBB); - - CI->replaceAllUsesWith(Result); - CI->eraseFromParent(); - - MadeCFGChange = true; - - return true; -} - -static bool foldLibraryCalls(Instruction &I, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI, DominatorTree &DT, - AssumptionCache &AC, bool &MadeCFGChange) { - CallInst *CI = dyn_cast<CallInst>(&I); - if (!CI) - return false; - - LibFunc Func; - Module *M = I.getModule(); - if (!TLI.getLibFunc(*CI, Func) || !isLibFuncEmittable(M, &TLI, Func)) - return false; - - switch (Func) { - case LibFunc_sqrt: - case LibFunc_sqrtf: - case LibFunc_sqrtl: - return foldSqrt(CI, TTI, TLI, AC, DT); - case LibFunc_strcmp: - return expandStrcmp(CI, DT, MadeCFGChange); - default: - break; - } - - return false; -} - /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, AliasAnalysis &AA, - AssumptionCache &AC, bool &MadeCFGChange) { + AssumptionCache &AC) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. @@ -1046,7 +953,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, // NOTE: This function introduces erasing of the instruction `I`, so it // needs to be called at the end of this sequence, otherwise we may make // bugs. - MadeChange |= foldLibraryCalls(I, TTI, TLI, DT, AC, MadeCFGChange); + MadeChange |= foldSqrt(I, TTI, TLI, AC, DT); } } @@ -1062,12 +969,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, /// handled in the callers of this function. static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, DominatorTree &DT, - AliasAnalysis &AA, bool &ChangedCFG) { + AliasAnalysis &AA) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, ChangedCFG); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC); return MadeChange; } @@ -1078,21 +985,12 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - - bool MadeCFGChange = false; - - if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { + if (!runImpl(F, AC, TTI, TLI, DT, AA)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } - // Mark all the analyses that instcombine updates as preserved. PreservedAnalyses PA; - - if (MadeCFGChange) - PA.preserve<DominatorTreeAnalysis>(); - else - PA.preserveSet<CFGAnalyses>(); - + PA.preserveSet<CFGAnalyses>(); return PA; } diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 6c62e84077ac..4d9050be5c55 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -366,7 +366,7 @@ static Type *getReducedType(Value *V, Type *Ty) { Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { Type *Ty = getReducedType(V, SclTy); if (auto *C = dyn_cast<Constant>(V)) { - C = ConstantExpr::getIntegerCast(C, Ty, false); + C = ConstantExpr::getTrunc(C, Ty); // If we got a constantexpr back, try to simplify it with DL info. return ConstantFoldConstant(C, DL, &TLI); } diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp index bf823ac55497..387734358775 100644 --- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp +++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp @@ -177,8 +177,7 @@ void CFGuard::insertCFGuardCheck(CallBase *CB) { // Create new call instruction. The CFGuard check should always be a call, // even if the original CallBase is an Invoke or CallBr instruction. CallInst *GuardCheck = - B.CreateCall(GuardFnType, GuardCheckLoad, - {B.CreateBitCast(CalledOperand, B.getInt8PtrTy())}, Bundles); + B.CreateCall(GuardFnType, GuardCheckLoad, {CalledOperand}, Bundles); // Ensure that the first argument is passed in the correct register // (e.g. ECX on 32-bit X86 targets). @@ -196,11 +195,6 @@ void CFGuard::insertCFGuardDispatch(CallBase *CB) { Value *CalledOperand = CB->getCalledOperand(); Type *CalledOperandType = CalledOperand->getType(); - // Cast the guard dispatch global to the type of the called operand. - PointerType *PTy = PointerType::get(CalledOperandType, 0); - if (GuardFnGlobal->getType() != PTy) - GuardFnGlobal = ConstantExpr::getBitCast(GuardFnGlobal, PTy); - // Load the global as a pointer to a function of the same type. LoadInst *GuardDispatchLoad = B.CreateLoad(CalledOperandType, GuardFnGlobal); @@ -236,8 +230,9 @@ bool CFGuard::doInitialization(Module &M) { return false; // Set up prototypes for the guard check and dispatch functions. - GuardFnType = FunctionType::get(Type::getVoidTy(M.getContext()), - {Type::getInt8PtrTy(M.getContext())}, false); + GuardFnType = + FunctionType::get(Type::getVoidTy(M.getContext()), + {PointerType::getUnqual(M.getContext())}, false); GuardFnPtrType = PointerType::get(GuardFnType, 0); // Get or insert the guard check or dispatch global symbols. diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index 29978bef661c..3e3825fcd50e 100644 --- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -29,15 +29,13 @@ struct Lowerer : coro::LowererBase { static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) { Builder.SetInsertPoint(SubFn); - Value *FrameRaw = SubFn->getFrame(); + Value *FramePtr = SubFn->getFrame(); int Index = SubFn->getIndex(); - auto *FrameTy = StructType::get( - SubFn->getContext(), {Builder.getInt8PtrTy(), Builder.getInt8PtrTy()}); - PointerType *FramePtrTy = FrameTy->getPointerTo(); + auto *FrameTy = StructType::get(SubFn->getContext(), + {Builder.getPtrTy(), Builder.getPtrTy()}); Builder.SetInsertPoint(SubFn); - auto *FramePtr = Builder.CreateBitCast(FrameRaw, FramePtrTy); auto *Gep = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index); auto *Load = Builder.CreateLoad(FrameTy->getElementType(Index), Gep); diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index d78ab1c1ea28..2f4083028ae0 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -165,7 +165,7 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize, auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); Frame->setAlignment(FrameAlign); auto *FrameVoidPtr = - new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt); + new BitCastInst(Frame, PointerType::getUnqual(C), "vFrame", InsertPt); for (auto *CB : CoroBegins) { CB->replaceAllUsesWith(FrameVoidPtr); @@ -194,12 +194,49 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB, for (auto *DA : It->second) Visited.insert(DA->getParent()); + SmallPtrSet<const BasicBlock *, 32> EscapingBBs; + for (auto *U : CB->users()) { + // The use from coroutine intrinsics are not a problem. + if (isa<CoroFreeInst, CoroSubFnInst, CoroSaveInst>(U)) + continue; + + // Think all other usages may be an escaping candidate conservatively. + // + // Note that the major user of switch ABI coroutine (the C++) will store + // resume.fn, destroy.fn and the index to the coroutine frame immediately. + // So the parent of the coro.begin in C++ will be always escaping. + // Then we can't get any performance benefits for C++ by improving the + // precision of the method. + // + // The reason why we still judge it is we want to make LLVM Coroutine in + // switch ABIs to be self contained as much as possible instead of a + // by-product of C++20 Coroutines. + EscapingBBs.insert(cast<Instruction>(U)->getParent()); + } + + bool PotentiallyEscaped = false; + do { const auto *BB = Worklist.pop_back_val(); if (!Visited.insert(BB).second) continue; - if (TIs.count(BB)) - return true; + + // A Path insensitive marker to test whether the coro.begin escapes. + // It is intentional to make it path insensitive while it may not be + // precise since we don't want the process to be too slow. + PotentiallyEscaped |= EscapingBBs.count(BB); + + if (TIs.count(BB)) { + if (isa<ReturnInst>(BB->getTerminator()) || PotentiallyEscaped) + return true; + + // If the function ends with the exceptional terminator, the memory used + // by the coroutine frame can be released by stack unwinding + // automatically. So we can think the coro.begin doesn't escape if it + // exits the function by exceptional terminator. + + continue; + } // Conservatively say that there is potentially a path. if (!--Limit) @@ -236,36 +273,36 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // memory location storing that value and not the virtual register. SmallPtrSet<BasicBlock *, 8> Terminators; - // First gather all of the non-exceptional terminators for the function. + // First gather all of the terminators for the function. // Consider the final coro.suspend as the real terminator when the current // function is a coroutine. - for (BasicBlock &B : *F) { - auto *TI = B.getTerminator(); - if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && - !isa<UnreachableInst>(TI)) - Terminators.insert(&B); - } + for (BasicBlock &B : *F) { + auto *TI = B.getTerminator(); + + if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI)) + continue; + + Terminators.insert(&B); + } // Filter out the coro.destroy that lie along exceptional paths. SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; for (const auto &It : DestroyAddr) { - // If there is any coro.destroy dominates all of the terminators for the - // coro.begin, we could know the corresponding coro.begin wouldn't escape. - for (Instruction *DA : It.second) { - if (llvm::all_of(Terminators, [&](auto *TI) { - return DT.dominates(DA, TI->getTerminator()); - })) { - ReferencedCoroBegins.insert(It.first); - break; - } - } - - // Whether there is any paths from coro.begin to Terminators which not pass - // through any of the coro.destroys. + // If every terminators is dominated by coro.destroy, we could know the + // corresponding coro.begin wouldn't escape. + // + // Otherwise hasEscapePath would decide whether there is any paths from + // coro.begin to Terminators which not pass through any of the + // coro.destroys. // // hasEscapePath is relatively slow, so we avoid to run it as much as // possible. - if (!ReferencedCoroBegins.count(It.first) && + if (llvm::all_of(Terminators, + [&](auto *TI) { + return llvm::any_of(It.second, [&](auto *DA) { + return DT.dominates(DA, TI->getTerminator()); + }); + }) || !hasEscapePath(It.first, Terminators)) ReferencedCoroBegins.insert(It.first); } diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index 1f373270f951..1134b20880f1 100644 --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -63,7 +63,7 @@ public: llvm::sort(V); } - size_t blockToIndex(BasicBlock *BB) const { + size_t blockToIndex(BasicBlock const *BB) const { auto *I = llvm::lower_bound(V, BB); assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block"); return I - V.begin(); @@ -112,10 +112,11 @@ class SuspendCrossingInfo { } /// Compute the BlockData for the current function in one iteration. - /// Returns whether the BlockData changes in this iteration. /// Initialize - Whether this is the first iteration, we can optimize /// the initial case a little bit by manual loop switch. - template <bool Initialize = false> bool computeBlockData(); + /// Returns whether the BlockData changes in this iteration. + template <bool Initialize = false> + bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT); public: #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -223,12 +224,14 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { } #endif -template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() { - const size_t N = Mapping.size(); +template <bool Initialize> +bool SuspendCrossingInfo::computeBlockData( + const ReversePostOrderTraversal<Function *> &RPOT) { bool Changed = false; - for (size_t I = 0; I < N; ++I) { - auto &B = Block[I]; + for (const BasicBlock *BB : RPOT) { + auto BBNo = Mapping.blockToIndex(BB); + auto &B = Block[BBNo]; // We don't need to count the predecessors when initialization. if constexpr (!Initialize) @@ -261,7 +264,7 @@ template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() { } if (B.Suspend) { - // If block S is a suspend block, it should kill all of the blocks it + // If block B is a suspend block, it should kill all of the blocks it // consumes. B.Kills |= B.Consumes; } else if (B.End) { @@ -273,8 +276,8 @@ template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() { } else { // This is reached when B block it not Suspend nor coro.end and it // need to make sure that it is not in the kill set. - B.KillLoop |= B.Kills[I]; - B.Kills.reset(I); + B.KillLoop |= B.Kills[BBNo]; + B.Kills.reset(BBNo); } if constexpr (!Initialize) { @@ -283,9 +286,6 @@ template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() { } } - if constexpr (Initialize) - return true; - return Changed; } @@ -325,9 +325,11 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) markSuspendBlock(Save); } - computeBlockData</*Initialize=*/true>(); - - while (computeBlockData()) + // It is considered to be faster to use RPO traversal for forward-edges + // dataflow analysis. + ReversePostOrderTraversal<Function *> RPOT(&F); + computeBlockData</*Initialize=*/true>(RPOT); + while (computeBlockData</*Initialize*/ false>(RPOT)) ; LLVM_DEBUG(dump()); @@ -1073,7 +1075,7 @@ static DIType *solveDIType(DIBuilder &Builder, Type *Ty, RetType = CharSizeType; else { if (Size % 8 != 0) - Size = TypeSize::Fixed(Size + 8 - (Size % 8)); + Size = TypeSize::getFixed(Size + 8 - (Size % 8)); RetType = Builder.createArrayType( Size, Layout.getPrefTypeAlign(Ty).value(), CharSizeType, @@ -1290,10 +1292,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, std::optional<FieldIDType> SwitchIndexFieldId; if (Shape.ABI == coro::ABI::Switch) { - auto *FramePtrTy = FrameTy->getPointerTo(); - auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, - /*IsVarArg=*/false); - auto *FnPtrTy = FnTy->getPointerTo(); + auto *FnPtrTy = PointerType::getUnqual(C); // Add header fields for the resume and destroy functions. // We can rely on these being perfectly packed. @@ -1680,15 +1679,6 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) { return CleanupRet; } -static void createFramePtr(coro::Shape &Shape) { - auto *CB = Shape.CoroBegin; - IRBuilder<> Builder(CB->getNextNode()); - StructType *FrameTy = Shape.FrameTy; - PointerType *FramePtrTy = FrameTy->getPointerTo(); - Shape.FramePtr = - cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr")); -} - // Replace all alloca and SSA values that are accessed across suspend points // with GetElementPointer from coroutine frame + loads and stores. Create an // AllocaSpillBB that will become the new entry block for the resume parts of @@ -1700,7 +1690,6 @@ static void createFramePtr(coro::Shape &Shape) { // becomes: // // %hdl = coro.begin(...) -// %FramePtr = bitcast i8* hdl to %f.frame* // br label %AllocaSpillBB // // AllocaSpillBB: @@ -1764,8 +1753,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { // Note: If we change the strategy dealing with alignment, we need to refine // this casting. if (GEP->getType() != Orig->getType()) - return Builder.CreateBitCast(GEP, Orig->getType(), - Orig->getName() + Twine(".cast")); + return Builder.CreateAddrSpaceCast(GEP, Orig->getType(), + Orig->getName() + Twine(".cast")); } return GEP; }; @@ -1775,13 +1764,12 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { auto SpillAlignment = Align(FrameData.getAlign(Def)); // Create a store instruction storing the value into the // coroutine frame. - Instruction *InsertPt = nullptr; + BasicBlock::iterator InsertPt; Type *ByValTy = nullptr; if (auto *Arg = dyn_cast<Argument>(Def)) { // For arguments, we will place the store instruction right after - // the coroutine frame pointer instruction, i.e. bitcast of - // coro.begin from i8* to %f.frame*. - InsertPt = Shape.getInsertPtAfterFramePtr(); + // the coroutine frame pointer instruction, i.e. coro.begin. + InsertPt = Shape.getInsertPtAfterFramePtr()->getIterator(); // If we're spilling an Argument, make sure we clear 'nocapture' // from the coroutine function. @@ -1792,35 +1780,35 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { } else if (auto *CSI = dyn_cast<AnyCoroSuspendInst>(Def)) { // Don't spill immediately after a suspend; splitting assumes // that the suspend will be followed by a branch. - InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHI(); + InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHIIt(); } else { auto *I = cast<Instruction>(Def); if (!DT.dominates(CB, I)) { // If it is not dominated by CoroBegin, then spill should be // inserted immediately after CoroFrame is computed. - InsertPt = Shape.getInsertPtAfterFramePtr(); + InsertPt = Shape.getInsertPtAfterFramePtr()->getIterator(); } else if (auto *II = dyn_cast<InvokeInst>(I)) { // If we are spilling the result of the invoke instruction, split // the normal edge and insert the spill in the new block. auto *NewBB = SplitEdge(II->getParent(), II->getNormalDest()); - InsertPt = NewBB->getTerminator(); + InsertPt = NewBB->getTerminator()->getIterator(); } else if (isa<PHINode>(I)) { // Skip the PHINodes and EH pads instructions. BasicBlock *DefBlock = I->getParent(); if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator())) - InsertPt = splitBeforeCatchSwitch(CSI); + InsertPt = splitBeforeCatchSwitch(CSI)->getIterator(); else - InsertPt = &*DefBlock->getFirstInsertionPt(); + InsertPt = DefBlock->getFirstInsertionPt(); } else { assert(!I->isTerminator() && "unexpected terminator"); // For all other values, the spill is placed immediately after // the definition. - InsertPt = I->getNextNode(); + InsertPt = I->getNextNode()->getIterator(); } } auto Index = FrameData.getFieldIndex(Def); - Builder.SetInsertPoint(InsertPt); + Builder.SetInsertPoint(InsertPt->getParent(), InsertPt); auto *G = Builder.CreateConstInBoundsGEP2_32( FrameTy, FramePtr, 0, Index, Def->getName() + Twine(".spill.addr")); if (ByValTy) { @@ -1840,7 +1828,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { // reference provided with the frame GEP. if (CurrentBlock != U->getParent()) { CurrentBlock = U->getParent(); - Builder.SetInsertPoint(&*CurrentBlock->getFirstInsertionPt()); + Builder.SetInsertPoint(CurrentBlock, + CurrentBlock->getFirstInsertionPt()); auto *GEP = GetFramePointer(E.first); GEP->setName(E.first->getName() + Twine(".reload.addr")); @@ -1863,6 +1852,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { if (LdInst->getPointerOperandType() != LdInst->getType()) break; CurDef = LdInst->getPointerOperand(); + if (!isa<AllocaInst, LoadInst>(CurDef)) + break; DIs = FindDbgDeclareUses(CurDef); } } @@ -1878,7 +1869,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { &*Builder.GetInsertPoint()); // This dbg.declare is for the main function entry point. It // will be deleted in all coro-split functions. - coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame); + coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame, + false /*UseEntryValue*/); } } @@ -1911,7 +1903,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { if (Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce || Shape.ABI == coro::ABI::Async) { // If we found any allocas, replace all of their remaining uses with Geps. - Builder.SetInsertPoint(&SpillBlock->front()); + Builder.SetInsertPoint(SpillBlock, SpillBlock->begin()); for (const auto &P : FrameData.Allocas) { AllocaInst *Alloca = P.Alloca; auto *G = GetFramePointer(Alloca); @@ -1930,7 +1922,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { // dbg.declares and dbg.values with the reload from the frame. // Note: We cannot replace the alloca with GEP instructions indiscriminately, // as some of the uses may not be dominated by CoroBegin. - Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front()); + Builder.SetInsertPoint(Shape.AllocaSpillBlock, + Shape.AllocaSpillBlock->begin()); SmallVector<Instruction *, 4> UsersToUpdate; for (const auto &A : FrameData.Allocas) { AllocaInst *Alloca = A.Alloca; @@ -1980,16 +1973,12 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { // to the pointer in the frame. for (const auto &Alias : A.Aliases) { auto *FramePtr = GetFramePointer(Alloca); - auto *FramePtrRaw = - Builder.CreateBitCast(FramePtr, Type::getInt8PtrTy(C)); auto &Value = *Alias.second; auto ITy = IntegerType::get(C, Value.getBitWidth()); - auto *AliasPtr = Builder.CreateGEP(Type::getInt8Ty(C), FramePtrRaw, + auto *AliasPtr = Builder.CreateGEP(Type::getInt8Ty(C), FramePtr, ConstantInt::get(ITy, Value)); - auto *AliasPtrTyped = - Builder.CreateBitCast(AliasPtr, Alias.first->getType()); Alias.first->replaceUsesWithIf( - AliasPtrTyped, [&](Use &U) { return DT.dominates(CB, U); }); + AliasPtr, [&](Use &U) { return DT.dominates(CB, U); }); } } @@ -2046,8 +2035,8 @@ static void movePHIValuesToInsertedBlock(BasicBlock *SuccBB, int Index = PN->getBasicBlockIndex(InsertedBB); Value *V = PN->getIncomingValue(Index); PHINode *InputV = PHINode::Create( - V->getType(), 1, V->getName() + Twine(".") + SuccBB->getName(), - &InsertedBB->front()); + V->getType(), 1, V->getName() + Twine(".") + SuccBB->getName()); + InputV->insertBefore(InsertedBB->begin()); InputV->addIncoming(V, PredBB); PN->setIncomingValue(Index, InputV); PN = dyn_cast<PHINode>(PN->getNextNode()); @@ -2193,7 +2182,8 @@ static void rewritePHIs(BasicBlock &BB) { // ehAwareSplitEdge will clone the LandingPad in all the edge blocks. // We replace the original landing pad with a PHINode that will collect the // results from all of them. - ReplPHI = PHINode::Create(LandingPad->getType(), 1, "", LandingPad); + ReplPHI = PHINode::Create(LandingPad->getType(), 1, ""); + ReplPHI->insertBefore(LandingPad->getIterator()); ReplPHI->takeName(LandingPad); LandingPad->replaceAllUsesWith(ReplPHI); // We will erase the original landing pad at the end of this function after @@ -2428,15 +2418,13 @@ static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) { static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas, SmallVectorImpl<Instruction*> &DeadInsts) { for (auto *AI : LocalAllocas) { - auto M = AI->getModule(); IRBuilder<> Builder(AI); // Save the stack depth. Try to avoid doing this if the stackrestore // is going to immediately precede a return or something. Value *StackSave = nullptr; if (localAllocaNeedsStackSave(AI)) - StackSave = Builder.CreateCall( - Intrinsic::getDeclaration(M, Intrinsic::stacksave)); + StackSave = Builder.CreateStackSave(); // Allocate memory. auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize()); @@ -2454,9 +2442,7 @@ static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas, auto FI = cast<CoroAllocaFreeInst>(U); if (StackSave) { Builder.SetInsertPoint(FI); - Builder.CreateCall( - Intrinsic::getDeclaration(M, Intrinsic::stackrestore), - StackSave); + Builder.CreateStackRestore(StackSave); } } DeadInsts.push_back(cast<Instruction>(U)); @@ -2498,7 +2484,7 @@ static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy, coro::Shape &Shape) { // Make a fake function pointer as a sort of intrinsic. auto FnTy = FunctionType::get(ValueTy, {}, false); - auto Fn = ConstantPointerNull::get(FnTy->getPointerTo()); + auto Fn = ConstantPointerNull::get(Builder.getPtrTy()); auto Call = Builder.CreateCall(FnTy, Fn, {}); Shape.SwiftErrorOps.push_back(Call); @@ -2512,9 +2498,9 @@ static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy, static Value *emitSetSwiftErrorValue(IRBuilder<> &Builder, Value *V, coro::Shape &Shape) { // Make a fake function pointer as a sort of intrinsic. - auto FnTy = FunctionType::get(V->getType()->getPointerTo(), + auto FnTy = FunctionType::get(Builder.getPtrTy(), {V->getType()}, false); - auto Fn = ConstantPointerNull::get(FnTy->getPointerTo()); + auto Fn = ConstantPointerNull::get(Builder.getPtrTy()); auto Call = Builder.CreateCall(FnTy, Fn, { V }); Shape.SwiftErrorOps.push_back(Call); @@ -2765,17 +2751,8 @@ static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape, // Sink lifetime.start markers to dominate block when they are // only used outside the region. if (Valid && Lifetimes.size() != 0) { - // May be AI itself, when the type of AI is i8* - auto *NewBitCast = [&](AllocaInst *AI) -> Value* { - if (isa<AllocaInst>(Lifetimes[0]->getOperand(1))) - return AI; - auto *Int8PtrTy = Type::getInt8PtrTy(F.getContext()); - return CastInst::Create(Instruction::BitCast, AI, Int8PtrTy, "", - DomBB->getTerminator()); - }(AI); - auto *NewLifetime = Lifetimes[0]->clone(); - NewLifetime->replaceUsesOfWith(NewLifetime->getOperand(1), NewBitCast); + NewLifetime->replaceUsesOfWith(NewLifetime->getOperand(1), AI); NewLifetime->insertBefore(DomBB->getTerminator()); // All the outsided lifetime.start markers are no longer necessary. @@ -2800,6 +2777,11 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape, if (AI == Shape.SwitchLowering.PromiseAlloca) return; + // The __coro_gro alloca should outlive the promise, make sure we + // keep it outside the frame. + if (AI->hasMetadata(LLVMContext::MD_coro_outside_frame)) + return; + // The code that uses lifetime.start intrinsic does not work for functions // with loops without exit. Disable it on ABIs we know to generate such // code. @@ -2818,7 +2800,7 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape, void coro::salvageDebugInfo( SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, - DbgVariableIntrinsic *DVI, bool OptimizeFrame) { + DbgVariableIntrinsic *DVI, bool OptimizeFrame, bool UseEntryValue) { Function *F = DVI->getFunction(); IRBuilder<> Builder(F->getContext()); auto InsertPt = F->getEntryBlock().getFirstInsertionPt(); @@ -2870,7 +2852,9 @@ void coro::salvageDebugInfo( // Swift async arguments are described by an entry value of the ABI-defined // register containing the coroutine context. - if (IsSwiftAsyncArg && !Expr->isEntryValue()) + // Entry values in variadic expressions are not supported. + if (IsSwiftAsyncArg && UseEntryValue && !Expr->isEntryValue() && + Expr->isSingleLocationExpression()) Expr = DIExpression::prepend(Expr, DIExpression::EntryValue); // If the coroutine frame is an Argument, store it in an alloca to improve @@ -2902,13 +2886,13 @@ void coro::salvageDebugInfo( // dbg.value since it does not have the same function wide guarantees that // dbg.declare does. if (isa<DbgDeclareInst>(DVI)) { - Instruction *InsertPt = nullptr; + std::optional<BasicBlock::iterator> InsertPt; if (auto *I = dyn_cast<Instruction>(Storage)) InsertPt = I->getInsertionPointAfterDef(); else if (isa<Argument>(Storage)) - InsertPt = &*F->getEntryBlock().begin(); + InsertPt = F->getEntryBlock().begin(); if (InsertPt) - DVI->moveBefore(InsertPt); + DVI->moveBefore(*(*InsertPt)->getParent(), *InsertPt); } } @@ -3110,7 +3094,7 @@ void coro::buildCoroutineFrame( Shape.ABI == coro::ABI::Async) sinkSpillUsesAfterCoroBegin(F, FrameData, Shape.CoroBegin); Shape.FrameTy = buildFrameType(F, Shape, FrameData); - createFramePtr(Shape); + Shape.FramePtr = Shape.CoroBegin; // For now, this works for C++ programs only. buildFrameDebugInfo(F, Shape, FrameData); insertSpills(FrameData, Shape); diff --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h index 014938c15a0a..f01aa58eb899 100644 --- a/llvm/lib/Transforms/Coroutines/CoroInstr.h +++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -123,8 +123,8 @@ public: void clearPromise() { Value *Arg = getArgOperand(PromiseArg); - setArgOperand(PromiseArg, - ConstantPointerNull::get(Type::getInt8PtrTy(getContext()))); + setArgOperand(PromiseArg, ConstantPointerNull::get( + PointerType::getUnqual(getContext()))); if (isa<AllocaInst>(Arg)) return; assert((isa<BitCastInst>(Arg) || isa<GetElementPtrInst>(Arg)) && @@ -185,9 +185,7 @@ public: void setCoroutineSelf() { assert(isa<ConstantPointerNull>(getArgOperand(CoroutineArg)) && "Coroutine argument is already assigned"); - auto *const Int8PtrTy = Type::getInt8PtrTy(getContext()); - setArgOperand(CoroutineArg, - ConstantExpr::getBitCast(getFunction(), Int8PtrTy)); + setArgOperand(CoroutineArg, getFunction()); } // Methods to support type inquiry through isa, cast, and dyn_cast: @@ -611,8 +609,37 @@ public: } }; +/// This represents the llvm.end.results instruction. +class LLVM_LIBRARY_VISIBILITY CoroEndResults : public IntrinsicInst { +public: + op_iterator retval_begin() { return arg_begin(); } + const_op_iterator retval_begin() const { return arg_begin(); } + + op_iterator retval_end() { return arg_end(); } + const_op_iterator retval_end() const { return arg_end(); } + + iterator_range<op_iterator> return_values() { + return make_range(retval_begin(), retval_end()); + } + iterator_range<const_op_iterator> return_values() const { + return make_range(retval_begin(), retval_end()); + } + + unsigned numReturns() const { + return std::distance(retval_begin(), retval_end()); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_end_results; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + class LLVM_LIBRARY_VISIBILITY AnyCoroEndInst : public IntrinsicInst { - enum { FrameArg, UnwindArg }; + enum { FrameArg, UnwindArg, TokenArg }; public: bool isFallthrough() const { return !isUnwind(); } @@ -620,6 +647,15 @@ public: return cast<Constant>(getArgOperand(UnwindArg))->isOneValue(); } + bool hasResults() const { + return !isa<ConstantTokenNone>(getArgOperand(TokenArg)); + } + + CoroEndResults *getResults() const { + assert(hasResults()); + return cast<CoroEndResults>(getArgOperand(TokenArg)); + } + // Methods to support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I) { auto ID = I->getIntrinsicID(); diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h index 067fb6bba47e..0856c4925cc5 100644 --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -32,7 +32,7 @@ void replaceCoroFree(CoroIdInst *CoroId, bool Elide); /// OptimizeFrame is false. void salvageDebugInfo( SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, - DbgVariableIntrinsic *DVI, bool OptimizeFrame); + DbgVariableIntrinsic *DVI, bool OptimizeFrame, bool IsEntryPoint); // Keeps data and helper functions for lowering coroutine intrinsics. struct LowererBase { @@ -185,7 +185,8 @@ struct LLVM_LIBRARY_VISIBILITY Shape { switch (ABI) { case coro::ABI::Switch: return FunctionType::get(Type::getVoidTy(FrameTy->getContext()), - FrameTy->getPointerTo(), /*IsVarArg*/false); + PointerType::getUnqual(FrameTy->getContext()), + /*IsVarArg=*/false); case coro::ABI::Retcon: case coro::ABI::RetconOnce: return RetconLowering.ResumePrototype->getFunctionType(); diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 39e909bf3316..244580f503d5 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -234,6 +234,8 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End, switch (Shape.ABI) { // The cloned functions in switch-lowering always return void. case coro::ABI::Switch: + assert(!cast<CoroEndInst>(End)->hasResults() && + "switch coroutine should not return any values"); // coro.end doesn't immediately end the coroutine in the main function // in this lowering, because we need to deallocate the coroutine. if (!InResume) @@ -251,14 +253,45 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End, // In unique continuation lowering, the continuations always return void. // But we may have implicitly allocated storage. - case coro::ABI::RetconOnce: + case coro::ABI::RetconOnce: { maybeFreeRetconStorage(Builder, Shape, FramePtr, CG); - Builder.CreateRetVoid(); + auto *CoroEnd = cast<CoroEndInst>(End); + auto *RetTy = Shape.getResumeFunctionType()->getReturnType(); + + if (!CoroEnd->hasResults()) { + assert(RetTy->isVoidTy()); + Builder.CreateRetVoid(); + break; + } + + auto *CoroResults = CoroEnd->getResults(); + unsigned NumReturns = CoroResults->numReturns(); + + if (auto *RetStructTy = dyn_cast<StructType>(RetTy)) { + assert(RetStructTy->getNumElements() == NumReturns && + "numbers of returns should match resume function singature"); + Value *ReturnValue = UndefValue::get(RetStructTy); + unsigned Idx = 0; + for (Value *RetValEl : CoroResults->return_values()) + ReturnValue = Builder.CreateInsertValue(ReturnValue, RetValEl, Idx++); + Builder.CreateRet(ReturnValue); + } else if (NumReturns == 0) { + assert(RetTy->isVoidTy()); + Builder.CreateRetVoid(); + } else { + assert(NumReturns == 1); + Builder.CreateRet(*CoroResults->retval_begin()); + } + CoroResults->replaceAllUsesWith(ConstantTokenNone::get(CoroResults->getContext())); + CoroResults->eraseFromParent(); break; + } // In non-unique continuation lowering, we signal completion by returning // a null continuation. case coro::ABI::Retcon: { + assert(!cast<CoroEndInst>(End)->hasResults() && + "retcon coroutine should not return any values"); maybeFreeRetconStorage(Builder, Shape, FramePtr, CG); auto RetTy = Shape.getResumeFunctionType()->getReturnType(); auto RetStructTy = dyn_cast<StructType>(RetTy); @@ -457,7 +490,8 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { Switch->addCase(IndexVal, ResumeBB); cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); - auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front()); + auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, ""); + PN->insertBefore(LandingBB->begin()); S->replaceAllUsesWith(PN); PN->addIncoming(Builder.getInt8(-1), SuspendBB); PN->addIncoming(S, ResumeBB); @@ -495,13 +529,20 @@ void CoroCloner::handleFinalSuspend() { BasicBlock *OldSwitchBB = Switch->getParent(); auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); Builder.SetInsertPoint(OldSwitchBB->getTerminator()); - auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr, - coro::Shape::SwitchFieldIndex::Resume, - "ResumeFn.addr"); - auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(), - GepIndex); - auto *Cond = Builder.CreateIsNull(Load); - Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); + + if (NewF->isCoroOnlyDestroyWhenComplete()) { + // When the coroutine can only be destroyed when complete, we don't need + // to generate code for other cases. + Builder.CreateBr(ResumeBB); + } else { + auto *GepIndex = Builder.CreateStructGEP( + Shape.FrameTy, NewFramePtr, coro::Shape::SwitchFieldIndex::Resume, + "ResumeFn.addr"); + auto *Load = + Builder.CreateLoad(Shape.getSwitchResumePointerType(), GepIndex); + auto *Cond = Builder.CreateIsNull(Load); + Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); + } OldSwitchBB->getTerminator()->eraseFromParent(); } } @@ -701,8 +742,13 @@ void CoroCloner::salvageDebugInfo() { SmallVector<DbgVariableIntrinsic *, 8> Worklist = collectDbgVariableIntrinsics(*NewF); SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; + + // Only 64-bit ABIs have a register we can refer to with the entry value. + bool UseEntryValue = + llvm::Triple(OrigF.getParent()->getTargetTriple()).isArch64Bit(); for (DbgVariableIntrinsic *DVI : Worklist) - coro::salvageDebugInfo(ArgToAllocaMap, DVI, Shape.OptimizeFrame); + coro::salvageDebugInfo(ArgToAllocaMap, DVI, Shape.OptimizeFrame, + UseEntryValue); // Remove all salvaged dbg.declare intrinsics that became // either unreachable or stale due to the CoroSplit transformation. @@ -811,7 +857,6 @@ Value *CoroCloner::deriveNewFramePointer() { auto *ActiveAsyncSuspend = cast<CoroSuspendAsyncInst>(ActiveSuspend); auto ContextIdx = ActiveAsyncSuspend->getStorageArgumentIndex() & 0xff; auto *CalleeContext = NewF->getArg(ContextIdx); - auto *FramePtrTy = Shape.FrameTy->getPointerTo(); auto *ProjectionFunc = ActiveAsyncSuspend->getAsyncContextProjectionFunction(); auto DbgLoc = @@ -831,22 +876,20 @@ Value *CoroCloner::deriveNewFramePointer() { auto InlineRes = InlineFunction(*CallerContext, InlineInfo); assert(InlineRes.isSuccess()); (void)InlineRes; - return Builder.CreateBitCast(FramePtrAddr, FramePtrTy); + return FramePtrAddr; } // In continuation-lowering, the argument is the opaque storage. case coro::ABI::Retcon: case coro::ABI::RetconOnce: { Argument *NewStorage = &*NewF->arg_begin(); - auto FramePtrTy = Shape.FrameTy->getPointerTo(); + auto FramePtrTy = PointerType::getUnqual(Shape.FrameTy->getContext()); // If the storage is inline, just bitcast to the storage to the frame type. if (Shape.RetconLowering.IsFrameInlineInStorage) - return Builder.CreateBitCast(NewStorage, FramePtrTy); + return NewStorage; // Otherwise, load the real frame from the opaque storage. - auto FramePtrPtr = - Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo()); - return Builder.CreateLoad(FramePtrTy, FramePtrPtr); + return Builder.CreateLoad(FramePtrTy, NewStorage); } } llvm_unreachable("bad ABI"); @@ -940,9 +983,22 @@ void CoroCloner::create() { // abstract specification, since the DWARF backend expects the // abstract specification to contain the linkage name and asserts // that they are identical. - if (!SP->getDeclaration() && SP->getUnit() && - SP->getUnit()->getSourceLanguage() == dwarf::DW_LANG_Swift) + if (SP->getUnit() && + SP->getUnit()->getSourceLanguage() == dwarf::DW_LANG_Swift) { SP->replaceLinkageName(MDString::get(Context, NewF->getName())); + if (auto *Decl = SP->getDeclaration()) { + auto *NewDecl = DISubprogram::get( + Decl->getContext(), Decl->getScope(), Decl->getName(), + NewF->getName(), Decl->getFile(), Decl->getLine(), Decl->getType(), + Decl->getScopeLine(), Decl->getContainingType(), + Decl->getVirtualIndex(), Decl->getThisAdjustment(), + Decl->getFlags(), Decl->getSPFlags(), Decl->getUnit(), + Decl->getTemplateParams(), nullptr, Decl->getRetainedNodes(), + Decl->getThrownTypes(), Decl->getAnnotations(), + Decl->getTargetFuncName()); + SP->replaceDeclaration(NewDecl); + } + } } NewF->setLinkage(savedLinkage); @@ -1047,7 +1103,7 @@ void CoroCloner::create() { // Remap vFrame pointer. auto *NewVFrame = Builder.CreateBitCast( - NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); + NewFramePtr, PointerType::getUnqual(Builder.getContext()), "vFrame"); Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); if (OldVFrame != NewVFrame) OldVFrame->replaceAllUsesWith(NewVFrame); @@ -1178,7 +1234,7 @@ static void setCoroInfo(Function &F, coro::Shape &Shape, // Update coro.begin instruction to refer to this constant. LLVMContext &C = F.getContext(); - auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); + auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C)); Shape.getSwitchCoroId()->setInfo(BC); } @@ -1425,10 +1481,9 @@ static void handleNoSuspendCoroutine(coro::Shape &Shape) { IRBuilder<> Builder(AllocInst); auto *Frame = Builder.CreateAlloca(Shape.FrameTy); Frame->setAlignment(Shape.FrameAlign); - auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); AllocInst->replaceAllUsesWith(Builder.getFalse()); AllocInst->eraseFromParent(); - CoroBegin->replaceAllUsesWith(VFrame); + CoroBegin->replaceAllUsesWith(Frame); } else { CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); } @@ -1658,7 +1713,7 @@ static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend, Value *Continuation) { auto *ResumeIntrinsic = Suspend->getResumeFunction(); auto &Context = Suspend->getParent()->getParent()->getContext(); - auto *Int8PtrTy = Type::getInt8PtrTy(Context); + auto *Int8PtrTy = PointerType::getUnqual(Context); IRBuilder<> Builder(ResumeIntrinsic); auto *Val = Builder.CreateBitOrPointerCast(Continuation, Int8PtrTy); @@ -1711,7 +1766,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, F.removeRetAttr(Attribute::NonNull); auto &Context = F.getContext(); - auto *Int8PtrTy = Type::getInt8PtrTy(Context); + auto *Int8PtrTy = PointerType::getUnqual(Context); auto *Id = cast<CoroIdAsyncInst>(Shape.CoroBegin->getId()); IRBuilder<> Builder(Id); @@ -1829,9 +1884,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType()); // Stash the allocated frame pointer in the continuation storage. - auto Dest = Builder.CreateBitCast(Id->getStorage(), - RawFramePtr->getType()->getPointerTo()); - Builder.CreateStore(RawFramePtr, Dest); + Builder.CreateStore(RawFramePtr, Id->getStorage()); } // Map all uses of llvm.coro.begin to the allocated frame pointer. @@ -1987,7 +2040,8 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, // coroutine funclets. SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; for (auto *DDI : collectDbgVariableIntrinsics(F)) - coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame); + coro::salvageDebugInfo(ArgToAllocaMap, DDI, Shape.OptimizeFrame, + false /*UseEntryValue*/); return Shape; } diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index cde74c5e693b..eef5543bae24 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -37,16 +37,15 @@ using namespace llvm; // Construct the lowerer base class and initialize its members. coro::LowererBase::LowererBase(Module &M) : TheModule(M), Context(M.getContext()), - Int8Ptr(Type::getInt8PtrTy(Context)), + Int8Ptr(PointerType::get(Context, 0)), ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr, /*isVarArg=*/false)), NullPtr(ConstantPointerNull::get(Int8Ptr)) {} -// Creates a sequence of instructions to obtain a resume function address using -// llvm.coro.subfn.addr. It generates the following sequence: +// Creates a call to llvm.coro.subfn.addr to obtain a resume function address. +// It generates the following: // -// call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index) -// bitcast i8* %2 to void(i8*)* +// call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index) Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt) { @@ -56,11 +55,7 @@ Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, assert(Index >= CoroSubFnInst::IndexFirst && Index < CoroSubFnInst::IndexLast && "makeSubFnCall: Index value out of range"); - auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt); - - auto *Bitcast = - new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt); - return Bitcast; + return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt); } // NOTE: Must be sorted! @@ -137,8 +132,9 @@ void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) { return; Value *Replacement = - Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext())) - : CoroFrees.front()->getFrame(); + Elide + ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0)) + : CoroFrees.front()->getFrame(); for (CoroFreeInst *CF : CoroFrees) { CF->replaceAllUsesWith(Replacement); @@ -267,7 +263,7 @@ void coro::Shape::buildFrom(Function &F) { if (!CoroBegin) { // Replace coro.frame which are supposed to be lowered to the result of // coro.begin with undef. - auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext())); + auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0)); for (CoroFrameInst *CF : CoroFrames) { CF->replaceAllUsesWith(Undef); CF->eraseFromParent(); diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp new file mode 100644 index 000000000000..fb7cba9edbdb --- /dev/null +++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp @@ -0,0 +1,312 @@ +//===----- HipStdPar.cpp - HIP C++ Standard Parallelism Support Passes ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file implements two passes that enable HIP C++ Standard Parallelism +// Support: +// +// 1. AcceleratorCodeSelection (required): Given that only algorithms are +// accelerated, and that the accelerated implementation exists in the form of +// a compute kernel, we assume that only the kernel, and all functions +// reachable from it, constitute code that the user expects the accelerator +// to execute. Thus, we identify the set of all functions reachable from +// kernels, and then remove all unreachable ones. This last part is necessary +// because it is possible for code that the user did not expect to execute on +// an accelerator to contain constructs that cannot be handled by the target +// BE, which cannot be provably demonstrated to be dead code in general, and +// thus can lead to mis-compilation. The degenerate case of this is when a +// Module contains no kernels (the parent TU had no algorithm invocations fit +// for acceleration), which we handle by completely emptying said module. +// **NOTE**: The above does not handle indirectly reachable functions i.e. +// it is possible to obtain a case where the target of an indirect +// call is otherwise unreachable and thus is removed; this +// restriction is aligned with the current `-hipstdpar` limitations +// and will be relaxed in the future. +// +// 2. AllocationInterposition (required only when on-demand paging is +// unsupported): Some accelerators or operating systems might not support +// transparent on-demand paging. Thus, they would only be able to access +// memory that is allocated by an accelerator-aware mechanism. For such cases +// the user can opt into enabling allocation / deallocation interposition, +// whereby we replace calls to known allocation / deallocation functions with +// calls to runtime implemented equivalents that forward the requests to +// accelerator-aware interfaces. We also support freeing system allocated +// memory that ends up in one of the runtime equivalents, since this can +// happen if e.g. a library that was compiled without interposition returns +// an allocation that can be validly passed to `free`. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/HipStdPar/HipStdPar.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include <cassert> +#include <string> +#include <utility> + +using namespace llvm; + +template<typename T> +static inline void eraseFromModule(T &ToErase) { + ToErase.replaceAllUsesWith(PoisonValue::get(ToErase.getType())); + ToErase.eraseFromParent(); +} + +static inline bool checkIfSupported(GlobalVariable &G) { + if (!G.isThreadLocal()) + return true; + + G.dropDroppableUses(); + + if (!G.isConstantUsed()) + return true; + + std::string W; + raw_string_ostream OS(W); + + OS << "Accelerator does not support the thread_local variable " + << G.getName(); + + Instruction *I = nullptr; + SmallVector<User *> Tmp(G.user_begin(), G.user_end()); + SmallPtrSet<User *, 5> Visited; + do { + auto U = std::move(Tmp.back()); + Tmp.pop_back(); + + if (Visited.contains(U)) + continue; + + if (isa<Instruction>(U)) + I = cast<Instruction>(U); + else + Tmp.insert(Tmp.end(), U->user_begin(), U->user_end()); + + Visited.insert(U); + } while (!I && !Tmp.empty()); + + assert(I && "thread_local global should have at least one non-constant use."); + + G.getContext().diagnose( + DiagnosticInfoUnsupported(*I->getParent()->getParent(), W, + I->getDebugLoc(), DS_Error)); + + return false; +} + +static inline void clearModule(Module &M) { // TODO: simplify. + while (!M.functions().empty()) + eraseFromModule(*M.begin()); + while (!M.globals().empty()) + eraseFromModule(*M.globals().begin()); + while (!M.aliases().empty()) + eraseFromModule(*M.aliases().begin()); + while (!M.ifuncs().empty()) + eraseFromModule(*M.ifuncs().begin()); +} + +static inline void maybeHandleGlobals(Module &M) { + unsigned GlobAS = M.getDataLayout().getDefaultGlobalsAddressSpace(); + for (auto &&G : M.globals()) { // TODO: should we handle these in the FE? + if (!checkIfSupported(G)) + return clearModule(M); + + if (G.isThreadLocal()) + continue; + if (G.isConstant()) + continue; + if (G.getAddressSpace() != GlobAS) + continue; + if (G.getLinkage() != GlobalVariable::ExternalLinkage) + continue; + + G.setLinkage(GlobalVariable::ExternalWeakLinkage); + G.setExternallyInitialized(true); + } +} + +template<unsigned N> +static inline void removeUnreachableFunctions( + const SmallPtrSet<const Function *, N>& Reachable, Module &M) { + removeFromUsedLists(M, [&](Constant *C) { + if (auto F = dyn_cast<Function>(C)) + return !Reachable.contains(F); + + return false; + }); + + SmallVector<std::reference_wrapper<Function>> ToRemove; + copy_if(M, std::back_inserter(ToRemove), [&](auto &&F) { + return !F.isIntrinsic() && !Reachable.contains(&F); + }); + + for_each(ToRemove, eraseFromModule<Function>); +} + +static inline bool isAcceleratorExecutionRoot(const Function *F) { + if (!F) + return false; + + return F->getCallingConv() == CallingConv::AMDGPU_KERNEL; +} + +static inline bool checkIfSupported(const Function *F, const CallBase *CB) { + const auto Dx = F->getName().rfind("__hipstdpar_unsupported"); + + if (Dx == StringRef::npos) + return true; + + const auto N = F->getName().substr(0, Dx); + + std::string W; + raw_string_ostream OS(W); + + if (N == "__ASM") + OS << "Accelerator does not support the ASM block:\n" + << cast<ConstantDataArray>(CB->getArgOperand(0))->getAsCString(); + else + OS << "Accelerator does not support the " << N << " function."; + + auto Caller = CB->getParent()->getParent(); + + Caller->getContext().diagnose( + DiagnosticInfoUnsupported(*Caller, W, CB->getDebugLoc(), DS_Error)); + + return false; +} + +PreservedAnalyses + HipStdParAcceleratorCodeSelectionPass::run(Module &M, + ModuleAnalysisManager &MAM) { + auto &CGA = MAM.getResult<CallGraphAnalysis>(M); + + SmallPtrSet<const Function *, 32> Reachable; + for (auto &&CGN : CGA) { + if (!isAcceleratorExecutionRoot(CGN.first)) + continue; + + Reachable.insert(CGN.first); + + SmallVector<const Function *> Tmp({CGN.first}); + do { + auto F = std::move(Tmp.back()); + Tmp.pop_back(); + + for (auto &&N : *CGA[F]) { + if (!N.second) + continue; + if (!N.second->getFunction()) + continue; + if (Reachable.contains(N.second->getFunction())) + continue; + + if (!checkIfSupported(N.second->getFunction(), + dyn_cast<CallBase>(*N.first))) + return PreservedAnalyses::none(); + + Reachable.insert(N.second->getFunction()); + Tmp.push_back(N.second->getFunction()); + } + } while (!std::empty(Tmp)); + } + + if (std::empty(Reachable)) + clearModule(M); + else + removeUnreachableFunctions(Reachable, M); + + maybeHandleGlobals(M); + + return PreservedAnalyses::none(); +} + +static constexpr std::pair<StringLiteral, StringLiteral> ReplaceMap[]{ + {"aligned_alloc", "__hipstdpar_aligned_alloc"}, + {"calloc", "__hipstdpar_calloc"}, + {"free", "__hipstdpar_free"}, + {"malloc", "__hipstdpar_malloc"}, + {"memalign", "__hipstdpar_aligned_alloc"}, + {"posix_memalign", "__hipstdpar_posix_aligned_alloc"}, + {"realloc", "__hipstdpar_realloc"}, + {"reallocarray", "__hipstdpar_realloc_array"}, + {"_ZdaPv", "__hipstdpar_operator_delete"}, + {"_ZdaPvm", "__hipstdpar_operator_delete_sized"}, + {"_ZdaPvSt11align_val_t", "__hipstdpar_operator_delete_aligned"}, + {"_ZdaPvmSt11align_val_t", "__hipstdpar_operator_delete_aligned_sized"}, + {"_ZdlPv", "__hipstdpar_operator_delete"}, + {"_ZdlPvm", "__hipstdpar_operator_delete_sized"}, + {"_ZdlPvSt11align_val_t", "__hipstdpar_operator_delete_aligned"}, + {"_ZdlPvmSt11align_val_t", "__hipstdpar_operator_delete_aligned_sized"}, + {"_Znam", "__hipstdpar_operator_new"}, + {"_ZnamRKSt9nothrow_t", "__hipstdpar_operator_new_nothrow"}, + {"_ZnamSt11align_val_t", "__hipstdpar_operator_new_aligned"}, + {"_ZnamSt11align_val_tRKSt9nothrow_t", + "__hipstdpar_operator_new_aligned_nothrow"}, + + {"_Znwm", "__hipstdpar_operator_new"}, + {"_ZnwmRKSt9nothrow_t", "__hipstdpar_operator_new_nothrow"}, + {"_ZnwmSt11align_val_t", "__hipstdpar_operator_new_aligned"}, + {"_ZnwmSt11align_val_tRKSt9nothrow_t", + "__hipstdpar_operator_new_aligned_nothrow"}, + {"__builtin_calloc", "__hipstdpar_calloc"}, + {"__builtin_free", "__hipstdpar_free"}, + {"__builtin_malloc", "__hipstdpar_malloc"}, + {"__builtin_operator_delete", "__hipstdpar_operator_delete"}, + {"__builtin_operator_new", "__hipstdpar_operator_new"}, + {"__builtin_realloc", "__hipstdpar_realloc"}, + {"__libc_calloc", "__hipstdpar_calloc"}, + {"__libc_free", "__hipstdpar_free"}, + {"__libc_malloc", "__hipstdpar_malloc"}, + {"__libc_memalign", "__hipstdpar_aligned_alloc"}, + {"__libc_realloc", "__hipstdpar_realloc"} +}; + +PreservedAnalyses +HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) { + SmallDenseMap<StringRef, StringRef> AllocReplacements(std::cbegin(ReplaceMap), + std::cend(ReplaceMap)); + + for (auto &&F : M) { + if (!F.hasName()) + continue; + if (!AllocReplacements.contains(F.getName())) + continue; + + if (auto R = M.getFunction(AllocReplacements[F.getName()])) { + F.replaceAllUsesWith(R); + } else { + std::string W; + raw_string_ostream OS(W); + + OS << "cannot be interposed, missing: " << AllocReplacements[F.getName()] + << ". Tried to run the allocation interposition pass without the " + << "replacement functions available."; + + F.getContext().diagnose(DiagnosticInfoUnsupported(F, W, + F.getSubprogram(), + DS_Warning)); + } + } + + if (auto F = M.getFunction("__hipstdpar_hidden_free")) { + auto LibcFree = M.getOrInsertFunction("__libc_free", F->getFunctionType(), + F->getAttributes()); + F->replaceAllUsesWith(LibcFree.getCallee()); + + eraseFromModule(*F); + } + + return PreservedAnalyses::none(); +} diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 824da6395f2e..fb3fa8d23daa 100644 --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -121,19 +121,24 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // that we are *not* promoting. For the ones that we do promote, the parameter // attributes are lost SmallVector<AttributeSet, 8> ArgAttrVec; + // Mapping from old to new argument indices. -1 for promoted or removed + // arguments. + SmallVector<unsigned> NewArgIndices; AttributeList PAL = F->getAttributes(); // First, determine the new argument list - unsigned ArgNo = 0; + unsigned ArgNo = 0, NewArgNo = 0; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I, ++ArgNo) { if (!ArgsToPromote.count(&*I)) { // Unchanged argument Params.push_back(I->getType()); ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo)); + NewArgIndices.push_back(NewArgNo++); } else if (I->use_empty()) { // Dead argument (which are always marked as promotable) ++NumArgumentsDead; + NewArgIndices.push_back((unsigned)-1); } else { const auto &ArgParts = ArgsToPromote.find(&*I)->second; for (const auto &Pair : ArgParts) { @@ -141,6 +146,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, ArgAttrVec.push_back(AttributeSet()); } ++NumArgumentsPromoted; + NewArgIndices.push_back((unsigned)-1); + NewArgNo += ArgParts.size(); } } @@ -154,6 +161,7 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, F->getName()); NF->copyAttributesFrom(F); NF->copyMetadata(F, 0); + NF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat); // The new function will have the !dbg metadata copied from the original // function. The original function may not be deleted, and dbg metadata need @@ -173,6 +181,19 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // the function. NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(), PAL.getRetAttrs(), ArgAttrVec)); + + // Remap argument indices in allocsize attribute. + if (auto AllocSize = NF->getAttributes().getFnAttrs().getAllocSizeArgs()) { + unsigned Arg1 = NewArgIndices[AllocSize->first]; + assert(Arg1 != (unsigned)-1 && "allocsize cannot be promoted argument"); + std::optional<unsigned> Arg2; + if (AllocSize->second) { + Arg2 = NewArgIndices[*AllocSize->second]; + assert(Arg2 != (unsigned)-1 && "allocsize cannot be promoted argument"); + } + NF->addFnAttr(Attribute::getWithAllocSizeArgs(F->getContext(), Arg1, Arg2)); + } + AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth); ArgAttrVec.clear(); diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp index 847d07a49dee..d8e290cbc8a4 100644 --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" @@ -50,6 +51,7 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include <cstdint> +#include <memory> #ifdef EXPENSIVE_CHECKS #include "llvm/IR/Verifier.h" @@ -93,6 +95,13 @@ static cl::opt<unsigned> cl::desc("Maximal number of fixpoint iterations."), cl::init(32)); +static cl::opt<unsigned> + MaxSpecializationPerCB("attributor-max-specializations-per-call-base", + cl::Hidden, + cl::desc("Maximal number of callees specialized for " + "a call base"), + cl::init(UINT32_MAX)); + static cl::opt<unsigned, true> MaxInitializationChainLengthX( "attributor-max-initialization-chain-length", cl::Hidden, cl::desc( @@ -166,6 +175,10 @@ static cl::opt<bool> SimplifyAllLoads("attributor-simplify-all-loads", cl::desc("Try to simplify all loads."), cl::init(true)); +static cl::opt<bool> CloseWorldAssumption( + "attributor-assume-closed-world", cl::Hidden, + cl::desc("Should a closed world be assumed, or not. Default if not set.")); + /// Logic operators for the change status enum class. /// ///{ @@ -226,10 +239,10 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA, return InstanceInfoAA && InstanceInfoAA->isAssumedUniqueForAnalysis(); } -Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty, - const TargetLibraryInfo *TLI, - const DataLayout &DL, - AA::RangeTy *RangePtr) { +Constant * +AA::getInitialValueForObj(Attributor &A, const AbstractAttribute &QueryingAA, + Value &Obj, Type &Ty, const TargetLibraryInfo *TLI, + const DataLayout &DL, AA::RangeTy *RangePtr) { if (isa<AllocaInst>(Obj)) return UndefValue::get(&Ty); if (Constant *Init = getInitialValueOfAllocation(&Obj, TLI, &Ty)) @@ -242,12 +255,13 @@ Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty, Constant *Initializer = nullptr; if (A.hasGlobalVariableSimplificationCallback(*GV)) { auto AssumedGV = A.getAssumedInitializerFromCallBack( - *GV, /* const AbstractAttribute *AA */ nullptr, UsedAssumedInformation); + *GV, &QueryingAA, UsedAssumedInformation); Initializer = *AssumedGV; if (!Initializer) return nullptr; } else { - if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer())) + if (!GV->hasLocalLinkage() && + (GV->isInterposable() || !(GV->isConstant() && GV->hasInitializer()))) return nullptr; if (!GV->hasInitializer()) return UndefValue::get(&Ty); @@ -316,7 +330,7 @@ Value *AA::getWithType(Value &V, Type &Ty) { if (C->getType()->isIntegerTy() && Ty.isIntegerTy()) return ConstantExpr::getTrunc(C, &Ty, /* OnlyIfReduced */ true); if (C->getType()->isFloatingPointTy() && Ty.isFloatingPointTy()) - return ConstantExpr::getFPTrunc(C, &Ty, /* OnlyIfReduced */ true); + return ConstantFoldCastInstruction(Instruction::FPTrunc, C, &Ty); } } return nullptr; @@ -350,7 +364,7 @@ AA::combineOptionalValuesInAAValueLatice(const std::optional<Value *> &A, template <bool IsLoad, typename Ty> static bool getPotentialCopiesOfMemoryValue( Attributor &A, Ty &I, SmallSetVector<Value *, 4> &PotentialCopies, - SmallSetVector<Instruction *, 4> &PotentialValueOrigins, + SmallSetVector<Instruction *, 4> *PotentialValueOrigins, const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, bool OnlyExact) { LLVM_DEBUG(dbgs() << "Trying to determine the potential copies of " << I @@ -361,8 +375,8 @@ static bool getPotentialCopiesOfMemoryValue( // sure that we can find all of them. If we abort we want to avoid spurious // dependences and potential copies in the provided container. SmallVector<const AAPointerInfo *> PIs; - SmallVector<Value *> NewCopies; - SmallVector<Instruction *> NewCopyOrigins; + SmallSetVector<Value *, 8> NewCopies; + SmallSetVector<Instruction *, 8> NewCopyOrigins; const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*I.getFunction()); @@ -425,6 +439,30 @@ static bool getPotentialCopiesOfMemoryValue( return AdjV; }; + auto SkipCB = [&](const AAPointerInfo::Access &Acc) { + if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead())) + return true; + if (IsLoad) { + if (Acc.isWrittenValueYetUndetermined()) + return true; + if (PotentialValueOrigins && !isa<AssumeInst>(Acc.getRemoteInst())) + return false; + if (!Acc.isWrittenValueUnknown()) + if (Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue())) + if (NewCopies.count(V)) { + NewCopyOrigins.insert(Acc.getRemoteInst()); + return true; + } + if (auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst())) + if (Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand())) + if (NewCopies.count(V)) { + NewCopyOrigins.insert(Acc.getRemoteInst()); + return true; + } + } + return false; + }; + auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) { if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead())) return true; @@ -449,8 +487,9 @@ static bool getPotentialCopiesOfMemoryValue( Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue()); if (!V) return false; - NewCopies.push_back(V); - NewCopyOrigins.push_back(Acc.getRemoteInst()); + NewCopies.insert(V); + if (PotentialValueOrigins) + NewCopyOrigins.insert(Acc.getRemoteInst()); return true; } auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst()); @@ -463,8 +502,9 @@ static bool getPotentialCopiesOfMemoryValue( Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand()); if (!V) return false; - NewCopies.push_back(V); - NewCopyOrigins.push_back(SI); + NewCopies.insert(V); + if (PotentialValueOrigins) + NewCopyOrigins.insert(SI); } else { assert(isa<StoreInst>(I) && "Expected load or store instruction only!"); auto *LI = dyn_cast<LoadInst>(Acc.getRemoteInst()); @@ -474,7 +514,7 @@ static bool getPotentialCopiesOfMemoryValue( << *Acc.getRemoteInst() << "\n";); return false; } - NewCopies.push_back(Acc.getRemoteInst()); + NewCopies.insert(Acc.getRemoteInst()); } return true; }; @@ -486,11 +526,11 @@ static bool getPotentialCopiesOfMemoryValue( AA::RangeTy Range; auto *PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj), DepClassTy::NONE); - if (!PI || - !PI->forallInterferingAccesses(A, QueryingAA, I, - /* FindInterferingWrites */ IsLoad, - /* FindInterferingReads */ !IsLoad, - CheckAccess, HasBeenWrittenTo, Range)) { + if (!PI || !PI->forallInterferingAccesses( + A, QueryingAA, I, + /* FindInterferingWrites */ IsLoad, + /* FindInterferingReads */ !IsLoad, CheckAccess, + HasBeenWrittenTo, Range, SkipCB)) { LLVM_DEBUG( dbgs() << "Failed to verify all interfering accesses for underlying object: " @@ -500,8 +540,8 @@ static bool getPotentialCopiesOfMemoryValue( if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) { const DataLayout &DL = A.getDataLayout(); - Value *InitialValue = - AA::getInitialValueForObj(A, Obj, *I.getType(), TLI, DL, &Range); + Value *InitialValue = AA::getInitialValueForObj( + A, QueryingAA, Obj, *I.getType(), TLI, DL, &Range); if (!InitialValue) { LLVM_DEBUG(dbgs() << "Could not determine required initial value of " "underlying object, abort!\n"); @@ -514,8 +554,9 @@ static bool getPotentialCopiesOfMemoryValue( return false; } - NewCopies.push_back(InitialValue); - NewCopyOrigins.push_back(nullptr); + NewCopies.insert(InitialValue); + if (PotentialValueOrigins) + NewCopyOrigins.insert(nullptr); } PIs.push_back(PI); @@ -540,7 +581,8 @@ static bool getPotentialCopiesOfMemoryValue( A.recordDependence(*PI, QueryingAA, DepClassTy::OPTIONAL); } PotentialCopies.insert(NewCopies.begin(), NewCopies.end()); - PotentialValueOrigins.insert(NewCopyOrigins.begin(), NewCopyOrigins.end()); + if (PotentialValueOrigins) + PotentialValueOrigins->insert(NewCopyOrigins.begin(), NewCopyOrigins.end()); return true; } @@ -551,7 +593,7 @@ bool AA::getPotentiallyLoadedValues( const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, bool OnlyExact) { return getPotentialCopiesOfMemoryValue</* IsLoad */ true>( - A, LI, PotentialValues, PotentialValueOrigins, QueryingAA, + A, LI, PotentialValues, &PotentialValueOrigins, QueryingAA, UsedAssumedInformation, OnlyExact); } @@ -559,10 +601,9 @@ bool AA::getPotentialCopiesOfStoredValue( Attributor &A, StoreInst &SI, SmallSetVector<Value *, 4> &PotentialCopies, const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, bool OnlyExact) { - SmallSetVector<Instruction *, 4> PotentialValueOrigins; return getPotentialCopiesOfMemoryValue</* IsLoad */ false>( - A, SI, PotentialCopies, PotentialValueOrigins, QueryingAA, - UsedAssumedInformation, OnlyExact); + A, SI, PotentialCopies, nullptr, QueryingAA, UsedAssumedInformation, + OnlyExact); } static bool isAssumedReadOnlyOrReadNone(Attributor &A, const IRPosition &IRP, @@ -723,7 +764,7 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, // Check if we can reach returns. bool UsedAssumedInformation = false; - if (A.checkForAllInstructions(ReturnInstCB, FromFn, QueryingAA, + if (A.checkForAllInstructions(ReturnInstCB, FromFn, &QueryingAA, {Instruction::Ret}, UsedAssumedInformation)) { LLVM_DEBUG(dbgs() << "[AA] No return is reachable, done\n"); continue; @@ -1021,6 +1062,23 @@ ChangeStatus AbstractAttribute::update(Attributor &A) { return HasChanged; } +Attributor::Attributor(SetVector<Function *> &Functions, + InformationCache &InfoCache, + AttributorConfig Configuration) + : Allocator(InfoCache.Allocator), Functions(Functions), + InfoCache(InfoCache), Configuration(Configuration) { + if (!isClosedWorldModule()) + return; + for (Function *Fn : Functions) + if (Fn->hasAddressTaken(/*PutOffender=*/nullptr, + /*IgnoreCallbackUses=*/false, + /*IgnoreAssumeLikeCalls=*/true, + /*IgnoreLLVMUsed=*/true, + /*IgnoreARCAttachedCall=*/false, + /*IgnoreCastedDirectCall=*/true)) + InfoCache.IndirectlyCallableFunctions.push_back(Fn); +} + bool Attributor::getAttrsFromAssumes(const IRPosition &IRP, Attribute::AttrKind AK, SmallVectorImpl<Attribute> &Attrs) { @@ -1053,8 +1111,7 @@ bool Attributor::getAttrsFromAssumes(const IRPosition &IRP, template <typename DescTy> ChangeStatus -Attributor::updateAttrMap(const IRPosition &IRP, - const ArrayRef<DescTy> &AttrDescs, +Attributor::updateAttrMap(const IRPosition &IRP, ArrayRef<DescTy> AttrDescs, function_ref<bool(const DescTy &, AttributeSet, AttributeMask &, AttrBuilder &)> CB) { @@ -1161,9 +1218,8 @@ void Attributor::getAttrs(const IRPosition &IRP, getAttrsFromAssumes(IRP, AK, Attrs); } -ChangeStatus -Attributor::removeAttrs(const IRPosition &IRP, - const ArrayRef<Attribute::AttrKind> &AttrKinds) { +ChangeStatus Attributor::removeAttrs(const IRPosition &IRP, + ArrayRef<Attribute::AttrKind> AttrKinds) { auto RemoveAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet, AttributeMask &AM, AttrBuilder &) { if (!AttrSet.hasAttribute(Kind)) @@ -1174,8 +1230,21 @@ Attributor::removeAttrs(const IRPosition &IRP, return updateAttrMap<Attribute::AttrKind>(IRP, AttrKinds, RemoveAttrCB); } +ChangeStatus Attributor::removeAttrs(const IRPosition &IRP, + ArrayRef<StringRef> Attrs) { + auto RemoveAttrCB = [&](StringRef Attr, AttributeSet AttrSet, + AttributeMask &AM, AttrBuilder &) -> bool { + if (!AttrSet.hasAttribute(Attr)) + return false; + AM.addAttribute(Attr); + return true; + }; + + return updateAttrMap<StringRef>(IRP, Attrs, RemoveAttrCB); +} + ChangeStatus Attributor::manifestAttrs(const IRPosition &IRP, - const ArrayRef<Attribute> &Attrs, + ArrayRef<Attribute> Attrs, bool ForceReplace) { LLVMContext &Ctx = IRP.getAnchorValue().getContext(); auto AddAttrCB = [&](const Attribute &Attr, AttributeSet AttrSet, @@ -1665,6 +1734,21 @@ bool Attributor::isAssumedDead(const BasicBlock &BB, return false; } +bool Attributor::checkForAllCallees( + function_ref<bool(ArrayRef<const Function *>)> Pred, + const AbstractAttribute &QueryingAA, const CallBase &CB) { + if (const Function *Callee = dyn_cast<Function>(CB.getCalledOperand())) + return Pred(Callee); + + const auto *CallEdgesAA = getAAFor<AACallEdges>( + QueryingAA, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); + if (!CallEdgesAA || CallEdgesAA->hasUnknownCallee()) + return false; + + const auto &Callees = CallEdgesAA->getOptimisticEdges(); + return Pred(Callees.getArrayRef()); +} + bool Attributor::checkForAllUses( function_ref<bool(const Use &, bool &)> Pred, const AbstractAttribute &QueryingAA, const Value &V, @@ -1938,7 +2022,7 @@ bool Attributor::checkForAllReturnedValues(function_ref<bool(Value &)> Pred, static bool checkForAllInstructionsImpl( Attributor *A, InformationCache::OpcodeInstMapTy &OpcodeInstMap, function_ref<bool(Instruction &)> Pred, const AbstractAttribute *QueryingAA, - const AAIsDead *LivenessAA, const ArrayRef<unsigned> &Opcodes, + const AAIsDead *LivenessAA, ArrayRef<unsigned> Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly = false, bool CheckPotentiallyDead = false) { for (unsigned Opcode : Opcodes) { @@ -1967,8 +2051,8 @@ static bool checkForAllInstructionsImpl( bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, const Function *Fn, - const AbstractAttribute &QueryingAA, - const ArrayRef<unsigned> &Opcodes, + const AbstractAttribute *QueryingAA, + ArrayRef<unsigned> Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, bool CheckPotentiallyDead) { @@ -1978,12 +2062,12 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, const IRPosition &QueryIRP = IRPosition::function(*Fn); const auto *LivenessAA = - CheckPotentiallyDead - ? nullptr - : (getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE)); + CheckPotentiallyDead && QueryingAA + ? (getAAFor<AAIsDead>(*QueryingAA, QueryIRP, DepClassTy::NONE)) + : nullptr; auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn); - if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA, + if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, QueryingAA, LivenessAA, Opcodes, UsedAssumedInformation, CheckBBLivenessOnly, CheckPotentiallyDead)) return false; @@ -1993,13 +2077,13 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, const AbstractAttribute &QueryingAA, - const ArrayRef<unsigned> &Opcodes, + ArrayRef<unsigned> Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, bool CheckPotentiallyDead) { const IRPosition &IRP = QueryingAA.getIRPosition(); const Function *AssociatedFunction = IRP.getAssociatedFunction(); - return checkForAllInstructions(Pred, AssociatedFunction, QueryingAA, Opcodes, + return checkForAllInstructions(Pred, AssociatedFunction, &QueryingAA, Opcodes, UsedAssumedInformation, CheckBBLivenessOnly, CheckPotentiallyDead); } @@ -2964,6 +3048,18 @@ ChangeStatus Attributor::rewriteFunctionSignatures( NewArgumentAttributes)); AttributeFuncs::updateMinLegalVectorWidthAttr(*NewFn, LargestVectorWidth); + // Remove argmem from the memory effects if we have no more pointer + // arguments, or they are readnone. + MemoryEffects ME = NewFn->getMemoryEffects(); + int ArgNo = -1; + if (ME.doesAccessArgPointees() && all_of(NewArgumentTypes, [&](Type *T) { + ++ArgNo; + return !T->isPtrOrPtrVectorTy() || + NewFn->hasParamAttribute(ArgNo, Attribute::ReadNone); + })) { + NewFn->setMemoryEffects(ME - MemoryEffects::argMemOnly()); + } + // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. @@ -3203,6 +3299,12 @@ InformationCache::FunctionInfo::~FunctionInfo() { It.getSecond()->~InstructionVectorTy(); } +const ArrayRef<Function *> +InformationCache::getIndirectlyCallableFunctions(Attributor &A) const { + assert(A.isClosedWorldModule() && "Cannot see all indirect callees!"); + return IndirectlyCallableFunctions; +} + void Attributor::recordDependence(const AbstractAttribute &FromAA, const AbstractAttribute &ToAA, DepClassTy DepClass) { @@ -3236,9 +3338,10 @@ void Attributor::checkAndQueryIRAttr(const IRPosition &IRP, AttributeSet Attrs) { bool IsKnown; if (!Attrs.hasAttribute(AK)) - if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE, - IsKnown)) - getOrCreateAAFor<AAType>(IRP); + if (!Configuration.Allowed || Configuration.Allowed->count(&AAType::ID)) + if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE, + IsKnown)) + getOrCreateAAFor<AAType>(IRP); } void Attributor::identifyDefaultAbstractAttributes(Function &F) { @@ -3285,6 +3388,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function might be "will-return". checkAndQueryIRAttr<Attribute::WillReturn, AAWillReturn>(FPos, FnAttrs); + // Every function might be marked "nosync" + checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs); + // Everything that is visible from the outside (=function, argument, return // positions), cannot be changed if the function is not IPO amendable. We can // however analyse the code inside. @@ -3293,9 +3399,6 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function can be nounwind. checkAndQueryIRAttr<Attribute::NoUnwind, AANoUnwind>(FPos, FnAttrs); - // Every function might be marked "nosync" - checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs); - // Every function might be "no-return". checkAndQueryIRAttr<Attribute::NoReturn, AANoReturn>(FPos, FnAttrs); @@ -3315,6 +3418,14 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function can track active assumptions. getOrCreateAAFor<AAAssumptionInfo>(FPos); + // If we're not using a dynamic mode for float, there's nothing worthwhile + // to infer. This misses the edge case denormal-fp-math="dynamic" and + // denormal-fp-math-f32=something, but that likely has no real world use. + DenormalMode Mode = F.getDenormalMode(APFloat::IEEEsingle()); + if (Mode.Input == DenormalMode::Dynamic || + Mode.Output == DenormalMode::Dynamic) + getOrCreateAAFor<AADenormalFPMath>(FPos); + // Return attributes are only appropriate if the return type is non void. Type *ReturnType = F.getReturnType(); if (!ReturnType->isVoidTy()) { @@ -3420,8 +3531,10 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand()); // TODO: Even if the callee is not known now we might be able to simplify // the call/callee. - if (!Callee) + if (!Callee) { + getOrCreateAAFor<AAIndirectCallInfo>(CBFnPos); return true; + } // Every call site can track active assumptions. getOrCreateAAFor<AAAssumptionInfo>(CBFnPos); @@ -3498,14 +3611,13 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { }; auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); - bool Success; + [[maybe_unused]] bool Success; bool UsedAssumedInformation = false; Success = checkForAllInstructionsImpl( nullptr, OpcodeInstMap, CallSitePred, nullptr, nullptr, {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, (unsigned)Instruction::Call}, UsedAssumedInformation); - (void)Success; assert(Success && "Expected the check call to be successful!"); auto LoadStorePred = [&](Instruction &I) -> bool { @@ -3531,10 +3643,26 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { nullptr, OpcodeInstMap, LoadStorePred, nullptr, nullptr, {(unsigned)Instruction::Load, (unsigned)Instruction::Store}, UsedAssumedInformation); - (void)Success; + assert(Success && "Expected the check call to be successful!"); + + // AllocaInstPredicate + auto AAAllocationInfoPred = [&](Instruction &I) -> bool { + getOrCreateAAFor<AAAllocationInfo>(IRPosition::value(I)); + return true; + }; + + Success = checkForAllInstructionsImpl( + nullptr, OpcodeInstMap, AAAllocationInfoPred, nullptr, nullptr, + {(unsigned)Instruction::Alloca}, UsedAssumedInformation); assert(Success && "Expected the check call to be successful!"); } +bool Attributor::isClosedWorldModule() const { + if (CloseWorldAssumption.getNumOccurrences()) + return CloseWorldAssumption; + return isModulePass() && Configuration.IsClosedWorldModule; +} + /// Helpers to ease debugging through output streams and print calls. /// ///{ @@ -3696,6 +3824,26 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, AttributorConfig AC(CGUpdater); AC.IsModulePass = IsModulePass; AC.DeleteFns = DeleteFns; + + /// Tracking callback for specialization of indirect calls. + DenseMap<CallBase *, std::unique_ptr<SmallPtrSet<Function *, 8>>> + IndirectCalleeTrackingMap; + if (MaxSpecializationPerCB.getNumOccurrences()) { + AC.IndirectCalleeSpecializationCallback = + [&](Attributor &, const AbstractAttribute &AA, CallBase &CB, + Function &Callee) { + if (MaxSpecializationPerCB == 0) + return false; + auto &Set = IndirectCalleeTrackingMap[&CB]; + if (!Set) + Set = std::make_unique<SmallPtrSet<Function *, 8>>(); + if (Set->size() >= MaxSpecializationPerCB) + return Set->contains(&Callee); + Set->insert(&Callee); + return true; + }; + } + Attributor A(Functions, InfoCache, AC); // Create shallow wrappers for all functions that are not IPO amendable @@ -3759,6 +3907,88 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, return Changed == ChangeStatus::CHANGED; } +static bool runAttributorLightOnFunctions(InformationCache &InfoCache, + SetVector<Function *> &Functions, + AnalysisGetter &AG, + CallGraphUpdater &CGUpdater, + FunctionAnalysisManager &FAM, + bool IsModulePass) { + if (Functions.empty()) + return false; + + LLVM_DEBUG({ + dbgs() << "[AttributorLight] Run on module with " << Functions.size() + << " functions:\n"; + for (Function *Fn : Functions) + dbgs() << " - " << Fn->getName() << "\n"; + }); + + // Create an Attributor and initially empty information cache that is filled + // while we identify default attribute opportunities. + AttributorConfig AC(CGUpdater); + AC.IsModulePass = IsModulePass; + AC.DeleteFns = false; + DenseSet<const char *> Allowed( + {&AAWillReturn::ID, &AANoUnwind::ID, &AANoRecurse::ID, &AANoSync::ID, + &AANoFree::ID, &AANoReturn::ID, &AAMemoryLocation::ID, + &AAMemoryBehavior::ID, &AAUnderlyingObjects::ID, &AANoCapture::ID, + &AAInterFnReachability::ID, &AAIntraFnReachability::ID, &AACallEdges::ID, + &AANoFPClass::ID, &AAMustProgress::ID, &AANonNull::ID}); + AC.Allowed = &Allowed; + AC.UseLiveness = false; + + Attributor A(Functions, InfoCache, AC); + + for (Function *F : Functions) { + if (F->hasExactDefinition()) + NumFnWithExactDefinition++; + else + NumFnWithoutExactDefinition++; + + // We look at internal functions only on-demand but if any use is not a + // direct call or outside the current set of analyzed functions, we have + // to do it eagerly. + if (F->hasLocalLinkage()) { + if (llvm::all_of(F->uses(), [&Functions](const Use &U) { + const auto *CB = dyn_cast<CallBase>(U.getUser()); + return CB && CB->isCallee(&U) && + Functions.count(const_cast<Function *>(CB->getCaller())); + })) + continue; + } + + // Populate the Attributor with abstract attribute opportunities in the + // function and the information cache with IR information. + A.identifyDefaultAbstractAttributes(*F); + } + + ChangeStatus Changed = A.run(); + + if (Changed == ChangeStatus::CHANGED) { + // Invalidate analyses for modified functions so that we don't have to + // invalidate all analyses for all functions in this SCC. + PreservedAnalyses FuncPA; + // We haven't changed the CFG for modified functions. + FuncPA.preserveSet<CFGAnalyses>(); + for (Function *Changed : A.getModifiedFunctions()) { + FAM.invalidate(*Changed, FuncPA); + // Also invalidate any direct callers of changed functions since analyses + // may care about attributes of direct callees. For example, MemorySSA + // cares about whether or not a call's callee modifies memory and queries + // that through function attributes. + for (auto *U : Changed->users()) { + if (auto *Call = dyn_cast<CallBase>(U)) { + if (Call->getCalledFunction() == Changed) + FAM.invalidate(*Call->getFunction(), FuncPA); + } + } + } + } + LLVM_DEBUG(dbgs() << "[Attributor] Done with " << Functions.size() + << " functions, result: " << Changed << ".\n"); + return Changed == ChangeStatus::CHANGED; +} + void AADepGraph::viewGraph() { llvm::ViewGraph(this, "Dependency Graph"); } void AADepGraph::dumpGraph() { @@ -3839,6 +4069,62 @@ PreservedAnalyses AttributorCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } +PreservedAnalyses AttributorLightPass::run(Module &M, + ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + AnalysisGetter AG(FAM, /* CachedOnly */ true); + + SetVector<Function *> Functions; + for (Function &F : M) + Functions.insert(&F); + + CallGraphUpdater CGUpdater; + BumpPtrAllocator Allocator; + InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); + if (runAttributorLightOnFunctions(InfoCache, Functions, AG, CGUpdater, FAM, + /* IsModulePass */ true)) { + PreservedAnalyses PA; + // We have not added or removed functions. + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We already invalidated all relevant function analyses above. + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; + } + return PreservedAnalyses::all(); +} + +PreservedAnalyses AttributorLightCGSCCPass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); + AnalysisGetter AG(FAM); + + SetVector<Function *> Functions; + for (LazyCallGraph::Node &N : C) + Functions.insert(&N.getFunction()); + + if (Functions.empty()) + return PreservedAnalyses::all(); + + Module &M = *Functions.back()->getParent(); + CallGraphUpdater CGUpdater; + CGUpdater.initialize(CG, C, AM, UR); + BumpPtrAllocator Allocator; + InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions); + if (runAttributorLightOnFunctions(InfoCache, Functions, AG, CGUpdater, FAM, + /* IsModulePass */ false)) { + PreservedAnalyses PA; + // We have not added or removed functions. + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We already invalidated all relevant function analyses above. + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; + } + return PreservedAnalyses::all(); +} namespace llvm { template <> struct GraphTraits<AADepGraphNode *> { diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 3a9a89d61355..889ebd7438bd 100644 --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -55,6 +55,7 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" @@ -64,12 +65,16 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/TypeSize.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> #include <numeric> #include <optional> +#include <string> using namespace llvm; @@ -188,6 +193,10 @@ PIPE_OPERATOR(AAPointerInfo) PIPE_OPERATOR(AAAssumptionInfo) PIPE_OPERATOR(AAUnderlyingObjects) PIPE_OPERATOR(AAAddressSpace) +PIPE_OPERATOR(AAAllocationInfo) +PIPE_OPERATOR(AAIndirectCallInfo) +PIPE_OPERATOR(AAGlobalValueInfo) +PIPE_OPERATOR(AADenormalFPMath) #undef PIPE_OPERATOR @@ -313,7 +322,6 @@ static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr, // If an offset is left we use byte-wise adjustment. if (IntOffset != 0) { - Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy()); Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(IntOffset), GEPName + ".b" + Twine(IntOffset.getZExtValue())); } @@ -377,7 +385,7 @@ getMinimalBaseOfPointer(Attributor &A, const AbstractAttribute &QueryingAA, /// Clamp the information known for all returned values of a function /// (identified by \p QueryingAA) into \p S. template <typename AAType, typename StateType = typename AAType::StateType, - Attribute::AttrKind IRAttributeKind = Attribute::None, + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind, bool RecurseForSelectAndPHI = true> static void clampReturnedValueStates( Attributor &A, const AAType &QueryingAA, StateType &S, @@ -400,7 +408,7 @@ static void clampReturnedValueStates( auto CheckReturnValue = [&](Value &RV) -> bool { const IRPosition &RVPos = IRPosition::value(RV, CBContext); // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { + if (Attribute::isEnumAttrKind(IRAttributeKind)) { bool IsKnown; return AA::hasAssumedIRAttr<IRAttributeKind>( A, &QueryingAA, RVPos, DepClassTy::REQUIRED, IsKnown); @@ -434,7 +442,7 @@ namespace { template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, bool PropagateCallBaseContext = false, - Attribute::AttrKind IRAttributeKind = Attribute::None, + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind, bool RecurseForSelectAndPHI = true> struct AAReturnedFromReturnedValues : public BaseType { AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A) @@ -455,7 +463,7 @@ struct AAReturnedFromReturnedValues : public BaseType { /// Clamp the information known at all call sites for a given argument /// (identified by \p QueryingAA) into \p S. template <typename AAType, typename StateType = typename AAType::StateType, - Attribute::AttrKind IRAttributeKind = Attribute::None> + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, StateType &S) { LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for " @@ -480,7 +488,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, return false; // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { + if (Attribute::isEnumAttrKind(IRAttributeKind)) { bool IsKnown; return AA::hasAssumedIRAttr<IRAttributeKind>( A, &QueryingAA, ACSArgPos, DepClassTy::REQUIRED, IsKnown); @@ -514,7 +522,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, /// context. template <typename AAType, typename BaseType, typename StateType = typename AAType::StateType, - Attribute::AttrKind IRAttributeKind = Attribute::None> + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> bool getArgumentStateFromCallBaseContext(Attributor &A, BaseType &QueryingAttribute, IRPosition &Pos, StateType &State) { @@ -529,7 +537,7 @@ bool getArgumentStateFromCallBaseContext(Attributor &A, const IRPosition CBArgPos = IRPosition::callsite_argument(*CBContext, ArgNo); // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { + if (Attribute::isEnumAttrKind(IRAttributeKind)) { bool IsKnown; return AA::hasAssumedIRAttr<IRAttributeKind>( A, &QueryingAttribute, CBArgPos, DepClassTy::REQUIRED, IsKnown); @@ -555,7 +563,7 @@ bool getArgumentStateFromCallBaseContext(Attributor &A, template <typename AAType, typename BaseType, typename StateType = typename AAType::StateType, bool BridgeCallBaseContext = false, - Attribute::AttrKind IRAttributeKind = Attribute::None> + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> struct AAArgumentFromCallSiteArguments : public BaseType { AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -585,45 +593,55 @@ struct AAArgumentFromCallSiteArguments : public BaseType { template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, bool IntroduceCallBaseContext = false, - Attribute::AttrKind IRAttributeKind = Attribute::None> -struct AACallSiteReturnedFromReturned : public BaseType { - AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A) - : BaseType(IRP, A) {} + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> +struct AACalleeToCallSite : public BaseType { + AACalleeToCallSite(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - assert(this->getIRPosition().getPositionKind() == - IRPosition::IRP_CALL_SITE_RETURNED && - "Can only wrap function returned positions for call site returned " - "positions!"); + auto IRPKind = this->getIRPosition().getPositionKind(); + assert((IRPKind == IRPosition::IRP_CALL_SITE_RETURNED || + IRPKind == IRPosition::IRP_CALL_SITE) && + "Can only wrap function returned positions for call site " + "returned positions!"); auto &S = this->getState(); - const Function *AssociatedFunction = - this->getIRPosition().getAssociatedFunction(); - if (!AssociatedFunction) - return S.indicatePessimisticFixpoint(); - - CallBase &CBContext = cast<CallBase>(this->getAnchorValue()); + CallBase &CB = cast<CallBase>(this->getAnchorValue()); if (IntroduceCallBaseContext) - LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" - << CBContext << "\n"); - - IRPosition FnPos = IRPosition::returned( - *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr); + LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" << CB + << "\n"); - // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { - bool IsKnown; - if (!AA::hasAssumedIRAttr<IRAttributeKind>(A, this, FnPos, - DepClassTy::REQUIRED, IsKnown)) - return S.indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](ArrayRef<const Function *> Callees) { + for (const Function *Callee : Callees) { + IRPosition FnPos = + IRPKind == llvm::IRPosition::IRP_CALL_SITE_RETURNED + ? IRPosition::returned(*Callee, + IntroduceCallBaseContext ? &CB : nullptr) + : IRPosition::function( + *Callee, IntroduceCallBaseContext ? &CB : nullptr); + // If possible, use the hasAssumedIRAttr interface. + if (Attribute::isEnumAttrKind(IRAttributeKind)) { + bool IsKnown; + if (!AA::hasAssumedIRAttr<IRAttributeKind>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnown)) + return false; + continue; + } - const AAType *AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED); - if (!AA) + const AAType *AA = + A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED); + if (!AA) + return false; + Changed |= clampStateAndIndicateChange(S, AA->getState()); + if (S.isAtFixpoint()) + return S.isValidState(); + } + return true; + }; + if (!A.checkForAllCallees(CalleePred, *this, CB)) return S.indicatePessimisticFixpoint(); - return clampStateAndIndicateChange(S, AA->getState()); + return Changed; } }; @@ -865,11 +883,9 @@ struct AA::PointerInfo::State : public AbstractState { AAPointerInfo::AccessKind Kind, Type *Ty, Instruction *RemoteI = nullptr); - using OffsetBinsTy = DenseMap<RangeTy, SmallSet<unsigned, 4>>; - - using const_bin_iterator = OffsetBinsTy::const_iterator; - const_bin_iterator begin() const { return OffsetBins.begin(); } - const_bin_iterator end() const { return OffsetBins.end(); } + AAPointerInfo::const_bin_iterator begin() const { return OffsetBins.begin(); } + AAPointerInfo::const_bin_iterator end() const { return OffsetBins.end(); } + int64_t numOffsetBins() const { return OffsetBins.size(); } const AAPointerInfo::Access &getAccess(unsigned Index) const { return AccessList[Index]; @@ -889,7 +905,7 @@ protected: // are all combined into a single Access object. This may result in loss of // information in RangeTy in the Access object. SmallVector<AAPointerInfo::Access> AccessList; - OffsetBinsTy OffsetBins; + AAPointerInfo::OffsetBinsTy OffsetBins; DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap; /// See AAPointerInfo::forallInterferingAccesses. @@ -1093,6 +1109,12 @@ struct AAPointerInfoImpl return AAPointerInfo::manifest(A); } + virtual const_bin_iterator begin() const override { return State::begin(); } + virtual const_bin_iterator end() const override { return State::end(); } + virtual int64_t numOffsetBins() const override { + return State::numOffsetBins(); + } + bool forallInterferingAccesses( AA::RangeTy Range, function_ref<bool(const AAPointerInfo::Access &, bool)> CB) @@ -1104,7 +1126,8 @@ struct AAPointerInfoImpl Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I, bool FindInterferingWrites, bool FindInterferingReads, function_ref<bool(const Access &, bool)> UserCB, bool &HasBeenWrittenTo, - AA::RangeTy &Range) const override { + AA::RangeTy &Range, + function_ref<bool(const Access &)> SkipCB) const override { HasBeenWrittenTo = false; SmallPtrSet<const Access *, 8> DominatingWrites; @@ -1183,6 +1206,11 @@ struct AAPointerInfoImpl A, this, IRPosition::function(Scope), DepClassTy::OPTIONAL, IsKnownNoRecurse); + // TODO: Use reaching kernels from AAKernelInfo (or move it to + // AAExecutionDomain) such that we allow scopes other than kernels as long + // as the reaching kernels are disjoint. + bool InstInKernel = Scope.hasFnAttribute("kernel"); + bool ObjHasKernelLifetime = false; const bool UseDominanceReasoning = FindInterferingWrites && IsKnownNoRecurse; const DominatorTree *DT = @@ -1215,6 +1243,7 @@ struct AAPointerInfoImpl // If the alloca containing function is not recursive the alloca // must be dead in the callee. const Function *AIFn = AI->getFunction(); + ObjHasKernelLifetime = AIFn->hasFnAttribute("kernel"); bool IsKnownNoRecurse; if (AA::hasAssumedIRAttr<Attribute::NoRecurse>( A, this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL, @@ -1224,7 +1253,8 @@ struct AAPointerInfoImpl } else if (auto *GV = dyn_cast<GlobalValue>(&getAssociatedValue())) { // If the global has kernel lifetime we can stop if we reach a kernel // as it is "dead" in the (unknown) callees. - if (HasKernelLifetime(GV, *GV->getParent())) + ObjHasKernelLifetime = HasKernelLifetime(GV, *GV->getParent()); + if (ObjHasKernelLifetime) IsLiveInCalleeCB = [](const Function &Fn) { return !Fn.hasFnAttribute("kernel"); }; @@ -1235,6 +1265,15 @@ struct AAPointerInfoImpl AA::InstExclusionSetTy ExclusionSet; auto AccessCB = [&](const Access &Acc, bool Exact) { + Function *AccScope = Acc.getRemoteInst()->getFunction(); + bool AccInSameScope = AccScope == &Scope; + + // If the object has kernel lifetime we can ignore accesses only reachable + // by other kernels. For now we only skip accesses *in* other kernels. + if (InstInKernel && ObjHasKernelLifetime && !AccInSameScope && + AccScope->hasFnAttribute("kernel")) + return true; + if (Exact && Acc.isMustAccess() && Acc.getRemoteInst() != &I) { if (Acc.isWrite() || (isa<LoadInst>(I) && Acc.isWriteOrAssumption())) ExclusionSet.insert(Acc.getRemoteInst()); @@ -1245,8 +1284,7 @@ struct AAPointerInfoImpl return true; bool Dominates = FindInterferingWrites && DT && Exact && - Acc.isMustAccess() && - (Acc.getRemoteInst()->getFunction() == &Scope) && + Acc.isMustAccess() && AccInSameScope && DT->dominates(Acc.getRemoteInst(), &I); if (Dominates) DominatingWrites.insert(&Acc); @@ -1276,6 +1314,8 @@ struct AAPointerInfoImpl // Helper to determine if we can skip a specific write access. auto CanSkipAccess = [&](const Access &Acc, bool Exact) { + if (SkipCB && SkipCB(Acc)) + return true; if (!CanIgnoreThreading(Acc)) return false; @@ -1817,9 +1857,14 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { LLVM_DEBUG(dbgs() << "[AAPointerInfo] Assumption found " << *Assumption.second << ": " << *LoadI << " == " << *Assumption.first << "\n"); - + bool UsedAssumedInformation = false; + std::optional<Value *> Content = nullptr; + if (Assumption.first) + Content = + A.getAssumedSimplified(*Assumption.first, *this, + UsedAssumedInformation, AA::Interprocedural); return handleAccess( - A, *Assumption.second, Assumption.first, AccessKind::AK_ASSUMPTION, + A, *Assumption.second, Content, AccessKind::AK_ASSUMPTION, OffsetInfoMap[CurPtr].Offsets, Changed, *LoadI->getType()); } @@ -2083,24 +2128,10 @@ struct AANoUnwindFunction final : public AANoUnwindImpl { }; /// NoUnwind attribute deduction for a call sites. -struct AANoUnwindCallSite final : AANoUnwindImpl { +struct AANoUnwindCallSite final + : AACalleeToCallSite<AANoUnwind, AANoUnwindImpl> { AANoUnwindCallSite(const IRPosition &IRP, Attributor &A) - : AANoUnwindImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoUnwind; - if (AA::hasAssumedIRAttr<Attribute::NoUnwind>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoUnwind)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); - } + : AACalleeToCallSite<AANoUnwind, AANoUnwindImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); } @@ -2200,8 +2231,15 @@ ChangeStatus AANoSyncImpl::updateImpl(Attributor &A) { if (I.mayReadOrWriteMemory()) return true; + bool IsKnown; + CallBase &CB = cast<CallBase>(I); + if (AA::hasAssumedIRAttr<Attribute::NoSync>( + A, this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL, + IsKnown)) + return true; + // non-convergent and readnone imply nosync. - return !cast<CallBase>(I).isConvergent(); + return !CB.isConvergent(); }; bool UsedAssumedInformation = false; @@ -2223,24 +2261,9 @@ struct AANoSyncFunction final : public AANoSyncImpl { }; /// NoSync attribute deduction for a call sites. -struct AANoSyncCallSite final : AANoSyncImpl { +struct AANoSyncCallSite final : AACalleeToCallSite<AANoSync, AANoSyncImpl> { AANoSyncCallSite(const IRPosition &IRP, Attributor &A) - : AANoSyncImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoSycn; - if (AA::hasAssumedIRAttr<Attribute::NoSync>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoSycn)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); - } + : AACalleeToCallSite<AANoSync, AANoSyncImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nosync); } @@ -2292,24 +2315,9 @@ struct AANoFreeFunction final : public AANoFreeImpl { }; /// NoFree attribute deduction for a call sites. -struct AANoFreeCallSite final : AANoFreeImpl { +struct AANoFreeCallSite final : AACalleeToCallSite<AANoFree, AANoFreeImpl> { AANoFreeCallSite(const IRPosition &IRP, Attributor &A) - : AANoFreeImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnown; - if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, FnPos, - DepClassTy::REQUIRED, IsKnown)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); - } + : AACalleeToCallSite<AANoFree, AANoFreeImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nofree); } @@ -2450,9 +2458,6 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP, if (A.hasAttr(IRP, AttrKinds, IgnoreSubsumingPositions, Attribute::NonNull)) return true; - if (IRP.getPositionKind() == IRP_RETURNED) - return false; - DominatorTree *DT = nullptr; AssumptionCache *AC = nullptr; InformationCache &InfoCache = A.getInfoCache(); @@ -2463,9 +2468,27 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP, } } - if (!isKnownNonZero(&IRP.getAssociatedValue(), A.getDataLayout(), 0, AC, - IRP.getCtxI(), DT)) + SmallVector<AA::ValueAndContext> Worklist; + if (IRP.getPositionKind() != IRP_RETURNED) { + Worklist.push_back({IRP.getAssociatedValue(), IRP.getCtxI()}); + } else { + bool UsedAssumedInformation = false; + if (!A.checkForAllInstructions( + [&](Instruction &I) { + Worklist.push_back({*cast<ReturnInst>(I).getReturnValue(), &I}); + return true; + }, + IRP.getAssociatedFunction(), nullptr, {Instruction::Ret}, + UsedAssumedInformation)) + return false; + } + + if (llvm::any_of(Worklist, [&](AA::ValueAndContext VAC) { + return !isKnownNonZero(VAC.getValue(), A.getDataLayout(), 0, AC, + VAC.getCtxI(), DT); + })) return false; + A.manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(), Attribute::NonNull)}); return true; @@ -2529,7 +2552,8 @@ static int64_t getKnownNonNullAndDerefBytesForUse( } std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); - if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || I->isVolatile()) + if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || + Loc->Size.isScalable() || I->isVolatile()) return 0; int64_t Offset; @@ -2610,6 +2634,23 @@ struct AANonNullFloating : public AANonNullImpl { Values.size() != 1 || Values.front().getValue() != AssociatedValue; if (!Stripped) { + bool IsKnown; + if (auto *PHI = dyn_cast<PHINode>(AssociatedValue)) + if (llvm::all_of(PHI->incoming_values(), [&](Value *Op) { + return AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*Op), DepClassTy::OPTIONAL, + IsKnown); + })) + return ChangeStatus::UNCHANGED; + if (auto *Select = dyn_cast<SelectInst>(AssociatedValue)) + if (AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*Select->getFalseValue()), + DepClassTy::OPTIONAL, IsKnown) && + AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*Select->getTrueValue()), + DepClassTy::OPTIONAL, IsKnown)) + return ChangeStatus::UNCHANGED; + // If we haven't stripped anything we might still be able to use a // different AA, but only if the IRP changes. Effectively when we // interpret this not as a call site value but as a floating/argument @@ -2634,10 +2675,11 @@ struct AANonNullFloating : public AANonNullImpl { /// NonNull attribute for function return value. struct AANonNullReturned final : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType, - false, AANonNull::IRAttributeKind> { + false, AANonNull::IRAttributeKind, false> { AANonNullReturned(const IRPosition &IRP, Attributor &A) : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType, - false, Attribute::NonNull>(IRP, A) {} + false, Attribute::NonNull, false>(IRP, A) { + } /// See AbstractAttribute::getAsStr(). const std::string getAsStr(Attributor *A) const override { @@ -2650,13 +2692,9 @@ struct AANonNullReturned final /// NonNull attribute for function argument. struct AANonNullArgument final - : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind> { + : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl> { AANonNullArgument(const IRPosition &IRP, Attributor &A) - : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind>(IRP, A) {} + : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) } @@ -2672,13 +2710,9 @@ struct AANonNullCallSiteArgument final : AANonNullFloating { /// NonNull attribute for a call site return position. struct AANonNullCallSiteReturned final - : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind> { + : AACalleeToCallSite<AANonNull, AANonNullImpl> { AANonNullCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind>(IRP, A) {} + : AACalleeToCallSite<AANonNull, AANonNullImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) } @@ -2830,24 +2864,10 @@ struct AANoRecurseFunction final : AANoRecurseImpl { }; /// NoRecurse attribute deduction for a call sites. -struct AANoRecurseCallSite final : AANoRecurseImpl { +struct AANoRecurseCallSite final + : AACalleeToCallSite<AANoRecurse, AANoRecurseImpl> { AANoRecurseCallSite(const IRPosition &IRP, Attributor &A) - : AANoRecurseImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoRecurse; - if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoRecurse)) - return indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + : AACalleeToCallSite<AANoRecurse, AANoRecurseImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(norecurse); } @@ -3355,26 +3375,17 @@ struct AAWillReturnFunction final : AAWillReturnImpl { }; /// WillReturn attribute deduction for a call sites. -struct AAWillReturnCallSite final : AAWillReturnImpl { +struct AAWillReturnCallSite final + : AACalleeToCallSite<AAWillReturn, AAWillReturnImpl> { AAWillReturnCallSite(const IRPosition &IRP, Attributor &A) - : AAWillReturnImpl(IRP, A) {} + : AACalleeToCallSite<AAWillReturn, AAWillReturnImpl>(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ false)) return ChangeStatus::UNCHANGED; - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnown; - if (AA::hasAssumedIRAttr<Attribute::WillReturn>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnown)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); + return AACalleeToCallSite::updateImpl(A); } /// See AbstractAttribute::trackStatistics() @@ -3402,6 +3413,18 @@ template <typename ToTy> struct ReachabilityQueryInfo { /// and remember if it worked: Reachable Result = Reachable::No; + /// Precomputed hash for this RQI. + unsigned Hash = 0; + + unsigned computeHashValue() const { + assert(Hash == 0 && "Computed hash twice!"); + using InstSetDMI = DenseMapInfo<const AA::InstExclusionSetTy *>; + using PairDMI = DenseMapInfo<std::pair<const Instruction *, const ToTy *>>; + return const_cast<ReachabilityQueryInfo<ToTy> *>(this)->Hash = + detail::combineHashValue(PairDMI ::getHashValue({From, To}), + InstSetDMI::getHashValue(ExclusionSet)); + } + ReachabilityQueryInfo(const Instruction *From, const ToTy *To) : From(From), To(To) {} @@ -3435,9 +3458,7 @@ template <typename ToTy> struct DenseMapInfo<ReachabilityQueryInfo<ToTy> *> { return &TombstoneKey; } static unsigned getHashValue(const ReachabilityQueryInfo<ToTy> *RQI) { - unsigned H = PairDMI ::getHashValue({RQI->From, RQI->To}); - H += InstSetDMI::getHashValue(RQI->ExclusionSet); - return H; + return RQI->Hash ? RQI->Hash : RQI->computeHashValue(); } static bool isEqual(const ReachabilityQueryInfo<ToTy> *LHS, const ReachabilityQueryInfo<ToTy> *RHS) { @@ -3480,24 +3501,24 @@ struct CachedReachabilityAA : public BaseTy { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; - InUpdate = true; for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) { RQITy *RQI = QueryVector[u]; - if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI)) + if (RQI->Result == RQITy::Reachable::No && + isReachableImpl(A, *RQI, /*IsTemporaryRQI=*/false)) Changed = ChangeStatus::CHANGED; } - InUpdate = false; return Changed; } - virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0; + virtual bool isReachableImpl(Attributor &A, RQITy &RQI, + bool IsTemporaryRQI) = 0; bool rememberResult(Attributor &A, typename RQITy::Reachable Result, - RQITy &RQI, bool UsedExclusionSet) { + RQITy &RQI, bool UsedExclusionSet, bool IsTemporaryRQI) { RQI.Result = Result; // Remove the temporary RQI from the cache. - if (!InUpdate) + if (IsTemporaryRQI) QueryCache.erase(&RQI); // Insert a plain RQI (w/o exclusion set) if that makes sense. Two options: @@ -3515,7 +3536,7 @@ struct CachedReachabilityAA : public BaseTy { } // Check if we need to insert a new permanent RQI with the exclusion set. - if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) { + if (IsTemporaryRQI && Result != RQITy::Reachable::Yes && UsedExclusionSet) { assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) && "Did not expect empty set!"); RQITy *RQIPtr = new (A.Allocator) @@ -3527,7 +3548,7 @@ struct CachedReachabilityAA : public BaseTy { QueryCache.insert(RQIPtr); } - if (Result == RQITy::Reachable::No && !InUpdate) + if (Result == RQITy::Reachable::No && IsTemporaryRQI) A.registerForUpdate(*this); return Result == RQITy::Reachable::Yes; } @@ -3568,7 +3589,6 @@ struct CachedReachabilityAA : public BaseTy { } private: - bool InUpdate = false; SmallVector<RQITy *> QueryVector; DenseSet<RQITy *> QueryCache; }; @@ -3577,7 +3597,10 @@ struct AAIntraFnReachabilityFunction final : public CachedReachabilityAA<AAIntraFnReachability, Instruction> { using Base = CachedReachabilityAA<AAIntraFnReachability, Instruction>; AAIntraFnReachabilityFunction(const IRPosition &IRP, Attributor &A) - : Base(IRP, A) {} + : Base(IRP, A) { + DT = A.getInfoCache().getAnalysisResultForFunction<DominatorTreeAnalysis>( + *IRP.getAssociatedFunction()); + } bool isAssumedReachable( Attributor &A, const Instruction &From, const Instruction &To, @@ -3589,7 +3612,8 @@ struct AAIntraFnReachabilityFunction final RQITy StackRQI(A, From, To, ExclusionSet, false); typename RQITy::Reachable Result; if (!NonConstThis->checkQueryCache(A, StackRQI, Result)) - return NonConstThis->isReachableImpl(A, StackRQI); + return NonConstThis->isReachableImpl(A, StackRQI, + /*IsTemporaryRQI=*/true); return Result == RQITy::Reachable::Yes; } @@ -3598,16 +3622,24 @@ struct AAIntraFnReachabilityFunction final // of them changed. auto *LivenessAA = A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); - if (LivenessAA && llvm::all_of(DeadEdges, [&](const auto &DeadEdge) { - return LivenessAA->isEdgeDead(DeadEdge.first, DeadEdge.second); + if (LivenessAA && + llvm::all_of(DeadEdges, + [&](const auto &DeadEdge) { + return LivenessAA->isEdgeDead(DeadEdge.first, + DeadEdge.second); + }) && + llvm::all_of(DeadBlocks, [&](const BasicBlock *BB) { + return LivenessAA->isAssumedDead(BB); })) { return ChangeStatus::UNCHANGED; } DeadEdges.clear(); + DeadBlocks.clear(); return Base::updateImpl(A); } - bool isReachableImpl(Attributor &A, RQITy &RQI) override { + bool isReachableImpl(Attributor &A, RQITy &RQI, + bool IsTemporaryRQI) override { const Instruction *Origin = RQI.From; bool UsedExclusionSet = false; @@ -3633,31 +3665,41 @@ struct AAIntraFnReachabilityFunction final // possible. if (FromBB == ToBB && WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); // Check if reaching the ToBB block is sufficient or if even that would not // ensure reaching the target. In the latter case we are done. if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); + const Function *Fn = FromBB->getParent(); SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks; if (RQI.ExclusionSet) for (auto *I : *RQI.ExclusionSet) - ExclusionBlocks.insert(I->getParent()); + if (I->getFunction() == Fn) + ExclusionBlocks.insert(I->getParent()); // Check if we make it out of the FromBB block at all. if (ExclusionBlocks.count(FromBB) && !WillReachInBlock(*RQI.From, *FromBB->getTerminator(), RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, true, IsTemporaryRQI); + + auto *LivenessAA = + A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (LivenessAA && LivenessAA->isAssumedDead(ToBB)) { + DeadBlocks.insert(ToBB); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); + } SmallPtrSet<const BasicBlock *, 16> Visited; SmallVector<const BasicBlock *, 16> Worklist; Worklist.push_back(FromBB); DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> LocalDeadEdges; - auto *LivenessAA = - A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); while (!Worklist.empty()) { const BasicBlock *BB = Worklist.pop_back_val(); if (!Visited.insert(BB).second) @@ -3669,8 +3711,12 @@ struct AAIntraFnReachabilityFunction final } // We checked before if we just need to reach the ToBB block. if (SuccBB == ToBB) - return rememberResult(A, RQITy::Reachable::Yes, RQI, - UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); + if (DT && ExclusionBlocks.empty() && DT->dominates(BB, ToBB)) + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); + if (ExclusionBlocks.count(SuccBB)) { UsedExclusionSet = true; continue; @@ -3680,16 +3726,24 @@ struct AAIntraFnReachabilityFunction final } DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end()); - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override {} private: + // Set of assumed dead blocks we used in the last query. If any changes we + // update the state. + DenseSet<const BasicBlock *> DeadBlocks; + // Set of assumed dead edges we used in the last query. If any changes we // update the state. DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> DeadEdges; + + /// The dominator tree of the function to short-circuit reasoning. + const DominatorTree *DT = nullptr; }; } // namespace @@ -3754,12 +3808,8 @@ struct AANoAliasFloating final : AANoAliasImpl { /// NoAlias attribute for an argument. struct AANoAliasArgument final - : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl, - AANoAlias::StateType, false, - Attribute::NoAlias> { - using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl, - AANoAlias::StateType, false, - Attribute::NoAlias>; + : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> { + using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>; AANoAliasArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} /// See AbstractAttribute::update(...). @@ -4027,24 +4077,10 @@ struct AANoAliasReturned final : AANoAliasImpl { }; /// NoAlias attribute deduction for a call site return value. -struct AANoAliasCallSiteReturned final : AANoAliasImpl { +struct AANoAliasCallSiteReturned final + : AACalleeToCallSite<AANoAlias, AANoAliasImpl> { AANoAliasCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AANoAliasImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::returned(*F); - bool IsKnownNoAlias; - if (!AA::hasAssumedIRAttr<Attribute::NoAlias>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoAlias)) - return indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + : AACalleeToCallSite<AANoAlias, AANoAliasImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noalias); } @@ -4696,23 +4732,53 @@ identifyAliveSuccessors(Attributor &A, const SwitchInst &SI, AbstractAttribute &AA, SmallVectorImpl<const Instruction *> &AliveSuccessors) { bool UsedAssumedInformation = false; - std::optional<Constant *> C = - A.getAssumedConstant(*SI.getCondition(), AA, UsedAssumedInformation); - if (!C || isa_and_nonnull<UndefValue>(*C)) { - // No value yet, assume all edges are dead. - } else if (isa_and_nonnull<ConstantInt>(*C)) { - for (const auto &CaseIt : SI.cases()) { - if (CaseIt.getCaseValue() == *C) { - AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front()); - return UsedAssumedInformation; - } - } - AliveSuccessors.push_back(&SI.getDefaultDest()->front()); + SmallVector<AA::ValueAndContext> Values; + if (!A.getAssumedSimplifiedValues(IRPosition::value(*SI.getCondition()), &AA, + Values, AA::AnyScope, + UsedAssumedInformation)) { + // Something went wrong, assume all successors are live. + for (const BasicBlock *SuccBB : successors(SI.getParent())) + AliveSuccessors.push_back(&SuccBB->front()); + return false; + } + + if (Values.empty() || + (Values.size() == 1 && + isa_and_nonnull<UndefValue>(Values.front().getValue()))) { + // No valid value yet, assume all edges are dead. return UsedAssumedInformation; - } else { + } + + Type &Ty = *SI.getCondition()->getType(); + SmallPtrSet<ConstantInt *, 8> Constants; + auto CheckForConstantInt = [&](Value *V) { + if (auto *CI = dyn_cast_if_present<ConstantInt>(AA::getWithType(*V, Ty))) { + Constants.insert(CI); + return true; + } + return false; + }; + + if (!all_of(Values, [&](AA::ValueAndContext &VAC) { + return CheckForConstantInt(VAC.getValue()); + })) { for (const BasicBlock *SuccBB : successors(SI.getParent())) AliveSuccessors.push_back(&SuccBB->front()); + return UsedAssumedInformation; + } + + unsigned MatchedCases = 0; + for (const auto &CaseIt : SI.cases()) { + if (Constants.count(CaseIt.getCaseValue())) { + ++MatchedCases; + AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front()); + } } + + // If all potential values have been matched, we will not visit the default + // case. + if (MatchedCases < Constants.size()) + AliveSuccessors.push_back(&SI.getDefaultDest()->front()); return UsedAssumedInformation; } @@ -5103,9 +5169,8 @@ struct AADereferenceableCallSiteArgument final : AADereferenceableFloating { /// Dereferenceable attribute deduction for a call site return value. struct AADereferenceableCallSiteReturned final - : AACallSiteReturnedFromReturned<AADereferenceable, AADereferenceableImpl> { - using Base = - AACallSiteReturnedFromReturned<AADereferenceable, AADereferenceableImpl>; + : AACalleeToCallSite<AADereferenceable, AADereferenceableImpl> { + using Base = AACalleeToCallSite<AADereferenceable, AADereferenceableImpl>; AADereferenceableCallSiteReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -5400,8 +5465,8 @@ struct AAAlignCallSiteArgument final : AAAlignFloating { /// Align attribute deduction for a call site return value. struct AAAlignCallSiteReturned final - : AACallSiteReturnedFromReturned<AAAlign, AAAlignImpl> { - using Base = AACallSiteReturnedFromReturned<AAAlign, AAAlignImpl>; + : AACalleeToCallSite<AAAlign, AAAlignImpl> { + using Base = AACalleeToCallSite<AAAlign, AAAlignImpl>; AAAlignCallSiteReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -5449,24 +5514,10 @@ struct AANoReturnFunction final : AANoReturnImpl { }; /// NoReturn attribute deduction for a call sites. -struct AANoReturnCallSite final : AANoReturnImpl { +struct AANoReturnCallSite final + : AACalleeToCallSite<AANoReturn, AANoReturnImpl> { AANoReturnCallSite(const IRPosition &IRP, Attributor &A) - : AANoReturnImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoReturn; - if (!AA::hasAssumedIRAttr<Attribute::NoReturn>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoReturn)) - return indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + : AACalleeToCallSite<AANoReturn, AANoReturnImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(noreturn); } @@ -5805,8 +5856,8 @@ struct AANoCaptureImpl : public AANoCapture { // For stores we already checked if we can follow them, if they make it // here we give up. if (isa<StoreInst>(UInst)) - return isCapturedIn(State, /* Memory */ true, /* Integer */ false, - /* Return */ false); + return isCapturedIn(State, /* Memory */ true, /* Integer */ true, + /* Return */ true); // Explicitly catch return instructions. if (isa<ReturnInst>(UInst)) { @@ -6476,7 +6527,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - return indicatePessimisticFixpoint(); + return indicatePessimisticFixpoint(); } void trackStatistics() const override { @@ -6937,13 +6988,17 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { << **DI->PotentialAllocationCalls.begin() << "\n"); return false; } - Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode(); - if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) { - LLVM_DEBUG( - dbgs() - << "[H2S] unique free call might not be executed with the allocation " - << *UniqueFree << "\n"); - return false; + + // __kmpc_alloc_shared and __kmpc_alloc_free are by construction matched. + if (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared) { + Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode(); + if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) { + LLVM_DEBUG( + dbgs() + << "[H2S] unique free call might not be executed with the allocation " + << *UniqueFree << "\n"); + return false; + } } return true; }; @@ -7796,6 +7851,9 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior { // Clear existing attributes. A.removeAttrs(IRP, AttrKinds); + // Clear conflicting writable attribute. + if (isAssumedReadOnly()) + A.removeAttrs(IRP, Attribute::Writable); // Use the generic manifest method. return IRAttribute::manifest(A); @@ -7983,6 +8041,10 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { ME = MemoryEffects::writeOnly(); A.removeAttrs(getIRPosition(), AttrKinds); + // Clear conflicting writable attribute. + if (ME.onlyReadsMemory()) + for (Argument &Arg : F.args()) + A.removeAttrs(IRPosition::argument(Arg), Attribute::Writable); return A.manifestAttrs(getIRPosition(), Attribute::getWithMemoryEffects(F.getContext(), ME)); } @@ -7999,24 +8061,10 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { }; /// AAMemoryBehavior attribute for call sites. -struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { +struct AAMemoryBehaviorCallSite final + : AACalleeToCallSite<AAMemoryBehavior, AAMemoryBehaviorImpl> { AAMemoryBehaviorCallSite(const IRPosition &IRP, Attributor &A) - : AAMemoryBehaviorImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto *FnAA = - A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::REQUIRED); - if (!FnAA) - return indicatePessimisticFixpoint(); - return clampStateAndIndicateChange(getState(), FnAA->getState()); - } + : AACalleeToCallSite<AAMemoryBehavior, AAMemoryBehaviorImpl>(IRP, A) {} /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { @@ -8031,6 +8079,11 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { ME = MemoryEffects::writeOnly(); A.removeAttrs(getIRPosition(), AttrKinds); + // Clear conflicting writable attribute. + if (ME.onlyReadsMemory()) + for (Use &U : CB.args()) + A.removeAttrs(IRPosition::callsite_argument(CB, U.getOperandNo()), + Attribute::Writable); return A.manifestAttrs( getIRPosition(), Attribute::getWithMemoryEffects(CB.getContext(), ME)); } @@ -8821,6 +8874,108 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { }; } // namespace +/// ------------------ denormal-fp-math Attribute ------------------------- + +namespace { +struct AADenormalFPMathImpl : public AADenormalFPMath { + AADenormalFPMathImpl(const IRPosition &IRP, Attributor &A) + : AADenormalFPMath(IRP, A) {} + + const std::string getAsStr(Attributor *A) const override { + std::string Str("AADenormalFPMath["); + raw_string_ostream OS(Str); + + DenormalState Known = getKnown(); + if (Known.Mode.isValid()) + OS << "denormal-fp-math=" << Known.Mode; + else + OS << "invalid"; + + if (Known.ModeF32.isValid()) + OS << " denormal-fp-math-f32=" << Known.ModeF32; + OS << ']'; + return OS.str(); + } +}; + +struct AADenormalFPMathFunction final : AADenormalFPMathImpl { + AADenormalFPMathFunction(const IRPosition &IRP, Attributor &A) + : AADenormalFPMathImpl(IRP, A) {} + + void initialize(Attributor &A) override { + const Function *F = getAnchorScope(); + DenormalMode Mode = F->getDenormalModeRaw(); + DenormalMode ModeF32 = F->getDenormalModeF32Raw(); + + // TODO: Handling this here prevents handling the case where a callee has a + // fixed denormal-fp-math with dynamic denormal-fp-math-f32, but called from + // a function with a fully fixed mode. + if (ModeF32 == DenormalMode::getInvalid()) + ModeF32 = Mode; + Known = DenormalState{Mode, ModeF32}; + if (isModeFixed()) + indicateFixpoint(); + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + auto CheckCallSite = [=, &Change, &A](AbstractCallSite CS) { + Function *Caller = CS.getInstruction()->getFunction(); + LLVM_DEBUG(dbgs() << "[AADenormalFPMath] Call " << Caller->getName() + << "->" << getAssociatedFunction()->getName() << '\n'); + + const auto *CallerInfo = A.getAAFor<AADenormalFPMath>( + *this, IRPosition::function(*Caller), DepClassTy::REQUIRED); + if (!CallerInfo) + return false; + + Change = Change | clampStateAndIndicateChange(this->getState(), + CallerInfo->getState()); + return true; + }; + + bool AllCallSitesKnown = true; + if (!A.checkForAllCallSites(CheckCallSite, *this, true, AllCallSitesKnown)) + return indicatePessimisticFixpoint(); + + if (Change == ChangeStatus::CHANGED && isModeFixed()) + indicateFixpoint(); + return Change; + } + + ChangeStatus manifest(Attributor &A) override { + LLVMContext &Ctx = getAssociatedFunction()->getContext(); + + SmallVector<Attribute, 2> AttrToAdd; + SmallVector<StringRef, 2> AttrToRemove; + if (Known.Mode == DenormalMode::getDefault()) { + AttrToRemove.push_back("denormal-fp-math"); + } else { + AttrToAdd.push_back( + Attribute::get(Ctx, "denormal-fp-math", Known.Mode.str())); + } + + if (Known.ModeF32 != Known.Mode) { + AttrToAdd.push_back( + Attribute::get(Ctx, "denormal-fp-math-f32", Known.ModeF32.str())); + } else { + AttrToRemove.push_back("denormal-fp-math-f32"); + } + + auto &IRP = getIRPosition(); + + // TODO: There should be a combined add and remove API. + return A.removeAttrs(IRP, AttrToRemove) | + A.manifestAttrs(IRP, AttrToAdd, /*ForceReplace=*/true); + } + + void trackStatistics() const override { + STATS_DECLTRACK_FN_ATTR(denormal_fp_math) + } +}; +} // namespace + /// ------------------ Value Constant Range Attribute ------------------------- namespace { @@ -9427,17 +9582,13 @@ struct AAValueConstantRangeCallSite : AAValueConstantRangeFunction { }; struct AAValueConstantRangeCallSiteReturned - : AACallSiteReturnedFromReturned<AAValueConstantRange, - AAValueConstantRangeImpl, - AAValueConstantRangeImpl::StateType, - /* IntroduceCallBaseContext */ true> { + : AACalleeToCallSite<AAValueConstantRange, AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true> { AAValueConstantRangeCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AAValueConstantRange, - AAValueConstantRangeImpl, - AAValueConstantRangeImpl::StateType, - /* IntroduceCallBaseContext */ true>(IRP, - A) { - } + : AACalleeToCallSite<AAValueConstantRange, AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true>(IRP, A) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { @@ -9956,12 +10107,12 @@ struct AAPotentialConstantValuesCallSite : AAPotentialConstantValuesFunction { }; struct AAPotentialConstantValuesCallSiteReturned - : AACallSiteReturnedFromReturned<AAPotentialConstantValues, - AAPotentialConstantValuesImpl> { + : AACalleeToCallSite<AAPotentialConstantValues, + AAPotentialConstantValuesImpl> { AAPotentialConstantValuesCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AAPotentialConstantValues, - AAPotentialConstantValuesImpl>(IRP, A) {} + : AACalleeToCallSite<AAPotentialConstantValues, + AAPotentialConstantValuesImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -10101,7 +10252,8 @@ struct AANoUndefFloating : public AANoUndefImpl { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { AANoUndefImpl::initialize(A); - if (!getState().isAtFixpoint()) + if (!getState().isAtFixpoint() && getAnchorScope() && + !getAnchorScope()->isDeclaration()) if (Instruction *CtxI = getCtxI()) followUsesInMBEC(*this, A, getState(), *CtxI); } @@ -10148,26 +10300,18 @@ struct AANoUndefFloating : public AANoUndefImpl { }; struct AANoUndefReturned final - : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef> { + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl> { AANoUndefReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef>(IRP, A) {} + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) } }; struct AANoUndefArgument final - : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef> { + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl> { AANoUndefArgument(const IRPosition &IRP, Attributor &A) - : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef>(IRP, A) {} + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noundef) } @@ -10182,13 +10326,9 @@ struct AANoUndefCallSiteArgument final : AANoUndefFloating { }; struct AANoUndefCallSiteReturned final - : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef> { + : AACalleeToCallSite<AANoUndef, AANoUndefImpl> { AANoUndefCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef>(IRP, A) {} + : AACalleeToCallSite<AANoUndef, AANoUndefImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) } @@ -10212,7 +10352,6 @@ struct AANoFPClassImpl : AANoFPClass { A.getAttrs(getIRPosition(), {Attribute::NoFPClass}, Attrs, false); for (const auto &Attr : Attrs) { addKnownBits(Attr.getNoFPClass()); - return; } const DataLayout &DL = A.getDataLayout(); @@ -10248,8 +10387,22 @@ struct AANoFPClassImpl : AANoFPClass { /*Depth=*/0, TLI, AC, I, DT); State.addKnownBits(~KnownFPClass.KnownFPClasses); - bool TrackUse = false; - return TrackUse; + if (auto *CI = dyn_cast<CallInst>(UseV)) { + // Special case FP intrinsic with struct return type. + switch (CI->getIntrinsicID()) { + case Intrinsic::frexp: + return true; + case Intrinsic::not_intrinsic: + // TODO: Could recognize math libcalls + return false; + default: + break; + } + } + + if (!UseV->getType()->isFPOrFPVectorTy()) + return false; + return !isa<LoadInst, AtomicRMWInst>(UseV); } const std::string getAsStr(Attributor *A) const override { @@ -10339,9 +10492,9 @@ struct AANoFPClassCallSiteArgument final : AANoFPClassFloating { }; struct AANoFPClassCallSiteReturned final - : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl> { + : AACalleeToCallSite<AANoFPClass, AANoFPClassImpl> { AANoFPClassCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl>(IRP, A) {} + : AACalleeToCallSite<AANoFPClass, AANoFPClassImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -10446,15 +10599,12 @@ struct AACallEdgesCallSite : public AACallEdgesImpl { return Change; } - // Process callee metadata if available. - if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) { - for (const auto &Op : MD->operands()) { - Function *Callee = mdconst::dyn_extract_or_null<Function>(Op); - if (Callee) - addCalledFunction(Callee, Change); - } - return Change; - } + if (CB->isIndirectCall()) + if (auto *IndirectCallAA = A.getAAFor<AAIndirectCallInfo>( + *this, getIRPosition(), DepClassTy::OPTIONAL)) + if (IndirectCallAA->foreachCallee( + [&](Function *Fn) { return VisitValue(*Fn, CB); })) + return Change; // The most simple case. ProcessCalledOperand(CB->getCalledOperand(), CB); @@ -10519,28 +10669,26 @@ struct AAInterFnReachabilityFunction bool instructionCanReach( Attributor &A, const Instruction &From, const Function &To, - const AA::InstExclusionSetTy *ExclusionSet, - SmallPtrSet<const Function *, 16> *Visited) const override { + const AA::InstExclusionSetTy *ExclusionSet) const override { assert(From.getFunction() == getAnchorScope() && "Queried the wrong AA!"); auto *NonConstThis = const_cast<AAInterFnReachabilityFunction *>(this); RQITy StackRQI(A, From, To, ExclusionSet, false); typename RQITy::Reachable Result; if (!NonConstThis->checkQueryCache(A, StackRQI, Result)) - return NonConstThis->isReachableImpl(A, StackRQI); + return NonConstThis->isReachableImpl(A, StackRQI, + /*IsTemporaryRQI=*/true); return Result == RQITy::Reachable::Yes; } - bool isReachableImpl(Attributor &A, RQITy &RQI) override { - return isReachableImpl(A, RQI, nullptr); - } - bool isReachableImpl(Attributor &A, RQITy &RQI, - SmallPtrSet<const Function *, 16> *Visited) { - - SmallPtrSet<const Function *, 16> LocalVisited; - if (!Visited) - Visited = &LocalVisited; + bool IsTemporaryRQI) override { + const Instruction *EntryI = + &RQI.From->getFunction()->getEntryBlock().front(); + if (EntryI != RQI.From && + !instructionCanReach(A, *EntryI, *RQI.To, nullptr)) + return rememberResult(A, RQITy::Reachable::No, RQI, false, + IsTemporaryRQI); auto CheckReachableCallBase = [&](CallBase *CB) { auto *CBEdges = A.getAAFor<AACallEdges>( @@ -10554,8 +10702,7 @@ struct AAInterFnReachabilityFunction for (Function *Fn : CBEdges->getOptimisticEdges()) { if (Fn == RQI.To) return false; - if (!Visited->insert(Fn).second) - continue; + if (Fn->isDeclaration()) { if (Fn->hasFnAttribute(Attribute::NoCallback)) continue; @@ -10563,15 +10710,20 @@ struct AAInterFnReachabilityFunction return false; } - const AAInterFnReachability *InterFnReachability = this; - if (Fn != getAnchorScope()) - InterFnReachability = A.getAAFor<AAInterFnReachability>( - *this, IRPosition::function(*Fn), DepClassTy::OPTIONAL); + if (Fn == getAnchorScope()) { + if (EntryI == RQI.From) + continue; + return false; + } + + const AAInterFnReachability *InterFnReachability = + A.getAAFor<AAInterFnReachability>(*this, IRPosition::function(*Fn), + DepClassTy::OPTIONAL); const Instruction &FnFirstInst = Fn->getEntryBlock().front(); if (!InterFnReachability || InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To, - RQI.ExclusionSet, Visited)) + RQI.ExclusionSet)) return false; } return true; @@ -10583,10 +10735,12 @@ struct AAInterFnReachabilityFunction // Determine call like instructions that we can reach from the inst. auto CheckCallBase = [&](Instruction &CBInst) { - if (!IntraFnReachability || !IntraFnReachability->isAssumedReachable( - A, *RQI.From, CBInst, RQI.ExclusionSet)) + // There are usually less nodes in the call graph, check inter function + // reachability first. + if (CheckReachableCallBase(cast<CallBase>(&CBInst))) return true; - return CheckReachableCallBase(cast<CallBase>(&CBInst)); + return IntraFnReachability && !IntraFnReachability->isAssumedReachable( + A, *RQI.From, CBInst, RQI.ExclusionSet); }; bool UsedExclusionSet = /* conservative */ true; @@ -10594,16 +10748,14 @@ struct AAInterFnReachabilityFunction if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this, UsedAssumedInformation, /* CheckBBLivenessOnly */ true)) - return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); } void trackStatistics() const override {} - -private: - SmallVector<RQITy *> QueryVector; - DenseSet<RQITy *> QueryCache; }; } // namespace @@ -10880,64 +11032,104 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { // Simplify the operands first. bool UsedAssumedInformation = false; - const auto &SimplifiedLHS = A.getAssumedSimplified( - IRPosition::value(*LHS, getCallBaseContext()), *this, - UsedAssumedInformation, AA::Intraprocedural); - if (!SimplifiedLHS.has_value()) + SmallVector<AA::ValueAndContext> LHSValues, RHSValues; + auto GetSimplifiedValues = [&](Value &V, + SmallVector<AA::ValueAndContext> &Values) { + if (!A.getAssumedSimplifiedValues( + IRPosition::value(V, getCallBaseContext()), this, Values, + AA::Intraprocedural, UsedAssumedInformation)) { + Values.clear(); + Values.push_back(AA::ValueAndContext{V, II.I.getCtxI()}); + } + return Values.empty(); + }; + if (GetSimplifiedValues(*LHS, LHSValues)) return true; - if (!*SimplifiedLHS) - return false; - LHS = *SimplifiedLHS; - - const auto &SimplifiedRHS = A.getAssumedSimplified( - IRPosition::value(*RHS, getCallBaseContext()), *this, - UsedAssumedInformation, AA::Intraprocedural); - if (!SimplifiedRHS.has_value()) + if (GetSimplifiedValues(*RHS, RHSValues)) return true; - if (!*SimplifiedRHS) - return false; - RHS = *SimplifiedRHS; LLVMContext &Ctx = LHS->getContext(); - // Handle the trivial case first in which we don't even need to think about - // null or non-null. - if (LHS == RHS && - (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) { - Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx), - CmpInst::isTrueWhenEqual(Pred)); - addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, - getAnchorScope()); - return true; - } - // From now on we only handle equalities (==, !=). - if (!CmpInst::isEquality(Pred)) - return false; + InformationCache &InfoCache = A.getInfoCache(); + Instruction *CmpI = dyn_cast<Instruction>(&Cmp); + Function *F = CmpI ? CmpI->getFunction() : nullptr; + const auto *DT = + F ? InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F) + : nullptr; + const auto *TLI = + F ? A.getInfoCache().getTargetLibraryInfoForFunction(*F) : nullptr; + auto *AC = + F ? InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F) + : nullptr; - bool LHSIsNull = isa<ConstantPointerNull>(LHS); - bool RHSIsNull = isa<ConstantPointerNull>(RHS); - if (!LHSIsNull && !RHSIsNull) - return false; + const DataLayout &DL = A.getDataLayout(); + SimplifyQuery Q(DL, TLI, DT, AC, CmpI); - // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the - // non-nullptr operand and if we assume it's non-null we can conclude the - // result of the comparison. - assert((LHSIsNull || RHSIsNull) && - "Expected nullptr versus non-nullptr comparison at this point"); + auto CheckPair = [&](Value &LHSV, Value &RHSV) { + if (isa<UndefValue>(LHSV) || isa<UndefValue>(RHSV)) { + addValue(A, getState(), *UndefValue::get(Cmp.getType()), + /* CtxI */ nullptr, II.S, getAnchorScope()); + return true; + } - // The index is the operand that we assume is not null. - unsigned PtrIdx = LHSIsNull; - bool IsKnownNonNull; - bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( - A, this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED, - IsKnownNonNull); - if (!IsAssumedNonNull) - return false; + // Handle the trivial case first in which we don't even need to think + // about null or non-null. + if (&LHSV == &RHSV && + (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) { + Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx), + CmpInst::isTrueWhenEqual(Pred)); + addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, + getAnchorScope()); + return true; + } + + auto *TypedLHS = AA::getWithType(LHSV, *LHS->getType()); + auto *TypedRHS = AA::getWithType(RHSV, *RHS->getType()); + if (TypedLHS && TypedRHS) { + Value *NewV = simplifyCmpInst(Pred, TypedLHS, TypedRHS, Q); + if (NewV && NewV != &Cmp) { + addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, + getAnchorScope()); + return true; + } + } + + // From now on we only handle equalities (==, !=). + if (!CmpInst::isEquality(Pred)) + return false; + + bool LHSIsNull = isa<ConstantPointerNull>(LHSV); + bool RHSIsNull = isa<ConstantPointerNull>(RHSV); + if (!LHSIsNull && !RHSIsNull) + return false; + + // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the + // non-nullptr operand and if we assume it's non-null we can conclude the + // result of the comparison. + assert((LHSIsNull || RHSIsNull) && + "Expected nullptr versus non-nullptr comparison at this point"); + + // The index is the operand that we assume is not null. + unsigned PtrIdx = LHSIsNull; + bool IsKnownNonNull; + bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*(PtrIdx ? &RHSV : &LHSV)), + DepClassTy::REQUIRED, IsKnownNonNull); + if (!IsAssumedNonNull) + return false; + + // The new value depends on the predicate, true for != and false for ==. + Constant *NewV = + ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE); + addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, + getAnchorScope()); + return true; + }; - // The new value depends on the predicate, true for != and false for ==. - Constant *NewV = - ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE); - addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, getAnchorScope()); + for (auto &LHSValue : LHSValues) + for (auto &RHSValue : RHSValues) + if (!CheckPair(*LHSValue.getValue(), *RHSValue.getValue())) + return false; return true; } @@ -11152,9 +11344,8 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { SmallVectorImpl<ItemInfo> &Worklist, SmallMapVector<const Function *, LivenessInfo, 4> &LivenessAAs) { if (auto *CI = dyn_cast<CmpInst>(&I)) - if (handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1), - CI->getPredicate(), II, Worklist)) - return true; + return handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1), + CI->getPredicate(), II, Worklist); switch (I.getOpcode()) { case Instruction::Select: @@ -11272,12 +11463,12 @@ struct AAPotentialValuesArgument final : AAPotentialValuesImpl { ChangeStatus updateImpl(Attributor &A) override { auto AssumedBefore = getAssumed(); - unsigned CSArgNo = getCallSiteArgNo(); + unsigned ArgNo = getCalleeArgNo(); bool UsedAssumedInformation = false; SmallVector<AA::ValueAndContext> Values; auto CallSitePred = [&](AbstractCallSite ACS) { - const auto CSArgIRP = IRPosition::callsite_argument(ACS, CSArgNo); + const auto CSArgIRP = IRPosition::callsite_argument(ACS, ArgNo); if (CSArgIRP.getPositionKind() == IRP_INVALID) return false; @@ -11889,6 +12080,455 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl { }; } // namespace +/// ------------------------ Global Value Info ------------------------------- +namespace { +struct AAGlobalValueInfoFloating : public AAGlobalValueInfo { + AAGlobalValueInfoFloating(const IRPosition &IRP, Attributor &A) + : AAGlobalValueInfo(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override {} + + bool checkUse(Attributor &A, const Use &U, bool &Follow, + SmallVectorImpl<const Value *> &Worklist) { + Instruction *UInst = dyn_cast<Instruction>(U.getUser()); + if (!UInst) { + Follow = true; + return true; + } + + LLVM_DEBUG(dbgs() << "[AAGlobalValueInfo] Check use: " << *U.get() << " in " + << *UInst << "\n"); + + if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { + int Idx = &Cmp->getOperandUse(0) == &U; + if (isa<Constant>(Cmp->getOperand(Idx))) + return true; + return U == &getAnchorValue(); + } + + // Explicitly catch return instructions. + if (isa<ReturnInst>(UInst)) { + auto CallSitePred = [&](AbstractCallSite ACS) { + Worklist.push_back(ACS.getInstruction()); + return true; + }; + bool UsedAssumedInformation = false; + // TODO: We should traverse the uses or add a "non-call-site" CB. + if (!A.checkForAllCallSites(CallSitePred, *UInst->getFunction(), + /*RequireAllCallSites=*/true, this, + UsedAssumedInformation)) + return false; + return true; + } + + // For now we only use special logic for call sites. However, the tracker + // itself knows about a lot of other non-capturing cases already. + auto *CB = dyn_cast<CallBase>(UInst); + if (!CB) + return false; + // Direct calls are OK uses. + if (CB->isCallee(&U)) + return true; + // Non-argument uses are scary. + if (!CB->isArgOperand(&U)) + return false; + // TODO: Iterate callees. + auto *Fn = dyn_cast<Function>(CB->getCalledOperand()); + if (!Fn || !A.isFunctionIPOAmendable(*Fn)) + return false; + + unsigned ArgNo = CB->getArgOperandNo(&U); + Worklist.push_back(Fn->getArg(ArgNo)); + return true; + } + + ChangeStatus updateImpl(Attributor &A) override { + unsigned NumUsesBefore = Uses.size(); + + SmallPtrSet<const Value *, 8> Visited; + SmallVector<const Value *> Worklist; + Worklist.push_back(&getAnchorValue()); + + auto UsePred = [&](const Use &U, bool &Follow) -> bool { + Uses.insert(&U); + switch (DetermineUseCaptureKind(U, nullptr)) { + case UseCaptureKind::NO_CAPTURE: + return checkUse(A, U, Follow, Worklist); + case UseCaptureKind::MAY_CAPTURE: + return checkUse(A, U, Follow, Worklist); + case UseCaptureKind::PASSTHROUGH: + Follow = true; + return true; + } + return true; + }; + auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) { + Uses.insert(&OldU); + return true; + }; + + while (!Worklist.empty()) { + const Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + if (!A.checkForAllUses(UsePred, *this, *V, + /* CheckBBLivenessOnly */ true, + DepClassTy::OPTIONAL, + /* IgnoreDroppableUses */ true, EquivalentUseCB)) { + return indicatePessimisticFixpoint(); + } + } + + return Uses.size() == NumUsesBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + bool isPotentialUse(const Use &U) const override { + return !isValidState() || Uses.contains(&U); + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + return "[" + std::to_string(Uses.size()) + " uses]"; + } + + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(GlobalValuesTracked); + } + +private: + /// Set of (transitive) uses of this GlobalValue. + SmallPtrSet<const Use *, 8> Uses; +}; +} // namespace + +/// ------------------------ Indirect Call Info ------------------------------- +namespace { +struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { + AAIndirectCallInfoCallSite(const IRPosition &IRP, Attributor &A) + : AAIndirectCallInfo(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees); + if (!MD && !A.isClosedWorldModule()) + return; + + if (MD) { + for (const auto &Op : MD->operands()) + if (Function *Callee = mdconst::dyn_extract_or_null<Function>(Op)) + PotentialCallees.insert(Callee); + } else if (A.isClosedWorldModule()) { + ArrayRef<Function *> IndirectlyCallableFunctions = + A.getInfoCache().getIndirectlyCallableFunctions(A); + PotentialCallees.insert(IndirectlyCallableFunctions.begin(), + IndirectlyCallableFunctions.end()); + } + + if (PotentialCallees.empty()) + indicateOptimisticFixpoint(); + } + + ChangeStatus updateImpl(Attributor &A) override { + CallBase *CB = cast<CallBase>(getCtxI()); + const Use &CalleeUse = CB->getCalledOperandUse(); + Value *FP = CB->getCalledOperand(); + + SmallSetVector<Function *, 4> AssumedCalleesNow; + bool AllCalleesKnownNow = AllCalleesKnown; + + auto CheckPotentialCalleeUse = [&](Function &PotentialCallee, + bool &UsedAssumedInformation) { + const auto *GIAA = A.getAAFor<AAGlobalValueInfo>( + *this, IRPosition::value(PotentialCallee), DepClassTy::OPTIONAL); + if (!GIAA || GIAA->isPotentialUse(CalleeUse)) + return true; + UsedAssumedInformation = !GIAA->isAtFixpoint(); + return false; + }; + + auto AddPotentialCallees = [&]() { + for (auto *PotentialCallee : PotentialCallees) { + bool UsedAssumedInformation = false; + if (CheckPotentialCalleeUse(*PotentialCallee, UsedAssumedInformation)) + AssumedCalleesNow.insert(PotentialCallee); + } + }; + + // Use simplification to find potential callees, if !callees was present, + // fallback to that set if necessary. + bool UsedAssumedInformation = false; + SmallVector<AA::ValueAndContext> Values; + if (!A.getAssumedSimplifiedValues(IRPosition::value(*FP), this, Values, + AA::ValueScope::AnyScope, + UsedAssumedInformation)) { + if (PotentialCallees.empty()) + return indicatePessimisticFixpoint(); + AddPotentialCallees(); + } + + // Try to find a reason for \p Fn not to be a potential callee. If none was + // found, add it to the assumed callees set. + auto CheckPotentialCallee = [&](Function &Fn) { + if (!PotentialCallees.empty() && !PotentialCallees.count(&Fn)) + return false; + + auto &CachedResult = FilterResults[&Fn]; + if (CachedResult.has_value()) + return CachedResult.value(); + + bool UsedAssumedInformation = false; + if (!CheckPotentialCalleeUse(Fn, UsedAssumedInformation)) { + if (!UsedAssumedInformation) + CachedResult = false; + return false; + } + + int NumFnArgs = Fn.arg_size(); + int NumCBArgs = CB->arg_size(); + + // Check if any excess argument (which we fill up with poison) is known to + // be UB on undef. + for (int I = NumCBArgs; I < NumFnArgs; ++I) { + bool IsKnown = false; + if (AA::hasAssumedIRAttr<Attribute::NoUndef>( + A, this, IRPosition::argument(*Fn.getArg(I)), + DepClassTy::OPTIONAL, IsKnown)) { + if (IsKnown) + CachedResult = false; + return false; + } + } + + CachedResult = true; + return true; + }; + + // Check simplification result, prune known UB callees, also restrict it to + // the !callees set, if present. + for (auto &VAC : Values) { + if (isa<UndefValue>(VAC.getValue())) + continue; + if (isa<ConstantPointerNull>(VAC.getValue()) && + VAC.getValue()->getType()->getPointerAddressSpace() == 0) + continue; + // TODO: Check for known UB, e.g., poison + noundef. + if (auto *VACFn = dyn_cast<Function>(VAC.getValue())) { + if (CheckPotentialCallee(*VACFn)) + AssumedCalleesNow.insert(VACFn); + continue; + } + if (!PotentialCallees.empty()) { + AddPotentialCallees(); + break; + } + AllCalleesKnownNow = false; + } + + if (AssumedCalleesNow == AssumedCallees && + AllCalleesKnown == AllCalleesKnownNow) + return ChangeStatus::UNCHANGED; + + std::swap(AssumedCallees, AssumedCalleesNow); + AllCalleesKnown = AllCalleesKnownNow; + return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + // If we can't specialize at all, give up now. + if (!AllCalleesKnown && AssumedCallees.empty()) + return ChangeStatus::UNCHANGED; + + CallBase *CB = cast<CallBase>(getCtxI()); + bool UsedAssumedInformation = false; + if (A.isAssumedDead(*CB, this, /*LivenessAA=*/nullptr, + UsedAssumedInformation)) + return ChangeStatus::UNCHANGED; + + ChangeStatus Changed = ChangeStatus::UNCHANGED; + Value *FP = CB->getCalledOperand(); + if (FP->getType()->getPointerAddressSpace()) + FP = new AddrSpaceCastInst(FP, PointerType::get(FP->getType(), 0), + FP->getName() + ".as0", CB); + + bool CBIsVoid = CB->getType()->isVoidTy(); + Instruction *IP = CB; + FunctionType *CSFT = CB->getFunctionType(); + SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end()); + + // If we know all callees and there are none, the call site is (effectively) + // dead (or UB). + if (AssumedCallees.empty()) { + assert(AllCalleesKnown && + "Expected all callees to be known if there are none."); + A.changeToUnreachableAfterManifest(CB); + return ChangeStatus::CHANGED; + } + + // Special handling for the single callee case. + if (AllCalleesKnown && AssumedCallees.size() == 1) { + auto *NewCallee = AssumedCallees.front(); + if (isLegalToPromote(*CB, NewCallee)) { + promoteCall(*CB, NewCallee, nullptr); + return ChangeStatus::CHANGED; + } + Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), + CSArgs, CB->getName(), CB); + if (!CBIsVoid) + A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall); + A.deleteAfterManifest(*CB); + return ChangeStatus::CHANGED; + } + + // For each potential value we create a conditional + // + // ``` + // if (ptr == value) value(args); + // else ... + // ``` + // + bool SpecializedForAnyCallees = false; + bool SpecializedForAllCallees = AllCalleesKnown; + ICmpInst *LastCmp = nullptr; + SmallVector<Function *, 8> SkippedAssumedCallees; + SmallVector<std::pair<CallInst *, Instruction *>> NewCalls; + for (Function *NewCallee : AssumedCallees) { + if (!A.shouldSpecializeCallSiteForCallee(*this, *CB, *NewCallee)) { + SkippedAssumedCallees.push_back(NewCallee); + SpecializedForAllCallees = false; + continue; + } + SpecializedForAnyCallees = true; + + LastCmp = new ICmpInst(IP, llvm::CmpInst::ICMP_EQ, FP, NewCallee); + Instruction *ThenTI = + SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false); + BasicBlock *CBBB = CB->getParent(); + A.registerManifestAddedBasicBlock(*ThenTI->getParent()); + A.registerManifestAddedBasicBlock(*CBBB); + auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode()); + BasicBlock *ElseBB; + if (IP == CB) { + ElseBB = BasicBlock::Create(ThenTI->getContext(), "", + ThenTI->getFunction(), CBBB); + A.registerManifestAddedBasicBlock(*ElseBB); + IP = BranchInst::Create(CBBB, ElseBB); + SplitTI->replaceUsesOfWith(CBBB, ElseBB); + } else { + ElseBB = IP->getParent(); + ThenTI->replaceUsesOfWith(ElseBB, CBBB); + } + CastInst *RetBC = nullptr; + CallInst *NewCall = nullptr; + if (isLegalToPromote(*CB, NewCallee)) { + auto *CBClone = cast<CallBase>(CB->clone()); + CBClone->insertBefore(ThenTI); + NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC)); + } else { + NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs, + CB->getName(), ThenTI); + } + NewCalls.push_back({NewCall, RetBC}); + } + + auto AttachCalleeMetadata = [&](CallBase &IndirectCB) { + if (!AllCalleesKnown) + return ChangeStatus::UNCHANGED; + MDBuilder MDB(IndirectCB.getContext()); + MDNode *Callees = MDB.createCallees(SkippedAssumedCallees); + IndirectCB.setMetadata(LLVMContext::MD_callees, Callees); + return ChangeStatus::CHANGED; + }; + + if (!SpecializedForAnyCallees) + return AttachCalleeMetadata(*CB); + + // Check if we need the fallback indirect call still. + if (SpecializedForAllCallees) { + LastCmp->replaceAllUsesWith(ConstantInt::getTrue(LastCmp->getContext())); + LastCmp->eraseFromParent(); + new UnreachableInst(IP->getContext(), IP); + IP->eraseFromParent(); + } else { + auto *CBClone = cast<CallInst>(CB->clone()); + CBClone->setName(CB->getName()); + CBClone->insertBefore(IP); + NewCalls.push_back({CBClone, nullptr}); + AttachCalleeMetadata(*CBClone); + } + + // Check if we need a PHI to merge the results. + if (!CBIsVoid) { + auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(), + CB->getName() + ".phi", + &*CB->getParent()->getFirstInsertionPt()); + for (auto &It : NewCalls) { + CallBase *NewCall = It.first; + Instruction *CallRet = It.second ? It.second : It.first; + if (CallRet->getType() == CB->getType()) + PHI->addIncoming(CallRet, CallRet->getParent()); + else if (NewCall->getType()->isVoidTy()) + PHI->addIncoming(PoisonValue::get(CB->getType()), + NewCall->getParent()); + else + llvm_unreachable("Call return should match or be void!"); + } + A.changeAfterManifest(IRPosition::callsite_returned(*CB), *PHI); + } + + A.deleteAfterManifest(*CB); + Changed = ChangeStatus::CHANGED; + + return Changed; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + return std::string(AllCalleesKnown ? "eliminate" : "specialize") + + " indirect call site with " + std::to_string(AssumedCallees.size()) + + " functions"; + } + + void trackStatistics() const override { + if (AllCalleesKnown) { + STATS_DECLTRACK( + Eliminated, CallSites, + "Number of indirect call sites eliminated via specialization") + } else { + STATS_DECLTRACK(Specialized, CallSites, + "Number of indirect call sites specialized") + } + } + + bool foreachCallee(function_ref<bool(Function *)> CB) const override { + return isValidState() && AllCalleesKnown && all_of(AssumedCallees, CB); + } + +private: + /// Map to remember filter results. + DenseMap<Function *, std::optional<bool>> FilterResults; + + /// If the !callee metadata was present, this set will contain all potential + /// callees (superset). + SmallSetVector<Function *, 4> PotentialCallees; + + /// This set contains all currently assumed calllees, which might grow over + /// time. + SmallSetVector<Function *, 4> AssumedCallees; + + /// Flag to indicate if all possible callees are in the AssumedCallees set or + /// if there could be others. + bool AllCalleesKnown = true; +}; +} // namespace + /// ------------------------ Address Space ------------------------------------ namespace { struct AAAddressSpaceImpl : public AAAddressSpace { @@ -11961,8 +12601,13 @@ struct AAAddressSpaceImpl : public AAAddressSpace { // CGSCC if the AA is run on CGSCC instead of the entire module. if (!A.isRunOn(Inst->getFunction())) return true; - if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) + if (isa<LoadInst>(Inst)) MakeChange(Inst, const_cast<Use &>(U)); + if (isa<StoreInst>(Inst)) { + // We only make changes if the use is the pointer operand. + if (U.getOperandNo() == 1) + MakeChange(Inst, const_cast<Use &>(U)); + } return true; }; @@ -12064,6 +12709,224 @@ struct AAAddressSpaceCallSiteArgument final : AAAddressSpaceImpl { }; } // namespace +/// ----------- Allocation Info ---------- +namespace { +struct AAAllocationInfoImpl : public AAAllocationInfo { + AAAllocationInfoImpl(const IRPosition &IRP, Attributor &A) + : AAAllocationInfo(IRP, A) {} + + std::optional<TypeSize> getAllocatedSize() const override { + assert(isValidState() && "the AA is invalid"); + return AssumedAllocatedSize; + } + + std::optional<TypeSize> findInitialAllocationSize(Instruction *I, + const DataLayout &DL) { + + // TODO: implement case for malloc like instructions + switch (I->getOpcode()) { + case Instruction::Alloca: { + AllocaInst *AI = cast<AllocaInst>(I); + return AI->getAllocationSize(DL); + } + default: + return std::nullopt; + } + } + + ChangeStatus updateImpl(Attributor &A) override { + + const IRPosition &IRP = getIRPosition(); + Instruction *I = IRP.getCtxI(); + + // TODO: update check for malloc like calls + if (!isa<AllocaInst>(I)) + return indicatePessimisticFixpoint(); + + bool IsKnownNoCapture; + if (!AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, IRP, DepClassTy::OPTIONAL, IsKnownNoCapture)) + return indicatePessimisticFixpoint(); + + const AAPointerInfo *PI = + A.getOrCreateAAFor<AAPointerInfo>(IRP, *this, DepClassTy::REQUIRED); + + if (!PI) + return indicatePessimisticFixpoint(); + + if (!PI->getState().isValidState()) + return indicatePessimisticFixpoint(); + + const DataLayout &DL = A.getDataLayout(); + const auto AllocationSize = findInitialAllocationSize(I, DL); + + // If allocation size is nullopt, we give up. + if (!AllocationSize) + return indicatePessimisticFixpoint(); + + // For zero sized allocations, we give up. + // Since we can't reduce further + if (*AllocationSize == 0) + return indicatePessimisticFixpoint(); + + int64_t BinSize = PI->numOffsetBins(); + + // TODO: implement for multiple bins + if (BinSize > 1) + return indicatePessimisticFixpoint(); + + if (BinSize == 0) { + auto NewAllocationSize = std::optional<TypeSize>(TypeSize(0, false)); + if (!changeAllocationSize(NewAllocationSize)) + return ChangeStatus::UNCHANGED; + return ChangeStatus::CHANGED; + } + + // TODO: refactor this to be part of multiple bin case + const auto &It = PI->begin(); + + // TODO: handle if Offset is not zero + if (It->first.Offset != 0) + return indicatePessimisticFixpoint(); + + uint64_t SizeOfBin = It->first.Offset + It->first.Size; + + if (SizeOfBin >= *AllocationSize) + return indicatePessimisticFixpoint(); + + auto NewAllocationSize = + std::optional<TypeSize>(TypeSize(SizeOfBin * 8, false)); + + if (!changeAllocationSize(NewAllocationSize)) + return ChangeStatus::UNCHANGED; + + return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + + assert(isValidState() && + "Manifest should only be called if the state is valid."); + + Instruction *I = getIRPosition().getCtxI(); + + auto FixedAllocatedSizeInBits = getAllocatedSize()->getFixedValue(); + + unsigned long NumBytesToAllocate = (FixedAllocatedSizeInBits + 7) / 8; + + switch (I->getOpcode()) { + // TODO: add case for malloc like calls + case Instruction::Alloca: { + + AllocaInst *AI = cast<AllocaInst>(I); + + Type *CharType = Type::getInt8Ty(I->getContext()); + + auto *NumBytesToValue = + ConstantInt::get(I->getContext(), APInt(32, NumBytesToAllocate)); + + AllocaInst *NewAllocaInst = + new AllocaInst(CharType, AI->getAddressSpace(), NumBytesToValue, + AI->getAlign(), AI->getName(), AI->getNextNode()); + + if (A.changeAfterManifest(IRPosition::inst(*AI), *NewAllocaInst)) + return ChangeStatus::CHANGED; + + break; + } + default: + break; + } + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + if (!isValidState()) + return "allocationinfo(<invalid>)"; + return "allocationinfo(" + + (AssumedAllocatedSize == HasNoAllocationSize + ? "none" + : std::to_string(AssumedAllocatedSize->getFixedValue())) + + ")"; + } + +private: + std::optional<TypeSize> AssumedAllocatedSize = HasNoAllocationSize; + + // Maintain the computed allocation size of the object. + // Returns (bool) weather the size of the allocation was modified or not. + bool changeAllocationSize(std::optional<TypeSize> Size) { + if (AssumedAllocatedSize == HasNoAllocationSize || + AssumedAllocatedSize != Size) { + AssumedAllocatedSize = Size; + return true; + } + return false; + } +}; + +struct AAAllocationInfoFloating : AAAllocationInfoImpl { + AAAllocationInfoFloating(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoReturned : AAAllocationInfoImpl { + AAAllocationInfoReturned(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // TODO: we don't rewrite function argument for now because it will need to + // rewrite the function signature and all call sites + (void)indicatePessimisticFixpoint(); + } + + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoCallSiteReturned : AAAllocationInfoImpl { + AAAllocationInfoCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoArgument : AAAllocationInfoImpl { + AAAllocationInfoArgument(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_ARG_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl { + AAAllocationInfoCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + + (void)indicatePessimisticFixpoint(); + } + + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(allocationinfo); + } +}; +} // namespace + const char AANoUnwind::ID = 0; const char AANoSync::ID = 0; const char AANoFree::ID = 0; @@ -12097,6 +12960,10 @@ const char AAPointerInfo::ID = 0; const char AAAssumptionInfo::ID = 0; const char AAUnderlyingObjects::ID = 0; const char AAAddressSpace::ID = 0; +const char AAAllocationInfo::ID = 0; +const char AAIndirectCallInfo::ID = 0; +const char AAGlobalValueInfo::ID = 0; +const char AADenormalFPMath::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -12143,6 +13010,18 @@ const char AAAddressSpace::ID = 0; return *AA; \ } +#define CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(POS, SUFFIX, CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_CREATE(CLASS, IRP, POS, SUFFIX) \ + default: \ + llvm_unreachable("Cannot create " #CLASS " for position otherthan " #POS \ + " position!"); \ + } \ + return *AA; \ + } + #define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ CLASS *AA = nullptr; \ @@ -12215,17 +13094,24 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFPClass) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAddressSpace) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAllocationInfo) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects) +CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_CALL_SITE, CallSite, + AAIndirectCallInfo) +CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_FLOAT, Floating, + AAGlobalValueInfo) + CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIntraFnReachability) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInterFnReachability) +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADenormalFPMath) CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) @@ -12234,5 +13120,6 @@ CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) #undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION #undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION #undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION #undef SWITCH_PK_CREATE #undef SWITCH_PK_INV diff --git a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 93d15f59a036..5cc8258a495a 100644 --- a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -85,7 +85,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) { LLVMContext &Ctx = M.getContext(); FunctionCallee C = M.getOrInsertFunction( "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), - Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); + PointerType::getUnqual(Ctx), PointerType::getUnqual(Ctx)); Function *F = cast<Function>(C.getCallee()); // Take over the existing function. The frontend emits a weak stub so that the // linker knows about the symbol; this pass replaces the function body. @@ -110,9 +110,9 @@ void CrossDSOCFI::buildCFICheck(Module &M) { BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F); IRBuilder<> IRBFail(TrapBB); - FunctionCallee CFICheckFailFn = - M.getOrInsertFunction("__cfi_check_fail", Type::getVoidTy(Ctx), - Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); + FunctionCallee CFICheckFailFn = M.getOrInsertFunction( + "__cfi_check_fail", Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx), + PointerType::getUnqual(Ctx)); IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); IRBFail.CreateBr(ExitBB); diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 01834015f3fd..4f65748c19e6 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -174,6 +174,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { NF->setComdat(F.getComdat()); F.getParent()->getFunctionList().insert(F.getIterator(), NF); NF->takeName(&F); + NF->IsNewDbgInfoFormat = F.IsNewDbgInfoFormat; // Loop over all the callers of the function, transforming the call sites // to pass in a smaller number of arguments into the new function. @@ -248,7 +249,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { NF->addMetadata(KindID, *Node); // Fix up any BlockAddresses that refer to the function. - F.replaceAllUsesWith(ConstantExpr::getBitCast(NF, F.getType())); + F.replaceAllUsesWith(NF); // Delete the bitcast that we just created, so that NF does not // appear to be address-taken. NF->removeDeadConstantUsers(); @@ -877,6 +878,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { // it again. F->getParent()->getFunctionList().insert(F->getIterator(), NF); NF->takeName(F); + NF->IsNewDbgInfoFormat = F->IsNewDbgInfoFormat; // Loop over all the callers of the function, transforming the call sites to // pass in a smaller number of arguments into the new function. diff --git a/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp index fa56a5b564ae..48ef0772e800 100644 --- a/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp +++ b/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp @@ -7,8 +7,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/EmbedBitcodePass.h" -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/Bitcode/BitcodeWriterPass.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" @@ -16,10 +14,8 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" -#include <memory> #include <string> using namespace llvm; @@ -34,19 +30,9 @@ PreservedAnalyses EmbedBitcodePass::run(Module &M, ModuleAnalysisManager &AM) { report_fatal_error( "EmbedBitcode pass currently only supports ELF object format", /*gen_crash_diag=*/false); - - std::unique_ptr<Module> NewModule = CloneModule(M); - MPM.run(*NewModule, AM); - std::string Data; raw_string_ostream OS(Data); - if (IsThinLTO) - ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(*NewModule, AM); - else - BitcodeWriterPass(OS, /*ShouldPreserveUseListOrder=*/false, EmitLTOSummary) - .run(*NewModule, AM); - + ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(M, AM); embedBufferInModule(M, MemoryBufferRef(Data, "ModuleData"), ".llvm.lto"); - return PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 74931e1032d1..9cf4e448c9b6 100644 --- a/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -11,38 +11,57 @@ #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LineIterator.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; #define DEBUG_TYPE "forceattrs" -static cl::list<std::string> - ForceAttributes("force-attribute", cl::Hidden, - cl::desc("Add an attribute to a function. This should be a " - "pair of 'function-name:attribute-name', for " - "example -force-attribute=foo:noinline. This " - "option can be specified multiple times.")); +static cl::list<std::string> ForceAttributes( + "force-attribute", cl::Hidden, + cl::desc( + "Add an attribute to a function. This can be a " + "pair of 'function-name:attribute-name', to apply an attribute to a " + "specific function. For " + "example -force-attribute=foo:noinline. Specifying only an attribute " + "will apply the attribute to every function in the module. This " + "option can be specified multiple times.")); static cl::list<std::string> ForceRemoveAttributes( "force-remove-attribute", cl::Hidden, - cl::desc("Remove an attribute from a function. This should be a " - "pair of 'function-name:attribute-name', for " - "example -force-remove-attribute=foo:noinline. This " + cl::desc("Remove an attribute from a function. This can be a " + "pair of 'function-name:attribute-name' to remove an attribute " + "from a specific function. For " + "example -force-remove-attribute=foo:noinline. Specifying only an " + "attribute will remove the attribute from all functions in the " + "module. This " "option can be specified multiple times.")); +static cl::opt<std::string> CSVFilePath( + "forceattrs-csv-path", cl::Hidden, + cl::desc( + "Path to CSV file containing lines of function names and attributes to " + "add to them in the form of `f1,attr1` or `f2,attr2=str`.")); + /// If F has any forced attributes given on the command line, add them. /// If F has any forced remove attributes given on the command line, remove /// them. When both force and force-remove are given to a function, the latter /// takes precedence. static void forceAttributes(Function &F) { auto ParseFunctionAndAttr = [&](StringRef S) { - auto Kind = Attribute::None; - auto KV = StringRef(S).split(':'); - if (KV.first != F.getName()) - return Kind; - Kind = Attribute::getAttrKindFromName(KV.second); + StringRef AttributeText; + if (S.contains(':')) { + auto KV = StringRef(S).split(':'); + if (KV.first != F.getName()) + return Attribute::None; + AttributeText = KV.second; + } else { + AttributeText = S; + } + auto Kind = Attribute::getAttrKindFromName(AttributeText); if (Kind == Attribute::None || !Attribute::canUseAsFnAttr(Kind)) { - LLVM_DEBUG(dbgs() << "ForcedAttribute: " << KV.second + LLVM_DEBUG(dbgs() << "ForcedAttribute: " << AttributeText << " unknown or not a function attribute!\n"); } return Kind; @@ -69,12 +88,52 @@ static bool hasForceAttributes() { PreservedAnalyses ForceFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &) { - if (!hasForceAttributes()) - return PreservedAnalyses::all(); - - for (Function &F : M.functions()) - forceAttributes(F); - - // Just conservatively invalidate analyses, this isn't likely to be important. - return PreservedAnalyses::none(); + bool Changed = false; + if (!CSVFilePath.empty()) { + auto BufferOrError = MemoryBuffer::getFileOrSTDIN(CSVFilePath); + if (!BufferOrError) + report_fatal_error("Cannot open CSV file."); + StringRef Buffer = BufferOrError.get()->getBuffer(); + auto MemoryBuffer = MemoryBuffer::getMemBuffer(Buffer); + line_iterator It(*MemoryBuffer); + for (; !It.is_at_end(); ++It) { + auto SplitPair = It->split(','); + if (SplitPair.second.empty()) + continue; + Function *Func = M.getFunction(SplitPair.first); + if (Func) { + if (Func->isDeclaration()) + continue; + auto SecondSplitPair = SplitPair.second.split('='); + if (!SecondSplitPair.second.empty()) { + Func->addFnAttr(SecondSplitPair.first, SecondSplitPair.second); + Changed = true; + } else { + auto AttrKind = Attribute::getAttrKindFromName(SplitPair.second); + if (AttrKind != Attribute::None && + Attribute::canUseAsFnAttr(AttrKind)) { + // TODO: There could be string attributes without a value, we should + // support those, too. + Func->addFnAttr(AttrKind); + Changed = true; + } else + errs() << "Cannot add " << SplitPair.second + << " as an attribute name.\n"; + } + } else { + errs() << "Function in CSV file at line " << It.line_number() + << " does not exist.\n"; + // TODO: `report_fatal_error at end of pass for missing functions. + continue; + } + } + } + if (hasForceAttributes()) { + for (Function &F : M.functions()) + forceAttributes(F); + Changed = true; + } + // Just conservatively invalidate analyses if we've made any changes, this + // isn't likely to be important. + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 34299f9dbb23..7c277518b21d 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -110,6 +110,39 @@ using SCCNodeSet = SmallSetVector<Function *, 8>; } // end anonymous namespace +static void addLocAccess(MemoryEffects &ME, const MemoryLocation &Loc, + ModRefInfo MR, AAResults &AAR) { + // Ignore accesses to known-invariant or local memory. + MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true); + if (isNoModRef(MR)) + return; + + const Value *UO = getUnderlyingObject(Loc.Ptr); + assert(!isa<AllocaInst>(UO) && + "Should have been handled by getModRefInfoMask()"); + if (isa<Argument>(UO)) { + ME |= MemoryEffects::argMemOnly(MR); + return; + } + + // If it's not an identified object, it might be an argument. + if (!isIdentifiedObject(UO)) + ME |= MemoryEffects::argMemOnly(MR); + ME |= MemoryEffects(IRMemLocation::Other, MR); +} + +static void addArgLocs(MemoryEffects &ME, const CallBase *Call, + ModRefInfo ArgMR, AAResults &AAR) { + for (const Value *Arg : Call->args()) { + if (!Arg->getType()->isPtrOrPtrVectorTy()) + continue; + + addLocAccess(ME, + MemoryLocation::getBeforeOrAfter(Arg, Call->getAAMetadata()), + ArgMR, AAR); + } +} + /// Returns the memory access attribute for function F using AAR for AA results, /// where SCCNodes is the current SCC. /// @@ -118,54 +151,48 @@ using SCCNodeSet = SmallSetVector<Function *, 8>; /// result will be based only on AA results for the function declaration; it /// will be assumed that some other (perhaps less optimized) version of the /// function may be selected at link time. -static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, - AAResults &AAR, - const SCCNodeSet &SCCNodes) { +/// +/// The return value is split into two parts: Memory effects that always apply, +/// and additional memory effects that apply if any of the functions in the SCC +/// can access argmem. +static std::pair<MemoryEffects, MemoryEffects> +checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR, + const SCCNodeSet &SCCNodes) { MemoryEffects OrigME = AAR.getMemoryEffects(&F); if (OrigME.doesNotAccessMemory()) // Already perfect! - return OrigME; + return {OrigME, MemoryEffects::none()}; if (!ThisBody) - return OrigME; + return {OrigME, MemoryEffects::none()}; MemoryEffects ME = MemoryEffects::none(); + // Additional locations accessed if the SCC accesses argmem. + MemoryEffects RecursiveArgME = MemoryEffects::none(); + // Inalloca and preallocated arguments are always clobbered by the call. if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) || F.getAttributes().hasAttrSomewhere(Attribute::Preallocated)) ME |= MemoryEffects::argMemOnly(ModRefInfo::ModRef); - auto AddLocAccess = [&](const MemoryLocation &Loc, ModRefInfo MR) { - // Ignore accesses to known-invariant or local memory. - MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true); - if (isNoModRef(MR)) - return; - - const Value *UO = getUnderlyingObject(Loc.Ptr); - assert(!isa<AllocaInst>(UO) && - "Should have been handled by getModRefInfoMask()"); - if (isa<Argument>(UO)) { - ME |= MemoryEffects::argMemOnly(MR); - return; - } - - // If it's not an identified object, it might be an argument. - if (!isIdentifiedObject(UO)) - ME |= MemoryEffects::argMemOnly(MR); - ME |= MemoryEffects(IRMemLocation::Other, MR); - }; // Scan the function body for instructions that may read or write memory. for (Instruction &I : instructions(F)) { // Some instructions can be ignored even if they read or write memory. // Detect these now, skipping to the next instruction if one is found. if (auto *Call = dyn_cast<CallBase>(&I)) { - // Ignore calls to functions in the same SCC, as long as the call sites - // don't have operand bundles. Calls with operand bundles are allowed to - // have memory effects not described by the memory effects of the call - // target. + // We can optimistically ignore calls to functions in the same SCC, with + // two caveats: + // * Calls with operand bundles may have additional effects. + // * Argument memory accesses may imply additional effects depending on + // what the argument location is. if (!Call->hasOperandBundles() && Call->getCalledFunction() && - SCCNodes.count(Call->getCalledFunction())) + SCCNodes.count(Call->getCalledFunction())) { + // Keep track of which additional locations are accessed if the SCC + // turns out to access argmem. + addArgLocs(RecursiveArgME, Call, ModRefInfo::ModRef, AAR); continue; + } + MemoryEffects CallME = AAR.getMemoryEffects(Call); // If the call doesn't access memory, we're done. @@ -190,15 +217,8 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, // Check whether all pointer arguments point to local memory, and // ignore calls that only access local memory. ModRefInfo ArgMR = CallME.getModRef(IRMemLocation::ArgMem); - if (ArgMR != ModRefInfo::NoModRef) { - for (const Use &U : Call->args()) { - const Value *Arg = U; - if (!Arg->getType()->isPtrOrPtrVectorTy()) - continue; - - AddLocAccess(MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()), ArgMR); - } - } + if (ArgMR != ModRefInfo::NoModRef) + addArgLocs(ME, Call, ArgMR, AAR); continue; } @@ -222,15 +242,15 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, if (I.isVolatile()) ME |= MemoryEffects::inaccessibleMemOnly(MR); - AddLocAccess(*Loc, MR); + addLocAccess(ME, *Loc, MR, AAR); } - return OrigME & ME; + return {OrigME & ME, RecursiveArgME}; } MemoryEffects llvm::computeFunctionBodyMemoryAccess(Function &F, AAResults &AAR) { - return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}); + return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}).first; } /// Deduce readonly/readnone/writeonly attributes for the SCC. @@ -238,24 +258,37 @@ template <typename AARGetterT> static void addMemoryAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, SmallSet<Function *, 8> &Changed) { MemoryEffects ME = MemoryEffects::none(); + MemoryEffects RecursiveArgME = MemoryEffects::none(); for (Function *F : SCCNodes) { // Call the callable parameter to look up AA results for this function. AAResults &AAR = AARGetter(*F); // Non-exact function definitions may not be selected at link time, and an // alternative version that writes to memory may be selected. See the // comment on GlobalValue::isDefinitionExact for more details. - ME |= checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); + auto [FnME, FnRecursiveArgME] = + checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); + ME |= FnME; + RecursiveArgME |= FnRecursiveArgME; // Reached bottom of the lattice, we will not be able to improve the result. if (ME == MemoryEffects::unknown()) return; } + // If the SCC accesses argmem, add recursive accesses resulting from that. + ModRefInfo ArgMR = ME.getModRef(IRMemLocation::ArgMem); + if (ArgMR != ModRefInfo::NoModRef) + ME |= RecursiveArgME & MemoryEffects(ArgMR); + for (Function *F : SCCNodes) { MemoryEffects OldME = F->getMemoryEffects(); MemoryEffects NewME = ME & OldME; if (NewME != OldME) { ++NumMemoryAttr; F->setMemoryEffects(NewME); + // Remove conflicting writable attributes. + if (!isModSet(NewME.getModRef(IRMemLocation::ArgMem))) + for (Argument &A : F->args()) + A.removeAttr(Attribute::Writable); Changed.insert(F); } } @@ -625,7 +658,15 @@ determinePointerAccessAttrs(Argument *A, // must be a data operand (e.g. argument or operand bundle) const unsigned UseIndex = CB.getDataOperandNo(U); - if (!CB.doesNotCapture(UseIndex)) { + // Some intrinsics (for instance ptrmask) do not capture their results, + // but return results thas alias their pointer argument, and thus should + // be handled like GEP or addrspacecast above. + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( + &CB, /*MustPreserveNullness=*/false)) { + for (Use &UU : CB.uses()) + if (Visited.insert(&UU).second) + Worklist.push_back(&UU); + } else if (!CB.doesNotCapture(UseIndex)) { if (!CB.onlyReadsMemory()) // If the callee can save a copy into other memory, then simply // scanning uses of the call is insufficient. We have no way @@ -639,7 +680,8 @@ determinePointerAccessAttrs(Argument *A, Worklist.push_back(&UU); } - if (CB.doesNotAccessMemory()) + ModRefInfo ArgMR = CB.getMemoryEffects().getModRef(IRMemLocation::ArgMem); + if (isNoModRef(ArgMR)) continue; if (Function *F = CB.getCalledFunction()) @@ -654,9 +696,9 @@ determinePointerAccessAttrs(Argument *A, // invokes with operand bundles. if (CB.doesNotAccessMemory(UseIndex)) { /* nop */ - } else if (CB.onlyReadsMemory() || CB.onlyReadsMemory(UseIndex)) { + } else if (!isModSet(ArgMR) || CB.onlyReadsMemory(UseIndex)) { IsRead = true; - } else if (CB.hasFnAttr(Attribute::WriteOnly) || + } else if (!isRefSet(ArgMR) || CB.dataOperandHasImpliedAttr(UseIndex, Attribute::WriteOnly)) { IsWrite = true; } else { @@ -810,6 +852,9 @@ static bool addAccessAttr(Argument *A, Attribute::AttrKind R) { A->removeAttr(Attribute::WriteOnly); A->removeAttr(Attribute::ReadOnly); A->removeAttr(Attribute::ReadNone); + // Remove conflicting writable attribute. + if (R == Attribute::ReadNone || R == Attribute::ReadOnly) + A->removeAttr(Attribute::Writable); A->addAttr(R); if (R == Attribute::ReadOnly) ++NumReadOnlyArg; @@ -1720,7 +1765,8 @@ static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) { template <typename AARGetterT> static SmallSet<Function *, 8> -deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { +deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter, + bool ArgAttrsOnly) { SCCNodesResult Nodes = createSCCNodeSet(Functions); // Bail if the SCC only contains optnone functions. @@ -1728,6 +1774,10 @@ deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { return {}; SmallSet<Function *, 8> Changed; + if (ArgAttrsOnly) { + addArgumentAttrs(Nodes.SCCNodes, Changed); + return Changed; + } addArgumentReturnedAttrs(Nodes.SCCNodes, Changed); addMemoryAttrs(Nodes.SCCNodes, AARGetter, Changed); @@ -1762,10 +1812,13 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, LazyCallGraph &CG, CGSCCUpdateResult &) { // Skip non-recursive functions if requested. + // Only infer argument attributes for non-recursive functions, because + // it can affect optimization behavior in conjunction with noalias. + bool ArgAttrsOnly = false; if (C.size() == 1 && SkipNonRecursive) { LazyCallGraph::Node &N = *C.begin(); if (!N->lookup(N)) - return PreservedAnalyses::all(); + ArgAttrsOnly = true; } FunctionAnalysisManager &FAM = @@ -1782,7 +1835,8 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, Functions.push_back(&N.getFunction()); } - auto ChangedFunctions = deriveAttrsInPostOrder(Functions, AARGetter); + auto ChangedFunctions = + deriveAttrsInPostOrder(Functions, AARGetter, ArgAttrsOnly); if (ChangedFunctions.empty()) return PreservedAnalyses::all(); @@ -1818,7 +1872,7 @@ void PostOrderFunctionAttrsPass::printPipeline( static_cast<PassInfoMixin<PostOrderFunctionAttrsPass> *>(this)->printPipeline( OS, MapClassName2PassName); if (SkipNonRecursive) - OS << "<skip-non-recursive>"; + OS << "<skip-non-recursive-function-attrs>"; } template <typename AARGetterT> diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp index f635b14cd2a9..9c546b531dff 100644 --- a/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/AutoUpgrade.h" @@ -272,7 +271,7 @@ class GlobalsImporter final { function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> IsPrevailing; FunctionImporter::ImportMapTy &ImportList; - StringMap<FunctionImporter::ExportSetTy> *const ExportLists; + DenseMap<StringRef, FunctionImporter::ExportSetTy> *const ExportLists; bool shouldImportGlobal(const ValueInfo &VI) { const auto &GVS = DefinedGVSummaries.find(VI.getGUID()); @@ -357,7 +356,7 @@ public: function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> IsPrevailing, FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists) + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists) : Index(Index), DefinedGVSummaries(DefinedGVSummaries), IsPrevailing(IsPrevailing), ImportList(ImportList), ExportLists(ExportLists) {} @@ -370,6 +369,29 @@ public: } }; +/// Determine the list of imports and exports for each module. +class ModuleImportsManager final { + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing; + const ModuleSummaryIndex &Index; + DenseMap<StringRef, FunctionImporter::ExportSetTy> *const ExportLists; + +public: + ModuleImportsManager( + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing, + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists = nullptr) + : IsPrevailing(IsPrevailing), Index(Index), ExportLists(ExportLists) {} + + /// Given the list of globals defined in a module, compute the list of imports + /// as well as the list of "exports", i.e. the list of symbols referenced from + /// another module (that may require promotion). + void computeImportForModule(const GVSummaryMapTy &DefinedGVSummaries, + StringRef ModName, + FunctionImporter::ImportMapTy &ImportList); +}; + static const char * getFailureName(FunctionImporter::ImportFailureReason Reason) { switch (Reason) { @@ -403,7 +425,7 @@ static void computeImportForFunction( isPrevailing, SmallVectorImpl<EdgeInfo> &Worklist, GlobalsImporter &GVImporter, FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists, FunctionImporter::ImportThresholdsTy &ImportThresholds) { GVImporter.onImportingSummary(Summary); static int ImportCount = 0; @@ -482,7 +504,7 @@ static void computeImportForFunction( continue; } - FunctionImporter::ImportFailureReason Reason; + FunctionImporter::ImportFailureReason Reason{}; CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, Summary.modulePath(), Reason); if (!CalleeSummary) { @@ -567,20 +589,13 @@ static void computeImportForFunction( } } -/// Given the list of globals defined in a module, compute the list of imports -/// as well as the list of "exports", i.e. the list of symbols referenced from -/// another module (that may require promotion). -static void ComputeImportForModule( - const GVSummaryMapTy &DefinedGVSummaries, - function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> - isPrevailing, - const ModuleSummaryIndex &Index, StringRef ModName, - FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { +void ModuleImportsManager::computeImportForModule( + const GVSummaryMapTy &DefinedGVSummaries, StringRef ModName, + FunctionImporter::ImportMapTy &ImportList) { // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. SmallVector<EdgeInfo, 128> Worklist; - GlobalsImporter GVI(Index, DefinedGVSummaries, isPrevailing, ImportList, + GlobalsImporter GVI(Index, DefinedGVSummaries, IsPrevailing, ImportList, ExportLists); FunctionImporter::ImportThresholdsTy ImportThresholds; @@ -603,7 +618,7 @@ static void ComputeImportForModule( continue; LLVM_DEBUG(dbgs() << "Initialize import for " << VI << "\n"); computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, - DefinedGVSummaries, isPrevailing, Worklist, GVI, + DefinedGVSummaries, IsPrevailing, Worklist, GVI, ImportList, ExportLists, ImportThresholds); } @@ -615,7 +630,7 @@ static void ComputeImportForModule( if (auto *FS = dyn_cast<FunctionSummary>(Summary)) computeImportForFunction(*FS, Index, Threshold, DefinedGVSummaries, - isPrevailing, Worklist, GVI, ImportList, + IsPrevailing, Worklist, GVI, ImportList, ExportLists, ImportThresholds); } @@ -671,10 +686,10 @@ static unsigned numGlobalVarSummaries(const ModuleSummaryIndex &Index, #endif #ifndef NDEBUG -static bool -checkVariableImport(const ModuleSummaryIndex &Index, - StringMap<FunctionImporter::ImportMapTy> &ImportLists, - StringMap<FunctionImporter::ExportSetTy> &ExportLists) { +static bool checkVariableImport( + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists, + DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) { DenseSet<GlobalValue::GUID> FlattenedImports; @@ -702,7 +717,7 @@ checkVariableImport(const ModuleSummaryIndex &Index, for (auto &ExportPerModule : ExportLists) for (auto &VI : ExportPerModule.second) if (!FlattenedImports.count(VI.getGUID()) && - IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first(), VI)) + IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first, VI)) return false; return true; @@ -712,19 +727,19 @@ checkVariableImport(const ModuleSummaryIndex &Index, /// Compute all the import and export for every module using the Index. void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, - const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> isPrevailing, - StringMap<FunctionImporter::ImportMapTy> &ImportLists, - StringMap<FunctionImporter::ExportSetTy> &ExportLists) { + DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists, + DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) { + ModuleImportsManager MIS(isPrevailing, Index, &ExportLists); // For each module that has function defined, compute the import/export lists. for (const auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { - auto &ImportList = ImportLists[DefinedGVSummaries.first()]; + auto &ImportList = ImportLists[DefinedGVSummaries.first]; LLVM_DEBUG(dbgs() << "Computing import for Module '" - << DefinedGVSummaries.first() << "'\n"); - ComputeImportForModule(DefinedGVSummaries.second, isPrevailing, Index, - DefinedGVSummaries.first(), ImportList, - &ExportLists); + << DefinedGVSummaries.first << "'\n"); + MIS.computeImportForModule(DefinedGVSummaries.second, + DefinedGVSummaries.first, ImportList); } // When computing imports we only added the variables and functions being @@ -735,7 +750,7 @@ void llvm::ComputeCrossModuleImport( for (auto &ELI : ExportLists) { FunctionImporter::ExportSetTy NewExports; const auto &DefinedGVSummaries = - ModuleToDefinedGVSummaries.lookup(ELI.first()); + ModuleToDefinedGVSummaries.lookup(ELI.first); for (auto &EI : ELI.second) { // Find the copy defined in the exporting module so that we can mark the // values it references in that specific definition as exported. @@ -783,7 +798,7 @@ void llvm::ComputeCrossModuleImport( LLVM_DEBUG(dbgs() << "Import/Export lists for " << ImportLists.size() << " modules:\n"); for (auto &ModuleImports : ImportLists) { - auto ModName = ModuleImports.first(); + auto ModName = ModuleImports.first; auto &Exports = ExportLists[ModName]; unsigned NumGVS = numGlobalVarSummaries(Index, Exports); LLVM_DEBUG(dbgs() << "* Module " << ModName << " exports " @@ -791,7 +806,7 @@ void llvm::ComputeCrossModuleImport( << " vars. Imports from " << ModuleImports.second.size() << " modules.\n"); for (auto &Src : ModuleImports.second) { - auto SrcModName = Src.first(); + auto SrcModName = Src.first; unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod << " functions imported from " << SrcModName << "\n"); @@ -809,7 +824,7 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index, LLVM_DEBUG(dbgs() << "* Module " << ModulePath << " imports from " << ImportList.size() << " modules.\n"); for (auto &Src : ImportList) { - auto SrcModName = Src.first(); + auto SrcModName = Src.first; unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod << " functions imported from " << SrcModName << "\n"); @@ -819,8 +834,15 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index, } #endif -/// Compute all the imports for the given module in the Index. -void llvm::ComputeCrossModuleImportForModule( +/// Compute all the imports for the given module using the Index. +/// +/// \p isPrevailing is a callback that will be called with a global value's GUID +/// and summary and should return whether the module corresponding to the +/// summary contains the linker-prevailing copy of that value. +/// +/// \p ImportList will be populated with a map that can be passed to +/// FunctionImporter::importFunctions() above (see description there). +static void ComputeCrossModuleImportForModuleForTest( StringRef ModulePath, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> isPrevailing, @@ -833,17 +855,20 @@ void llvm::ComputeCrossModuleImportForModule( // Compute the import list for this module. LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); - ComputeImportForModule(FunctionSummaryMap, isPrevailing, Index, ModulePath, - ImportList); + ModuleImportsManager MIS(isPrevailing, Index); + MIS.computeImportForModule(FunctionSummaryMap, ModulePath, ImportList); #ifndef NDEBUG dumpImportListForModule(Index, ModulePath, ImportList); #endif } -// Mark all external summaries in Index for import into the given module. -// Used for distributed builds using a distributed index. -void llvm::ComputeCrossModuleImportForModuleFromIndex( +/// Mark all external summaries in \p Index for import into the given module. +/// Used for testing the case of distributed builds using a distributed index. +/// +/// \p ImportList will be populated with a map that can be passed to +/// FunctionImporter::importFunctions() above (see description there). +static void ComputeCrossModuleImportForModuleFromIndexForTest( StringRef ModulePath, const ModuleSummaryIndex &Index, FunctionImporter::ImportMapTy &ImportList) { for (const auto &GlobalList : Index) { @@ -1041,7 +1066,7 @@ void llvm::computeDeadSymbolsWithConstProp( /// \p ModulePath. void llvm::gatherImportedSummariesForModule( StringRef ModulePath, - const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries, const FunctionImporter::ImportMapTy &ImportList, std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { // Include all summaries from the importing module. @@ -1049,10 +1074,9 @@ void llvm::gatherImportedSummariesForModule( ModuleToDefinedGVSummaries.lookup(ModulePath); // Include summaries for imports. for (const auto &ILI : ImportList) { - auto &SummariesForIndex = - ModuleToSummariesForIndex[std::string(ILI.first())]; + auto &SummariesForIndex = ModuleToSummariesForIndex[std::string(ILI.first)]; const auto &DefinedGVSummaries = - ModuleToDefinedGVSummaries.lookup(ILI.first()); + ModuleToDefinedGVSummaries.lookup(ILI.first); for (const auto &GI : ILI.second) { const auto &DS = DefinedGVSummaries.find(GI); assert(DS != DefinedGVSummaries.end() && @@ -1298,7 +1322,7 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { // ensure all uses of alias instead use the new clone (casted if necessary). NewFn->setLinkage(GA->getLinkage()); NewFn->setVisibility(GA->getVisibility()); - GA->replaceAllUsesWith(ConstantExpr::getBitCast(NewFn, GA->getType())); + GA->replaceAllUsesWith(NewFn); NewFn->takeName(GA); return NewFn; } @@ -1327,7 +1351,7 @@ Expected<bool> FunctionImporter::importFunctions( // Do the actual import of functions now, one Module at a time std::set<StringRef> ModuleNameOrderedList; for (const auto &FunctionsToImportPerModule : ImportList) { - ModuleNameOrderedList.insert(FunctionsToImportPerModule.first()); + ModuleNameOrderedList.insert(FunctionsToImportPerModule.first); } for (const auto &Name : ModuleNameOrderedList) { // Get the module for the import @@ -1461,7 +1485,7 @@ Expected<bool> FunctionImporter::importFunctions( return ImportedCount; } -static bool doImportingForModule( +static bool doImportingForModuleForTest( Module &M, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> isPrevailing) { if (SummaryFile.empty()) @@ -1481,11 +1505,11 @@ static bool doImportingForModule( // when testing distributed backend handling via the opt tool, when // we have distributed indexes containing exactly the summaries to import. if (ImportAllIndex) - ComputeCrossModuleImportForModuleFromIndex(M.getModuleIdentifier(), *Index, - ImportList); + ComputeCrossModuleImportForModuleFromIndexForTest(M.getModuleIdentifier(), + *Index, ImportList); else - ComputeCrossModuleImportForModule(M.getModuleIdentifier(), isPrevailing, - *Index, ImportList); + ComputeCrossModuleImportForModuleForTest(M.getModuleIdentifier(), + isPrevailing, *Index, ImportList); // Conservatively mark all internal values as promoted. This interface is // only used when doing importing via the function importing pass. The pass @@ -1533,7 +1557,7 @@ PreservedAnalyses FunctionImportPass::run(Module &M, auto isPrevailing = [](GlobalValue::GUID, const GlobalValueSummary *) { return true; }; - if (!doImportingForModule(M, isPrevailing)) + if (!doImportingForModuleForTest(M, isPrevailing)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index 3d6c501e4596..a4c12006ee24 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -5,45 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This specialises functions with constant parameters. Constant parameters -// like function pointers and constant globals are propagated to the callee by -// specializing the function. The main benefit of this pass at the moment is -// that indirect calls are transformed into direct calls, which provides inline -// opportunities that the inliner would not have been able to achieve. That's -// why function specialisation is run before the inliner in the optimisation -// pipeline; that is by design. Otherwise, we would only benefit from constant -// passing, which is a valid use-case too, but hasn't been explored much in -// terms of performance uplifts, cost-model and compile-time impact. -// -// Current limitations: -// - It does not yet handle integer ranges. We do support "literal constants", -// but that's off by default under an option. -// - The cost-model could be further looked into (it mainly focuses on inlining -// benefits), -// -// Ideas: -// - With a function specialization attribute for arguments, we could have -// a direct way to steer function specialization, avoiding the cost-model, -// and thus control compile-times / code-size. -// -// Todos: -// - Specializing recursive functions relies on running the transformation a -// number of times, which is controlled by option -// `func-specialization-max-iters`. Thus, increasing this value and the -// number of iterations, will linearly increase the number of times recursive -// functions get specialized, see also the discussion in -// https://reviews.llvm.org/D106426 for details. Perhaps there is a -// compile-time friendlier way to control/limit the number of specialisations -// for recursive functions. -// - Don't transform the function if function specialization does not trigger; -// the SCCPSolver may make IR changes. -// -// References: -// - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable -// it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q -// -//===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/ADT/Statistic.h" @@ -78,16 +39,47 @@ static cl::opt<unsigned> MaxClones( "The maximum number of clones allowed for a single function " "specialization")); +static cl::opt<unsigned> + MaxDiscoveryIterations("funcspec-max-discovery-iterations", cl::init(100), + cl::Hidden, + cl::desc("The maximum number of iterations allowed " + "when searching for transitive " + "phis")); + static cl::opt<unsigned> MaxIncomingPhiValues( - "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc( - "The maximum number of incoming values a PHI node can have to be " - "considered during the specialization bonus estimation")); + "funcspec-max-incoming-phi-values", cl::init(8), cl::Hidden, + cl::desc("The maximum number of incoming values a PHI node can have to be " + "considered during the specialization bonus estimation")); + +static cl::opt<unsigned> MaxBlockPredecessors( + "funcspec-max-block-predecessors", cl::init(2), cl::Hidden, cl::desc( + "The maximum number of predecessors a basic block can have to be " + "considered during the estimation of dead code")); static cl::opt<unsigned> MinFunctionSize( - "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc( + "funcspec-min-function-size", cl::init(300), cl::Hidden, cl::desc( "Don't specialize functions that have less than this number of " "instructions")); +static cl::opt<unsigned> MaxCodeSizeGrowth( + "funcspec-max-codesize-growth", cl::init(3), cl::Hidden, cl::desc( + "Maximum codesize growth allowed per function")); + +static cl::opt<unsigned> MinCodeSizeSavings( + "funcspec-min-codesize-savings", cl::init(20), cl::Hidden, cl::desc( + "Reject specializations whose codesize savings are less than this" + "much percent of the original function size")); + +static cl::opt<unsigned> MinLatencySavings( + "funcspec-min-latency-savings", cl::init(40), cl::Hidden, + cl::desc("Reject specializations whose latency savings are less than this" + "much percent of the original function size")); + +static cl::opt<unsigned> MinInliningBonus( + "funcspec-min-inlining-bonus", cl::init(300), cl::Hidden, cl::desc( + "Reject specializations whose inlining bonus is less than this" + "much percent of the original function size")); + static cl::opt<bool> SpecializeOnAddress( "funcspec-on-address", cl::init(false), cl::Hidden, cl::desc( "Enable function specialization on the address of global values")); @@ -101,32 +93,32 @@ static cl::opt<bool> SpecializeLiteralConstant( "Enable specialization of functions that take a literal constant as an " "argument")); -// Estimates the instruction cost of all the basic blocks in \p WorkList. -// The successors of such blocks are added to the list as long as they are -// executable and they have a unique predecessor. \p WorkList represents -// the basic blocks of a specialization which become dead once we replace -// instructions that are known to be constants. The aim here is to estimate -// the combination of size and latency savings in comparison to the non -// specialized version of the function. -static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList, - DenseSet<BasicBlock *> &DeadBlocks, - ConstMap &KnownConstants, SCCPSolver &Solver, - BlockFrequencyInfo &BFI, - TargetTransformInfo &TTI) { - Cost Bonus = 0; +bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ, + DenseSet<BasicBlock *> &DeadBlocks) { + unsigned I = 0; + return all_of(predecessors(Succ), + [&I, BB, Succ, &DeadBlocks] (BasicBlock *Pred) { + return I++ < MaxBlockPredecessors && + (Pred == BB || Pred == Succ || DeadBlocks.contains(Pred)); + }); +} +// Estimates the codesize savings due to dead code after constant propagation. +// \p WorkList represents the basic blocks of a specialization which will +// eventually become dead once we replace instructions that are known to be +// constants. The successors of such blocks are added to the list as long as +// the \p Solver found they were executable prior to specialization, and only +// if all their predecessors are dead. +Cost InstCostVisitor::estimateBasicBlocks( + SmallVectorImpl<BasicBlock *> &WorkList) { + Cost CodeSize = 0; // Accumulate the instruction cost of each basic block weighted by frequency. while (!WorkList.empty()) { BasicBlock *BB = WorkList.pop_back_val(); - uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() / - BFI.getEntryFreq(); - if (!Weight) - continue; - - // These blocks are considered dead as far as the InstCostVisitor is - // concerned. They haven't been proven dead yet by the Solver, but - // may become if we propagate the constant specialization arguments. + // These blocks are considered dead as far as the InstCostVisitor + // is concerned. They haven't been proven dead yet by the Solver, + // but may become if we propagate the specialization arguments. if (!DeadBlocks.insert(BB).second) continue; @@ -139,74 +131,100 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList, if (KnownConstants.contains(&I)) continue; - Bonus += Weight * - TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + Cost C = TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); - LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus - << " after user " << I << "\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: CodeSize " << C + << " for user " << I << "\n"); + CodeSize += C; } // Keep adding dead successors to the list as long as they are - // executable and they have a unique predecessor. + // executable and only reachable from dead blocks. for (BasicBlock *SuccBB : successors(BB)) - if (Solver.isBlockExecutable(SuccBB) && - SuccBB->getUniquePredecessor() == BB) + if (isBlockExecutable(SuccBB) && + canEliminateSuccessor(BB, SuccBB, DeadBlocks)) WorkList.push_back(SuccBB); } - return Bonus; + return CodeSize; } static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) { if (auto *C = dyn_cast<Constant>(V)) return C; - if (auto It = KnownConstants.find(V); It != KnownConstants.end()) - return It->second; - return nullptr; + return KnownConstants.lookup(V); } -Cost InstCostVisitor::getBonusFromPendingPHIs() { - Cost Bonus = 0; +Bonus InstCostVisitor::getBonusFromPendingPHIs() { + Bonus B; while (!PendingPHIs.empty()) { Instruction *Phi = PendingPHIs.pop_back_val(); - Bonus += getUserBonus(Phi); + // The pending PHIs could have been proven dead by now. + if (isBlockExecutable(Phi->getParent())) + B += getUserBonus(Phi); } - return Bonus; + return B; +} + +/// Compute a bonus for replacing argument \p A with constant \p C. +Bonus InstCostVisitor::getSpecializationBonus(Argument *A, Constant *C) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " + << C->getNameOrAsOperand() << "\n"); + Bonus B; + for (auto *U : A->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + if (isBlockExecutable(UI->getParent())) + B += getUserBonus(UI, A, C); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated bonus {CodeSize = " + << B.CodeSize << ", Latency = " << B.Latency + << "} for argument " << *A << "\n"); + return B; } -Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { +Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { + // We have already propagated a constant for this user. + if (KnownConstants.contains(User)) + return {0, 0}; + // Cache the iterator before visiting. LastVisited = Use ? KnownConstants.insert({Use, C}).first : KnownConstants.end(); - if (auto *I = dyn_cast<SwitchInst>(User)) - return estimateSwitchInst(*I); - - if (auto *I = dyn_cast<BranchInst>(User)) - return estimateBranchInst(*I); - - C = visit(*User); - if (!C) - return 0; + Cost CodeSize = 0; + if (auto *I = dyn_cast<SwitchInst>(User)) { + CodeSize = estimateSwitchInst(*I); + } else if (auto *I = dyn_cast<BranchInst>(User)) { + CodeSize = estimateBranchInst(*I); + } else { + C = visit(*User); + if (!C) + return {0, 0}; + } + // Even though it doesn't make sense to bind switch and branch instructions + // with a constant, unlike any other instruction type, it prevents estimating + // their bonus multiple times. KnownConstants.insert({User, C}); + CodeSize += TTI.getInstructionCost(User, TargetTransformInfo::TCK_CodeSize); + uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() / - BFI.getEntryFreq(); - if (!Weight) - return 0; + BFI.getEntryFreq().getFrequency(); - Cost Bonus = Weight * - TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency); + Cost Latency = Weight * + TTI.getInstructionCost(User, TargetTransformInfo::TCK_Latency); - LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus - << " for user " << *User << "\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: {CodeSize = " << CodeSize + << ", Latency = " << Latency << "} for user " + << *User << "\n"); + Bonus B(CodeSize, Latency); for (auto *U : User->users()) if (auto *UI = dyn_cast<Instruction>(U)) - if (UI != User && Solver.isBlockExecutable(UI->getParent())) - Bonus += getUserBonus(UI, User, C); + if (UI != User && isBlockExecutable(UI->getParent())) + B += getUserBonus(UI, User, C); - return Bonus; + return B; } Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { @@ -226,14 +244,12 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { SmallVector<BasicBlock *> WorkList; for (const auto &Case : I.cases()) { BasicBlock *BB = Case.getCaseSuccessor(); - if (BB == Succ || !Solver.isBlockExecutable(BB) || - BB->getUniquePredecessor() != I.getParent()) - continue; - WorkList.push_back(BB); + if (BB != Succ && isBlockExecutable(BB) && + canEliminateSuccessor(I.getParent(), BB, DeadBlocks)) + WorkList.push_back(BB); } - return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI, - TTI); + return estimateBasicBlocks(WorkList); } Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { @@ -246,12 +262,55 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { // Initialize the worklist with the dead successor as long as // it is executable and has a unique predecessor. SmallVector<BasicBlock *> WorkList; - if (Solver.isBlockExecutable(Succ) && - Succ->getUniquePredecessor() == I.getParent()) + if (isBlockExecutable(Succ) && + canEliminateSuccessor(I.getParent(), Succ, DeadBlocks)) WorkList.push_back(Succ); - return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI, - TTI); + return estimateBasicBlocks(WorkList); +} + +bool InstCostVisitor::discoverTransitivelyIncomingValues( + Constant *Const, PHINode *Root, DenseSet<PHINode *> &TransitivePHIs) { + + SmallVector<PHINode *, 64> WorkList; + WorkList.push_back(Root); + unsigned Iter = 0; + + while (!WorkList.empty()) { + PHINode *PN = WorkList.pop_back_val(); + + if (++Iter > MaxDiscoveryIterations || + PN->getNumIncomingValues() > MaxIncomingPhiValues) + return false; + + if (!TransitivePHIs.insert(PN).second) + continue; + + for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) { + Value *V = PN->getIncomingValue(I); + + // Disregard self-references and dead incoming values. + if (auto *Inst = dyn_cast<Instruction>(V)) + if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I))) + continue; + + if (Constant *C = findConstantFor(V, KnownConstants)) { + // Not all incoming values are the same constant. Bail immediately. + if (C != Const) + return false; + continue; + } + + if (auto *Phi = dyn_cast<PHINode>(V)) { + WorkList.push_back(Phi); + continue; + } + + // We can't reason about anything else. + return false; + } + } + return true; } Constant *InstCostVisitor::visitPHINode(PHINode &I) { @@ -260,23 +319,52 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) { bool Inserted = VisitedPHIs.insert(&I).second; Constant *Const = nullptr; + bool HaveSeenIncomingPHI = false; for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) { Value *V = I.getIncomingValue(Idx); + + // Disregard self-references and dead incoming values. if (auto *Inst = dyn_cast<Instruction>(V)) if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx))) continue; - Constant *C = findConstantFor(V, KnownConstants); - if (!C) { - if (Inserted) - PendingPHIs.push_back(&I); - return nullptr; + + if (Constant *C = findConstantFor(V, KnownConstants)) { + if (!Const) + Const = C; + // Not all incoming values are the same constant. Bail immediately. + if (C != Const) + return nullptr; + continue; } - if (!Const) - Const = C; - else if (C != Const) + + if (Inserted) { + // First time we are seeing this phi. We will retry later, after + // all the constant arguments have been propagated. Bail for now. + PendingPHIs.push_back(&I); return nullptr; + } + + if (isa<PHINode>(V)) { + // Perhaps it is a Transitive Phi. We will confirm later. + HaveSeenIncomingPHI = true; + continue; + } + + // We can't reason about anything else. + return nullptr; } + + if (!Const) + return nullptr; + + if (!HaveSeenIncomingPHI) + return Const; + + DenseSet<PHINode *> TransitivePHIs; + if (!discoverTransitivelyIncomingValues(Const, &I, TransitivePHIs)) + return nullptr; + return Const; } @@ -479,10 +567,7 @@ void FunctionSpecializer::promoteConstantStackValues(Function *F) { Value *GV = new GlobalVariable(M, ConstVal->getType(), true, GlobalValue::InternalLinkage, ConstVal, - "funcspec.arg"); - if (ArgOpType != ConstVal->getType()) - GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType); - + "specialized.arg." + Twine(++NGlobals)); Call->setArgOperand(Idx, GV); } } @@ -572,13 +657,18 @@ bool FunctionSpecializer::run() { if (!Inserted && !Metrics.isRecursive && !SpecializeLiteralConstant) continue; + int64_t Sz = *Metrics.NumInsts.getValue(); + assert(Sz > 0 && "CodeSize should be positive"); + // It is safe to down cast from int64_t, NumInsts is always positive. + unsigned FuncSize = static_cast<unsigned>(Sz); + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " - << F.getName() << " is " << Metrics.NumInsts << "\n"); + << F.getName() << " is " << FuncSize << "\n"); if (Inserted && Metrics.isRecursive) promoteConstantStackValues(&F); - if (!findSpecializations(&F, Metrics.NumInsts, AllSpecs, SM)) { + if (!findSpecializations(&F, FuncSize, AllSpecs, SM)) { LLVM_DEBUG( dbgs() << "FnSpecialization: No possible specializations found for " << F.getName() << "\n"); @@ -706,14 +796,15 @@ void FunctionSpecializer::removeDeadFunctions() { /// Clone the function \p F and remove the ssa_copy intrinsics added by /// the SCCPSolver in the cloned version. -static Function *cloneCandidateFunction(Function *F) { +static Function *cloneCandidateFunction(Function *F, unsigned NSpecs) { ValueToValueMapTy Mappings; Function *Clone = CloneFunction(F, Mappings); + Clone->setName(F->getName() + ".specialized." + Twine(NSpecs)); removeSSACopy(*Clone); return Clone; } -bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, +bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize, SmallVectorImpl<Spec> &AllSpecs, SpecMap &SM) { // A mapping from a specialisation signature to the index of the respective @@ -779,20 +870,48 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, AllSpecs[Index].CallSites.push_back(&CS); } else { // Calculate the specialisation gain. - Cost Score = 0; + Bonus B; + unsigned Score = 0; InstCostVisitor Visitor = getInstCostVisitorFor(F); - for (ArgInfo &A : S.Args) - Score += getSpecializationBonus(A.Formal, A.Actual, Visitor); - Score += Visitor.getBonusFromPendingPHIs(); + for (ArgInfo &A : S.Args) { + B += Visitor.getSpecializationBonus(A.Formal, A.Actual); + Score += getInliningBonus(A.Formal, A.Actual); + } + B += Visitor.getBonusFromPendingPHIs(); + - LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = " - << Score << "\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization bonus {CodeSize = " + << B.CodeSize << ", Latency = " << B.Latency + << ", Inlining = " << Score << "}\n"); + + FunctionGrowth[F] += FuncSize - B.CodeSize; + + auto IsProfitable = [](Bonus &B, unsigned Score, unsigned FuncSize, + unsigned FuncGrowth) -> bool { + // No check required. + if (ForceSpecialization) + return true; + // Minimum inlining bonus. + if (Score > MinInliningBonus * FuncSize / 100) + return true; + // Minimum codesize savings. + if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100) + return false; + // Minimum latency savings. + if (B.Latency < MinLatencySavings * FuncSize / 100) + return false; + // Maximum codesize growth. + if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) + return false; + return true; + }; // Discard unprofitable specialisations. - if (!ForceSpecialization && Score <= SpecCost) + if (!IsProfitable(B, Score, FuncSize, FunctionGrowth[F])) continue; // Create a new specialisation entry. + Score += std::max(B.CodeSize, B.Latency); auto &Spec = AllSpecs.emplace_back(F, S, Score); if (CS.getFunction() != F) Spec.CallSites.push_back(&CS); @@ -838,7 +957,7 @@ bool FunctionSpecializer::isCandidateFunction(Function *F) { Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) { - Function *Clone = cloneCandidateFunction(F); + Function *Clone = cloneCandidateFunction(F, Specializations.size() + 1); // The original function does not neccessarily have internal linkage, but the // clone must. @@ -859,30 +978,14 @@ Function *FunctionSpecializer::createSpecialization(Function *F, return Clone; } -/// Compute a bonus for replacing argument \p A with constant \p C. -Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, - InstCostVisitor &Visitor) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " - << C->getNameOrAsOperand() << "\n"); - - Cost TotalCost = 0; - for (auto *U : A->users()) - if (auto *UI = dyn_cast<Instruction>(U)) - if (Solver.isBlockExecutable(UI->getParent())) - TotalCost += Visitor.getUserBonus(UI, A, C); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus " - << TotalCost << " for argument " << *A << "\n"); - - // The below heuristic is only concerned with exposing inlining - // opportunities via indirect call promotion. If the argument is not a - // (potentially casted) function pointer, give up. - // - // TODO: Perhaps we should consider checking such inlining opportunities - // while traversing the users of the specialization arguments ? +/// Compute the inlining bonus for replacing argument \p A with constant \p C. +/// The below heuristic is only concerned with exposing inlining +/// opportunities via indirect call promotion. If the argument is not a +/// (potentially casted) function pointer, give up. +unsigned FunctionSpecializer::getInliningBonus(Argument *A, Constant *C) { Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts()); if (!CalledFunction) - return TotalCost; + return 0; // Get TTI for the called function (used for the inline cost). auto &CalleeTTI = (GetTTI)(*CalledFunction); @@ -892,7 +995,7 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, // calls to be promoted to direct calls. If the indirect call promotion // would likely enable the called function to be inlined, specializing is a // good idea. - int Bonus = 0; + int InliningBonus = 0; for (User *U : A->users()) { if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) continue; @@ -919,15 +1022,15 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, // We clamp the bonus for this call to be between zero and the default // threshold. if (IC.isAlways()) - Bonus += Params.DefaultThreshold; + InliningBonus += Params.DefaultThreshold; else if (IC.isVariable() && IC.getCostDelta() > 0) - Bonus += IC.getCostDelta(); + InliningBonus += IC.getCostDelta(); - LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus + LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << InliningBonus << " for user " << *U << "\n"); } - return TotalCost + Bonus; + return InliningBonus > 0 ? static_cast<unsigned>(InliningBonus) : 0; } /// Determine if it is possible to specialise the function for constant values diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 1ccc523ead8a..951372adcfa9 100644 --- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -17,7 +17,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" @@ -390,7 +389,7 @@ static bool collectSRATypes(DenseMap<uint64_t, GlobalPart> &Parts, } // Scalable types not currently supported. - if (isa<ScalableVectorType>(Ty)) + if (Ty->isScalableTy()) return false; auto IsStored = [](Value *V, Constant *Initializer) { @@ -930,25 +929,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, } // Update users of the allocation to use the new global instead. - BitCastInst *TheBC = nullptr; - while (!CI->use_empty()) { - Instruction *User = cast<Instruction>(CI->user_back()); - if (BitCastInst *BCI = dyn_cast<BitCastInst>(User)) { - if (BCI->getType() == NewGV->getType()) { - BCI->replaceAllUsesWith(NewGV); - BCI->eraseFromParent(); - } else { - BCI->setOperand(0, NewGV); - } - } else { - if (!TheBC) - TheBC = new BitCastInst(NewGV, CI->getType(), "newgv", CI); - User->replaceUsesOfWith(CI, TheBC); - } - } - - SmallSetVector<Constant *, 1> RepValues; - RepValues.insert(NewGV); + CI->replaceAllUsesWith(NewGV); // If there is a comparison against null, we will insert a global bool to // keep track of whether the global was initialized yet or not. @@ -980,9 +961,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, Use &LoadUse = *LI->use_begin(); ICmpInst *ICI = dyn_cast<ICmpInst>(LoadUse.getUser()); if (!ICI) { - auto *CE = ConstantExpr::getBitCast(NewGV, LI->getType()); - RepValues.insert(CE); - LoadUse.set(CE); + LoadUse.set(NewGV); continue; } @@ -1028,8 +1007,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, // To further other optimizations, loop over all users of NewGV and try to // constant prop them. This will promote GEP instructions with constant // indices into GEP constant-exprs, which will allow global-opt to hack on it. - for (auto *CE : RepValues) - ConstantPropUsersOf(CE, DL, TLI); + ConstantPropUsersOf(NewGV, DL, TLI); return NewGV; } @@ -1474,7 +1452,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, if (!GS.HasMultipleAccessingFunctions && GS.AccessingFunction && GV->getValueType()->isSingleValueType() && - GV->getType()->getAddressSpace() == 0 && + GV->getType()->getAddressSpace() == DL.getAllocaAddrSpace() && !GV->isExternallyInitialized() && GS.AccessingFunction->doesNotRecurse() && isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV, @@ -1584,7 +1562,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, GV->getAddressSpace()); NGV->takeName(GV); NGV->copyAttributesFrom(GV); - GV->replaceAllUsesWith(ConstantExpr::getBitCast(NGV, GV->getType())); + GV->replaceAllUsesWith(NGV); GV->eraseFromParent(); GV = NGV; } @@ -1635,7 +1613,7 @@ processGlobal(GlobalValue &GV, function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree) { - if (GV.getName().startswith("llvm.")) + if (GV.getName().starts_with("llvm.")) return false; GlobalStatus GS; @@ -1701,13 +1679,16 @@ static void RemoveAttribute(Function *F, Attribute::AttrKind A) { /// idea here is that we don't want to mess with the convention if the user /// explicitly requested something with performance implications like coldcc, /// GHC, or anyregcc. -static bool hasChangeableCC(Function *F) { +static bool hasChangeableCCImpl(Function *F) { CallingConv::ID CC = F->getCallingConv(); // FIXME: Is it worth transforming x86_stdcallcc and x86_fastcallcc? if (CC != CallingConv::C && CC != CallingConv::X86_ThisCall) return false; + if (F->isVarArg()) + return false; + // FIXME: Change CC for the whole chain of musttail calls when possible. // // Can't change CC of the function that either has musttail calls, or is a @@ -1727,7 +1708,16 @@ static bool hasChangeableCC(Function *F) { if (BB.getTerminatingMustTailCall()) return false; - return true; + return !F->hasAddressTaken(); +} + +using ChangeableCCCacheTy = SmallDenseMap<Function *, bool, 8>; +static bool hasChangeableCC(Function *F, + ChangeableCCCacheTy &ChangeableCCCache) { + auto Res = ChangeableCCCache.try_emplace(F, false); + if (Res.second) + Res.first->second = hasChangeableCCImpl(F); + return Res.first->second; } /// Return true if the block containing the call site has a BlockFrequency of @@ -1781,7 +1771,8 @@ static void changeCallSitesToColdCC(Function *F) { // coldcc calling convention. static bool hasOnlyColdCalls(Function &F, - function_ref<BlockFrequencyInfo &(Function &)> GetBFI) { + function_ref<BlockFrequencyInfo &(Function &)> GetBFI, + ChangeableCCCacheTy &ChangeableCCCache) { for (BasicBlock &BB : F) { for (Instruction &I : BB) { if (CallInst *CI = dyn_cast<CallInst>(&I)) { @@ -1800,8 +1791,7 @@ hasOnlyColdCalls(Function &F, if (!CalledFn->hasLocalLinkage()) return false; // Check if it's valid to use coldcc calling convention. - if (!hasChangeableCC(CalledFn) || CalledFn->isVarArg() || - CalledFn->hasAddressTaken()) + if (!hasChangeableCC(CalledFn, ChangeableCCCache)) return false; BlockFrequencyInfo &CallerBFI = GetBFI(F); if (!isColdCallSite(*CI, CallerBFI)) @@ -1873,12 +1863,9 @@ static void RemovePreallocated(Function *F) { CB->eraseFromParent(); Builder.SetInsertPoint(PreallocatedSetup); - auto *StackSave = - Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::stacksave)); - + auto *StackSave = Builder.CreateStackSave(); Builder.SetInsertPoint(NewCB->getNextNonDebugInstruction()); - Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::stackrestore), - StackSave); + Builder.CreateStackRestore(StackSave); // Replace @llvm.call.preallocated.arg() with alloca. // Cannot modify users() while iterating over it, so make a copy. @@ -1905,10 +1892,8 @@ static void RemovePreallocated(Function *F) { Builder.SetInsertPoint(InsertBefore); auto *Alloca = Builder.CreateAlloca(ArgType, AddressSpace, nullptr, "paarg"); - auto *BitCast = Builder.CreateBitCast( - Alloca, Type::getInt8PtrTy(M->getContext()), UseCall->getName()); - ArgAllocas[AllocArgIndex] = BitCast; - AllocaReplacement = BitCast; + ArgAllocas[AllocArgIndex] = Alloca; + AllocaReplacement = Alloca; } UseCall->replaceAllUsesWith(AllocaReplacement); @@ -1931,9 +1916,10 @@ OptimizeFunctions(Module &M, bool Changed = false; + ChangeableCCCacheTy ChangeableCCCache; std::vector<Function *> AllCallsCold; for (Function &F : llvm::make_early_inc_range(M)) - if (hasOnlyColdCalls(F, GetBFI)) + if (hasOnlyColdCalls(F, GetBFI, ChangeableCCCache)) AllCallsCold.push_back(&F); // Optimize functions. @@ -1995,7 +1981,7 @@ OptimizeFunctions(Module &M, continue; } - if (hasChangeableCC(&F) && !F.isVarArg() && !F.hasAddressTaken()) { + if (hasChangeableCC(&F, ChangeableCCCache)) { NumInternalFunc++; TargetTransformInfo &TTI = GetTTI(F); // Change the calling convention to coldcc if either stress testing is @@ -2005,6 +1991,7 @@ OptimizeFunctions(Module &M, if (EnableColdCCStressTest || (TTI.useColdCCForColdCall(F) && isValidCandidateForColdCC(F, GetBFI, AllCallsCold))) { + ChangeableCCCache.erase(&F); F.setCallingConv(CallingConv::Cold); changeCallSitesToColdCC(&F); Changed = true; @@ -2012,7 +1999,7 @@ OptimizeFunctions(Module &M, } } - if (hasChangeableCC(&F) && !F.isVarArg() && !F.hasAddressTaken()) { + if (hasChangeableCC(&F, ChangeableCCCache)) { // If this function has a calling convention worth changing, is not a // varargs function, and is only called directly, promote it to use the // Fast calling convention. @@ -2117,19 +2104,18 @@ static void setUsedInitializer(GlobalVariable &V, const auto *VEPT = cast<PointerType>(VAT->getArrayElementType()); // Type of pointer to the array of pointers. - PointerType *Int8PtrTy = - Type::getInt8PtrTy(V.getContext(), VEPT->getAddressSpace()); + PointerType *PtrTy = + PointerType::get(V.getContext(), VEPT->getAddressSpace()); SmallVector<Constant *, 8> UsedArray; for (GlobalValue *GV : Init) { - Constant *Cast = - ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy); + Constant *Cast = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, PtrTy); UsedArray.push_back(Cast); } // Sort to get deterministic order. array_pod_sort(UsedArray.begin(), UsedArray.end(), compareNames); - ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size()); + ArrayType *ATy = ArrayType::get(PtrTy, UsedArray.size()); Module *M = V.getParent(); V.removeFromParent(); @@ -2299,7 +2285,7 @@ OptimizeGlobalAliases(Module &M, if (!hasUsesToReplace(J, Used, RenameTarget)) continue; - J.replaceAllUsesWith(ConstantExpr::getBitCast(Aliasee, J.getType())); + J.replaceAllUsesWith(Aliasee); ++NumAliasesResolved; Changed = true; diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index 599ace9ca79f..fabb3c5fb921 100644 --- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -44,6 +44,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/Support/CommandLine.h" @@ -86,6 +87,11 @@ static cl::opt<int> MaxParametersForSplit( "hotcoldsplit-max-params", cl::init(4), cl::Hidden, cl::desc("Maximum number of parameters for a split function")); +static cl::opt<int> ColdBranchProbDenom( + "hotcoldsplit-cold-probability-denom", cl::init(100), cl::Hidden, + cl::desc("Divisor of cold branch probability." + "BranchProbability = 1/ColdBranchProbDenom")); + namespace { // Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify // this function unless you modify the MBB version as well. @@ -102,6 +108,32 @@ bool blockEndsInUnreachable(const BasicBlock &BB) { return !(isa<ReturnInst>(I) || isa<IndirectBrInst>(I)); } +void analyzeProfMetadata(BasicBlock *BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks) { + // TODO: Handle branches with > 2 successors. + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + if (!CondBr) + return; + + uint64_t TrueWt, FalseWt; + if (!extractBranchWeights(*CondBr, TrueWt, FalseWt)) + return; + + auto SumWt = TrueWt + FalseWt; + if (SumWt == 0) + return; + + auto TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); + auto FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + + if (TrueProb <= ColdProbThresh) + AnnotatedColdBlocks.insert(CondBr->getSuccessor(0)); + + if (FalseProb <= ColdProbThresh) + AnnotatedColdBlocks.insert(CondBr->getSuccessor(1)); +} + bool unlikelyExecuted(BasicBlock &BB) { // Exception handling blocks are unlikely executed. if (BB.isEHPad() || isa<ResumeInst>(BB.getTerminator())) @@ -183,6 +215,34 @@ bool HotColdSplitting::isFunctionCold(const Function &F) const { return false; } +bool HotColdSplitting::isBasicBlockCold(BasicBlock *BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl<BasicBlock *> &ColdBlocks, + SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks, + BlockFrequencyInfo *BFI) const { + // This block is already part of some outlining region. + if (ColdBlocks.count(BB)) + return true; + + if (BFI) { + if (PSI->isColdBlock(BB, BFI)) + return true; + } else { + // Find cold blocks of successors of BB during a reverse postorder traversal. + analyzeProfMetadata(BB, ColdProbThresh, AnnotatedColdBlocks); + + // A statically cold BB would be known before it is visited + // because the prof-data of incoming edges are 'analyzed' as part of RPOT. + if (AnnotatedColdBlocks.count(BB)) + return true; + } + + if (EnableStaticAnalysis && unlikelyExecuted(*BB)) + return true; + + return false; +} + // Returns false if the function should not be considered for hot-cold split // optimization. bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { @@ -565,6 +625,9 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { // The set of cold blocks. SmallPtrSet<BasicBlock *, 4> ColdBlocks; + // Set of cold blocks obtained with RPOT. + SmallPtrSet<BasicBlock *, 4> AnnotatedColdBlocks; + // The worklist of non-intersecting regions left to outline. SmallVector<OutliningRegion, 2> OutliningWorklist; @@ -587,16 +650,15 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { TargetTransformInfo &TTI = GetTTI(F); OptimizationRemarkEmitter &ORE = (*GetORE)(F); AssumptionCache *AC = LookupAC(F); + auto ColdProbThresh = TTI.getPredictableBranchThreshold().getCompl(); + + if (ColdBranchProbDenom.getNumOccurrences()) + ColdProbThresh = BranchProbability(1, ColdBranchProbDenom.getValue()); // Find all cold regions. for (BasicBlock *BB : RPOT) { - // This block is already part of some outlining region. - if (ColdBlocks.count(BB)) - continue; - - bool Cold = (BFI && PSI->isColdBlock(BB, BFI)) || - (EnableStaticAnalysis && unlikelyExecuted(*BB)); - if (!Cold) + if (!isBasicBlockCold(BB, ColdProbThresh, ColdBlocks, AnnotatedColdBlocks, + BFI)) continue; LLVM_DEBUG({ diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp index e258299c6a4c..a6e19df7c5f1 100644 --- a/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -155,7 +155,7 @@ struct OutlinableGroup { /// \param TargetBB - the BasicBlock to put Instruction into. static void moveBBContents(BasicBlock &SourceBB, BasicBlock &TargetBB) { for (Instruction &I : llvm::make_early_inc_range(SourceBB)) - I.moveBefore(TargetBB, TargetBB.end()); + I.moveBeforePreserving(TargetBB, TargetBB.end()); } /// A function to sort the keys of \p Map, which must be a mapping of constant @@ -198,7 +198,7 @@ Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other, BasicBlock * OutlinableRegion::findCorrespondingBlockIn(const OutlinableRegion &Other, BasicBlock *BB) { - Instruction *FirstNonPHI = BB->getFirstNonPHI(); + Instruction *FirstNonPHI = BB->getFirstNonPHIOrDbg(); assert(FirstNonPHI && "block is empty?"); Value *CorrespondingVal = findCorrespondingValueIn(Other, FirstNonPHI); if (!CorrespondingVal) @@ -557,7 +557,7 @@ collectRegionsConstants(OutlinableRegion &Region, // Iterate over the operands in an instruction. If the global value number, // assigned by the IRSimilarityCandidate, has been seen before, we check if - // the the number has been found to be not the same value in each instance. + // the number has been found to be not the same value in each instance. for (Value *V : ID.OperVals) { std::optional<unsigned> GVNOpt = C.getGVN(V); assert(GVNOpt && "Expected a GVN for operand?"); @@ -766,7 +766,7 @@ static void moveFunctionData(Function &Old, Function &New, } } -/// Find the the constants that will need to be lifted into arguments +/// Find the constants that will need to be lifted into arguments /// as they are not the same in each instance of the region. /// /// \param [in] C - The IRSimilarityCandidate containing the region we are @@ -1346,7 +1346,7 @@ findExtractedOutputToOverallOutputMapping(Module &M, OutlinableRegion &Region, // the output, so we add a pointer type to the argument types of the overall // function to handle this output and create a mapping to it. if (!TypeFound) { - Group.ArgumentTypes.push_back(Output->getType()->getPointerTo( + Group.ArgumentTypes.push_back(PointerType::get(Output->getContext(), M.getDataLayout().getAllocaAddrSpace())); // Mark the new pointer type as the last value in the aggregate argument // list. diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp index 3e00aebce372..a9747aebf67b 100644 --- a/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/llvm/lib/Transforms/IPO/Inliner.cpp @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/Inliner.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -63,7 +62,6 @@ #include <cassert> #include <functional> #include <utility> -#include <vector> using namespace llvm; diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 9b4b3efd7283..733f290b1bc9 100644 --- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -381,8 +381,7 @@ struct ScopedSaveAliaseesAndUsed { appendToCompilerUsed(M, CompilerUsed); for (auto P : FunctionAliases) - P.first->setAliasee( - ConstantExpr::getBitCast(P.second, P.first->getType())); + P.first->setAliasee(P.second); for (auto P : ResolverIFuncs) { // This does not preserve pointer casts that may have been stripped by the @@ -411,16 +410,19 @@ class LowerTypeTestsModule { // selectJumpTableArmEncoding may decide to use Thumb in either case. bool CanUseArmJumpTable = false, CanUseThumbBWJumpTable = false; + // Cache variable used by hasBranchTargetEnforcement(). + int HasBranchTargetEnforcement = -1; + // The jump table type we ended up deciding on. (Usually the same as // Arch, except that 'arm' and 'thumb' are often interchangeable.) Triple::ArchType JumpTableArch = Triple::UnknownArch; IntegerType *Int1Ty = Type::getInt1Ty(M.getContext()); IntegerType *Int8Ty = Type::getInt8Ty(M.getContext()); - PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext()); + PointerType *Int8PtrTy = PointerType::getUnqual(M.getContext()); ArrayType *Int8Arr0Ty = ArrayType::get(Type::getInt8Ty(M.getContext()), 0); IntegerType *Int32Ty = Type::getInt32Ty(M.getContext()); - PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty); + PointerType *Int32PtrTy = PointerType::getUnqual(M.getContext()); IntegerType *Int64Ty = Type::getInt64Ty(M.getContext()); IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0); @@ -492,6 +494,7 @@ class LowerTypeTestsModule { ArrayRef<GlobalTypeMember *> Globals); Triple::ArchType selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions); + bool hasBranchTargetEnforcement(); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS, @@ -755,9 +758,9 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI, // also conveniently gives us a bit offset to use during the load from // the bitset. Value *OffsetSHR = - B.CreateLShr(PtrOffset, ConstantExpr::getZExt(TIL.AlignLog2, IntPtrTy)); + B.CreateLShr(PtrOffset, B.CreateZExt(TIL.AlignLog2, IntPtrTy)); Value *OffsetSHL = B.CreateShl( - PtrOffset, ConstantExpr::getZExt( + PtrOffset, B.CreateZExt( ConstantExpr::getSub( ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)), TIL.AlignLog2), @@ -962,7 +965,6 @@ LowerTypeTestsModule::importTypeId(StringRef TypeId) { Int8Arr0Ty); if (auto *GV = dyn_cast<GlobalVariable>(C)) GV->setVisibility(GlobalValue::HiddenVisibility); - C = ConstantExpr::getBitCast(C, Int8PtrTy); return C; }; @@ -1100,15 +1102,13 @@ void LowerTypeTestsModule::importFunction( replaceCfiUses(F, FDecl, isJumpTableCanonical); // Set visibility late because it's used in replaceCfiUses() to determine - // whether uses need to to be replaced. + // whether uses need to be replaced. F->setVisibility(Visibility); } void LowerTypeTestsModule::lowerTypeTestCalls( ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) { - CombinedGlobalAddr = ConstantExpr::getBitCast(CombinedGlobalAddr, Int8PtrTy); - // For each type identifier in this disjoint set... for (Metadata *TypeId : TypeIds) { // Build the bitset. @@ -1196,6 +1196,20 @@ static const unsigned kARMJumpTableEntrySize = 4; static const unsigned kARMBTIJumpTableEntrySize = 8; static const unsigned kARMv6MJumpTableEntrySize = 16; static const unsigned kRISCVJumpTableEntrySize = 8; +static const unsigned kLOONGARCH64JumpTableEntrySize = 8; + +bool LowerTypeTestsModule::hasBranchTargetEnforcement() { + if (HasBranchTargetEnforcement == -1) { + // First time this query has been called. Find out the answer by checking + // the module flags. + if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("branch-target-enforcement"))) + HasBranchTargetEnforcement = (BTE->getZExtValue() != 0); + else + HasBranchTargetEnforcement = 0; + } + return HasBranchTargetEnforcement; +} unsigned LowerTypeTestsModule::getJumpTableEntrySize() { switch (JumpTableArch) { @@ -1209,19 +1223,22 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() { case Triple::arm: return kARMJumpTableEntrySize; case Triple::thumb: - if (CanUseThumbBWJumpTable) + if (CanUseThumbBWJumpTable) { + if (hasBranchTargetEnforcement()) + return kARMBTIJumpTableEntrySize; return kARMJumpTableEntrySize; - else + } else { return kARMv6MJumpTableEntrySize; + } case Triple::aarch64: - if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( - M.getModuleFlag("branch-target-enforcement"))) - if (BTE->getZExtValue()) - return kARMBTIJumpTableEntrySize; + if (hasBranchTargetEnforcement()) + return kARMBTIJumpTableEntrySize; return kARMJumpTableEntrySize; case Triple::riscv32: case Triple::riscv64: return kRISCVJumpTableEntrySize; + case Triple::loongarch64: + return kLOONGARCH64JumpTableEntrySize; default: report_fatal_error("Unsupported architecture for jump tables"); } @@ -1251,10 +1268,8 @@ void LowerTypeTestsModule::createJumpTableEntry( } else if (JumpTableArch == Triple::arm) { AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::aarch64) { - if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( - Dest->getParent()->getModuleFlag("branch-target-enforcement"))) - if (BTE->getZExtValue()) - AsmOS << "bti c\n"; + if (hasBranchTargetEnforcement()) + AsmOS << "bti c\n"; AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::thumb) { if (!CanUseThumbBWJumpTable) { @@ -1281,11 +1296,16 @@ void LowerTypeTestsModule::createJumpTableEntry( << ".balign 4\n" << "1: .word $" << ArgIndex << " - (0b + 4)\n"; } else { + if (hasBranchTargetEnforcement()) + AsmOS << "bti\n"; AsmOS << "b.w $" << ArgIndex << "\n"; } } else if (JumpTableArch == Triple::riscv32 || JumpTableArch == Triple::riscv64) { AsmOS << "tail $" << ArgIndex << "@plt\n"; + } else if (JumpTableArch == Triple::loongarch64) { + AsmOS << "pcalau12i $$t0, %pc_hi20($" << ArgIndex << ")\n" + << "jirl $$r0, $$t0, %pc_lo12($" << ArgIndex << ")\n"; } else { report_fatal_error("Unsupported architecture for jump tables"); } @@ -1304,7 +1324,8 @@ void LowerTypeTestsModule::buildBitSetsFromFunctions( ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) { if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm || Arch == Triple::thumb || Arch == Triple::aarch64 || - Arch == Triple::riscv32 || Arch == Triple::riscv64) + Arch == Triple::riscv32 || Arch == Triple::riscv64 || + Arch == Triple::loongarch64) buildBitSetsFromFunctionsNative(TypeIds, Functions); else if (Arch == Triple::wasm32 || Arch == Triple::wasm64) buildBitSetsFromFunctionsWASM(TypeIds, Functions); @@ -1446,9 +1467,19 @@ void LowerTypeTestsModule::createJumpTable( SmallVector<Value *, 16> AsmArgs; AsmArgs.reserve(Functions.size() * 2); - for (GlobalTypeMember *GTM : Functions) + // Check if all entries have the NoUnwind attribute. + // If all entries have it, we can safely mark the + // cfi.jumptable as NoUnwind, otherwise, direct calls + // to the jump table will not handle exceptions properly + bool areAllEntriesNounwind = true; + for (GlobalTypeMember *GTM : Functions) { + if (!llvm::cast<llvm::Function>(GTM->getGlobal()) + ->hasFnAttribute(llvm::Attribute::NoUnwind)) { + areAllEntriesNounwind = false; + } createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs, cast<Function>(GTM->getGlobal())); + } // Align the whole table by entry size. F->setAlignment(Align(getJumpTableEntrySize())); @@ -1461,17 +1492,23 @@ void LowerTypeTestsModule::createJumpTable( if (JumpTableArch == Triple::arm) F->addFnAttr("target-features", "-thumb-mode"); if (JumpTableArch == Triple::thumb) { - F->addFnAttr("target-features", "+thumb-mode"); - if (CanUseThumbBWJumpTable) { - // Thumb jump table assembly needs Thumb2. The following attribute is - // added by Clang for -march=armv7. - F->addFnAttr("target-cpu", "cortex-a8"); + if (hasBranchTargetEnforcement()) { + // If we're generating a Thumb jump table with BTI, add a target-features + // setting to ensure BTI can be assembled. + F->addFnAttr("target-features", "+thumb-mode,+pacbti"); + } else { + F->addFnAttr("target-features", "+thumb-mode"); + if (CanUseThumbBWJumpTable) { + // Thumb jump table assembly needs Thumb2. The following attribute is + // added by Clang for -march=armv7. + F->addFnAttr("target-cpu", "cortex-a8"); + } } } // When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI // for the function to avoid double BTI. This is a no-op without // -mbranch-protection=. - if (JumpTableArch == Triple::aarch64) { + if (JumpTableArch == Triple::aarch64 || JumpTableArch == Triple::thumb) { F->addFnAttr("branch-target-enforcement", "false"); F->addFnAttr("sign-return-address", "none"); } @@ -1485,8 +1522,13 @@ void LowerTypeTestsModule::createJumpTable( // -fcf-protection=. if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) F->addFnAttr(Attribute::NoCfCheck); - // Make sure we don't emit .eh_frame for this function. - F->addFnAttr(Attribute::NoUnwind); + + // Make sure we don't emit .eh_frame for this function if it isn't needed. + if (areAllEntriesNounwind) + F->addFnAttr(Attribute::NoUnwind); + + // Make sure we do not inline any calls to the cfi.jumptable. + F->addFnAttr(Attribute::NoInline); BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F); IRBuilder<> IRB(BB); @@ -1618,12 +1660,10 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( Function *F = cast<Function>(Functions[I]->getGlobal()); bool IsJumpTableCanonical = Functions[I]->isJumpTableCanonical(); - Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast( - ConstantExpr::getInBoundsGetElementPtr( - JumpTableType, JumpTable, - ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), - ConstantInt::get(IntPtrTy, I)}), - F->getType()); + Constant *CombinedGlobalElemPtr = ConstantExpr::getInBoundsGetElementPtr( + JumpTableType, JumpTable, + ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), + ConstantInt::get(IntPtrTy, I)}); const bool IsExported = Functions[I]->isExported(); if (!IsJumpTableCanonical) { diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index f835fb26fcb8..70a3f3067d9d 100644 --- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -104,11 +104,13 @@ static cl::opt<std::string> MemProfImportSummary( cl::desc("Import summary to use for testing the ThinLTO backend via opt"), cl::Hidden); +namespace llvm { // Indicate we are linking with an allocator that supports hot/cold operator // new interfaces. cl::opt<bool> SupportsHotColdNew( "supports-hot-cold-new", cl::init(false), cl::Hidden, cl::desc("Linking with hot/cold operator new interfaces")); +} // namespace llvm namespace { /// CRTP base for graphs built from either IR or ThinLTO summary index. @@ -791,11 +793,10 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: eraseCalleeEdge(const ContextEdge *Edge) { - auto EI = - std::find_if(CalleeEdges.begin(), CalleeEdges.end(), - [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) { - return CalleeEdge.get() == Edge; - }); + auto EI = llvm::find_if( + CalleeEdges, [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) { + return CalleeEdge.get() == Edge; + }); assert(EI != CalleeEdges.end()); CalleeEdges.erase(EI); } @@ -803,11 +804,10 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: eraseCallerEdge(const ContextEdge *Edge) { - auto EI = - std::find_if(CallerEdges.begin(), CallerEdges.end(), - [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) { - return CallerEdge.get() == Edge; - }); + auto EI = llvm::find_if( + CallerEdges, [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) { + return CallerEdge.get() == Edge; + }); assert(EI != CallerEdges.end()); CallerEdges.erase(EI); } @@ -2093,8 +2093,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( for (auto &Edge : CallerEdges) { // Skip any that have been removed by an earlier recursive call. if (Edge->Callee == nullptr && Edge->Caller == nullptr) { - assert(!std::count(Node->CallerEdges.begin(), Node->CallerEdges.end(), - Edge)); + assert(!llvm::count(Node->CallerEdges, Edge)); continue; } // Ignore any caller we previously visited via another edge. @@ -2985,6 +2984,21 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { if (!mayHaveMemprofSummary(CB)) continue; + auto *CalledValue = CB->getCalledOperand(); + auto *CalledFunction = CB->getCalledFunction(); + if (CalledValue && !CalledFunction) { + CalledValue = CalledValue->stripPointerCasts(); + // Stripping pointer casts can reveal a called function. + CalledFunction = dyn_cast<Function>(CalledValue); + } + // Check if this is an alias to a function. If so, get the + // called aliasee for the checks below. + if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) { + assert(!CalledFunction && + "Expected null called function in callsite for alias"); + CalledFunction = dyn_cast<Function>(GA->getAliaseeObject()); + } + CallStack<MDNode, MDNode::op_iterator> CallsiteContext( I.getMetadata(LLVMContext::MD_callsite)); auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof); @@ -3116,13 +3130,13 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size()); // Should have skipped indirect calls via mayHaveMemprofSummary. - assert(CB->getCalledFunction()); - assert(!IsMemProfClone(*CB->getCalledFunction())); + assert(CalledFunction); + assert(!IsMemProfClone(*CalledFunction)); // Update the calls per the summary info. // Save orig name since it gets updated in the first iteration // below. - auto CalleeOrigName = CB->getCalledFunction()->getName(); + auto CalleeOrigName = CalledFunction->getName(); for (unsigned J = 0; J < StackNode.Clones.size(); J++) { // Do nothing if this version calls the original version of its // callee. @@ -3130,7 +3144,7 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { continue; auto NewF = M.getOrInsertFunction( getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]), - CB->getCalledFunction()->getFunctionType()); + CalledFunction->getFunctionType()); CallBase *CBClone; // Copy 0 is the original function. if (!J) diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp index feda5d6459cb..c8c011d94e4a 100644 --- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -107,6 +107,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/StructuralHash.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -171,15 +172,14 @@ namespace { class FunctionNode { mutable AssertingVH<Function> F; - FunctionComparator::FunctionHash Hash; + IRHash Hash; public: // Note the hash is recalculated potentially multiple times, but it is cheap. - FunctionNode(Function *F) - : F(F), Hash(FunctionComparator::functionHash(*F)) {} + FunctionNode(Function *F) : F(F), Hash(StructuralHash(*F)) {} Function *getFunc() const { return F; } - FunctionComparator::FunctionHash getHash() const { return Hash; } + IRHash getHash() const { return Hash; } /// Replace the reference to the function F by the function G, assuming their /// implementations are equal. @@ -375,9 +375,32 @@ bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) { } #endif +/// Check whether \p F has an intrinsic which references +/// distinct metadata as an operand. The most common +/// instance of this would be CFI checks for function-local types. +static bool hasDistinctMetadataIntrinsic(const Function &F) { + for (const BasicBlock &BB : F) { + for (const Instruction &I : BB.instructionsWithoutDebug()) { + if (!isa<IntrinsicInst>(&I)) + continue; + + for (Value *Op : I.operands()) { + auto *MDL = dyn_cast<MetadataAsValue>(Op); + if (!MDL) + continue; + if (MDNode *N = dyn_cast<MDNode>(MDL->getMetadata())) + if (N->isDistinct()) + return true; + } + } + } + return false; +} + /// Check whether \p F is eligible for function merging. static bool isEligibleForMerging(Function &F) { - return !F.isDeclaration() && !F.hasAvailableExternallyLinkage(); + return !F.isDeclaration() && !F.hasAvailableExternallyLinkage() && + !hasDistinctMetadataIntrinsic(F); } bool MergeFunctions::runOnModule(Module &M) { @@ -390,11 +413,10 @@ bool MergeFunctions::runOnModule(Module &M) { // All functions in the module, ordered by hash. Functions with a unique // hash value are easily eliminated. - std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> - HashedFuncs; + std::vector<std::pair<IRHash, Function *>> HashedFuncs; for (Function &Func : M) { if (isEligibleForMerging(Func)) { - HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func}); + HashedFuncs.push_back({StructuralHash(Func), &Func}); } } @@ -441,7 +463,6 @@ bool MergeFunctions::runOnModule(Module &M) { // Replace direct callers of Old with New. void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { - Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType()); for (Use &U : llvm::make_early_inc_range(Old->uses())) { CallBase *CB = dyn_cast<CallBase>(U.getUser()); if (CB && CB->isCallee(&U)) { @@ -450,7 +471,7 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { // type congruences in byval(), in which case we need to keep the byval // type of the call-site, not the callee function. remove(CB->getFunction()); - U.set(BitcastNew); + U.set(New); } } } @@ -632,7 +653,7 @@ static bool canCreateThunkFor(Function *F) { // Don't merge tiny functions using a thunk, since it can just end up // making the function larger. if (F->size() == 1) { - if (F->front().size() <= 2) { + if (F->front().sizeWithoutDebug() < 2) { LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName() << " is too small to bother creating a thunk for\n"); return false; @@ -641,6 +662,13 @@ static bool canCreateThunkFor(Function *F) { return true; } +/// Copy metadata from one function to another. +static void copyMetadataIfPresent(Function *From, Function *To, StringRef Key) { + if (MDNode *MD = From->getMetadata(Key)) { + To->setMetadata(Key, MD); + } +} + // Replace G with a simple tail call to bitcast(F). Also (unless // MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), // delete G. Under MergeFunctionsPDI, we use G itself for creating @@ -719,6 +747,9 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { } else { NewG->copyAttributesFrom(G); NewG->takeName(G); + // Ensure CFI type metadata is propagated to the new function. + copyMetadataIfPresent(G, NewG, "type"); + copyMetadataIfPresent(G, NewG, "kcfi_type"); removeUsers(G); G->replaceAllUsesWith(NewG); G->eraseFromParent(); @@ -741,10 +772,9 @@ static bool canCreateAliasFor(Function *F) { // Replace G with an alias to F (deleting function G) void MergeFunctions::writeAlias(Function *F, Function *G) { - Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); PointerType *PtrType = G->getType(); auto *GA = GlobalAlias::create(G->getValueType(), PtrType->getAddressSpace(), - G->getLinkage(), "", BitcastF, G->getParent()); + G->getLinkage(), "", F, G->getParent()); const MaybeAlign FAlign = F->getAlign(); const MaybeAlign GAlign = G->getAlign(); @@ -795,6 +825,9 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { F->getAddressSpace(), "", F->getParent()); NewF->copyAttributesFrom(F); NewF->takeName(F); + // Ensure CFI type metadata is propagated to the new function. + copyMetadataIfPresent(F, NewF, "type"); + copyMetadataIfPresent(F, NewF, "kcfi_type"); removeUsers(F); F->replaceAllUsesWith(NewF); @@ -825,9 +858,8 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { // to replace a key in ValueMap<GlobalValue *> with a non-global. GlobalNumbers.erase(G); // If G's address is not significant, replace it entirely. - Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); removeUsers(G); - G->replaceAllUsesWith(BitcastF); + G->replaceAllUsesWith(F); } else { // Redirect direct callers of G to F. (See note on MergeFunctionsPDI // above). diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 588f3901e3cb..b2665161c090 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/Assumptions.h" #include "llvm/IR/BasicBlock.h" @@ -42,6 +43,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -156,6 +158,8 @@ STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, "Number of OpenMP runtime function uses identified"); STATISTIC(NumOpenMPTargetRegionKernels, "Number of OpenMP target region entry points (=kernels) identified"); +STATISTIC(NumNonOpenMPTargetRegionKernels, + "Number of non-OpenMP target region kernels identified"); STATISTIC(NumOpenMPTargetRegionKernelsSPMD, "Number of OpenMP target region entry points (=kernels) executed in " "SPMD-mode instead of generic-mode"); @@ -181,6 +185,92 @@ STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated"); static constexpr auto TAG = "[" DEBUG_TYPE "]"; #endif +namespace KernelInfo { + +// struct ConfigurationEnvironmentTy { +// uint8_t UseGenericStateMachine; +// uint8_t MayUseNestedParallelism; +// llvm::omp::OMPTgtExecModeFlags ExecMode; +// int32_t MinThreads; +// int32_t MaxThreads; +// int32_t MinTeams; +// int32_t MaxTeams; +// }; + +// struct DynamicEnvironmentTy { +// uint16_t DebugIndentionLevel; +// }; + +// struct KernelEnvironmentTy { +// ConfigurationEnvironmentTy Configuration; +// IdentTy *Ident; +// DynamicEnvironmentTy *DynamicEnv; +// }; + +#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \ + constexpr const unsigned MEMBER##Idx = IDX; + +KERNEL_ENVIRONMENT_IDX(Configuration, 0) +KERNEL_ENVIRONMENT_IDX(Ident, 1) + +#undef KERNEL_ENVIRONMENT_IDX + +#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \ + constexpr const unsigned MEMBER##Idx = IDX; + +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX + +#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \ + RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \ + return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \ + } + +KERNEL_ENVIRONMENT_GETTER(Ident, Constant) +KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct) + +#undef KERNEL_ENVIRONMENT_GETTER + +#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \ + ConstantInt *get##MEMBER##FromKernelEnvironment( \ + ConstantStruct *KernelEnvC) { \ + ConstantStruct *ConfigC = \ + getConfigurationFromKernelEnvironment(KernelEnvC); \ + return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \ + } + +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER + +GlobalVariable * +getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) { + constexpr const int InitKernelEnvironmentArgNo = 0; + return cast<GlobalVariable>( + KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo) + ->stripPointerCasts()); +} + +ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) { + GlobalVariable *KernelEnvGV = + getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + return cast<ConstantStruct>(KernelEnvGV->getInitializer()); +} +} // namespace KernelInfo + namespace { struct AAHeapToShared; @@ -196,6 +286,7 @@ struct OMPInformationCache : public InformationCache { : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M), OpenMPPostLink(OpenMPPostLink) { + OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M); OMPBuilder.initialize(); initializeRuntimeFunctions(M); initializeInternalControlVars(); @@ -531,7 +622,7 @@ struct OMPInformationCache : public InformationCache { for (Function &F : M) { for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"}) if (F.hasFnAttribute(Attribute::NoInline) && - F.getName().startswith(Prefix) && + F.getName().starts_with(Prefix) && !F.hasFnAttribute(Attribute::OptimizeNone)) F.removeFnAttr(Attribute::NoInline); } @@ -595,7 +686,7 @@ struct KernelInfoState : AbstractState { /// The parallel regions (identified by the outlined parallel functions) that /// can be reached from the associated function. - BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false> + BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false> ReachedKnownParallelRegions; /// State to track what parallel region we might reach. @@ -610,6 +701,10 @@ struct KernelInfoState : AbstractState { /// one we abort as the kernel is malformed. CallBase *KernelInitCB = nullptr; + /// The constant kernel environement as taken from and passed to + /// __kmpc_target_init. + ConstantStruct *KernelEnvC = nullptr; + /// The __kmpc_target_deinit call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; @@ -651,6 +746,7 @@ struct KernelInfoState : AbstractState { SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedKnownParallelRegions.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); + NestedParallelism = true; return ChangeStatus::CHANGED; } @@ -680,6 +776,8 @@ struct KernelInfoState : AbstractState { return false; if (ParallelLevels != RHS.ParallelLevels) return false; + if (NestedParallelism != RHS.NestedParallelism) + return false; return true; } @@ -714,6 +812,12 @@ struct KernelInfoState : AbstractState { "assumptions."); KernelDeinitCB = KIS.KernelDeinitCB; } + if (KIS.KernelEnvC) { + if (KernelEnvC && KernelEnvC != KIS.KernelEnvC) + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); + KernelEnvC = KIS.KernelEnvC; + } SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; @@ -875,6 +979,9 @@ struct OpenMPOpt { } } + if (OMPInfoCache.OpenMPPostLink) + Changed |= removeRuntimeSymbols(); + return Changed; } @@ -903,7 +1010,7 @@ struct OpenMPOpt { /// Print OpenMP GPU kernels for testing. void printKernels() const { for (Function *F : SCC) { - if (!omp::isKernel(*F)) + if (!omp::isOpenMPKernel(*F)) continue; auto Remark = [&](OptimizationRemarkAnalysis ORA) { @@ -1404,6 +1511,37 @@ private: return Changed; } + /// Tries to remove known runtime symbols that are optional from the module. + bool removeRuntimeSymbols() { + // The RPC client symbol is defined in `libc` and indicates that something + // required an RPC server. If its users were all optimized out then we can + // safely remove it. + // TODO: This should be somewhere more common in the future. + if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) { + if (!GV->getType()->isPointerTy()) + return false; + + Constant *C = GV->getInitializer(); + if (!C) + return false; + + // Check to see if the only user of the RPC client is the external handle. + GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts()); + if (!Client || Client->getNumUses() > 1 || + Client->user_back() != GV->getInitializer()) + return false; + + Client->replaceAllUsesWith(PoisonValue::get(Client->getType())); + Client->eraseFromParent(); + + GV->replaceAllUsesWith(PoisonValue::get(GV->getType())); + GV->eraseFromParent(); + + return true; + } + return false; + } + /// Tries to hide the latency of runtime calls that involve host to /// device memory transfers by splitting them into their "issue" and "wait" /// versions. The "issue" is moved upwards as much as possible. The "wait" is @@ -1858,7 +1996,7 @@ private: Function *F = I->getParent()->getParent(); auto &ORE = OREGetter(F); - if (RemarkName.startswith("OMP")) + if (RemarkName.starts_with("OMP")) ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)) << " [" << RemarkName << "]"; @@ -1874,7 +2012,7 @@ private: RemarkCallBack &&RemarkCB) const { auto &ORE = OREGetter(F); - if (RemarkName.startswith("OMP")) + if (RemarkName.starts_with("OMP")) ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)) << " [" << RemarkName << "]"; @@ -1944,7 +2082,7 @@ Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { // TODO: We should use an AA to create an (optimistic and callback // call-aware) call graph. For now we stick to simple patterns that // are less powerful, basically the worst fixpoint. - if (isKernel(F)) { + if (isOpenMPKernel(F)) { CachedKernel = Kernel(&F); return *CachedKernel; } @@ -2535,6 +2673,17 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { } }; +/// Determines if \p BB exits the function unconditionally itself or reaches a +/// block that does through only unique successors. +static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) { + if (succ_empty(BB)) + return true; + const BasicBlock *const Successor = BB->getUniqueSuccessor(); + if (!Successor) + return false; + return hasFunctionEndAsUniqueSuccessor(Successor); +} + struct AAExecutionDomainFunction : public AAExecutionDomain { AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) : AAExecutionDomain(IRP, A) {} @@ -2587,18 +2736,22 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { if (!ED.IsReachedFromAlignedBarrierOnly || ED.EncounteredNonLocalSideEffect) return; + if (!ED.EncounteredAssumes.empty() && !A.isModulePass()) + return; - // We can remove this barrier, if it is one, or all aligned barriers - // reaching the kernel end. In the latter case we can transitively work - // our way back until we find a barrier that guards a side-effect if we - // are dealing with the kernel end here. + // We can remove this barrier, if it is one, or aligned barriers reaching + // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel + // end should only be removed if the kernel end is their unique successor; + // otherwise, they may have side-effects that aren't accounted for in the + // kernel end in their other successors. If those barriers have other + // barriers reaching them, those can be transitively removed as well as + // long as the kernel end is also their unique successor. if (CB) { DeletedBarriers.insert(CB); A.deleteAfterManifest(*CB); ++NumBarriersEliminated; Changed = ChangeStatus::CHANGED; } else if (!ED.AlignedBarriers.empty()) { - NumBarriersEliminated += ED.AlignedBarriers.size(); Changed = ChangeStatus::CHANGED; SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(), ED.AlignedBarriers.end()); @@ -2609,7 +2762,10 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { continue; if (LastCB->getFunction() != getAnchorScope()) continue; + if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent())) + continue; if (!DeletedBarriers.count(LastCB)) { + ++NumBarriersEliminated; A.deleteAfterManifest(*LastCB); continue; } @@ -2633,7 +2789,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { HandleAlignedBarrier(CB); // Handle the "kernel end barrier" for kernels too. - if (omp::isKernel(*getAnchorScope())) + if (omp::isOpenMPKernel(*getAnchorScope())) HandleAlignedBarrier(nullptr); return Changed; @@ -2779,9 +2935,11 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; if (!CB) return false; - const int InitModeArgNo = 1; - auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo)); - return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC); + ConstantStruct *KernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(CB); + ConstantInt *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC); + return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC; } if (C->isZero()) { @@ -2884,11 +3042,11 @@ bool AAExecutionDomainFunction::handleCallees(Attributor &A, } else { // We could not find all predecessors, so this is either a kernel or a // function with external linkage (or with some other weird uses). - if (omp::isKernel(*getAnchorScope())) { + if (omp::isOpenMPKernel(*getAnchorScope())) { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = true; EntryBBED.EncounteredNonLocalSideEffect = false; - ExitED.IsReachingAlignedBarrierOnly = true; + ExitED.IsReachingAlignedBarrierOnly = false; } else { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = false; @@ -2938,7 +3096,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { Function *F = getAnchorScope(); BasicBlock &EntryBB = F->getEntryBlock(); - bool IsKernel = omp::isKernel(*F); + bool IsKernel = omp::isOpenMPKernel(*F); SmallVector<Instruction *> SyncInstWorklist; for (auto &RIt : *RPOT) { @@ -3063,7 +3221,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (EDAA && EDAA->getState().isValidState()) { const auto &CalleeED = EDAA->getFunctionExecutionDomain(); ED.IsReachedFromAlignedBarrierOnly = - CalleeED.IsReachedFromAlignedBarrierOnly; + CalleeED.IsReachedFromAlignedBarrierOnly; AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly; if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly) ED.EncounteredNonLocalSideEffect |= @@ -3442,6 +3600,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { using Base = StateWrapper<KernelInfoState, AbstractAttribute>; AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + /// The callee value is tracked beyond a simple stripPointerCasts, so we allow + /// unknown callees. + static bool requiresCalleeForCallBase() { return false; } + /// Statistics are tracked as part of manifest for now. void trackStatistics() const override {} @@ -3468,7 +3630,8 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { ", #ParLevels: " + (ParallelLevels.isValidState() ? std::to_string(ParallelLevels.size()) - : "<invalid>"); + : "<invalid>") + + ", NestedPar: " + (NestedParallelism ? "yes" : "no"); } /// Create an abstract attribute biew for the position \p IRP. @@ -3500,6 +3663,33 @@ struct AAKernelInfoFunction : AAKernelInfo { return GuardedInstructions; } + void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) { + Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction( + KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx}); + assert(NewKernelEnvC && "Failed to create new kernel environment"); + KernelEnvC = cast<ConstantStruct>(NewKernelEnvC); + } + +#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \ + void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \ + ConstantStruct *ConfigC = \ + KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \ + Constant *NewConfigC = ConstantFoldInsertValueInstruction( \ + ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \ + assert(NewConfigC && "Failed to create new configuration environment"); \ + setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \ + } + + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -3548,61 +3738,73 @@ struct AAKernelInfoFunction : AAKernelInfo { ReachingKernelEntries.insert(Fn); IsKernelEntry = true; - // For kernels we might need to initialize/finalize the IsSPMD state and - // we need to register a simplification callback so that the Attributor - // knows the constant arguments to __kmpc_target_init and - // __kmpc_target_deinit might actually change. - - Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional<Value *> { - return nullptr; - }; + KernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + GlobalVariable *KernelEnvGV = + KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB); - Attributor::SimplifictionCallbackTy ModeSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional<Value *> { - // IRP represents the "SPMDCompatibilityTracker" argument of an - // __kmpc_target_init or - // __kmpc_target_deinit call. We will answer this one with the internal - // state. - if (!SPMDCompatibilityTracker.isValidState()) - return nullptr; - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - if (AA) - A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + Attributor::GlobalVariableSimplifictionCallbackTy + KernelConfigurationSimplifyCB = + [&](const GlobalVariable &GV, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> std::optional<Constant *> { + if (!isAtFixpoint()) { + if (!AA) + return nullptr; UsedAssumedInformation = true; - } else { - UsedAssumedInformation = false; + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); } - auto *Val = ConstantInt::getSigned( - IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()), - SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD - : OMP_TGT_EXEC_MODE_GENERIC); - return Val; + return KernelEnvC; }; - constexpr const int InitModeArgNo = 1; - constexpr const int DeinitModeArgNo = 1; - constexpr const int InitUseStateMachineArgNo = 2; - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), - StateMachineSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo), - ModeSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo), - ModeSimplifyCB); + A.registerGlobalVariableSimplificationCallback( + *KernelEnvGV, KernelConfigurationSimplifyCB); // Check if we know we are in SPMD-mode already. - ConstantInt *ModeArg = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); - if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) + ConstantInt *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC); + ConstantInt *AssumedExecModeC = ConstantInt::get( + ExecModeC->getType(), + ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD); + if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD) SPMDCompatibilityTracker.indicateOptimisticFixpoint(); - // This is a generic region but SPMDization is disabled so stop tracking. else if (DisableOpenMPOptSPMDization) + // This is a generic region but SPMDization is disabled so stop + // tracking. SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else + setExecModeOfKernelEnvironment(AssumedExecModeC); + + const Triple T(Fn->getParent()->getTargetTriple()); + auto *Int32Ty = Type::getInt32Ty(Fn->getContext()); + auto [MinThreads, MaxThreads] = + OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn); + if (MinThreads) + setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads)); + if (MaxThreads) + setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads)); + auto [MinTeams, MaxTeams] = + OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn); + if (MinTeams) + setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams)); + if (MaxTeams) + setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams)); + + ConstantInt *MayUseNestedParallelismC = + KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC); + ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get( + MayUseNestedParallelismC->getType(), NestedParallelism); + setMayUseNestedParallelismOfKernelEnvironment( + AssumedMayUseNestedParallelismC); + + if (!DisableOpenMPOptStateMachineRewrite) { + ConstantInt *UseGenericStateMachineC = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + KernelEnvC); + ConstantInt *AssumedUseGenericStateMachineC = + ConstantInt::get(UseGenericStateMachineC->getType(), false); + setUseGenericStateMachineOfKernelEnvironment( + AssumedUseGenericStateMachineC); + } // Register virtual uses of functions we might need to preserve. auto RegisterVirtualUse = [&](RuntimeFunction RFKind, @@ -3703,22 +3905,32 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!KernelInitCB || !KernelDeinitCB) return ChangeStatus::UNCHANGED; - /// Insert nested Parallelism global variable - Function *Kernel = getAnchorScope(); - Module &M = *Kernel->getParent(); - Type *Int8Ty = Type::getInt8Ty(M.getContext()); - auto *GV = new GlobalVariable( - M, Int8Ty, /* isConstant */ true, GlobalValue::WeakAnyLinkage, - ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0), - Kernel->getName() + "_nested_parallelism"); - GV->setVisibility(GlobalValue::HiddenVisibility); - - // If we can we change the execution mode to SPMD-mode otherwise we build a - // custom state machine. ChangeStatus Changed = ChangeStatus::UNCHANGED; + + bool HasBuiltStateMachine = true; if (!changeToSPMDMode(A, Changed)) { if (!KernelInitCB->getCalledFunction()->isDeclaration()) - return buildCustomStateMachine(A); + HasBuiltStateMachine = buildCustomStateMachine(A, Changed); + else + HasBuiltStateMachine = false; + } + + // We need to reset KernelEnvC if specific rewriting is not done. + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + ConstantInt *OldUseGenericStateMachineVal = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC); + if (!HasBuiltStateMachine) + setUseGenericStateMachineOfKernelEnvironment( + OldUseGenericStateMachineVal); + + // At last, update the KernelEnvc + GlobalVariable *KernelEnvGV = + KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + if (KernelEnvGV->getInitializer() != KernelEnvC) { + KernelEnvGV->setInitializer(KernelEnvC); + Changed = ChangeStatus::CHANGED; } return Changed; @@ -3788,14 +4000,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // Find escaping outputs from the guarded region to outside users and // broadcast their values to them. for (Instruction &I : *RegionStartBB) { - SmallPtrSet<Instruction *, 4> OutsideUsers; - for (User *Usr : I.users()) { - Instruction &UsrI = *cast<Instruction>(Usr); + SmallVector<Use *, 4> OutsideUses; + for (Use &U : I.uses()) { + Instruction &UsrI = *cast<Instruction>(U.getUser()); if (UsrI.getParent() != RegionStartBB) - OutsideUsers.insert(&UsrI); + OutsideUses.push_back(&U); } - if (OutsideUsers.empty()) + if (OutsideUses.empty()) continue; HasBroadcastValues = true; @@ -3818,8 +4030,8 @@ struct AAKernelInfoFunction : AAKernelInfo { RegionBarrierBB->getTerminator()); // Emit a load instruction and replace uses of the output value. - for (Instruction *UsrI : OutsideUsers) - UsrI->replaceUsesOfWith(&I, LoadI); + for (Use *U : OutsideUses) + A.changeUseAfterManifest(*U, *LoadI); } auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); @@ -4043,19 +4255,14 @@ struct AAKernelInfoFunction : AAKernelInfo { auto *CB = cast<CallBase>(Kernel->user_back()); Kernel = CB->getCaller(); } - assert(omp::isKernel(*Kernel) && "Expected kernel function!"); + assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!"); // Check if the kernel is already in SPMD mode, if so, return success. - GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( - (Kernel->getName() + "_exec_mode").str()); - assert(ExecMode && "Kernel without exec mode?"); - assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); - - // Set the global exec mode flag to indicate SPMD-Generic mode. - assert(isa<ConstantInt>(ExecMode->getInitializer()) && - "ExecMode is not an integer!"); - const int8_t ExecModeVal = - cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + auto *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC); + const int8_t ExecModeVal = ExecModeC->getSExtValue(); if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) return true; @@ -4073,27 +4280,8 @@ struct AAKernelInfoFunction : AAKernelInfo { // kernel is executed in. assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && "Initially non-SPMD kernel has SPMD exec mode!"); - ExecMode->setInitializer( - ConstantInt::get(ExecMode->getInitializer()->getType(), - ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); - - // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. - const int InitModeArgNo = 1; - const int DeinitModeArgNo = 1; - const int InitUseStateMachineArgNo = 2; - - auto &Ctx = getAnchorValue().getContext(); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitModeArgNo), - *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), - OMP_TGT_EXEC_MODE_SPMD)); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), - *ConstantInt::getBool(Ctx, false)); - A.changeUseAfterManifest( - KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), - *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), - OMP_TGT_EXEC_MODE_SPMD)); + setExecModeOfKernelEnvironment(ConstantInt::get( + ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); ++NumOpenMPTargetRegionKernelsSPMD; @@ -4104,46 +4292,47 @@ struct AAKernelInfoFunction : AAKernelInfo { return true; }; - ChangeStatus buildCustomStateMachine(Attributor &A) { + bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) { // If we have disabled state machine rewrites, don't make a custom one if (DisableOpenMPOptStateMachineRewrite) - return ChangeStatus::UNCHANGED; + return false; // Don't rewrite the state machine if we are not in a valid state. if (!ReachedKnownParallelRegions.isValidState()) - return ChangeStatus::UNCHANGED; + return false; auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); if (!OMPInfoCache.runtimeFnsAvailable( {OMPRTL___kmpc_get_hardware_num_threads_in_block, OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic, OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel})) - return ChangeStatus::UNCHANGED; + return false; - const int InitModeArgNo = 1; - const int InitUseStateMachineArgNo = 2; + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); // Check if the current configuration is non-SPMD and generic state machine. // If we already have SPMD mode or a custom state machine we do not need to // go any further. If it is anything but a constant something is weird and // we give up. - ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( - KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); - ConstantInt *Mode = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); + ConstantInt *UseStateMachineC = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC); + ConstantInt *ModeC = + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC); // If we are stuck with generic mode, try to create a custom device (=GPU) // state machine which is specialized for the parallel regions that are // reachable by the kernel. - if (!UseStateMachine || UseStateMachine->isZero() || !Mode || - (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) - return ChangeStatus::UNCHANGED; + if (UseStateMachineC->isZero() || + (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) + return false; + + Changed = ChangeStatus::CHANGED; // If not SPMD mode, indicate we use a custom state machine now. - auto &Ctx = getAnchorValue().getContext(); - auto *FalseVal = ConstantInt::getBool(Ctx, false); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal); + setUseGenericStateMachineOfKernelEnvironment( + ConstantInt::get(UseStateMachineC->getType(), false)); // If we don't actually need a state machine we are done here. This can // happen if there simply are no parallel regions. In the resulting kernel @@ -4157,7 +4346,7 @@ struct AAKernelInfoFunction : AAKernelInfo { }; A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark); - return ChangeStatus::CHANGED; + return true; } // Keep track in the statistics of our new shiny custom state machine. @@ -4222,6 +4411,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // UserCodeEntryBB: // user code // __kmpc_target_deinit(...) // + auto &Ctx = getAnchorValue().getContext(); Function *Kernel = getAssociatedFunction(); assert(Kernel && "Expected an associated function!"); @@ -4292,7 +4482,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // Create local storage for the work function pointer. const DataLayout &DL = M.getDataLayout(); - Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); + Type *VoidPtrTy = PointerType::getUnqual(Ctx); Instruction *WorkFnAI = new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr, "worker.work_fn.addr", &Kernel->getEntryBlock().front()); @@ -4304,7 +4494,7 @@ struct AAKernelInfoFunction : AAKernelInfo { StateMachineBeginBB->end()), DLoc)); - Value *Ident = KernelInitCB->getArgOperand(0); + Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC); Value *GTid = KernelInitCB; FunctionCallee BarrierFn = @@ -4337,9 +4527,6 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionType *ParallelRegionFnTy = FunctionType::get( Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)}, false); - Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( - WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast", - StateMachineBeginBB); Instruction *IsDone = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, @@ -4358,11 +4545,15 @@ struct AAKernelInfoFunction : AAKernelInfo { Value *ZeroArg = Constant::getNullValue(ParallelRegionFnTy->getParamType(0)); + const unsigned int WrapperFunctionArgNo = 6; + // Now that we have most of the CFG skeleton it is time for the if-cascade // that checks the function pointer we got from the runtime against the // parallel regions we expect, if there are any. for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) { - auto *ParallelRegion = ReachedKnownParallelRegions[I]; + auto *CB = ReachedKnownParallelRegions[I]; + auto *ParallelRegion = dyn_cast<Function>( + CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts()); BasicBlock *PRExecuteBB = BasicBlock::Create( Ctx, "worker_state_machine.parallel_region.execute", Kernel, StateMachineEndParallelBB); @@ -4374,13 +4565,15 @@ struct AAKernelInfoFunction : AAKernelInfo { BasicBlock *PRNextBB = BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", Kernel, StateMachineEndParallelBB); + A.registerManifestAddedBasicBlock(*PRExecuteBB); + A.registerManifestAddedBasicBlock(*PRNextBB); // Check if we need to compare the pointer at all or if we can just // call the parallel region function. Value *IsPR; if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) { Instruction *CmpI = ICmpInst::Create( - ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, + ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion, "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); CmpI->setDebugLoc(DLoc); IsPR = CmpI; @@ -4400,7 +4593,7 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!ReachedUnknownParallelRegions.empty()) { StateMachineIfCascadeCurrentBB->setName( "worker_state_machine.parallel_region.fallback.execute"); - CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "", + CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "", StateMachineIfCascadeCurrentBB) ->setDebugLoc(DLoc); } @@ -4423,7 +4616,7 @@ struct AAKernelInfoFunction : AAKernelInfo { BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB) ->setDebugLoc(DLoc); - return ChangeStatus::CHANGED; + return true; } /// Fixpoint iteration update function. Will be called every time a dependence @@ -4431,6 +4624,46 @@ struct AAKernelInfoFunction : AAKernelInfo { ChangeStatus updateImpl(Attributor &A) override { KernelInfoState StateBefore = getState(); + // When we leave this function this RAII will make sure the member + // KernelEnvC is updated properly depending on the state. That member is + // used for simplification of values and needs to be up to date at all + // times. + struct UpdateKernelEnvCRAII { + AAKernelInfoFunction &AA; + + UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {} + + ~UpdateKernelEnvCRAII() { + if (!AA.KernelEnvC) + return; + + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB); + + if (!AA.isValidState()) { + AA.KernelEnvC = ExistingKernelEnvC; + return; + } + + if (!AA.ReachedKnownParallelRegions.isValidState()) + AA.setUseGenericStateMachineOfKernelEnvironment( + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC)); + + if (!AA.SPMDCompatibilityTracker.isValidState()) + AA.setExecModeOfKernelEnvironment( + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC)); + + ConstantInt *MayUseNestedParallelismC = + KernelInfo::getMayUseNestedParallelismFromKernelEnvironment( + AA.KernelEnvC); + ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get( + MayUseNestedParallelismC->getType(), AA.NestedParallelism); + AA.setMayUseNestedParallelismOfKernelEnvironment( + NewMayUseNestedParallelismC); + } + } RAII(*this); + // Callback to check a read/write instruction. auto CheckRWInst = [&](Instruction &I) { // We handle calls later. @@ -4634,15 +4867,13 @@ struct AAKernelInfoCallSite : AAKernelInfo { AAKernelInfo::initialize(A); CallBase &CB = cast<CallBase>(getAssociatedValue()); - Function *Callee = getAssociatedFunction(); - auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>( *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); // Check for SPMD-mode assumptions. if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) { - SPMDCompatibilityTracker.indicateOptimisticFixpoint(); indicateOptimisticFixpoint(); + return; } // First weed out calls we do not care about, that is readonly/readnone @@ -4657,124 +4888,156 @@ struct AAKernelInfoCallSite : AAKernelInfo { // we will handle them explicitly in the switch below. If it is not, we // will use an AAKernelInfo object on the callee to gather information and // merge that into the current state. The latter happens in the updateImpl. - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); - if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { - // Unknown caller or declarations are not analyzable, we give up. - if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { + auto CheckCallee = [&](Function *Callee, unsigned NumCallees) { + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + // Unknown caller or declarations are not analyzable, we give up. + if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { - // Unknown callees might contain parallel regions, except if they have - // an appropriate assumption attached. - if (!AssumptionAA || - !(AssumptionAA->hasAssumption("omp_no_openmp") || - AssumptionAA->hasAssumption("omp_no_parallelism"))) - ReachedUnknownParallelRegions.insert(&CB); + // Unknown callees might contain parallel regions, except if they have + // an appropriate assumption attached. + if (!AssumptionAA || + !(AssumptionAA->hasAssumption("omp_no_openmp") || + AssumptionAA->hasAssumption("omp_no_parallelism"))) + ReachedUnknownParallelRegions.insert(&CB); - // If SPMDCompatibilityTracker is not fixed, we need to give up on the - // idea we can run something unknown in SPMD-mode. - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - } + // If SPMDCompatibilityTracker is not fixed, we need to give up on the + // idea we can run something unknown in SPMD-mode. + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); + } - // We have updated the state for this unknown call properly, there won't - // be any change so we indicate a fixpoint. - indicateOptimisticFixpoint(); + // We have updated the state for this unknown call properly, there + // won't be any change so we indicate a fixpoint. + indicateOptimisticFixpoint(); + } + // If the callee is known and can be used in IPO, we will update the + // state based on the callee state in updateImpl. + return; + } + if (NumCallees > 1) { + indicatePessimisticFixpoint(); + return; } - // If the callee is known and can be used in IPO, we will update the state - // based on the callee state in updateImpl. - return; - } - const unsigned int WrapperFunctionArgNo = 6; - RuntimeFunction RF = It->getSecond(); - switch (RF) { - // All the functions we know are compatible with SPMD mode. - case OMPRTL___kmpc_is_spmd_exec_mode: - case OMPRTL___kmpc_distribute_static_fini: - case OMPRTL___kmpc_for_static_fini: - case OMPRTL___kmpc_global_thread_num: - case OMPRTL___kmpc_get_hardware_num_threads_in_block: - case OMPRTL___kmpc_get_hardware_num_blocks: - case OMPRTL___kmpc_single: - case OMPRTL___kmpc_end_single: - case OMPRTL___kmpc_master: - case OMPRTL___kmpc_end_master: - case OMPRTL___kmpc_barrier: - case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: - case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: - case OMPRTL___kmpc_nvptx_end_reduce_nowait: - break; - case OMPRTL___kmpc_distribute_static_init_4: - case OMPRTL___kmpc_distribute_static_init_4u: - case OMPRTL___kmpc_distribute_static_init_8: - case OMPRTL___kmpc_distribute_static_init_8u: - case OMPRTL___kmpc_for_static_init_4: - case OMPRTL___kmpc_for_static_init_4u: - case OMPRTL___kmpc_for_static_init_8: - case OMPRTL___kmpc_for_static_init_8u: { - // Check the schedule and allow static schedule in SPMD mode. - unsigned ScheduleArgOpNo = 2; - auto *ScheduleTypeCI = - dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); - unsigned ScheduleTypeVal = - ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; - switch (OMPScheduleType(ScheduleTypeVal)) { - case OMPScheduleType::UnorderedStatic: - case OMPScheduleType::UnorderedStaticChunked: - case OMPScheduleType::OrderedDistribute: - case OMPScheduleType::OrderedDistributeChunked: + RuntimeFunction RF = It->getSecond(); + switch (RF) { + // All the functions we know are compatible with SPMD mode. + case OMPRTL___kmpc_is_spmd_exec_mode: + case OMPRTL___kmpc_distribute_static_fini: + case OMPRTL___kmpc_for_static_fini: + case OMPRTL___kmpc_global_thread_num: + case OMPRTL___kmpc_get_hardware_num_threads_in_block: + case OMPRTL___kmpc_get_hardware_num_blocks: + case OMPRTL___kmpc_single: + case OMPRTL___kmpc_end_single: + case OMPRTL___kmpc_master: + case OMPRTL___kmpc_end_master: + case OMPRTL___kmpc_barrier: + case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: + case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: + case OMPRTL___kmpc_error: + case OMPRTL___kmpc_flush: + case OMPRTL___kmpc_get_hardware_thread_id_in_block: + case OMPRTL___kmpc_get_warp_size: + case OMPRTL_omp_get_thread_num: + case OMPRTL_omp_get_num_threads: + case OMPRTL_omp_get_max_threads: + case OMPRTL_omp_in_parallel: + case OMPRTL_omp_get_dynamic: + case OMPRTL_omp_get_cancellation: + case OMPRTL_omp_get_nested: + case OMPRTL_omp_get_schedule: + case OMPRTL_omp_get_thread_limit: + case OMPRTL_omp_get_supported_active_levels: + case OMPRTL_omp_get_max_active_levels: + case OMPRTL_omp_get_level: + case OMPRTL_omp_get_ancestor_thread_num: + case OMPRTL_omp_get_team_size: + case OMPRTL_omp_get_active_level: + case OMPRTL_omp_in_final: + case OMPRTL_omp_get_proc_bind: + case OMPRTL_omp_get_num_places: + case OMPRTL_omp_get_num_procs: + case OMPRTL_omp_get_place_proc_ids: + case OMPRTL_omp_get_place_num: + case OMPRTL_omp_get_partition_num_places: + case OMPRTL_omp_get_partition_place_nums: + case OMPRTL_omp_get_wtime: break; - default: + case OMPRTL___kmpc_distribute_static_init_4: + case OMPRTL___kmpc_distribute_static_init_4u: + case OMPRTL___kmpc_distribute_static_init_8: + case OMPRTL___kmpc_distribute_static_init_8u: + case OMPRTL___kmpc_for_static_init_4: + case OMPRTL___kmpc_for_static_init_4u: + case OMPRTL___kmpc_for_static_init_8: + case OMPRTL___kmpc_for_static_init_8u: { + // Check the schedule and allow static schedule in SPMD mode. + unsigned ScheduleArgOpNo = 2; + auto *ScheduleTypeCI = + dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); + unsigned ScheduleTypeVal = + ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; + switch (OMPScheduleType(ScheduleTypeVal)) { + case OMPScheduleType::UnorderedStatic: + case OMPScheduleType::UnorderedStaticChunked: + case OMPScheduleType::OrderedDistribute: + case OMPScheduleType::OrderedDistributeChunked: + break; + default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); + break; + }; + } break; + case OMPRTL___kmpc_target_init: + KernelInitCB = &CB; + break; + case OMPRTL___kmpc_target_deinit: + KernelDeinitCB = &CB; + break; + case OMPRTL___kmpc_parallel_51: + if (!handleParallel51(A, CB)) + indicatePessimisticFixpoint(); + return; + case OMPRTL___kmpc_omp_task: + // We do not look into tasks right now, just give up. SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); + ReachedUnknownParallelRegions.insert(&CB); break; - }; - } break; - case OMPRTL___kmpc_target_init: - KernelInitCB = &CB; - break; - case OMPRTL___kmpc_target_deinit: - KernelDeinitCB = &CB; - break; - case OMPRTL___kmpc_parallel_51: - if (auto *ParallelRegion = dyn_cast<Function>( - CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { - ReachedKnownParallelRegions.insert(ParallelRegion); - /// Check nested parallelism - auto *FnAA = A.getAAFor<AAKernelInfo>( - *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); - NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || - !FnAA->ReachedKnownParallelRegions.empty() || - !FnAA->ReachedUnknownParallelRegions.empty(); + case OMPRTL___kmpc_alloc_shared: + case OMPRTL___kmpc_free_shared: + // Return without setting a fixpoint, to be resolved in updateImpl. + return; + default: + // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, + // generally. However, they do not hide parallel regions. + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); break; } - // The condition above should usually get the parallel region function - // pointer and record it. In the off chance it doesn't we assume the - // worst. - ReachedUnknownParallelRegions.insert(&CB); - break; - case OMPRTL___kmpc_omp_task: - // We do not look into tasks right now, just give up. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - ReachedUnknownParallelRegions.insert(&CB); - break; - case OMPRTL___kmpc_alloc_shared: - case OMPRTL___kmpc_free_shared: - // Return without setting a fixpoint, to be resolved in updateImpl. + // All other OpenMP runtime calls will not reach parallel regions so they + // can be safely ignored for now. Since it is a known OpenMP runtime call + // we have now modeled all effects and there is no need for any update. + indicateOptimisticFixpoint(); + }; + + const auto *AACE = + A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) { + CheckCallee(getAssociatedFunction(), 1); return; - default: - // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, - // generally. However, they do not hide parallel regions. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - break; } - // All other OpenMP runtime calls will not reach parallel regions so they - // can be safely ignored for now. Since it is a known OpenMP runtime call we - // have now modeled all effects and there is no need for any update. - indicateOptimisticFixpoint(); + const auto &OptimisticEdges = AACE->getOptimisticEdges(); + for (auto *Callee : OptimisticEdges) { + CheckCallee(Callee, OptimisticEdges.size()); + if (isAtFixpoint()) + break; + } } ChangeStatus updateImpl(Attributor &A) override { @@ -4782,62 +5045,115 @@ struct AAKernelInfoCallSite : AAKernelInfo { // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); + KernelInfoState StateBefore = getState(); - // If F is not a runtime function, propagate the AAKernelInfo of the callee. - if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { - const IRPosition &FnPos = IRPosition::function(*F); - auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); - if (!FnAA) + auto CheckCallee = [&](Function *F, int NumCallees) { + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); + + // If F is not a runtime function, propagate the AAKernelInfo of the + // callee. + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + const IRPosition &FnPos = IRPosition::function(*F); + auto *FnAA = + A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); + if (!FnAA) + return indicatePessimisticFixpoint(); + if (getState() == FnAA->getState()) + return ChangeStatus::UNCHANGED; + getState() = FnAA->getState(); + return ChangeStatus::CHANGED; + } + if (NumCallees > 1) return indicatePessimisticFixpoint(); - if (getState() == FnAA->getState()) - return ChangeStatus::UNCHANGED; - getState() = FnAA->getState(); - return ChangeStatus::CHANGED; - } - // F is a runtime function that allocates or frees memory, check - // AAHeapToStack and AAHeapToShared. - KernelInfoState StateBefore = getState(); - assert((It->getSecond() == OMPRTL___kmpc_alloc_shared || - It->getSecond() == OMPRTL___kmpc_free_shared) && - "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); + CallBase &CB = cast<CallBase>(getAssociatedValue()); + if (It->getSecond() == OMPRTL___kmpc_parallel_51) { + if (!handleParallel51(A, CB)) + return indicatePessimisticFixpoint(); + return StateBefore == getState() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } - CallBase &CB = cast<CallBase>(getAssociatedValue()); + // F is a runtime function that allocates or frees memory, check + // AAHeapToStack and AAHeapToShared. + assert( + (It->getSecond() == OMPRTL___kmpc_alloc_shared || + It->getSecond() == OMPRTL___kmpc_free_shared) && + "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); - auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( - *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); - auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( - *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); - RuntimeFunction RF = It->getSecond(); + RuntimeFunction RF = It->getSecond(); - switch (RF) { - // If neither HeapToStack nor HeapToShared assume the call is removed, - // assume SPMD incompatibility. - case OMPRTL___kmpc_alloc_shared: - if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && - (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) - SPMDCompatibilityTracker.insert(&CB); - break; - case OMPRTL___kmpc_free_shared: - if ((!HeapToStackAA || - !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && - (!HeapToSharedAA || - !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) + switch (RF) { + // If neither HeapToStack nor HeapToShared assume the call is removed, + // assume SPMD incompatibility. + case OMPRTL___kmpc_alloc_shared: + if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && + (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) + SPMDCompatibilityTracker.insert(&CB); + break; + case OMPRTL___kmpc_free_shared: + if ((!HeapToStackAA || + !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && + (!HeapToSharedAA || + !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) + SPMDCompatibilityTracker.insert(&CB); + break; + default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); - break; - default: - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); + } + return ChangeStatus::CHANGED; + }; + + const auto *AACE = + A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) { + if (Function *F = getAssociatedFunction()) + CheckCallee(F, /*NumCallees=*/1); + } else { + const auto &OptimisticEdges = AACE->getOptimisticEdges(); + for (auto *Callee : OptimisticEdges) { + CheckCallee(Callee, OptimisticEdges.size()); + if (isAtFixpoint()) + break; + } } return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + + /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was + /// handled, if a problem occurred, false is returned. + bool handleParallel51(Attributor &A, CallBase &CB) { + const unsigned int NonWrapperFunctionArgNo = 5; + const unsigned int WrapperFunctionArgNo = 6; + auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed() + ? NonWrapperFunctionArgNo + : WrapperFunctionArgNo; + + auto *ParallelRegion = dyn_cast<Function>( + CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts()); + if (!ParallelRegion) + return false; + + ReachedKnownParallelRegions.insert(&CB); + /// Check nested parallelism + auto *FnAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); + NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || + !FnAA->ReachedKnownParallelRegions.empty() || + !FnAA->ReachedKnownParallelRegions.isValidState() || + !FnAA->ReachedUnknownParallelRegions.isValidState() || + !FnAA->ReachedUnknownParallelRegions.empty(); + return true; + } }; struct AAFoldRuntimeCall @@ -5251,6 +5567,11 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) { UsedAssumedInformation, AA::Interprocedural); continue; } + if (auto *CI = dyn_cast<CallBase>(&I)) { + if (CI->isIndirectCall()) + A.getOrCreateAAFor<AAIndirectCallInfo>( + IRPosition::callsite_function(*CI)); + } if (auto *SI = dyn_cast<StoreInst>(&I)) { A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); continue; @@ -5569,7 +5890,9 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } -bool llvm::omp::isKernel(Function &Fn) { return Fn.hasFnAttribute("kernel"); } +bool llvm::omp::isOpenMPKernel(Function &Fn) { + return Fn.hasFnAttribute("kernel"); +} KernelSet llvm::omp::getDeviceKernels(Module &M) { // TODO: Create a more cross-platform way of determining device kernels. @@ -5591,10 +5914,13 @@ KernelSet llvm::omp::getDeviceKernels(Module &M) { if (!KernelFn) continue; - assert(isKernel(*KernelFn) && "Inconsistent kernel function annotation"); - ++NumOpenMPTargetRegionKernels; - - Kernels.insert(KernelFn); + // We are only interested in OpenMP target regions. Others, such as kernels + // generated by CUDA but linked together, are not interesting to this pass. + if (isOpenMPKernel(*KernelFn)) { + ++NumOpenMPTargetRegionKernels; + Kernels.insert(KernelFn); + } else + ++NumNonOpenMPTargetRegionKernels; } return Kernels; diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp index b88ba2dec24b..aa4f205ec5bd 100644 --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -161,7 +161,7 @@ struct FunctionOutliningInfo { // The dominating block of the region to be outlined. BasicBlock *NonReturnBlock = nullptr; - // The set of blocks in Entries that that are predecessors to ReturnBlock + // The set of blocks in Entries that are predecessors to ReturnBlock SmallVector<BasicBlock *, 4> ReturnBlockPreds; }; @@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline( const DataLayout &DL = Caller->getParent()->getDataLayout(); // The savings of eliminating the call: - int NonWeightedSavings = getCallsiteCost(CB, DL); + int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL); BlockFrequency NormWeightedSavings(NonWeightedSavings); // Weighted saving is smaller than weighted cost, return false @@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB, } if (CallInst *CI = dyn_cast<CallInst>(&I)) { - InlineCost += getCallsiteCost(*CI, DL); + InlineCost += getCallsiteCost(*TTI, *CI, DL); continue; } if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) { - InlineCost += getCallsiteCost(*II, DL); + InlineCost += getCallsiteCost(*TTI, *II, DL); continue; } @@ -1042,7 +1042,7 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const { ClonedOI->ReturnBlock = ClonedOI->ReturnBlock->splitBasicBlock( ClonedOI->ReturnBlock->getFirstNonPHI()->getIterator()); BasicBlock::iterator I = PreReturn->begin(); - Instruction *Ins = &ClonedOI->ReturnBlock->front(); + BasicBlock::iterator Ins = ClonedOI->ReturnBlock->begin(); SmallVector<Instruction *, 4> DeadPhis; while (I != PreReturn->end()) { PHINode *OldPhi = dyn_cast<PHINode>(I); @@ -1050,9 +1050,10 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const { break; PHINode *RetPhi = - PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, "", Ins); + PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, ""); + RetPhi->insertBefore(Ins); OldPhi->replaceAllUsesWith(RetPhi); - Ins = ClonedOI->ReturnBlock->getFirstNonPHI(); + Ins = ClonedOI->ReturnBlock->getFirstNonPHIIt(); RetPhi->addIncoming(&*I, PreReturn); for (BasicBlock *E : ClonedOI->ReturnBlockPreds) { diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp index e2e6364df906..b1f9b827dcba 100644 --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ModRef.h" @@ -43,7 +44,7 @@ STATISTIC(NumInstReplaced, "Number of instructions replaced with (simpler) instruction"); static cl::opt<unsigned> FuncSpecMaxIters( - "funcspec-max-iters", cl::init(1), cl::Hidden, cl::desc( + "funcspec-max-iters", cl::init(10), cl::Hidden, cl::desc( "The maximum number of iterations function specialization is run")); static void findReturnsToZap(Function &F, @@ -235,11 +236,11 @@ static bool runIPSCCP( // nodes in executable blocks we found values for. The function's entry // block is not part of BlocksToErase, so we have to handle it separately. for (BasicBlock *BB : BlocksToErase) { - NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(), + NumInstRemoved += changeToUnreachable(BB->getFirstNonPHIOrDbg(), /*PreserveLCSSA=*/false, &DTU); } if (!Solver.isBlockExecutable(&F.front())) - NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(), + NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHIOrDbg(), /*PreserveLCSSA=*/false, &DTU); BasicBlock *NewUnreachableBB = nullptr; @@ -371,6 +372,18 @@ static bool runIPSCCP( StoreInst *SI = cast<StoreInst>(GV->user_back()); SI->eraseFromParent(); } + + // Try to create a debug constant expression for the global variable + // initializer value. + SmallVector<DIGlobalVariableExpression *, 1> GVEs; + GV->getDebugInfo(GVEs); + if (GVEs.size() == 1) { + DIBuilder DIB(M); + if (DIExpression *InitExpr = getExpressionForConstant( + DIB, *GV->getInitializer(), *GV->getValueType())) + GVEs[0]->replaceOperandWith(1, InitExpr); + } + MadeChanges = true; M.eraseGlobalVariable(GV); ++NumGlobalConst; diff --git a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp index 3ddf5fe20edb..f7a54d428f20 100644 --- a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp +++ b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/SampleContextTracker.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/InstrTypes.h" @@ -29,7 +28,7 @@ using namespace sampleprof; namespace llvm { ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, - StringRef CalleeName) { + FunctionId CalleeName) { if (CalleeName.empty()) return getHottestChildContext(CallSite); @@ -104,7 +103,7 @@ SampleContextTracker::moveContextSamples(ContextTrieNode &ToNodeParent, } void ContextTrieNode::removeChildContext(const LineLocation &CallSite, - StringRef CalleeName) { + FunctionId CalleeName) { uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite); // Note this essentially calls dtor and destroys that child context AllChildContext.erase(Hash); @@ -114,7 +113,7 @@ std::map<uint64_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { return AllChildContext; } -StringRef ContextTrieNode::getFuncName() const { return FuncName; } +FunctionId ContextTrieNode::getFuncName() const { return FuncName; } FunctionSamples *ContextTrieNode::getFunctionSamples() const { return FuncSamples; @@ -178,7 +177,7 @@ void ContextTrieNode::dumpTree() { } ContextTrieNode *ContextTrieNode::getOrCreateChildContext( - const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { + const LineLocation &CallSite, FunctionId CalleeName, bool AllowCreate) { uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite); auto It = AllChildContext.find(Hash); if (It != AllChildContext.end()) { @@ -201,7 +200,7 @@ SampleContextTracker::SampleContextTracker( : GUIDToFuncNameMap(GUIDToFuncNameMap) { for (auto &FuncSample : Profiles) { FunctionSamples *FSamples = &FuncSample.second; - SampleContext Context = FuncSample.first; + SampleContext Context = FuncSample.second.getContext(); LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString() << "\n"); ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); @@ -232,14 +231,12 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, return nullptr; CalleeName = FunctionSamples::getCanonicalFnName(CalleeName); - // Convert real function names to MD5 names, if the input profile is - // MD5-based. - std::string FGUID; - CalleeName = getRepInFormat(CalleeName, FunctionSamples::UseMD5, FGUID); + + FunctionId FName = getRepInFormat(CalleeName); // For indirect call, CalleeName will be empty, in which case the context // profile for callee with largest total samples will be returned. - ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName); + ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, FName); if (CalleeContext) { FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); LLVM_DEBUG(if (FSamples) { @@ -305,27 +302,23 @@ SampleContextTracker::getContextSamplesFor(const SampleContext &Context) { SampleContextTracker::ContextSamplesTy & SampleContextTracker::getAllContextSamplesFor(const Function &Func) { StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); - return FuncToCtxtProfiles[CanonName]; + return FuncToCtxtProfiles[getRepInFormat(CanonName)]; } SampleContextTracker::ContextSamplesTy & SampleContextTracker::getAllContextSamplesFor(StringRef Name) { - return FuncToCtxtProfiles[Name]; + return FuncToCtxtProfiles[getRepInFormat(Name)]; } FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, bool MergeContext) { StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); - return getBaseSamplesFor(CanonName, MergeContext); + return getBaseSamplesFor(getRepInFormat(CanonName), MergeContext); } -FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, +FunctionSamples *SampleContextTracker::getBaseSamplesFor(FunctionId Name, bool MergeContext) { LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); - // Convert real function names to MD5 names, if the input profile is - // MD5-based. - std::string FGUID; - Name = getRepInFormat(Name, FunctionSamples::UseMD5, FGUID); // Base profile is top-level node (child of root node), so try to retrieve // existing top-level node for given function first. If it exists, it could be @@ -373,7 +366,7 @@ void SampleContextTracker::markContextSamplesInlined( ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; } void SampleContextTracker::promoteMergeContextSamplesTree( - const Instruction &Inst, StringRef CalleeName) { + const Instruction &Inst, FunctionId CalleeName) { LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n" << Inst << "\n"); // Get the caller context for the call instruction, we don't use callee @@ -458,9 +451,9 @@ void SampleContextTracker::dump() { RootContext.dumpTree(); } StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const { if (!FunctionSamples::UseMD5) - return Node->getFuncName(); + return Node->getFuncName().stringRef(); assert(GUIDToFuncNameMap && "GUIDToFuncNameMap needs to be populated first"); - return GUIDToFuncNameMap->lookup(std::stoull(Node->getFuncName().data())); + return GUIDToFuncNameMap->lookup(Node->getFuncName().getHashCode()); } ContextTrieNode * @@ -470,7 +463,7 @@ SampleContextTracker::getContextFor(const SampleContext &Context) { ContextTrieNode * SampleContextTracker::getCalleeContextFor(const DILocation *DIL, - StringRef CalleeName) { + FunctionId CalleeName) { assert(DIL && "Expect non-null location"); ContextTrieNode *CallContext = getContextFor(DIL); @@ -485,7 +478,7 @@ SampleContextTracker::getCalleeContextFor(const DILocation *DIL, ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { assert(DIL && "Expect non-null location"); - SmallVector<std::pair<LineLocation, StringRef>, 10> S; + SmallVector<std::pair<LineLocation, FunctionId>, 10> S; // Use C++ linkage name if possible. const DILocation *PrevDIL = DIL; @@ -494,7 +487,8 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { if (Name.empty()) Name = PrevDIL->getScope()->getSubprogram()->getName(); S.push_back( - std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), Name)); + std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), + getRepInFormat(Name))); PrevDIL = DIL; } @@ -503,24 +497,14 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName(); if (RootName.empty()) RootName = PrevDIL->getScope()->getSubprogram()->getName(); - S.push_back(std::make_pair(LineLocation(0, 0), RootName)); - - // Convert real function names to MD5 names, if the input profile is - // MD5-based. - std::list<std::string> MD5Names; - if (FunctionSamples::UseMD5) { - for (auto &Location : S) { - MD5Names.emplace_back(); - getRepInFormat(Location.second, FunctionSamples::UseMD5, MD5Names.back()); - Location.second = MD5Names.back(); - } - } + S.push_back(std::make_pair(LineLocation(0, 0), + getRepInFormat(RootName))); ContextTrieNode *ContextNode = &RootContext; int I = S.size(); while (--I >= 0 && ContextNode) { LineLocation &CallSite = S[I].first; - StringRef CalleeName = S[I].second; + FunctionId CalleeName = S[I].second; ContextNode = ContextNode->getChildContext(CallSite, CalleeName); } @@ -540,10 +524,10 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, // Create child node at parent line/disc location if (AllowCreate) { ContextNode = - ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.FuncName); + ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.Func); } else { ContextNode = - ContextNode->getChildContext(CallSiteLoc, Callsite.FuncName); + ContextNode->getChildContext(CallSiteLoc, Callsite.Func); } CallSiteLoc = Callsite.Location; } @@ -553,12 +537,14 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, return ContextNode; } -ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) { +ContextTrieNode * +SampleContextTracker::getTopLevelContextNode(FunctionId FName) { assert(!FName.empty() && "Top level node query must provide valid name"); return RootContext.getChildContext(LineLocation(0, 0), FName); } -ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { +ContextTrieNode & +SampleContextTracker::addTopLevelContextNode(FunctionId FName) { assert(!getTopLevelContextNode(FName) && "Node to add must not exist"); return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName); } @@ -638,7 +624,7 @@ void SampleContextTracker::createContextLessProfileMap( FunctionSamples *FProfile = Node->getFunctionSamples(); // Profile's context can be empty, use ContextNode's func name. if (FProfile) - ContextLessProfiles[Node->getFuncName()].merge(*FProfile); + ContextLessProfiles.Create(Node->getFuncName()).merge(*FProfile); } } } // namespace llvm diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp index a53baecd4776..6c6f0a0eca72 100644 --- a/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -56,6 +56,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/PseudoProbe.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/ProfileData/InstrProf.h" @@ -142,11 +143,6 @@ static cl::opt<bool> PersistProfileStaleness( cl::desc("Compute stale profile statistical metrics and write it into the " "native object file(.llvm_stats section).")); -static cl::opt<bool> FlattenProfileForMatching( - "flatten-profile-for-matching", cl::Hidden, cl::init(true), - cl::desc( - "Use flattened profile for stale profile detection and matching.")); - static cl::opt<bool> ProfileSampleAccurate( "profile-sample-accurate", cl::Hidden, cl::init(false), cl::desc("If the sample profile is accurate, we will mark all un-sampled " @@ -429,7 +425,7 @@ struct CandidateComparer { return LCS->getBodySamples().size() > RCS->getBodySamples().size(); // Tie breaker using GUID so we have stable/deterministic inlining order - return LCS->getGUID(LCS->getName()) < RCS->getGUID(RCS->getName()); + return LCS->getGUID() < RCS->getGUID(); } }; @@ -458,32 +454,44 @@ class SampleProfileMatcher { uint64_t MismatchedFuncHashSamples = 0; uint64_t TotalFuncHashSamples = 0; + // A dummy name for unknown indirect callee, used to differentiate from a + // non-call instruction that also has an empty callee name. + static constexpr const char *UnknownIndirectCallee = + "unknown.indirect.callee"; + public: SampleProfileMatcher(Module &M, SampleProfileReader &Reader, const PseudoProbeManager *ProbeManager) - : M(M), Reader(Reader), ProbeManager(ProbeManager) { - if (FlattenProfileForMatching) { - ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, - FunctionSamples::ProfileIsCS); - } - } + : M(M), Reader(Reader), ProbeManager(ProbeManager){}; void runOnModule(); private: FunctionSamples *getFlattenedSamplesFor(const Function &F) { StringRef CanonFName = FunctionSamples::getCanonicalFnName(F); - auto It = FlattenedProfiles.find(CanonFName); + auto It = FlattenedProfiles.find(FunctionId(CanonFName)); if (It != FlattenedProfiles.end()) return &It->second; return nullptr; } - void runOnFunction(const Function &F, const FunctionSamples &FS); + void runOnFunction(const Function &F); + void findIRAnchors(const Function &F, + std::map<LineLocation, StringRef> &IRAnchors); + void findProfileAnchors( + const FunctionSamples &FS, + std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors); + void countMismatchedSamples(const FunctionSamples &FS); void countProfileMismatches( + const Function &F, const FunctionSamples &FS, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors); + void countProfileCallsiteMismatches( const FunctionSamples &FS, - const std::unordered_set<LineLocation, LineLocationHash> - &MatchedCallsiteLocs, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites); - LocToLocMap &getIRToProfileLocationMap(const Function &F) { auto Ret = FuncMappings.try_emplace( FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap()); @@ -491,12 +499,10 @@ private: } void distributeIRToProfileLocationMap(); void distributeIRToProfileLocationMap(FunctionSamples &FS); - void populateProfileCallsites( - const FunctionSamples &FS, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap); void runStaleProfileMatching( - const std::map<LineLocation, StringRef> &IRLocations, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap, + const Function &F, const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, LocToLocMap &IRToProfileLocationMap); }; @@ -538,7 +544,6 @@ protected: findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const; void findExternalInlineCandidate(CallBase *CB, const FunctionSamples *Samples, DenseSet<GlobalValue::GUID> &InlinedGUIDs, - const StringMap<Function *> &SymbolMap, uint64_t Threshold); // Attempt to promote indirect call and also inline the promoted call bool tryPromoteAndInlineCandidate( @@ -573,7 +578,7 @@ protected: /// the function name. If the function name contains suffix, additional /// entry is added to map from the stripped name to the function if there /// is one-to-one mapping. - StringMap<Function *> SymbolMap; + HashKeyMap<std::unordered_map, FunctionId, Function *> SymbolMap; std::function<AssumptionCache &(Function &)> GetAC; std::function<TargetTransformInfo &(Function &)> GetTTI; @@ -615,6 +620,11 @@ protected: // All the Names used in FunctionSamples including outline function // names, inline instance names and call target names. StringSet<> NamesInProfile; + // MD5 version of NamesInProfile. Either NamesInProfile or GUIDsInProfile is + // populated, depends on whether the profile uses MD5. Because the name table + // generally contains several magnitude more entries than the number of + // functions, we do not want to convert all names from one form to another. + llvm::DenseSet<uint64_t> GUIDsInProfile; // For symbol in profile symbol list, whether to regard their profiles // to be accurate. It is mainly decided by existance of profile symbol @@ -759,8 +769,7 @@ SampleProfileLoader::findIndirectCallFunctionSamples( assert(L && R && "Expect non-null FunctionSamples"); if (L->getHeadSamplesEstimate() != R->getHeadSamplesEstimate()) return L->getHeadSamplesEstimate() > R->getHeadSamplesEstimate(); - return FunctionSamples::getGUID(L->getName()) < - FunctionSamples::getGUID(R->getName()); + return L->getGUID() < R->getGUID(); }; if (FunctionSamples::ProfileIsCS) { @@ -970,13 +979,13 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate( // This prevents allocating an array of zero length in callees below. if (MaxNumPromotions == 0) return false; - auto CalleeFunctionName = Candidate.CalleeSamples->getFuncName(); + auto CalleeFunctionName = Candidate.CalleeSamples->getFunction(); auto R = SymbolMap.find(CalleeFunctionName); - if (R == SymbolMap.end() || !R->getValue()) + if (R == SymbolMap.end() || !R->second) return false; auto &CI = *Candidate.CallInstr; - if (!doesHistoryAllowICP(CI, R->getValue()->getName())) + if (!doesHistoryAllowICP(CI, R->second->getName())) return false; const char *Reason = "Callee function not available"; @@ -986,17 +995,17 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate( // clone the caller first, and inline the cloned caller if it is // recursive. As llvm does not inline recursive calls, we will // simply ignore it instead of handling it explicitly. - if (!R->getValue()->isDeclaration() && R->getValue()->getSubprogram() && - R->getValue()->hasFnAttribute("use-sample-profile") && - R->getValue() != &F && isLegalToPromote(CI, R->getValue(), &Reason)) { + if (!R->second->isDeclaration() && R->second->getSubprogram() && + R->second->hasFnAttribute("use-sample-profile") && + R->second != &F && isLegalToPromote(CI, R->second, &Reason)) { // For promoted target, set its value with NOMORE_ICP_MAGICNUM count // in the value profile metadata so the target won't be promoted again. SmallVector<InstrProfValueData, 1> SortedCallTargets = {InstrProfValueData{ - Function::getGUID(R->getValue()->getName()), NOMORE_ICP_MAGICNUM}}; + Function::getGUID(R->second->getName()), NOMORE_ICP_MAGICNUM}}; updateIDTMetaData(CI, SortedCallTargets, 0); auto *DI = &pgo::promoteIndirectCall( - CI, R->getValue(), Candidate.CallsiteCount, Sum, false, ORE); + CI, R->second, Candidate.CallsiteCount, Sum, false, ORE); if (DI) { Sum -= Candidate.CallsiteCount; // Do not prorate the indirect callsite distribution since the original @@ -1025,7 +1034,8 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate( } } else { LLVM_DEBUG(dbgs() << "\nFailed to promote indirect call to " - << Candidate.CalleeSamples->getFuncName() << " because " + << FunctionSamples::getCanonicalFnName( + Candidate.CallInstr->getName())<< " because " << Reason << "\n"); } return false; @@ -1070,8 +1080,7 @@ void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates( void SampleProfileLoader::findExternalInlineCandidate( CallBase *CB, const FunctionSamples *Samples, - DenseSet<GlobalValue::GUID> &InlinedGUIDs, - const StringMap<Function *> &SymbolMap, uint64_t Threshold) { + DenseSet<GlobalValue::GUID> &InlinedGUIDs, uint64_t Threshold) { // If ExternalInlineAdvisor(ReplayInlineAdvisor) wants to inline an external // function make sure it's imported @@ -1080,7 +1089,7 @@ void SampleProfileLoader::findExternalInlineCandidate( // just add the direct GUID and move on if (!Samples) { InlinedGUIDs.insert( - FunctionSamples::getGUID(CB->getCalledFunction()->getName())); + Function::getGUID(CB->getCalledFunction()->getName())); return; } // Otherwise, drop the threshold to import everything that we can @@ -1121,22 +1130,20 @@ void SampleProfileLoader::findExternalInlineCandidate( CalleeSample->getContext().hasAttribute(ContextShouldBeInlined); if (!PreInline && CalleeSample->getHeadSamplesEstimate() < Threshold) continue; - - StringRef Name = CalleeSample->getFuncName(); - Function *Func = SymbolMap.lookup(Name); + + Function *Func = SymbolMap.lookup(CalleeSample->getFunction()); // Add to the import list only when it's defined out of module. if (!Func || Func->isDeclaration()) - InlinedGUIDs.insert(FunctionSamples::getGUID(CalleeSample->getName())); + InlinedGUIDs.insert(CalleeSample->getGUID()); // Import hot CallTargets, which may not be available in IR because full // profile annotation cannot be done until backend compilation in ThinLTO. for (const auto &BS : CalleeSample->getBodySamples()) for (const auto &TS : BS.second.getCallTargets()) - if (TS.getValue() > Threshold) { - StringRef CalleeName = CalleeSample->getFuncName(TS.getKey()); - const Function *Callee = SymbolMap.lookup(CalleeName); + if (TS.second > Threshold) { + const Function *Callee = SymbolMap.lookup(TS.first); if (!Callee || Callee->isDeclaration()) - InlinedGUIDs.insert(FunctionSamples::getGUID(TS.getKey())); + InlinedGUIDs.insert(TS.first.getHashCode()); } // Import hot child context profile associted with callees. Note that this @@ -1234,7 +1241,7 @@ bool SampleProfileLoader::inlineHotFunctions( for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { uint64_t SumOrigin = Sum; if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap, + findExternalInlineCandidate(I, FS, InlinedGUIDs, PSI->getOrCompHotCountThreshold()); continue; } @@ -1255,7 +1262,7 @@ bool SampleProfileLoader::inlineHotFunctions( } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), - InlinedGUIDs, SymbolMap, + InlinedGUIDs, PSI->getOrCompHotCountThreshold()); } } @@ -1504,7 +1511,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( for (const auto *FS : CalleeSamples) { // TODO: Consider disable pre-lTO ICP for MonoLTO as well if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap, + findExternalInlineCandidate(I, FS, InlinedGUIDs, PSI->getOrCompHotCountThreshold()); continue; } @@ -1557,7 +1564,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), - InlinedGUIDs, SymbolMap, + InlinedGUIDs, PSI->getOrCompHotCountThreshold()); } } @@ -1619,7 +1626,12 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( // Note that we have to do the merge right after processing function. // This allows OutlineFS's profile to be used for annotation during // top-down processing of functions' annotation. - FunctionSamples *OutlineFS = Reader->getOrCreateSamplesFor(*Callee); + FunctionSamples *OutlineFS = Reader->getSamplesFor(*Callee); + // If outlined function does not exist in the profile, add it to a + // separate map so that it does not rehash the original profile. + if (!OutlineFS) + OutlineFS = &OutlineFunctionSamples[ + FunctionId(FunctionSamples::getCanonicalFnName(Callee->getName()))]; OutlineFS->merge(*FS, 1); // Set outlined profile to be synthetic to not bias the inliner. OutlineFS->SetContextSynthetic(); @@ -1638,7 +1650,7 @@ GetSortedValueDataFromCallTargets(const SampleRecord::CallTargetMap &M) { SmallVector<InstrProfValueData, 2> R; for (const auto &I : SampleRecord::SortCallTargets(M)) { R.emplace_back( - InstrProfValueData{FunctionSamples::getGUID(I.first), I.second}); + InstrProfValueData{I.first.getHashCode(), I.second}); } return R; } @@ -1699,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { else if (OverwriteExistingWeights) I.setMetadata(LLVMContext::MD_prof, nullptr); } else if (!isa<IntrinsicInst>(&I)) { - I.setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights( - {static_cast<uint32_t>(BlockWeights[BB])})); + setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])}); } } } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { @@ -1709,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { // clear it for cold code. for (auto &I : *BB) { if (isa<CallInst>(I) || isa<InvokeInst>(I)) { - if (cast<CallBase>(I).isIndirectCall()) + if (cast<CallBase>(I).isIndirectCall()) { I.setMetadata(LLVMContext::MD_prof, nullptr); - else - I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0)); + } else { + setBranchWeights(I, {uint32_t(0)}); + } } } } @@ -1792,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { if (MaxWeight > 0 && (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) { LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); - TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + setBranchWeights(*TI, Weights); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst) << "most popular destination for conditional branches at " @@ -1865,7 +1876,8 @@ SampleProfileLoader::buildProfiledCallGraph(Module &M) { for (Function &F : M) { if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) continue; - ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(F)); + ProfiledCG->addProfiledFunction( + getRepInFormat(FunctionSamples::getCanonicalFnName(F))); } return ProfiledCG; @@ -1913,7 +1925,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) { // on the profile to favor more inlining. This is only a problem with CS // profile. // 3. Transitive indirect call edges due to inlining. When a callee function - // (say B) is inlined into into a caller function (say A) in LTO prelink, + // (say B) is inlined into a caller function (say A) in LTO prelink, // every call edge originated from the callee B will be transferred to // the caller A. If any transferred edge (say A->C) is indirect, the // original profiled indirect edge B->C, even if considered, would not @@ -2016,8 +2028,16 @@ bool SampleProfileLoader::doInitialization(Module &M, ProfileAccurateForSymsInList && PSL && !ProfileSampleAccurate; if (ProfAccForSymsInList) { NamesInProfile.clear(); - if (auto NameTable = Reader->getNameTable()) - NamesInProfile.insert(NameTable->begin(), NameTable->end()); + GUIDsInProfile.clear(); + if (auto NameTable = Reader->getNameTable()) { + if (FunctionSamples::UseMD5) { + for (auto Name : *NameTable) + GUIDsInProfile.insert(Name.getHashCode()); + } else { + for (auto Name : *NameTable) + NamesInProfile.insert(Name.stringRef()); + } + } CoverageTracker.setProfAccForSymsInList(true); } @@ -2103,77 +2123,200 @@ bool SampleProfileLoader::doInitialization(Module &M, return true; } -void SampleProfileMatcher::countProfileMismatches( - const FunctionSamples &FS, - const std::unordered_set<LineLocation, LineLocationHash> - &MatchedCallsiteLocs, - uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) { +void SampleProfileMatcher::findIRAnchors( + const Function &F, std::map<LineLocation, StringRef> &IRAnchors) { + // For inlined code, recover the original callsite and callee by finding the + // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the + // top-level frame is "main:1", the callsite is "1" and the callee is "foo". + auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) { + assert((DIL && DIL->getInlinedAt()) && "No inlined callsite"); + const DILocation *PrevDIL = nullptr; + do { + PrevDIL = DIL; + DIL = DIL->getInlinedAt(); + } while (DIL->getInlinedAt()); - auto isInvalidLineOffset = [](uint32_t LineOffset) { - return LineOffset & 0x8000; + LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL); + StringRef CalleeName = PrevDIL->getSubprogramLinkageName(); + return std::make_pair(Callsite, CalleeName); }; - // Check if there are any callsites in the profile that does not match to any - // IR callsites, those callsite samples will be discarded. - for (auto &I : FS.getBodySamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) - continue; + auto GetCanonicalCalleeName = [](const CallBase *CB) { + StringRef CalleeName = UnknownIndirectCallee; + if (Function *Callee = CB->getCalledFunction()) + CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); + return CalleeName; + }; + + // Extract profile matching anchors in the IR. + for (auto &BB : F) { + for (auto &I : BB) { + DILocation *DIL = I.getDebugLoc(); + if (!DIL) + continue; + + if (FunctionSamples::ProfileIsProbeBased) { + if (auto Probe = extractProbe(I)) { + // Flatten inlined IR for the matching. + if (DIL->getInlinedAt()) { + IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); + } else { + // Use empty StringRef for basic block probe. + StringRef CalleeName; + if (const auto *CB = dyn_cast<CallBase>(&I)) { + // Skip the probe inst whose callee name is "llvm.pseudoprobe". + if (!isa<IntrinsicInst>(&I)) + CalleeName = GetCanonicalCalleeName(CB); + } + IRAnchors.emplace(LineLocation(Probe->Id, 0), CalleeName); + } + } + } else { + // TODO: For line-number based profile(AutoFDO), currently only support + // find callsite anchors. In future, we need to parse all the non-call + // instructions to extract the line locations for profile matching. + if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) + continue; - uint64_t Count = I.second.getSamples(); - if (!I.second.getCallTargets().empty()) { - TotalCallsiteSamples += Count; - FuncProfiledCallsites++; - if (!MatchedCallsiteLocs.count(Loc)) { - MismatchedCallsiteSamples += Count; - FuncMismatchedCallsites++; + if (DIL->getInlinedAt()) { + IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); + } else { + LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL); + StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I)); + IRAnchors.emplace(Callsite, CalleeName); + } } } } +} - for (auto &I : FS.getCallsiteSamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) - continue; +void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) { + const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID()); + // Skip the function that is external or renamed. + if (!FuncDesc) + return; + + if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { + MismatchedFuncHashSamples += FS.getTotalSamples(); + return; + } + for (const auto &I : FS.getCallsiteSamples()) + for (const auto &CS : I.second) + countMismatchedSamples(CS.second); +} + +void SampleProfileMatcher::countProfileMismatches( + const Function &F, const FunctionSamples &FS, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors) { + [[maybe_unused]] bool IsFuncHashMismatch = false; + if (FunctionSamples::ProfileIsProbeBased) { + TotalFuncHashSamples += FS.getTotalSamples(); + TotalProfiledFunc++; + const auto *FuncDesc = ProbeManager->getDesc(F); + if (FuncDesc) { + if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { + NumMismatchedFuncHash++; + IsFuncHashMismatch = true; + } + countMismatchedSamples(FS); + } + } + + uint64_t FuncMismatchedCallsites = 0; + uint64_t FuncProfiledCallsites = 0; + countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors, + FuncMismatchedCallsites, + FuncProfiledCallsites); + TotalProfiledCallsites += FuncProfiledCallsites; + NumMismatchedCallsites += FuncMismatchedCallsites; + LLVM_DEBUG({ + if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch && + FuncMismatchedCallsites) + dbgs() << "Function checksum is matched but there are " + << FuncMismatchedCallsites << "/" << FuncProfiledCallsites + << " mismatched callsites.\n"; + }); +} + +void SampleProfileMatcher::countProfileCallsiteMismatches( + const FunctionSamples &FS, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, + uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) { + + // Check if there are any callsites in the profile that does not match to any + // IR callsites, those callsite samples will be discarded. + for (const auto &I : ProfileAnchors) { + const auto &Loc = I.first; + const auto &Callees = I.second; + assert(!Callees.empty() && "Callees should not be empty"); + + StringRef IRCalleeName; + const auto &IR = IRAnchors.find(Loc); + if (IR != IRAnchors.end()) + IRCalleeName = IR->second; - uint64_t Count = 0; - for (auto &FM : I.second) { - Count += FM.second.getHeadSamplesEstimate(); + // Compute number of samples in the original profile. + uint64_t CallsiteSamples = 0; + auto CTM = FS.findCallTargetMapAt(Loc); + if (CTM) { + for (const auto &I : CTM.get()) + CallsiteSamples += I.second; } - TotalCallsiteSamples += Count; + const auto *FSMap = FS.findFunctionSamplesMapAt(Loc); + if (FSMap) { + for (const auto &I : *FSMap) + CallsiteSamples += I.second.getTotalSamples(); + } + + bool CallsiteIsMatched = false; + // Since indirect call does not have CalleeName, check conservatively if + // callsite in the profile is a callsite location. This is to reduce num of + // false positive since otherwise all the indirect call samples will be + // reported as mismatching. + if (IRCalleeName == UnknownIndirectCallee) + CallsiteIsMatched = true; + else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName))) + CallsiteIsMatched = true; + FuncProfiledCallsites++; - if (!MatchedCallsiteLocs.count(Loc)) { - MismatchedCallsiteSamples += Count; + TotalCallsiteSamples += CallsiteSamples; + if (!CallsiteIsMatched) { FuncMismatchedCallsites++; + MismatchedCallsiteSamples += CallsiteSamples; } } } -// Populate the anchors(direct callee name) from profile. -void SampleProfileMatcher::populateProfileCallsites( - const FunctionSamples &FS, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap) { +void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS, + std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) { + auto isInvalidLineOffset = [](uint32_t LineOffset) { + return LineOffset & 0x8000; + }; + for (const auto &I : FS.getBodySamples()) { - const auto &Loc = I.first; - const auto &CTM = I.second.getCallTargets(); - // Filter out possible indirect calls, use direct callee name as anchor. - if (CTM.size() == 1) { - StringRef CalleeName = CTM.begin()->first(); - const auto &Candidates = CalleeToCallsitesMap.try_emplace( - CalleeName, std::set<LineLocation>()); - Candidates.first->second.insert(Loc); + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + for (const auto &I : I.second.getCallTargets()) { + auto Ret = ProfileAnchors.try_emplace(Loc, + std::unordered_set<FunctionId>()); + Ret.first->second.insert(I.first); } } for (const auto &I : FS.getCallsiteSamples()) { const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; const auto &CalleeMap = I.second; - // Filter out possible indirect calls, use direct callee name as anchor. - if (CalleeMap.size() == 1) { - StringRef CalleeName = CalleeMap.begin()->first; - const auto &Candidates = CalleeToCallsitesMap.try_emplace( - CalleeName, std::set<LineLocation>()); - Candidates.first->second.insert(Loc); + for (const auto &I : CalleeMap) { + auto Ret = ProfileAnchors.try_emplace(Loc, + std::unordered_set<FunctionId>()); + Ret.first->second.insert(I.first); } } } @@ -2196,12 +2339,30 @@ void SampleProfileMatcher::populateProfileCallsites( // [1, 2, 3(foo), 4, 7, 8(bar), 9] // The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9]. void SampleProfileMatcher::runStaleProfileMatching( - const std::map<LineLocation, StringRef> &IRLocations, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap, + const Function &F, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, LocToLocMap &IRToProfileLocationMap) { + LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() + << "\n"); assert(IRToProfileLocationMap.empty() && "Run stale profile matching only once per function"); + std::unordered_map<FunctionId, std::set<LineLocation>> + CalleeToCallsitesMap; + for (const auto &I : ProfileAnchors) { + const auto &Loc = I.first; + const auto &Callees = I.second; + // Filter out possible indirect calls, use direct callee name as anchor. + if (Callees.size() == 1) { + FunctionId CalleeName = *Callees.begin(); + const auto &Candidates = CalleeToCallsitesMap.try_emplace( + CalleeName, std::set<LineLocation>()); + Candidates.first->second.insert(Loc); + } + } + auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) { // Skip the unchanged location mapping to save memory. if (From != To) @@ -2212,18 +2373,19 @@ void SampleProfileMatcher::runStaleProfileMatching( int32_t LocationDelta = 0; SmallVector<LineLocation> LastMatchedNonAnchors; - for (const auto &IR : IRLocations) { + for (const auto &IR : IRAnchors) { const auto &Loc = IR.first; - StringRef CalleeName = IR.second; + auto CalleeName = IR.second; bool IsMatchedAnchor = false; // Match the anchor location in lexical order. if (!CalleeName.empty()) { - auto ProfileAnchors = CalleeToCallsitesMap.find(CalleeName); - if (ProfileAnchors != CalleeToCallsitesMap.end() && - !ProfileAnchors->second.empty()) { - auto CI = ProfileAnchors->second.begin(); + auto CandidateAnchors = CalleeToCallsitesMap.find( + getRepInFormat(CalleeName)); + if (CandidateAnchors != CalleeToCallsitesMap.end() && + !CandidateAnchors->second.empty()) { + auto CI = CandidateAnchors->second.begin(); const auto Candidate = *CI; - ProfileAnchors->second.erase(CI); + CandidateAnchors->second.erase(CI); InsertMatching(Loc, Candidate); LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName << " is matched from " << Loc << " to " << Candidate @@ -2261,122 +2423,56 @@ void SampleProfileMatcher::runStaleProfileMatching( } } -void SampleProfileMatcher::runOnFunction(const Function &F, - const FunctionSamples &FS) { - bool IsFuncHashMismatch = false; - if (FunctionSamples::ProfileIsProbeBased) { - uint64_t Count = FS.getTotalSamples(); - TotalFuncHashSamples += Count; - TotalProfiledFunc++; - if (!ProbeManager->profileIsValid(F, FS)) { - MismatchedFuncHashSamples += Count; - NumMismatchedFuncHash++; - IsFuncHashMismatch = true; - } - } - - std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs; - // The value of the map is the name of direct callsite and use empty StringRef - // for non-direct-call site. - std::map<LineLocation, StringRef> IRLocations; - - // Extract profile matching anchors and profile mismatch metrics in the IR. - for (auto &BB : F) { - for (auto &I : BB) { - // TODO: Support line-number based location(AutoFDO). - if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) { - if (std::optional<PseudoProbe> Probe = extractProbe(I)) - IRLocations.emplace(LineLocation(Probe->Id, 0), StringRef()); - } - - if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) - continue; - - const auto *CB = dyn_cast<CallBase>(&I); - if (auto &DLoc = I.getDebugLoc()) { - LineLocation IRCallsite = FunctionSamples::getCallSiteIdentifier(DLoc); - - StringRef CalleeName; - if (Function *Callee = CB->getCalledFunction()) - CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); - - // Force to overwrite the callee name in case any non-call location was - // written before. - auto R = IRLocations.emplace(IRCallsite, CalleeName); - R.first->second = CalleeName; - assert((!FunctionSamples::ProfileIsProbeBased || R.second || - R.first->second == CalleeName) && - "Overwrite non-call or different callee name location for " - "pseudo probe callsite"); +void SampleProfileMatcher::runOnFunction(const Function &F) { + // We need to use flattened function samples for matching. + // Unlike IR, which includes all callsites from the source code, the callsites + // in profile only show up when they are hit by samples, i,e. the profile + // callsites in one context may differ from those in another context. To get + // the maximum number of callsites, we merge the function profiles from all + // contexts, aka, the flattened profile to find profile anchors. + const auto *FSFlattened = getFlattenedSamplesFor(F); + if (!FSFlattened) + return; - // Go through all the callsites on the IR and flag the callsite if the - // target name is the same as the one in the profile. - const auto CTM = FS.findCallTargetMapAt(IRCallsite); - const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite); - - // Indirect call case. - if (CalleeName.empty()) { - // Since indirect call does not have the CalleeName, check - // conservatively if callsite in the profile is a callsite location. - // This is to avoid nums of false positive since otherwise all the - // indirect call samples will be reported as mismatching. - if ((CTM && !CTM->empty()) || (CallsiteFS && !CallsiteFS->empty())) - MatchedCallsiteLocs.insert(IRCallsite); - } else { - // Check if the call target name is matched for direct call case. - if ((CTM && CTM->count(CalleeName)) || - (CallsiteFS && CallsiteFS->count(CalleeName))) - MatchedCallsiteLocs.insert(IRCallsite); - } - } - } - } + // Anchors for IR. It's a map from IR location to callee name, callee name is + // empty for non-call instruction and use a dummy name(UnknownIndirectCallee) + // for unknown indrect callee name. + std::map<LineLocation, StringRef> IRAnchors; + findIRAnchors(F, IRAnchors); + // Anchors for profile. It's a map from callsite location to a set of callee + // name. + std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors; + findProfileAnchors(*FSFlattened, ProfileAnchors); // Detect profile mismatch for profile staleness metrics report. - if (ReportProfileStaleness || PersistProfileStaleness) { - uint64_t FuncMismatchedCallsites = 0; - uint64_t FuncProfiledCallsites = 0; - countProfileMismatches(FS, MatchedCallsiteLocs, FuncMismatchedCallsites, - FuncProfiledCallsites); - TotalProfiledCallsites += FuncProfiledCallsites; - NumMismatchedCallsites += FuncMismatchedCallsites; - LLVM_DEBUG({ - if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch && - FuncMismatchedCallsites) - dbgs() << "Function checksum is matched but there are " - << FuncMismatchedCallsites << "/" << FuncProfiledCallsites - << " mismatched callsites.\n"; - }); + // Skip reporting the metrics for imported functions. + if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) && + (ReportProfileStaleness || PersistProfileStaleness)) { + // Use top-level nested FS for counting profile mismatch metrics since + // currently once a callsite is mismatched, all its children profiles are + // dropped. + if (const auto *FS = Reader.getSamplesFor(F)) + countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors); } - if (IsFuncHashMismatch && SalvageStaleProfile) { - LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() - << "\n"); - - StringMap<std::set<LineLocation>> CalleeToCallsitesMap; - populateProfileCallsites(FS, CalleeToCallsitesMap); - + // Run profile matching for checksum mismatched profile, currently only + // support for pseudo-probe. + if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased && + !ProbeManager->profileIsValid(F, *FSFlattened)) { // The matching result will be saved to IRToProfileLocationMap, create a new // map for each function. - auto &IRToProfileLocationMap = getIRToProfileLocationMap(F); - - runStaleProfileMatching(IRLocations, CalleeToCallsitesMap, - IRToProfileLocationMap); + runStaleProfileMatching(F, IRAnchors, ProfileAnchors, + getIRToProfileLocationMap(F)); } } void SampleProfileMatcher::runOnModule() { + ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, + FunctionSamples::ProfileIsCS); for (auto &F : M) { if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) continue; - FunctionSamples *FS = nullptr; - if (FlattenProfileForMatching) - FS = getFlattenedSamplesFor(F); - else - FS = Reader.getSamplesFor(F); - if (!FS) - continue; - runOnFunction(F, *FS); + runOnFunction(F); } if (SalvageStaleProfile) distributeIRToProfileLocationMap(); @@ -2424,7 +2520,7 @@ void SampleProfileMatcher::runOnModule() { void SampleProfileMatcher::distributeIRToProfileLocationMap( FunctionSamples &FS) { - const auto ProfileMappings = FuncMappings.find(FS.getName()); + const auto ProfileMappings = FuncMappings.find(FS.getFuncName()); if (ProfileMappings != FuncMappings.end()) { FS.setIRToProfileLocationMap(&(ProfileMappings->second)); } @@ -2466,10 +2562,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, Function *F = dyn_cast<Function>(N_F.getValue()); if (F == nullptr || OrigName.empty()) continue; - SymbolMap[OrigName] = F; + SymbolMap[FunctionId(OrigName)] = F; StringRef NewName = FunctionSamples::getCanonicalFnName(*F); if (OrigName != NewName && !NewName.empty()) { - auto r = SymbolMap.insert(std::make_pair(NewName, F)); + auto r = SymbolMap.emplace(FunctionId(NewName), F); // Failiing to insert means there is already an entry in SymbolMap, // thus there are multiple functions that are mapped to the same // stripped name. In this case of name conflicting, set the value @@ -2482,11 +2578,11 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, if (Remapper) { if (auto MapName = Remapper->lookUpNameInProfile(OrigName)) { if (*MapName != OrigName && !MapName->empty()) - SymbolMap.insert(std::make_pair(*MapName, F)); + SymbolMap.emplace(FunctionId(*MapName), F); } } } - assert(SymbolMap.count(StringRef()) == 0 && + assert(SymbolMap.count(FunctionId()) == 0 && "No empty StringRef should be added in SymbolMap"); if (ReportProfileStaleness || PersistProfileStaleness || @@ -2550,7 +2646,9 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) // but not cold accumulatively...), so the outline function showing up as // cold in sampled binary will actually not be cold after current build. StringRef CanonName = FunctionSamples::getCanonicalFnName(F); - if (NamesInProfile.count(CanonName)) + if ((FunctionSamples::UseMD5 && + GUIDsInProfile.count(Function::getGUID(CanonName))) || + (!FunctionSamples::UseMD5 && NamesInProfile.count(CanonName))) initialEntryCount = -1; } @@ -2571,8 +2669,24 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) if (FunctionSamples::ProfileIsCS) Samples = ContextTracker->getBaseSamplesFor(F); - else + else { Samples = Reader->getSamplesFor(F); + // Try search in previously inlined functions that were split or duplicated + // into base. + if (!Samples) { + StringRef CanonName = FunctionSamples::getCanonicalFnName(F); + auto It = OutlineFunctionSamples.find(FunctionId(CanonName)); + if (It != OutlineFunctionSamples.end()) { + Samples = &It->second; + } else if (auto Remapper = Reader->getRemapper()) { + if (auto RemppedName = Remapper->lookUpNameInProfile(CanonName)) { + It = OutlineFunctionSamples.find(FunctionId(*RemppedName)); + if (It != OutlineFunctionSamples.end()) + Samples = &It->second; + } + } + } + } if (Samples && !Samples->empty()) return emitAnnotations(F); diff --git a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index 0a42de7224b4..8f0b12d0cfed 100644 --- a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" @@ -95,13 +96,13 @@ void PseudoProbeVerifier::runAfterPass(StringRef PassID, Any IR) { std::string Banner = "\n*** Pseudo Probe Verification After " + PassID.str() + " ***\n"; dbgs() << Banner; - if (const auto **M = any_cast<const Module *>(&IR)) + if (const auto **M = llvm::any_cast<const Module *>(&IR)) runAfterPass(*M); - else if (const auto **F = any_cast<const Function *>(&IR)) + else if (const auto **F = llvm::any_cast<const Function *>(&IR)) runAfterPass(*F); - else if (const auto **C = any_cast<const LazyCallGraph::SCC *>(&IR)) + else if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) runAfterPass(*C); - else if (const auto **L = any_cast<const Loop *>(&IR)) + else if (const auto **L = llvm::any_cast<const Loop *>(&IR)) runAfterPass(*L); else llvm_unreachable("Unknown IR unit"); @@ -221,12 +222,26 @@ void SampleProfileProber::computeProbeIdForBlocks() { } void SampleProfileProber::computeProbeIdForCallsites() { + LLVMContext &Ctx = F->getContext(); + Module *M = F->getParent(); + for (auto &BB : *F) { for (auto &I : BB) { if (!isa<CallBase>(I)) continue; if (isa<IntrinsicInst>(&I)) continue; + + // The current implementation uses the lower 16 bits of the discriminator + // so anything larger than 0xFFFF will be ignored. + if (LastProbeId >= 0xFFFF) { + std::string Msg = "Pseudo instrumentation incomplete for " + + std::string(F->getName()) + " because it's too large"; + Ctx.diagnose( + DiagnosticInfoSampleProfile(M->getName().data(), Msg, DS_Warning)); + return; + } + CallProbeIds[&I] = ++LastProbeId; } } diff --git a/llvm/lib/Transforms/IPO/StripSymbols.cpp b/llvm/lib/Transforms/IPO/StripSymbols.cpp index 147513452789..28d7d4ba6b01 100644 --- a/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -30,12 +30,18 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/TypeFinder.h" #include "llvm/IR/ValueSymbolTable.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +static cl::opt<bool> + StripGlobalConstants("strip-global-constants", cl::init(false), cl::Hidden, + cl::desc("Removes debug compile units which reference " + "to non-existing global constants")); + /// OnlyUsedBy - Return true if V is only used by Usr. static bool OnlyUsedBy(Value *V, Value *Usr) { for (User *U : V->users()) @@ -73,7 +79,7 @@ static void StripSymtab(ValueSymbolTable &ST, bool PreserveDbgInfo) { Value *V = VI->getValue(); ++VI; if (!isa<GlobalValue>(V) || cast<GlobalValue>(V)->hasLocalLinkage()) { - if (!PreserveDbgInfo || !V->getName().startswith("llvm.dbg")) + if (!PreserveDbgInfo || !V->getName().starts_with("llvm.dbg")) // Set name to "", removing from symbol table! V->setName(""); } @@ -88,7 +94,7 @@ static void StripTypeNames(Module &M, bool PreserveDbgInfo) { for (StructType *STy : StructTypes) { if (STy->isLiteral() || STy->getName().empty()) continue; - if (PreserveDbgInfo && STy->getName().startswith("llvm.dbg")) + if (PreserveDbgInfo && STy->getName().starts_with("llvm.dbg")) continue; STy->setName(""); @@ -118,13 +124,13 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { for (GlobalVariable &GV : M.globals()) { if (GV.hasLocalLinkage() && !llvmUsedValues.contains(&GV)) - if (!PreserveDbgInfo || !GV.getName().startswith("llvm.dbg")) + if (!PreserveDbgInfo || !GV.getName().starts_with("llvm.dbg")) GV.setName(""); // Internal symbols can't participate in linkage } for (Function &I : M) { if (I.hasLocalLinkage() && !llvmUsedValues.contains(&I)) - if (!PreserveDbgInfo || !I.getName().startswith("llvm.dbg")) + if (!PreserveDbgInfo || !I.getName().starts_with("llvm.dbg")) I.setName(""); // Internal symbols can't participate in linkage if (auto *Symtab = I.getValueSymbolTable()) StripSymtab(*Symtab, PreserveDbgInfo); @@ -216,7 +222,8 @@ static bool stripDeadDebugInfoImpl(Module &M) { // Create our live global variable list. bool GlobalVariableChange = false; for (auto *DIG : DIC->getGlobalVariables()) { - if (DIG->getExpression() && DIG->getExpression()->isConstant()) + if (DIG->getExpression() && DIG->getExpression()->isConstant() && + !StripGlobalConstants) LiveGVs.insert(DIG); // Make sure we only visit each global variable only once. diff --git a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index d46f9a6c6757..f6f895676084 100644 --- a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -111,7 +111,7 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, // Now compute the callsite count from relative frequency and // entry count: BasicBlock *CSBB = CB.getParent(); - Scaled64 EntryFreq(BFI.getEntryFreq(), 0); + Scaled64 EntryFreq(BFI.getEntryFreq().getFrequency(), 0); Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0); BBCount /= EntryFreq; BBCount *= Counts[Caller]; diff --git a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index fc1e70b1b3d3..e5f9fa1dda88 100644 --- a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -186,7 +186,7 @@ void simplifyExternals(Module &M) { if (!F.isDeclaration() || F.getFunctionType() == EmptyFT || // Changing the type of an intrinsic may invalidate the IR. - F.getName().startswith("llvm.")) + F.getName().starts_with("llvm.")) continue; Function *NewF = @@ -198,7 +198,7 @@ void simplifyExternals(Module &M) { AttributeList::FunctionIndex, F.getAttributes().getFnAttrs())); NewF->takeName(&F); - F.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, F.getType())); + F.replaceAllUsesWith(NewF); F.eraseFromParent(); } @@ -329,7 +329,7 @@ void splitAndWriteThinLTOBitcode( // comdat in MergedM to keep the comdat together. DenseSet<const Comdat *> MergedMComdats; for (GlobalVariable &GV : M.globals()) - if (HasTypeMetadata(&GV)) { + if (!GV.isDeclaration() && HasTypeMetadata(&GV)) { if (const auto *C = GV.getComdat()) MergedMComdats.insert(C); forEachVirtualFunction(GV.getInitializer(), [&](Function *F) { diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index d33258642365..85afc020dbf8 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -58,7 +58,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -369,8 +368,6 @@ template <> struct DenseMapInfo<VTableSlotSummary> { } // end namespace llvm -namespace { - // Returns true if the function must be unreachable based on ValueInfo. // // In particular, identifies a function as unreachable in the following @@ -378,7 +375,7 @@ namespace { // 1) All summaries are live. // 2) All function summaries indicate it's unreachable // 3) There is no non-function with the same GUID (which is rare) -bool mustBeUnreachableFunction(ValueInfo TheFnVI) { +static bool mustBeUnreachableFunction(ValueInfo TheFnVI) { if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { // Returns false if ValueInfo is absent, or the summary list is empty // (e.g., function declarations). @@ -403,6 +400,7 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) { return true; } +namespace { // A virtual call site. VTable is the loaded virtual table pointer, and CS is // the indirect virtual call. struct VirtualCallSite { @@ -590,7 +588,7 @@ struct DevirtModule { : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), ExportSummary(ExportSummary), ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), - Int8PtrTy(Type::getInt8PtrTy(M.getContext())), + Int8PtrTy(PointerType::getUnqual(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), Int64Ty(Type::getInt64Ty(M.getContext())), IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), @@ -776,20 +774,59 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } -namespace llvm { // Enable whole program visibility if enabled by client (e.g. linker) or // internal option, and not force disabled. -bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { +bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && !DisableWholeProgramVisibility; } +static bool +typeIDVisibleToRegularObj(StringRef TypeID, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + // TypeID for member function pointer type is an internal construct + // and won't exist in IsVisibleToRegularObj. The full TypeID + // will be present and participate in invalidation. + if (TypeID.ends_with(".virtual")) + return false; + + // TypeID that doesn't start with Itanium mangling (_ZTS) will be + // non-externally visible types which cannot interact with + // external native files. See CodeGenModule::CreateMetadataIdentifierImpl. + if (!TypeID.consume_front("_ZTS")) + return false; + + // TypeID is keyed off the type name symbol (_ZTS). However, the native + // object may not contain this symbol if it does not contain a key + // function for the base type and thus only contains a reference to the + // type info (_ZTI). To catch this case we query using the type info + // symbol corresponding to the TypeID. + std::string typeInfo = ("_ZTI" + TypeID).str(); + return IsVisibleToRegularObj(typeInfo); +} + +static bool +skipUpdateDueToValidation(GlobalVariable &GV, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + SmallVector<MDNode *, 2> Types; + GV.getMetadata(LLVMContext::MD_type, Types); + + for (auto Type : Types) + if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get())) + return typeIDVisibleToRegularObj(TypeID->getString(), + IsVisibleToRegularObj); + + return false; +} + /// If whole program visibility asserted, then upgrade all public vcall /// visibility metadata on vtable definitions to linkage unit visibility in /// Module IR (for regular or hybrid LTO). -void updateVCallVisibilityInModule( +void llvm::updateVCallVisibilityInModule( Module &M, bool WholeProgramVisibilityEnabledInLTO, - const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { + const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, + bool ValidateAllVtablesHaveTypeInfos, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (GlobalVariable &GV : M.globals()) { @@ -800,13 +837,19 @@ void updateVCallVisibilityInModule( GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && // Don't upgrade the visibility for symbols exported to the dynamic // linker, as we have no information on their eventual use. - !DynamicExportSymbols.count(GV.getGUID())) + !DynamicExportSymbols.count(GV.getGUID()) && + // With validation enabled, we want to exclude symbols visible to + // regular objects. Local symbols will be in this group due to the + // current implementation but those with VCallVisibilityTranslationUnit + // will have already been marked in clang so are unaffected. + !(ValidateAllVtablesHaveTypeInfos && + skipUpdateDueToValidation(GV, IsVisibleToRegularObj))) GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); } } -void updatePublicTypeTestCalls(Module &M, - bool WholeProgramVisibilityEnabledInLTO) { +void llvm::updatePublicTypeTestCalls(Module &M, + bool WholeProgramVisibilityEnabledInLTO) { Function *PublicTypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::public_type_test)); if (!PublicTypeTestFunc) @@ -832,12 +875,26 @@ void updatePublicTypeTestCalls(Module &M, } } +/// Based on typeID string, get all associated vtable GUIDS that are +/// visible to regular objects. +void llvm::getVisibleToRegularObjVtableGUIDs( + ModuleSummaryIndex &Index, + DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + for (const auto &typeID : Index.typeIdCompatibleVtableMap()) { + if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj)) + for (const TypeIdOffsetVtableInfo &P : typeID.second) + VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID()); + } +} + /// If whole program visibility asserted, then upgrade all public vcall /// visibility metadata on vtable definition summaries to linkage unit /// visibility in Module summary index (for ThinLTO). -void updateVCallVisibilityInIndex( +void llvm::updateVCallVisibilityInIndex( ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, - const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { + const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, + const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) { if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (auto &P : Index) { @@ -850,18 +907,24 @@ void updateVCallVisibilityInIndex( if (!GVar || GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; + // With validation enabled, we want to exclude symbols visible to regular + // objects. Local symbols will be in this group due to the current + // implementation but those with VCallVisibilityTranslationUnit will have + // already been marked in clang so are unaffected. + if (VisibleToRegularObjSymbols.count(P.first)) + continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } } } -void runWholeProgramDevirtOnIndex( +void llvm::runWholeProgramDevirtOnIndex( ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); } -void updateIndexWPDForExports( +void llvm::updateIndexWPDForExports( ModuleSummaryIndex &Summary, function_ref<bool(StringRef, ValueInfo)> isExported, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { @@ -887,8 +950,6 @@ void updateIndexWPDForExports( } } -} // end namespace llvm - static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { // Check that summary index contains regular LTO module when performing // export to prevent occasional use of index from pure ThinLTO compilation @@ -942,7 +1003,7 @@ bool DevirtModule::runForTesting( ExitOnError ExitOnErr( "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); std::error_code EC; - if (StringRef(ClWriteSummary).endswith(".bc")) { + if (StringRef(ClWriteSummary).ends_with(".bc")) { raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); ExitOnErr(errorCodeToError(EC)); writeIndexToFile(*Summary, OS); @@ -1045,8 +1106,8 @@ bool DevirtModule::tryFindVirtualCallTargets( } bool DevirtIndex::tryFindVirtualCallTargets( - std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, - uint64_t ByteOffset) { + std::vector<ValueInfo> &TargetsForSlot, + const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) { for (const TypeIdOffsetVtableInfo &P : TIdInfo) { // Find a representative copy of the vtable initializer. // We can have multiple available_externally, linkonce_odr and weak_odr @@ -1203,7 +1264,8 @@ static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { // to better ensure we have the opportunity to inline them. bool IsExported = false; auto &S = Callee.getSummaryList()[0]; - CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); + CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false, + /* RelBF = */ 0); auto AddCalls = [&](CallSiteInfo &CSInfo) { for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { FS->addCall({Callee, CI}); @@ -1437,7 +1499,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, IRBuilder<> IRB(&CB); std::vector<Value *> Args; - Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); + Args.push_back(VCallSite.VTable); llvm::append_range(Args, CB.args()); CallBase *NewCS = nullptr; @@ -1471,10 +1533,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // llvm.type.test and therefore require an llvm.type.test resolution for the // type identifier. - std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { - CBs.first->replaceAllUsesWith(CBs.second); - CBs.first->eraseFromParent(); - }); + for (auto &[Old, New] : CallBases) { + Old->replaceAllUsesWith(New); + Old->eraseFromParent(); + } }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -1648,8 +1710,7 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, } Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { - Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); - return ConstantExpr::getGetElementPtr(Int8Ty, C, + return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV, ConstantInt::get(Int64Ty, M->Offset)); } @@ -1708,8 +1769,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, continue; auto *RetType = cast<IntegerType>(Call.CB.getType()); IRBuilder<> B(&Call.CB); - Value *Addr = - B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); + Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); if (RetType->getBitWidth() == 1) { Value *Bits = B.CreateLoad(Int8Ty, Addr); Value *BitsAndBit = B.CreateAnd(Bits, Bit); @@ -2007,17 +2067,14 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (TypeCheckedLoadFunc->getIntrinsicID() == Intrinsic::type_checked_load_relative) { Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty)); - LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr); + LoadedValue = LoadB.CreateLoad(Int32Ty, GEP); LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); } else { Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = - LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); - LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP); } for (Instruction *LoadedPtr : LoadedPtrs) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 91ca44e0f11e..719a2678fc18 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -830,15 +830,15 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C) Constant *NarrowC; if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) { - Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty); - Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideC = Builder.CreateSExt(NarrowC, Ty); + Value *NewC = Builder.CreateAdd(WideC, Op1C); Value *WideX = Builder.CreateSExt(X, Ty); return BinaryOperator::CreateAdd(WideX, NewC); } // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C) if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) { - Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty); - Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideC = Builder.CreateZExt(NarrowC, Ty); + Value *NewC = Builder.CreateAdd(WideC, Op1C); Value *WideX = Builder.CreateZExt(X, Ty); return BinaryOperator::CreateAdd(WideX, NewC); } @@ -903,8 +903,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // (X | Op01C) + Op1C --> X + (Op01C + Op1C) iff the `or` is actually an `add` Constant *Op01C; - if (match(Op0, m_Or(m_Value(X), m_ImmConstant(Op01C))) && - haveNoCommonBitsSet(X, Op01C, DL, &AC, &Add, &DT)) + if (match(Op0, m_DisjointOr(m_Value(X), m_ImmConstant(Op01C)))) return BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C)); // (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C) @@ -995,6 +994,69 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { return nullptr; } +// match variations of a^2 + 2*a*b + b^2 +// +// to reuse the code between the FP and Int versions, the instruction OpCodes +// and constant types have been turned into template parameters. +// +// Mul2Rhs: The constant to perform the multiplicative equivalent of X*2 with; +// should be `m_SpecificFP(2.0)` for FP and `m_SpecificInt(1)` for Int +// (we're matching `X<<1` instead of `X*2` for Int) +template <bool FP, typename Mul2Rhs> +static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A, + Value *&B) { + constexpr unsigned MulOp = FP ? Instruction::FMul : Instruction::Mul; + constexpr unsigned AddOp = FP ? Instruction::FAdd : Instruction::Add; + constexpr unsigned Mul2Op = FP ? Instruction::FMul : Instruction::Shl; + + // (a * a) + (((a * 2) + b) * b) + if (match(&I, m_c_BinOp( + AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))), + m_OneUse(m_BinOp( + MulOp, + m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs), + m_Value(B)), + m_Deferred(B)))))) + return true; + + // ((a * b) * 2) or ((a * 2) * b) + // + + // (a * a + b * b) or (b * b + a * a) + return match( + &I, + m_c_BinOp(AddOp, + m_CombineOr( + m_OneUse(m_BinOp( + Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)), + m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs), + m_Value(B)))), + m_OneUse(m_c_BinOp( + AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)), + m_BinOp(MulOp, m_Deferred(B), m_Deferred(B)))))); +} + +// Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +Instruction *InstCombinerImpl::foldSquareSumInt(BinaryOperator &I) { + Value *A, *B; + if (matchesSquareSum</*FP*/ false>(I, m_SpecificInt(1), A, B)) { + Value *AB = Builder.CreateAdd(A, B); + return BinaryOperator::CreateMul(AB, AB); + } + return nullptr; +} + +// Fold floating point variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +// Requires `nsz` and `reassoc`. +Instruction *InstCombinerImpl::foldSquareSumFP(BinaryOperator &I) { + assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch"); + Value *A, *B; + if (matchesSquareSum</*FP*/ true>(I, m_SpecificFP(2.0), A, B)) { + Value *AB = Builder.CreateFAddFMF(A, B, &I); + return BinaryOperator::CreateFMulFMF(AB, AB, &I); + } + return nullptr; +} + // Matches multiplication expression Op * C where C is a constant. Returns the // constant value in C and the other operand in Op. Returns true if such a // match is found. @@ -1146,6 +1208,21 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +// Transform: +// (add A, (shl (neg B), Y)) +// -> (sub A, (shl B, Y)) +static Instruction *combineAddSubWithShlAddSub(InstCombiner::BuilderTy &Builder, + const BinaryOperator &I) { + Value *A, *B, *Cnt; + if (match(&I, + m_c_Add(m_OneUse(m_Shl(m_OneUse(m_Neg(m_Value(B))), m_Value(Cnt))), + m_Value(A)))) { + Value *NewShl = Builder.CreateShl(B, Cnt); + return BinaryOperator::CreateSub(A, NewShl); + } + return nullptr; +} + /// Try to reduce signed division by power-of-2 to an arithmetic shift right. static Instruction *foldAddToAshr(BinaryOperator &Add) { // Division must be by power-of-2, but not the minimum signed value. @@ -1156,18 +1233,28 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) { return nullptr; // Rounding is done by adding -1 if the dividend (X) is negative and has any - // low bits set. The canonical pattern for that is an "ugt" compare with SMIN: - // sext (icmp ugt (X & (DivC - 1)), SMIN) - const APInt *MaskC; + // low bits set. It recognizes two canonical patterns: + // 1. For an 'ugt' cmp with the signed minimum value (SMIN), the + // pattern is: sext (icmp ugt (X & (DivC - 1)), SMIN). + // 2. For an 'eq' cmp, the pattern's: sext (icmp eq X & (SMIN + 1), SMIN + 1). + // Note that, by the time we end up here, if possible, ugt has been + // canonicalized into eq. + const APInt *MaskC, *MaskCCmp; ICmpInst::Predicate Pred; if (!match(Add.getOperand(1), m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)), - m_SignMask()))) || - Pred != ICmpInst::ICMP_UGT) + m_APInt(MaskCCmp))))) + return nullptr; + + if ((Pred != ICmpInst::ICMP_UGT || !MaskCCmp->isSignMask()) && + (Pred != ICmpInst::ICMP_EQ || *MaskCCmp != *MaskC)) return nullptr; APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits()); - if (*MaskC != (SMin | (*DivC - 1))) + bool IsMaskValid = Pred == ICmpInst::ICMP_UGT + ? (*MaskC == (SMin | (*DivC - 1))) + : (*DivC == 2 && *MaskC == SMin + 1); + if (!IsMaskValid) return nullptr; // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC) @@ -1327,8 +1414,10 @@ static Instruction *foldBoxMultiply(BinaryOperator &I) { // ResLo = (CrossSum << HalfBits) + (YLo * XLo) Value *XLo, *YLo; Value *CrossSum; + // Require one-use on the multiply to avoid increasing the number of + // multiplications. if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)), - m_Mul(m_Value(YLo), m_Value(XLo))))) + m_OneUse(m_Mul(m_Value(YLo), m_Value(XLo)))))) return nullptr; // XLo = X & HalfMask @@ -1386,6 +1475,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *R = foldBinOpShiftWithShift(I)) return R; + if (Instruction *R = combineAddSubWithShlAddSub(Builder, I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) @@ -1406,7 +1498,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B)); // -A + B --> B - A - return BinaryOperator::CreateSub(RHS, A); + auto *Sub = BinaryOperator::CreateSub(RHS, A); + auto *OB0 = cast<OverflowingBinaryOperator>(LHS); + Sub->setHasNoSignedWrap(I.hasNoSignedWrap() && OB0->hasNoSignedWrap()); + + return Sub; } // A + -B --> A - B @@ -1485,8 +1581,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); // A+B --> A|B iff A and B have no bits set in common. - if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) - return BinaryOperator::CreateOr(LHS, RHS); + WithCache<const Value *> LHSCache(LHS), RHSCache(RHS); + if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I))) + return BinaryOperator::CreateDisjointOr(LHS, RHS); if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1576,15 +1673,33 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { m_c_UMin(m_Deferred(A), m_Deferred(B)))))) return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I); + // (~X) + (~Y) --> -2 - (X + Y) + { + // To ensure we can save instructions we need to ensure that we consume both + // LHS/RHS (i.e they have a `not`). + bool ConsumesLHS, ConsumesRHS; + if (isFreeToInvert(LHS, LHS->hasOneUse(), ConsumesLHS) && ConsumesLHS && + isFreeToInvert(RHS, RHS->hasOneUse(), ConsumesRHS) && ConsumesRHS) { + Value *NotLHS = getFreelyInverted(LHS, LHS->hasOneUse(), &Builder); + Value *NotRHS = getFreelyInverted(RHS, RHS->hasOneUse(), &Builder); + assert(NotLHS != nullptr && NotRHS != nullptr && + "isFreeToInvert desynced with getFreelyInverted"); + Value *LHSPlusRHS = Builder.CreateAdd(NotLHS, NotRHS); + return BinaryOperator::CreateSub(ConstantInt::get(RHS->getType(), -2), + LHSPlusRHS); + } + } + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { + if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) { + if (!I.hasNoUnsignedWrap() && + willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1610,11 +1725,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // ctpop(A) + ctpop(B) => ctpop(A | B) if A and B have no bits set in common. if (match(LHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(A)))) && match(RHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(B)))) && - haveNoCommonBitsSet(A, B, DL, &AC, &I, &DT)) + haveNoCommonBitsSet(A, B, SQ.getWithInstruction(&I))) return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + if (Instruction *Res = foldSquareSumInt(I)) + return Res; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; @@ -1755,10 +1873,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { // instcombined. if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) if (IsValidPromotion(FPType, LHSIntVal->getType())) { - Constant *CI = - ConstantExpr::getFPToSI(CFP, LHSIntVal->getType()); + Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP, + LHSIntVal->getType(), DL); if (LHSConv->hasOneUse() && - ConstantExpr::getSIToFP(CI, I.getType()) == CFP && + ConstantFoldCastOperand(Instruction::SIToFP, CI, I.getType(), DL) == + CFP && willNotOverflowSignedAdd(LHSIntVal, CI, I)) { // Insert the new integer add. Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv"); @@ -1794,6 +1913,9 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { if (Instruction *F = factorizeFAddFSub(I, Builder)) return F; + if (Instruction *F = foldSquareSumFP(I)) + return F; + // Try to fold fadd into start value of reduction intrinsic. if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>( m_AnyZeroFP(), m_Value(X))), @@ -2017,14 +2139,16 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // C-(X+C2) --> (C-C2)-X if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) { - // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW - // => (C-C2)-X can have NSW + // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW/NUW + // => (C-C2)-X can have NSW/NUW bool WillNotSOV = willNotOverflowSignedSub(C, C2, I); BinaryOperator *Res = BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() && WillNotSOV); + Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap() && + OBO1->hasNoUnsignedWrap()); return Res; } } @@ -2058,7 +2182,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) || match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1))); })) { - if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this)) + if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation && + I.hasNoSignedWrap(), + Op1, *this)) return BinaryOperator::CreateAdd(NegOp1, Op0); } if (IsNegation) @@ -2093,19 +2219,50 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // ((X - Y) - Op1) --> X - (Y + Op1) if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) { - Value *Add = Builder.CreateAdd(Y, Op1); - return BinaryOperator::CreateSub(X, Add); + OverflowingBinaryOperator *LHSSub = cast<OverflowingBinaryOperator>(Op0); + bool HasNUW = I.hasNoUnsignedWrap() && LHSSub->hasNoUnsignedWrap(); + bool HasNSW = HasNUW && I.hasNoSignedWrap() && LHSSub->hasNoSignedWrap(); + Value *Add = Builder.CreateAdd(Y, Op1, "", /* HasNUW */ HasNUW, + /* HasNSW */ HasNSW); + BinaryOperator *Sub = BinaryOperator::CreateSub(X, Add); + Sub->setHasNoUnsignedWrap(HasNUW); + Sub->setHasNoSignedWrap(HasNSW); + return Sub; + } + + { + // (X + Z) - (Y + Z) --> (X - Y) + // This is done in other passes, but we want to be able to consume this + // pattern in InstCombine so we can generate it without creating infinite + // loops. + if (match(Op0, m_Add(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Add(m_Value(Y), m_Specific(Z)))) + return BinaryOperator::CreateSub(X, Y); + + // (X + C0) - (Y + C1) --> (X - Y) + (C0 - C1) + Constant *CX, *CY; + if (match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(CX)))) && + match(Op1, m_OneUse(m_Add(m_Value(Y), m_ImmConstant(CY))))) { + Value *OpsSub = Builder.CreateSub(X, Y); + Constant *ConstsSub = ConstantExpr::getSub(CX, CY); + return BinaryOperator::CreateAdd(OpsSub, ConstsSub); + } } // (~X) - (~Y) --> Y - X - // This is placed after the other reassociations and explicitly excludes a - // sub-of-sub pattern to avoid infinite looping. - if (isFreeToInvert(Op0, Op0->hasOneUse()) && - isFreeToInvert(Op1, Op1->hasOneUse()) && - !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) { - Value *NotOp0 = Builder.CreateNot(Op0); - Value *NotOp1 = Builder.CreateNot(Op1); - return BinaryOperator::CreateSub(NotOp1, NotOp0); + { + // Need to ensure we can consume at least one of the `not` instructions, + // otherwise this can inf loop. + bool ConsumesOp0, ConsumesOp1; + if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) && + isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) && + (ConsumesOp0 || ConsumesOp1)) { + Value *NotOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder); + Value *NotOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder); + assert(NotOp0 != nullptr && NotOp1 != nullptr && + "isFreeToInvert desynced with getFreelyInverted"); + return BinaryOperator::CreateSub(NotOp1, NotOp0); + } } auto m_AddRdx = [](Value *&Vec) { @@ -2520,18 +2677,33 @@ static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) { return nullptr; } -static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, - InstCombiner::BuilderTy &Builder) { - Value *FNeg; - if (!match(&I, m_FNeg(m_Value(FNeg)))) - return nullptr; - +Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp, + Instruction &FMFSource) { Value *X, *Y; - if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y))))) - return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + if (match(FNegOp, m_FMul(m_Value(X), m_Value(Y)))) { + return cast<Instruction>(Builder.CreateFMulFMF( + Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource)); + } + + if (match(FNegOp, m_FDiv(m_Value(X), m_Value(Y)))) { + return cast<Instruction>(Builder.CreateFDivFMF( + Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource)); + } + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(FNegOp)) { + // Make sure to preserve flags and metadata on the call. + if (II->getIntrinsicID() == Intrinsic::ldexp) { + FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags(); + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); - if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y))))) - return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + CallInst *New = Builder.CreateCall( + II->getCalledFunction(), + {Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)}); + New->copyMetadata(*II); + return New; + } + } return nullptr; } @@ -2553,13 +2725,13 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) return BinaryOperator::CreateFSubFMF(Y, X, &I); - if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) - return R; - Value *OneUse; if (!match(Op, m_OneUse(m_Value(OneUse)))) return nullptr; + if (Instruction *R = hoistFNegAboveFMulFDiv(OneUse, I)) + return replaceInstUsesWith(I, R); + // Try to eliminate fneg if at least 1 arm of the select is negated. Value *Cond; if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) { @@ -2569,8 +2741,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { S->copyFastMathFlags(&I); if (auto *OldSel = dyn_cast<SelectInst>(Op)) { - FastMathFlags FMF = I.getFastMathFlags(); - FMF |= OldSel->getFastMathFlags(); + FastMathFlags FMF = I.getFastMathFlags() | OldSel->getFastMathFlags(); S->setFastMathFlags(FMF); if (!OldSel->hasNoSignedZeros() && !CommonOperand && !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) @@ -2638,9 +2809,6 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; - if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) - return R; - Value *X, *Y; Constant *C; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 8a1fb6b7f17e..6002f599ca71 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1099,39 +1099,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); } - Value *Base, *Offset; - if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset)))) - return nullptr; - - if (!match(UnsignedICmp, - m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || - !ICmpInst::isUnsigned(UnsignedPred)) - return nullptr; - - // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset - // (no overflow and not null) - if ((UnsignedPred == ICmpInst::ICMP_UGE || - UnsignedPred == ICmpInst::ICMP_UGT) && - EqPred == ICmpInst::ICMP_NE && IsAnd) - return Builder.CreateICmpUGT(Base, Offset); - - // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset - // (overflow or null) - if ((UnsignedPred == ICmpInst::ICMP_ULE || - UnsignedPred == ICmpInst::ICMP_ULT) && - EqPred == ICmpInst::ICMP_EQ && !IsAnd) - return Builder.CreateICmpULE(Base, Offset); - - // Base <= Offset && (Base - Offset) != 0 --> Base < Offset - if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && - IsAnd) - return Builder.CreateICmpULT(Base, Offset); - - // Base > Offset || (Base - Offset) == 0 --> Base >= Offset - if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && - !IsAnd) - return Builder.CreateICmpUGE(Base, Offset); - return nullptr; } @@ -1179,13 +1146,40 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; - if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) - return nullptr; + auto GetMatchPart = [&](ICmpInst *Cmp, + unsigned OpNo) -> std::optional<IntPart> { + if (Pred == Cmp->getPredicate()) + return matchIntPart(Cmp->getOperand(OpNo)); + + const APInt *C; + // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to: + // (icmp ult (xor x, y), 1 << C) so also look for that. + if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) { + if (!match(Cmp->getOperand(1), m_Power2(C)) || + !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value()))) + return std::nullopt; + } - std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); - std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); - std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); - std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); + // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to: + // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that. + else if (Pred == CmpInst::ICMP_NE && + Cmp->getPredicate() == CmpInst::ICMP_UGT) { + if (!match(Cmp->getOperand(1), m_LowBitMask(C)) || + !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value()))) + return std::nullopt; + } else { + return std::nullopt; + } + + unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero(); + Instruction *I = cast<Instruction>(Cmp->getOperand(0)); + return {{I->getOperand(OpNo), From, C->getBitWidth() - From}}; + }; + + std::optional<IntPart> L0 = GetMatchPart(Cmp0, 0); + std::optional<IntPart> R0 = GetMatchPart(Cmp0, 1); + std::optional<IntPart> L1 = GetMatchPart(Cmp1, 0); + std::optional<IntPart> R1 = GetMatchPart(Cmp1, 1); if (!L0 || !R0 || !L1 || !R1) return nullptr; @@ -1616,7 +1610,7 @@ static Instruction *reassociateFCmps(BinaryOperator &BO, /// (~A & ~B) == (~(A | B)) /// (~A | ~B) == (~(A & B)) static Instruction *matchDeMorgansLaws(BinaryOperator &I, - InstCombiner::BuilderTy &Builder) { + InstCombiner &IC) { const Instruction::BinaryOps Opcode = I.getOpcode(); assert((Opcode == Instruction::And || Opcode == Instruction::Or) && "Trying to match De Morgan's Laws with something other than and/or"); @@ -1629,10 +1623,10 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *A, *B; if (match(Op0, m_OneUse(m_Not(m_Value(A)))) && match(Op1, m_OneUse(m_Not(m_Value(B)))) && - !InstCombiner::isFreeToInvert(A, A->hasOneUse()) && - !InstCombiner::isFreeToInvert(B, B->hasOneUse())) { + !IC.isFreeToInvert(A, A->hasOneUse()) && + !IC.isFreeToInvert(B, B->hasOneUse())) { Value *AndOr = - Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); + IC.Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } @@ -1644,8 +1638,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *C; if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) && match(Op1, m_Not(m_Value(C)))) { - Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); - return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); + Value *FlippedBO = IC.Builder.CreateBinOp(FlippedOpcode, B, C); + return BinaryOperator::Create(Opcode, A, IC.Builder.CreateNot(FlippedBO)); } return nullptr; @@ -1669,7 +1663,7 @@ bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) { /// Fold {and,or,xor} (cast X), C. static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, - InstCombiner::BuilderTy &Builder) { + InstCombinerImpl &IC) { Constant *C = dyn_cast<Constant>(Logic.getOperand(1)); if (!C) return nullptr; @@ -1684,21 +1678,17 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, // instruction may be cheaper (particularly in the case of vectors). Value *X; if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { - Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); - Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy); - if (ZextTruncC == C) { + if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) { // LogicOpc (zext X), C --> zext (LogicOpc X, C) - Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC); return new ZExtInst(NewOp, DestTy); } } if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) { - Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); - Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy); - if (SextTruncC == C) { + if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) { // LogicOpc (sext X), C --> sext (LogicOpc X, C) - Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC); return new SExtInst(NewOp, DestTy); } } @@ -1756,7 +1746,7 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { if (!SrcTy->isIntOrIntVectorTy()) return nullptr; - if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder)) + if (Instruction *Ret = foldLogicCastConstant(I, Cast0, *this)) return Ret; CastInst *Cast1 = dyn_cast<CastInst>(Op1); @@ -1802,29 +1792,6 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { return CastInst::Create(CastOpcode, NewOp, DestTy); } - // For now, only 'and'/'or' have optimizations after this. - if (LogicOpc == Instruction::Xor) - return nullptr; - - // If this is logic(cast(icmp), cast(icmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); - ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); - if (ICmp0 && ICmp1) { - if (Value *Res = - foldAndOrOfICmps(ICmp0, ICmp1, I, LogicOpc == Instruction::And)) - return CastInst::Create(CastOpcode, Res, DestTy); - return nullptr; - } - - // If this is logic(cast(fcmp), cast(fcmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); - FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); - if (FCmp0 && FCmp1) - if (Value *R = foldLogicOfFCmps(FCmp0, FCmp1, LogicOpc == Instruction::And)) - return CastInst::Create(CastOpcode, R, DestTy); - return nullptr; } @@ -2160,10 +2127,10 @@ Instruction *InstCombinerImpl::foldBinOpOfDisplacedShifts(BinaryOperator &I) { Constant *ShiftedC1, *ShiftedC2, *AddC; Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); - if (!match(&I, - m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)), - m_Shift(m_ImmConstant(ShiftedC2), - m_Add(m_Deferred(ShAmt), m_ImmConstant(AddC)))))) + if (!match(&I, m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)), + m_Shift(m_ImmConstant(ShiftedC2), + m_AddLike(m_Deferred(ShAmt), + m_ImmConstant(AddC)))))) return nullptr; // Make sure the add constant is a valid shift amount. @@ -2254,6 +2221,14 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y); } + // Canonicalize: + // (X +/- Y) & Y --> ~X & Y when Y is a power of 2. + if (match(&I, m_c_And(m_Value(Y), m_OneUse(m_CombineOr( + m_c_Add(m_Value(X), m_Deferred(Y)), + m_Sub(m_Value(X), m_Deferred(Y)))))) && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, /*Depth*/ 0, &I)) + return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); + const APInt *C; if (match(Op1, m_APInt(C))) { const APInt *XorC; @@ -2300,13 +2275,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { const APInt *AddC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) { - // If we add zeros to every bit below a mask, the add has no effect: - // (X + AddC) & LowMaskC --> X & LowMaskC - unsigned Ctlz = C->countl_zero(); - APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); - if ((*AddC & LowMask).isZero()) - return BinaryOperator::CreateAnd(X, Op1); - // If we are masking the result of the add down to exactly one bit and // the constant we are adding has no bits set below that bit, then the // add is flipping a single bit. Example: @@ -2455,6 +2423,28 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { } } + // If we are clearing the sign bit of a floating-point value, convert this to + // fabs, then cast back to integer. + // + // This is a generous interpretation for noimplicitfloat, this is not a true + // floating-point operation. + // + // Assumes any IEEE-represented type has the sign bit in the high bit. + // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt + Value *CastOp; + if (match(Op0, m_BitCast(m_Value(CastOp))) && + match(Op1, m_MaxSignedValue()) && + !Builder.GetInsertBlock()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat)) { + Type *EltTy = CastOp->getType()->getScalarType(); + if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && + EltTy->getPrimitiveSizeInBits() == + I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); + return new BitCastInst(FAbs, I.getType()); + } + } + if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))), m_SignMask())) && match(Y, m_SpecificInt_ICMP( @@ -2479,21 +2469,21 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (I.getType()->isIntOrIntVectorTy(1)) { if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ true)) - return I; + return R; } if (auto *SI1 = dyn_cast<SelectInst>(Op1)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ true)) - return I; + return R; } } if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) return FoldedLogic; - if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) + if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this)) return DeMorgan; { @@ -2513,16 +2503,24 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateAnd(Op1, B); // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C - if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) - if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) - if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) - return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C)); + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) { + Value *NotC = Op1->hasOneUse() + ? Builder.CreateNot(C) + : getFreelyInverted(C, C->hasOneUse(), &Builder); + if (NotC != nullptr) + return BinaryOperator::CreateAnd(Op0, NotC); + } // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C - if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) - if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) - if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) - return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B))) && + match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) { + Value *NotC = Op0->hasOneUse() + ? Builder.CreateNot(C) + : getFreelyInverted(C, C->hasOneUse(), &Builder); + if (NotC != nullptr) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); + } // (A | B) & (~A ^ B) -> A & B // (A | B) & (B ^ ~A) -> A & B @@ -2621,23 +2619,34 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // with binop identity constant. But creating a select with non-constant // arm may not be reversible due to poison semantics. Is that a good // canonicalization? - Value *A; - if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && - A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Op1, Constant::getNullValue(Ty)); - if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + Value *A, *B; + if (match(&I, m_c_And(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); + return SelectInst::Create(A, B, Constant::getNullValue(Ty)); // Similarly, a 'not' of the bool translates to a swap of the select arms: - // ~sext(A) & Op1 --> A ? 0 : Op1 - // Op0 & ~sext(A) --> A ? 0 : Op0 - if (match(Op0, m_Not(m_SExt(m_Value(A)))) && + // ~sext(A) & B / B & ~sext(A) --> A ? 0 : B + if (match(&I, m_c_And(m_Not(m_SExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Constant::getNullValue(Ty), Op1); - if (match(Op1, m_Not(m_SExt(m_Value(A)))) && + return SelectInst::Create(A, Constant::getNullValue(Ty), B); + + // and(zext(A), B) -> A ? (B & 1) : 0 + if (match(&I, m_c_And(m_OneUse(m_ZExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Constant::getNullValue(Ty), Op0); + return SelectInst::Create(A, Builder.CreateAnd(B, ConstantInt::get(Ty, 1)), + Constant::getNullValue(Ty)); + + // (-1 + A) & B --> A ? 0 : B where A is 0/1. + if (match(&I, m_c_And(m_OneUse(m_Add(m_ZExtOrSelf(m_Value(A)), m_AllOnes())), + m_Value(B)))) { + if (A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Constant::getNullValue(Ty), B); + if (computeKnownBits(A, /* Depth */ 0, &I).countMaxActiveBits() <= 1) { + return SelectInst::Create( + Builder.CreateICmpEQ(A, Constant::getNullValue(A->getType())), B, + Constant::getNullValue(Ty)); + } + } // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( @@ -2698,105 +2707,178 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I, } /// Match UB-safe variants of the funnel shift intrinsic. -static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) { +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, + const DominatorTree &DT) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? unsigned Width = Or.getType()->getScalarSizeInBits(); + Instruction *Or0, *Or1; + if (!match(Or.getOperand(0), m_Instruction(Or0)) || + !match(Or.getOperand(1), m_Instruction(Or1))) + return nullptr; + + bool IsFshl = true; // Sub on LSHR. + SmallVector<Value *, 3> FShiftArgs; + // First, find an or'd pair of opposite shifts: // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1) - BinaryOperator *Or0, *Or1; - if (!match(Or.getOperand(0), m_BinOp(Or0)) || - !match(Or.getOperand(1), m_BinOp(Or1))) - return nullptr; + if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) { + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, + m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, + m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || + Or0->getOpcode() == Or1->getOpcode()) + return nullptr; - Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || - Or0->getOpcode() == Or1->getOpcode()) - return nullptr; + // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(ShVal0, ShVal1); + std::swap(ShAmt0, ShAmt1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); - // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). - if (Or0->getOpcode() == BinaryOperator::LShr) { - std::swap(Or0, Or1); - std::swap(ShVal0, ShVal1); - std::swap(ShAmt0, ShAmt1); - } - assert(Or0->getOpcode() == BinaryOperator::Shl && - Or1->getOpcode() == BinaryOperator::LShr && - "Illegal or(shift,shift) pair"); + // Match the shift amount operands for a funnel shift pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { + // Check for constant shift amounts that sum to the bitwidth. + const APInt *LI, *RI; + if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) + if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) + return ConstantInt::get(L->getType(), *LI); + + Constant *LC, *RC; + if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) && + match(L, + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(R, + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) + return ConstantExpr::mergeUndefsWith(LC, RC); + + // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. + // We limit this to X < Width in case the backend re-expands the + // intrinsic, and has to reintroduce a shift modulo operation (InstCombine + // might remove it after this fold). This still doesn't guarantee that the + // final codegen will match this original pattern. + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { + KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + return KnownL.getMaxValue().ult(Width) ? L : nullptr; + } - // Match the shift amount operands for a funnel shift pattern. This always - // matches a subtraction on the R operand. - auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { - // Check for constant shift amounts that sum to the bitwidth. - const APInt *LI, *RI; - if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) - if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) - return ConstantInt::get(L->getType(), *LI); + // For non-constant cases, the following patterns currently only work for + // rotation patterns. + // TODO: Add general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; - Constant *LC, *RC; - if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) && - match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && - match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && - match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) - return ConstantExpr::mergeUndefsWith(LC, RC); + // For non-constant cases we don't support non-pow2 shift masks. + // TODO: Is it worth matching urem as well? + if (!isPowerOf2_32(Width)) + return nullptr; - // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. - // We limit this to X < Width in case the backend re-expands the intrinsic, - // and has to reintroduce a shift modulo operation (InstCombine might remove - // it after this fold). This still doesn't guarantee that the final codegen - // will match this original pattern. - if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { - KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); - return KnownL.getMaxValue().ult(Width) ? L : nullptr; + // The shift amount may be masked with negation: + // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + // Similar to above, but the shift amount may be extended after masking, + // so return the extended value as the parameter for the intrinsic. + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, + m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))), + m_SpecificInt(Mask)))) + return L; + + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return L; + + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); + IsFshl = false; // Sub on SHL. } + if (!ShAmt) + return nullptr; + + FShiftArgs = {ShVal0, ShVal1, ShAmt}; + } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) { + // If there are two 'or' instructions concat variables in opposite order: + // + // Slot1 and Slot2 are all zero bits. + // | Slot1 | Low | Slot2 | High | + // LowHigh = or (shl (zext Low), ZextLowShlAmt), (zext High) + // | Slot2 | High | Slot1 | Low | + // HighLow = or (shl (zext High), ZextHighShlAmt), (zext Low) + // + // the latter 'or' can be safely convert to + // -> HighLow = fshl LowHigh, LowHigh, ZextHighShlAmt + // if ZextLowShlAmt + ZextHighShlAmt == Width. + if (!isa<ZExtInst>(Or1)) + std::swap(Or0, Or1); - // For non-constant cases, the following patterns currently only work for - // rotation patterns. - // TODO: Add general funnel-shift compatible patterns. - if (ShVal0 != ShVal1) + Value *High, *ZextHigh, *Low; + const APInt *ZextHighShlAmt; + if (!match(Or0, + m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt))))) return nullptr; - // For non-constant cases we don't support non-pow2 shift masks. - // TODO: Is it worth matching urem as well? - if (!isPowerOf2_32(Width)) + if (!match(Or1, m_ZExt(m_Value(Low))) || + !match(ZextHigh, m_ZExt(m_Value(High)))) return nullptr; - // The shift amount may be masked with negation: - // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) - Value *X; - unsigned Mask = Width - 1; - if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && - match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) - return X; + unsigned HighSize = High->getType()->getScalarSizeInBits(); + unsigned LowSize = Low->getType()->getScalarSizeInBits(); + // Make sure High does not overlap with Low and most significant bits of + // High aren't shifted out. + if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize)) + return nullptr; - // Similar to above, but the shift amount may be extended after masking, - // so return the extended value as the parameter for the intrinsic. - if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && - match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))), - m_SpecificInt(Mask)))) - return L; + for (User *U : ZextHigh->users()) { + Value *X, *Y; + if (!match(U, m_Or(m_Value(X), m_Value(Y)))) + continue; - if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && - match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) - return L; + if (!isa<ZExtInst>(Y)) + std::swap(X, Y); - return nullptr; - }; + const APInt *ZextLowShlAmt; + if (!match(X, m_Shl(m_Specific(Or1), m_APInt(ZextLowShlAmt))) || + !match(Y, m_Specific(ZextHigh)) || !DT.dominates(U, &Or)) + continue; - Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); - bool IsFshl = true; // Sub on LSHR. - if (!ShAmt) { - ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); - IsFshl = false; // Sub on SHL. + // HighLow is good concat. If sum of two shifts amount equals to Width, + // LowHigh must also be a good concat. + if (*ZextLowShlAmt + *ZextHighShlAmt != Width) + continue; + + // Low must not overlap with High and most significant bits of Low must + // not be shifted out. + assert(ZextLowShlAmt->uge(HighSize) && + ZextLowShlAmt->ule(Width - LowSize) && "Invalid concat"); + + FShiftArgs = {U, U, ConstantInt::get(Or0->getType(), *ZextHighShlAmt)}; + break; + } } - if (!ShAmt) + + if (FShiftArgs.empty()) return nullptr; Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return CallInst::Create(F, {ShVal0, ShVal1, ShAmt}); + return CallInst::Create(F, FShiftArgs); } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. @@ -3272,14 +3354,14 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) { if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ false)) - return I; + return R; } if (auto *SI1 = dyn_cast<SelectInst>(Op1)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ false)) - return I; + return R; } } @@ -3290,7 +3372,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*MatchBitReversals*/ true)) return BitOp; - if (Instruction *Funnel = matchFunnelShift(I, *this)) + if (Instruction *Funnel = matchFunnelShift(I, *this, DT)) return Funnel; if (Instruction *Concat = matchOrConcat(I, Builder)) @@ -3311,9 +3393,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // If the operands have no common bits set: // or (mul X, Y), X --> add (mul X, Y), X --> mul X, (Y + 1) - if (match(&I, - m_c_Or(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), m_Deferred(X))) && - haveNoCommonBitsSet(Op0, Op1, DL)) { + if (match(&I, m_c_DisjointOr(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), + m_Deferred(X)))) { Value *IncrementY = Builder.CreateAdd(Y, ConstantInt::get(Ty, 1)); return BinaryOperator::CreateMul(X, IncrementY); } @@ -3435,7 +3516,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C)); - if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) + if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this)) return DeMorgan; // Canonicalize xor to the RHS. @@ -3581,12 +3662,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // with binop identity constant. But creating a select with non-constant // arm may not be reversible due to poison semantics. Is that a good // canonicalization? - if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + if (match(&I, m_c_Or(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op1); - if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && - A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op0); + return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), B); // Note: If we've gotten to the point of visiting the outer OR, then the // inner one couldn't be simplified. If it was a constant, then it won't @@ -3628,6 +3706,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { } } + { + // ((A & B) ^ A) | ((A & B) ^ B) -> A ^ B + // (A ^ (A & B)) | (B ^ (A & B)) -> A ^ B + // ((A & B) ^ B) | ((A & B) ^ A) -> A ^ B + // (B ^ (A & B)) | (A ^ (A & B)) -> A ^ B + const auto TryXorOpt = [&](Value *Lhs, Value *Rhs) -> Instruction * { + if (match(Lhs, m_c_Xor(m_And(m_Value(A), m_Value(B)), m_Deferred(A))) && + match(Rhs, + m_c_Xor(m_And(m_Specific(A), m_Specific(B)), m_Deferred(B)))) { + return BinaryOperator::CreateXor(A, B); + } + return nullptr; + }; + + if (Instruction *Result = TryXorOpt(Op0, Op1)) + return Result; + if (Instruction *Result = TryXorOpt(Op1, Op0)) + return Result; + } + if (Instruction *V = canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; @@ -3720,6 +3818,31 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; + // If we are setting the sign bit of a floating-point value, convert + // this to fneg(fabs), then cast back to integer. + // + // If the result isn't immediately cast back to a float, this will increase + // the number of instructions. This is still probably a better canonical form + // as it enables FP value tracking. + // + // Assumes any IEEE-represented type has the sign bit in the high bit. + // + // This is generous interpretation of noimplicitfloat, this is not a true + // floating-point operation. + Value *CastOp; + if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + !Builder.GetInsertBlock()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat)) { + Type *EltTy = CastOp->getType()->getScalarType(); + if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && + EltTy->getPrimitiveSizeInBits() == + I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); + Value *FNegFAbs = Builder.CreateFNeg(FAbs); + return new BitCastInst(FNegFAbs, I.getType()); + } + } + return nullptr; } @@ -3931,26 +4054,6 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, return nullptr; } -// Transform -// ~(x ^ y) -// into: -// (~x) ^ y -// or into -// x ^ (~y) -static Instruction *sinkNotIntoXor(BinaryOperator &I, Value *X, Value *Y, - InstCombiner::BuilderTy &Builder) { - // We only want to do the transform if it is free to do. - if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) { - // Ok, good. - } else if (InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) { - std::swap(X, Y); - } else - return nullptr; - - Value *NotX = Builder.CreateNot(X, X->getName() + ".not"); - return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); -} - static Instruction *foldNotXor(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { Value *X, *Y; @@ -3959,9 +4062,6 @@ static Instruction *foldNotXor(BinaryOperator &I, if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) return nullptr; - if (Instruction *NewXor = sinkNotIntoXor(I, X, Y, Builder)) - return NewXor; - auto hasCommonOperand = [](Value *A, Value *B, Value *C, Value *D) { return A == C || A == D || B == C || B == D; }; @@ -4023,13 +4123,13 @@ static bool canFreelyInvert(InstCombiner &IC, Value *Op, Instruction *IgnoredUser) { auto *I = dyn_cast<Instruction>(Op); return I && IC.isFreeToInvert(I, /*WillInvertAllUses=*/true) && - InstCombiner::canFreelyInvertAllUsersOf(I, IgnoredUser); + IC.canFreelyInvertAllUsersOf(I, IgnoredUser); } static Value *freelyInvert(InstCombinerImpl &IC, Value *Op, Instruction *IgnoredUser) { auto *I = cast<Instruction>(Op); - IC.Builder.SetInsertPoint(&*I->getInsertionPointAfterDef()); + IC.Builder.SetInsertPoint(*I->getInsertionPointAfterDef()); Value *NotOp = IC.Builder.CreateNot(Op, Op->getName() + ".not"); Op->replaceUsesWithIf(NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; }); @@ -4067,7 +4167,7 @@ bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) { Op0 = freelyInvert(*this, Op0, &I); Op1 = freelyInvert(*this, Op1, &I); - Builder.SetInsertPoint(I.getInsertionPointAfterDef()); + Builder.SetInsertPoint(*I.getInsertionPointAfterDef()); Value *NewLogicOp; if (IsBinaryOp) NewLogicOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); @@ -4115,7 +4215,7 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { *OpToInvert = freelyInvert(*this, *OpToInvert, &I); - Builder.SetInsertPoint(&*I.getInsertionPointAfterDef()); + Builder.SetInsertPoint(*I.getInsertionPointAfterDef()); Value *NewBinOp; if (IsBinaryOp) NewBinOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); @@ -4259,15 +4359,6 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // ~max(~X, Y) --> min(X, ~Y) auto *II = dyn_cast<IntrinsicInst>(NotOp); if (II && II->hasOneUse()) { - if (match(NotOp, m_MaxOrMin(m_Value(X), m_Value(Y))) && - isFreeToInvert(X, X->hasOneUse()) && - isFreeToInvert(Y, Y->hasOneUse())) { - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - Value *NotX = Builder.CreateNot(X); - Value *NotY = Builder.CreateNot(Y); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, NotX, NotY); - return replaceInstUsesWith(I, InvMaxMin); - } if (match(NotOp, m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y)))) { Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); Value *NotY = Builder.CreateNot(Y); @@ -4317,6 +4408,11 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { if (Instruction *NewXor = foldNotXor(I, Builder)) return NewXor; + // TODO: Could handle multi-use better by checking if all uses of NotOp (other + // than I) can be inverted. + if (Value *R = getFreelyInverted(NotOp, NotOp->hasOneUse(), &Builder)) + return replaceInstUsesWith(I, R); + return nullptr; } @@ -4366,7 +4462,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { Value *M; if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), m_c_And(m_Deferred(M), m_Value())))) - return BinaryOperator::CreateOr(Op0, Op1); + return BinaryOperator::CreateDisjointOr(Op0, Op1); if (Instruction *Xor = visitMaskedMerge(I, Builder)) return Xor; @@ -4466,6 +4562,27 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // a 'not' op and moving it before the shift. Doing that requires // preventing the inverse fold in canShiftBinOpWithConstantRHS(). } + + // If we are XORing the sign bit of a floating-point value, convert + // this to fneg, then cast back to integer. + // + // This is generous interpretation of noimplicitfloat, this is not a true + // floating-point operation. + // + // Assumes any IEEE-represented type has the sign bit in the high bit. + // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt + Value *CastOp; + if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + !Builder.GetInsertBlock()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat)) { + Type *EltTy = CastOp->getType()->getScalarType(); + if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && + EltTy->getPrimitiveSizeInBits() == + I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + Value *FNeg = Builder.CreateFNeg(CastOp); + return new BitCastInst(FNeg, I.getType()); + } + } } // FIXME: This should not be limited to scalar (pull into APInt match above). diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index d3ec6a7aa667..255ce6973a16 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -89,12 +89,6 @@ static cl::opt<unsigned> GuardWideningWindow( cl::desc("How wide an instruction window to bypass looking for " "another guard")); -namespace llvm { -/// enable preservation of attributes in assume like: -/// call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ] -extern cl::opt<bool> EnableKnowledgeRetention; -} // namespace llvm - /// Return the specified type promoted as it would be to pass though a va_arg /// area. static Type *getPromotedType(Type *Ty) { @@ -174,14 +168,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { return nullptr; // Use an integer load+store unless we can find something better. - unsigned SrcAddrSp = - cast<PointerType>(MI->getArgOperand(1)->getType())->getAddressSpace(); - unsigned DstAddrSp = - cast<PointerType>(MI->getArgOperand(0)->getType())->getAddressSpace(); - IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3); - Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp); - Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); // If the memcpy has metadata describing the members, see if we can get the // TBAA tag describing our copy. @@ -200,8 +187,8 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { CopyMD = cast<MDNode>(M->getOperand(2)); } - Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); - Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); + Value *Src = MI->getArgOperand(1); + Value *Dest = MI->getArgOperand(0); LoadInst *L = Builder.CreateLoad(IntType, Src); // Alignment from the mem intrinsic will be better, so use it. L->setAlignment(*CopySrcAlign); @@ -291,9 +278,6 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { Type *ITy = IntegerType::get(MI->getContext(), Len*8); // n=1 -> i8. Value *Dest = MI->getDest(); - unsigned DstAddrSp = cast<PointerType>(Dest->getType())->getAddressSpace(); - Type *NewDstPtrTy = PointerType::get(ITy, DstAddrSp); - Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); // Extract the fill value and store. const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; @@ -301,7 +285,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile()); S->copyMetadata(*MI, LLVMContext::MD_DIAssignID); for (auto *DAI : at::getAssignmentMarkers(S)) { - if (any_of(DAI->location_ops(), [&](Value *V) { return V == FillC; })) + if (llvm::is_contained(DAI->location_ops(), FillC)) DAI->replaceVariableLocationOp(FillC, FillVal); } @@ -500,8 +484,6 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, if (Result->getType()->getPointerAddressSpace() != II.getType()->getPointerAddressSpace()) Result = IC.Builder.CreateAddrSpaceCast(Result, II.getType()); - if (Result->getType() != II.getType()) - Result = IC.Builder.CreateBitCast(Result, II.getType()); return cast<Instruction>(Result); } @@ -532,6 +514,8 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType())); } + Constant *C; + if (IsTZ) { // cttz(-x) -> cttz(x) if (match(Op0, m_Neg(m_Value(X)))) @@ -567,6 +551,38 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) return IC.replaceOperand(II, 0, X); + + // cttz(shl(%const, %val), 1) --> add(cttz(%const, 1), %val) + if (match(Op0, m_Shl(m_ImmConstant(C), m_Value(X))) && + match(Op1, m_One())) { + Value *ConstCttz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1); + return BinaryOperator::CreateAdd(ConstCttz, X); + } + + // cttz(lshr exact (%const, %val), 1) --> sub(cttz(%const, 1), %val) + if (match(Op0, m_Exact(m_LShr(m_ImmConstant(C), m_Value(X)))) && + match(Op1, m_One())) { + Value *ConstCttz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1); + return BinaryOperator::CreateSub(ConstCttz, X); + } + } else { + // ctlz(lshr(%const, %val), 1) --> add(ctlz(%const, 1), %val) + if (match(Op0, m_LShr(m_ImmConstant(C), m_Value(X))) && + match(Op1, m_One())) { + Value *ConstCtlz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1); + return BinaryOperator::CreateAdd(ConstCtlz, X); + } + + // ctlz(shl nuw (%const, %val), 1) --> sub(ctlz(%const, 1), %val) + if (match(Op0, m_NUWShl(m_ImmConstant(C), m_Value(X))) && + match(Op1, m_One())) { + Value *ConstCtlz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1); + return BinaryOperator::CreateSub(ConstCtlz, X); + } } KnownBits Known = IC.computeKnownBits(Op0, 0, &II); @@ -911,11 +927,27 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { Value *FAbsSrc; if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) { - II.setArgOperand(1, ConstantInt::get(Src1->getType(), fabs(Mask))); + II.setArgOperand(1, ConstantInt::get(Src1->getType(), inverse_fabs(Mask))); return replaceOperand(II, 0, FAbsSrc); } - // TODO: is.fpclass(x, fcInf) -> fabs(x) == inf + if ((OrderedMask == fcInf || OrderedInvertedMask == fcInf) && + (IsOrdered || IsUnordered) && !IsStrict) { + // is.fpclass(x, fcInf) -> fcmp oeq fabs(x), +inf + // is.fpclass(x, ~fcInf) -> fcmp one fabs(x), +inf + // is.fpclass(x, fcInf|fcNan) -> fcmp ueq fabs(x), +inf + // is.fpclass(x, ~(fcInf|fcNan)) -> fcmp une fabs(x), +inf + Constant *Inf = ConstantFP::getInfinity(Src0->getType()); + FCmpInst::Predicate Pred = + IsUnordered ? FCmpInst::FCMP_UEQ : FCmpInst::FCMP_OEQ; + if (OrderedInvertedMask == fcInf) + Pred = IsUnordered ? FCmpInst::FCMP_UNE : FCmpInst::FCMP_ONE; + + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Src0); + Value *CmpInf = Builder.CreateFCmp(Pred, Fabs, Inf); + CmpInf->takeName(&II); + return replaceInstUsesWith(II, CmpInf); + } if ((OrderedMask == fcPosInf || OrderedMask == fcNegInf) && (IsOrdered || IsUnordered) && !IsStrict) { @@ -992,8 +1024,7 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { return replaceInstUsesWith(II, FCmp); } - KnownFPClass Known = computeKnownFPClass( - Src0, DL, Mask, 0, &getTargetLibraryInfo(), &AC, &II, &DT); + KnownFPClass Known = computeKnownFPClass(Src0, Mask, &II); // Clear test bits we know must be false from the source value. // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other @@ -1030,6 +1061,20 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } +static std::optional<bool> getKnownSignOrZero(Value *Op, Instruction *CxtI, + const DataLayout &DL, + AssumptionCache *AC, + DominatorTree *DT) { + if (std::optional<bool> Sign = getKnownSign(Op, CxtI, DL, AC, DT)) + return Sign; + + Value *X, *Y; + if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) + return isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, CxtI, DL); + + return std::nullopt; +} + /// Return true if two values \p Op0 and \p Op1 are known to have the same sign. static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI, const DataLayout &DL, AssumptionCache *AC, @@ -1530,12 +1575,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); - if (std::optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { - // abs(x) -> x if x >= 0 - if (!*Sign) + if (std::optional<bool> Known = + getKnownSignOrZero(IIOperand, II, DL, &AC, &DT)) { + // abs(x) -> x if x >= 0 (include abs(x-y) --> x - y where x >= y) + // abs(x) -> x if x > 0 (include abs(x-y) --> x - y where x > y) + if (!*Known) return replaceInstUsesWith(*II, IIOperand); // abs(x) -> -x if x < 0 + // abs(x) -> -x if x < = 0 (include abs(x-y) --> y - x where x <= y) if (IntMinIsPoison) return BinaryOperator::CreateNSWNeg(IIOperand); return BinaryOperator::CreateNeg(IIOperand); @@ -1580,8 +1628,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Constant *C; if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) && I0->hasOneUse()) { - Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getZExt(NarrowC, II->getType()) == C) { + if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) { Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC); return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType()); } @@ -1603,13 +1650,26 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Constant *C; if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) && I0->hasOneUse()) { - Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getSExt(NarrowC, II->getType()) == C) { + if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) { Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC); return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType()); } } + // umin(i1 X, i1 Y) -> and i1 X, Y + // smax(i1 X, i1 Y) -> and i1 X, Y + if ((IID == Intrinsic::umin || IID == Intrinsic::smax) && + II->getType()->isIntOrIntVectorTy(1)) { + return BinaryOperator::CreateAnd(I0, I1); + } + + // umax(i1 X, i1 Y) -> or i1 X, Y + // smin(i1 X, i1 Y) -> or i1 X, Y + if ((IID == Intrinsic::umax || IID == Intrinsic::smin) && + II->getType()->isIntOrIntVectorTy(1)) { + return BinaryOperator::CreateOr(I0, I1); + } + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { // smax (neg nsw X), (neg nsw Y) --> neg nsw (smin X, Y) // smin (neg nsw X), (neg nsw Y) --> neg nsw (smax X, Y) @@ -1672,12 +1732,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { Value *A; if (match(X, m_OneUse(m_Not(m_Value(A)))) && - !isFreeToInvert(A, A->hasOneUse()) && - isFreeToInvert(Y, Y->hasOneUse())) { - Value *NotY = Builder.CreateNot(Y); - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); - return BinaryOperator::CreateNot(InvMaxMin); + !isFreeToInvert(A, A->hasOneUse())) { + if (Value *NotY = getFreelyInverted(Y, Y->hasOneUse(), &Builder)) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); + return BinaryOperator::CreateNot(InvMaxMin); + } } return nullptr; }; @@ -1929,6 +1989,52 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return &CI; break; } + case Intrinsic::ptrmask: { + unsigned BitWidth = DL.getPointerTypeSizeInBits(II->getType()); + KnownBits Known(BitWidth); + if (SimplifyDemandedInstructionBits(*II, Known)) + return II; + + Value *InnerPtr, *InnerMask; + bool Changed = false; + // Combine: + // (ptrmask (ptrmask p, A), B) + // -> (ptrmask p, (and A, B)) + if (match(II->getArgOperand(0), + m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr), + m_Value(InnerMask))))) { + assert(II->getArgOperand(1)->getType() == InnerMask->getType() && + "Mask types must match"); + // TODO: If InnerMask == Op1, we could copy attributes from inner + // callsite -> outer callsite. + Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask); + replaceOperand(CI, 0, InnerPtr); + replaceOperand(CI, 1, NewMask); + Changed = true; + } + + // See if we can deduce non-null. + if (!CI.hasRetAttr(Attribute::NonNull) && + (Known.isNonZero() || + isKnownNonZero(II, DL, /*Depth*/ 0, &AC, II, &DT))) { + CI.addRetAttr(Attribute::NonNull); + Changed = true; + } + + unsigned NewAlignmentLog = + std::min(Value::MaxAlignmentExponent, + std::min(BitWidth - 1, Known.countMinTrailingZeros())); + // Known bits will capture if we had alignment information associated with + // the pointer argument. + if (NewAlignmentLog > Log2(CI.getRetAlign().valueOrOne())) { + CI.addRetAttr(Attribute::getWithAlignment( + CI.getContext(), Align(uint64_t(1) << NewAlignmentLog))); + Changed = true; + } + if (Changed) + return &CI; + break; + } case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: { if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) @@ -2493,10 +2599,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { VectorType *NewVT = cast<VectorType>(II->getType()); if (Constant *CV0 = dyn_cast<Constant>(Arg0)) { if (Constant *CV1 = dyn_cast<Constant>(Arg1)) { - CV0 = ConstantExpr::getIntegerCast(CV0, NewVT, /*isSigned=*/!Zext); - CV1 = ConstantExpr::getIntegerCast(CV1, NewVT, /*isSigned=*/!Zext); - - return replaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); + Value *V0 = Builder.CreateIntCast(CV0, NewVT, /*isSigned=*/!Zext); + Value *V1 = Builder.CreateIntCast(CV1, NewVT, /*isSigned=*/!Zext); + return replaceInstUsesWith(CI, Builder.CreateMul(V0, V1)); } // Couldn't simplify - canonicalize constant to the RHS. @@ -2950,24 +3055,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return replaceOperand(CI, 0, InsertTuple); } - auto *DstTy = dyn_cast<FixedVectorType>(ReturnType); - auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + auto *DstTy = dyn_cast<VectorType>(ReturnType); + auto *VecTy = dyn_cast<VectorType>(Vec->getType()); - // Only canonicalize if the the destination vector and Vec are fixed - // vectors. if (DstTy && VecTy) { - unsigned DstNumElts = DstTy->getNumElements(); - unsigned VecNumElts = VecTy->getNumElements(); + auto DstEltCnt = DstTy->getElementCount(); + auto VecEltCnt = VecTy->getElementCount(); unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); // Extracting the entirety of Vec is a nop. - if (VecNumElts == DstNumElts) { + if (DstEltCnt == VecTy->getElementCount()) { replaceInstUsesWith(CI, Vec); return eraseInstFromFunction(CI); } + // Only canonicalize to shufflevector if the destination vector and + // Vec are fixed vectors. + if (VecEltCnt.isScalable() || DstEltCnt.isScalable()) + break; + SmallVector<int, 8> Mask; - for (unsigned i = 0; i != DstNumElts; ++i) + for (unsigned i = 0; i != DstEltCnt.getKnownMinValue(); ++i) Mask.push_back(IdxN + i); Value *Shuffle = Builder.CreateShuffleVector(Vec, Mask); @@ -3943,9 +4051,9 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); NC->setDebugLoc(Caller->getDebugLoc()); - Instruction *InsertPt = NewCall->getInsertionPointAfterDef(); - assert(InsertPt && "No place to insert cast"); - InsertNewInstBefore(NC, *InsertPt); + auto OptInsertPt = NewCall->getInsertionPointAfterDef(); + assert(OptInsertPt && "No place to insert cast"); + InsertNewInstBefore(NC, *OptInsertPt); Worklist.pushUsersToWorkList(*Caller); } else { NV = PoisonValue::get(Caller->getType()); @@ -3972,8 +4080,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { Instruction * InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, IntrinsicInst &Tramp) { - Value *Callee = Call.getCalledOperand(); - Type *CalleeTy = Callee->getType(); FunctionType *FTy = Call.getFunctionType(); AttributeList Attrs = Call.getAttributes(); @@ -4070,12 +4176,8 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, // Replace the trampoline call with a direct call. Let the generic // code sort out any function type mismatches. - FunctionType *NewFTy = FunctionType::get(FTy->getReturnType(), NewTypes, - FTy->isVarArg()); - Constant *NewCallee = - NestF->getType() == PointerType::getUnqual(NewFTy) ? - NestF : ConstantExpr::getBitCast(NestF, - PointerType::getUnqual(NewFTy)); + FunctionType *NewFTy = + FunctionType::get(FTy->getReturnType(), NewTypes, FTy->isVarArg()); AttributeList NewPAL = AttributeList::get(FTy->getContext(), Attrs.getFnAttrs(), Attrs.getRetAttrs(), NewArgAttrs); @@ -4085,19 +4187,18 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, Instruction *NewCaller; if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) { - NewCaller = InvokeInst::Create(NewFTy, NewCallee, - II->getNormalDest(), II->getUnwindDest(), - NewArgs, OpBundles); + NewCaller = InvokeInst::Create(NewFTy, NestF, II->getNormalDest(), + II->getUnwindDest(), NewArgs, OpBundles); cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&Call)) { NewCaller = - CallBrInst::Create(NewFTy, NewCallee, CBI->getDefaultDest(), + CallBrInst::Create(NewFTy, NestF, CBI->getDefaultDest(), CBI->getIndirectDests(), NewArgs, OpBundles); cast<CallBrInst>(NewCaller)->setCallingConv(CBI->getCallingConv()); cast<CallBrInst>(NewCaller)->setAttributes(NewPAL); } else { - NewCaller = CallInst::Create(NewFTy, NewCallee, NewArgs, OpBundles); + NewCaller = CallInst::Create(NewFTy, NestF, NewArgs, OpBundles); cast<CallInst>(NewCaller)->setTailCallKind( cast<CallInst>(Call).getTailCallKind()); cast<CallInst>(NewCaller)->setCallingConv( @@ -4113,7 +4214,6 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, // Replace the trampoline call with a direct call. Since there is no 'nest' // parameter, there is no need to adjust the argument list. Let the generic // code sort out any function type mismatches. - Constant *NewCallee = ConstantExpr::getBitCast(NestF, CalleeTy); - Call.setCalledFunction(FTy, NewCallee); + Call.setCalledFunction(FTy, NestF); return &Call; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 5c84f666616d..6629ca840a67 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -29,11 +29,8 @@ using namespace PatternMatch; /// true for, actually insert the code to evaluate the expression. Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned) { - if (Constant *C = dyn_cast<Constant>(V)) { - C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); - // If we got a constantexpr back, try to simplify it with DL info. - return ConstantFoldConstant(C, DL, &TLI); - } + if (Constant *C = dyn_cast<Constant>(V)) + return ConstantFoldIntegerCast(C, Ty, isSigned, DL); // Otherwise, it must be an instruction. Instruction *I = cast<Instruction>(V); @@ -112,7 +109,7 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, } Res->takeName(I); - return InsertNewInstWith(Res, *I); + return InsertNewInstWith(Res, I->getIterator()); } Instruction::CastOps @@ -217,7 +214,8 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { /// free to be evaluated in that type. This is a helper for canEvaluate*. static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { if (isa<Constant>(V)) - return true; + return match(V, m_ImmConstant()); + Value *X; if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) && X->getType() == Ty) @@ -229,7 +227,6 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { /// Filter out values that we can not evaluate in the destination type for free. /// This is a helper for canEvaluate*. static bool canNotEvaluateInType(Value *V, Type *Ty) { - assert(!isa<Constant>(V) && "Constant should already be handled."); if (!isa<Instruction>(V)) return true; // We don't extend or shrink something that has multiple uses -- doing so @@ -505,11 +502,13 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc)) return nullptr; - // We have an unnecessarily wide rotate! - // trunc (or (shl ShVal0, ShAmt), (lshr ShVal1, BitWidth - ShAmt)) - // Narrow the inputs and convert to funnel shift intrinsic: - // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) - Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + // Adjust the width of ShAmt for narrowed funnel shift operation: + // - Zero-extend if ShAmt is narrower than the destination type. + // - Truncate if ShAmt is wider, discarding non-significant high-order bits. + // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal), + // zext/trunc(ShAmt)). + Value *NarrowShAmt = Builder.CreateZExtOrTrunc(ShAmt, DestTy); + Value *X, *Y; X = Y = Builder.CreateTrunc(ShVal0, DestTy); if (ShVal0 != ShVal1) @@ -582,13 +581,15 @@ Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { APInt(SrcWidth, MaxShiftAmt)))) { auto *OldShift = cast<Instruction>(Trunc.getOperand(0)); bool IsExact = OldShift->isExact(); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); - ShAmt = Constant::mergeUndefsWith(ShAmt, C); - Value *Shift = - OldShift->getOpcode() == Instruction::AShr - ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) - : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); - return CastInst::CreateTruncOrBitCast(Shift, DestTy); + if (Constant *ShAmt = ConstantFoldIntegerCast(C, A->getType(), + /*IsSigned*/ true, DL)) { + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } } } break; @@ -904,19 +905,18 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. // zext (X != 0) to i32 --> X iff X has only the low bit set. // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. - if (Op1CV->isZero() && Cmp->isEquality() && - (Cmp->getOperand(0)->getType() == Zext.getType() || - Cmp->getPredicate() == ICmpInst::ICMP_NE)) { - // If Op1C some other power of two, convert: - KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); + if (Op1CV->isZero() && Cmp->isEquality()) { // Exactly 1 possible 1? But not the high-bit because that is // canonicalized to this form. + KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); APInt KnownZeroMask(~Known.Zero); - if (KnownZeroMask.isPowerOf2() && - (Zext.getType()->getScalarSizeInBits() != - KnownZeroMask.logBase2() + 1)) { - uint32_t ShAmt = KnownZeroMask.logBase2(); + uint32_t ShAmt = KnownZeroMask.logBase2(); + bool IsExpectShAmt = KnownZeroMask.isPowerOf2() && + (Zext.getType()->getScalarSizeInBits() != ShAmt + 1); + if (IsExpectShAmt && + (Cmp->getOperand(0)->getType() == Zext.getType() || + Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) { Value *In = Cmp->getOperand(0); if (ShAmt) { // Perform a logical shr by shiftamt. @@ -1184,14 +1184,14 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { Value *X; if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && X->getType() == DestTy) - return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, DestTy)); + return BinaryOperator::CreateAnd(X, Builder.CreateZExt(C, DestTy)); // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). Value *And; if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && X->getType() == DestTy) { - Constant *ZC = ConstantExpr::getZExt(C, DestTy); + Value *ZC = Builder.CreateZExt(C, DestTy); return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } @@ -1202,7 +1202,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // zext (and (trunc X), C) --> and X, (zext C) if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) && X->getType() == DestTy) { - Constant *ZextC = ConstantExpr::getZExt(C, DestTy); + Value *ZextC = Builder.CreateZExt(C, DestTy); return BinaryOperator::CreateAnd(X, ZextC); } @@ -1221,6 +1221,22 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { } } + if (!Zext.hasNonNeg()) { + // If this zero extend is only used by a shift, add nneg flag. + if (Zext.hasOneUse() && + SrcTy->getScalarSizeInBits() > + Log2_64_Ceil(DestTy->getScalarSizeInBits()) && + match(Zext.user_back(), m_Shift(m_Value(), m_Specific(&Zext)))) { + Zext.setNonNeg(); + return &Zext; + } + + if (isKnownNonNegative(Src, SQ.getWithInstruction(&Zext))) { + Zext.setNonNeg(); + return &Zext; + } + } + return nullptr; } @@ -1373,8 +1389,11 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { unsigned DestBitSize = DestTy->getScalarSizeInBits(); // If the value being extended is zero or positive, use a zext instead. - if (isKnownNonNegative(Src, DL, 0, &AC, &Sext, &DT)) - return CastInst::Create(Instruction::ZExt, Src, DestTy); + if (isKnownNonNegative(Src, SQ.getWithInstruction(&Sext))) { + auto CI = CastInst::Create(Instruction::ZExt, Src, DestTy); + CI->setNonNeg(true); + return CI; + } // Try to extend the entire expression tree to the wide destination type. if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) { @@ -1445,9 +1464,11 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { // TODO: Eventually this could be subsumed by EvaluateInDifferentType. Constant *BA = nullptr, *CA = nullptr; if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)), - m_Constant(CA))) && + m_ImmConstant(CA))) && BA->isElementWiseEqual(CA) && A->getType() == DestTy) { - Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy); + Constant *WideCurrShAmt = + ConstantFoldCastOperand(Instruction::SExt, CA, DestTy, DL); + assert(WideCurrShAmt && "Constant folding of ImmConstant cannot fail"); Constant *NumLowbitsLeft = ConstantExpr::getSub( ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt); Constant *NewShAmt = ConstantExpr::getSub( @@ -1915,29 +1936,6 @@ Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { return nullptr; } -/// Implement the transforms for cast of pointer (bitcast/ptrtoint) -Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) { - Value *Src = CI.getOperand(0); - - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) { - // If casting the result of a getelementptr instruction with no offset, turn - // this into a cast of the original pointer! - if (GEP->hasAllZeroIndices() && - // If CI is an addrspacecast and GEP changes the poiner type, merging - // GEP into CI would undo canonicalizing addrspacecast with different - // pointer types, causing infinite loops. - (!isa<AddrSpaceCastInst>(CI) || - GEP->getType() == GEP->getPointerOperandType())) { - // Changing the cast operand is usually not a good idea but it is safe - // here because the pointer operand is being replaced with another - // pointer operand so the opcode doesn't need to change. - return replaceOperand(CI, 0, GEP->getOperand(0)); - } - } - - return commonCastTransforms(CI); -} - Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { // If the destination integer type is not the intptr_t type for this target, // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast @@ -1955,6 +1953,15 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); } + // (ptrtoint (ptrmask P, M)) + // -> (and (ptrtoint P), M) + // This is generally beneficial as `and` is better supported than `ptrmask`. + Value *Ptr, *Mask; + if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr), + m_Value(Mask)))) && + Mask->getType() == Ty) + return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask); + if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) { // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. // While this can increase the number of instructions it doesn't actually @@ -1979,7 +1986,7 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { return InsertElementInst::Create(Vec, NewCast, Index); } - return commonPointerCastTransforms(CI); + return commonCastTransforms(CI); } /// This input value (which is known to have vector type) is being zero extended @@ -2136,9 +2143,12 @@ static bool collectInsertionElements(Value *V, unsigned Shift, Type *ElementIntTy = IntegerType::get(C->getContext(), ElementSize); for (unsigned i = 0; i != NumElts; ++i) { - unsigned ShiftI = Shift+i*ElementSize; - Constant *Piece = ConstantExpr::getLShr(C, ConstantInt::get(C->getType(), - ShiftI)); + unsigned ShiftI = Shift + i * ElementSize; + Constant *Piece = ConstantFoldBinaryInstruction( + Instruction::LShr, C, ConstantInt::get(C->getType(), ShiftI)); + if (!Piece) + return false; + Piece = ConstantExpr::getTrunc(Piece, ElementIntTy); if (!collectInsertionElements(Piece, ShiftI, Elements, VecEltTy, isBigEndian)) @@ -2701,11 +2711,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (Instruction *I = foldBitCastSelect(CI, Builder)) return I; - if (SrcTy->isPointerTy()) - return commonPointerCastTransforms(CI); return commonCastTransforms(CI); } Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { - return commonPointerCastTransforms(CI); + return commonCastTransforms(CI); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 656f04370e17..e42e011bd436 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -12,12 +12,14 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" @@ -26,6 +28,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <bitset> using namespace llvm; using namespace PatternMatch; @@ -412,7 +415,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( /// Returns true if we can rewrite Start as a GEP with pointer Base /// and some integer offset. The nodes that need to be re-written /// for this transformation will be added to Explored. -static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, +static bool canRewriteGEPAsOffset(Value *Start, Value *Base, const DataLayout &DL, SetVector<Value *> &Explored) { SmallVector<Value *, 16> WorkList(1, Start); @@ -440,27 +443,15 @@ static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, continue; } - if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) && - !isa<GetElementPtrInst>(V) && !isa<PHINode>(V)) + if (!isa<GetElementPtrInst>(V) && !isa<PHINode>(V)) // We've found some value that we can't explore which is different from // the base. Therefore we can't do this transformation. return false; - if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) { - auto *CI = cast<CastInst>(V); - if (!CI->isNoopCast(DL)) - return false; - - if (!Explored.contains(CI->getOperand(0))) - WorkList.push_back(CI->getOperand(0)); - } - if (auto *GEP = dyn_cast<GEPOperator>(V)) { - // We're limiting the GEP to having one index. This will preserve - // the original pointer type. We could handle more cases in the - // future. - if (GEP->getNumIndices() != 1 || !GEP->isInBounds() || - GEP->getSourceElementType() != ElemTy) + // Only allow inbounds GEPs with at most one variable offset. + auto IsNonConst = [](Value *V) { return !isa<ConstantInt>(V); }; + if (!GEP->isInBounds() || count_if(GEP->indices(), IsNonConst) > 1) return false; if (!Explored.contains(GEP->getOperand(0))) @@ -514,7 +505,8 @@ static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, static void setInsertionPoint(IRBuilder<> &Builder, Value *V, bool Before = true) { if (auto *PHI = dyn_cast<PHINode>(V)) { - Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt()); + BasicBlock *Parent = PHI->getParent(); + Builder.SetInsertPoint(Parent, Parent->getFirstInsertionPt()); return; } if (auto *I = dyn_cast<Instruction>(V)) { @@ -526,7 +518,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V, if (auto *A = dyn_cast<Argument>(V)) { // Set the insertion point in the entry block. BasicBlock &Entry = A->getParent()->getEntryBlock(); - Builder.SetInsertPoint(&*Entry.getFirstInsertionPt()); + Builder.SetInsertPoint(&Entry, Entry.getFirstInsertionPt()); return; } // Otherwise, this is a constant and we don't need to set a new @@ -536,7 +528,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V, /// Returns a re-written value of Start as an indexed GEP using Base as a /// pointer. -static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, +static Value *rewriteGEPAsOffset(Value *Start, Value *Base, const DataLayout &DL, SetVector<Value *> &Explored, InstCombiner &IC) { @@ -567,36 +559,18 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, // Create all the other instructions. for (Value *Val : Explored) { - if (NewInsts.contains(Val)) continue; - if (auto *CI = dyn_cast<CastInst>(Val)) { - // Don't get rid of the intermediate variable here; the store can grow - // the map which will invalidate the reference to the input value. - Value *V = NewInsts[CI->getOperand(0)]; - NewInsts[CI] = V; - continue; - } if (auto *GEP = dyn_cast<GEPOperator>(Val)) { - Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)] - : GEP->getOperand(1); setInsertionPoint(Builder, GEP); - // Indices might need to be sign extended. GEPs will magically do - // this, but we need to do it ourselves here. - if (Index->getType()->getScalarSizeInBits() != - NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) { - Index = Builder.CreateSExtOrTrunc( - Index, NewInsts[GEP->getOperand(0)]->getType(), - GEP->getOperand(0)->getName() + ".sext"); - } - - auto *Op = NewInsts[GEP->getOperand(0)]; + Value *Op = NewInsts[GEP->getOperand(0)]; + Value *OffsetV = emitGEPOffset(&Builder, DL, GEP); if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) - NewInsts[GEP] = Index; + NewInsts[GEP] = OffsetV; else NewInsts[GEP] = Builder.CreateNSWAdd( - Op, Index, GEP->getOperand(0)->getName() + ".add"); + Op, OffsetV, GEP->getOperand(0)->getName() + ".add"); continue; } if (isa<PHINode>(Val)) @@ -624,23 +598,14 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, } } - PointerType *PtrTy = - ElemTy->getPointerTo(Start->getType()->getPointerAddressSpace()); for (Value *Val : Explored) { if (Val == Base) continue; - // Depending on the type, for external users we have to emit - // a GEP or a GEP + ptrtoint. setInsertionPoint(Builder, Val, false); - - // Cast base to the expected type. - Value *NewVal = Builder.CreateBitOrPointerCast( - Base, PtrTy, Start->getName() + "to.ptr"); - NewVal = Builder.CreateInBoundsGEP(ElemTy, NewVal, ArrayRef(NewInsts[Val]), - Val->getName() + ".ptr"); - NewVal = Builder.CreateBitOrPointerCast( - NewVal, Val->getType(), Val->getName() + ".conv"); + // Create GEP for external users. + Value *NewVal = Builder.CreateInBoundsGEP( + Builder.getInt8Ty(), Base, NewInsts[Val], Val->getName() + ".ptr"); IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal); // Add old instruction to worklist for DCE. We don't directly remove it // here because the original compare is one of the users. @@ -650,48 +615,6 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, return NewInsts[Start]; } -/// Looks through GEPs, IntToPtrInsts and PtrToIntInsts in order to express -/// the input Value as a constant indexed GEP. Returns a pair containing -/// the GEPs Pointer and Index. -static std::pair<Value *, Value *> -getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) { - Type *IndexType = IntegerType::get(V->getContext(), - DL.getIndexTypeSizeInBits(V->getType())); - - Constant *Index = ConstantInt::getNullValue(IndexType); - while (true) { - if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { - // We accept only inbouds GEPs here to exclude the possibility of - // overflow. - if (!GEP->isInBounds()) - break; - if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 && - GEP->getSourceElementType() == ElemTy) { - V = GEP->getOperand(0); - Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1)); - Index = ConstantExpr::getAdd( - Index, ConstantExpr::getSExtOrTrunc(GEPIndex, IndexType)); - continue; - } - break; - } - if (auto *CI = dyn_cast<IntToPtrInst>(V)) { - if (!CI->isNoopCast(DL)) - break; - V = CI->getOperand(0); - continue; - } - if (auto *CI = dyn_cast<PtrToIntInst>(V)) { - if (!CI->isNoopCast(DL)) - break; - V = CI->getOperand(0); - continue; - } - break; - } - return {V, Index}; -} - /// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant. /// We can look through PHIs, GEPs and casts in order to determine a common base /// between GEPLHS and RHS. @@ -706,14 +629,19 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, if (!GEPLHS->hasAllConstantIndices()) return nullptr; - Type *ElemTy = GEPLHS->getSourceElementType(); - Value *PtrBase, *Index; - std::tie(PtrBase, Index) = getAsConstantIndexedAddress(ElemTy, GEPLHS, DL); + APInt Offset(DL.getIndexTypeSizeInBits(GEPLHS->getType()), 0); + Value *PtrBase = + GEPLHS->stripAndAccumulateConstantOffsets(DL, Offset, + /*AllowNonInbounds*/ false); + + // Bail if we looked through addrspacecast. + if (PtrBase->getType() != GEPLHS->getType()) + return nullptr; // The set of nodes that will take part in this transformation. SetVector<Value *> Nodes; - if (!canRewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes)) + if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes)) return nullptr; // We know we can re-write this as @@ -722,13 +650,14 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, // can't have overflow on either side. We can therefore re-write // this as: // OFFSET1 cmp OFFSET2 - Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes, IC); + Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes, IC); // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written // GEP having PtrBase as the pointer base, and has returned in NewRHS the // offset. Since Index is the offset of LHS to the base pointer, we will now // compare the offsets instead of comparing the pointers. - return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), + IC.Builder.getInt(Offset), NewRHS); } /// Fold comparisons between a GEP instruction and something else. At this point @@ -844,17 +773,6 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } - // If one of the GEPs has all zero indices, recurse. - // FIXME: Handle vector of pointers. - if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices()) - return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), - ICmpInst::getSwappedPredicate(Cond), I); - - // If the other GEP has all zero indices, recurse. - // FIXME: Handle vector of pointers. - if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices()) - return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); - bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) { @@ -894,8 +812,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! if ((GEPsInBounds || CmpInst::isEquality(Cond)) && - (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && - (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { + (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && + (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); Value *R = EmitGEPOffset(GEPRHS); @@ -1285,9 +1203,9 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { if (Pred == ICmpInst::ICMP_SGT) { Value *A, *B; if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) { - if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) + if (isKnownPositive(A, SQ.getWithInstruction(&Cmp))) return new ICmpInst(Pred, B, Cmp.getOperand(1)); - if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) + if (isKnownPositive(B, SQ.getWithInstruction(&Cmp))) return new ICmpInst(Pred, A, Cmp.getOperand(1)); } } @@ -1554,6 +1472,61 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, return nullptr; } +/// Fold icmp (trunc X), (trunc Y). +/// Fold icmp (trunc X), (zext Y). +Instruction * +InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, + const SimplifyQuery &Q) { + if (Cmp.isSigned()) + return nullptr; + + Value *X, *Y; + ICmpInst::Predicate Pred; + bool YIsZext = false; + // Try to match icmp (trunc X), (trunc Y) + if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) { + if (X->getType() != Y->getType() && + (!Cmp.getOperand(0)->hasOneUse() || !Cmp.getOperand(1)->hasOneUse())) + return nullptr; + if (!isDesirableIntType(X->getType()->getScalarSizeInBits()) && + isDesirableIntType(Y->getType()->getScalarSizeInBits())) { + std::swap(X, Y); + Pred = Cmp.getSwappedPredicate(Pred); + } + } + // Try to match icmp (trunc X), (zext Y) + else if (match(&Cmp, m_c_ICmp(Pred, m_Trunc(m_Value(X)), + m_OneUse(m_ZExt(m_Value(Y)))))) + + YIsZext = true; + else + return nullptr; + + Type *TruncTy = Cmp.getOperand(0)->getType(); + unsigned TruncBits = TruncTy->getScalarSizeInBits(); + + // If this transform will end up changing from desirable types -> undesirable + // types skip it. + if (isDesirableIntType(TruncBits) && + !isDesirableIntType(X->getType()->getScalarSizeInBits())) + return nullptr; + + // Check if the trunc is unneeded. + KnownBits KnownX = llvm::computeKnownBits(X, /*Depth*/ 0, Q); + if (KnownX.countMaxActiveBits() > TruncBits) + return nullptr; + + if (!YIsZext) { + // If Y is also a trunc, make sure it is unneeded. + KnownBits KnownY = llvm::computeKnownBits(Y, /*Depth*/ 0, Q); + if (KnownY.countMaxActiveBits() > TruncBits) + return nullptr; + } + + Value *NewY = Builder.CreateZExtOrTrunc(Y, X->getType()); + return new ICmpInst(Pred, X, NewY); +} + /// Fold icmp (xor X, Y), C. Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, @@ -1944,19 +1917,18 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return nullptr; } -/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0. -static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, - InstCombiner::BuilderTy &Builder) { - // Are we using xors to bitwise check for a pair or pairs of (in)equalities? - // Convert to a shorter form that has more potential to be folded even - // further. - // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) - // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) - // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 --> +/// Fold icmp eq/ne (or (xor/sub (X1, X2), xor/sub (X3, X4))), 0. +static Value *foldICmpOrXorSubChain(ICmpInst &Cmp, BinaryOperator *Or, + InstCombiner::BuilderTy &Builder) { + // Are we using xors or subs to bitwise check for a pair or pairs of + // (in)equalities? Convert to a shorter form that has more potential to be + // folded even further. + // ((X1 ^/- X2) || (X3 ^/- X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^/- X2) || (X3 ^/- X4)) != 0 --> (X1 != X2) || (X3 != X4) + // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) == 0 --> // (X1 == X2) && (X3 == X4) && (X5 == X6) - // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 --> + // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) != 0 --> // (X1 != X2) || (X3 != X4) || (X5 != X6) - // TODO: Implement for sub SmallVector<std::pair<Value *, Value *>, 2> CmpValues; SmallVector<Value *, 16> WorkList(1, Or); @@ -1967,9 +1939,16 @@ static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, if (match(OrOperatorArgument, m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) { CmpValues.emplace_back(Lhs, Rhs); - } else { - WorkList.push_back(OrOperatorArgument); + return; } + + if (match(OrOperatorArgument, + m_OneUse(m_Sub(m_Value(Lhs), m_Value(Rhs))))) { + CmpValues.emplace_back(Lhs, Rhs); + return; + } + + WorkList.push_back(OrOperatorArgument); }; Value *CurrentValue = WorkList.pop_back_val(); @@ -2082,7 +2061,7 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return BinaryOperator::Create(BOpc, CmpP, CmpQ); } - if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder)) + if (Value *V = foldICmpOrXorSubChain(Cmp, Or, Builder)) return replaceInstUsesWith(Cmp, V); return nullptr; @@ -2443,7 +2422,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, // constant-value-based preconditions in the folds below, then we could assert // those conditions rather than checking them. This is difficult because of // undef/poison (PR34838). - if (IsAShr) { + if (IsAShr && Shr->hasOneUse()) { if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) { // When ShAmtC can be shifted losslessly: // icmp PRED (ashr exact X, ShAmtC), C --> icmp PRED X, (C << ShAmtC) @@ -2483,7 +2462,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, ConstantInt::getAllOnesValue(ShrTy)); } } - } else { + } else if (!IsAShr) { if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) @@ -2888,19 +2867,97 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); } +static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0, + Value *Op1, IRBuilderBase &Builder, + bool HasOneUse) { + auto FoldConstant = [&](bool Val) { + Constant *Res = Val ? Builder.getTrue() : Builder.getFalse(); + if (Op0->getType()->isVectorTy()) + Res = ConstantVector::getSplat( + cast<VectorType>(Op0->getType())->getElementCount(), Res); + return Res; + }; + + switch (Table.to_ulong()) { + case 0: // 0 0 0 0 + return FoldConstant(false); + case 1: // 0 0 0 1 + return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr; + case 2: // 0 0 1 0 + return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr; + case 3: // 0 0 1 1 + return Builder.CreateNot(Op0); + case 4: // 0 1 0 0 + return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr; + case 5: // 0 1 0 1 + return Builder.CreateNot(Op1); + case 6: // 0 1 1 0 + return Builder.CreateXor(Op0, Op1); + case 7: // 0 1 1 1 + return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr; + case 8: // 1 0 0 0 + return Builder.CreateAnd(Op0, Op1); + case 9: // 1 0 0 1 + return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr; + case 10: // 1 0 1 0 + return Op1; + case 11: // 1 0 1 1 + return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr; + case 12: // 1 1 0 0 + return Op0; + case 13: // 1 1 0 1 + return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr; + case 14: // 1 1 1 0 + return Builder.CreateOr(Op0, Op1); + case 15: // 1 1 1 1 + return FoldConstant(true); + default: + llvm_unreachable("Invalid Operation"); + } + return nullptr; +} + /// Fold icmp (add X, Y), C. Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, const APInt &C) { Value *Y = Add->getOperand(1); + Value *X = Add->getOperand(0); + + Value *Op0, *Op1; + Instruction *Ext0, *Ext1; + const CmpInst::Predicate Pred = Cmp.getPredicate(); + if (match(Add, + m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))), + m_CombineAnd(m_Instruction(Ext1), + m_ZExtOrSExt(m_Value(Op1))))) && + Op0->getType()->isIntOrIntVectorTy(1) && + Op1->getType()->isIntOrIntVectorTy(1)) { + unsigned BW = C.getBitWidth(); + std::bitset<4> Table; + auto ComputeTable = [&](bool Op0Val, bool Op1Val) { + int Res = 0; + if (Op0Val) + Res += isa<ZExtInst>(Ext0) ? 1 : -1; + if (Op1Val) + Res += isa<ZExtInst>(Ext1) ? 1 : -1; + return ICmpInst::compare(APInt(BW, Res, true), C, Pred); + }; + + Table[0] = ComputeTable(false, false); + Table[1] = ComputeTable(false, true); + Table[2] = ComputeTable(true, false); + Table[3] = ComputeTable(true, true); + if (auto *Cond = + createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse())) + return replaceInstUsesWith(Cmp, Cond); + } const APInt *C2; if (Cmp.isEquality() || !match(Y, m_APInt(C2))) return nullptr; // Fold icmp pred (add X, C2), C. - Value *X = Add->getOperand(0); Type *Ty = Add->getType(); - const CmpInst::Predicate Pred = Cmp.getPredicate(); // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE @@ -3172,18 +3229,6 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { } } - // Test to see if the operands of the icmp are casted versions of other - // values. If the ptr->ptr cast can be stripped off both arguments, do so. - if (DstType->isPointerTy() && (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { - // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast - // so eliminate it as well. - if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) - Op1 = BC2->getOperand(0); - - Op1 = Builder.CreateBitCast(Op1, SrcType); - return new ICmpInst(Pred, BCSrcOp, Op1); - } - const APInt *C; if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() || !SrcType->isIntOrIntVectorTy()) @@ -3196,10 +3241,12 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 // Example: are all elements equal? --> are zero elements not equal? // TODO: Try harder to reduce compare of 2 freely invertible operands? - if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && - isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { - Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), DstType); - return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); + if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse()) { + if (Value *NotBCSrcOp = + getFreelyInverted(BCSrcOp, BCSrcOp->hasOneUse(), &Builder)) { + Value *Cast = Builder.CreateBitCast(NotBCSrcOp, DstType); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); + } } // If this is checking if all elements of an extended vector are clear or not, @@ -3878,21 +3925,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { return nullptr; switch (LHSI->getOpcode()) { - case Instruction::GetElementPtr: - // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null - if (RHSC->isNullValue() && - cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) - return new ICmpInst( - I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; case Instruction::PHI: - // Only fold icmp into the PHI if the phi and icmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) - return NV; + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) + return NV; break; case Instruction::IntToPtr: // icmp pred inttoptr(X), null -> icmp pred X, 0 @@ -4243,7 +4278,12 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; - NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + if (NewShAmt->getType() != WidestTy) { + NewShAmt = + ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, WidestTy, SQ.DL); + if (!NewShAmt) + return nullptr; + } unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); // Is the new shift amount smaller than the bit width? @@ -4424,6 +4464,65 @@ static Instruction *foldICmpXNegX(ICmpInst &I, return nullptr; } +static Instruction *foldICmpAndXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + // Normalize and operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) { + std::swap(Op0, Op1); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + if (!match(Op0, m_c_And(m_Specific(Op1), m_Value(A)))) + return nullptr; + + // (icmp (X & Y) u< X --> (X & Y) != X + if (Pred == ICmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + // (icmp (X & Y) u>= X --> (X & Y) == X + if (Pred == ICmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + + return nullptr; +} + +static Instruction *foldICmpOrXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + + // Normalize or operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_Or(m_Specific(Op0), m_Value(A)))) { + std::swap(Op0, Op1); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else if (!match(Op0, m_c_Or(m_Specific(Op1), m_Value(A)))) { + return nullptr; + } + + // icmp (X | Y) u<= X --> (X | Y) == X + if (Pred == ICmpInst::ICMP_ULE) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + + // icmp (X | Y) u> X --> (X | Y) != X + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + if (ICmpInst::isEquality(Pred) && Op0->hasOneUse()) { + // icmp (X | Y) eq/ne Y --> (X & ~Y) eq/ne 0 if Y is freely invertible + if (Value *NotOp1 = + IC.getFreelyInverted(Op1, Op1->hasOneUse(), &IC.Builder)) + return new ICmpInst(Pred, IC.Builder.CreateAnd(A, NotOp1), + Constant::getNullValue(Op1->getType())); + // icmp (X | Y) eq/ne Y --> (~X | Y) eq/ne -1 if X is freely invertible. + if (Value *NotA = IC.getFreelyInverted(A, A->hasOneUse(), &IC.Builder)) + return new ICmpInst(Pred, IC.Builder.CreateOr(Op1, NotA), + Constant::getAllOnesValue(Op1->getType())); + } + return nullptr; +} + static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, InstCombinerImpl &IC) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; @@ -4746,6 +4845,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (Instruction * R = foldICmpXorXX(I, Q, *this)) return R; + if (Instruction *R = foldICmpOrXX(I, Q, *this)) + return R; { // Try to remove shared multiplier from comparison: @@ -4915,6 +5016,9 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldICmpAndXX(I, Q, *this)) + return R; + if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) return replaceInstUsesWith(I, V); @@ -4924,88 +5028,153 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, return nullptr; } -/// Fold icmp Pred min|max(X, Y), X. -static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { - ICmpInst::Predicate Pred = Cmp.getPredicate(); - Value *Op0 = Cmp.getOperand(0); - Value *X = Cmp.getOperand(1); - - // Canonicalize minimum or maximum operand to LHS of the icmp. - if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) || - match(X, m_c_SMax(m_Specific(Op0), m_Value())) || - match(X, m_c_UMin(m_Specific(Op0), m_Value())) || - match(X, m_c_UMax(m_Specific(Op0), m_Value()))) { - std::swap(Op0, X); - Pred = Cmp.getSwappedPredicate(); - } - - Value *Y; - if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) { - // smin(X, Y) == X --> X s<= Y - // smin(X, Y) s>= X --> X s<= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); - - // smin(X, Y) != X --> X s> Y - // smin(X, Y) s< X --> X s> Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT) - return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); - - // These cases should be handled in InstSimplify: - // smin(X, Y) s<= X --> true - // smin(X, Y) s> X --> false +/// Fold icmp Pred min|max(X, Y), Z. +Instruction * +InstCombinerImpl::foldICmpWithMinMaxImpl(Instruction &I, + MinMaxIntrinsic *MinMax, Value *Z, + ICmpInst::Predicate Pred) { + Value *X = MinMax->getLHS(); + Value *Y = MinMax->getRHS(); + if (ICmpInst::isSigned(Pred) && !MinMax->isSigned()) return nullptr; - } - - if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) { - // smax(X, Y) == X --> X s>= Y - // smax(X, Y) s<= X --> X s>= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - - // smax(X, Y) != X --> X s< Y - // smax(X, Y) s> X --> X s< Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - - // These cases should be handled in InstSimplify: - // smax(X, Y) s>= X --> true - // smax(X, Y) s< X --> false + if (ICmpInst::isUnsigned(Pred) && MinMax->isSigned()) return nullptr; + SimplifyQuery Q = SQ.getWithInstruction(&I); + auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> { + if (!Val) + return std::nullopt; + if (match(Val, m_One())) + return true; + if (match(Val, m_Zero())) + return false; + return std::nullopt; + }; + auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(Pred, X, Z, Q)); + auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(Pred, Y, Z, Q)); + if (!CmpXZ.has_value() && !CmpYZ.has_value()) + return nullptr; + if (!CmpXZ.has_value()) { + std::swap(X, Y); + std::swap(CmpXZ, CmpYZ); } - if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) { - // umin(X, Y) == X --> X u<= Y - // umin(X, Y) u>= X --> X u<= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE) - return new ICmpInst(ICmpInst::ICMP_ULE, X, Y); - - // umin(X, Y) != X --> X u> Y - // umin(X, Y) u< X --> X u> Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT) - return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); + auto FoldIntoCmpYZ = [&]() -> Instruction * { + if (CmpYZ.has_value()) + return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *CmpYZ)); + return ICmpInst::Create(Instruction::ICmp, Pred, Y, Z); + }; - // These cases should be handled in InstSimplify: - // umin(X, Y) u<= X --> true - // umin(X, Y) u> X --> false - return nullptr; + switch (Pred) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: { + // If X == Z: + // Expr Result + // min(X, Y) == Z X <= Y + // max(X, Y) == Z X >= Y + // min(X, Y) != Z X > Y + // max(X, Y) != Z X < Y + if ((Pred == ICmpInst::ICMP_EQ) == *CmpXZ) { + ICmpInst::Predicate NewPred = + ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); + if (Pred == ICmpInst::ICMP_NE) + NewPred = ICmpInst::getInversePredicate(NewPred); + return ICmpInst::Create(Instruction::ICmp, NewPred, X, Y); + } + // Otherwise (X != Z): + ICmpInst::Predicate NewPred = MinMax->getPredicate(); + auto MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q)); + if (!MinMaxCmpXZ.has_value()) { + std::swap(X, Y); + std::swap(CmpXZ, CmpYZ); + // Re-check pre-condition X != Z + if (!CmpXZ.has_value() || (Pred == ICmpInst::ICMP_EQ) == *CmpXZ) + break; + MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q)); + } + if (!MinMaxCmpXZ.has_value()) + break; + if (*MinMaxCmpXZ) { + // Expr Fact Result + // min(X, Y) == Z X < Z false + // max(X, Y) == Z X > Z false + // min(X, Y) != Z X < Z true + // max(X, Y) != Z X > Z true + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred == ICmpInst::ICMP_NE)); + } else { + // Expr Fact Result + // min(X, Y) == Z X > Z Y == Z + // max(X, Y) == Z X < Z Y == Z + // min(X, Y) != Z X > Z Y != Z + // max(X, Y) != Z X < Z Y != Z + return FoldIntoCmpYZ(); + } + break; + } + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: { + bool IsSame = MinMax->getPredicate() == ICmpInst::getStrictPredicate(Pred); + if (*CmpXZ) { + if (IsSame) { + // Expr Fact Result + // min(X, Y) < Z X < Z true + // min(X, Y) <= Z X <= Z true + // max(X, Y) > Z X > Z true + // max(X, Y) >= Z X >= Z true + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + } else { + // Expr Fact Result + // max(X, Y) < Z X < Z Y < Z + // max(X, Y) <= Z X <= Z Y <= Z + // min(X, Y) > Z X > Z Y > Z + // min(X, Y) >= Z X >= Z Y >= Z + return FoldIntoCmpYZ(); + } + } else { + if (IsSame) { + // Expr Fact Result + // min(X, Y) < Z X >= Z Y < Z + // min(X, Y) <= Z X > Z Y <= Z + // max(X, Y) > Z X <= Z Y > Z + // max(X, Y) >= Z X < Z Y >= Z + return FoldIntoCmpYZ(); + } else { + // Expr Fact Result + // max(X, Y) < Z X >= Z false + // max(X, Y) <= Z X > Z false + // min(X, Y) > Z X <= Z false + // min(X, Y) >= Z X < Z false + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + } + } + break; + } + default: + break; } - if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) { - // umax(X, Y) == X --> X u>= Y - // umax(X, Y) u<= X --> X u>= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE) - return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); + return nullptr; +} +Instruction *InstCombinerImpl::foldICmpWithMinMax(ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Lhs = Cmp.getOperand(0); + Value *Rhs = Cmp.getOperand(1); - // umax(X, Y) != X --> X u< Y - // umax(X, Y) u> X --> X u< Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Lhs)) { + if (Instruction *Res = foldICmpWithMinMaxImpl(Cmp, MinMax, Rhs, Pred)) + return Res; + } - // These cases should be handled in InstSimplify: - // umax(X, Y) u>= X --> true - // umax(X, Y) u< X --> false - return nullptr; + if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Rhs)) { + if (Instruction *Res = foldICmpWithMinMaxImpl( + Cmp, MinMax, Lhs, ICmpInst::getSwappedPredicate(Pred))) + return Res; } return nullptr; @@ -5173,35 +5342,6 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); } - // Test if 2 values have different or same signbits: - // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 - // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 - // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0 - // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1 - Instruction *ExtI; - if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) && - (Op0->hasOneUse() || Op1->hasOneUse())) { - unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); - Instruction *ShiftI; - Value *X, *Y; - ICmpInst::Predicate Pred2; - if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), - m_Shr(m_Value(X), - m_SpecificIntAllowUndef(OpWidth - 1)))) && - match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) && - Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) { - unsigned ExtOpc = ExtI->getOpcode(); - unsigned ShiftOpc = ShiftI->getOpcode(); - if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || - (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { - Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); - Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) - : Builder.CreateIsNotNeg(Xor); - return replaceInstUsesWith(I, R); - } - } - } - // (A >> C) == (B >> C) --> (A^B) u< (1 << C) // For lshr and ashr pairs. const APInt *AP1, *AP2; @@ -5307,6 +5447,40 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { Pred, A, Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B})); + // Canonicalize: + // icmp eq/ne OneUse(A ^ Cst), B --> icmp eq/ne (A ^ B), Cst + Constant *Cst; + if (match(&I, m_c_ICmp(PredUnused, + m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))), + m_CombineAnd(m_Value(B), m_Unless(m_ImmConstant()))))) + return new ICmpInst(Pred, Builder.CreateXor(A, B), Cst); + + { + // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2) + auto m_Matcher = + m_CombineOr(m_CombineOr(m_c_Add(m_Value(B), m_Deferred(A)), + m_c_Xor(m_Value(B), m_Deferred(A))), + m_Sub(m_Value(B), m_Deferred(A))); + std::optional<bool> IsZero = std::nullopt; + if (match(&I, m_c_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)), + m_Deferred(A)))) + IsZero = false; + // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0) + else if (match(&I, + m_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)), + m_Zero()))) + IsZero = true; + + if (IsZero && isKnownToBeAPowerOfTwo(A, /* OrZero */ true, /*Depth*/ 0, &I)) + // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2) + // -> (icmp eq/ne (and X, P2), 0) + // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0) + // -> (icmp eq/ne (and X, P2), P2) + return new ICmpInst(Pred, Builder.CreateAnd(B, A), + *IsZero ? A + : ConstantInt::getNullValue(A->getType())); + } + return nullptr; } @@ -5383,8 +5557,8 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { // icmp Pred (ext X), (ext Y) Value *Y; if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) { - bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0)); - bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1)); + bool IsZext0 = isa<ZExtInst>(ICmp.getOperand(0)); + bool IsZext1 = isa<ZExtInst>(ICmp.getOperand(1)); if (IsZext0 != IsZext1) { // If X and Y and both i1 @@ -5396,11 +5570,16 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y), Constant::getNullValue(X->getType())); - // If we have mismatched casts, treat the zext of a non-negative source as - // a sext to simulate matching casts. Otherwise, we are done. - // TODO: Can we handle some predicates (equality) without non-negative? - if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || - (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) + // If we have mismatched casts and zext has the nneg flag, we can + // treat the "zext nneg" as "sext". Otherwise, we cannot fold and quit. + + auto *NonNegInst0 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(0)); + auto *NonNegInst1 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(1)); + + bool IsNonNeg0 = NonNegInst0 && NonNegInst0->hasNonNeg(); + bool IsNonNeg1 = NonNegInst1 && NonNegInst1->hasNonNeg(); + + if ((IsZext0 && IsNonNeg0) || (IsZext1 && IsNonNeg1)) IsSignedExt = true; else return nullptr; @@ -5442,25 +5621,20 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { if (!C) return nullptr; - // Compute the constant that would happen if we truncated to SrcTy then - // re-extended to DestTy. + // If a lossless truncate is possible... Type *SrcTy = CastOp0->getSrcTy(); - Type *DestTy = CastOp0->getDestTy(); - Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); - Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); - - // If the re-extended constant didn't change... - if (Res2 == C) { + Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode()); + if (Res) { if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), X, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res); // A signed comparison of sign extended values simplifies into a // signed comparison. if (IsSignedExt && IsSignedCmp) - return new ICmpInst(ICmp.getPredicate(), X, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res); } // The re-extended constant changed, partly changed (in the case of a vector), @@ -5518,13 +5692,8 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { Value *NewOp1 = nullptr; if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { Value *PtrSrc = PtrToIntOp1->getOperand(0); - if (PtrSrc->getType()->getPointerAddressSpace() == - Op0Src->getType()->getPointerAddressSpace()) { + if (PtrSrc->getType() == Op0Src->getType()) NewOp1 = PtrToIntOp1->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (Op0Src->getType() != NewOp1->getType()) - NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); - } } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); } @@ -5641,22 +5810,20 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, - Value *OtherVal, + const APInt *OtherVal, InstCombinerImpl &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. if (!isa<IntegerType>(MulVal->getType())) return nullptr; - assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); - assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); auto *MulInstr = dyn_cast<Instruction>(MulVal); if (!MulInstr) return nullptr; assert(MulInstr->getOpcode() == Instruction::Mul); - auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)), - *RHS = cast<ZExtOperator>(MulInstr->getOperand(1)); + auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)), + *RHS = cast<ZExtInst>(MulInstr->getOperand(1)); assert(LHS->getOpcode() == Instruction::ZExt); assert(RHS->getOpcode() == Instruction::ZExt); Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); @@ -5709,70 +5876,26 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, // Recognize patterns switch (I.getPredicate()) { - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_NE: - // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. - ConstantInt *CI; - Value *ValToMask; - if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) { - if (ValToMask != MulVal) - return nullptr; - const APInt &CVal = CI->getValue() + 1; - if (CVal.isPowerOf2()) { - unsigned MaskWidth = CVal.logBase2(); - if (MaskWidth == MulWidth) - break; // Recognized - } - } - return nullptr; - - case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGT: { // Recognize pattern: // mulval = mul(zext A, zext B) // cmp ugt mulval, max - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getMaxValue(MulWidth); - MaxVal = MaxVal.zext(CI->getBitWidth()); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } - return nullptr; - - case ICmpInst::ICMP_UGE: - // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp uge mulval, max+1 - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } - return nullptr; - - case ICmpInst::ICMP_ULE: - // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp ule mulval, max - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getMaxValue(MulWidth); - MaxVal = MaxVal.zext(CI->getBitWidth()); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(OtherVal->getBitWidth()); + if (MaxVal.eq(*OtherVal)) + break; // Recognized return nullptr; + } - case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULT: { // Recognize pattern: // mulval = mul(zext A, zext B) // cmp ule mulval, max + 1 - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } + APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth); + if (MaxVal.eq(*OtherVal)) + break; // Recognized return nullptr; + } default: return nullptr; @@ -5798,7 +5921,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, if (MulVal->hasNUsesOrMore(2)) { Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); for (User *U : make_early_inc_range(MulVal->users())) { - if (U == &I || U == OtherVal) + if (U == &I) continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) @@ -5819,34 +5942,10 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, IC.addToWorklist(cast<Instruction>(U)); } } - if (isa<Instruction>(OtherVal)) - IC.addToWorklist(cast<Instruction>(OtherVal)); // The original icmp gets replaced with the overflow value, maybe inverted // depending on predicate. - bool Inverse = false; - switch (I.getPredicate()) { - case ICmpInst::ICMP_NE: - break; - case ICmpInst::ICMP_EQ: - Inverse = true; - break; - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - if (I.getOperand(0) == MulVal) - break; - Inverse = true; - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - if (I.getOperand(1) == MulVal) - break; - Inverse = true; - break; - default: - llvm_unreachable("Unexpected predicate"); - } - if (Inverse) { + if (I.getPredicate() == ICmpInst::ICMP_ULT) { Value *Res = Builder.CreateExtractValue(Call, 1); return BinaryOperator::CreateNot(Res); } @@ -6015,13 +6114,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { KnownBits Op0Known(BitWidth); KnownBits Op1Known(BitWidth); - if (SimplifyDemandedBits(&I, 0, - getDemandedBitsLHSMask(I, BitWidth), - Op0Known, 0)) - return &I; + { + // Don't use dominating conditions when folding icmp using known bits. This + // may convert signed into unsigned predicates in ways that other passes + // (especially IndVarSimplify) may not be able to reliably undo. + SQ.DC = nullptr; + auto _ = make_scope_exit([&]() { SQ.DC = &DC; }); + if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth), + Op0Known, 0)) + return &I; - if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) - return &I; + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) + return &I; + } // Given the known and unknown bits, compute a range that the LHS could be // in. Compute the Min, Max and RHS values based on the known bits. For the @@ -6269,57 +6374,70 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + // icmp eq/ne X, (zext/sext (icmp eq/ne X, C)) + ICmpInst::Predicate Pred1, Pred2; const APInt *C; - if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) && - match(I.getOperand(1), m_APInt(C)) && - X->getType()->isIntOrIntVectorTy(1) && - Y->getType()->isIntOrIntVectorTy(1)) { - unsigned BitWidth = C->getBitWidth(); - Pred = I.getPredicate(); - APInt Zero = APInt::getZero(BitWidth); - APInt MinusOne = APInt::getAllOnes(BitWidth); - APInt One(BitWidth, 1); - if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) || - (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT)) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) || - (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT)) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - if (I.getOperand(0)->hasOneUse()) { - APInt NewC = *C; - // canonicalize predicate to eq/ne - if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) || - (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) { - // x s< 0 in [-1, 1] --> x == -1 - // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1 - NewC = MinusOne; - Pred = ICmpInst::ICMP_EQ; - } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) || - (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) { - // x s> -1 in [-1, 1] --> x != -1 - // x u< -1 in [-1, 1] --> x != -1 - Pred = ICmpInst::ICMP_NE; - } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) { - // x s> 0 in [-1, 1] --> x == 1 - NewC = One; - Pred = ICmpInst::ICMP_EQ; - } else if (*C == One && Pred == ICmpInst::ICMP_SLT) { - // x s< 1 in [-1, 1] --> x != 1 - Pred = ICmpInst::ICMP_NE; + Instruction *ExtI; + if (match(&I, m_c_ICmp(Pred1, m_Value(X), + m_CombineAnd(m_Instruction(ExtI), + m_ZExtOrSExt(m_ICmp(Pred2, m_Deferred(X), + m_APInt(C)))))) && + ICmpInst::isEquality(Pred1) && ICmpInst::isEquality(Pred2)) { + bool IsSExt = ExtI->getOpcode() == Instruction::SExt; + bool HasOneUse = ExtI->hasOneUse() && ExtI->getOperand(0)->hasOneUse(); + auto CreateRangeCheck = [&] { + Value *CmpV1 = + Builder.CreateICmp(Pred1, X, Constant::getNullValue(X->getType())); + Value *CmpV2 = Builder.CreateICmp( + Pred1, X, ConstantInt::getSigned(X->getType(), IsSExt ? -1 : 1)); + return BinaryOperator::Create( + Pred1 == ICmpInst::ICMP_EQ ? Instruction::Or : Instruction::And, + CmpV1, CmpV2); + }; + if (C->isZero()) { + if (Pred2 == ICmpInst::ICMP_EQ) { + // icmp eq X, (zext/sext (icmp eq X, 0)) --> false + // icmp ne X, (zext/sext (icmp eq X, 0)) --> true + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE)); + } else if (!IsSExt || HasOneUse) { + // icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1 + // icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1 + // icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1 + // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X == -1 + return CreateRangeCheck(); } - - if (NewC == MinusOne) { - if (Pred == ICmpInst::ICMP_EQ) - return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); - if (Pred == ICmpInst::ICMP_NE) - return BinaryOperator::CreateOr(X, Builder.CreateNot(Y)); - } else if (NewC == One) { - if (Pred == ICmpInst::ICMP_EQ) - return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y)); - if (Pred == ICmpInst::ICMP_NE) - return BinaryOperator::CreateOr(Builder.CreateNot(X), Y); + } else if (IsSExt ? C->isAllOnes() : C->isOne()) { + if (Pred2 == ICmpInst::ICMP_NE) { + // icmp eq X, (zext (icmp ne X, 1)) --> false + // icmp ne X, (zext (icmp ne X, 1)) --> true + // icmp eq X, (sext (icmp ne X, -1)) --> false + // icmp ne X, (sext (icmp ne X, -1)) --> true + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE)); + } else if (!IsSExt || HasOneUse) { + // icmp eq X, (zext (icmp eq X, 1)) --> X == 0 || X == 1 + // icmp ne X, (zext (icmp eq X, 1)) --> X != 0 && X != 1 + // icmp eq X, (sext (icmp eq X, -1)) --> X == 0 || X == -1 + // icmp ne X, (sext (icmp eq X, -1)) --> X != 0 && X == -1 + return CreateRangeCheck(); } + } else { + // when C != 0 && C != 1: + // icmp eq X, (zext (icmp eq X, C)) --> icmp eq X, 0 + // icmp eq X, (zext (icmp ne X, C)) --> icmp eq X, 1 + // icmp ne X, (zext (icmp eq X, C)) --> icmp ne X, 0 + // icmp ne X, (zext (icmp ne X, C)) --> icmp ne X, 1 + // when C != 0 && C != -1: + // icmp eq X, (sext (icmp eq X, C)) --> icmp eq X, 0 + // icmp eq X, (sext (icmp ne X, C)) --> icmp eq X, -1 + // icmp ne X, (sext (icmp eq X, C)) --> icmp ne X, 0 + // icmp ne X, (sext (icmp ne X, C)) --> icmp ne X, -1 + return ICmpInst::Create( + Instruction::ICmp, Pred1, X, + ConstantInt::getSigned(X->getType(), Pred2 == ICmpInst::ICMP_NE + ? (IsSExt ? -1 : 1) + : 0)); } } @@ -6783,6 +6901,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; + if (Instruction *Res = foldICmpTruncWithTruncOrExt(I, Q)) + return Res; + // Test if the ICmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -6913,38 +7034,40 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { return Res; { - Value *A, *B; - // Transform (A & ~B) == 0 --> (A & B) != 0 - // and (A & ~B) != 0 --> (A & B) == 0 + Value *X, *Y; + // Transform (X & ~Y) == 0 --> (X & Y) != 0 + // and (X & ~Y) != 0 --> (X & Y) == 0 // if A is a power of 2. - if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, false, 0, &I) && I.isEquality()) - return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(A, B), + if (match(Op0, m_And(m_Value(X), m_Not(m_Value(Y)))) && + match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(X, false, 0, &I) && + I.isEquality()) + return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(X, Y), Op1); - // ~X < ~Y --> Y < X - // ~X < C --> X > ~C - if (match(Op0, m_Not(m_Value(A)))) { - if (match(Op1, m_Not(m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, A); - - const APInt *C; - if (match(Op1, m_APInt(C))) - return new ICmpInst(I.getSwappedPredicate(), A, - ConstantInt::get(Op1->getType(), ~(*C))); + // Op0 pred Op1 -> ~Op1 pred ~Op0, if this allows us to drop an instruction. + if (Op0->getType()->isIntOrIntVectorTy()) { + bool ConsumesOp0, ConsumesOp1; + if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) && + isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) && + (ConsumesOp0 || ConsumesOp1)) { + Value *InvOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder); + Value *InvOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder); + assert(InvOp0 && InvOp1 && + "Mismatch between isFreeToInvert and getFreelyInverted"); + return new ICmpInst(I.getSwappedPredicate(), InvOp0, InvOp1); + } } Instruction *AddI = nullptr; - if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B), + if (match(&I, m_UAddWithOverflow(m_Value(X), m_Value(Y), m_Instruction(AddI))) && - isa<IntegerType>(A->getType())) { + isa<IntegerType>(X->getType())) { Value *Result; Constant *Overflow; // m_UAddWithOverflow can match patterns that do not include an explicit // "add" instruction, so check the opcode of the matched op. if (AddI->getOpcode() == Instruction::Add && - OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI, + OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, X, Y, *AddI, Result, Overflow)) { replaceInstUsesWith(*AddI, Result); eraseInstFromFunction(*AddI); @@ -6952,14 +7075,37 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } - // (zext a) * (zext b) --> llvm.umul.with.overflow. - if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) + // (zext X) * (zext Y) --> llvm.umul.with.overflow. + if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) && + match(Op1, m_APInt(C))) { + if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this)) return R; } - if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) - return R; + + // Signbit test folds + // Fold (X u>> BitWidth - 1 Pred ZExt(i1)) --> X s< 0 Pred i1 + // Fold (X s>> BitWidth - 1 Pred SExt(i1)) --> X s< 0 Pred i1 + Instruction *ExtI; + if ((I.isUnsigned() || I.isEquality()) && + match(Op1, + m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(Y)))) && + Y->getType()->getScalarSizeInBits() == 1 && + (Op0->hasOneUse() || Op1->hasOneUse())) { + unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + Instruction *ShiftI; + if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), + m_Shr(m_Value(X), m_SpecificIntAllowUndef( + OpWidth - 1))))) { + unsigned ExtOpc = ExtI->getOpcode(); + unsigned ShiftOpc = ShiftI->getOpcode(); + if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || + (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { + Value *SLTZero = + Builder.CreateICmpSLT(X, Constant::getNullValue(X->getType())); + Value *Cmp = Builder.CreateICmp(Pred, SLTZero, Y, I.getName()); + return replaceInstUsesWith(I, Cmp); + } + } } } @@ -7177,17 +7323,14 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, } // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or - // [0, UMAX], but it may still be fractional. See if it is fractional by - // casting the FP value to the integer value and back, checking for equality. + // [0, UMAX], but it may still be fractional. Check whether this is the case + // using the IsExact flag. // Don't do this for zero, because -0.0 is not fractional. - Constant *RHSInt = LHSUnsigned - ? ConstantExpr::getFPToUI(RHSC, IntTy) - : ConstantExpr::getFPToSI(RHSC, IntTy); + APSInt RHSInt(IntWidth, LHSUnsigned); + bool IsExact; + RHS.convertToInteger(RHSInt, APFloat::rmTowardZero, &IsExact); if (!RHS.isZero()) { - bool Equal = LHSUnsigned - ? ConstantExpr::getUIToFP(RHSInt, RHSC->getType()) == RHSC - : ConstantExpr::getSIToFP(RHSInt, RHSC->getType()) == RHSC; - if (!Equal) { + if (!IsExact) { // If we had a comparison against a fractional value, we have to adjust // the compare predicate and sometimes the value. RHSC is rounded towards // zero at this point. @@ -7253,7 +7396,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, // Lower this FP comparison into an appropriate integer version of the // comparison. - return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt); + return new ICmpInst(Pred, LHSI->getOperand(0), Builder.getInt(RHSInt)); } /// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary. @@ -7532,12 +7675,8 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) { switch (LHSI->getOpcode()) { case Instruction::PHI: - // Only fold fcmp into the PHI if the phi and fcmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) - return NV; + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) + return NV; break; case Instruction::SIToFP: case Instruction::UIToFP: diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 701579e1de48..bb620ad8d41c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -16,6 +16,7 @@ #define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" @@ -73,6 +74,10 @@ public: virtual ~InstCombinerImpl() = default; + /// Perform early cleanup and prepare the InstCombine worklist. + bool prepareWorklist(Function &F, + ReversePostOrderTraversal<BasicBlock *> &RPOT); + /// Run the combiner over the entire worklist until it is empty. /// /// \returns true if the IR is changed. @@ -93,6 +98,7 @@ public: Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); + Instruction *foldFMulReassoc(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); @@ -126,7 +132,6 @@ public: Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); Instruction *commonCastTransforms(CastInst &CI); - Instruction *commonPointerCastTransforms(CastInst &CI); Instruction *visitTrunc(TruncInst &CI); Instruction *visitZExt(ZExtInst &Zext); Instruction *visitSExt(SExtInst &Sext); @@ -193,6 +198,44 @@ public: LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy, const Twine &Suffix = ""); + KnownFPClass computeKnownFPClass(Value *Val, FastMathFlags FMF, + FPClassTest Interested = fcAllFlags, + const Instruction *CtxI = nullptr, + unsigned Depth = 0) const { + return llvm::computeKnownFPClass(Val, FMF, DL, Interested, Depth, &TLI, &AC, + CtxI, &DT); + } + + KnownFPClass computeKnownFPClass(Value *Val, + FPClassTest Interested = fcAllFlags, + const Instruction *CtxI = nullptr, + unsigned Depth = 0) const { + return llvm::computeKnownFPClass(Val, DL, Interested, Depth, &TLI, &AC, + CtxI, &DT); + } + + /// Check if fmul \p MulVal, +0.0 will yield +0.0 (or signed zero is + /// ignorable). + bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF, + const Instruction *CtxI) const; + + Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) { + Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy); + Constant *ExtTruncC = + ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL); + if (ExtTruncC && ExtTruncC == C) + return TruncC; + return nullptr; + } + + Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) { + return getLosslessTrunc(C, TruncTy, Instruction::ZExt); + } + + Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) { + return getLosslessTrunc(C, TruncTy, Instruction::SExt); + } + private: bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const; @@ -252,13 +295,15 @@ private: Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext); - bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, + bool willNotOverflowSignedAdd(const WithCache<const Value *> &LHS, + const WithCache<const Value *> &RHS, const Instruction &CxtI) const { return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; } - bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, + bool willNotOverflowUnsignedAdd(const WithCache<const Value *> &LHS, + const WithCache<const Value *> &RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; @@ -387,15 +432,17 @@ private: Instruction *foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI, bool IsAnd); + Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource); + public: /// Create and insert the idiom we use to indicate a block is unreachable /// without having to rewrite the CFG from within InstCombine. void CreateNonTerminatorUnreachable(Instruction *InsertAt) { auto &Ctx = InsertAt->getContext(); auto *SI = new StoreInst(ConstantInt::getTrue(Ctx), - PoisonValue::get(Type::getInt1PtrTy(Ctx)), + PoisonValue::get(PointerType::getUnqual(Ctx)), /*isVolatile*/ false, Align(1)); - InsertNewInstBefore(SI, *InsertAt); + InsertNewInstBefore(SI, InsertAt->getIterator()); } /// Combiner aware instruction erasure. @@ -412,6 +459,7 @@ public: // use counts. SmallVector<Value *> Ops(I.operands()); Worklist.remove(&I); + DC.removeValue(&I); I.eraseFromParent(); for (Value *Op : Ops) Worklist.handleUseCountDecrement(Op); @@ -498,6 +546,7 @@ public: /// Tries to simplify operands to an integer instruction based on its /// demanded bits. bool SimplifyDemandedInstructionBits(Instruction &Inst); + bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0, @@ -535,6 +584,9 @@ public: Instruction *foldAddWithConstant(BinaryOperator &Add); + Instruction *foldSquareSumInt(BinaryOperator &I); + Instruction *foldSquareSumFP(BinaryOperator &I); + /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *foldPHIArgOpIntoPHI(PHINode &PN); @@ -580,6 +632,9 @@ public: Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, const APInt &C); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); + Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax, + Value *Z, ICmpInst::Predicate Pred); + Instruction *foldICmpWithMinMax(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); Instruction *foldSignBitTest(ICmpInst &I); @@ -593,6 +648,8 @@ public: ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, const APInt &C); + Instruction *foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, + const SimplifyQuery &Q); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, const APInt &C); Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, @@ -667,8 +724,12 @@ public: bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock); bool removeInstructionsBeforeUnreachable(Instruction &I); - bool handleUnreachableFrom(Instruction *I); - bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); + void addDeadEdge(BasicBlock *From, BasicBlock *To, + SmallVectorImpl<BasicBlock *> &Worklist); + void handleUnreachableFrom(Instruction *I, + SmallVectorImpl<BasicBlock *> &Worklist); + void handlePotentiallyDeadBlocks(SmallVectorImpl<BasicBlock *> &Worklist); + void handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); }; @@ -679,16 +740,11 @@ class Negator final { using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; BuilderTy Builder; - const DataLayout &DL; - AssumptionCache &AC; - const DominatorTree &DT; - const bool IsTrulyNegation; SmallDenseMap<Value *, Value *> NegationsCache; - Negator(LLVMContext &C, const DataLayout &DL, AssumptionCache &AC, - const DominatorTree &DT, bool IsTrulyNegation); + Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation); #if LLVM_ENABLE_STATS unsigned NumValuesVisitedInThisNegator = 0; @@ -700,13 +756,13 @@ class Negator final { std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I); - [[nodiscard]] Value *visitImpl(Value *V, unsigned Depth); + [[nodiscard]] Value *visitImpl(Value *V, bool IsNSW, unsigned Depth); - [[nodiscard]] Value *negate(Value *V, unsigned Depth); + [[nodiscard]] Value *negate(Value *V, bool IsNSW, unsigned Depth); /// Recurse depth-first and attempt to sink the negation. /// FIXME: use worklist? - [[nodiscard]] std::optional<Result> run(Value *Root); + [[nodiscard]] std::optional<Result> run(Value *Root, bool IsNSW); Negator(const Negator &) = delete; Negator(Negator &&) = delete; @@ -716,7 +772,7 @@ class Negator final { public: /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, /// otherwise returns negated value. - [[nodiscard]] static Value *Negate(bool LHSIsZero, Value *Root, + [[nodiscard]] static Value *Negate(bool LHSIsZero, bool IsNSW, Value *Root, InstCombinerImpl &IC); }; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 6aa20ee26b9a..b72b68c68d98 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -36,6 +36,13 @@ static cl::opt<unsigned> MaxCopiedFromConstantUsers( cl::desc("Maximum users to visit in copy from constant transform"), cl::Hidden); +namespace llvm { +cl::opt<bool> EnableInferAlignmentPass( + "enable-infer-alignment-pass", cl::init(true), cl::Hidden, cl::ZeroOrMore, + cl::desc("Enable the InferAlignment pass, disabling alignment inference in " + "InstCombine")); +} + /// isOnlyCopiedFromConstantMemory - Recursively walk the uses of a (derived) /// pointer to an alloca. Ignore any reads of the pointer, return false if we /// see any stores or other unknown uses. If we see pointer arithmetic, keep @@ -224,7 +231,7 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, Value *Idx[2] = {NullIdx, NullIdx}; Instruction *GEP = GetElementPtrInst::CreateInBounds( NewTy, New, Idx, New->getName() + ".sub"); - IC.InsertNewInstBefore(GEP, *It); + IC.InsertNewInstBefore(GEP, It); // Now make everything use the getelementptr instead of the original // allocation. @@ -380,7 +387,7 @@ void PointerReplacer::replace(Instruction *I) { NewI->takeName(LT); copyMetadataForLoad(*NewI, *LT); - IC.InsertNewInstWith(NewI, *LT); + IC.InsertNewInstWith(NewI, LT->getIterator()); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; } else if (auto *PHI = dyn_cast<PHINode>(I)) { @@ -398,7 +405,7 @@ void PointerReplacer::replace(Instruction *I) { Indices.append(GEP->idx_begin(), GEP->idx_end()); auto *NewI = GetElementPtrInst::Create(GEP->getSourceElementType(), V, Indices); - IC.InsertNewInstWith(NewI, *GEP); + IC.InsertNewInstWith(NewI, GEP->getIterator()); NewI->takeName(GEP); WorkMap[GEP] = NewI; } else if (auto *BC = dyn_cast<BitCastInst>(I)) { @@ -407,14 +414,14 @@ void PointerReplacer::replace(Instruction *I) { auto *NewT = PointerType::get(BC->getType()->getContext(), V->getType()->getPointerAddressSpace()); auto *NewI = new BitCastInst(V, NewT); - IC.InsertNewInstWith(NewI, *BC); + IC.InsertNewInstWith(NewI, BC->getIterator()); NewI->takeName(BC); WorkMap[BC] = NewI; } else if (auto *SI = dyn_cast<SelectInst>(I)) { auto *NewSI = SelectInst::Create( SI->getCondition(), getReplacement(SI->getTrueValue()), getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI); - IC.InsertNewInstWith(NewSI, *SI); + IC.InsertNewInstWith(NewSI, SI->getIterator()); NewSI->takeName(SI); WorkMap[SI] = NewSI; } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { @@ -449,7 +456,7 @@ void PointerReplacer::replace(Instruction *I) { ASC->getType()->getPointerAddressSpace()) { auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); NewI->takeName(ASC); - IC.InsertNewInstWith(NewI, *ASC); + IC.InsertNewInstWith(NewI, ASC->getIterator()); NewV = NewI; } IC.replaceInstUsesWith(*ASC, NewV); @@ -507,8 +514,6 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { // types. const Align MaxAlign = std::max(EntryAI->getAlign(), AI.getAlign()); EntryAI->setAlignment(MaxAlign); - if (AI.getType() != EntryAI->getType()) - return new BitCastInst(EntryAI, AI.getType()); return replaceInstUsesWith(AI, EntryAI); } } @@ -534,13 +539,11 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace(); - auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace); if (AI.getAddressSpace() == SrcAddrSpace) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); - Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); - Instruction *NewI = replaceInstUsesWith(AI, Cast); + Instruction *NewI = replaceInstUsesWith(AI, TheSrc); eraseInstFromFunction(*Copy); ++NumGlobalCopies; return NewI; @@ -551,8 +554,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); - Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); - PtrReplacer.replacePointer(Cast); + PtrReplacer.replacePointer(TheSrc); ++NumGlobalCopies; } } @@ -582,16 +584,9 @@ LoadInst *InstCombinerImpl::combineLoadToNewType(LoadInst &LI, Type *NewTy, assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && "can't fold an atomic load to requested type"); - Value *Ptr = LI.getPointerOperand(); - unsigned AS = LI.getPointerAddressSpace(); - Type *NewPtrTy = NewTy->getPointerTo(AS); - Value *NewPtr = nullptr; - if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && - NewPtr->getType() == NewPtrTy)) - NewPtr = Builder.CreateBitCast(Ptr, NewPtrTy); - - LoadInst *NewLoad = Builder.CreateAlignedLoad( - NewTy, NewPtr, LI.getAlign(), LI.isVolatile(), LI.getName() + Suffix); + LoadInst *NewLoad = + Builder.CreateAlignedLoad(NewTy, LI.getPointerOperand(), LI.getAlign(), + LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); copyMetadataForLoad(*NewLoad, LI); return NewLoad; @@ -606,13 +601,11 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, "can't fold an atomic store of requested type"); Value *Ptr = SI.getPointerOperand(); - unsigned AS = SI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; SI.getAllMetadata(MD); - StoreInst *NewStore = IC.Builder.CreateAlignedStore( - V, IC.Builder.CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), - SI.getAlign(), SI.isVolatile()); + StoreInst *NewStore = + IC.Builder.CreateAlignedStore(V, Ptr, SI.getAlign(), SI.isVolatile()); NewStore->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; @@ -655,29 +648,6 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, return NewStore; } -/// Returns true if instruction represent minmax pattern like: -/// select ((cmp load V1, load V2), V1, V2). -static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { - assert(V->getType()->isPointerTy() && "Expected pointer type."); - // Ignore possible ty* to ixx* bitcast. - V = InstCombiner::peekThroughBitcast(V); - // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax - // pattern. - CmpInst::Predicate Pred; - Instruction *L1; - Instruction *L2; - Value *LHS; - Value *RHS; - if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), - m_Value(LHS), m_Value(RHS)))) - return false; - LoadTy = L1->getType(); - return (match(L1, m_Load(m_Specific(LHS))) && - match(L2, m_Load(m_Specific(RHS)))) || - (match(L1, m_Load(m_Specific(RHS))) && - match(L2, m_Load(m_Specific(LHS)))); -} - /// Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// @@ -818,7 +788,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { return nullptr; const DataLayout &DL = IC.getDataLayout(); - auto EltSize = DL.getTypeAllocSize(ET); + TypeSize EltSize = DL.getTypeAllocSize(ET); const auto Align = LI.getAlign(); auto *Addr = LI.getPointerOperand(); @@ -826,7 +796,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *Zero = ConstantInt::get(IdxType, 0); Value *V = PoisonValue::get(T); - uint64_t Offset = 0; + TypeSize Offset = TypeSize::get(0, ET->isScalableTy()); for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, @@ -834,9 +804,9 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { }; auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), Name + ".elt"); + auto EltAlign = commonAlignment(Align, Offset.getKnownMinValue()); auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, - commonAlignment(Align, Offset), - Name + ".unpack"); + EltAlign, Name + ".unpack"); L->setAAMetadata(LI.getAAMetadata()); V = IC.Builder.CreateInsertValue(V, L, i); Offset += EltSize; @@ -971,7 +941,7 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, Type *SourceElementType = GEPI->getSourceElementType(); // Size information about scalable vectors is not available, so we cannot // deduce whether indexing at n is undefined behaviour or not. Bail out. - if (isa<ScalableVectorType>(SourceElementType)) + if (SourceElementType->isScalableTy()) return false; Type *AllocTy = GetElementPtrInst::getIndexedType(SourceElementType, Ops); @@ -1020,7 +990,7 @@ static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr, Instruction *NewGEPI = GEPI->clone(); NewGEPI->setOperand(Idx, ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0)); - IC.InsertNewInstBefore(NewGEPI, *GEPI); + IC.InsertNewInstBefore(NewGEPI, GEPI->getIterator()); return NewGEPI; } } @@ -1062,11 +1032,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { if (Instruction *Res = combineLoadToOperationType(*this, LI)) return Res; - // Attempt to improve the alignment. - Align KnownAlign = getOrEnforceKnownAlignment( - Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT); - if (KnownAlign > LI.getAlign()) - LI.setAlignment(KnownAlign); + if (!EnableInferAlignmentPass) { + // Attempt to improve the alignment. + Align KnownAlign = getOrEnforceKnownAlignment( + Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT); + if (KnownAlign > LI.getAlign()) + LI.setAlignment(KnownAlign); + } // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) @@ -1337,7 +1309,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { return false; const DataLayout &DL = IC.getDataLayout(); - auto EltSize = DL.getTypeAllocSize(AT->getElementType()); + TypeSize EltSize = DL.getTypeAllocSize(AT->getElementType()); const auto Align = SI.getAlign(); SmallString<16> EltName = V->getName(); @@ -1349,7 +1321,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *IdxType = Type::getInt64Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - uint64_t Offset = 0; + TypeSize Offset = TypeSize::get(0, AT->getElementType()->isScalableTy()); for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, @@ -1358,7 +1330,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); - auto EltAlign = commonAlignment(Align, Offset); + auto EltAlign = commonAlignment(Align, Offset.getKnownMinValue()); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); NS->setAAMetadata(SI.getAAMetadata()); Offset += EltSize; @@ -1399,58 +1371,6 @@ static bool equivalentAddressValues(Value *A, Value *B) { return false; } -/// Converts store (bitcast (load (bitcast (select ...)))) to -/// store (load (select ...)), where select is minmax: -/// select ((cmp load V1, load V2), V1, V2). -static bool removeBitcastsFromLoadStoreOnMinMax(InstCombinerImpl &IC, - StoreInst &SI) { - // bitcast? - if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) - return false; - // load? integer? - Value *LoadAddr; - if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) - return false; - auto *LI = cast<LoadInst>(SI.getValueOperand()); - if (!LI->getType()->isIntegerTy()) - return false; - Type *CmpLoadTy; - if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy)) - return false; - - // Make sure the type would actually change. - // This condition can be hit with chains of bitcasts. - if (LI->getType() == CmpLoadTy) - return false; - - // Make sure we're not changing the size of the load/store. - const auto &DL = IC.getDataLayout(); - if (DL.getTypeStoreSizeInBits(LI->getType()) != - DL.getTypeStoreSizeInBits(CmpLoadTy)) - return false; - - if (!all_of(LI->users(), [LI, LoadAddr](User *U) { - auto *SI = dyn_cast<StoreInst>(U); - return SI && SI->getPointerOperand() != LI && - InstCombiner::peekThroughBitcast(SI->getPointerOperand()) != - LoadAddr && - !SI->getPointerOperand()->isSwiftError(); - })) - return false; - - IC.Builder.SetInsertPoint(LI); - LoadInst *NewLI = IC.combineLoadToNewType(*LI, CmpLoadTy); - // Replace all the stores with stores of the newly loaded value. - for (auto *UI : LI->users()) { - auto *USI = cast<StoreInst>(UI); - IC.Builder.SetInsertPoint(USI); - combineStoreToNewValue(IC, *USI, NewLI); - } - IC.replaceInstUsesWith(*LI, PoisonValue::get(LI->getType())); - IC.eraseInstFromFunction(*LI); - return true; -} - Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); @@ -1459,19 +1379,18 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { if (combineStoreToValueType(*this, SI)) return eraseInstFromFunction(SI); - // Attempt to improve the alignment. - const Align KnownAlign = getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT); - if (KnownAlign > SI.getAlign()) - SI.setAlignment(KnownAlign); + if (!EnableInferAlignmentPass) { + // Attempt to improve the alignment. + const Align KnownAlign = getOrEnforceKnownAlignment( + Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT); + if (KnownAlign > SI.getAlign()) + SI.setAlignment(KnownAlign); + } // Try to canonicalize the stored type. if (unpackStoreToAggregate(*this, SI)) return eraseInstFromFunction(SI); - if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) - return eraseInstFromFunction(SI); - // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) return replaceOperand(SI, 1, NewGEPI); @@ -1508,8 +1427,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { --BBI; // Don't count debug info directives, lest they affect codegen, // and we skip pointer-to-pointer bitcasts, which are NOPs. - if (BBI->isDebugOrPseudoInst() || - (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { + if (BBI->isDebugOrPseudoInst()) { ScanInsts++; continue; } @@ -1560,11 +1478,15 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { // This is a non-terminator unreachable marker. Don't remove it. if (isa<UndefValue>(Ptr)) { - // Remove all instructions after the marker and guaranteed-to-transfer - // instructions before the marker. - if (handleUnreachableFrom(SI.getNextNode()) || - removeInstructionsBeforeUnreachable(SI)) + // Remove guaranteed-to-transfer instructions before the marker. + if (removeInstructionsBeforeUnreachable(SI)) return &SI; + + // Remove all instructions after the marker and handle dead blocks this + // implies. + SmallVector<BasicBlock *> Worklist; + handleUnreachableFrom(SI.getNextNode(), Worklist); + handlePotentiallyDeadBlocks(Worklist); return nullptr; } @@ -1626,8 +1548,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { if (OtherBr->isUnconditional()) { --BBI; // Skip over debugging info and pseudo probes. - while (BBI->isDebugOrPseudoInst() || - (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { + while (BBI->isDebugOrPseudoInst()) { if (BBI==OtherBB->begin()) return false; --BBI; @@ -1681,7 +1602,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { Builder.SetInsertPoint(OtherStore); PN->addIncoming(Builder.CreateBitOrPointerCast(MergedVal, PN->getType()), OtherBB); - MergedVal = InsertNewInstBefore(PN, DestBB->front()); + MergedVal = InsertNewInstBefore(PN, DestBB->begin()); PN->setDebugLoc(MergedLoc); } @@ -1690,7 +1611,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), SI.getAlign(), SI.getOrdering(), SI.getSyncScopeID()); - InsertNewInstBefore(NewSI, *BBI); + InsertNewInstBefore(NewSI, BBI); NewSI->setDebugLoc(MergedLoc); NewSI->mergeDIAssignID({&SI, OtherStore}); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 50458e2773e6..8d5866e98a8e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -258,9 +258,14 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) { // Interpret X * (-1<<C) as (-X) * (1<<C) and try to sink the negation. // The "* (1<<C)" thus becomes a potential shifting opportunity. - if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this)) - return BinaryOperator::CreateMul( - NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName()); + if (Value *NegOp0 = + Negator::Negate(/*IsNegation*/ true, HasNSW, Op0, *this)) { + auto *Op1C = cast<Constant>(Op1); + return replaceInstUsesWith( + I, Builder.CreateMul(NegOp0, ConstantExpr::getNeg(Op1C), "", + /* HasNUW */ false, + HasNSW && Op1C->isNotMinSignedValue())); + } // Try to convert multiply of extended operand to narrow negate and shift // for better analysis. @@ -295,9 +300,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // Canonicalize (X|C1)*MulC -> X*MulC+C1*MulC. Value *X; Constant *C1; - if ((match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(C1))))) || - (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C1)))) && - haveNoCommonBitsSet(X, C1, DL, &AC, &I, &DT))) { + if (match(Op0, m_OneUse(m_AddLike(m_Value(X), m_ImmConstant(C1))))) { // C1*MulC simplifies to a tidier constant. Value *NewC = Builder.CreateMul(C1, MulC); auto *BOp0 = cast<BinaryOperator>(Op0); @@ -555,6 +558,180 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { return nullptr; } +Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *X, *Y; + Constant *C; + + // Reassociate constant RHS with another constant to form constant + // expression. + if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { + Constant *C1; + if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { + // (C1 / X) * C --> (C * C1) / X + Constant *CC1 = + ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL); + if (CC1 && CC1->isNormalFP()) + return BinaryOperator::CreateFDivFMF(CC1, X, &I); + } + if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { + // (X / C1) * C --> X * (C / C1) + Constant *CDivC1 = + ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL); + if (CDivC1 && CDivC1->isNormalFP()) + return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); + + // If the constant was a denormal, try reassociating differently. + // (X / C1) * C --> X / (C1 / C) + Constant *C1DivC = + ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL); + if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP()) + return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); + } + + // We do not need to match 'fadd C, X' and 'fsub X, C' because they are + // canonicalized to 'fadd X, C'. Distributing the multiply may allow + // further folds and (X * C) + C2 is 'fma'. + if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { + // (X + C1) * C --> (X * C) + (C * C1) + if (Constant *CC1 = + ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFAddFMF(XC, CC1, &I); + } + } + if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { + // (C1 - X) * C --> (C * C1) - (X * C) + if (Constant *CC1 = + ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFSubFMF(CC1, XC, &I); + } + } + } + + Value *Z; + if (match(&I, + m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) { + // Sink division: (X / Y) * Z --> (X * Z) / Y + Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); + return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); + } + + // sqrt(X) * sqrt(Y) -> sqrt(X * Y) + // nnan disallows the possibility of returning a number if both operands are + // negative (in that case, we should return NaN). + if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) && + match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); + return replaceInstUsesWith(I, Sqrt); + } + + // The following transforms are done irrespective of the number of uses + // for the expression "1.0/sqrt(X)". + // 1) 1.0/sqrt(X) * X -> X/sqrt(X) + // 2) X * 1.0/sqrt(X) -> X/sqrt(X) + // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it + // has the necessary (reassoc) fast-math-flags. + if (I.hasNoSignedZeros() && + match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Sqrt(m_Value(X))) && Op1 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + if (I.hasNoSignedZeros() && + match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Sqrt(m_Value(X))) && Op0 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + + // Like the similar transform in instsimplify, this requires 'nsz' because + // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. + if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) { + // Peek through fdiv to find squaring of square root: + // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y + if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(XX, Y, &I); + } + // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) + if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(Y, XX, &I); + } + } + + // pow(X, Y) * X --> pow(X, Y+1) + // X * pow(X, Y) --> pow(X, Y+1) + if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X), + m_Value(Y))), + m_Deferred(X)))) { + Value *Y1 = Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I); + Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I); + return replaceInstUsesWith(I, Pow); + } + + if (I.isOnlyUserOfAnyOperand()) { + // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { + auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); + return replaceInstUsesWith(I, NewPow); + } + // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) { + auto *XZ = Builder.CreateFMulFMF(X, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I); + return replaceInstUsesWith(I, NewPow); + } + + // powi(x, y) * powi(x, z) -> powi(x, y + z) + if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && + Y->getType() == Z->getType()) { + auto *YZ = Builder.CreateAdd(Y, Z); + auto *NewPow = Builder.CreateIntrinsic( + Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); + return replaceInstUsesWith(I, NewPow); + } + + // exp(X) * exp(Y) -> exp(X + Y) + if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I); + return replaceInstUsesWith(I, Exp); + } + + // exp2(X) * exp2(Y) -> exp2(X + Y) + if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I); + return replaceInstUsesWith(I, Exp2); + } + } + + // (X*Y) * X => (X*X) * Y where Y != X + // The purpose is two-fold: + // 1) to form a power expression (of X). + // 2) potentially shorten the critical path: After transformation, the + // latency of the instruction Y is amortized by the expression of X*X, + // and therefore Y is in a "less critical" position compared to what it + // was before the transformation. + if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && Op1 != Y) { + Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && Op0 != Y) { + Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -602,176 +779,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); - if (I.hasAllowReassoc()) { - // Reassociate constant RHS with another constant to form constant - // expression. - if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { - Constant *C1; - if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { - // (C1 / X) * C --> (C * C1) / X - Constant *CC1 = - ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL); - if (CC1 && CC1->isNormalFP()) - return BinaryOperator::CreateFDivFMF(CC1, X, &I); - } - if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { - // (X / C1) * C --> X * (C / C1) - Constant *CDivC1 = - ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL); - if (CDivC1 && CDivC1->isNormalFP()) - return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); - - // If the constant was a denormal, try reassociating differently. - // (X / C1) * C --> X / (C1 / C) - Constant *C1DivC = - ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL); - if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP()) - return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); - } - - // We do not need to match 'fadd C, X' and 'fsub X, C' because they are - // canonicalized to 'fadd X, C'. Distributing the multiply may allow - // further folds and (X * C) + C2 is 'fma'. - if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { - // (X + C1) * C --> (X * C) + (C * C1) - if (Constant *CC1 = ConstantFoldBinaryOpOperands( - Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMulFMF(X, C, &I); - return BinaryOperator::CreateFAddFMF(XC, CC1, &I); - } - } - if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { - // (C1 - X) * C --> (C * C1) - (X * C) - if (Constant *CC1 = ConstantFoldBinaryOpOperands( - Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMulFMF(X, C, &I); - return BinaryOperator::CreateFSubFMF(CC1, XC, &I); - } - } - } - - Value *Z; - if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), - m_Value(Z)))) { - // Sink division: (X / Y) * Z --> (X * Z) / Y - Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); - return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); - } - - // sqrt(X) * sqrt(Y) -> sqrt(X * Y) - // nnan disallows the possibility of returning a number if both operands are - // negative (in that case, we should return NaN). - if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) && - match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) { - Value *XY = Builder.CreateFMulFMF(X, Y, &I); - Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); - return replaceInstUsesWith(I, Sqrt); - } - - // The following transforms are done irrespective of the number of uses - // for the expression "1.0/sqrt(X)". - // 1) 1.0/sqrt(X) * X -> X/sqrt(X) - // 2) X * 1.0/sqrt(X) -> X/sqrt(X) - // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it - // has the necessary (reassoc) fast-math-flags. - if (I.hasNoSignedZeros() && - match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && - match(Y, m_Sqrt(m_Value(X))) && Op1 == X) - return BinaryOperator::CreateFDivFMF(X, Y, &I); - if (I.hasNoSignedZeros() && - match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && - match(Y, m_Sqrt(m_Value(X))) && Op0 == X) - return BinaryOperator::CreateFDivFMF(X, Y, &I); - - // Like the similar transform in instsimplify, this requires 'nsz' because - // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. - if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && - Op0->hasNUses(2)) { - // Peek through fdiv to find squaring of square root: - // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y - if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) { - Value *XX = Builder.CreateFMulFMF(X, X, &I); - return BinaryOperator::CreateFDivFMF(XX, Y, &I); - } - // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) - if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) { - Value *XX = Builder.CreateFMulFMF(X, X, &I); - return BinaryOperator::CreateFDivFMF(Y, XX, &I); - } - } - - // pow(X, Y) * X --> pow(X, Y+1) - // X * pow(X, Y) --> pow(X, Y+1) - if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X), - m_Value(Y))), - m_Deferred(X)))) { - Value *Y1 = - Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I); - Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I); - return replaceInstUsesWith(I, Pow); - } - - if (I.isOnlyUserOfAnyOperand()) { - // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) - if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { - auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); - auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); - return replaceInstUsesWith(I, NewPow); - } - // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y) - if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) { - auto *XZ = Builder.CreateFMulFMF(X, Z, &I); - auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I); - return replaceInstUsesWith(I, NewPow); - } - - // powi(x, y) * powi(x, z) -> powi(x, y + z) - if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && - Y->getType() == Z->getType()) { - auto *YZ = Builder.CreateAdd(Y, Z); - auto *NewPow = Builder.CreateIntrinsic( - Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); - return replaceInstUsesWith(I, NewPow); - } - - // exp(X) * exp(Y) -> exp(X + Y) - if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && - match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { - Value *XY = Builder.CreateFAddFMF(X, Y, &I); - Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I); - return replaceInstUsesWith(I, Exp); - } - - // exp2(X) * exp2(Y) -> exp2(X + Y) - if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) && - match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) { - Value *XY = Builder.CreateFAddFMF(X, Y, &I); - Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I); - return replaceInstUsesWith(I, Exp2); - } - } - - // (X*Y) * X => (X*X) * Y where Y != X - // The purpose is two-fold: - // 1) to form a power expression (of X). - // 2) potentially shorten the critical path: After transformation, the - // latency of the instruction Y is amortized by the expression of X*X, - // and therefore Y is in a "less critical" position compared to what it - // was before the transformation. - if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && - Op1 != Y) { - Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); - return BinaryOperator::CreateFMulFMF(XX, Y, &I); - } - if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && - Op0 != Y) { - Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); - return BinaryOperator::CreateFMulFMF(XX, Y, &I); - } - } + if (I.hasAllowReassoc()) + if (Instruction *FoldedMul = foldFMulReassoc(I)) + return FoldedMul; // log2(X * 0.5) * Y = log2(X) * Y - Y if (I.isFast()) { @@ -802,7 +812,7 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { I.hasNoSignedZeros() && match(Start, m_Zero())) return replaceInstUsesWith(I, Start); - // minimun(X, Y) * maximum(X, Y) => X * Y. + // minimum(X, Y) * maximum(X, Y) => X * Y. if (match(&I, m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), @@ -918,8 +928,7 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, return Remainder.isMinValue(); } -static Instruction *foldIDivShl(BinaryOperator &I, - InstCombiner::BuilderTy &Builder) { +static Value *foldIDivShl(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { assert((I.getOpcode() == Instruction::SDiv || I.getOpcode() == Instruction::UDiv) && "Expected integer divide"); @@ -928,7 +937,6 @@ static Instruction *foldIDivShl(BinaryOperator &I, Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); - Instruction *Ret = nullptr; Value *X, *Y, *Z; // With appropriate no-wrap constraints, remove a common factor in the @@ -943,12 +951,12 @@ static Instruction *foldIDivShl(BinaryOperator &I, // (X * Y) u/ (X << Z) --> Y u>> Z if (!IsSigned && HasNUW) - Ret = BinaryOperator::CreateLShr(Y, Z); + return Builder.CreateLShr(Y, Z, "", I.isExact()); // (X * Y) s/ (X << Z) --> Y s/ (1 << Z) if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) { Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z); - Ret = BinaryOperator::CreateSDiv(Y, Shl); + return Builder.CreateSDiv(Y, Shl, "", I.isExact()); } } @@ -966,20 +974,38 @@ static Instruction *foldIDivShl(BinaryOperator &I, ((Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()) || (Shl0->hasNoUnsignedWrap() && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap()))) - Ret = BinaryOperator::CreateUDiv(X, Y); + return Builder.CreateUDiv(X, Y, "", I.isExact()); // For signed div, we need 'nsw' on both shifts + 'nuw' on the divisor. // (X << Z) / (Y << Z) --> X / Y if (IsSigned && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap() && Shl1->hasNoUnsignedWrap()) - Ret = BinaryOperator::CreateSDiv(X, Y); + return Builder.CreateSDiv(X, Y, "", I.isExact()); } - if (!Ret) - return nullptr; + // If X << Y and X << Z does not overflow, then: + // (X << Y) / (X << Z) -> (1 << Y) / (1 << Z) -> 1 << Y >> Z + if (match(Op0, m_Shl(m_Value(X), m_Value(Y))) && + match(Op1, m_Shl(m_Specific(X), m_Value(Z)))) { + auto *Shl0 = cast<OverflowingBinaryOperator>(Op0); + auto *Shl1 = cast<OverflowingBinaryOperator>(Op1); - Ret->setIsExact(I.isExact()); - return Ret; + if (IsSigned ? (Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap()) + : (Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap())) { + Constant *One = ConstantInt::get(X->getType(), 1); + // Only preserve the nsw flag if dividend has nsw + // or divisor has nsw and operator is sdiv. + Value *Dividend = Builder.CreateShl( + One, Y, "shl.dividend", + /*HasNUW*/ true, + /*HasNSW*/ + IsSigned ? (Shl0->hasNoUnsignedWrap() || Shl1->hasNoUnsignedWrap()) + : Shl0->hasNoSignedWrap()); + return Builder.CreateLShr(Dividend, Z, "", I.isExact()); + } + } + + return nullptr; } /// This function implements the transforms common to both integer division @@ -1156,8 +1182,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { return NewDiv; } - if (Instruction *R = foldIDivShl(I, Builder)) - return R; + if (Value *R = foldIDivShl(I, Builder)) + return replaceInstUsesWith(I, R); // With the appropriate no-wrap constraint, remove a multiply by the divisor // after peeking through another divide: @@ -1263,7 +1289,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, /// If we have zero-extended operands of an unsigned div or rem, we may be able /// to narrow the operation (sink the zext below the math). static Instruction *narrowUDivURem(BinaryOperator &I, - InstCombiner::BuilderTy &Builder) { + InstCombinerImpl &IC) { Instruction::BinaryOps Opcode = I.getOpcode(); Value *N = I.getOperand(0); Value *D = I.getOperand(1); @@ -1273,7 +1299,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I, X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) { // udiv (zext X), (zext Y) --> zext (udiv X, Y) // urem (zext X), (zext Y) --> zext (urem X, Y) - Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y); + Value *NarrowOp = IC.Builder.CreateBinOp(Opcode, X, Y); return new ZExtInst(NarrowOp, Ty); } @@ -1281,24 +1307,24 @@ static Instruction *narrowUDivURem(BinaryOperator &I, if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) { // If the constant is the same in the smaller type, use the narrow version. - Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getZExt(TruncC, Ty) != C) + Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType()); + if (!TruncC) return nullptr; // udiv (zext X), C --> zext (udiv X, C') // urem (zext X), C --> zext (urem X, C') - return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty); + return new ZExtInst(IC.Builder.CreateBinOp(Opcode, X, TruncC), Ty); } if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C))) { // If the constant is the same in the smaller type, use the narrow version. - Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getZExt(TruncC, Ty) != C) + Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType()); + if (!TruncC) return nullptr; // udiv C, (zext X) --> zext (udiv C', X) // urem C, (zext X) --> zext (urem C', X) - return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty); + return new ZExtInst(IC.Builder.CreateBinOp(Opcode, TruncC, X), Ty); } return nullptr; @@ -1346,7 +1372,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { return CastInst::CreateZExtOrBitCast(Cmp, Ty); } - if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) + if (Instruction *NarrowDiv = narrowUDivURem(I, *this)) return NarrowDiv; // If the udiv operands are non-overflowing multiplies with a common operand, @@ -1405,7 +1431,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined) if (match(Op1, m_AllOnes()) || (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) - return BinaryOperator::CreateNeg(Op0); + return BinaryOperator::CreateNSWNeg(Op0); // X / INT_MIN --> X == INT_MIN if (match(Op1, m_SignMask())) @@ -1428,7 +1454,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { Constant *NegPow2C = ConstantExpr::getNeg(cast<Constant>(Op1)); Constant *C = ConstantExpr::getExactLogBase2(NegPow2C); Value *Ashr = Builder.CreateAShr(Op0, C, I.getName() + ".neg", true); - return BinaryOperator::CreateNeg(Ashr); + return BinaryOperator::CreateNSWNeg(Ashr); } } @@ -1490,7 +1516,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (KnownDividend.isNonNegative()) { // If both operands are unsigned, turn this into a udiv. - if (isKnownNonNegative(Op1, DL, 0, &AC, &I, &DT)) { + if (isKnownNonNegative(Op1, SQ.getWithInstruction(&I))) { auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); BO->setIsExact(I.isExact()); return BO; @@ -1516,6 +1542,13 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { } } + // -X / X --> X == INT_MIN ? 1 : -1 + if (isKnownNegation(Op0, Op1)) { + APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + Value *Cond = Builder.CreateICmpEQ(Op0, ConstantInt::get(Ty, MinVal)); + return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), + ConstantInt::getAllOnesValue(Ty)); + } return nullptr; } @@ -1759,6 +1792,21 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, Pow); } + // powi(X, Y) / X --> powi(X, Y-1) + // This is legal when (Y - 1) can't wraparound, in which case reassoc and nnan + // are required. + // TODO: Multi-use may be also better off creating Powi(x,y-1) + if (I.hasAllowReassoc() && I.hasNoNaNs() && + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::powi>(m_Specific(Op1), + m_Value(Y)))) && + willNotOverflowSignedSub(Y, ConstantInt::get(Y->getType(), 1), I)) { + Constant *NegOne = ConstantInt::getAllOnesValue(Y->getType()); + Value *Y1 = Builder.CreateAdd(Y, NegOne); + Type *Types[] = {Op1->getType(), Y1->getType()}; + Value *Pow = Builder.CreateIntrinsic(Intrinsic::powi, Types, {Op1, Y1}, &I); + return replaceInstUsesWith(I, Pow); + } + return nullptr; } @@ -1936,7 +1984,7 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { if (Instruction *common = commonIRemTransforms(I)) return common; - if (Instruction *NarrowRem = narrowUDivURem(I, Builder)) + if (Instruction *NarrowRem = narrowUDivURem(I, *this)) return NarrowRem; // X urem Y -> X and Y-1, where Y is a power of 2, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index e24abc48424d..513b185c83a4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -20,7 +20,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constant.h" @@ -98,14 +97,13 @@ static cl::opt<unsigned> cl::desc("What is the maximal lookup depth when trying to " "check for viability of negation sinking.")); -Negator::Negator(LLVMContext &C, const DataLayout &DL_, AssumptionCache &AC_, - const DominatorTree &DT_, bool IsTrulyNegation_) - : Builder(C, TargetFolder(DL_), +Negator::Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation_) + : Builder(C, TargetFolder(DL), IRBuilderCallbackInserter([&](Instruction *I) { ++NegatorNumInstructionsCreatedTotal; NewInstructions.push_back(I); })), - DL(DL_), AC(AC_), DT(DT_), IsTrulyNegation(IsTrulyNegation_) {} + IsTrulyNegation(IsTrulyNegation_) {} #if LLVM_ENABLE_STATS Negator::~Negator() { @@ -128,7 +126,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // FIXME: can this be reworked into a worklist-based algorithm while preserving // the depth-first, early bailout traversal? -[[nodiscard]] Value *Negator::visitImpl(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::visitImpl(Value *V, bool IsNSW, unsigned Depth) { // -(undef) -> undef. if (match(V, m_Undef())) return V; @@ -237,7 +235,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // However, only do this either if the old `sub` doesn't stick around, or // it was subtracting from a constant. Otherwise, this isn't profitable. return Builder.CreateSub(I->getOperand(1), I->getOperand(0), - I->getName() + ".neg"); + I->getName() + ".neg", /* HasNUW */ false, + IsNSW && I->hasNoSignedWrap()); } // Some other cases, while still don't require recursion, @@ -302,7 +301,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { switch (I->getOpcode()) { case Instruction::Freeze: { // `freeze` is negatible if its operand is negatible. - Value *NegOp = negate(I->getOperand(0), Depth + 1); + Value *NegOp = negate(I->getOperand(0), IsNSW, Depth + 1); if (!NegOp) // Early return. return nullptr; return Builder.CreateFreeze(NegOp, I->getName() + ".neg"); @@ -313,7 +312,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { SmallVector<Value *, 4> NegatedIncomingValues(PHI->getNumOperands()); for (auto I : zip(PHI->incoming_values(), NegatedIncomingValues)) { if (!(std::get<1>(I) = - negate(std::get<0>(I), Depth + 1))) // Early return. + negate(std::get<0>(I), IsNSW, Depth + 1))) // Early return. return nullptr; } // All incoming values are indeed negatible. Create negated PHI node. @@ -336,10 +335,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return NewSelect; } // `select` is negatible if both hands of `select` are negatible. - Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + Value *NegOp1 = negate(I->getOperand(1), IsNSW, Depth + 1); if (!NegOp1) // Early return. return nullptr; - Value *NegOp2 = negate(I->getOperand(2), Depth + 1); + Value *NegOp2 = negate(I->getOperand(2), IsNSW, Depth + 1); if (!NegOp2) return nullptr; // Do preserve the metadata! @@ -349,10 +348,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { case Instruction::ShuffleVector: { // `shufflevector` is negatible if both operands are negatible. auto *Shuf = cast<ShuffleVectorInst>(I); - Value *NegOp0 = negate(I->getOperand(0), Depth + 1); + Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1); if (!NegOp0) // Early return. return nullptr; - Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + Value *NegOp1 = negate(I->getOperand(1), IsNSW, Depth + 1); if (!NegOp1) return nullptr; return Builder.CreateShuffleVector(NegOp0, NegOp1, Shuf->getShuffleMask(), @@ -361,7 +360,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { case Instruction::ExtractElement: { // `extractelement` is negatible if source operand is negatible. auto *EEI = cast<ExtractElementInst>(I); - Value *NegVector = negate(EEI->getVectorOperand(), Depth + 1); + Value *NegVector = negate(EEI->getVectorOperand(), IsNSW, Depth + 1); if (!NegVector) // Early return. return nullptr; return Builder.CreateExtractElement(NegVector, EEI->getIndexOperand(), @@ -371,10 +370,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // `insertelement` is negatible if both the source vector and // element-to-be-inserted are negatible. auto *IEI = cast<InsertElementInst>(I); - Value *NegVector = negate(IEI->getOperand(0), Depth + 1); + Value *NegVector = negate(IEI->getOperand(0), IsNSW, Depth + 1); if (!NegVector) // Early return. return nullptr; - Value *NegNewElt = negate(IEI->getOperand(1), Depth + 1); + Value *NegNewElt = negate(IEI->getOperand(1), IsNSW, Depth + 1); if (!NegNewElt) // Early return. return nullptr; return Builder.CreateInsertElement(NegVector, NegNewElt, IEI->getOperand(2), @@ -382,15 +381,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { } case Instruction::Trunc: { // `trunc` is negatible if its operand is negatible. - Value *NegOp = negate(I->getOperand(0), Depth + 1); + Value *NegOp = negate(I->getOperand(0), /* IsNSW */ false, Depth + 1); if (!NegOp) // Early return. return nullptr; return Builder.CreateTrunc(NegOp, I->getType(), I->getName() + ".neg"); } case Instruction::Shl: { // `shl` is negatible if the first operand is negatible. - if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) - return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); + IsNSW &= I->hasNoSignedWrap(); + if (Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1)) + return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg", + /* HasNUW */ false, IsNSW); // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); if (!Op1C || !IsTrulyNegation) @@ -398,11 +399,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return Builder.CreateMul( I->getOperand(0), ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C), - I->getName() + ".neg"); + I->getName() + ".neg", /* HasNUW */ false, IsNSW); } case Instruction::Or: { - if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I, - &DT)) + if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) return nullptr; // Don't know how to handle `or` in general. std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); // `or`/`add` are interchangeable when operands have no common bits set. @@ -417,7 +417,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { SmallVector<Value *, 2> NegatedOps, NonNegatedOps; for (Value *Op : I->operands()) { // Can we sink the negation into this operand? - if (Value *NegOp = negate(Op, Depth + 1)) { + if (Value *NegOp = negate(Op, /* IsNSW */ false, Depth + 1)) { NegatedOps.emplace_back(NegOp); // Successfully negated operand! continue; } @@ -446,9 +446,11 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // `xor` is negatible if one of its operands is invertible. // FIXME: InstCombineInverter? But how to connect Inverter and Negator? if (auto *C = dyn_cast<Constant>(Ops[1])) { - Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C)); - return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), - I->getName() + ".neg"); + if (IsTrulyNegation) { + Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C)); + return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), + I->getName() + ".neg"); + } } return nullptr; } @@ -458,16 +460,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { Value *NegatedOp, *OtherOp; // First try the second operand, in case it's a constant it will be best to // just invert it instead of sinking the `neg` deeper. - if (Value *NegOp1 = negate(Ops[1], Depth + 1)) { + if (Value *NegOp1 = negate(Ops[1], /* IsNSW */ false, Depth + 1)) { NegatedOp = NegOp1; OtherOp = Ops[0]; - } else if (Value *NegOp0 = negate(Ops[0], Depth + 1)) { + } else if (Value *NegOp0 = negate(Ops[0], /* IsNSW */ false, Depth + 1)) { NegatedOp = NegOp0; OtherOp = Ops[1]; } else // Can't negate either of them. return nullptr; - return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg"); + return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg", + /* HasNUW */ false, IsNSW && I->hasNoSignedWrap()); } default: return nullptr; // Don't know, likely not negatible for free. @@ -476,7 +479,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { llvm_unreachable("Can't get here. We always return from switch."); } -[[nodiscard]] Value *Negator::negate(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::negate(Value *V, bool IsNSW, unsigned Depth) { NegatorMaxDepthVisited.updateMax(Depth); ++NegatorNumValuesVisited; @@ -506,15 +509,16 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { #endif // No luck. Try negating it for real. - Value *NegatedV = visitImpl(V, Depth); + Value *NegatedV = visitImpl(V, IsNSW, Depth); // And cache the (real) result for the future. NegationsCache[V] = NegatedV; return NegatedV; } -[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root) { - Value *Negated = negate(Root, /*Depth=*/0); +[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root, + bool IsNSW) { + Value *Negated = negate(Root, IsNSW, /*Depth=*/0); if (!Negated) { // We must cleanup newly-inserted instructions, to avoid any potential // endless combine looping. @@ -525,7 +529,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated); } -[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, Value *Root, +[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, bool IsNSW, Value *Root, InstCombinerImpl &IC) { ++NegatorTotalNegationsAttempted; LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root @@ -534,9 +538,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter)) return nullptr; - Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), - IC.getDominatorTree(), LHSIsZero); - std::optional<Result> Res = N.run(Root); + Negator N(Root->getContext(), IC.getDataLayout(), LHSIsZero); + std::optional<Result> Res = N.run(Root, IsNSW); if (!Res) { // Negation failed. LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root << "\n"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 2f6aa85062a5..20b34c1379d5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -248,7 +248,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { PHINode *NewPtrPHI = PHINode::Create( IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); - InsertNewInstBefore(NewPtrPHI, PN); + InsertNewInstBefore(NewPtrPHI, PN.getIterator()); SmallDenseMap<Value *, Instruction *> Casts; for (auto Incoming : zip(PN.blocks(), AvailablePtrVals)) { auto *IncomingBB = std::get<0>(Incoming); @@ -285,10 +285,10 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (isa<PHINode>(IncomingI)) InsertPos = BB->getFirstInsertionPt(); assert(InsertPos != BB->end() && "should have checked above"); - InsertNewInstBefore(CI, *InsertPos); + InsertNewInstBefore(CI, InsertPos); } else { auto *InsertBB = &IncomingBB->getParent()->getEntryBlock(); - InsertNewInstBefore(CI, *InsertBB->getFirstInsertionPt()); + InsertNewInstBefore(CI, InsertBB->getFirstInsertionPt()); } } NewPtrPHI->addIncoming(CI, IncomingBB); @@ -353,7 +353,7 @@ InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) { NewOperand->addIncoming( cast<InsertValueInst>(std::get<1>(Incoming))->getOperand(OpIdx), std::get<0>(Incoming)); - InsertNewInstBefore(NewOperand, PN); + InsertNewInstBefore(NewOperand, PN.getIterator()); } // And finally, create `insertvalue` over the newly-formed PHI nodes. @@ -391,7 +391,7 @@ InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) { NewAggregateOperand->addIncoming( cast<ExtractValueInst>(std::get<1>(Incoming))->getAggregateOperand(), std::get<0>(Incoming)); - InsertNewInstBefore(NewAggregateOperand, PN); + InsertNewInstBefore(NewAggregateOperand, PN.getIterator()); // And finally, create `extractvalue` over the newly-formed PHI nodes. auto *NewEVI = ExtractValueInst::Create(NewAggregateOperand, @@ -450,7 +450,7 @@ Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) { NewLHS = PHINode::Create(LHSType, PN.getNumIncomingValues(), FirstInst->getOperand(0)->getName() + ".pn"); NewLHS->addIncoming(InLHS, PN.getIncomingBlock(0)); - InsertNewInstBefore(NewLHS, PN); + InsertNewInstBefore(NewLHS, PN.getIterator()); LHSVal = NewLHS; } @@ -458,7 +458,7 @@ Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) { NewRHS = PHINode::Create(RHSType, PN.getNumIncomingValues(), FirstInst->getOperand(1)->getName() + ".pn"); NewRHS->addIncoming(InRHS, PN.getIncomingBlock(0)); - InsertNewInstBefore(NewRHS, PN); + InsertNewInstBefore(NewRHS, PN.getIterator()); RHSVal = NewRHS; } @@ -581,7 +581,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { Value *FirstOp = FirstInst->getOperand(I); PHINode *NewPN = PHINode::Create(FirstOp->getType(), E, FirstOp->getName() + ".pn"); - InsertNewInstBefore(NewPN, PN); + InsertNewInstBefore(NewPN, PN.getIterator()); NewPN->addIncoming(FirstOp, PN.getIncomingBlock(0)); OperandPhis[I] = NewPN; @@ -769,7 +769,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { NewLI->setOperand(0, InVal); delete NewPN; } else { - InsertNewInstBefore(NewPN, PN); + InsertNewInstBefore(NewPN, PN.getIterator()); } // If this was a volatile load that we are merging, make sure to loop through @@ -825,8 +825,8 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) { NumZexts++; } else if (auto *C = dyn_cast<Constant>(V)) { // Make sure that constants can fit in the new type. - Constant *Trunc = ConstantExpr::getTrunc(C, NarrowType); - if (ConstantExpr::getZExt(Trunc, C->getType()) != C) + Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType); + if (!Trunc) return nullptr; NewIncoming.push_back(Trunc); NumConsts++; @@ -853,7 +853,7 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) { for (unsigned I = 0; I != NumIncomingValues; ++I) NewPhi->addIncoming(NewIncoming[I], Phi.getIncomingBlock(I)); - InsertNewInstBefore(NewPhi, Phi); + InsertNewInstBefore(NewPhi, Phi.getIterator()); return CastInst::CreateZExtOrBitCast(NewPhi, Phi.getType()); } @@ -943,7 +943,7 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) { PhiVal = InVal; delete NewPN; } else { - InsertNewInstBefore(NewPN, PN); + InsertNewInstBefore(NewPN, PN.getIterator()); PhiVal = NewPN; } @@ -996,8 +996,8 @@ static bool isDeadPHICycle(PHINode *PN, /// Return true if this phi node is always equal to NonPhiInVal. /// This happens with mutually cyclic phi nodes like: /// z = some value; x = phi (y, z); y = phi (x, z) -static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, - SmallPtrSetImpl<PHINode*> &ValueEqualPHIs) { +static bool PHIsEqualValue(PHINode *PN, Value *&NonPhiInVal, + SmallPtrSetImpl<PHINode *> &ValueEqualPHIs) { // See if we already saw this PHI node. if (!ValueEqualPHIs.insert(PN).second) return true; @@ -1010,8 +1010,11 @@ static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, // the value. for (Value *Op : PN->incoming_values()) { if (PHINode *OpPN = dyn_cast<PHINode>(Op)) { - if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) - return false; + if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) { + if (NonPhiInVal) + return false; + NonPhiInVal = OpPN; + } } else if (Op != NonPhiInVal) return false; } @@ -1368,7 +1371,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // sinking. auto InsertPt = BB->getFirstInsertionPt(); if (InsertPt != BB->end()) { - Self.Builder.SetInsertPoint(&*InsertPt); + Self.Builder.SetInsertPoint(&*BB, InsertPt); return Self.Builder.CreateNot(Cond); } @@ -1437,22 +1440,45 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // are induction variable analysis (sometimes) and ADCE, which is only run // late. if (PHIUser->hasOneUse() && - (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) && + (isa<BinaryOperator>(PHIUser) || isa<UnaryOperator>(PHIUser) || + isa<GetElementPtrInst>(PHIUser)) && PHIUser->user_back() == &PN) { return replaceInstUsesWith(PN, PoisonValue::get(PN.getType())); } - // When a PHI is used only to be compared with zero, it is safe to replace - // an incoming value proved as known nonzero with any non-zero constant. - // For example, in the code below, the incoming value %v can be replaced - // with any non-zero constant based on the fact that the PHI is only used to - // be compared with zero and %v is a known non-zero value: - // %v = select %cond, 1, 2 - // %p = phi [%v, BB] ... - // icmp eq, %p, 0 - auto *CmpInst = dyn_cast<ICmpInst>(PHIUser); - // FIXME: To be simple, handle only integer type for now. - if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() && - match(CmpInst->getOperand(1), m_Zero())) { + } + + // When a PHI is used only to be compared with zero, it is safe to replace + // an incoming value proved as known nonzero with any non-zero constant. + // For example, in the code below, the incoming value %v can be replaced + // with any non-zero constant based on the fact that the PHI is only used to + // be compared with zero and %v is a known non-zero value: + // %v = select %cond, 1, 2 + // %p = phi [%v, BB] ... + // icmp eq, %p, 0 + // FIXME: To be simple, handle only integer type for now. + // This handles a small number of uses to keep the complexity down, and an + // icmp(or(phi)) can equally be replaced with any non-zero constant as the + // "or" will only add bits. + if (!PN.hasNUsesOrMore(3)) { + SmallVector<Instruction *> DropPoisonFlags; + bool AllUsesOfPhiEndsInCmp = all_of(PN.users(), [&](User *U) { + auto *CmpInst = dyn_cast<ICmpInst>(U); + if (!CmpInst) { + // This is always correct as OR only add bits and we are checking + // against 0. + if (U->hasOneUse() && match(U, m_c_Or(m_Specific(&PN), m_Value()))) { + DropPoisonFlags.push_back(cast<Instruction>(U)); + CmpInst = dyn_cast<ICmpInst>(U->user_back()); + } + } + if (!CmpInst || !isa<IntegerType>(PN.getType()) || + !CmpInst->isEquality() || !match(CmpInst->getOperand(1), m_Zero())) { + return false; + } + return true; + }); + // All uses of PHI results in a compare with zero. + if (AllUsesOfPhiEndsInCmp) { ConstantInt *NonZeroConst = nullptr; bool MadeChange = false; for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { @@ -1461,9 +1487,11 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { if (!NonZeroConst) NonZeroConst = getAnyNonZeroConstInt(PN); - if (NonZeroConst != VA) { replaceOperand(PN, I, NonZeroConst); + // The "disjoint" flag may no longer hold after the transform. + for (Instruction *I : DropPoisonFlags) + I->dropPoisonGeneratingFlags(); MadeChange = true; } } @@ -1478,7 +1506,9 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // z = some value; x = phi (y, z); y = phi (x, z) // where the phi nodes don't necessarily need to be in the same block. Do a // quick check to see if the PHI node only contains a single non-phi value, if - // so, scan to see if the phi cycle is actually equal to that value. + // so, scan to see if the phi cycle is actually equal to that value. If the + // phi has no non-phi values then allow the "NonPhiInVal" to be set later if + // one of the phis itself does not have a single input. { unsigned InValNo = 0, NumIncomingVals = PN.getNumIncomingValues(); // Scan for the first non-phi operand. @@ -1486,25 +1516,25 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { isa<PHINode>(PN.getIncomingValue(InValNo))) ++InValNo; - if (InValNo != NumIncomingVals) { - Value *NonPhiInVal = PN.getIncomingValue(InValNo); + Value *NonPhiInVal = + InValNo != NumIncomingVals ? PN.getIncomingValue(InValNo) : nullptr; - // Scan the rest of the operands to see if there are any conflicts, if so - // there is no need to recursively scan other phis. + // Scan the rest of the operands to see if there are any conflicts, if so + // there is no need to recursively scan other phis. + if (NonPhiInVal) for (++InValNo; InValNo != NumIncomingVals; ++InValNo) { Value *OpVal = PN.getIncomingValue(InValNo); if (OpVal != NonPhiInVal && !isa<PHINode>(OpVal)) break; } - // If we scanned over all operands, then we have one unique value plus - // phi values. Scan PHI nodes to see if they all merge in each other or - // the value. - if (InValNo == NumIncomingVals) { - SmallPtrSet<PHINode*, 16> ValueEqualPHIs; - if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) - return replaceInstUsesWith(PN, NonPhiInVal); - } + // If we scanned over all operands, then we have one unique value plus + // phi values. Scan PHI nodes to see if they all merge in each other or + // the value. + if (InValNo == NumIncomingVals) { + SmallPtrSet<PHINode *, 16> ValueEqualPHIs; + if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) + return replaceInstUsesWith(PN, NonPhiInVal); } } @@ -1512,11 +1542,12 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // the blocks in the same order. This will help identical PHIs be eliminated // by other passes. Other passes shouldn't depend on this for correctness // however. - PHINode *FirstPN = cast<PHINode>(PN.getParent()->begin()); - if (&PN != FirstPN) - for (unsigned I = 0, E = FirstPN->getNumIncomingValues(); I != E; ++I) { + auto Res = PredOrder.try_emplace(PN.getParent()); + if (!Res.second) { + const auto &Preds = Res.first->second; + for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { BasicBlock *BBA = PN.getIncomingBlock(I); - BasicBlock *BBB = FirstPN->getIncomingBlock(I); + BasicBlock *BBB = Preds[I]; if (BBA != BBB) { Value *VA = PN.getIncomingValue(I); unsigned J = PN.getBasicBlockIndex(BBB); @@ -1531,6 +1562,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // this in this case. } } + } else { + // Remember the block order of the first encountered phi node. + append_range(Res.first->second, PN.blocks()); + } // Is there an identical PHI node in this basic block? for (PHINode &IdenticalPN : PN.getParent()->phis()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 661c50062223..2dda46986f0f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -689,34 +689,40 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, } /// We want to turn: -/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) +/// (select (icmp eq (and X, C1), 0), Y, (BinOp Y, C2)) /// into: -/// (or (shl (and X, C1), C3), Y) +/// IF C2 u>= C1 +/// (BinOp Y, (shl (and X, C1), C3)) +/// ELSE +/// (BinOp Y, (lshr (and X, C1), C3)) /// iff: +/// 0 on the RHS is the identity value (i.e add, xor, shl, etc...) /// C1 and C2 are both powers of 2 /// where: -/// C3 = Log(C2) - Log(C1) +/// IF C2 u>= C1 +/// C3 = Log(C2) - Log(C1) +/// ELSE +/// C3 = Log(C1) - Log(C2) /// /// This transform handles cases where: /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, +static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy &Builder) { // Only handle integer compares. Also, if this is a vector select, we need a // vector compare. if (!TrueVal->getType()->isIntOrIntVectorTy() || - TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) + TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); Value *CmpRHS = IC->getOperand(1); - Value *V; unsigned C1Log; - bool IsEqualZero; bool NeedAnd = false; + CmpInst::Predicate Pred = IC->getPredicate(); if (IC->isEquality()) { if (!match(CmpRHS, m_Zero())) return nullptr; @@ -725,49 +731,49 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) return nullptr; - V = CmpLHS; C1Log = C1->logBase2(); - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; - } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || - IC->getPredicate() == ICmpInst::ICMP_SGT) { - // We also need to recognize (icmp slt (trunc (X)), 0) and - // (icmp sgt (trunc (X)), -1). - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; - if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || - (!IsEqualZero && !match(CmpRHS, m_Zero()))) - return nullptr; - - if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) + } else { + APInt C1; + if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) || + !C1.isPowerOf2()) return nullptr; - C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + C1Log = C1.logBase2(); NeedAnd = true; - } else { - return nullptr; } + Value *Y, *V = CmpLHS; + BinaryOperator *BinOp; const APInt *C2; - bool OrOnTrueVal = false; - bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2))); - if (!OrOnFalseVal) - OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2))); - - if (!OrOnFalseVal && !OrOnTrueVal) + bool NeedXor; + if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) { + Y = TrueVal; + BinOp = cast<BinaryOperator>(FalseVal); + NeedXor = Pred == ICmpInst::ICMP_NE; + } else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) { + Y = FalseVal; + BinOp = cast<BinaryOperator>(TrueVal); + NeedXor = Pred == ICmpInst::ICMP_EQ; + } else { return nullptr; + } - Value *Y = OrOnFalseVal ? TrueVal : FalseVal; + // Check that 0 on RHS is identity value for this binop. + auto *IdentityC = + ConstantExpr::getBinOpIdentity(BinOp->getOpcode(), BinOp->getType(), + /*AllowRHSConstant*/ true); + if (IdentityC == nullptr || !IdentityC->isNullValue()) + return nullptr; unsigned C2Log = C2->logBase2(); - bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. - Value *Or = OrOnFalseVal ? FalseVal : TrueVal; - if ((NeedShift + NeedXor + NeedZExtTrunc) > - (IC->hasOneUse() + Or->hasOneUse())) + if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) > + (IC->hasOneUse() + BinOp->hasOneUse())) return nullptr; if (NeedAnd) { @@ -788,7 +794,7 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, if (NeedXor) V = Builder.CreateXor(V, *C2); - return Builder.CreateOr(V, Y); + return Builder.CreateBinOp(BinOp->getOpcode(), Y, V); } /// Canonicalize a set or clear of a masked set of constant bits to @@ -870,7 +876,7 @@ static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { auto *FalseValI = cast<Instruction>(FalseVal); auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"), - *FalseValI); + FalseValI->getIterator()); IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); return IC.replaceInstUsesWith(SI, FalseValI); } @@ -1303,45 +1309,28 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, return nullptr; // InstSimplify already performed this fold if it was possible subject to - // current poison-generating flags. Try the transform again with - // poison-generating flags temporarily dropped. - bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false; - if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) { - WasNUW = OBO->hasNoUnsignedWrap(); - WasNSW = OBO->hasNoSignedWrap(); - FalseInst->setHasNoUnsignedWrap(false); - FalseInst->setHasNoSignedWrap(false); - } - if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) { - WasExact = PEO->isExact(); - FalseInst->setIsExact(false); - } - if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) { - WasInBounds = GEP->isInBounds(); - GEP->setIsInBounds(false); - } + // current poison-generating flags. Check whether dropping poison-generating + // flags enables the transform. // Try each equivalence substitution possibility. // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 + SmallVector<Instruction *> DropFlags; if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ false) == TrueVal || + /* AllowRefinement */ false, + &DropFlags) == TrueVal || simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ false) == TrueVal) { + /* AllowRefinement */ false, + &DropFlags) == TrueVal) { + for (Instruction *I : DropFlags) { + I->dropPoisonGeneratingFlagsAndMetadata(); + Worklist.add(I); + } + return replaceInstUsesWith(Sel, FalseVal); } - // Restore poison-generating flags if the transform did not apply. - if (WasNUW) - FalseInst->setHasNoUnsignedWrap(); - if (WasNSW) - FalseInst->setHasNoSignedWrap(); - if (WasExact) - FalseInst->setIsExact(); - if (WasInBounds) - cast<GetElementPtrInst>(FalseInst)->setIsInBounds(); - return nullptr; } @@ -1506,8 +1495,13 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!match(ReplacementLow, m_ImmConstant(LowC)) || !match(ReplacementHigh, m_ImmConstant(HighC))) return nullptr; - ReplacementLow = ConstantExpr::getSExt(LowC, X->getType()); - ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType()); + const DataLayout &DL = Sel0.getModule()->getDataLayout(); + ReplacementLow = + ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL); + ReplacementHigh = + ConstantFoldCastOperand(Instruction::SExt, HighC, X->getType(), DL); + assert(ReplacementLow && ReplacementHigh && + "Constant folding of ImmConstant cannot fail"); } // All good, finally emit the new pattern. @@ -1797,7 +1791,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder)) return V; - if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) @@ -2094,9 +2088,8 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { // If the constant is the same after truncation to the smaller type and // extension to the original type, we can narrow the select. Type *SelType = Sel.getType(); - Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); - Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); - if (ExtC == C && ExtInst->hasOneUse()) { + Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode); + if (TruncC && ExtInst->hasOneUse()) { Value *TruncCVal = cast<Value>(TruncC); if (ExtInst == Sel.getFalseValue()) std::swap(X, TruncCVal); @@ -2107,23 +2100,6 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); } - // If one arm of the select is the extend of the condition, replace that arm - // with the extension of the appropriate known bool value. - if (Cond == X) { - if (ExtInst == Sel.getTrueValue()) { - // select X, (sext X), C --> select X, -1, C - // select X, (zext X), C --> select X, 1, C - Constant *One = ConstantInt::getTrue(SmallType); - Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType); - return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel); - } else { - // select X, C, (sext X) --> select X, C, 0 - // select X, C, (zext X) --> select X, C, 0 - Constant *Zero = ConstantInt::getNullValue(SelType); - return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel); - } - } - return nullptr; } @@ -2561,7 +2537,7 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, return nullptr; } - Builder.SetInsertPoint(&*BB->begin()); + Builder.SetInsertPoint(BB, BB->begin()); auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size()); for (auto *Pred : predecessors(BB)) PN->addIncoming(Inputs[Pred], Pred); @@ -2584,6 +2560,61 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, return nullptr; } +/// Tries to reduce a pattern that arises when calculating the remainder of the +/// Euclidean division. When the divisor is a power of two and is guaranteed not +/// to be negative, a signed remainder can be folded with a bitwise and. +/// +/// (x % n) < 0 ? (x % n) + n : (x % n) +/// -> x & (n - 1) +static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, + IRBuilderBase &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + ICmpInst::Predicate Pred; + Value *Op, *RemRes, *Remainder; + const APInt *C; + bool TrueIfSigned = false; + + if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) && + IC.isSignBitCheck(Pred, *C, TrueIfSigned))) + return nullptr; + + // If the sign bit is not set, we have a SGE/SGT comparison, and the operands + // of the select are inverted. + if (!TrueIfSigned) + std::swap(TrueVal, FalseVal); + + auto FoldToBitwiseAnd = [&](Value *Remainder) -> Instruction * { + Value *Add = Builder.CreateAdd( + Remainder, Constant::getAllOnesValue(RemRes->getType())); + return BinaryOperator::CreateAnd(Op, Add); + }; + + // Match the general case: + // %rem = srem i32 %x, %n + // %cnd = icmp slt i32 %rem, 0 + // %add = add i32 %rem, %n + // %sel = select i1 %cnd, i32 %add, i32 %rem + if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) && + match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) && + IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) && + FalseVal == RemRes) + return FoldToBitwiseAnd(Remainder); + + // Match the case where the one arm has been replaced by constant 1: + // %rem = srem i32 %n, 2 + // %cnd = icmp slt i32 %rem, 0 + // %sel = select i1 %cnd, i32 1, i32 %rem + if (match(TrueVal, m_One()) && + match(RemRes, m_SRem(m_Value(Op), m_SpecificInt(2))) && + FalseVal == RemRes) + return FoldToBitwiseAnd(ConstantInt::get(RemRes->getType(), 2)); + + return nullptr; +} + static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition()); if (!FI) @@ -2860,8 +2891,15 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal, std::swap(InnerSel.TrueVal, InnerSel.FalseVal); Value *AltCond = nullptr; - auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) { - return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond))); + auto matchOuterCond = [OuterSel, IsAndVariant, &AltCond](auto m_InnerCond) { + // An unsimplified select condition can match both LogicalAnd and LogicalOr + // (select true, true, false). Since below we assume that LogicalAnd implies + // InnerSel match the FVal and vice versa for LogicalOr, we can't match the + // alternative pattern here. + return IsAndVariant ? match(OuterSel.Cond, + m_c_LogicalAnd(m_InnerCond, m_Value(AltCond))) + : match(OuterSel.Cond, + m_c_LogicalOr(m_InnerCond, m_Value(AltCond))); }; // Finally, match the condition that was driving the outermost `select`, @@ -3024,31 +3062,37 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); - // select a, (select ~a, true, b), false -> select a, b, false - if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) && - match(FalseVal, m_Zero())) - return replaceOperand(SI, 1, B); - // select a, true, (select ~a, b, false) -> select a, true, b - if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) && - match(TrueVal, m_One())) - return replaceOperand(SI, 2, B); // ~(A & B) & (A | B) --> A ^ B if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), m_c_LogicalOr(m_Deferred(A), m_Deferred(B))))) return BinaryOperator::CreateXor(A, B); - // select (~a | c), a, b -> and a, (or c, freeze(b)) - if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && - CondVal->hasOneUse()) { - FalseVal = Builder.CreateFreeze(FalseVal); - return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); + // select (~a | c), a, b -> select a, (select c, true, b), false + if (match(CondVal, + m_OneUse(m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))))) { + Value *OrV = Builder.CreateSelect(C, One, FalseVal); + return SelectInst::Create(TrueVal, OrV, Zero); + } + // select (c & b), a, b -> select b, (select ~c, true, a), false + if (match(CondVal, m_OneUse(m_c_And(m_Value(C), m_Specific(FalseVal))))) { + if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) { + Value *OrV = Builder.CreateSelect(NotC, One, TrueVal); + return SelectInst::Create(FalseVal, OrV, Zero); + } + } + // select (a | c), a, b -> select a, true, (select ~c, b, false) + if (match(CondVal, m_OneUse(m_c_Or(m_Specific(TrueVal), m_Value(C))))) { + if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) { + Value *AndV = Builder.CreateSelect(NotC, FalseVal, Zero); + return SelectInst::Create(TrueVal, One, AndV); + } } - // select (~c & b), a, b -> and b, (or freeze(a), c) - if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && - CondVal->hasOneUse()) { - TrueVal = Builder.CreateFreeze(TrueVal); - return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + // select (c & ~b), a, b -> select b, true, (select c, a, false) + if (match(CondVal, + m_OneUse(m_c_And(m_Value(C), m_Not(m_Specific(FalseVal)))))) { + Value *AndV = Builder.CreateSelect(C, TrueVal, Zero); + return SelectInst::Create(FalseVal, One, AndV); } if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { @@ -3057,7 +3101,7 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { Value *Op1 = IsAnd ? TrueVal : FalseVal; if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); - InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); + InsertNewInstBefore(FI, cast<Instruction>(Y->getUser())->getIterator()); replaceUse(*Y, FI); return replaceInstUsesWith(SI, Op1); } @@ -3272,6 +3316,31 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { Masked); } +bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF, + const Instruction *CtxI) const { + KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI); + + return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity() && + (FMF.noSignedZeros() || Known.signBitIsZeroOrNaN()); +} + +static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0, + Value *Cmp1, Value *TrueVal, + Value *FalseVal, Instruction &CtxI, + bool SelectIsNSZ) { + Value *MulRHS; + if (match(Cmp1, m_PosZeroFP()) && + match(TrueVal, m_c_FMul(m_Specific(Cmp0), m_Value(MulRHS)))) { + FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags(); + // nsz must be on the select, it must be ignored on the multiply. We + // need nnan and ninf on the multiply for the other value. + FMF.setNoSignedZeros(SelectIsNSZ); + return IC.fmulByZeroIsZero(MulRHS, FMF, &CtxI); + } + + return false; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3303,28 +3372,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { ConstantInt::getFalse(CondType), SQ, /* AllowRefinement */ true)) return replaceOperand(SI, 2, S); - - // Handle patterns involving sext/zext + not explicitly, - // as simplifyWithOpReplaced() only looks past one instruction. - Value *NotCond; - - // select a, sext(!a), b -> select !a, b, 0 - // select a, zext(!a), b -> select !a, b, 0 - if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond), - m_Not(m_Specific(CondVal)))))) - return SelectInst::Create(NotCond, FalseVal, - Constant::getNullValue(SelType)); - - // select a, b, zext(!a) -> select !a, 1, b - if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond), - m_Not(m_Specific(CondVal)))))) - return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal); - - // select a, b, sext(!a) -> select !a, -1, b - if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond), - m_Not(m_Specific(CondVal)))))) - return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType), - TrueVal); } if (Instruction *R = foldSelectOfBools(SI)) @@ -3362,7 +3409,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + auto *SIFPOp = dyn_cast<FPMathOperator>(&SI); + if (auto *FCmp = dyn_cast<FCmpInst>(CondVal)) { + FCmpInst::Predicate Pred = FCmp->getPredicate(); Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1); // Are we selecting a value based on a comparison of the two values? if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || @@ -3372,7 +3422,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // // e.g. // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X - if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) { + if (FCmp->hasOneUse() && FCmpInst::isUnordered(Pred)) { FCmpInst::Predicate InvPred = FCmp->getInversePredicate(); IRBuilder<>::FastMathFlagGuard FMFG(Builder); // FIXME: The FMF should propagate from the select, not the fcmp. @@ -3383,14 +3433,47 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { return replaceInstUsesWith(SI, NewSel); } } + + if (SIFPOp) { + // Fold out scale-if-equals-zero pattern. + // + // This pattern appears in code with denormal range checks after it's + // assumed denormals are treated as zero. This drops a canonicalization. + + // TODO: Could relax the signed zero logic. We just need to know the sign + // of the result matches (fmul x, y has the same sign as x). + // + // TODO: Handle always-canonicalizing variant that selects some value or 1 + // scaling factor in the fmul visitor. + + // TODO: Handle ldexp too + + Value *MatchCmp0 = nullptr; + Value *MatchCmp1 = nullptr; + + // (select (fcmp [ou]eq x, 0.0), (fmul x, K), x => x + // (select (fcmp [ou]ne x, 0.0), x, (fmul x, K) => x + if (Pred == CmpInst::FCMP_OEQ || Pred == CmpInst::FCMP_UEQ) { + MatchCmp0 = FalseVal; + MatchCmp1 = TrueVal; + } else if (Pred == CmpInst::FCMP_ONE || Pred == CmpInst::FCMP_UNE) { + MatchCmp0 = TrueVal; + MatchCmp1 = FalseVal; + } + + if (Cmp0 == MatchCmp0 && + matchFMulByZeroIfResultEqZero(*this, Cmp0, Cmp1, MatchCmp1, MatchCmp0, + SI, SIFPOp->hasNoSignedZeros())) + return replaceInstUsesWith(SI, Cmp0); + } } - if (isa<FPMathOperator>(SI)) { + if (SIFPOp) { // TODO: Try to forward-propagate FMF from select arms to the select. // Canonicalize select of FP values where NaN and -0.0 are not valid as // minnum/maxnum intrinsics. - if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) { + if (SIFPOp->hasNoNaNs() && SIFPOp->hasNoSignedZeros()) { Value *X, *Y; if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) return replaceInstUsesWith( @@ -3430,6 +3513,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectExtConst(SI)) return I; + if (Instruction *I = foldSelectWithSRem(SI, *this, Builder)) + return I; + // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0)) // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 89dad455f015..b7958978c450 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -136,9 +136,14 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( assert(IdenticalShOpcodes && "Should not get here with different shifts."); - // All good, we can do this fold. - NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); + if (NewShAmt->getType() != X->getType()) { + NewShAmt = ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, + X->getType(), SQ.DL); + if (!NewShAmt) + return nullptr; + } + // All good, we can do this fold. BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); // The flags can only be propagated if there wasn't a trunc. @@ -245,7 +250,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, SumOfShAmts = Constant::replaceUndefsWith( SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), ExtendedTy->getScalarSizeInBits())); - auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + auto *ExtendedSumOfShAmts = ConstantFoldCastOperand( + Instruction::ZExt, SumOfShAmts, ExtendedTy, Q.DL); + if (!ExtendedSumOfShAmts) + return nullptr; + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); auto *ExtendedInvertedMask = @@ -278,16 +287,22 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, ShAmtsDiff = Constant::replaceUndefsWith( ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -WidestTyBitWidth)); - auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + auto *ExtendedNumHighBitsToClear = ConstantFoldCastOperand( + Instruction::ZExt, ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), WidestTyBitWidth, /*isSigned=*/false), ShAmtsDiff), - ExtendedTy); + ExtendedTy, Q.DL); + if (!ExtendedNumHighBitsToClear) + return nullptr; + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); - NewMask = - ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + NewMask = ConstantFoldBinaryOpOperands(Instruction::LShr, ExtendedAllOnes, + ExtendedNumHighBitsToClear, Q.DL); + if (!NewMask) + return nullptr; } else return nullptr; // Don't know anything about this pattern. @@ -545,8 +560,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, /// this succeeds, getShiftedValue() will be called to produce the value. static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombinerImpl &IC, Instruction *CxtI) { - // We can always evaluate constants shifted. - if (isa<Constant>(V)) + // We can always evaluate immediate constants. + if (match(V, m_ImmConstant())) return true; Instruction *I = dyn_cast<Instruction>(V); @@ -709,13 +724,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Mul: { assert(!isLeftShift && "Unexpected shift direction!"); auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0)); - IC.InsertNewInstWith(Neg, *I); + IC.InsertNewInstWith(Neg, I->getIterator()); unsigned TypeWidth = I->getType()->getScalarSizeInBits(); APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits); auto *And = BinaryOperator::CreateAnd(Neg, ConstantInt::get(I->getType(), Mask)); And->takeName(I); - return IC.InsertNewInstWith(And, *I); + return IC.InsertNewInstWith(And, I->getIterator()); } } } @@ -745,7 +760,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, // (C2 >> X) >> C1 --> (C2 >> C1) >> X Constant *C2; Value *X; - if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X)))) + if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X)))) return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); @@ -928,6 +943,60 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { return new ZExtInst(Overflow, Ty); } +// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits. +static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) { + assert(I.isShift() && "Expected a shift as input"); + // We already have all the flags. + if (I.getOpcode() == Instruction::Shl) { + if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap()) + return false; + } else { + if (I.isExact()) + return false; + + // shr (shl X, Y), Y + if (match(I.getOperand(0), m_Shl(m_Value(), m_Specific(I.getOperand(1))))) { + I.setIsExact(); + return true; + } + } + + // Compute what we know about shift count. + KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q); + unsigned BitWidth = KnownCnt.getBitWidth(); + // Since shift produces a poison value if RHS is equal to or larger than the + // bit width, we can safely assume that RHS is less than the bit width. + uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1); + + KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q); + bool Changed = false; + + if (I.getOpcode() == Instruction::Shl) { + // If we have as many leading zeros than maximum shift cnt we have nuw. + if (!I.hasNoUnsignedWrap() && MaxCnt <= KnownAmt.countMinLeadingZeros()) { + I.setHasNoUnsignedWrap(); + Changed = true; + } + // If we have more sign bits than maximum shift cnt we have nsw. + if (!I.hasNoSignedWrap()) { + if (MaxCnt < KnownAmt.countMinSignBits() || + MaxCnt < ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, + Q.CxtI, Q.DT)) { + I.setHasNoSignedWrap(); + Changed = true; + } + } + return Changed; + } + + // If we have at least as many trailing zeros as maximum count then we have + // exact. + Changed = MaxCnt <= KnownAmt.countMinTrailingZeros(); + I.setIsExact(Changed); + + return Changed; +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -976,7 +1045,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoUnsignedWrap( + I.hasNoUnsignedWrap() || + (ShrAmt && + cast<Instruction>(Op0)->getOpcode() == Instruction::LShr && + I.hasNoSignedWrap())); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); return NewShl; } @@ -997,7 +1070,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoUnsignedWrap( + I.hasNoUnsignedWrap() || + (ShrAmt && + cast<Instruction>(Op0)->getOpcode() == Instruction::LShr && + I.hasNoSignedWrap())); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); Builder.Insert(NewShl); APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); @@ -1108,22 +1185,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Value *NewShift = Builder.CreateShl(X, Op1); return BinaryOperator::CreateSub(NewLHS, NewShift); } - - // If the shifted-out value is known-zero, then this is a NUW shift. - if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, - &I)) { - I.setHasNoUnsignedWrap(); - return &I; - } - - // If the shifted-out value is all signbits, then this is a NSW shift. - if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { - I.setHasNoSignedWrap(); - return &I; - } } + if (setShiftFlags(I, Q)) + return &I; + // Transform (x >> y) << y to x & (-1 << y) // Valid for any type of right-shift. Value *X; @@ -1161,15 +1227,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Value *NegX = Builder.CreateNeg(X, "neg"); return BinaryOperator::CreateAnd(NegX, X); } - - // The only way to shift out the 1 is with an over-shift, so that would - // be poison with or without "nuw". Undef is excluded because (undef << X) - // is not undef (it is zero). - Constant *ConstantOne = cast<Constant>(Op0); - if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) { - I.setHasNoUnsignedWrap(); - return &I; - } } return nullptr; @@ -1235,9 +1292,10 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { unsigned ShlAmtC = C1->getZExtValue(); Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) + // (X <<nuw C1) >>u C --> X <<nuw/nsw (C1 - C) auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(true); + NewShl->setHasNoSignedWrap(ShAmtC > 0); return NewShl; } if (Op0->hasOneUse()) { @@ -1370,12 +1428,13 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Op0->hasOneUse()) { APInt NewMulC = MulC->lshr(ShAmtC); // if c is divisible by (1 << ShAmtC): - // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC) + // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC) if (MulC->eq(NewMulC.shl(ShAmtC))) { auto *NewMul = BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); - BinaryOperator *OrigMul = cast<BinaryOperator>(Op0); - NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap()); + assert(ShAmtC != 0 && + "lshr X, 0 should be handled by simplifyLShrInst."); + NewMul->setHasNoSignedWrap(true); return NewMul; } } @@ -1414,15 +1473,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *And = Builder.CreateAnd(BoolX, BoolY); return new ZExtInst(And, Ty); } - - // If the shifted-out value is known-zero, then this is an exact shift. - if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { - I.setIsExact(); - return &I; - } } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (setShiftFlags(I, Q)) + return &I; + // Transform (x << y) >> y to x & (-1 >> y) if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); @@ -1581,15 +1637,12 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); } - - // If the shifted-out value is known-zero, then this is an exact shift. - if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { - I.setIsExact(); - return &I; - } } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (setShiftFlags(I, Q)) + return &I; + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` // as the pattern to splat the lowest bit. // FIXME: iff X is already masked, we don't need the one-use check. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 00eece9534b0..046ce9d1207e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -24,6 +24,12 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +static cl::opt<bool> + VerifyKnownBits("instcombine-verify-known-bits", + cl::desc("Verify that computeKnownBits() and " + "SimplifyDemandedBits() are consistent"), + cl::Hidden, cl::init(false)); + /// Check to see if the specified operand of the specified instruction is a /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return true. @@ -48,15 +54,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, return true; } +/// Returns the bitwidth of the given scalar or pointer type. For vector types, +/// returns the element type's bitwidth. +static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { + if (unsigned BitWidth = Ty->getScalarSizeInBits()) + return BitWidth; + return DL.getPointerTypeSizeInBits(Ty); +} /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if /// the instruction has any properties that allow us to simplify its operands. -bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { - unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); - KnownBits Known(BitWidth); - APInt DemandedMask(APInt::getAllOnes(BitWidth)); - +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst, + KnownBits &Known) { + APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth())); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); if (!V) return false; @@ -65,6 +76,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { return true; } +/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if +/// the instruction has any properties that allow us to simplify its operands. +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { + KnownBits Known(getBitWidth(Inst.getType(), DL)); + return SimplifyDemandedInstructionBits(Inst, Known); +} + /// This form of SimplifyDemandedBits simplifies the specified instruction /// operand if possible, updating it in place. It returns true if it made any /// change and false otherwise. @@ -95,8 +113,8 @@ bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, /// expression. /// Known.One and Known.Zero always follow the invariant that: /// Known.One & Known.Zero == 0. -/// That is, a bit can't be both 1 and 0. Note that the bits in Known.One and -/// Known.Zero may only be accurate for those bits set in DemandedMask. Note +/// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero +/// are accurate even for bits not in DemandedMask. Note /// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all /// be the same. /// @@ -143,7 +161,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI); KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); - // If this is the root being simplified, allow it to have multiple uses, // just set the DemandedMask to all bits so that we can try to simplify the // operands. This allows visitTruncInst (for example) to simplify the @@ -196,7 +213,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, DL, &AC, CxtI, &DT); + Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -220,13 +237,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If either the LHS or the RHS are One, the result is One. if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, - Depth + 1)) + Depth + 1)) { + // Disjoint flag may not longer hold. + I->dropPoisonGeneratingFlags(); return I; + } assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, DL, &AC, CxtI, &DT); + Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -244,6 +264,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (ShrinkDemandedConstant(I, 1, DemandedMask)) return I; + // Infer disjoint flag if no common bits are set. + if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) { + WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown), + RHSCache(I->getOperand(1), RHSKnown); + if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(I))) { + cast<PossiblyDisjointInst>(I)->setIsDisjoint(true); + return I; + } + } + break; } case Instruction::Xor: { @@ -265,7 +295,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, DL, &AC, CxtI, &DT); + Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -284,9 +314,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) { Instruction *Or = - BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1), - I->getName()); - return InsertNewInstWith(Or, *I); + BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1)); + if (DemandedMask.isAllOnes()) + cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true); + Or->takeName(I); + return InsertNewInstWith(Or, I->getIterator()); } // If all of the demanded bits on one side are known, and all of the set @@ -298,7 +330,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Constant *AndC = Constant::getIntegerValue(VTy, ~RHSKnown.One & DemandedMask); Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); - return InsertNewInstWith(And, *I); + return InsertNewInstWith(And, I->getIterator()); } // If the RHS is a constant, see if we can change it. Don't alter a -1 @@ -330,11 +362,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); - InsertNewInstWith(NewAnd, *I); + InsertNewInstWith(NewAnd, I->getIterator()); Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); - return InsertNewInstWith(NewXor, *I); + return InsertNewInstWith(NewXor, I->getIterator()); } } break; @@ -411,36 +443,21 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); KnownBits InputKnown(SrcBitWidth); - if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) { + // For zext nneg, we may have dropped the instruction which made the + // input non-negative. + I->dropPoisonGeneratingFlags(); return I; + } assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); + if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() && + !InputKnown.isNegative()) + InputKnown.makeNonNegative(); Known = InputKnown.zextOrTrunc(BitWidth); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - break; - } - case Instruction::BitCast: - if (!I->getOperand(0)->getType()->isIntOrIntVectorTy()) - return nullptr; // vector->int or fp->int? - - if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { - if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { - if (isa<ScalableVectorType>(DstVTy) || - isa<ScalableVectorType>(SrcVTy) || - cast<FixedVectorType>(DstVTy)->getNumElements() != - cast<FixedVectorType>(SrcVTy)->getNumElements()) - // Don't touch a bitcast between vectors of different element counts. - return nullptr; - } else - // Don't touch a scalar-to-vector bitcast. - return nullptr; - } else if (I->getOperand(0)->getType()->isVectorTy()) - // Don't touch a vector-to-scalar bitcast. - return nullptr; - if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1)) - return I; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; + } case Instruction::SExt: { // Compute the bits in the result that are not present in the input. unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); @@ -461,8 +478,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (InputKnown.isNonNegative() || DemandedMask.getActiveBits() <= SrcBitWidth) { // Convert to ZExt cast. - CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName()); - return InsertNewInstWith(NewCast, *I); + CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy); + NewCast->takeName(I); + return InsertNewInstWith(NewCast, I->getIterator()); } // If the sign bit of the input is known set or clear, then we know the @@ -586,7 +604,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) { Constant *ShiftC = ConstantInt::get(VTy, CTZ); Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); - return InsertNewInstWith(Shl, *I); + return InsertNewInstWith(Shl, I->getIterator()); } } // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: @@ -595,7 +613,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { Constant *One = ConstantInt::get(VTy, 1); Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); - return InsertNewInstWith(And1, *I); + return InsertNewInstWith(And1, I->getIterator()); } computeKnownBits(I, Known, Depth, CxtI); @@ -624,10 +642,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask.countr_zero() >= ShiftAmt && match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); - Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); - if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) { + Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C, + LeftShiftAmtC, DL); + if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC, + DL) == C) { Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); - return InsertNewInstWith(Lshr, *I); + return InsertNewInstWith(Lshr, I->getIterator()); } } @@ -688,24 +708,23 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Constant *C; if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); - Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC); - if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) { + Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C, + RightShiftAmtC, DL); + if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC, + RightShiftAmtC, DL) == C) { Instruction *Shl = BinaryOperator::CreateShl(NewC, X); - return InsertNewInstWith(Shl, *I); + return InsertNewInstWith(Shl, I->getIterator()); } } } // Unsigned shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); - - // If the shift is exact, then it does demand the low bits (and knows that - // they are zero). - if (cast<LShrOperator>(I)->isExact()) - DemandedMaskIn.setLowBits(ShiftAmt); - - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) { + // exact flag may not longer hold. + I->dropPoisonGeneratingFlags(); return I; + } assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShiftAmt); Known.One.lshrInPlace(ShiftAmt); @@ -733,7 +752,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Perform the logical shift right. Instruction *NewVal = BinaryOperator::CreateLShr( I->getOperand(0), I->getOperand(1), I->getName()); - return InsertNewInstWith(NewVal, *I); + return InsertNewInstWith(NewVal, I->getIterator()); } const APInt *SA; @@ -747,13 +766,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask.countl_zero() <= ShiftAmt) DemandedMaskIn.setSignBit(); - // If the shift is exact, then it does demand the low bits (and knows that - // they are zero). - if (cast<AShrOperator>(I)->isExact()) - DemandedMaskIn.setLowBits(ShiftAmt); - - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) { + // exact flag may not longer hold. + I->dropPoisonGeneratingFlags(); return I; + } assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now plus sign bits. @@ -770,7 +787,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), I->getOperand(1)); LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); - return InsertNewInstWith(LShr, *I); + LShr->takeName(I); + return InsertNewInstWith(LShr, I->getIterator()); } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. Known.One |= HighBits; } @@ -867,7 +885,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, match(II->getArgOperand(0), m_Not(m_Value(X)))) { Function *Ctpop = Intrinsic::getDeclaration( II->getModule(), Intrinsic::ctpop, VTy); - return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I); + return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator()); } break; } @@ -894,10 +912,52 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, NewVal = BinaryOperator::CreateShl( II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); NewVal->takeName(I); - return InsertNewInstWith(NewVal, *I); + return InsertNewInstWith(NewVal, I->getIterator()); } break; } + case Intrinsic::ptrmask: { + unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits(); + RHSKnown = KnownBits(MaskWidth); + // If either the LHS or the RHS are Zero, the result is zero. + if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) || + SimplifyDemandedBits( + I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth), + RHSKnown, Depth + 1)) + return I; + + // TODO: Should be 1-extend + RHSKnown = RHSKnown.anyextOrTrunc(BitWidth); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + Known = LHSKnown & RHSKnown; + KnownBitsComputed = true; + + // If the client is only demanding bits we know to be zero, return + // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer + // provenance, but making the mask zero will be easily optimizable in + // the backend. + if (DemandedMask.isSubsetOf(Known.Zero) && + !match(I->getOperand(1), m_Zero())) + return replaceOperand( + *I, 1, Constant::getNullValue(I->getOperand(1)->getType())); + + // Mask in demanded space does nothing. + // NOTE: We may have attributes associated with the return value of the + // llvm.ptrmask intrinsic that will be lost when we just return the + // operand. We should try to preserve them. + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) + return I->getOperand(0); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant( + I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth))) + return I; + + break; + } + case Intrinsic::fshr: case Intrinsic::fshl: { const APInt *SA; @@ -918,7 +978,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) return I; } else { // fshl is a rotate - // Avoid converting rotate into funnel shift. + // Avoid converting rotate into funnel shift. // Only simplify if one operand is constant. LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I); if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) && @@ -982,10 +1042,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } } + if (V->getType()->isPointerTy()) { + Align Alignment = V->getPointerAlignment(DL); + Known.Zero.setLowBits(Log2(Alignment)); + } + // If the client is only demanding bits that we know, return the known - // constant. - if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) + // constant. We can't directly simplify pointers as a constant because of + // pointer provenance. + // TODO: We could return `(inttoptr const)` for pointers. + if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(VTy, Known.One); + + if (VerifyKnownBits) { + KnownBits ReferenceKnown = computeKnownBits(V, Depth, CxtI); + if (Known != ReferenceKnown) { + errs() << "Mismatched known bits for " << *V << " in " + << I->getFunction()->getName() << "\n"; + errs() << "computeKnownBits(): " << ReferenceKnown << "\n"; + errs() << "SimplifyDemandedBits(): " << Known << "\n"; + std::abort(); + } + } + return nullptr; } @@ -1009,8 +1088,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( case Instruction::And: { computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = LHSKnown & RHSKnown; - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1029,8 +1109,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( case Instruction::Or: { computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = LHSKnown | RHSKnown; - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1051,8 +1132,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( case Instruction::Xor: { computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = LHSKnown ^ RHSKnown; - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1085,7 +1167,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown); - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::Sub: { @@ -1101,7 +1183,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown); - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::AShr: { @@ -1219,7 +1301,7 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits( New->setIsExact(true); } - return InsertNewInstWith(New, *Shl); + return InsertNewInstWith(New, Shl->getIterator()); } return nullptr; @@ -1549,7 +1631,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, Instruction *New = InsertElementInst::Create( Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx), Shuffle->getName()); - InsertNewInstWith(New, *Shuffle); + InsertNewInstWith(New, Shuffle->getIterator()); return New; } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 4a5ffef2b08e..c8b58c51d4e6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -132,7 +132,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // Create a scalar PHI node that will replace the vector PHI node // just before the current PHI node. PHINode *scalarPHI = cast<PHINode>(InsertNewInstWith( - PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), *PN)); + PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), PN->getIterator())); // Scalarize each PHI operand. for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { Value *PHIInVal = PN->getIncomingValue(i); @@ -148,10 +148,10 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, Value *Op = InsertNewInstWith( ExtractElementInst::Create(B0->getOperand(opId), Elt, B0->getOperand(opId)->getName() + ".Elt"), - *B0); + B0->getIterator()); Value *newPHIUser = InsertNewInstWith( BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(), - scalarPHI, Op, B0), *B0); + scalarPHI, Op, B0), B0->getIterator()); scalarPHI->addIncoming(newPHIUser, inBB); } else { // Scalarize PHI input: @@ -165,7 +165,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, InsertPos = inBB->getFirstInsertionPt(); } - InsertNewInstWith(newEI, *InsertPos); + InsertNewInstWith(newEI, InsertPos); scalarPHI->addIncoming(newEI, inBB); } @@ -441,7 +441,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { if (IndexC->getValue().getActiveBits() <= BitWidth) Idx = ConstantInt::get(Ty, IndexC->getValue().zextOrTrunc(BitWidth)); else - Idx = UndefValue::get(Ty); + Idx = PoisonValue::get(Ty); return replaceInstUsesWith(EI, Idx); } } @@ -742,7 +742,7 @@ static bool replaceExtractElements(InsertElementInst *InsElt, if (ExtVecOpInst && !isa<PHINode>(ExtVecOpInst)) WideVec->insertAfter(ExtVecOpInst); else - IC.InsertNewInstWith(WideVec, *ExtElt->getParent()->getFirstInsertionPt()); + IC.InsertNewInstWith(WideVec, ExtElt->getParent()->getFirstInsertionPt()); // Replace extracts from the original narrow vector with extracts from the new // wide vector. @@ -751,7 +751,7 @@ static bool replaceExtractElements(InsertElementInst *InsElt, if (!OldExt || OldExt->getParent() != WideVec->getParent()) continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); - IC.InsertNewInstWith(NewExt, *OldExt); + IC.InsertNewInstWith(NewExt, OldExt->getIterator()); IC.replaceInstUsesWith(*OldExt, NewExt); // Add the old extracts to the worklist for DCE. We can't remove the // extracts directly, because they may still be used by the calling code. @@ -1121,7 +1121,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // Note that the same block can be a predecessor more than once, // and we need to preserve that invariant for the PHI node. BuilderTy::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(UseBB->getFirstNonPHI()); + Builder.SetInsertPoint(UseBB, UseBB->getFirstNonPHIIt()); auto *PHI = Builder.CreatePHI(AggTy, Preds.size(), OrigIVI.getName() + ".merged"); for (BasicBlock *Pred : Preds) @@ -2122,8 +2122,8 @@ static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) { NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i]; // A select mask with undef elements might look like an identity mask. - assert((ShuffleVectorInst::isSelectMask(NewMask) || - ShuffleVectorInst::isIdentityMask(NewMask)) && + assert((ShuffleVectorInst::isSelectMask(NewMask, NumElts) || + ShuffleVectorInst::isIdentityMask(NewMask, NumElts)) && "Unexpected shuffle mask"); return new ShuffleVectorInst(X, Y, NewMask); } @@ -2197,9 +2197,9 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, !match(Op1, m_Undef()) || match(Mask, m_ZeroMask()) || IndexC == 0) return nullptr; - // Insert into element 0 of an undef vector. - UndefValue *UndefVec = UndefValue::get(Shuf.getType()); - Value *NewIns = Builder.CreateInsertElement(UndefVec, X, (uint64_t)0); + // Insert into element 0 of a poison vector. + PoisonValue *PoisonVec = PoisonValue::get(Shuf.getType()); + Value *NewIns = Builder.CreateInsertElement(PoisonVec, X, (uint64_t)0); // Splat from element 0. Any mask element that is undefined remains undefined. // For example: diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index afd6e034f46d..f072f5cec309 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -130,13 +130,6 @@ STATISTIC(NumReassoc , "Number of reassociations"); DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); -// FIXME: these limits eventually should be as low as 2. -#ifndef NDEBUG -static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; -#else -static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000; -#endif - static cl::opt<bool> EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), cl::init(true)); @@ -145,12 +138,6 @@ static cl::opt<unsigned> MaxSinkNumUsers( "instcombine-max-sink-users", cl::init(32), cl::desc("Maximum number of undroppable users for instruction sinking")); -static cl::opt<unsigned> InfiniteLoopDetectionThreshold( - "instcombine-infinite-loop-threshold", - cl::desc("Number of instruction combining iterations considered an " - "infinite loop"), - cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden); - static cl::opt<unsigned> MaxArraySize("instcombine-maxarray-size", cl::init(1024), cl::desc("Maximum array size considered when doing a combine")); @@ -358,15 +345,19 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, // Fold the constants together in the destination type: // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) + const DataLayout &DL = IC.getDataLayout(); Type *DestTy = C1->getType(); - Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); - Constant *FoldedC = - ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout()); + Constant *CastC2 = ConstantFoldCastOperand(CastOpcode, C2, DestTy, DL); + if (!CastC2) + return false; + Constant *FoldedC = ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, DL); if (!FoldedC) return false; IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0)); IC.replaceOperand(*BinOp1, 1, FoldedC); + BinOp1->dropPoisonGeneratingFlags(); + Cast->dropPoisonGeneratingFlags(); return true; } @@ -542,12 +533,12 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { BinaryOperator::Create(Opcode, A, B); if (isa<FPMathOperator>(NewBO)) { - FastMathFlags Flags = I.getFastMathFlags(); - Flags &= Op0->getFastMathFlags(); - Flags &= Op1->getFastMathFlags(); - NewBO->setFastMathFlags(Flags); + FastMathFlags Flags = I.getFastMathFlags() & + Op0->getFastMathFlags() & + Op1->getFastMathFlags(); + NewBO->setFastMathFlags(Flags); } - InsertNewInstWith(NewBO, I); + InsertNewInstWith(NewBO, I.getIterator()); NewBO->takeName(Op1); replaceOperand(I, 0, NewBO); replaceOperand(I, 1, CRes); @@ -749,7 +740,16 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, // 2) BinOp1 == BinOp2 (if BinOp == `add`, then also requires `shl`). // // -> (BinOp (logic_shift (BinOp X, Y)), Mask) +// +// (Binop1 (Binop2 (arithmetic_shift X, Amt), Mask), (arithmetic_shift Y, Amt)) +// IFF +// 1) Binop1 is bitwise logical operator `and`, `or` or `xor` +// 2) Binop2 is `not` +// +// -> (arithmetic_shift Binop1((not X), Y), Amt) + Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { + const DataLayout &DL = I.getModule()->getDataLayout(); auto IsValidBinOpc = [](unsigned Opc) { switch (Opc) { default: @@ -768,11 +768,13 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { // constraints. auto IsCompletelyDistributable = [](unsigned BinOpc1, unsigned BinOpc2, unsigned ShOpc) { + assert(ShOpc != Instruction::AShr); return (BinOpc1 != Instruction::Add && BinOpc2 != Instruction::Add) || ShOpc == Instruction::Shl; }; auto GetInvShift = [](unsigned ShOpc) { + assert(ShOpc != Instruction::AShr); return ShOpc == Instruction::LShr ? Instruction::Shl : Instruction::LShr; }; @@ -796,23 +798,23 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { // Otherwise, need mask that meets the below requirement. // (logic_shift (inv_logic_shift Mask, ShAmt), ShAmt) == Mask - return ConstantExpr::get( - ShOpc, ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift), - CShift) == CMask; + Constant *MaskInvShift = + ConstantFoldBinaryOpOperands(GetInvShift(ShOpc), CMask, CShift, DL); + return ConstantFoldBinaryOpOperands(ShOpc, MaskInvShift, CShift, DL) == + CMask; }; auto MatchBinOp = [&](unsigned ShOpnum) -> Instruction * { Constant *CMask, *CShift; Value *X, *Y, *ShiftedX, *Mask, *Shift; if (!match(I.getOperand(ShOpnum), - m_OneUse(m_LogicalShift(m_Value(Y), m_Value(Shift))))) + m_OneUse(m_Shift(m_Value(Y), m_Value(Shift))))) return nullptr; if (!match(I.getOperand(1 - ShOpnum), m_BinOp(m_Value(ShiftedX), m_Value(Mask)))) return nullptr; - if (!match(ShiftedX, - m_OneUse(m_LogicalShift(m_Value(X), m_Specific(Shift))))) + if (!match(ShiftedX, m_OneUse(m_Shift(m_Value(X), m_Specific(Shift))))) return nullptr; // Make sure we are matching instruction shifts and not ConstantExpr @@ -836,6 +838,18 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { if (!IsValidBinOpc(I.getOpcode()) || !IsValidBinOpc(BinOpc)) return nullptr; + if (ShOpc == Instruction::AShr) { + if (Instruction::isBitwiseLogicOp(I.getOpcode()) && + BinOpc == Instruction::Xor && match(Mask, m_AllOnes())) { + Value *NotX = Builder.CreateNot(X); + Value *NewBinOp = Builder.CreateBinOp(I.getOpcode(), Y, NotX); + return BinaryOperator::Create( + static_cast<Instruction::BinaryOps>(ShOpc), NewBinOp, Shift); + } + + return nullptr; + } + // If BinOp1 == BinOp2 and it's bitwise or shl with add, then just // distribute to drop the shift irrelevant of constants. if (BinOpc == I.getOpcode() && @@ -857,7 +871,8 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { if (!CanDistributeBinops(I.getOpcode(), BinOpc, ShOpc, CMask, CShift)) return nullptr; - Constant *NewCMask = ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift); + Constant *NewCMask = + ConstantFoldBinaryOpOperands(GetInvShift(ShOpc), CMask, CShift, DL); Value *NewBinOp2 = Builder.CreateBinOp( static_cast<Instruction::BinaryOps>(BinOpc), X, NewCMask); Value *NewBinOp1 = Builder.CreateBinOp(I.getOpcode(), Y, NewBinOp2); @@ -924,13 +939,17 @@ InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { // If the value used in the zext/sext is the select condition, or the negated // of the select condition, the binop can be simplified. - if (CondVal == A) - return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal), + if (CondVal == A) { + Value *NewTrueVal = NewFoldedConst(false, TrueVal); + return SelectInst::Create(CondVal, NewTrueVal, NewFoldedConst(true, FalseVal)); + } - if (match(A, m_Not(m_Specific(CondVal)))) - return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal), + if (match(A, m_Not(m_Specific(CondVal)))) { + Value *NewTrueVal = NewFoldedConst(true, TrueVal); + return SelectInst::Create(CondVal, NewTrueVal, NewFoldedConst(false, FalseVal)); + } return nullptr; } @@ -1167,6 +1186,8 @@ void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { break; case Instruction::Xor: replaceInstUsesWith(cast<Instruction>(*U), I); + // Add to worklist for DCE. + addToWorklist(cast<Instruction>(U)); break; default: llvm_unreachable("Got unexpected user - out of sync with " @@ -1268,7 +1289,7 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, Value *NewOp, InstCombiner &IC) { Instruction *Clone = I.clone(); Clone->replaceUsesOfWith(SI, NewOp); - IC.InsertNewInstBefore(Clone, *SI); + IC.InsertNewInstBefore(Clone, SI->getIterator()); return Clone; } @@ -1302,6 +1323,21 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return nullptr; } + // Test if a FCmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms. And in this case, at + // least one of the comparison operands has at least one user besides + // the compare (the select), which would often largely negate the + // benefit of folding anyway. + if (auto *CI = dyn_cast<FCmpInst>(SI->getCondition())) { + if (CI->hasOneUse()) { + Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); + if ((TV == Op0 && FV == Op1) || (FV == Op0 && TV == Op1)) + return nullptr; + } + } + // Make sure that one of the select arms constant folds successfully. Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); @@ -1316,6 +1352,47 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } +static Value *simplifyInstructionWithPHI(Instruction &I, PHINode *PN, + Value *InValue, BasicBlock *InBB, + const DataLayout &DL, + const SimplifyQuery SQ) { + // NB: It is a precondition of this transform that the operands be + // phi translatable! This is usually trivially satisfied by limiting it + // to constant ops, and for selects we do a more sophisticated check. + SmallVector<Value *> Ops; + for (Value *Op : I.operands()) { + if (Op == PN) + Ops.push_back(InValue); + else + Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB)); + } + + // Don't consider the simplification successful if we get back a constant + // expression. That's just an instruction in hiding. + // Also reject the case where we simplify back to the phi node. We wouldn't + // be able to remove it in that case. + Value *NewVal = simplifyInstructionWithOperands( + &I, Ops, SQ.getWithInstruction(InBB->getTerminator())); + if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) + return NewVal; + + // Check if incoming PHI value can be replaced with constant + // based on implied condition. + BranchInst *TerminatorBI = dyn_cast<BranchInst>(InBB->getTerminator()); + const ICmpInst *ICmp = dyn_cast<ICmpInst>(&I); + if (TerminatorBI && TerminatorBI->isConditional() && + TerminatorBI->getSuccessor(0) != TerminatorBI->getSuccessor(1) && ICmp) { + bool LHSIsTrue = TerminatorBI->getSuccessor(0) == PN->getParent(); + std::optional<bool> ImpliedCond = + isImpliedCondition(TerminatorBI->getCondition(), ICmp->getPredicate(), + Ops[0], Ops[1], DL, LHSIsTrue); + if (ImpliedCond) + return ConstantInt::getBool(I.getType(), ImpliedCond.value()); + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) @@ -1344,29 +1421,11 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { Value *InVal = PN->getIncomingValue(i); BasicBlock *InBB = PN->getIncomingBlock(i); - // NB: It is a precondition of this transform that the operands be - // phi translatable! This is usually trivially satisfied by limiting it - // to constant ops, and for selects we do a more sophisticated check. - SmallVector<Value *> Ops; - for (Value *Op : I.operands()) { - if (Op == PN) - Ops.push_back(InVal); - else - Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB)); - } - - // Don't consider the simplification successful if we get back a constant - // expression. That's just an instruction in hiding. - // Also reject the case where we simplify back to the phi node. We wouldn't - // be able to remove it in that case. - Value *NewVal = simplifyInstructionWithOperands( - &I, Ops, SQ.getWithInstruction(InBB->getTerminator())); - if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) { + if (auto *NewVal = simplifyInstructionWithPHI(I, PN, InVal, InBB, DL, SQ)) { NewPhiValues.push_back(NewVal); continue; } - if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. if (NonSimplifiedBB) return nullptr; // More than one non-simplified value. NonSimplifiedBB = InBB; @@ -1402,7 +1461,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { // Okay, we can do the transformation: create the new PHI node. PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues()); - InsertNewInstBefore(NewPN, *PN); + InsertNewInstBefore(NewPN, PN->getIterator()); NewPN->takeName(PN); NewPN->setDebugLoc(PN->getDebugLoc()); @@ -1417,7 +1476,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { else U = U->DoPHITranslation(PN->getParent(), NonSimplifiedBB); } - InsertNewInstBefore(Clone, *NonSimplifiedBB->getTerminator()); + InsertNewInstBefore(Clone, NonSimplifiedBB->getTerminator()->getIterator()); } for (unsigned i = 0; i != NumPHIValues; ++i) { @@ -1848,8 +1907,8 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) { Constant *WideC; if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC))) return nullptr; - Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType()); - if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC) + Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc); + if (!NarrowC) return nullptr; Y = NarrowC; } @@ -1940,7 +1999,7 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), 0); if (NumVarIndices != Src->getNumIndices()) { // FIXME: getIndexedOffsetInType() does not handled scalable vectors. - if (isa<ScalableVectorType>(BaseType)) + if (BaseType->isScalableTy()) return nullptr; SmallVector<Value *> ConstantIndices; @@ -2048,12 +2107,126 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, return nullptr; } +Value *InstCombiner::getFreelyInvertedImpl(Value *V, bool WillInvertAllUses, + BuilderTy *Builder, + bool &DoesConsume, unsigned Depth) { + static Value *const NonNull = reinterpret_cast<Value *>(uintptr_t(1)); + // ~(~(X)) -> X. + Value *A, *B; + if (match(V, m_Not(m_Value(A)))) { + DoesConsume = true; + return A; + } + + Constant *C; + // Constants can be considered to be not'ed values. + if (match(V, m_ImmConstant(C))) + return ConstantExpr::getNot(C); + + if (Depth++ >= MaxAnalysisRecursionDepth) + return nullptr; + + // The rest of the cases require that we invert all uses so don't bother + // doing the analysis if we know we can't use the result. + if (!WillInvertAllUses) + return nullptr; + + // Compares can be inverted if all of their uses are being modified to use + // the ~V. + if (auto *I = dyn_cast<CmpInst>(V)) { + if (Builder != nullptr) + return Builder->CreateCmp(I->getInversePredicate(), I->getOperand(0), + I->getOperand(1)); + return NonNull; + } + + // If `V` is of the form `A + B` then `-1 - V` can be folded into + // `(-1 - B) - A` if we are willing to invert all of the uses. + if (match(V, m_Add(m_Value(A), m_Value(B)))) { + if (auto *BV = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateSub(BV, A) : NonNull; + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateSub(AV, B) : NonNull; + return nullptr; + } + + // If `V` is of the form `A ^ ~B` then `~(A ^ ~B)` can be folded + // into `A ^ B` if we are willing to invert all of the uses. + if (match(V, m_Xor(m_Value(A), m_Value(B)))) { + if (auto *BV = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateXor(A, BV) : NonNull; + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateXor(AV, B) : NonNull; + return nullptr; + } + + // If `V` is of the form `B - A` then `-1 - V` can be folded into + // `A + (-1 - B)` if we are willing to invert all of the uses. + if (match(V, m_Sub(m_Value(A), m_Value(B)))) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateAdd(AV, B) : NonNull; + return nullptr; + } + + // If `V` is of the form `(~A) s>> B` then `~((~A) s>> B)` can be folded + // into `A s>> B` if we are willing to invert all of the uses. + if (match(V, m_AShr(m_Value(A), m_Value(B)))) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateAShr(AV, B) : NonNull; + return nullptr; + } + + // Treat lshr with non-negative operand as ashr. + if (match(V, m_LShr(m_Value(A), m_Value(B))) && + isKnownNonNegative(A, SQ.getWithInstruction(cast<Instruction>(V)), + Depth)) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateAShr(AV, B) : NonNull; + return nullptr; + } + + Value *Cond; + // LogicOps are special in that we canonicalize them at the cost of an + // instruction. + bool IsSelect = match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) && + !shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(V)); + // Selects/min/max with invertible operands are freely invertible + if (IsSelect || match(V, m_MaxOrMin(m_Value(A), m_Value(B)))) { + if (!getFreelyInvertedImpl(B, B->hasOneUse(), /*Builder*/ nullptr, + DoesConsume, Depth)) + return nullptr; + if (Value *NotA = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) { + if (Builder != nullptr) { + Value *NotB = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + DoesConsume, Depth); + assert(NotB != nullptr && + "Unable to build inverted value for known freely invertable op"); + if (auto *II = dyn_cast<IntrinsicInst>(V)) + return Builder->CreateBinaryIntrinsic( + getInverseMinMaxIntrinsic(II->getIntrinsicID()), NotA, NotB); + return Builder->CreateSelect(Cond, NotA, NotB); + } + return NonNull; + } + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Value *PtrOp = GEP.getOperand(0); SmallVector<Value *, 8> Indices(GEP.indices()); Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); - bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); + bool IsGEPSrcEleScalable = GEPEltType->isScalableTy(); if (Value *V = simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(), SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); @@ -2221,7 +2394,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { NewGEP->setOperand(DI, NewPN); } - NewGEP->insertInto(GEP.getParent(), GEP.getParent()->getFirstInsertionPt()); + NewGEP->insertBefore(*GEP.getParent(), GEP.getParent()->getFirstInsertionPt()); return replaceOperand(GEP, 0, NewGEP); } @@ -2264,11 +2437,43 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, GEPType); } } - // We do not handle pointer-vector geps here. if (GEPType->isVectorTy()) return nullptr; + if (GEP.getNumIndices() == 1) { + // Try to replace ADD + GEP with GEP + GEP. + Value *Idx1, *Idx2; + if (match(GEP.getOperand(1), + m_OneUse(m_Add(m_Value(Idx1), m_Value(Idx2))))) { + // %idx = add i64 %idx1, %idx2 + // %gep = getelementptr i32, ptr %ptr, i64 %idx + // as: + // %newptr = getelementptr i32, ptr %ptr, i64 %idx1 + // %newgep = getelementptr i32, ptr %newptr, i64 %idx2 + auto *NewPtr = Builder.CreateGEP(GEP.getResultElementType(), + GEP.getPointerOperand(), Idx1); + return GetElementPtrInst::Create(GEP.getResultElementType(), NewPtr, + Idx2); + } + ConstantInt *C; + if (match(GEP.getOperand(1), m_OneUse(m_SExt(m_OneUse(m_NSWAdd( + m_Value(Idx1), m_ConstantInt(C))))))) { + // %add = add nsw i32 %idx1, idx2 + // %sidx = sext i32 %add to i64 + // %gep = getelementptr i32, ptr %ptr, i64 %sidx + // as: + // %newptr = getelementptr i32, ptr %ptr, i32 %idx1 + // %newgep = getelementptr i32, ptr %newptr, i32 idx2 + auto *NewPtr = Builder.CreateGEP( + GEP.getResultElementType(), GEP.getPointerOperand(), + Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType())); + return GetElementPtrInst::Create( + GEP.getResultElementType(), NewPtr, + Builder.CreateSExt(C, GEP.getOperand(1)->getType())); + } + } + if (!GEP.isInBounds()) { unsigned IdxWidth = DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); @@ -2362,6 +2567,26 @@ static bool isAllocSiteRemovable(Instruction *AI, unsigned OtherIndex = (ICI->getOperand(0) == PI) ? 1 : 0; if (!isNeverEqualToUnescapedAlloc(ICI->getOperand(OtherIndex), TLI, AI)) return false; + + // Do not fold compares to aligned_alloc calls, as they may have to + // return null in case the required alignment cannot be satisfied, + // unless we can prove that both alignment and size are valid. + auto AlignmentAndSizeKnownValid = [](CallBase *CB) { + // Check if alignment and size of a call to aligned_alloc is valid, + // that is alignment is a power-of-2 and the size is a multiple of the + // alignment. + const APInt *Alignment; + const APInt *Size; + return match(CB->getArgOperand(0), m_APInt(Alignment)) && + match(CB->getArgOperand(1), m_APInt(Size)) && + Alignment->isPowerOf2() && Size->urem(*Alignment).isZero(); + }; + auto *CB = dyn_cast<CallBase>(AI); + LibFunc TheLibFunc; + if (CB && TLI.getLibFunc(*CB->getCalledFunction(), TheLibFunc) && + TLI.has(TheLibFunc) && TheLibFunc == LibFunc_aligned_alloc && + !AlignmentAndSizeKnownValid(CB)) + return false; Users.emplace_back(I); continue; } @@ -2451,9 +2676,10 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { // If we are removing an alloca with a dbg.declare, insert dbg.value calls // before each store. SmallVector<DbgVariableIntrinsic *, 8> DVIs; + SmallVector<DPValue *, 8> DPVs; std::unique_ptr<DIBuilder> DIB; if (isa<AllocaInst>(MI)) { - findDbgUsers(DVIs, &MI); + findDbgUsers(DVIs, &MI, &DPVs); DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); } @@ -2493,6 +2719,9 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { for (auto *DVI : DVIs) if (DVI->isAddressOfVariable()) ConvertDebugDeclareToDebugValue(DVI, SI, *DIB); + for (auto *DPV : DPVs) + if (DPV->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DPV, SI, *DIB); } else { // Casts, GEP, or anything else: we're about to delete this instruction, // so it can not have any valid uses. @@ -2531,9 +2760,15 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { // If there is a dead store to `%a` in @trivially_inlinable_no_op, the // "arg0" dbg.value may be stale after the call. However, failing to remove // the DW_OP_deref dbg.value causes large gaps in location coverage. + // + // FIXME: the Assignment Tracking project has now likely made this + // redundant (and it's sometimes harmful). for (auto *DVI : DVIs) if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref()) DVI->eraseFromParent(); + for (auto *DPV : DPVs) + if (DPV->isAddressOfVariable() || DPV->getExpression()->startsWithDeref()) + DPV->eraseFromParent(); return eraseInstFromFunction(MI); } @@ -2612,7 +2847,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, for (Instruction &Instr : llvm::make_early_inc_range(*FreeInstrBB)) { if (&Instr == FreeInstrBBTerminator) break; - Instr.moveBefore(TI); + Instr.moveBeforePreserving(TI); } assert(FreeInstrBB->size() == 1 && "Only the branch instruction should remain"); @@ -2746,55 +2981,77 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { return nullptr; } +void InstCombinerImpl::addDeadEdge(BasicBlock *From, BasicBlock *To, + SmallVectorImpl<BasicBlock *> &Worklist) { + if (!DeadEdges.insert({From, To}).second) + return; + + // Replace phi node operands in successor with poison. + for (PHINode &PN : To->phis()) + for (Use &U : PN.incoming_values()) + if (PN.getIncomingBlock(U) == From && !isa<PoisonValue>(U)) { + replaceUse(U, PoisonValue::get(PN.getType())); + addToWorklist(&PN); + MadeIRChange = true; + } + + Worklist.push_back(To); +} + // Under the assumption that I is unreachable, remove it and following -// instructions. -bool InstCombinerImpl::handleUnreachableFrom(Instruction *I) { - bool Changed = false; +// instructions. Changes are reported directly to MadeIRChange. +void InstCombinerImpl::handleUnreachableFrom( + Instruction *I, SmallVectorImpl<BasicBlock *> &Worklist) { BasicBlock *BB = I->getParent(); for (Instruction &Inst : make_early_inc_range( make_range(std::next(BB->getTerminator()->getReverseIterator()), std::next(I->getReverseIterator())))) { if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) { replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType())); - Changed = true; + MadeIRChange = true; } if (Inst.isEHPad() || Inst.getType()->isTokenTy()) continue; + // RemoveDIs: erase debug-info on this instruction manually. + Inst.dropDbgValues(); eraseInstFromFunction(Inst); - Changed = true; + MadeIRChange = true; } - // Replace phi node operands in successor blocks with poison. + // RemoveDIs: to match behaviour in dbg.value mode, drop debug-info on + // terminator too. + BB->getTerminator()->dropDbgValues(); + + // Handle potentially dead successors. for (BasicBlock *Succ : successors(BB)) - for (PHINode &PN : Succ->phis()) - for (Use &U : PN.incoming_values()) - if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) { - replaceUse(U, PoisonValue::get(PN.getType())); - addToWorklist(&PN); - Changed = true; - } + addDeadEdge(BB, Succ, Worklist); +} - // TODO: Successor blocks may also be dead. - return Changed; +void InstCombinerImpl::handlePotentiallyDeadBlocks( + SmallVectorImpl<BasicBlock *> &Worklist) { + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + if (!all_of(predecessors(BB), [&](BasicBlock *Pred) { + return DeadEdges.contains({Pred, BB}) || DT.dominates(BB, Pred); + })) + continue; + + handleUnreachableFrom(&BB->front(), Worklist); + } } -bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB, +void InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc) { - bool Changed = false; + SmallVector<BasicBlock *> Worklist; for (BasicBlock *Succ : successors(BB)) { // The live successor isn't dead. if (Succ == LiveSucc) continue; - if (!all_of(predecessors(Succ), [&](BasicBlock *Pred) { - return DT.dominates(BasicBlockEdge(BB, Succ), - BasicBlockEdge(Pred, Succ)); - })) - continue; - - Changed |= handleUnreachableFrom(&Succ->front()); + addDeadEdge(BB, Succ, Worklist); } - return Changed; + + handlePotentiallyDeadBlocks(Worklist); } Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { @@ -2840,14 +3097,17 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return &BI; } - if (isa<UndefValue>(Cond) && - handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr)) - return &BI; - if (auto *CI = dyn_cast<ConstantInt>(Cond)) - if (handlePotentiallyDeadSuccessors(BI.getParent(), - BI.getSuccessor(!CI->getZExtValue()))) - return &BI; + if (isa<UndefValue>(Cond)) { + handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr); + return nullptr; + } + if (auto *CI = dyn_cast<ConstantInt>(Cond)) { + handlePotentiallyDeadSuccessors(BI.getParent(), + BI.getSuccessor(!CI->getZExtValue())); + return nullptr; + } + DC.registerBranch(&BI); return nullptr; } @@ -2866,14 +3126,6 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, Op0); } - if (isa<UndefValue>(Cond) && - handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr)) - return &SI; - if (auto *CI = dyn_cast<ConstantInt>(Cond)) - if (handlePotentiallyDeadSuccessors( - SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor())) - return &SI; - KnownBits Known = computeKnownBits(Cond, 0, &SI); unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); @@ -2906,6 +3158,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, NewCond); } + if (isa<UndefValue>(Cond)) { + handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr); + return nullptr; + } + if (auto *CI = dyn_cast<ConstantInt>(Cond)) { + handlePotentiallyDeadSuccessors(SI.getParent(), + SI.findCaseValue(CI)->getCaseSuccessor()); + return nullptr; + } + return nullptr; } @@ -3532,7 +3794,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, Value *StartV = StartU->get(); BasicBlock *StartBB = PN->getIncomingBlock(*StartU); bool StartNeedsFreeze = !isGuaranteedNotToBeUndefOrPoison(StartV); - // We can't insert freeze if the the start value is the result of the + // We can't insert freeze if the start value is the result of the // terminator (e.g. an invoke). if (StartNeedsFreeze && StartBB->getTerminator() == StartV) return nullptr; @@ -3583,19 +3845,27 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { // *all* uses if the operand is an invoke/callbr and the use is in a phi on // the normal/default destination. This is why the domination check in the // replacement below is still necessary. - Instruction *MoveBefore; + BasicBlock::iterator MoveBefore; if (isa<Argument>(Op)) { MoveBefore = - &*FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca(); + FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca(); } else { - MoveBefore = cast<Instruction>(Op)->getInsertionPointAfterDef(); - if (!MoveBefore) + auto MoveBeforeOpt = cast<Instruction>(Op)->getInsertionPointAfterDef(); + if (!MoveBeforeOpt) return false; + MoveBefore = *MoveBeforeOpt; } + // Don't move to the position of a debug intrinsic. + if (isa<DbgInfoIntrinsic>(MoveBefore)) + MoveBefore = MoveBefore->getNextNonDebugInstruction()->getIterator(); + // Re-point iterator to come after any debug-info records, if we're + // running in "RemoveDIs" mode + MoveBefore.setHeadBit(false); + bool Changed = false; - if (&FI != MoveBefore) { - FI.moveBefore(MoveBefore); + if (&FI != &*MoveBefore) { + FI.moveBefore(*MoveBefore->getParent(), MoveBefore); Changed = true; } @@ -3798,7 +4068,7 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, /// the new position. BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); - I->moveBefore(&*InsertPos); + I->moveBefore(*DestBlock, InsertPos); ++NumSunkInst; // Also sink all related debug uses from the source basic block. Otherwise we @@ -3808,10 +4078,19 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, // here, but that computation has been sunk. SmallVector<DbgVariableIntrinsic *, 2> DbgUsers; findDbgUsers(DbgUsers, I); - // Process the sinking DbgUsers in reverse order, as we only want to clone the - // last appearing debug intrinsic for each given variable. + + // For all debug values in the destination block, the sunk instruction + // will still be available, so they do not need to be dropped. + SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSalvage; + SmallVector<DPValue *, 2> DPValuesToSalvage; + for (auto &DbgUser : DbgUsers) + if (DbgUser->getParent() != DestBlock) + DbgUsersToSalvage.push_back(DbgUser); + + // Process the sinking DbgUsersToSalvage in reverse order, as we only want + // to clone the last appearing debug intrinsic for each given variable. SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSink; - for (DbgVariableIntrinsic *DVI : DbgUsers) + for (DbgVariableIntrinsic *DVI : DbgUsersToSalvage) if (DVI->getParent() == SrcBlock) DbgUsersToSink.push_back(DVI); llvm::sort(DbgUsersToSink, @@ -3847,7 +4126,10 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, // Perform salvaging without the clones, then sink the clones. if (!DIIClones.empty()) { - salvageDebugInfoForDbgValues(*I, DbgUsers); + // RemoveDIs: pass in empty vector of DPValues until we get to instrumenting + // this pass. + SmallVector<DPValue *, 1> DummyDPValues; + salvageDebugInfoForDbgValues(*I, DbgUsersToSalvage, DummyDPValues); // The clones are in reverse order of original appearance, reverse again to // maintain the original order. for (auto &DIIClone : llvm::reverse(DIIClones)) { @@ -4093,43 +4375,52 @@ public: } }; -/// Populate the IC worklist from a function, by walking it in depth-first -/// order and adding all reachable code to the worklist. +/// Populate the IC worklist from a function, by walking it in reverse +/// post-order and adding all reachable code to the worklist. /// /// This has a couple of tricks to make the code faster and more powerful. In /// particular, we constant fold and DCE instructions as we go, to avoid adding /// them to the worklist (this significantly speeds up instcombine on code where /// many instructions are dead or constant). Additionally, if we find a branch /// whose condition is a known constant, we only visit the reachable successors. -static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, - const TargetLibraryInfo *TLI, - InstructionWorklist &ICWorklist) { +bool InstCombinerImpl::prepareWorklist( + Function &F, ReversePostOrderTraversal<BasicBlock *> &RPOT) { bool MadeIRChange = false; - SmallPtrSet<BasicBlock *, 32> Visited; - SmallVector<BasicBlock*, 256> Worklist; - Worklist.push_back(&F.front()); - + SmallPtrSet<BasicBlock *, 32> LiveBlocks; SmallVector<Instruction *, 128> InstrsForInstructionWorklist; DenseMap<Constant *, Constant *> FoldedConstants; AliasScopeTracker SeenAliasScopes; - do { - BasicBlock *BB = Worklist.pop_back_val(); + auto HandleOnlyLiveSuccessor = [&](BasicBlock *BB, BasicBlock *LiveSucc) { + for (BasicBlock *Succ : successors(BB)) + if (Succ != LiveSucc && DeadEdges.insert({BB, Succ}).second) + for (PHINode &PN : Succ->phis()) + for (Use &U : PN.incoming_values()) + if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) { + U.set(PoisonValue::get(PN.getType())); + MadeIRChange = true; + } + }; - // We have now visited this block! If we've already been here, ignore it. - if (!Visited.insert(BB).second) + for (BasicBlock *BB : RPOT) { + if (!BB->isEntryBlock() && all_of(predecessors(BB), [&](BasicBlock *Pred) { + return DeadEdges.contains({Pred, BB}) || DT.dominates(BB, Pred); + })) { + HandleOnlyLiveSuccessor(BB, nullptr); continue; + } + LiveBlocks.insert(BB); for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { // ConstantProp instruction if trivially constant. if (!Inst.use_empty() && (Inst.getNumOperands() == 0 || isa<Constant>(Inst.getOperand(0)))) - if (Constant *C = ConstantFoldInstruction(&Inst, DL, TLI)) { + if (Constant *C = ConstantFoldInstruction(&Inst, DL, &TLI)) { LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << Inst << '\n'); Inst.replaceAllUsesWith(C); ++NumConstProp; - if (isInstructionTriviallyDead(&Inst, TLI)) + if (isInstructionTriviallyDead(&Inst, &TLI)) Inst.eraseFromParent(); MadeIRChange = true; continue; @@ -4143,7 +4434,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, auto *C = cast<Constant>(U); Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) - FoldRes = ConstantFoldConstant(C, DL, TLI); + FoldRes = ConstantFoldConstant(C, DL, &TLI); if (FoldRes != C) { LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << Inst @@ -4163,37 +4454,39 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, } } - // Recursively visit successors. If this is a branch or switch on a - // constant, only visit the reachable successor. + // If this is a branch or switch on a constant, mark only the single + // live successor. Otherwise assume all successors are live. Instruction *TI = BB->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(TI); BI && BI->isConditional()) { - if (isa<UndefValue>(BI->getCondition())) + if (isa<UndefValue>(BI->getCondition())) { // Branch on undef is UB. + HandleOnlyLiveSuccessor(BB, nullptr); continue; + } if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) { bool CondVal = Cond->getZExtValue(); - BasicBlock *ReachableBB = BI->getSuccessor(!CondVal); - Worklist.push_back(ReachableBB); + HandleOnlyLiveSuccessor(BB, BI->getSuccessor(!CondVal)); continue; } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - if (isa<UndefValue>(SI->getCondition())) + if (isa<UndefValue>(SI->getCondition())) { // Switch on undef is UB. + HandleOnlyLiveSuccessor(BB, nullptr); continue; + } if (auto *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { - Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); + HandleOnlyLiveSuccessor(BB, + SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } } - - append_range(Worklist, successors(TI)); - } while (!Worklist.empty()); + } // Remove instructions inside unreachable blocks. This prevents the // instcombine code from having to deal with some bad special cases, and // reduces use counts of instructions. for (BasicBlock &BB : F) { - if (Visited.count(&BB)) + if (LiveBlocks.count(&BB)) continue; unsigned NumDeadInstInBB; @@ -4210,11 +4503,11 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // of the function down. This jives well with the way that it adds all uses // of instructions to the worklist after doing a transformation, thus avoiding // some N^2 behavior in pathological cases. - ICWorklist.reserve(InstrsForInstructionWorklist.size()); + Worklist.reserve(InstrsForInstructionWorklist.size()); for (Instruction *Inst : reverse(InstrsForInstructionWorklist)) { // DCE instruction if trivially dead. As we iterate in reverse program // order here, we will clean up whole chains of dead instructions. - if (isInstructionTriviallyDead(Inst, TLI) || + if (isInstructionTriviallyDead(Inst, &TLI) || SeenAliasScopes.isNoAliasScopeDeclDead(Inst)) { ++NumDeadInst; LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); @@ -4224,7 +4517,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, continue; } - ICWorklist.push(Inst); + Worklist.push(Inst); } return MadeIRChange; @@ -4234,7 +4527,7 @@ static bool combineInstructionsOverFunction( Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { + ProfileSummaryInfo *PSI, LoopInfo *LI, const InstCombineOptions &Opts) { auto &DL = F.getParent()->getDataLayout(); /// Builder - This is an IRBuilder that automatically inserts new @@ -4247,6 +4540,8 @@ static bool combineInstructionsOverFunction( AC.registerAssumption(Assume); })); + ReversePostOrderTraversal<BasicBlock *> RPOT(&F.front()); + // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. bool MadeIRChange = false; @@ -4256,35 +4551,33 @@ static bool combineInstructionsOverFunction( // Iterate while there is work to do. unsigned Iteration = 0; while (true) { - ++NumWorklistIterations; ++Iteration; - if (Iteration > InfiniteLoopDetectionThreshold) { - report_fatal_error( - "Instruction Combining seems stuck in an infinite loop after " + - Twine(InfiniteLoopDetectionThreshold) + " iterations."); - } - - if (Iteration > MaxIterations) { - LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << MaxIterations + if (Iteration > Opts.MaxIterations && !Opts.VerifyFixpoint) { + LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << Opts.MaxIterations << " on " << F.getName() - << " reached; stopping before reaching a fixpoint\n"); + << " reached; stopping without verifying fixpoint\n"); break; } + ++NumWorklistIterations; LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); - MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, ORE, BFI, PSI, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; - - if (!IC.run()) + bool MadeChangeInThisIteration = IC.prepareWorklist(F, RPOT); + MadeChangeInThisIteration |= IC.run(); + if (!MadeChangeInThisIteration) break; MadeIRChange = true; + if (Iteration > Opts.MaxIterations) { + report_fatal_error( + "Instruction Combining did not reach a fixpoint after " + + Twine(Opts.MaxIterations) + " iterations"); + } } if (Iteration == 1) @@ -4307,7 +4600,8 @@ void InstCombinePass::printPipeline( OS, MapClassName2PassName); OS << '<'; OS << "max-iterations=" << Options.MaxIterations << ";"; - OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info"; + OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info;"; + OS << (Options.VerifyFixpoint ? "" : "no-") << "verify-fixpoint"; OS << '>'; } @@ -4333,7 +4627,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, Options.MaxIterations, LI)) + BFI, PSI, LI, Options)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -4382,8 +4676,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { nullptr; return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, - InstCombineDefaultMaxIterations, LI); + BFI, PSI, LI, InstCombineOptions()); } char InstructionCombiningPass::ID = 0; diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index bde5fba20f3b..b175e6f93f3e 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -201,8 +201,8 @@ static cl::opt<bool> ClRecover( static cl::opt<bool> ClInsertVersionCheck( "asan-guard-against-version-mismatch", - cl::desc("Guard against compiler/runtime version mismatch."), - cl::Hidden, cl::init(true)); + cl::desc("Guard against compiler/runtime version mismatch."), cl::Hidden, + cl::init(true)); // This flag may need to be replaced with -f[no-]asan-reads. static cl::opt<bool> ClInstrumentReads("asan-instrument-reads", @@ -323,10 +323,9 @@ static cl::opt<unsigned> ClRealignStack( static cl::opt<int> ClInstrumentationWithCallsThreshold( "asan-instrumentation-with-call-threshold", - cl::desc( - "If the function being instrumented contains more than " - "this number of memory accesses, use callbacks instead of " - "inline checks (-1 means never use callbacks)."), + cl::desc("If the function being instrumented contains more than " + "this number of memory accesses, use callbacks instead of " + "inline checks (-1 means never use callbacks)."), cl::Hidden, cl::init(7000)); static cl::opt<std::string> ClMemoryAccessCallbackPrefix( @@ -491,7 +490,8 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, bool IsMIPS32 = TargetTriple.isMIPS32(); bool IsMIPS64 = TargetTriple.isMIPS64(); bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb(); - bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64; + bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64 || + TargetTriple.getArch() == Triple::aarch64_be; bool IsLoongArch64 = TargetTriple.isLoongArch64(); bool IsRISCV64 = TargetTriple.getArch() == Triple::riscv64; bool IsWindows = TargetTriple.isOSWindows(); @@ -644,8 +644,9 @@ namespace { /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer { AddressSanitizer(Module &M, const StackSafetyGlobalInfo *SSGI, - bool CompileKernel = false, bool Recover = false, - bool UseAfterScope = false, + int InstrumentationWithCallsThreshold, + uint32_t MaxInlinePoisoningSize, bool CompileKernel = false, + bool Recover = false, bool UseAfterScope = false, AsanDetectStackUseAfterReturnMode UseAfterReturn = AsanDetectStackUseAfterReturnMode::Runtime) : CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan @@ -654,12 +655,19 @@ struct AddressSanitizer { UseAfterScope(UseAfterScope || ClUseAfterScope), UseAfterReturn(ClUseAfterReturn.getNumOccurrences() ? ClUseAfterReturn : UseAfterReturn), - SSGI(SSGI) { + SSGI(SSGI), + InstrumentationWithCallsThreshold( + ClInstrumentationWithCallsThreshold.getNumOccurrences() > 0 + ? ClInstrumentationWithCallsThreshold + : InstrumentationWithCallsThreshold), + MaxInlinePoisoningSize(ClMaxInlinePoisoningSize.getNumOccurrences() > 0 + ? ClMaxInlinePoisoningSize + : MaxInlinePoisoningSize) { C = &(M.getContext()); DL = &M.getDataLayout(); LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); - Int8PtrTy = Type::getInt8PtrTy(*C); + PtrTy = PointerType::getUnqual(*C); Int32Ty = Type::getInt32Ty(*C); TargetTriple = Triple(M.getTargetTriple()); @@ -751,8 +759,8 @@ private: bool UseAfterScope; AsanDetectStackUseAfterReturnMode UseAfterReturn; Type *IntptrTy; - Type *Int8PtrTy; Type *Int32Ty; + PointerType *PtrTy; ShadowMapping Mapping; FunctionCallee AsanHandleNoReturnFunc; FunctionCallee AsanPtrCmpFunction, AsanPtrSubFunction; @@ -773,17 +781,22 @@ private: FunctionCallee AMDGPUAddressShared; FunctionCallee AMDGPUAddressPrivate; + int InstrumentationWithCallsThreshold; + uint32_t MaxInlinePoisoningSize; }; class ModuleAddressSanitizer { public: - ModuleAddressSanitizer(Module &M, bool CompileKernel = false, - bool Recover = false, bool UseGlobalsGC = true, - bool UseOdrIndicator = true, + ModuleAddressSanitizer(Module &M, bool InsertVersionCheck, + bool CompileKernel = false, bool Recover = false, + bool UseGlobalsGC = true, bool UseOdrIndicator = true, AsanDtorKind DestructorKind = AsanDtorKind::Global, AsanCtorKind ConstructorKind = AsanCtorKind::Global) : CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel), + InsertVersionCheck(ClInsertVersionCheck.getNumOccurrences() > 0 + ? ClInsertVersionCheck + : InsertVersionCheck), Recover(ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC && !this->CompileKernel), // Enable aliases as they should have no downside with ODR indicators. @@ -802,10 +815,13 @@ public: // do globals-gc. UseCtorComdat(UseGlobalsGC && ClWithComdat && !this->CompileKernel), DestructorKind(DestructorKind), - ConstructorKind(ConstructorKind) { + ConstructorKind(ClConstructorKind.getNumOccurrences() > 0 + ? ClConstructorKind + : ConstructorKind) { C = &(M.getContext()); int LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); + PtrTy = PointerType::getUnqual(*C); TargetTriple = Triple(M.getTargetTriple()); Mapping = getShadowMapping(TargetTriple, LongSize, this->CompileKernel); @@ -819,11 +835,11 @@ public: private: void initializeCallbacks(Module &M); - bool InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat); + void instrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat); void InstrumentGlobalsCOFF(IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers); - void InstrumentGlobalsELF(IRBuilder<> &IRB, Module &M, + void instrumentGlobalsELF(IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers, const std::string &UniqueModuleId); @@ -854,6 +870,7 @@ private: int GetAsanVersion(const Module &M) const; bool CompileKernel; + bool InsertVersionCheck; bool Recover; bool UseGlobalsGC; bool UsePrivateAlias; @@ -862,6 +879,7 @@ private: AsanDtorKind DestructorKind; AsanCtorKind ConstructorKind; Type *IntptrTy; + PointerType *PtrTy; LLVMContext *C; Triple TargetTriple; ShadowMapping Mapping; @@ -1148,22 +1166,22 @@ AddressSanitizerPass::AddressSanitizerPass( AsanCtorKind ConstructorKind) : Options(Options), UseGlobalGC(UseGlobalGC), UseOdrIndicator(UseOdrIndicator), DestructorKind(DestructorKind), - ConstructorKind(ClConstructorKind) {} + ConstructorKind(ConstructorKind) {} PreservedAnalyses AddressSanitizerPass::run(Module &M, ModuleAnalysisManager &MAM) { - ModuleAddressSanitizer ModuleSanitizer(M, Options.CompileKernel, - Options.Recover, UseGlobalGC, - UseOdrIndicator, DestructorKind, - ConstructorKind); + ModuleAddressSanitizer ModuleSanitizer( + M, Options.InsertVersionCheck, Options.CompileKernel, Options.Recover, + UseGlobalGC, UseOdrIndicator, DestructorKind, ConstructorKind); bool Modified = false; auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); const StackSafetyGlobalInfo *const SSGI = ClUseStackSafety ? &MAM.getResult<StackSafetyGlobalAnalysis>(M) : nullptr; for (Function &F : M) { - AddressSanitizer FunctionSanitizer(M, SSGI, Options.CompileKernel, - Options.Recover, Options.UseAfterScope, - Options.UseAfterReturn); + AddressSanitizer FunctionSanitizer( + M, SSGI, Options.InstrumentationWithCallsThreshold, + Options.MaxInlinePoisoningSize, Options.CompileKernel, Options.Recover, + Options.UseAfterScope, Options.UseAfterReturn); const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); Modified |= FunctionSanitizer.instrumentFunction(F, &TLI); } @@ -1188,17 +1206,17 @@ static size_t TypeStoreSizeToSizeIndex(uint32_t TypeSize) { /// Check if \p G has been created by a trusted compiler pass. static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) { // Do not instrument @llvm.global_ctors, @llvm.used, etc. - if (G->getName().startswith("llvm.") || + if (G->getName().starts_with("llvm.") || // Do not instrument gcov counter arrays. - G->getName().startswith("__llvm_gcov_ctr") || + G->getName().starts_with("__llvm_gcov_ctr") || // Do not instrument rtti proxy symbols for function sanitizer. - G->getName().startswith("__llvm_rtti_proxy")) + G->getName().starts_with("__llvm_rtti_proxy")) return true; // Do not instrument asan globals. - if (G->getName().startswith(kAsanGenPrefix) || - G->getName().startswith(kSanCovGenPrefix) || - G->getName().startswith(kODRGenPrefix)) + if (G->getName().starts_with(kAsanGenPrefix) || + G->getName().starts_with(kSanCovGenPrefix) || + G->getName().starts_with(kODRGenPrefix)) return true; return false; @@ -1232,15 +1250,13 @@ Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) { void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { InstrumentationIRBuilder IRB(MI); if (isa<MemTransferInst>(MI)) { - IRB.CreateCall( - isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + IRB.CreateCall(isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy, + {MI->getOperand(0), MI->getOperand(1), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); } else if (isa<MemSetInst>(MI)) { IRB.CreateCall( AsanMemset, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + {MI->getOperand(0), IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); } @@ -1570,7 +1586,7 @@ void AddressSanitizer::instrumentMaskedLoadOrStore( InstrumentedAddress = IRB.CreateExtractElement(Addr, Index); } else if (Stride) { Index = IRB.CreateMul(Index, Stride); - Addr = IRB.CreateBitCast(Addr, Type::getInt8PtrTy(*C)); + Addr = IRB.CreateBitCast(Addr, PointerType::getUnqual(*C)); InstrumentedAddress = IRB.CreateGEP(Type::getInt8Ty(*C), Addr, {Index}); } else { InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index}); @@ -1695,9 +1711,8 @@ Instruction *AddressSanitizer::instrumentAMDGPUAddress( return InsertBefore; // Instrument generic addresses in supported addressspaces. IRBuilder<> IRB(InsertBefore); - Value *AddrLong = IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()); - Value *IsShared = IRB.CreateCall(AMDGPUAddressShared, {AddrLong}); - Value *IsPrivate = IRB.CreateCall(AMDGPUAddressPrivate, {AddrLong}); + Value *IsShared = IRB.CreateCall(AMDGPUAddressShared, {Addr}); + Value *IsPrivate = IRB.CreateCall(AMDGPUAddressPrivate, {Addr}); Value *IsSharedOrPrivate = IRB.CreateOr(IsShared, IsPrivate); Value *Cmp = IRB.CreateNot(IsSharedOrPrivate); Value *AddrSpaceZeroLanding = @@ -1728,7 +1743,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Module *M = IRB.GetInsertBlock()->getParent()->getParent(); IRB.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::asan_check_memaccess), - {IRB.CreatePointerCast(Addr, Int8PtrTy), + {IRB.CreatePointerCast(Addr, PtrTy), ConstantInt::get(Int32Ty, AccessInfo.Packed)}); return; } @@ -1869,7 +1884,7 @@ ModuleAddressSanitizer::getExcludedAliasedGlobal(const GlobalAlias &GA) const { // When compiling the kernel, globals that are aliased by symbols prefixed // by "__" are special and cannot be padded with a redzone. - if (GA.getName().startswith("__")) + if (GA.getName().starts_with("__")) return dyn_cast<GlobalVariable>(C->stripPointerCastsAndAliases()); return nullptr; @@ -1939,9 +1954,9 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { // Do not instrument function pointers to initialization and termination // routines: dynamic linker will not properly handle redzones. - if (Section.startswith(".preinit_array") || - Section.startswith(".init_array") || - Section.startswith(".fini_array")) { + if (Section.starts_with(".preinit_array") || + Section.starts_with(".init_array") || + Section.starts_with(".fini_array")) { return false; } @@ -1978,7 +1993,7 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { // those conform to /usr/lib/objc/runtime.h, so we can't add redzones to // them. if (ParsedSegment == "__OBJC" || - (ParsedSegment == "__DATA" && ParsedSection.startswith("__objc_"))) { + (ParsedSegment == "__DATA" && ParsedSection.starts_with("__objc_"))) { LLVM_DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n"); return false; } @@ -2006,7 +2021,7 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { if (CompileKernel) { // Globals that prefixed by "__" are special and cannot be padded with a // redzone. - if (G->getName().startswith("__")) + if (G->getName().starts_with("__")) return false; } @@ -2129,6 +2144,9 @@ ModuleAddressSanitizer::CreateMetadataGlobal(Module &M, Constant *Initializer, M, Initializer->getType(), false, Linkage, Initializer, Twine("__asan_global_") + GlobalValue::dropLLVMManglingEscape(OriginalName)); Metadata->setSection(getGlobalMetadataSection()); + // Place metadata in a large section for x86-64 ELF binaries to mitigate + // relocation pressure. + setGlobalVariableLargeSection(TargetTriple, *Metadata); return Metadata; } @@ -2177,7 +2195,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsCOFF( appendToCompilerUsed(M, MetadataGlobals); } -void ModuleAddressSanitizer::InstrumentGlobalsELF( +void ModuleAddressSanitizer::instrumentGlobalsELF( IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers, const std::string &UniqueModuleId) { @@ -2187,7 +2205,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsELF( // false negative odr violations at link time. If odr indicators are used, we // keep the comdat sections, as link time odr violations will be dectected on // the odr indicator symbols. - bool UseComdatForGlobalsGC = UseOdrIndicator; + bool UseComdatForGlobalsGC = UseOdrIndicator && !UniqueModuleId.empty(); SmallVector<GlobalValue *, 16> MetadataGlobals(ExtendedGlobals.size()); for (size_t i = 0; i < ExtendedGlobals.size(); i++) { @@ -2237,7 +2255,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsELF( // We also need to unregister globals at the end, e.g., when a shared library // gets closed. - if (DestructorKind != AsanDtorKind::None) { + if (DestructorKind != AsanDtorKind::None && !MetadataGlobals.empty()) { IRBuilder<> IrbDtor(CreateAsanModuleDtor(M)); IrbDtor.CreateCall(AsanUnregisterElfGlobals, {IRB.CreatePointerCast(RegisteredFlag, IntptrTy), @@ -2343,10 +2361,8 @@ void ModuleAddressSanitizer::InstrumentGlobalsWithMetadataArray( // redzones and inserts this function into llvm.global_ctors. // Sets *CtorComdat to true if the global registration code emitted into the // asan constructor is comdat-compatible. -bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, +void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat) { - *CtorComdat = false; - // Build set of globals that are aliased by some GA, where // getExcludedAliasedGlobal(GA) returns the relevant GlobalVariable. SmallPtrSet<const GlobalVariable *, 16> AliasedGlobalExclusions; @@ -2364,11 +2380,6 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, } size_t n = GlobalsToChange.size(); - if (n == 0) { - *CtorComdat = true; - return false; - } - auto &DL = M.getDataLayout(); // A global is described by a structure @@ -2391,8 +2402,11 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, // We shouldn't merge same module names, as this string serves as unique // module ID in runtime. - GlobalVariable *ModuleName = createPrivateGlobalForString( - M, M.getModuleIdentifier(), /*AllowMerging*/ false, kAsanGenPrefix); + GlobalVariable *ModuleName = + n != 0 + ? createPrivateGlobalForString(M, M.getModuleIdentifier(), + /*AllowMerging*/ false, kAsanGenPrefix) + : nullptr; for (size_t i = 0; i < n; i++) { GlobalVariable *G = GlobalsToChange[i]; @@ -2455,7 +2469,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, G->eraseFromParent(); NewGlobals[i] = NewGlobal; - Constant *ODRIndicator = ConstantExpr::getNullValue(IRB.getInt8PtrTy()); + Constant *ODRIndicator = ConstantPointerNull::get(PtrTy); GlobalValue *InstrumentedGlobal = NewGlobal; bool CanUsePrivateAliases = @@ -2470,8 +2484,8 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, // ODR should not happen for local linkage. if (NewGlobal->hasLocalLinkage()) { - ODRIndicator = ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1), - IRB.getInt8PtrTy()); + ODRIndicator = + ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1), PtrTy); } else if (UseOdrIndicator) { // With local aliases, we need to provide another externally visible // symbol __odr_asan_XXX to detect ODR violation. @@ -2517,19 +2531,27 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, } appendToCompilerUsed(M, ArrayRef<GlobalValue *>(GlobalsToAddToUsedList)); - std::string ELFUniqueModuleId = - (UseGlobalsGC && TargetTriple.isOSBinFormatELF()) ? getUniqueModuleId(&M) - : ""; - - if (!ELFUniqueModuleId.empty()) { - InstrumentGlobalsELF(IRB, M, NewGlobals, Initializers, ELFUniqueModuleId); + if (UseGlobalsGC && TargetTriple.isOSBinFormatELF()) { + // Use COMDAT and register globals even if n == 0 to ensure that (a) the + // linkage unit will only have one module constructor, and (b) the register + // function will be called. The module destructor is not created when n == + // 0. *CtorComdat = true; - } else if (UseGlobalsGC && TargetTriple.isOSBinFormatCOFF()) { - InstrumentGlobalsCOFF(IRB, M, NewGlobals, Initializers); - } else if (UseGlobalsGC && ShouldUseMachOGlobalsSection()) { - InstrumentGlobalsMachO(IRB, M, NewGlobals, Initializers); + instrumentGlobalsELF(IRB, M, NewGlobals, Initializers, + getUniqueModuleId(&M)); + } else if (n == 0) { + // When UseGlobalsGC is false, COMDAT can still be used if n == 0, because + // all compile units will have identical module constructor/destructor. + *CtorComdat = TargetTriple.isOSBinFormatELF(); } else { - InstrumentGlobalsWithMetadataArray(IRB, M, NewGlobals, Initializers); + *CtorComdat = false; + if (UseGlobalsGC && TargetTriple.isOSBinFormatCOFF()) { + InstrumentGlobalsCOFF(IRB, M, NewGlobals, Initializers); + } else if (UseGlobalsGC && ShouldUseMachOGlobalsSection()) { + InstrumentGlobalsMachO(IRB, M, NewGlobals, Initializers); + } else { + InstrumentGlobalsWithMetadataArray(IRB, M, NewGlobals, Initializers); + } } // Create calls for poisoning before initializers run and unpoisoning after. @@ -2537,7 +2559,6 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, createInitializerPoisonCalls(M, ModuleName); LLVM_DEBUG(dbgs() << M); - return true; } uint64_t @@ -2588,7 +2609,7 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) { } else { std::string AsanVersion = std::to_string(GetAsanVersion(M)); std::string VersionCheckName = - ClInsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : ""; + InsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : ""; std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions(M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{}, @@ -2601,10 +2622,10 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) { assert(AsanCtorFunction || ConstructorKind == AsanCtorKind::None); if (AsanCtorFunction) { IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator()); - InstrumentGlobals(IRB, M, &CtorComdat); + instrumentGlobals(IRB, M, &CtorComdat); } else { IRBuilder<> IRB(*C); - InstrumentGlobals(IRB, M, &CtorComdat); + instrumentGlobals(IRB, M, &CtorComdat); } } @@ -2684,15 +2705,12 @@ void AddressSanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo *T ? std::string("") : ClMemoryAccessCallbackPrefix; AsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy); - AsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy); + PtrTy, PtrTy, PtrTy, IntptrTy); + AsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", PtrTy, + PtrTy, PtrTy, IntptrTy); AsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset", TLI->getAttrList(C, {1}, /*Signed=*/false), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy); + PtrTy, PtrTy, IRB.getInt32Ty(), IntptrTy); AsanHandleNoReturnFunc = M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy()); @@ -2705,10 +2723,10 @@ void AddressSanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo *T AsanShadowGlobal = M.getOrInsertGlobal("__asan_shadow", ArrayType::get(IRB.getInt8Ty(), 0)); - AMDGPUAddressShared = M.getOrInsertFunction( - kAMDGPUAddressSharedName, IRB.getInt1Ty(), IRB.getInt8PtrTy()); - AMDGPUAddressPrivate = M.getOrInsertFunction( - kAMDGPUAddressPrivateName, IRB.getInt1Ty(), IRB.getInt8PtrTy()); + AMDGPUAddressShared = + M.getOrInsertFunction(kAMDGPUAddressSharedName, IRB.getInt1Ty(), PtrTy); + AMDGPUAddressPrivate = + M.getOrInsertFunction(kAMDGPUAddressPrivateName, IRB.getInt1Ty(), PtrTy); } bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) { @@ -2799,7 +2817,7 @@ bool AddressSanitizer::instrumentFunction(Function &F, return false; if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false; - if (F.getName().startswith("__asan_")) return false; + if (F.getName().starts_with("__asan_")) return false; bool FunctionModified = false; @@ -2890,9 +2908,9 @@ bool AddressSanitizer::instrumentFunction(Function &F, } } - bool UseCalls = (ClInstrumentationWithCallsThreshold >= 0 && + bool UseCalls = (InstrumentationWithCallsThreshold >= 0 && OperandsToInstrument.size() + IntrinToInstrument.size() > - (unsigned)ClInstrumentationWithCallsThreshold); + (unsigned)InstrumentationWithCallsThreshold); const DataLayout &DL = F.getParent()->getDataLayout(); ObjectSizeOpts ObjSizeOpts; ObjSizeOpts.RoundToAlign = true; @@ -3034,7 +3052,7 @@ void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask, Value *Ptr = IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)); Value *Poison = IRB.getIntN(StoreSizeInBytes * 8, Val); IRB.CreateAlignedStore( - Poison, IRB.CreateIntToPtr(Ptr, Poison->getType()->getPointerTo()), + Poison, IRB.CreateIntToPtr(Ptr, PointerType::getUnqual(Poison->getContext())), Align(1)); i += StoreSizeInBytes; @@ -3066,7 +3084,7 @@ void FunctionStackPoisoner::copyToShadow(ArrayRef<uint8_t> ShadowMask, for (; j < End && ShadowMask[j] && Val == ShadowBytes[j]; ++j) { } - if (j - i >= ClMaxInlinePoisoningSize) { + if (j - i >= ASan.MaxInlinePoisoningSize) { copyToShadowInline(ShadowMask, ShadowBytes, Done, i, IRB, ShadowBase); IRB.CreateCall(AsanSetShadowFunc[Val], {IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)), @@ -3500,7 +3518,7 @@ void FunctionStackPoisoner::processStaticAllocas() { IntptrTy, IRBPoison.CreateIntToPtr(SavedFlagPtrPtr, IntptrPtrTy)); IRBPoison.CreateStore( Constant::getNullValue(IRBPoison.getInt8Ty()), - IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getInt8PtrTy())); + IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getPtrTy())); } else { // For larger frames call __asan_stack_free_*. IRBPoison.CreateCall( diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index 709095184af5..ee5b81960417 100644 --- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -37,6 +37,9 @@ using namespace llvm; static cl::opt<bool> SingleTrapBB("bounds-checking-single-trap", cl::desc("Use one trap block per function")); +static cl::opt<bool> DebugTrapBB("bounds-checking-unique-traps", + cl::desc("Always use one trap per check")); + STATISTIC(ChecksAdded, "Bounds checks added"); STATISTIC(ChecksSkipped, "Bounds checks skipped"); STATISTIC(ChecksUnable, "Bounds checks unable to add"); @@ -180,19 +183,27 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, // will create a fresh block every time it is called. BasicBlock *TrapBB = nullptr; auto GetTrapBB = [&TrapBB](BuilderTy &IRB) { - if (TrapBB && SingleTrapBB) - return TrapBB; - Function *Fn = IRB.GetInsertBlock()->getParent(); - // FIXME: This debug location doesn't make a lot of sense in the - // `SingleTrapBB` case. auto DebugLoc = IRB.getCurrentDebugLocation(); IRBuilder<>::InsertPointGuard Guard(IRB); + + if (TrapBB && SingleTrapBB && !DebugTrapBB) + return TrapBB; + TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn); IRB.SetInsertPoint(TrapBB); - auto *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap); - CallInst *TrapCall = IRB.CreateCall(F, {}); + Intrinsic::ID IntrID = DebugTrapBB ? Intrinsic::ubsantrap : Intrinsic::trap; + auto *F = Intrinsic::getDeclaration(Fn->getParent(), IntrID); + + CallInst *TrapCall; + if (DebugTrapBB) { + TrapCall = + IRB.CreateCall(F, ConstantInt::get(IRB.getInt8Ty(), Fn->size())); + } else { + TrapCall = IRB.CreateCall(F, {}); + } + TrapCall->setDoesNotReturn(); TrapCall->setDoesNotThrow(); TrapCall->setDebugLoc(DebugLoc); diff --git a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp index d53e12ad1ff5..e2e5f21b376b 100644 --- a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp @@ -66,7 +66,7 @@ static bool runCGProfilePass( if (F.isDeclaration() || !F.getEntryCount()) continue; auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); - if (BFI.getEntryFreq() == 0) + if (BFI.getEntryFreq() == BlockFrequency(0)) continue; TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); for (auto &BB : F) { diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 3e3be536defc..0a3d8d6000cf 100644 --- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -1593,8 +1593,8 @@ static void insertTrivialPHIs(CHRScope *Scope, // Insert a trivial phi for I (phi [&I, P0], [&I, P1], ...) at // ExitBlock. Replace I with the new phi in UI unless UI is another // phi at ExitBlock. - PHINode *PN = PHINode::Create(I.getType(), pred_size(ExitBlock), "", - &ExitBlock->front()); + PHINode *PN = PHINode::Create(I.getType(), pred_size(ExitBlock), ""); + PN->insertBefore(ExitBlock->begin()); for (BasicBlock *Pred : predecessors(ExitBlock)) { PN->addIncoming(&I, Pred); } @@ -1777,6 +1777,13 @@ void CHR::cloneScopeBlocks(CHRScope *Scope, BasicBlock *NewBB = CloneBasicBlock(BB, VMap, ".nonchr", &F); NewBlocks.push_back(NewBB); VMap[BB] = NewBB; + + // Unreachable predecessors will not be cloned and will not have an edge + // to the cloned block. As such, also remove them from any phi nodes. + for (PHINode &PN : make_early_inc_range(NewBB->phis())) + PN.removeIncomingValueIf([&](unsigned Idx) { + return !DT.isReachableFromEntry(PN.getIncomingBlock(Idx)); + }); } // Place the cloned blocks right after the original blocks (right before the @@ -1871,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope, static_cast<uint32_t>(CHRBranchBias.scale(1000)), static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)), }; - MDBuilder MDB(F.getContext()); - MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + setBranchWeights(*MergedBR, Weights); CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1] << "\n"); } diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 8caee5bed8ed..2ba127bba6f6 100644 --- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -564,7 +564,7 @@ class DataFlowSanitizer { /// getShadowTy([n x T]) = [n x getShadowTy(T)] /// getShadowTy(other type) = i16 Type *getShadowTy(Type *OrigTy); - /// Returns the shadow type of of V's type. + /// Returns the shadow type of V's type. Type *getShadowTy(Value *V); const uint64_t NumOfElementsInArgOrgTLS = ArgTLSSize / OriginWidthBytes; @@ -1145,7 +1145,7 @@ bool DataFlowSanitizer::initializeModule(Module &M) { Mod = &M; Ctx = &M.getContext(); - Int8Ptr = Type::getInt8PtrTy(*Ctx); + Int8Ptr = PointerType::getUnqual(*Ctx); OriginTy = IntegerType::get(*Ctx, OriginWidthBits); OriginPtrTy = PointerType::getUnqual(OriginTy); PrimitiveShadowTy = IntegerType::get(*Ctx, ShadowWidthBits); @@ -1162,19 +1162,19 @@ bool DataFlowSanitizer::initializeModule(Module &M) { FunctionType::get(IntegerType::get(*Ctx, 64), DFSanLoadLabelAndOriginArgs, /*isVarArg=*/false); DFSanUnimplementedFnTy = FunctionType::get( - Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false); + Type::getVoidTy(*Ctx), PointerType::getUnqual(*Ctx), /*isVarArg=*/false); Type *DFSanWrapperExternWeakNullArgs[2] = {Int8Ptr, Int8Ptr}; DFSanWrapperExternWeakNullFnTy = FunctionType::get(Type::getVoidTy(*Ctx), DFSanWrapperExternWeakNullArgs, /*isVarArg=*/false); Type *DFSanSetLabelArgs[4] = {PrimitiveShadowTy, OriginTy, - Type::getInt8PtrTy(*Ctx), IntptrTy}; + PointerType::getUnqual(*Ctx), IntptrTy}; DFSanSetLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), DFSanSetLabelArgs, /*isVarArg=*/false); DFSanNonzeroLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), std::nullopt, /*isVarArg=*/false); DFSanVarargWrapperFnTy = FunctionType::get( - Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false); + Type::getVoidTy(*Ctx), PointerType::getUnqual(*Ctx), /*isVarArg=*/false); DFSanConditionalCallbackFnTy = FunctionType::get(Type::getVoidTy(*Ctx), PrimitiveShadowTy, /*isVarArg=*/false); @@ -1288,7 +1288,7 @@ void DataFlowSanitizer::buildExternWeakCheckIfNeeded(IRBuilder<> &IRB, // for a extern weak function, add a check here to help identify the issue. if (GlobalValue::isExternalWeakLinkage(F->getLinkage())) { std::vector<Value *> Args; - Args.push_back(IRB.CreatePointerCast(F, IRB.getInt8PtrTy())); + Args.push_back(F); Args.push_back(IRB.CreateGlobalStringPtr(F->getName())); IRB.CreateCall(DFSanWrapperExternWeakNullFn, Args); } @@ -1553,7 +1553,7 @@ bool DataFlowSanitizer::runImpl( assert(isa<Function>(C) && "Personality routine is not a function!"); Function *F = cast<Function>(C); if (!isInstrumented(F)) - llvm::erase_value(FnsToInstrument, F); + llvm::erase(FnsToInstrument, F); } } @@ -1575,7 +1575,7 @@ bool DataFlowSanitizer::runImpl( // below will take care of instrumenting it. Function *NewF = buildWrapperFunction(F, "", GA.getLinkage(), F->getFunctionType()); - GA.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, GA.getType())); + GA.replaceAllUsesWith(NewF); NewF->takeName(&GA); GA.eraseFromParent(); FnsToInstrument.push_back(NewF); @@ -1622,9 +1622,6 @@ bool DataFlowSanitizer::runImpl( WrapperLinkage, FT); NewF->removeFnAttrs(ReadOnlyNoneAttrs); - Value *WrappedFnCst = - ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT)); - // Extern weak functions can sometimes be null at execution time. // Code will sometimes check if an extern weak function is null. // This could look something like: @@ -1657,9 +1654,9 @@ bool DataFlowSanitizer::runImpl( } return true; }; - F.replaceUsesWithIf(WrappedFnCst, IsNotCmpUse); + F.replaceUsesWithIf(NewF, IsNotCmpUse); - UnwrappedFnMap[WrappedFnCst] = &F; + UnwrappedFnMap[NewF] = &F; *FI = NewF; if (!F.isDeclaration()) { @@ -2273,8 +2270,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( IRBuilder<> IRB(Pos); CallInst *Call = IRB.CreateCall(DFS.DFSanLoadLabelAndOriginFn, - {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), - ConstantInt::get(DFS.IntptrTy, Size)}); + {Addr, ConstantInt::get(DFS.IntptrTy, Size)}); Call->addRetAttr(Attribute::ZExt); return {IRB.CreateTrunc(IRB.CreateLShr(Call, DFS.OriginWidthBits), DFS.PrimitiveShadowTy), @@ -2436,9 +2432,9 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { if (ClEventCallbacks) { IRBuilder<> IRB(Pos); - Value *Addr8 = IRB.CreateBitCast(LI.getPointerOperand(), DFSF.DFS.Int8Ptr); + Value *Addr = LI.getPointerOperand(); CallInst *CI = - IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr8}); + IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr}); CI->addParamAttr(0, Attribute::ZExt); } @@ -2530,10 +2526,9 @@ void DFSanFunction::storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size, } if (shouldInstrumentWithCall()) { - IRB.CreateCall(DFS.DFSanMaybeStoreOriginFn, - {CollapsedShadow, - IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), - ConstantInt::get(DFS.IntptrTy, Size), Origin}); + IRB.CreateCall( + DFS.DFSanMaybeStoreOriginFn, + {CollapsedShadow, Addr, ConstantInt::get(DFS.IntptrTy, Size), Origin}); } else { Value *Cmp = convertToBool(CollapsedShadow, IRB, "_dfscmp"); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); @@ -2554,9 +2549,7 @@ void DFSanFunction::storeZeroPrimitiveShadow(Value *Addr, uint64_t Size, IntegerType::get(*DFS.Ctx, Size * DFS.ShadowWidthBits); Value *ExtZeroShadow = ConstantInt::get(ShadowTy, 0); Value *ShadowAddr = DFS.getShadowAddress(Addr, Pos); - Value *ExtShadowAddr = - IRB.CreateBitCast(ShadowAddr, PointerType::getUnqual(ShadowTy)); - IRB.CreateAlignedStore(ExtZeroShadow, ExtShadowAddr, ShadowAlign); + IRB.CreateAlignedStore(ExtZeroShadow, ShadowAddr, ShadowAlign); // Do not write origins for 0 shadows because we do not trace origins for // untainted sinks. } @@ -2611,11 +2604,9 @@ void DFSanFunction::storePrimitiveShadowOrigin(Value *Addr, uint64_t Size, ShadowVec, PrimitiveShadow, ConstantInt::get(Type::getInt32Ty(*DFS.Ctx), I)); } - Value *ShadowVecAddr = - IRB.CreateBitCast(ShadowAddr, PointerType::getUnqual(ShadowVecTy)); do { Value *CurShadowVecAddr = - IRB.CreateConstGEP1_32(ShadowVecTy, ShadowVecAddr, Offset); + IRB.CreateConstGEP1_32(ShadowVecTy, ShadowAddr, Offset); IRB.CreateAlignedStore(ShadowVec, CurShadowVecAddr, ShadowAlign); LeftSize -= ShadowVecSize; ++Offset; @@ -2699,9 +2690,9 @@ void DFSanVisitor::visitStoreInst(StoreInst &SI) { PrimitiveShadow, Origin, &SI); if (ClEventCallbacks) { IRBuilder<> IRB(&SI); - Value *Addr8 = IRB.CreateBitCast(SI.getPointerOperand(), DFSF.DFS.Int8Ptr); + Value *Addr = SI.getPointerOperand(); CallInst *CI = - IRB.CreateCall(DFSF.DFS.DFSanStoreCallbackFn, {PrimitiveShadow, Addr8}); + IRB.CreateCall(DFSF.DFS.DFSanStoreCallbackFn, {PrimitiveShadow, Addr}); CI->addParamAttr(0, Attribute::ZExt); } } @@ -2918,11 +2909,9 @@ void DFSanVisitor::visitMemSetInst(MemSetInst &I) { Value *ValOrigin = DFSF.DFS.shouldTrackOrigins() ? DFSF.getOrigin(I.getValue()) : DFSF.DFS.ZeroOrigin; - IRB.CreateCall( - DFSF.DFS.DFSanSetLabelFn, - {ValShadow, ValOrigin, - IRB.CreateBitCast(I.getDest(), Type::getInt8PtrTy(*DFSF.DFS.Ctx)), - IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)}); + IRB.CreateCall(DFSF.DFS.DFSanSetLabelFn, + {ValShadow, ValOrigin, I.getDest(), + IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)}); } void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) { @@ -2933,28 +2922,24 @@ void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) { if (DFSF.DFS.shouldTrackOrigins()) { IRB.CreateCall( DFSF.DFS.DFSanMemOriginTransferFn, - {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(I.getArgOperand(1), IRB.getInt8PtrTy()), + {I.getArgOperand(0), I.getArgOperand(1), IRB.CreateIntCast(I.getArgOperand(2), DFSF.DFS.IntptrTy, false)}); } - Value *RawDestShadow = DFSF.DFS.getShadowAddress(I.getDest(), &I); + Value *DestShadow = DFSF.DFS.getShadowAddress(I.getDest(), &I); Value *SrcShadow = DFSF.DFS.getShadowAddress(I.getSource(), &I); Value *LenShadow = IRB.CreateMul(I.getLength(), ConstantInt::get(I.getLength()->getType(), DFSF.DFS.ShadowWidthBytes)); - Type *Int8Ptr = Type::getInt8PtrTy(*DFSF.DFS.Ctx); - Value *DestShadow = IRB.CreateBitCast(RawDestShadow, Int8Ptr); - SrcShadow = IRB.CreateBitCast(SrcShadow, Int8Ptr); auto *MTI = cast<MemTransferInst>( IRB.CreateCall(I.getFunctionType(), I.getCalledOperand(), {DestShadow, SrcShadow, LenShadow, I.getVolatileCst()})); MTI->setDestAlignment(DFSF.getShadowAlign(I.getDestAlign().valueOrOne())); MTI->setSourceAlignment(DFSF.getShadowAlign(I.getSourceAlign().valueOrOne())); if (ClEventCallbacks) { - IRB.CreateCall(DFSF.DFS.DFSanMemTransferCallbackFn, - {RawDestShadow, - IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)}); + IRB.CreateCall( + DFSF.DFS.DFSanMemTransferCallbackFn, + {DestShadow, IRB.CreateZExtOrTrunc(I.getLength(), DFSF.DFS.IntptrTy)}); } } @@ -3225,10 +3210,9 @@ void DFSanVisitor::visitLibAtomicLoad(CallBase &CB) { // TODO: Support ClCombinePointerLabelsOnLoad // TODO: Support ClEventCallbacks - NextIRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, - {NextIRB.CreatePointerCast(DstPtr, NextIRB.getInt8PtrTy()), - NextIRB.CreatePointerCast(SrcPtr, NextIRB.getInt8PtrTy()), - NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); + NextIRB.CreateCall( + DFSF.DFS.DFSanMemShadowOriginTransferFn, + {DstPtr, SrcPtr, NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); } Value *DFSanVisitor::makeAddReleaseOrderingTable(IRBuilder<> &IRB) { @@ -3264,10 +3248,9 @@ void DFSanVisitor::visitLibAtomicStore(CallBase &CB) { // TODO: Support ClCombinePointerLabelsOnStore // TODO: Support ClEventCallbacks - IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, - {IRB.CreatePointerCast(DstPtr, IRB.getInt8PtrTy()), - IRB.CreatePointerCast(SrcPtr, IRB.getInt8PtrTy()), - IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); + IRB.CreateCall( + DFSF.DFS.DFSanMemShadowOriginTransferFn, + {DstPtr, SrcPtr, IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); } void DFSanVisitor::visitLibAtomicExchange(CallBase &CB) { @@ -3285,16 +3268,14 @@ void DFSanVisitor::visitLibAtomicExchange(CallBase &CB) { // the additional complexity to address this is not warrented. // Current Target to Dest - IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, - {IRB.CreatePointerCast(DstPtr, IRB.getInt8PtrTy()), - IRB.CreatePointerCast(TargetPtr, IRB.getInt8PtrTy()), - IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); + IRB.CreateCall( + DFSF.DFS.DFSanMemShadowOriginTransferFn, + {DstPtr, TargetPtr, IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); // Current Src to Target (overriding) - IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, - {IRB.CreatePointerCast(TargetPtr, IRB.getInt8PtrTy()), - IRB.CreatePointerCast(SrcPtr, IRB.getInt8PtrTy()), - IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); + IRB.CreateCall( + DFSF.DFS.DFSanMemShadowOriginTransferFn, + {TargetPtr, SrcPtr, IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); } void DFSanVisitor::visitLibAtomicCompareExchange(CallBase &CB) { @@ -3317,13 +3298,10 @@ void DFSanVisitor::visitLibAtomicCompareExchange(CallBase &CB) { // If original call returned true, copy Desired to Target. // If original call returned false, copy Target to Expected. - NextIRB.CreateCall( - DFSF.DFS.DFSanMemShadowOriginConditionalExchangeFn, - {NextIRB.CreateIntCast(&CB, NextIRB.getInt8Ty(), false), - NextIRB.CreatePointerCast(TargetPtr, NextIRB.getInt8PtrTy()), - NextIRB.CreatePointerCast(ExpectedPtr, NextIRB.getInt8PtrTy()), - NextIRB.CreatePointerCast(DesiredPtr, NextIRB.getInt8PtrTy()), - NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); + NextIRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginConditionalExchangeFn, + {NextIRB.CreateIntCast(&CB, NextIRB.getInt8Ty(), false), + TargetPtr, ExpectedPtr, DesiredPtr, + NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); } void DFSanVisitor::visitCallBase(CallBase &CB) { diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 21f0b1a92293..1ff0a34bae24 100644 --- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -148,7 +148,7 @@ private: std::string mangleName(const DICompileUnit *CU, GCovFileType FileType); GCOVOptions Options; - support::endianness Endian; + llvm::endianness Endian; raw_ostream *os; // Checksum, produced by hash of EdgeDestinations @@ -750,7 +750,7 @@ static BasicBlock *getInstrBB(CFGMST<Edge, BBInfo> &MST, Edge &E, #ifndef NDEBUG static void dumpEdges(CFGMST<Edge, BBInfo> &MST, GCOVFunction &GF) { size_t ID = 0; - for (auto &E : make_pointee_range(MST.AllEdges)) { + for (const auto &E : make_pointee_range(MST.allEdges())) { GCOVBlock &Src = E.SrcBB ? GF.getBlock(E.SrcBB) : GF.getEntryBlock(); GCOVBlock &Dst = E.DestBB ? GF.getBlock(E.DestBB) : GF.getReturnBlock(); dbgs() << " Edge " << ID++ << ": " << Src.Number << "->" << Dst.Number @@ -788,8 +788,8 @@ bool GCOVProfiler::emitProfileNotes( std::vector<uint8_t> EdgeDestinations; SmallVector<std::pair<GlobalVariable *, MDNode *>, 8> CountersBySP; - Endian = M->getDataLayout().isLittleEndian() ? support::endianness::little - : support::endianness::big; + Endian = M->getDataLayout().isLittleEndian() ? llvm::endianness::little + : llvm::endianness::big; unsigned FunctionIdent = 0; for (auto &F : M->functions()) { DISubprogram *SP = F.getSubprogram(); @@ -820,8 +820,8 @@ bool GCOVProfiler::emitProfileNotes( CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry_=*/false, BPI, BFI); // getInstrBB can split basic blocks and push elements to AllEdges. - for (size_t I : llvm::seq<size_t>(0, MST.AllEdges.size())) { - auto &E = *MST.AllEdges[I]; + for (size_t I : llvm::seq<size_t>(0, MST.numEdges())) { + auto &E = *MST.allEdges()[I]; // For now, disable spanning tree optimization when fork or exec* is // used. if (HasExecOrFork) @@ -836,16 +836,16 @@ bool GCOVProfiler::emitProfileNotes( // Some non-tree edges are IndirectBr which cannot be split. Ignore them // as well. - llvm::erase_if(MST.AllEdges, [](std::unique_ptr<Edge> &E) { + llvm::erase_if(MST.allEdges(), [](std::unique_ptr<Edge> &E) { return E->Removed || (!E->InMST && !E->Place); }); const size_t Measured = std::stable_partition( - MST.AllEdges.begin(), MST.AllEdges.end(), + MST.allEdges().begin(), MST.allEdges().end(), [](std::unique_ptr<Edge> &E) { return E->Place; }) - - MST.AllEdges.begin(); + MST.allEdges().begin(); for (size_t I : llvm::seq<size_t>(0, Measured)) { - Edge &E = *MST.AllEdges[I]; + Edge &E = *MST.allEdges()[I]; GCOVBlock &Src = E.SrcBB ? Func.getBlock(E.SrcBB) : Func.getEntryBlock(); GCOVBlock &Dst = @@ -854,13 +854,13 @@ bool GCOVProfiler::emitProfileNotes( E.DstNumber = Dst.Number; } std::stable_sort( - MST.AllEdges.begin(), MST.AllEdges.begin() + Measured, + MST.allEdges().begin(), MST.allEdges().begin() + Measured, [](const std::unique_ptr<Edge> &L, const std::unique_ptr<Edge> &R) { return L->SrcNumber != R->SrcNumber ? L->SrcNumber < R->SrcNumber : L->DstNumber < R->DstNumber; }); - for (const Edge &E : make_pointee_range(MST.AllEdges)) { + for (const Edge &E : make_pointee_range(MST.allEdges())) { GCOVBlock &Src = E.SrcBB ? Func.getBlock(E.SrcBB) : Func.getEntryBlock(); GCOVBlock &Dst = @@ -898,7 +898,9 @@ bool GCOVProfiler::emitProfileNotes( if (Line == Loc.getLine()) continue; Line = Loc.getLine(); - if (SP != getDISubprogram(Loc.getScope())) + MDNode *Scope = Loc.getScope(); + // TODO: Handle blocks from another file due to #line, #include, etc. + if (isa<DILexicalBlockFile>(Scope) || SP != getDISubprogram(Scope)) continue; GCOVLines &Lines = Block.getFile(Filename); @@ -915,7 +917,7 @@ bool GCOVProfiler::emitProfileNotes( CountersBySP.emplace_back(Counters, SP); for (size_t I : llvm::seq<size_t>(0, Measured)) { - const Edge &E = *MST.AllEdges[I]; + const Edge &E = *MST.allEdges()[I]; IRBuilder<> Builder(E.Place, E.Place->getFirstInsertionPt()); Value *V = Builder.CreateConstInBoundsGEP2_64( Counters->getValueType(), Counters, 0, I); @@ -955,7 +957,7 @@ bool GCOVProfiler::emitProfileNotes( continue; } os = &out; - if (Endian == support::endianness::big) { + if (Endian == llvm::endianness::big) { out.write("gcno", 4); out.write(Options.Version, 4); } else { @@ -1029,9 +1031,9 @@ void GCOVProfiler::emitGlobalConstructor( FunctionCallee GCOVProfiler::getStartFileFunc(const TargetLibraryInfo *TLI) { Type *Args[] = { - Type::getInt8PtrTy(*Ctx), // const char *orig_filename - Type::getInt32Ty(*Ctx), // uint32_t version - Type::getInt32Ty(*Ctx), // uint32_t checksum + PointerType::getUnqual(*Ctx), // const char *orig_filename + Type::getInt32Ty(*Ctx), // uint32_t version + Type::getInt32Ty(*Ctx), // uint32_t checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); return M->getOrInsertFunction("llvm_gcda_start_file", FTy, @@ -1051,8 +1053,8 @@ FunctionCallee GCOVProfiler::getEmitFunctionFunc(const TargetLibraryInfo *TLI) { FunctionCallee GCOVProfiler::getEmitArcsFunc(const TargetLibraryInfo *TLI) { Type *Args[] = { - Type::getInt32Ty(*Ctx), // uint32_t num_counters - Type::getInt64PtrTy(*Ctx), // uint64_t *counters + Type::getInt32Ty(*Ctx), // uint32_t num_counters + PointerType::getUnqual(*Ctx), // uint64_t *counters }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); return M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy, @@ -1098,19 +1100,16 @@ Function *GCOVProfiler::insertCounterWriteout( // Collect the relevant data into a large constant data structure that we can // walk to write out everything. StructType *StartFileCallArgsTy = StructType::create( - {Builder.getInt8PtrTy(), Builder.getInt32Ty(), Builder.getInt32Ty()}, + {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getInt32Ty()}, "start_file_args_ty"); StructType *EmitFunctionCallArgsTy = StructType::create( {Builder.getInt32Ty(), Builder.getInt32Ty(), Builder.getInt32Ty()}, "emit_function_args_ty"); - StructType *EmitArcsCallArgsTy = StructType::create( - {Builder.getInt32Ty(), Builder.getInt64Ty()->getPointerTo()}, - "emit_arcs_args_ty"); - StructType *FileInfoTy = - StructType::create({StartFileCallArgsTy, Builder.getInt32Ty(), - EmitFunctionCallArgsTy->getPointerTo(), - EmitArcsCallArgsTy->getPointerTo()}, - "file_info"); + auto *PtrTy = Builder.getPtrTy(); + StructType *EmitArcsCallArgsTy = + StructType::create({Builder.getInt32Ty(), PtrTy}, "emit_arcs_args_ty"); + StructType *FileInfoTy = StructType::create( + {StartFileCallArgsTy, Builder.getInt32Ty(), PtrTy, PtrTy}, "file_info"); Constant *Zero32 = Builder.getInt32(0); // Build an explicit array of two zeros for use in ConstantExpr GEP building. diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 28db47a19092..f7f8fed643e9 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -17,9 +17,11 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/BinaryFormat/ELF.h" @@ -42,7 +44,6 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/IR/NoFolder.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" @@ -52,6 +53,7 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" @@ -134,7 +136,7 @@ static cl::opt<size_t> ClMaxLifetimes( static cl::opt<bool> ClUseAfterScope("hwasan-use-after-scope", cl::desc("detect use after scope within function"), - cl::Hidden, cl::init(false)); + cl::Hidden, cl::init(true)); static cl::opt<bool> ClGenerateTagsWithCalls( "hwasan-generate-tags-with-calls", @@ -223,6 +225,10 @@ static cl::opt<bool> ClInlineAllChecks("hwasan-inline-all-checks", cl::desc("inline all checks"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClInlineFastPathChecks("hwasan-inline-fast-path-checks", + cl::desc("inline all checks"), + cl::Hidden, cl::init(false)); + // Enabled from clang by "-fsanitize-hwaddress-experimental-aliasing". static cl::opt<bool> ClUsePageAliases("hwasan-experimental-use-page-aliases", cl::desc("Use page aliasing in HWASan"), @@ -274,9 +280,18 @@ public: initializeModule(); } + void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM); + +private: + struct ShadowTagCheckInfo { + Instruction *TagMismatchTerm = nullptr; + Value *PtrLong = nullptr; + Value *AddrLong = nullptr; + Value *PtrTag = nullptr; + Value *MemTag = nullptr; + }; void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; } - void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM); void initializeModule(); void createHwasanCtorComdat(); @@ -291,18 +306,24 @@ public: Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); int64_t getAccessInfo(bool IsWrite, unsigned AccessSizeIndex); + ShadowTagCheckInfo insertShadowTagCheck(Value *Ptr, Instruction *InsertBefore, + DomTreeUpdater &DTU, LoopInfo *LI); void instrumentMemAccessOutline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, - Instruction *InsertBefore); + Instruction *InsertBefore, + DomTreeUpdater &DTU, LoopInfo *LI); void instrumentMemAccessInline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, - Instruction *InsertBefore); + Instruction *InsertBefore, DomTreeUpdater &DTU, + LoopInfo *LI); bool ignoreMemIntrinsic(MemIntrinsic *MI); void instrumentMemIntrinsic(MemIntrinsic *MI); - bool instrumentMemAccess(InterestingMemoryOperand &O); + bool instrumentMemAccess(InterestingMemoryOperand &O, DomTreeUpdater &DTU, + LoopInfo *LI); bool ignoreAccess(Instruction *Inst, Value *Ptr); void getInterestingMemoryOperands( - Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting); + Instruction *I, const TargetLibraryInfo &TLI, + SmallVectorImpl<InterestingMemoryOperand> &Interesting); void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); @@ -332,7 +353,6 @@ public: void instrumentPersonalityFunctions(); -private: LLVMContext *C; Module &M; const StackSafetyGlobalInfo *SSI; @@ -364,7 +384,7 @@ private: Type *VoidTy = Type::getVoidTy(M.getContext()); Type *IntptrTy; - Type *Int8PtrTy; + PointerType *PtrTy; Type *Int8Ty; Type *Int32Ty; Type *Int64Ty = Type::getInt64Ty(M.getContext()); @@ -372,6 +392,7 @@ private: bool CompileKernel; bool Recover; bool OutlinedChecks; + bool InlineFastPath; bool UseShortGranules; bool InstrumentLandingPads; bool InstrumentWithCalls; @@ -420,6 +441,12 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M, HWASan.sanitizeFunction(F, FAM); PreservedAnalyses PA = PreservedAnalyses::none(); + // DominatorTreeAnalysis, PostDominatorTreeAnalysis, and LoopAnalysis + // are incrementally updated throughout this pass whenever + // SplitBlockAndInsertIfThen is called. + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); // GlobalsAA is considered stateless and does not get invalidated unless // explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers // make changes that require GlobalsAA to be invalidated. @@ -560,7 +587,7 @@ void HWAddressSanitizer::initializeModule() { C = &(M.getContext()); IRBuilder<> IRB(*C); IntptrTy = IRB.getIntPtrTy(DL); - Int8PtrTy = IRB.getInt8PtrTy(); + PtrTy = IRB.getPtrTy(); Int8Ty = IRB.getInt8Ty(); Int32Ty = IRB.getInt32Ty(); @@ -579,6 +606,13 @@ void HWAddressSanitizer::initializeModule() { TargetTriple.isOSBinFormatELF() && (ClInlineAllChecks.getNumOccurrences() ? !ClInlineAllChecks : !Recover); + InlineFastPath = + (ClInlineFastPathChecks.getNumOccurrences() + ? ClInlineFastPathChecks + : !(TargetTriple.isAndroid() || + TargetTriple.isOSFuchsia())); // These platforms may prefer less + // inlining to reduce binary size. + if (ClMatchAllTag.getNumOccurrences()) { if (ClMatchAllTag != -1) { MatchAllTag = ClMatchAllTag & 0xFF; @@ -633,19 +667,19 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false); HwasanMemoryAccessCallbackFnTy = FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false); - HwasanMemTransferFnTy = FunctionType::get( - Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy, Int8Ty}, false); - HwasanMemsetFnTy = FunctionType::get( - Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}, false); + HwasanMemTransferFnTy = + FunctionType::get(PtrTy, {PtrTy, PtrTy, IntptrTy, Int8Ty}, false); + HwasanMemsetFnTy = + FunctionType::get(PtrTy, {PtrTy, Int32Ty, IntptrTy, Int8Ty}, false); } else { HwasanMemoryAccessCallbackSizedFnTy = FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false); HwasanMemoryAccessCallbackFnTy = FunctionType::get(VoidTy, {IntptrTy}, false); HwasanMemTransferFnTy = - FunctionType::get(Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy}, false); + FunctionType::get(PtrTy, {PtrTy, PtrTy, IntptrTy}, false); HwasanMemsetFnTy = - FunctionType::get(Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy}, false); + FunctionType::get(PtrTy, {PtrTy, Int32Ty, IntptrTy}, false); } for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { @@ -679,7 +713,7 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { MemIntrinCallbackPrefix + "memset" + MatchAllStr, HwasanMemsetFnTy); HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy, - Int8PtrTy, Int8Ty, IntptrTy); + PtrTy, Int8Ty, IntptrTy); HwasanGenerateTagFunc = M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty); @@ -699,7 +733,7 @@ Value *HWAddressSanitizer::getOpaqueNoopCast(IRBuilder<> &IRB, Value *Val) { // This prevents code bloat as a result of rematerializing trivial definitions // such as constants or global addresses at every load and store. InlineAsm *Asm = - InlineAsm::get(FunctionType::get(Int8PtrTy, {Val->getType()}, false), + InlineAsm::get(FunctionType::get(PtrTy, {Val->getType()}, false), StringRef(""), StringRef("=r,0"), /*hasSideEffects=*/false); return IRB.CreateCall(Asm, {Val}, ".hwasan.shadow"); @@ -713,15 +747,15 @@ Value *HWAddressSanitizer::getShadowNonTls(IRBuilder<> &IRB) { if (Mapping.Offset != kDynamicShadowSentinel) return getOpaqueNoopCast( IRB, ConstantExpr::getIntToPtr( - ConstantInt::get(IntptrTy, Mapping.Offset), Int8PtrTy)); + ConstantInt::get(IntptrTy, Mapping.Offset), PtrTy)); if (Mapping.InGlobal) return getDynamicShadowIfunc(IRB); Value *GlobalDynamicAddress = IRB.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal( - kHwasanShadowMemoryDynamicAddress, Int8PtrTy); - return IRB.CreateLoad(Int8PtrTy, GlobalDynamicAddress); + kHwasanShadowMemoryDynamicAddress, PtrTy); + return IRB.CreateLoad(PtrTy, GlobalDynamicAddress); } bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { @@ -748,7 +782,8 @@ bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { } void HWAddressSanitizer::getInterestingMemoryOperands( - Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting) { + Instruction *I, const TargetLibraryInfo &TLI, + SmallVectorImpl<InterestingMemoryOperand> &Interesting) { // Skip memory accesses inserted by another instrumentation. if (I->hasMetadata(LLVMContext::MD_nosanitize)) return; @@ -786,6 +821,7 @@ void HWAddressSanitizer::getInterestingMemoryOperands( Type *Ty = CI->getParamByValType(ArgNo); Interesting.emplace_back(I, ArgNo, false, Ty, Align(1)); } + maybeMarkSanitizerLibraryCallNoBuiltin(CI, &TLI); } } @@ -824,7 +860,7 @@ Value *HWAddressSanitizer::memToShadow(Value *Mem, IRBuilder<> &IRB) { // Mem >> Scale Value *Shadow = IRB.CreateLShr(Mem, Mapping.Scale); if (Mapping.Offset == 0) - return IRB.CreateIntToPtr(Shadow, Int8PtrTy); + return IRB.CreateIntToPtr(Shadow, PtrTy); // (Mem >> Scale) + Offset return IRB.CreateGEP(Int8Ty, ShadowBase, Shadow); } @@ -839,14 +875,48 @@ int64_t HWAddressSanitizer::getAccessInfo(bool IsWrite, (AccessSizeIndex << HWASanAccessInfo::AccessSizeShift); } +HWAddressSanitizer::ShadowTagCheckInfo +HWAddressSanitizer::insertShadowTagCheck(Value *Ptr, Instruction *InsertBefore, + DomTreeUpdater &DTU, LoopInfo *LI) { + ShadowTagCheckInfo R; + + IRBuilder<> IRB(InsertBefore); + + R.PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); + R.PtrTag = + IRB.CreateTrunc(IRB.CreateLShr(R.PtrLong, PointerTagShift), Int8Ty); + R.AddrLong = untagPointer(IRB, R.PtrLong); + Value *Shadow = memToShadow(R.AddrLong, IRB); + R.MemTag = IRB.CreateLoad(Int8Ty, Shadow); + Value *TagMismatch = IRB.CreateICmpNE(R.PtrTag, R.MemTag); + + if (MatchAllTag.has_value()) { + Value *TagNotIgnored = IRB.CreateICmpNE( + R.PtrTag, ConstantInt::get(R.PtrTag->getType(), *MatchAllTag)); + TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); + } + + R.TagMismatchTerm = SplitBlockAndInsertIfThen( + TagMismatch, InsertBefore, false, + MDBuilder(*C).createBranchWeights(1, 100000), &DTU, LI); + + return R; +} + void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, - Instruction *InsertBefore) { + Instruction *InsertBefore, + DomTreeUpdater &DTU, + LoopInfo *LI) { assert(!UsePageAliases); const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex); + + if (InlineFastPath) + InsertBefore = + insertShadowTagCheck(Ptr, InsertBefore, DTU, LI).TagMismatchTerm; + IRBuilder<> IRB(InsertBefore); Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); IRB.CreateCall(Intrinsic::getDeclaration( M, UseShortGranules ? Intrinsic::hwasan_check_memaccess_shortgranules @@ -856,55 +926,38 @@ void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite, void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, - Instruction *InsertBefore) { + Instruction *InsertBefore, + DomTreeUpdater &DTU, + LoopInfo *LI) { assert(!UsePageAliases); const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex); - IRBuilder<> IRB(InsertBefore); - - Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); - Value *PtrTag = - IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), Int8Ty); - Value *AddrLong = untagPointer(IRB, PtrLong); - Value *Shadow = memToShadow(AddrLong, IRB); - Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow); - Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); - - if (MatchAllTag.has_value()) { - Value *TagNotIgnored = IRB.CreateICmpNE( - PtrTag, ConstantInt::get(PtrTag->getType(), *MatchAllTag)); - TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); - } - Instruction *CheckTerm = - SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, false, - MDBuilder(*C).createBranchWeights(1, 100000)); + ShadowTagCheckInfo TCI = insertShadowTagCheck(Ptr, InsertBefore, DTU, LI); - IRB.SetInsertPoint(CheckTerm); + IRBuilder<> IRB(TCI.TagMismatchTerm); Value *OutOfShortGranuleTagRange = - IRB.CreateICmpUGT(MemTag, ConstantInt::get(Int8Ty, 15)); - Instruction *CheckFailTerm = - SplitBlockAndInsertIfThen(OutOfShortGranuleTagRange, CheckTerm, !Recover, - MDBuilder(*C).createBranchWeights(1, 100000)); + IRB.CreateICmpUGT(TCI.MemTag, ConstantInt::get(Int8Ty, 15)); + Instruction *CheckFailTerm = SplitBlockAndInsertIfThen( + OutOfShortGranuleTagRange, TCI.TagMismatchTerm, !Recover, + MDBuilder(*C).createBranchWeights(1, 100000), &DTU, LI); - IRB.SetInsertPoint(CheckTerm); - Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(PtrLong, 15), Int8Ty); + IRB.SetInsertPoint(TCI.TagMismatchTerm); + Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(TCI.PtrLong, 15), Int8Ty); PtrLowBits = IRB.CreateAdd( PtrLowBits, ConstantInt::get(Int8Ty, (1 << AccessSizeIndex) - 1)); - Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, MemTag); - SplitBlockAndInsertIfThen(PtrLowBitsOOB, CheckTerm, false, - MDBuilder(*C).createBranchWeights(1, 100000), - (DomTreeUpdater *)nullptr, nullptr, - CheckFailTerm->getParent()); + Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, TCI.MemTag); + SplitBlockAndInsertIfThen(PtrLowBitsOOB, TCI.TagMismatchTerm, false, + MDBuilder(*C).createBranchWeights(1, 100000), &DTU, + LI, CheckFailTerm->getParent()); - IRB.SetInsertPoint(CheckTerm); - Value *InlineTagAddr = IRB.CreateOr(AddrLong, 15); - InlineTagAddr = IRB.CreateIntToPtr(InlineTagAddr, Int8PtrTy); + IRB.SetInsertPoint(TCI.TagMismatchTerm); + Value *InlineTagAddr = IRB.CreateOr(TCI.AddrLong, 15); + InlineTagAddr = IRB.CreateIntToPtr(InlineTagAddr, PtrTy); Value *InlineTag = IRB.CreateLoad(Int8Ty, InlineTagAddr); - Value *InlineTagMismatch = IRB.CreateICmpNE(PtrTag, InlineTag); - SplitBlockAndInsertIfThen(InlineTagMismatch, CheckTerm, false, - MDBuilder(*C).createBranchWeights(1, 100000), - (DomTreeUpdater *)nullptr, nullptr, - CheckFailTerm->getParent()); + Value *InlineTagMismatch = IRB.CreateICmpNE(TCI.PtrTag, InlineTag); + SplitBlockAndInsertIfThen(InlineTagMismatch, TCI.TagMismatchTerm, false, + MDBuilder(*C).createBranchWeights(1, 100000), &DTU, + LI, CheckFailTerm->getParent()); IRB.SetInsertPoint(CheckFailTerm); InlineAsm *Asm; @@ -912,7 +965,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, case Triple::x86_64: // The signal handler will find the data address in rdi. Asm = InlineAsm::get( - FunctionType::get(VoidTy, {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false), "int3\nnopl " + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)) + "(%rax)", @@ -923,7 +976,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, case Triple::aarch64_be: // The signal handler will find the data address in x0. Asm = InlineAsm::get( - FunctionType::get(VoidTy, {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false), "brk #" + itostr(0x900 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), "{x0}", /*hasSideEffects=*/true); @@ -931,7 +984,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, case Triple::riscv64: // The signal handler will find the data address in x10. Asm = InlineAsm::get( - FunctionType::get(VoidTy, {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false), "ebreak\naddiw x0, x11, " + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), "{x10}", @@ -940,9 +993,10 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, default: report_fatal_error("unsupported architecture"); } - IRB.CreateCall(Asm, PtrLong); + IRB.CreateCall(Asm, TCI.PtrLong); if (Recover) - cast<BranchInst>(CheckFailTerm)->setSuccessor(0, CheckTerm->getParent()); + cast<BranchInst>(CheckFailTerm) + ->setSuccessor(0, TCI.TagMismatchTerm->getParent()); } bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) { @@ -958,40 +1012,28 @@ bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) { void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { IRBuilder<> IRB(MI); if (isa<MemTransferInst>(MI)) { - if (UseMatchAllCallback) { - IRB.CreateCall( - isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false), - ConstantInt::get(Int8Ty, *MatchAllTag)}); - } else { - IRB.CreateCall( - isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); - } + SmallVector<Value *, 4> Args{ + MI->getOperand(0), MI->getOperand(1), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}; + + if (UseMatchAllCallback) + Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag)); + IRB.CreateCall(isa<MemMoveInst>(MI) ? HwasanMemmove : HwasanMemcpy, Args); } else if (isa<MemSetInst>(MI)) { - if (UseMatchAllCallback) { - IRB.CreateCall( - HwasanMemset, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false), - ConstantInt::get(Int8Ty, *MatchAllTag)}); - } else { - IRB.CreateCall( - HwasanMemset, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); - } + SmallVector<Value *, 4> Args{ + MI->getOperand(0), + IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}; + if (UseMatchAllCallback) + Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag)); + IRB.CreateCall(HwasanMemset, Args); } MI->eraseFromParent(); } -bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) { +bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O, + DomTreeUpdater &DTU, + LoopInfo *LI) { Value *Addr = O.getPtr(); LLVM_DEBUG(dbgs() << "Instrumenting: " << O.getInsn() << "\n"); @@ -1006,34 +1048,26 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) { *O.Alignment >= O.TypeStoreSize / 8)) { size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeStoreSize); if (InstrumentWithCalls) { - if (UseMatchAllCallback) { - IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], - {IRB.CreatePointerCast(Addr, IntptrTy), - ConstantInt::get(Int8Ty, *MatchAllTag)}); - } else { - IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], - IRB.CreatePointerCast(Addr, IntptrTy)); - } + SmallVector<Value *, 2> Args{IRB.CreatePointerCast(Addr, IntptrTy)}; + if (UseMatchAllCallback) + Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag)); + IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], + Args); } else if (OutlinedChecks) { - instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn()); + instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn(), + DTU, LI); } else { - instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn()); + instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn(), + DTU, LI); } } else { - if (UseMatchAllCallback) { - IRB.CreateCall( - HwasanMemoryAccessCallbackSized[O.IsWrite], - {IRB.CreatePointerCast(Addr, IntptrTy), - IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize), - ConstantInt::get(IntptrTy, 8)), - ConstantInt::get(Int8Ty, *MatchAllTag)}); - } else { - IRB.CreateCall( - HwasanMemoryAccessCallbackSized[O.IsWrite], - {IRB.CreatePointerCast(Addr, IntptrTy), - IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize), - ConstantInt::get(IntptrTy, 8))}); - } + SmallVector<Value *, 3> Args{ + IRB.CreatePointerCast(Addr, IntptrTy), + IRB.CreateUDiv(IRB.CreateTypeSize(IntptrTy, O.TypeStoreSize), + ConstantInt::get(IntptrTy, 8))}; + if (UseMatchAllCallback) + Args.emplace_back(ConstantInt::get(Int8Ty, *MatchAllTag)); + IRB.CreateCall(HwasanMemoryAccessCallbackSized[O.IsWrite], Args); } untagPointerOperand(O.getInsn(), Addr); @@ -1049,7 +1083,7 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, Tag = IRB.CreateTrunc(Tag, Int8Ty); if (InstrumentWithCalls) { IRB.CreateCall(HwasanTagMemoryFunc, - {IRB.CreatePointerCast(AI, Int8PtrTy), Tag, + {IRB.CreatePointerCast(AI, PtrTy), Tag, ConstantInt::get(IntptrTy, AlignedSize)}); } else { size_t ShadowSize = Size >> Mapping.Scale; @@ -1067,9 +1101,9 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, const uint8_t SizeRemainder = Size % Mapping.getObjectAlignment().value(); IRB.CreateStore(ConstantInt::get(Int8Ty, SizeRemainder), IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize)); - IRB.CreateStore(Tag, IRB.CreateConstGEP1_32( - Int8Ty, IRB.CreatePointerCast(AI, Int8PtrTy), - AlignedSize - 1)); + IRB.CreateStore( + Tag, IRB.CreateConstGEP1_32(Int8Ty, IRB.CreatePointerCast(AI, PtrTy), + AlignedSize - 1)); } } } @@ -1183,10 +1217,8 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { // in Bionic's libc/private/bionic_tls.h. Function *ThreadPointerFunc = Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); - Value *SlotPtr = IRB.CreatePointerCast( - IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc), 0x30), - Ty->getPointerTo(0)); - return SlotPtr; + return IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc), + 0x30); } if (ThreadPtrGlobal) return ThreadPtrGlobal; @@ -1208,7 +1240,7 @@ Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) { Module *M = F->getParent(); auto *GetStackPointerFn = Intrinsic::getDeclaration( M, Intrinsic::frameaddress, - IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); + IRB.getPtrTy(M->getDataLayout().getAllocaAddrSpace())); CachedSP = IRB.CreatePtrToInt( IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(Int32Ty)}), IntptrTy); @@ -1271,8 +1303,8 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { // Store data to ring buffer. Value *FrameRecordInfo = getFrameRecordInfo(IRB); - Value *RecordPtr = IRB.CreateIntToPtr(ThreadLongMaybeUntagged, - IntptrTy->getPointerTo(0)); + Value *RecordPtr = + IRB.CreateIntToPtr(ThreadLongMaybeUntagged, IRB.getPtrTy(0)); IRB.CreateStore(FrameRecordInfo, RecordPtr); // Update the ring buffer. Top byte of ThreadLong defines the size of the @@ -1309,7 +1341,7 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { ThreadLongMaybeUntagged, ConstantInt::get(IntptrTy, (1ULL << kShadowBaseAlignment) - 1)), ConstantInt::get(IntptrTy, 1), "hwasan.shadow"); - ShadowBase = IRB.CreateIntToPtr(ShadowBase, Int8PtrTy); + ShadowBase = IRB.CreateIntToPtr(ShadowBase, PtrTy); } } @@ -1369,7 +1401,7 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, size_t Size = memtag::getAllocaSizeInBytes(*AI); size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); - Value *AICast = IRB.CreatePointerCast(AI, Int8PtrTy); + Value *AICast = IRB.CreatePointerCast(AI, PtrTy); auto HandleLifetime = [&](IntrinsicInst *II) { // Set the lifetime intrinsic to cover the whole alloca. This reduces the @@ -1462,6 +1494,7 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, SmallVector<InterestingMemoryOperand, 16> OperandsToInstrument; SmallVector<MemIntrinsic *, 16> IntrinToInstrument; SmallVector<Instruction *, 8> LandingPadVec; + const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); memtag::StackInfoBuilder SIB(SSI); for (auto &Inst : instructions(F)) { @@ -1472,7 +1505,7 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, if (InstrumentLandingPads && isa<LandingPadInst>(Inst)) LandingPadVec.push_back(&Inst); - getInterestingMemoryOperands(&Inst, OperandsToInstrument); + getInterestingMemoryOperands(&Inst, TLI, OperandsToInstrument); if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst)) if (!ignoreMemIntrinsic(MI)) @@ -1528,8 +1561,13 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, } } + DominatorTree *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + PostDominatorTree *PDT = FAM.getCachedResult<PostDominatorTreeAnalysis>(F); + LoopInfo *LI = FAM.getCachedResult<LoopAnalysis>(F); + DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); for (auto &Operand : OperandsToInstrument) - instrumentMemAccess(Operand); + instrumentMemAccess(Operand, DTU, LI); + DTU.flush(); if (ClInstrumentMemIntrinsics && !IntrinToInstrument.empty()) { for (auto *Inst : IntrinToInstrument) @@ -1624,7 +1662,7 @@ void HWAddressSanitizer::instrumentGlobals() { if (GV.hasSanitizerMetadata() && GV.getSanitizerMetadata().NoHWAddress) continue; - if (GV.isDeclarationForLinker() || GV.getName().startswith("llvm.") || + if (GV.isDeclarationForLinker() || GV.getName().starts_with("llvm.") || GV.isThreadLocal()) continue; @@ -1682,8 +1720,8 @@ void HWAddressSanitizer::instrumentPersonalityFunctions() { return; FunctionCallee HwasanPersonalityWrapper = M.getOrInsertFunction( - "__hwasan_personality_wrapper", Int32Ty, Int32Ty, Int32Ty, Int64Ty, - Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy); + "__hwasan_personality_wrapper", Int32Ty, Int32Ty, Int32Ty, Int64Ty, PtrTy, + PtrTy, PtrTy, PtrTy, PtrTy); FunctionCallee UnwindGetGR = M.getOrInsertFunction("_Unwind_GetGR", VoidTy); FunctionCallee UnwindGetCFA = M.getOrInsertFunction("_Unwind_GetCFA", VoidTy); @@ -1692,7 +1730,7 @@ void HWAddressSanitizer::instrumentPersonalityFunctions() { if (P.first) ThunkName += ("." + P.first->getName()).str(); FunctionType *ThunkFnTy = FunctionType::get( - Int32Ty, {Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int8PtrTy}, false); + Int32Ty, {Int32Ty, Int32Ty, Int64Ty, PtrTy, PtrTy}, false); bool IsLocal = P.first && (!isa<GlobalValue>(P.first) || cast<GlobalValue>(P.first)->hasLocalLinkage()); auto *ThunkFn = Function::Create(ThunkFnTy, @@ -1710,10 +1748,8 @@ void HWAddressSanitizer::instrumentPersonalityFunctions() { HwasanPersonalityWrapper, {ThunkFn->getArg(0), ThunkFn->getArg(1), ThunkFn->getArg(2), ThunkFn->getArg(3), ThunkFn->getArg(4), - P.first ? IRB.CreateBitCast(P.first, Int8PtrTy) - : Constant::getNullValue(Int8PtrTy), - IRB.CreateBitCast(UnwindGetGR.getCallee(), Int8PtrTy), - IRB.CreateBitCast(UnwindGetCFA.getCallee(), Int8PtrTy)}); + P.first ? P.first : Constant::getNullValue(PtrTy), + UnwindGetGR.getCallee(), UnwindGetCFA.getCallee()}); WrapperCall->setTailCall(); IRB.CreateRet(WrapperCall); diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 5c9799235017..7344fea17517 100644 --- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Value.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/Casting.h" @@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee, promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights); if (AttachProfToDirectCall) { - MDBuilder MDB(NewInst.getContext()); - NewInst.setMetadata( - LLVMContext::MD_prof, - MDB.createBranchWeights({static_cast<uint32_t>(Count)})); + setBranchWeights(NewInst, {static_cast<uint32_t>(Count)}); } using namespace ore; diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index a7b1953ce81c..d3282779d9f5 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This pass lowers instrprof_* intrinsics emitted by a frontend for profiling. +// This pass lowers instrprof_* intrinsics emitted by an instrumentor. // It also builds the data structures and initialization code needed for // updating execution counts and emitting the profile at runtime. // @@ -14,6 +14,7 @@ #include "llvm/Transforms/Instrumentation/InstrProfiling.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" @@ -23,6 +24,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" @@ -47,6 +49,9 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -190,7 +195,8 @@ public: auto *OrigBiasInst = dyn_cast<BinaryOperator>(AddrInst->getOperand(0)); assert(OrigBiasInst->getOpcode() == Instruction::BinaryOps::Add); Value *BiasInst = Builder.Insert(OrigBiasInst->clone()); - Addr = Builder.CreateIntToPtr(BiasInst, Ty->getPointerTo()); + Addr = Builder.CreateIntToPtr(BiasInst, + PointerType::getUnqual(Ty->getContext())); } if (AtomicCounterUpdatePromoted) // automic update currently can only be promoted across the current @@ -241,7 +247,10 @@ public: return; for (BasicBlock *ExitBlock : LoopExitBlocks) { - if (BlockSet.insert(ExitBlock).second) { + if (BlockSet.insert(ExitBlock).second && + llvm::none_of(predecessors(ExitBlock), [&](const BasicBlock *Pred) { + return llvm::isPresplitCoroSuspendExitEdge(*Pred, *ExitBlock); + })) { ExitBlocks.push_back(ExitBlock); InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); } @@ -430,6 +439,15 @@ bool InstrProfiling::lowerIntrinsics(Function *F) { } else if (auto *IPVP = dyn_cast<InstrProfValueProfileInst>(&Instr)) { lowerValueProfileInst(IPVP); MadeChange = true; + } else if (auto *IPMP = dyn_cast<InstrProfMCDCBitmapParameters>(&Instr)) { + IPMP->eraseFromParent(); + MadeChange = true; + } else if (auto *IPBU = dyn_cast<InstrProfMCDCTVBitmapUpdate>(&Instr)) { + lowerMCDCTestVectorBitmapUpdate(IPBU); + MadeChange = true; + } else if (auto *IPTU = dyn_cast<InstrProfMCDCCondBitmapUpdate>(&Instr)) { + lowerMCDCCondBitmapUpdate(IPTU); + MadeChange = true; } } } @@ -544,19 +562,27 @@ bool InstrProfiling::run( // the instrumented function. This is counting the number of instrumented // target value sites to enter it as field in the profile data variable. for (Function &F : M) { - InstrProfInstBase *FirstProfInst = nullptr; - for (BasicBlock &BB : F) - for (auto I = BB.begin(), E = BB.end(); I != E; I++) + InstrProfCntrInstBase *FirstProfInst = nullptr; + for (BasicBlock &BB : F) { + for (auto I = BB.begin(), E = BB.end(); I != E; I++) { if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(I)) computeNumValueSiteCounts(Ind); - else if (FirstProfInst == nullptr && - (isa<InstrProfIncrementInst>(I) || isa<InstrProfCoverInst>(I))) - FirstProfInst = dyn_cast<InstrProfInstBase>(I); + else { + if (FirstProfInst == nullptr && + (isa<InstrProfIncrementInst>(I) || isa<InstrProfCoverInst>(I))) + FirstProfInst = dyn_cast<InstrProfCntrInstBase>(I); + // If the MCDCBitmapParameters intrinsic seen, create the bitmaps. + if (const auto &Params = dyn_cast<InstrProfMCDCBitmapParameters>(I)) + static_cast<void>(getOrCreateRegionBitmaps(Params)); + } + } + } - // Value profiling intrinsic lowering requires per-function profile data - // variable to be created first. - if (FirstProfInst != nullptr) + // Use a profile intrinsic to create the region counters and data variable. + // Also create the data variable based on the MCDCParams. + if (FirstProfInst != nullptr) { static_cast<void>(getOrCreateRegionCounters(FirstProfInst)); + } } for (Function &F : M) @@ -651,15 +677,11 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { SmallVector<OperandBundleDef, 1> OpBundles; Ind->getOperandBundlesAsDefs(OpBundles); if (!IsMemOpSize) { - Value *Args[3] = {Ind->getTargetValue(), - Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), - Builder.getInt32(Index)}; + Value *Args[3] = {Ind->getTargetValue(), DataVar, Builder.getInt32(Index)}; Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), Args, OpBundles); } else { - Value *Args[3] = {Ind->getTargetValue(), - Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), - Builder.getInt32(Index)}; + Value *Args[3] = {Ind->getTargetValue(), DataVar, Builder.getInt32(Index)}; Call = Builder.CreateCall( getOrInsertValueProfilingCall(*M, *TLI, ValueProfilingCallType::MemOp), Args, OpBundles); @@ -670,7 +692,7 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { Ind->eraseFromParent(); } -Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) { +Value *InstrProfiling::getCounterAddress(InstrProfCntrInstBase *I) { auto *Counters = getOrCreateRegionCounters(I); IRBuilder<> Builder(I); @@ -710,6 +732,25 @@ Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) { return Builder.CreateIntToPtr(Add, Addr->getType()); } +Value *InstrProfiling::getBitmapAddress(InstrProfMCDCTVBitmapUpdate *I) { + auto *Bitmaps = getOrCreateRegionBitmaps(I); + IRBuilder<> Builder(I); + + auto *Addr = Builder.CreateConstInBoundsGEP2_32( + Bitmaps->getValueType(), Bitmaps, 0, I->getBitmapIndex()->getZExtValue()); + + if (isRuntimeCounterRelocationEnabled()) { + LLVMContext &Ctx = M->getContext(); + Ctx.diagnose(DiagnosticInfoPGOProfile( + M->getName().data(), + Twine("Runtime counter relocation is presently not supported for MC/DC " + "bitmaps."), + DS_Warning)); + } + + return Addr; +} + void InstrProfiling::lowerCover(InstrProfCoverInst *CoverInstruction) { auto *Addr = getCounterAddress(CoverInstruction); IRBuilder<> Builder(CoverInstruction); @@ -769,6 +810,86 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { CoverageNamesVar->eraseFromParent(); } +void InstrProfiling::lowerMCDCTestVectorBitmapUpdate( + InstrProfMCDCTVBitmapUpdate *Update) { + IRBuilder<> Builder(Update); + auto *Int8Ty = Type::getInt8Ty(M->getContext()); + auto *Int8PtrTy = PointerType::getUnqual(M->getContext()); + auto *Int32Ty = Type::getInt32Ty(M->getContext()); + auto *Int64Ty = Type::getInt64Ty(M->getContext()); + auto *MCDCCondBitmapAddr = Update->getMCDCCondBitmapAddr(); + auto *BitmapAddr = getBitmapAddress(Update); + + // Load Temp Val. + // %mcdc.temp = load i32, ptr %mcdc.addr, align 4 + auto *Temp = Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp"); + + // Calculate byte offset using div8. + // %1 = lshr i32 %mcdc.temp, 3 + auto *BitmapByteOffset = Builder.CreateLShr(Temp, 0x3); + + // Add byte offset to section base byte address. + // %2 = zext i32 %1 to i64 + // %3 = add i64 ptrtoint (ptr @__profbm_test to i64), %2 + auto *BitmapByteAddr = + Builder.CreateAdd(Builder.CreatePtrToInt(BitmapAddr, Int64Ty), + Builder.CreateZExtOrBitCast(BitmapByteOffset, Int64Ty)); + + // Convert to a pointer. + // %4 = inttoptr i32 %3 to ptr + BitmapByteAddr = Builder.CreateIntToPtr(BitmapByteAddr, Int8PtrTy); + + // Calculate bit offset into bitmap byte by using div8 remainder (AND ~8) + // %5 = and i32 %mcdc.temp, 7 + // %6 = trunc i32 %5 to i8 + auto *BitToSet = Builder.CreateTrunc(Builder.CreateAnd(Temp, 0x7), Int8Ty); + + // Shift bit offset left to form a bitmap. + // %7 = shl i8 1, %6 + auto *ShiftedVal = Builder.CreateShl(Builder.getInt8(0x1), BitToSet); + + // Load profile bitmap byte. + // %mcdc.bits = load i8, ptr %4, align 1 + auto *Bitmap = Builder.CreateLoad(Int8Ty, BitmapByteAddr, "mcdc.bits"); + + // Perform logical OR of profile bitmap byte and shifted bit offset. + // %8 = or i8 %mcdc.bits, %7 + auto *Result = Builder.CreateOr(Bitmap, ShiftedVal); + + // Store the updated profile bitmap byte. + // store i8 %8, ptr %3, align 1 + Builder.CreateStore(Result, BitmapByteAddr); + Update->eraseFromParent(); +} + +void InstrProfiling::lowerMCDCCondBitmapUpdate( + InstrProfMCDCCondBitmapUpdate *Update) { + IRBuilder<> Builder(Update); + auto *Int32Ty = Type::getInt32Ty(M->getContext()); + auto *MCDCCondBitmapAddr = Update->getMCDCCondBitmapAddr(); + + // Load the MCDC temporary value from the stack. + // %mcdc.temp = load i32, ptr %mcdc.addr, align 4 + auto *Temp = Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp"); + + // Zero-extend the evaluated condition boolean value (0 or 1) by 32bits. + // %1 = zext i1 %tobool to i32 + auto *CondV_32 = Builder.CreateZExt(Update->getCondBool(), Int32Ty); + + // Shift the boolean value left (by the condition's ID) to form a bitmap. + // %2 = shl i32 %1, <Update->getCondID()> + auto *ShiftedVal = Builder.CreateShl(CondV_32, Update->getCondID()); + + // Perform logical OR of the bitmap against the loaded MCDC temporary value. + // %3 = or i32 %mcdc.temp, %2 + auto *Result = Builder.CreateOr(Temp, ShiftedVal); + + // Store the updated temporary value back to the stack. + // store i32 %3, ptr %mcdc.addr, align 4 + Builder.CreateStore(Result, MCDCCondBitmapAddr); + Update->eraseFromParent(); +} + /// Get the name of a profiling variable for a particular function. static std::string getVarName(InstrProfInstBase *Inc, StringRef Prefix, bool &Renamed) { @@ -784,7 +905,7 @@ static std::string getVarName(InstrProfInstBase *Inc, StringRef Prefix, Renamed = true; uint64_t FuncHash = Inc->getHash()->getZExtValue(); SmallVector<char, 24> HashPostfix; - if (Name.endswith((Twine(".") + Twine(FuncHash)).toStringRef(HashPostfix))) + if (Name.ends_with((Twine(".") + Twine(FuncHash)).toStringRef(HashPostfix))) return (Prefix + Name).str(); return (Prefix + Name + "." + Twine(FuncHash)).str(); } @@ -878,7 +999,7 @@ static inline bool shouldUsePublicSymbol(Function *Fn) { } static inline Constant *getFuncAddrForProfData(Function *Fn) { - auto *Int8PtrTy = Type::getInt8PtrTy(Fn->getContext()); + auto *Int8PtrTy = PointerType::getUnqual(Fn->getContext()); // Store a nullptr in __llvm_profd, if we shouldn't use a real address if (!shouldRecordFunctionAddr(Fn)) return ConstantPointerNull::get(Int8PtrTy); @@ -886,7 +1007,7 @@ static inline Constant *getFuncAddrForProfData(Function *Fn) { // If we can't use an alias, we must use the public symbol, even though this // may require a symbolic relocation. if (shouldUsePublicSymbol(Fn)) - return ConstantExpr::getBitCast(Fn, Int8PtrTy); + return Fn; // When possible use a private alias to avoid symbolic relocations. auto *GA = GlobalAlias::create(GlobalValue::LinkageTypes::PrivateLinkage, @@ -909,7 +1030,7 @@ static inline Constant *getFuncAddrForProfData(Function *Fn) { // appendToCompilerUsed(*Fn->getParent(), {GA}); - return ConstantExpr::getBitCast(GA, Int8PtrTy); + return GA; } static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { @@ -924,37 +1045,31 @@ static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { return true; } -GlobalVariable * -InstrProfiling::createRegionCounters(InstrProfInstBase *Inc, StringRef Name, - GlobalValue::LinkageTypes Linkage) { - uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); - auto &Ctx = M->getContext(); - GlobalVariable *GV; - if (isa<InstrProfCoverInst>(Inc)) { - auto *CounterTy = Type::getInt8Ty(Ctx); - auto *CounterArrTy = ArrayType::get(CounterTy, NumCounters); - // TODO: `Constant::getAllOnesValue()` does not yet accept an array type. - std::vector<Constant *> InitialValues(NumCounters, - Constant::getAllOnesValue(CounterTy)); - GV = new GlobalVariable(*M, CounterArrTy, false, Linkage, - ConstantArray::get(CounterArrTy, InitialValues), - Name); - GV->setAlignment(Align(1)); - } else { - auto *CounterTy = ArrayType::get(Type::getInt64Ty(Ctx), NumCounters); - GV = new GlobalVariable(*M, CounterTy, false, Linkage, - Constant::getNullValue(CounterTy), Name); - GV->setAlignment(Align(8)); - } - return GV; +void InstrProfiling::maybeSetComdat(GlobalVariable *GV, Function *Fn, + StringRef VarName) { + bool DataReferencedByCode = profDataReferencedByCode(*M); + bool NeedComdat = needsComdatForCounter(*Fn, *M); + bool UseComdat = (NeedComdat || TT.isOSBinFormatELF()); + + if (!UseComdat) + return; + + StringRef GroupName = + TT.isOSBinFormatCOFF() && DataReferencedByCode ? GV->getName() : VarName; + Comdat *C = M->getOrInsertComdat(GroupName); + if (!NeedComdat) + C->setSelectionKind(Comdat::NoDeduplicate); + GV->setComdat(C); + // COFF doesn't allow the comdat group leader to have private linkage, so + // upgrade private linkage to internal linkage to produce a symbol table + // entry. + if (TT.isOSBinFormatCOFF() && GV->hasPrivateLinkage()) + GV->setLinkage(GlobalValue::InternalLinkage); } -GlobalVariable * -InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { +GlobalVariable *InstrProfiling::setupProfileSection(InstrProfInstBase *Inc, + InstrProfSectKind IPSK) { GlobalVariable *NamePtr = Inc->getName(); - auto &PD = ProfileDataMap[NamePtr]; - if (PD.RegionCounters) - return PD.RegionCounters; // Match the linkage and visibility of the name global. Function *Fn = Inc->getParent()->getParent(); @@ -993,42 +1108,101 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { // nodeduplicate COMDAT which is lowered to a zero-flag section group. This // allows -z start-stop-gc to discard the entire group when the function is // discarded. - bool DataReferencedByCode = profDataReferencedByCode(*M); - bool NeedComdat = needsComdatForCounter(*Fn, *M); bool Renamed; - std::string CntsVarName = - getVarName(Inc, getInstrProfCountersVarPrefix(), Renamed); - std::string DataVarName = - getVarName(Inc, getInstrProfDataVarPrefix(), Renamed); - auto MaybeSetComdat = [&](GlobalVariable *GV) { - bool UseComdat = (NeedComdat || TT.isOSBinFormatELF()); - if (UseComdat) { - StringRef GroupName = TT.isOSBinFormatCOFF() && DataReferencedByCode - ? GV->getName() - : CntsVarName; - Comdat *C = M->getOrInsertComdat(GroupName); - if (!NeedComdat) - C->setSelectionKind(Comdat::NoDeduplicate); - GV->setComdat(C); - // COFF doesn't allow the comdat group leader to have private linkage, so - // upgrade private linkage to internal linkage to produce a symbol table - // entry. - if (TT.isOSBinFormatCOFF() && GV->hasPrivateLinkage()) - GV->setLinkage(GlobalValue::InternalLinkage); - } - }; + GlobalVariable *Ptr; + StringRef VarPrefix; + std::string VarName; + if (IPSK == IPSK_cnts) { + VarPrefix = getInstrProfCountersVarPrefix(); + VarName = getVarName(Inc, VarPrefix, Renamed); + InstrProfCntrInstBase *CntrIncrement = dyn_cast<InstrProfCntrInstBase>(Inc); + Ptr = createRegionCounters(CntrIncrement, VarName, Linkage); + } else if (IPSK == IPSK_bitmap) { + VarPrefix = getInstrProfBitmapVarPrefix(); + VarName = getVarName(Inc, VarPrefix, Renamed); + InstrProfMCDCBitmapInstBase *BitmapUpdate = + dyn_cast<InstrProfMCDCBitmapInstBase>(Inc); + Ptr = createRegionBitmaps(BitmapUpdate, VarName, Linkage); + } else { + llvm_unreachable("Profile Section must be for Counters or Bitmaps"); + } + + Ptr->setVisibility(Visibility); + // Put the counters and bitmaps in their own sections so linkers can + // remove unneeded sections. + Ptr->setSection(getInstrProfSectionName(IPSK, TT.getObjectFormat())); + Ptr->setLinkage(Linkage); + maybeSetComdat(Ptr, Fn, VarName); + return Ptr; +} + +GlobalVariable * +InstrProfiling::createRegionBitmaps(InstrProfMCDCBitmapInstBase *Inc, + StringRef Name, + GlobalValue::LinkageTypes Linkage) { + uint64_t NumBytes = Inc->getNumBitmapBytes()->getZExtValue(); + auto *BitmapTy = ArrayType::get(Type::getInt8Ty(M->getContext()), NumBytes); + auto GV = new GlobalVariable(*M, BitmapTy, false, Linkage, + Constant::getNullValue(BitmapTy), Name); + GV->setAlignment(Align(1)); + return GV; +} + +GlobalVariable * +InstrProfiling::getOrCreateRegionBitmaps(InstrProfMCDCBitmapInstBase *Inc) { + GlobalVariable *NamePtr = Inc->getName(); + auto &PD = ProfileDataMap[NamePtr]; + if (PD.RegionBitmaps) + return PD.RegionBitmaps; + + // If RegionBitmaps doesn't already exist, create it by first setting up + // the corresponding profile section. + auto *BitmapPtr = setupProfileSection(Inc, IPSK_bitmap); + PD.RegionBitmaps = BitmapPtr; + PD.NumBitmapBytes = Inc->getNumBitmapBytes()->getZExtValue(); + return PD.RegionBitmaps; +} +GlobalVariable * +InstrProfiling::createRegionCounters(InstrProfCntrInstBase *Inc, StringRef Name, + GlobalValue::LinkageTypes Linkage) { uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); - LLVMContext &Ctx = M->getContext(); + auto &Ctx = M->getContext(); + GlobalVariable *GV; + if (isa<InstrProfCoverInst>(Inc)) { + auto *CounterTy = Type::getInt8Ty(Ctx); + auto *CounterArrTy = ArrayType::get(CounterTy, NumCounters); + // TODO: `Constant::getAllOnesValue()` does not yet accept an array type. + std::vector<Constant *> InitialValues(NumCounters, + Constant::getAllOnesValue(CounterTy)); + GV = new GlobalVariable(*M, CounterArrTy, false, Linkage, + ConstantArray::get(CounterArrTy, InitialValues), + Name); + GV->setAlignment(Align(1)); + } else { + auto *CounterTy = ArrayType::get(Type::getInt64Ty(Ctx), NumCounters); + GV = new GlobalVariable(*M, CounterTy, false, Linkage, + Constant::getNullValue(CounterTy), Name); + GV->setAlignment(Align(8)); + } + return GV; +} + +GlobalVariable * +InstrProfiling::getOrCreateRegionCounters(InstrProfCntrInstBase *Inc) { + GlobalVariable *NamePtr = Inc->getName(); + auto &PD = ProfileDataMap[NamePtr]; + if (PD.RegionCounters) + return PD.RegionCounters; - auto *CounterPtr = createRegionCounters(Inc, CntsVarName, Linkage); - CounterPtr->setVisibility(Visibility); - CounterPtr->setSection( - getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat())); - CounterPtr->setLinkage(Linkage); - MaybeSetComdat(CounterPtr); + // If RegionCounters doesn't already exist, create it by first setting up + // the corresponding profile section. + auto *CounterPtr = setupProfileSection(Inc, IPSK_cnts); PD.RegionCounters = CounterPtr; + if (DebugInfoCorrelate) { + LLVMContext &Ctx = M->getContext(); + Function *Fn = Inc->getParent()->getParent(); if (auto *SP = Fn->getSubprogram()) { DIBuilder DB(*M, true, SP->getUnit()); Metadata *FunctionNameAnnotation[] = { @@ -1056,16 +1230,58 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { Annotations); CounterPtr->addDebugInfo(DICounter); DB.finalize(); - } else { - std::string Msg = ("Missing debug info for function " + Fn->getName() + - "; required for profile correlation.") - .str(); - Ctx.diagnose( - DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); } + + // Mark the counter variable as used so that it isn't optimized out. + CompilerUsedVars.push_back(PD.RegionCounters); } - auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); + // Create the data variable (if it doesn't already exist). + createDataVariable(Inc); + + return PD.RegionCounters; +} + +void InstrProfiling::createDataVariable(InstrProfCntrInstBase *Inc) { + // When debug information is correlated to profile data, a data variable + // is not needed. + if (DebugInfoCorrelate) + return; + + GlobalVariable *NamePtr = Inc->getName(); + auto &PD = ProfileDataMap[NamePtr]; + + // Return if data variable was already created. + if (PD.DataVar) + return; + + LLVMContext &Ctx = M->getContext(); + + Function *Fn = Inc->getParent()->getParent(); + GlobalValue::LinkageTypes Linkage = NamePtr->getLinkage(); + GlobalValue::VisibilityTypes Visibility = NamePtr->getVisibility(); + + // Due to the limitation of binder as of 2021/09/28, the duplicate weak + // symbols in the same csect won't be discarded. When there are duplicate weak + // symbols, we can NOT guarantee that the relocations get resolved to the + // intended weak symbol, so we can not ensure the correctness of the relative + // CounterPtr, so we have to use private linkage for counter and data symbols. + if (TT.isOSBinFormatXCOFF()) { + Linkage = GlobalValue::PrivateLinkage; + Visibility = GlobalValue::DefaultVisibility; + } + + bool DataReferencedByCode = profDataReferencedByCode(*M); + bool NeedComdat = needsComdatForCounter(*Fn, *M); + bool Renamed; + + // The Data Variable section is anchored to profile counters. + std::string CntsVarName = + getVarName(Inc, getInstrProfCountersVarPrefix(), Renamed); + std::string DataVarName = + getVarName(Inc, getInstrProfDataVarPrefix(), Renamed); + + auto *Int8PtrTy = PointerType::getUnqual(Ctx); // Allocate statically the array of pointers to value profile nodes for // the current function. Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy); @@ -1079,19 +1295,18 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { *M, ValuesTy, false, Linkage, Constant::getNullValue(ValuesTy), getVarName(Inc, getInstrProfValuesVarPrefix(), Renamed)); ValuesVar->setVisibility(Visibility); + setGlobalVariableLargeSection(TT, *ValuesVar); ValuesVar->setSection( getInstrProfSectionName(IPSK_vals, TT.getObjectFormat())); ValuesVar->setAlignment(Align(8)); - MaybeSetComdat(ValuesVar); - ValuesPtrExpr = - ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx)); + maybeSetComdat(ValuesVar, Fn, CntsVarName); + ValuesPtrExpr = ValuesVar; } - if (DebugInfoCorrelate) { - // Mark the counter variable as used so that it isn't optimized out. - CompilerUsedVars.push_back(PD.RegionCounters); - return PD.RegionCounters; - } + uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); + auto *CounterPtr = PD.RegionCounters; + + uint64_t NumBitmapBytes = PD.NumBitmapBytes; // Create data variable. auto *IntPtrTy = M->getDataLayout().getIntPtrType(M->getContext()); @@ -1134,6 +1349,16 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { ConstantExpr::getSub(ConstantExpr::getPtrToInt(CounterPtr, IntPtrTy), ConstantExpr::getPtrToInt(Data, IntPtrTy)); + // Bitmaps are relative to the same data variable as profile counters. + GlobalVariable *BitmapPtr = PD.RegionBitmaps; + Constant *RelativeBitmapPtr = ConstantInt::get(IntPtrTy, 0); + + if (BitmapPtr != nullptr) { + RelativeBitmapPtr = + ConstantExpr::getSub(ConstantExpr::getPtrToInt(BitmapPtr, IntPtrTy), + ConstantExpr::getPtrToInt(Data, IntPtrTy)); + } + Constant *DataVals[] = { #define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init, #include "llvm/ProfileData/InstrProfData.inc" @@ -1143,7 +1368,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { Data->setVisibility(Visibility); Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat())); Data->setAlignment(Align(INSTR_PROF_DATA_ALIGNMENT)); - MaybeSetComdat(Data); + maybeSetComdat(Data, Fn, CntsVarName); PD.DataVar = Data; @@ -1155,8 +1380,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { NamePtr->setLinkage(GlobalValue::PrivateLinkage); // Collect the referenced names to be used by emitNameData. ReferencedNames.push_back(NamePtr); - - return PD.RegionCounters; } void InstrProfiling::emitVNodes() { @@ -1201,6 +1424,7 @@ void InstrProfiling::emitVNodes() { auto *VNodesVar = new GlobalVariable( *M, VNodesTy, false, GlobalValue::PrivateLinkage, Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName()); + setGlobalVariableLargeSection(TT, *VNodesVar); VNodesVar->setSection( getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat())); VNodesVar->setAlignment(M->getDataLayout().getABITypeAlign(VNodesTy)); @@ -1228,6 +1452,7 @@ void InstrProfiling::emitNameData() { GlobalValue::PrivateLinkage, NamesVal, getInstrProfNamesVarName()); NamesSize = CompressedNameStr.size(); + setGlobalVariableLargeSection(TT, *NamesVar); NamesVar->setSection( getInstrProfSectionName(IPSK_name, TT.getObjectFormat())); // On COFF, it's important to reduce the alignment down to 1 to prevent the @@ -1248,7 +1473,7 @@ void InstrProfiling::emitRegistration() { // Construct the function. auto *VoidTy = Type::getVoidTy(M->getContext()); - auto *VoidPtrTy = Type::getInt8PtrTy(M->getContext()); + auto *VoidPtrTy = PointerType::getUnqual(M->getContext()); auto *Int64Ty = Type::getInt64Ty(M->getContext()); auto *RegisterFTy = FunctionType::get(VoidTy, false); auto *RegisterF = Function::Create(RegisterFTy, GlobalValue::InternalLinkage, @@ -1265,10 +1490,10 @@ void InstrProfiling::emitRegistration() { IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", RegisterF)); for (Value *Data : CompilerUsedVars) if (!isa<Function>(Data)) - IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy)); + IRB.CreateCall(RuntimeRegisterF, Data); for (Value *Data : UsedVars) if (Data != NamesVar && !isa<Function>(Data)) - IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy)); + IRB.CreateCall(RuntimeRegisterF, Data); if (NamesVar) { Type *ParamTypes[] = {VoidPtrTy, Int64Ty}; @@ -1277,8 +1502,7 @@ void InstrProfiling::emitRegistration() { auto *NamesRegisterF = Function::Create(NamesRegisterTy, GlobalVariable::ExternalLinkage, getInstrProfNamesRegFuncName(), M); - IRB.CreateCall(NamesRegisterF, {IRB.CreateBitCast(NamesVar, VoidPtrTy), - IRB.getInt64(NamesSize)}); + IRB.CreateCall(NamesRegisterF, {NamesVar, IRB.getInt64(NamesSize)}); } IRB.CreateRetVoid(); diff --git a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp index 806afc8fcdf7..199afbe966dd 100644 --- a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -85,3 +85,10 @@ Comdat *llvm::getOrCreateFunctionComdat(Function &F, Triple &T) { return C; } +void llvm::setGlobalVariableLargeSection(Triple &TargetTriple, + GlobalVariable &GV) { + if (TargetTriple.getArch() == Triple::x86_64 && + TargetTriple.getObjectFormat() == Triple::ELF) { + GV.setCodeModel(CodeModel::Large); + } +} diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp index 789ed005d03d..539b7441d24b 100644 --- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp @@ -182,6 +182,7 @@ public: C = &(M.getContext()); LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); + PtrTy = PointerType::getUnqual(*C); } /// If it is an interesting memory access, populate information @@ -209,6 +210,7 @@ private: LLVMContext *C; int LongSize; Type *IntptrTy; + PointerType *PtrTy; ShadowMapping Mapping; // These arrays is indexed by AccessIsWrite @@ -267,15 +269,13 @@ Value *MemProfiler::memToShadow(Value *Shadow, IRBuilder<> &IRB) { void MemProfiler::instrumentMemIntrinsic(MemIntrinsic *MI) { IRBuilder<> IRB(MI); if (isa<MemTransferInst>(MI)) { - IRB.CreateCall( - isa<MemMoveInst>(MI) ? MemProfMemmove : MemProfMemcpy, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + IRB.CreateCall(isa<MemMoveInst>(MI) ? MemProfMemmove : MemProfMemcpy, + {MI->getOperand(0), MI->getOperand(1), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); } else if (isa<MemSetInst>(MI)) { IRB.CreateCall( MemProfMemset, - {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + {MI->getOperand(0), IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); } @@ -364,13 +364,13 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { StringRef SectionName = GV->getSection(); // Check if the global is in the PGO counters section. auto OF = Triple(I->getModule()->getTargetTriple()).getObjectFormat(); - if (SectionName.endswith( + if (SectionName.ends_with( getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false))) return std::nullopt; } // Do not instrument accesses to LLVM internal variables. - if (GV->getName().startswith("__llvm")) + if (GV->getName().starts_with("__llvm")) return std::nullopt; } @@ -519,14 +519,12 @@ void MemProfiler::initializeCallbacks(Module &M) { FunctionType::get(IRB.getVoidTy(), Args1, false)); } MemProfMemmove = M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + "memmove", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + ClMemoryAccessCallbackPrefix + "memmove", PtrTy, PtrTy, PtrTy, IntptrTy); MemProfMemcpy = M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "memcpy", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy); - MemProfMemset = M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "memset", - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy); + PtrTy, PtrTy, PtrTy, IntptrTy); + MemProfMemset = + M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "memset", PtrTy, + PtrTy, IRB.getInt32Ty(), IntptrTy); } bool MemProfiler::maybeInsertMemProfInitAtFunctionEntry(Function &F) { @@ -562,7 +560,7 @@ bool MemProfiler::instrumentFunction(Function &F) { return false; if (ClDebugFunc == F.getName()) return false; - if (F.getName().startswith("__memprof_")) + if (F.getName().starts_with("__memprof_")) return false; bool FunctionModified = false; @@ -628,7 +626,7 @@ static void addCallsiteMetadata(Instruction &I, static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset, uint32_t Column) { - llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little> + llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::endianness::little> HashBuilder; HashBuilder.add(Function, LineOffset, Column); llvm::BLAKE3Result<8> Hash = HashBuilder.final(); @@ -678,13 +676,19 @@ static void readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader, const TargetLibraryInfo &TLI) { auto &Ctx = M.getContext(); - - auto FuncName = getPGOFuncName(F); + // Previously we used getIRPGOFuncName() here. If F is local linkage, + // getIRPGOFuncName() returns FuncName with prefix 'FileName;'. But + // llvm-profdata uses FuncName in dwarf to create GUID which doesn't + // contain FileName's prefix. It caused local linkage function can't + // find MemProfRecord. So we use getName() now. + // 'unique-internal-linkage-names' can make MemProf work better for local + // linkage function. + auto FuncName = F.getName(); auto FuncGUID = Function::getGUID(FuncName); - Expected<memprof::MemProfRecord> MemProfResult = - MemProfReader->getMemProfRecord(FuncGUID); - if (Error E = MemProfResult.takeError()) { - handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { + std::optional<memprof::MemProfRecord> MemProfRec; + auto Err = MemProfReader->getMemProfRecord(FuncGUID).moveInto(MemProfRec); + if (Err) { + handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) { auto Err = IPE.get(); bool SkipWarning = false; LLVM_DEBUG(dbgs() << "Error in reading profile for Func " << FuncName @@ -715,6 +719,12 @@ static void readMemprof(Module &M, Function &F, return; } + // Detect if there are non-zero column numbers in the profile. If not, + // treat all column numbers as 0 when matching (i.e. ignore any non-zero + // columns in the IR). The profiled binary might have been built with + // column numbers disabled, for example. + bool ProfileHasColumns = false; + // Build maps of the location hash to all profile data with that leaf location // (allocation info and the callsites). std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo; @@ -722,21 +732,22 @@ static void readMemprof(Module &M, Function &F, // the frame array (see comments below where the map entries are added). std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>> LocHashToCallSites; - const auto MemProfRec = std::move(MemProfResult.get()); - for (auto &AI : MemProfRec.AllocSites) { + for (auto &AI : MemProfRec->AllocSites) { // Associate the allocation info with the leaf frame. The later matching // code will match any inlined call sequences in the IR with a longer prefix // of call stack frames. uint64_t StackId = computeStackId(AI.CallStack[0]); LocHashToAllocInfo[StackId].insert(&AI); + ProfileHasColumns |= AI.CallStack[0].Column; } - for (auto &CS : MemProfRec.CallSites) { + for (auto &CS : MemProfRec->CallSites) { // Need to record all frames from leaf up to and including this function, // as any of these may or may not have been inlined at this point. unsigned Idx = 0; for (auto &StackFrame : CS) { uint64_t StackId = computeStackId(StackFrame); LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++)); + ProfileHasColumns |= StackFrame.Column; // Once we find this function, we can stop recording. if (StackFrame.Function == FuncGUID) break; @@ -785,21 +796,21 @@ static void readMemprof(Module &M, Function &F, if (Name.empty()) Name = DIL->getScope()->getSubprogram()->getName(); auto CalleeGUID = Function::getGUID(Name); - auto StackId = - computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn()); - // LeafFound will only be false on the first iteration, since we either - // set it true or break out of the loop below. + auto StackId = computeStackId(CalleeGUID, GetOffset(DIL), + ProfileHasColumns ? DIL->getColumn() : 0); + // Check if we have found the profile's leaf frame. If yes, collect + // the rest of the call's inlined context starting here. If not, see if + // we find a match further up the inlined context (in case the profile + // was missing debug frames at the leaf). if (!LeafFound) { AllocInfoIter = LocHashToAllocInfo.find(StackId); CallSitesIter = LocHashToCallSites.find(StackId); - // Check if the leaf is in one of the maps. If not, no need to look - // further at this call. - if (AllocInfoIter == LocHashToAllocInfo.end() && - CallSitesIter == LocHashToCallSites.end()) - break; - LeafFound = true; + if (AllocInfoIter != LocHashToAllocInfo.end() || + CallSitesIter != LocHashToCallSites.end()) + LeafFound = true; } - InlinedCallStack.push_back(StackId); + if (LeafFound) + InlinedCallStack.push_back(StackId); } // If leaf not in either of the maps, skip inst. if (!LeafFound) diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 83d90049abc3..94af63da38c8 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -152,7 +152,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -550,6 +549,7 @@ public: private: friend struct MemorySanitizerVisitor; + friend struct VarArgHelperBase; friend struct VarArgAMD64Helper; friend struct VarArgMIPS64Helper; friend struct VarArgAArch64Helper; @@ -574,8 +574,9 @@ private: Triple TargetTriple; LLVMContext *C; - Type *IntptrTy; + Type *IntptrTy; ///< Integer type with the size of a ptr in default AS. Type *OriginTy; + PointerType *PtrTy; ///< Integer type with the size of a ptr in default AS. // XxxTLS variables represent the per-thread state in MSan and per-task state // in KMSAN. @@ -595,16 +596,13 @@ private: /// Thread-local origin storage for function return value. Value *RetvalOriginTLS; - /// Thread-local shadow storage for in-register va_arg function - /// parameters (x86_64-specific). + /// Thread-local shadow storage for in-register va_arg function. Value *VAArgTLS; - /// Thread-local shadow storage for in-register va_arg function - /// parameters (x86_64-specific). + /// Thread-local shadow storage for in-register va_arg function. Value *VAArgOriginTLS; - /// Thread-local shadow storage for va_arg overflow area - /// (x86_64-specific). + /// Thread-local shadow storage for va_arg overflow area. Value *VAArgOverflowSizeTLS; /// Are the instrumentation callbacks set up? @@ -823,11 +821,10 @@ void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) { PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty()); // Functions for poisoning and unpoisoning memory. - MsanPoisonAllocaFn = - M.getOrInsertFunction("__msan_poison_alloca", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy, IRB.getInt8PtrTy()); + MsanPoisonAllocaFn = M.getOrInsertFunction( + "__msan_poison_alloca", IRB.getVoidTy(), PtrTy, IntptrTy, PtrTy); MsanUnpoisonAllocaFn = M.getOrInsertFunction( - "__msan_unpoison_alloca", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy); + "__msan_unpoison_alloca", IRB.getVoidTy(), PtrTy, IntptrTy); } static Constant *getOrInsertGlobal(Module &M, StringRef Name, Type *Ty) { @@ -894,18 +891,18 @@ void MemorySanitizer::createUserspaceApi(Module &M, const TargetLibraryInfo &TLI FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize); MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction( FunctionName, TLI.getAttrList(C, {0, 2}, /*Signed=*/false), - IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), IRB.getInt8PtrTy(), + IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), PtrTy, IRB.getInt32Ty()); } - MsanSetAllocaOriginWithDescriptionFn = M.getOrInsertFunction( - "__msan_set_alloca_origin_with_descr", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy, IRB.getInt8PtrTy(), IRB.getInt8PtrTy()); - MsanSetAllocaOriginNoDescriptionFn = M.getOrInsertFunction( - "__msan_set_alloca_origin_no_descr", IRB.getVoidTy(), IRB.getInt8PtrTy(), - IntptrTy, IRB.getInt8PtrTy()); - MsanPoisonStackFn = M.getOrInsertFunction( - "__msan_poison_stack", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy); + MsanSetAllocaOriginWithDescriptionFn = + M.getOrInsertFunction("__msan_set_alloca_origin_with_descr", + IRB.getVoidTy(), PtrTy, IntptrTy, PtrTy, PtrTy); + MsanSetAllocaOriginNoDescriptionFn = + M.getOrInsertFunction("__msan_set_alloca_origin_no_descr", + IRB.getVoidTy(), PtrTy, IntptrTy, PtrTy); + MsanPoisonStackFn = M.getOrInsertFunction("__msan_poison_stack", + IRB.getVoidTy(), PtrTy, IntptrTy); } /// Insert extern declaration of runtime-provided functions and globals. @@ -923,16 +920,14 @@ void MemorySanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo &TL IRB.getInt32Ty()); MsanSetOriginFn = M.getOrInsertFunction( "__msan_set_origin", TLI.getAttrList(C, {2}, /*Signed=*/false), - IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, IRB.getInt32Ty()); + IRB.getVoidTy(), PtrTy, IntptrTy, IRB.getInt32Ty()); MemmoveFn = - M.getOrInsertFunction("__msan_memmove", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + M.getOrInsertFunction("__msan_memmove", PtrTy, PtrTy, PtrTy, IntptrTy); MemcpyFn = - M.getOrInsertFunction("__msan_memcpy", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); - MemsetFn = M.getOrInsertFunction( - "__msan_memset", TLI.getAttrList(C, {1}, /*Signed=*/true), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); + M.getOrInsertFunction("__msan_memcpy", PtrTy, PtrTy, PtrTy, IntptrTy); + MemsetFn = M.getOrInsertFunction("__msan_memset", + TLI.getAttrList(C, {1}, /*Signed=*/true), + PtrTy, PtrTy, IRB.getInt32Ty(), IntptrTy); MsanInstrumentAsmStoreFn = M.getOrInsertFunction("__msan_instrument_asm_store", IRB.getVoidTy(), @@ -1046,6 +1041,7 @@ void MemorySanitizer::initializeModule(Module &M) { IRBuilder<> IRB(*C); IntptrTy = IRB.getIntPtrTy(DL); OriginTy = IRB.getInt32Ty(); + PtrTy = IRB.getPtrTy(); ColdCallWeights = MDBuilder(*C).createBranchWeights(1, 1000); OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000); @@ -1304,9 +1300,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { FunctionCallee Fn = MS.MaybeStoreOriginFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); - CallBase *CB = IRB.CreateCall( - Fn, {ConvertedShadow2, - IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), Origin}); + CallBase *CB = IRB.CreateCall(Fn, {ConvertedShadow2, Addr, Origin}); CB->addParamAttr(0, Attribute::ZExt); CB->addParamAttr(2, Attribute::ZExt); } else { @@ -1676,7 +1670,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { VectTy->getElementCount()); } assert(IntPtrTy == MS.IntptrTy); - return ShadowTy->getPointerTo(); + return PointerType::get(*MS.C, 0); } Constant *constToIntPtr(Type *IntPtrTy, uint64_t C) const { @@ -1718,6 +1712,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { std::pair<Value *, Value *> getShadowOriginPtrUserspace(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, MaybeAlign Alignment) { + VectorType *VectTy = dyn_cast<VectorType>(Addr->getType()); + if (!VectTy) { + assert(Addr->getType()->isPointerTy()); + } else { + assert(VectTy->getElementType()->isPointerTy()); + } Type *IntptrTy = ptrToIntPtrType(Addr->getType()); Value *ShadowOffset = getShadowPtrOffset(Addr, IRB); Value *ShadowLong = ShadowOffset; @@ -1800,11 +1800,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // TODO: Support callbacs with vectors of addresses. unsigned NumElements = cast<FixedVectorType>(VectTy)->getNumElements(); Value *ShadowPtrs = ConstantInt::getNullValue( - FixedVectorType::get(ShadowTy->getPointerTo(), NumElements)); + FixedVectorType::get(IRB.getPtrTy(), NumElements)); Value *OriginPtrs = nullptr; if (MS.TrackOrigins) OriginPtrs = ConstantInt::getNullValue( - FixedVectorType::get(MS.OriginTy->getPointerTo(), NumElements)); + FixedVectorType::get(IRB.getPtrTy(), NumElements)); for (unsigned i = 0; i < NumElements; ++i) { Value *OneAddr = IRB.CreateExtractElement(Addr, ConstantInt::get(IRB.getInt32Ty(), i)); @@ -1832,33 +1832,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Compute the shadow address for a given function argument. /// /// Shadow = ParamTLS+ArgOffset. - Value *getShadowPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { + Value *getShadowPtrForArgument(IRBuilder<> &IRB, int ArgOffset) { Value *Base = IRB.CreatePointerCast(MS.ParamTLS, MS.IntptrTy); if (ArgOffset) Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(getShadowTy(A), 0), - "_msarg"); + return IRB.CreateIntToPtr(Base, IRB.getPtrTy(0), "_msarg"); } /// Compute the origin address for a given function argument. - Value *getOriginPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { + Value *getOriginPtrForArgument(IRBuilder<> &IRB, int ArgOffset) { if (!MS.TrackOrigins) return nullptr; Value *Base = IRB.CreatePointerCast(MS.ParamOriginTLS, MS.IntptrTy); if (ArgOffset) Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0), - "_msarg_o"); + return IRB.CreateIntToPtr(Base, IRB.getPtrTy(0), "_msarg_o"); } /// Compute the shadow address for a retval. - Value *getShadowPtrForRetval(Value *A, IRBuilder<> &IRB) { - return IRB.CreatePointerCast(MS.RetvalTLS, - PointerType::get(getShadowTy(A), 0), "_msret"); + Value *getShadowPtrForRetval(IRBuilder<> &IRB) { + return IRB.CreatePointerCast(MS.RetvalTLS, IRB.getPtrTy(0), "_msret"); } /// Compute the origin address for a retval. - Value *getOriginPtrForRetval(IRBuilder<> &IRB) { + Value *getOriginPtrForRetval() { // We keep a single origin for the entire retval. Might be too optimistic. return MS.RetvalOriginTLS; } @@ -1982,7 +1979,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { CpShadowPtr, Constant::getNullValue(EntryIRB.getInt8Ty()), Size, ArgAlign); } else { - Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset); + Value *Base = getShadowPtrForArgument(EntryIRB, ArgOffset); const Align CopyAlign = std::min(ArgAlign, kShadowTLSAlignment); Value *Cpy = EntryIRB.CreateMemCpy(CpShadowPtr, CopyAlign, Base, CopyAlign, Size); @@ -1991,7 +1988,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { Value *OriginPtr = - getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset); + getOriginPtrForArgument(EntryIRB, ArgOffset); // FIXME: OriginSize should be: // alignTo(V % kMinOriginAlignment + Size, kMinOriginAlignment) unsigned OriginSize = alignTo(Size, kMinOriginAlignment); @@ -2010,12 +2007,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(A, getCleanOrigin()); } else { // Shadow over TLS - Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset); + Value *Base = getShadowPtrForArgument(EntryIRB, ArgOffset); ShadowPtr = EntryIRB.CreateAlignedLoad(getShadowTy(&FArg), Base, kShadowTLSAlignment); if (MS.TrackOrigins) { Value *OriginPtr = - getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset); + getOriginPtrForArgument(EntryIRB, ArgOffset); setOrigin(A, EntryIRB.CreateLoad(MS.OriginTy, OriginPtr)); } } @@ -2838,11 +2835,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitMemMoveInst(MemMoveInst &I) { getShadow(I.getArgOperand(1)); // Ensure shadow initialized IRBuilder<> IRB(&I); - IRB.CreateCall( - MS.MemmoveFn, - {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(I.getArgOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)}); + IRB.CreateCall(MS.MemmoveFn, + {I.getArgOperand(0), I.getArgOperand(1), + IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)}); I.eraseFromParent(); } @@ -2863,11 +2858,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitMemCpyInst(MemCpyInst &I) { getShadow(I.getArgOperand(1)); // Ensure shadow initialized IRBuilder<> IRB(&I); - IRB.CreateCall( - MS.MemcpyFn, - {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(I.getArgOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)}); + IRB.CreateCall(MS.MemcpyFn, + {I.getArgOperand(0), I.getArgOperand(1), + IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)}); I.eraseFromParent(); } @@ -2876,7 +2869,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); IRB.CreateCall( MS.MemsetFn, - {IRB.CreatePointerCast(I.getArgOperand(0), IRB.getInt8PtrTy()), + {I.getArgOperand(0), IRB.CreateIntCast(I.getArgOperand(1), IRB.getInt32Ty(), false), IRB.CreateIntCast(I.getArgOperand(2), MS.IntptrTy, false)}); I.eraseFromParent(); @@ -3385,8 +3378,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, Ty, Align(1), /*isStore*/ true).first; - IRB.CreateStore(getCleanShadow(Ty), - IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo())); + IRB.CreateStore(getCleanShadow(Ty), ShadowPtr); if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); @@ -4162,7 +4154,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (Function *Func = CB.getCalledFunction()) { // __sanitizer_unaligned_{load,store} functions may be called by users // and always expects shadows in the TLS. So don't check them. - MayCheckCall &= !Func->getName().startswith("__sanitizer_unaligned_"); + MayCheckCall &= !Func->getName().starts_with("__sanitizer_unaligned_"); } unsigned ArgOffset = 0; @@ -4188,7 +4180,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // in that case getShadow() will copy the actual arg shadow to // __msan_param_tls. Value *ArgShadow = getShadow(A); - Value *ArgShadowBase = getShadowPtrForArgument(A, IRB, ArgOffset); + Value *ArgShadowBase = getShadowPtrForArgument(IRB, ArgOffset); LLVM_DEBUG(dbgs() << " Arg#" << i << ": " << *A << " Shadow: " << *ArgShadow << "\n"); if (ByVal) { @@ -4215,7 +4207,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Store = IRB.CreateMemCpy(ArgShadowBase, Alignment, AShadowPtr, Alignment, Size); if (MS.TrackOrigins) { - Value *ArgOriginBase = getOriginPtrForArgument(A, IRB, ArgOffset); + Value *ArgOriginBase = getOriginPtrForArgument(IRB, ArgOffset); // FIXME: OriginSize should be: // alignTo(A % kMinOriginAlignment + Size, kMinOriginAlignment) unsigned OriginSize = alignTo(Size, kMinOriginAlignment); @@ -4237,7 +4229,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Constant *Cst = dyn_cast<Constant>(ArgShadow); if (MS.TrackOrigins && !(Cst && Cst->isNullValue())) { IRB.CreateStore(getOrigin(A), - getOriginPtrForArgument(A, IRB, ArgOffset)); + getOriginPtrForArgument(IRB, ArgOffset)); } } (void)Store; @@ -4269,7 +4261,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRBBefore(&CB); // Until we have full dynamic coverage, make sure the retval shadow is 0. - Value *Base = getShadowPtrForRetval(&CB, IRBBefore); + Value *Base = getShadowPtrForRetval(IRBBefore); IRBBefore.CreateAlignedStore(getCleanShadow(&CB), Base, kShadowTLSAlignment); BasicBlock::iterator NextInsn; @@ -4294,12 +4286,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } IRBuilder<> IRBAfter(&*NextInsn); Value *RetvalShadow = IRBAfter.CreateAlignedLoad( - getShadowTy(&CB), getShadowPtrForRetval(&CB, IRBAfter), + getShadowTy(&CB), getShadowPtrForRetval(IRBAfter), kShadowTLSAlignment, "_msret"); setShadow(&CB, RetvalShadow); if (MS.TrackOrigins) setOrigin(&CB, IRBAfter.CreateLoad(MS.OriginTy, - getOriginPtrForRetval(IRBAfter))); + getOriginPtrForRetval())); } bool isAMustTailRetVal(Value *RetVal) { @@ -4320,7 +4312,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Don't emit the epilogue for musttail call returns. if (isAMustTailRetVal(RetVal)) return; - Value *ShadowPtr = getShadowPtrForRetval(RetVal, IRB); + Value *ShadowPtr = getShadowPtrForRetval(IRB); bool HasNoUndef = F.hasRetAttribute(Attribute::NoUndef); bool StoreShadow = !(MS.EagerChecks && HasNoUndef); // FIXME: Consider using SpecialCaseList to specify a list of functions that @@ -4340,7 +4332,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (StoreShadow) { IRB.CreateAlignedStore(Shadow, ShadowPtr, kShadowTLSAlignment); if (MS.TrackOrigins && StoreOrigin) - IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval(IRB)); + IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval()); } } @@ -4374,8 +4366,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void poisonAllocaUserspace(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { if (PoisonStack && ClPoisonStackWithCall) { - IRB.CreateCall(MS.MsanPoisonStackFn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); + IRB.CreateCall(MS.MsanPoisonStackFn, {&I, Len}); } else { Value *ShadowBase, *OriginBase; std::tie(ShadowBase, OriginBase) = getShadowOriginPtr( @@ -4390,13 +4381,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ClPrintStackNames) { Value *Descr = getLocalVarDescription(I); IRB.CreateCall(MS.MsanSetAllocaOriginWithDescriptionFn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, - IRB.CreatePointerCast(Idptr, IRB.getInt8PtrTy()), - IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy())}); + {&I, Len, Idptr, Descr}); } else { - IRB.CreateCall(MS.MsanSetAllocaOriginNoDescriptionFn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, - IRB.CreatePointerCast(Idptr, IRB.getInt8PtrTy())}); + IRB.CreateCall(MS.MsanSetAllocaOriginNoDescriptionFn, {&I, Len, Idptr}); } } } @@ -4404,12 +4391,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void poisonAllocaKmsan(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { Value *Descr = getLocalVarDescription(I); if (PoisonStack) { - IRB.CreateCall(MS.MsanPoisonAllocaFn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, - IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy())}); + IRB.CreateCall(MS.MsanPoisonAllocaFn, {&I, Len, Descr}); } else { - IRB.CreateCall(MS.MsanUnpoisonAllocaFn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); + IRB.CreateCall(MS.MsanUnpoisonAllocaFn, {&I, Len}); } } @@ -4571,10 +4555,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (!ElemTy->isSized()) return; - Value *Ptr = IRB.CreatePointerCast(Operand, IRB.getInt8PtrTy()); Value *SizeVal = IRB.CreateTypeSize(MS.IntptrTy, DL.getTypeStoreSize(ElemTy)); - IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Ptr, SizeVal}); + IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Operand, SizeVal}); } /// Get the number of output arguments returned by pointers. @@ -4668,8 +4651,91 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } }; +struct VarArgHelperBase : public VarArgHelper { + Function &F; + MemorySanitizer &MS; + MemorySanitizerVisitor &MSV; + SmallVector<CallInst *, 16> VAStartInstrumentationList; + const unsigned VAListTagSize; + + VarArgHelperBase(Function &F, MemorySanitizer &MS, + MemorySanitizerVisitor &MSV, unsigned VAListTagSize) + : F(F), MS(MS), MSV(MSV), VAListTagSize(VAListTagSize) {} + + Value *getShadowAddrForVAArgument(IRBuilder<> &IRB, unsigned ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); + return IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + } + + /// Compute the shadow address for a given va_arg. + Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, + unsigned ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), + "_msarg_va_s"); + } + + /// Compute the shadow address for a given va_arg. + Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, + unsigned ArgOffset, unsigned ArgSize) { + // Make sure we don't overflow __msan_va_arg_tls. + if (ArgOffset + ArgSize > kParamTLSSize) + return nullptr; + return getShadowPtrForVAArgument(Ty, IRB, ArgOffset); + } + + /// Compute the origin address for a given va_arg. + Value *getOriginPtrForVAArgument(IRBuilder<> &IRB, int ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy); + // getOriginPtrForVAArgument() is always called after + // getShadowPtrForVAArgument(), so __msan_va_arg_origin_tls can never + // overflow. + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0), + "_msarg_va_o"); + } + + void CleanUnusedTLS(IRBuilder<> &IRB, Value *ShadowBase, + unsigned BaseOffset) { + // The tails of __msan_va_arg_tls is not large enough to fit full + // value shadow, but it will be copied to backup anyway. Make it + // clean. + if (BaseOffset >= kParamTLSSize) + return; + Value *TailSize = + ConstantInt::getSigned(IRB.getInt32Ty(), kParamTLSSize - BaseOffset); + IRB.CreateMemSet(ShadowBase, ConstantInt::getNullValue(IRB.getInt8Ty()), + TailSize, Align(8)); + } + + void unpoisonVAListTagForInst(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *VAListTag = I.getArgOperand(0); + const Align Alignment = Align(8); + auto [ShadowPtr, OriginPtr] = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); + // Unpoison the whole __va_list_tag. + IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), + VAListTagSize, Alignment, false); + } + + void visitVAStartInst(VAStartInst &I) override { + if (F.getCallingConv() == CallingConv::Win64) + return; + VAStartInstrumentationList.push_back(&I); + unpoisonVAListTagForInst(I); + } + + void visitVACopyInst(VACopyInst &I) override { + if (F.getCallingConv() == CallingConv::Win64) + return; + unpoisonVAListTagForInst(I); + } +}; + /// AMD64-specific implementation of VarArgHelper. -struct VarArgAMD64Helper : public VarArgHelper { +struct VarArgAMD64Helper : public VarArgHelperBase { // An unfortunate workaround for asymmetric lowering of va_arg stuff. // See a comment in visitCallBase for more details. static const unsigned AMD64GpEndOffset = 48; // AMD64 ABI Draft 0.99.6 p3.5.7 @@ -4678,20 +4744,15 @@ struct VarArgAMD64Helper : public VarArgHelper { static const unsigned AMD64FpEndOffsetNoSSE = AMD64GpEndOffset; unsigned AMD64FpEndOffset; - Function &F; - MemorySanitizer &MS; - MemorySanitizerVisitor &MSV; AllocaInst *VAArgTLSCopy = nullptr; AllocaInst *VAArgTLSOriginCopy = nullptr; Value *VAArgOverflowSize = nullptr; - SmallVector<CallInst *, 16> VAStartInstrumentationList; - enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; VarArgAMD64Helper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV) { + : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/24) { AMD64FpEndOffset = AMD64FpEndOffsetSSE; for (const auto &Attr : F.getAttributes().getFnAttrs()) { if (Attr.isStringAttribute() && @@ -4706,6 +4767,8 @@ struct VarArgAMD64Helper : public VarArgHelper { ArgKind classifyArgument(Value *arg) { // A very rough approximation of X86_64 argument classification rules. Type *T = arg->getType(); + if (T->isX86_FP80Ty()) + return AK_Memory; if (T->isFPOrFPVectorTy() || T->isX86_MMXTy()) return AK_FloatingPoint; if (T->isIntegerTy() && T->getPrimitiveSizeInBits() <= 64) @@ -4728,6 +4791,7 @@ struct VarArgAMD64Helper : public VarArgHelper { unsigned FpOffset = AMD64GpEndOffset; unsigned OverflowOffset = AMD64FpEndOffset; const DataLayout &DL = F.getParent()->getDataLayout(); + for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal); @@ -4740,19 +4804,24 @@ struct VarArgAMD64Helper : public VarArgHelper { assert(A->getType()->isPointerTy()); Type *RealTy = CB.getParamByValType(ArgNo); uint64_t ArgSize = DL.getTypeAllocSize(RealTy); - Value *ShadowBase = getShadowPtrForVAArgument( - RealTy, IRB, OverflowOffset, alignTo(ArgSize, 8)); + uint64_t AlignedSize = alignTo(ArgSize, 8); + unsigned BaseOffset = OverflowOffset; + Value *ShadowBase = + getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); Value *OriginBase = nullptr; if (MS.TrackOrigins) - OriginBase = getOriginPtrForVAArgument(RealTy, IRB, OverflowOffset); - OverflowOffset += alignTo(ArgSize, 8); - if (!ShadowBase) - continue; + OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset); + OverflowOffset += AlignedSize; + + if (OverflowOffset > kParamTLSSize) { + CleanUnusedTLS(IRB, ShadowBase, BaseOffset); + continue; // We have no space to copy shadow there. + } + Value *ShadowPtr, *OriginPtr; std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, /*isStore*/ false); - IRB.CreateMemCpy(ShadowBase, kShadowTLSAlignment, ShadowPtr, kShadowTLSAlignment, ArgSize); if (MS.TrackOrigins) @@ -4767,37 +4836,42 @@ struct VarArgAMD64Helper : public VarArgHelper { Value *ShadowBase, *OriginBase = nullptr; switch (AK) { case AK_GeneralPurpose: - ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8); + ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset); if (MS.TrackOrigins) - OriginBase = getOriginPtrForVAArgument(A->getType(), IRB, GpOffset); + OriginBase = getOriginPtrForVAArgument(IRB, GpOffset); GpOffset += 8; + assert(GpOffset <= kParamTLSSize); break; case AK_FloatingPoint: - ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16); + ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset); if (MS.TrackOrigins) - OriginBase = getOriginPtrForVAArgument(A->getType(), IRB, FpOffset); + OriginBase = getOriginPtrForVAArgument(IRB, FpOffset); FpOffset += 16; + assert(FpOffset <= kParamTLSSize); break; case AK_Memory: if (IsFixed) continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); + uint64_t AlignedSize = alignTo(ArgSize, 8); + unsigned BaseOffset = OverflowOffset; ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8); - if (MS.TrackOrigins) - OriginBase = - getOriginPtrForVAArgument(A->getType(), IRB, OverflowOffset); - OverflowOffset += alignTo(ArgSize, 8); + getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); + if (MS.TrackOrigins) { + OriginBase = getOriginPtrForVAArgument(IRB, OverflowOffset); + } + OverflowOffset += AlignedSize; + if (OverflowOffset > kParamTLSSize) { + // We have no space to copy shadow there. + CleanUnusedTLS(IRB, ShadowBase, BaseOffset); + continue; + } } // Take fixed arguments into account for GpOffset and FpOffset, // but don't actually store shadows for them. // TODO(glider): don't call get*PtrForVAArgument() for them. if (IsFixed) continue; - if (!ShadowBase) - continue; Value *Shadow = MSV.getShadow(A); IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment); if (MS.TrackOrigins) { @@ -4813,59 +4887,6 @@ struct VarArgAMD64Helper : public VarArgHelper { IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS); } - /// Compute the shadow address for a given va_arg. - Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - unsigned ArgOffset, unsigned ArgSize) { - // Make sure we don't overflow __msan_va_arg_tls. - if (ArgOffset + ArgSize > kParamTLSSize) - return nullptr; - Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), - "_msarg_va_s"); - } - - /// Compute the origin address for a given va_arg. - Value *getOriginPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, int ArgOffset) { - Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy); - // getOriginPtrForVAArgument() is always called after - // getShadowPtrForVAArgument(), so __msan_va_arg_origin_tls can never - // overflow. - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0), - "_msarg_va_o"); - } - - void unpoisonVAListTagForInst(IntrinsicInst &I) { - IRBuilder<> IRB(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment, - /*isStore*/ true); - - // Unpoison the whole __va_list_tag. - // FIXME: magic ABI constants. - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 24, Alignment, false); - // We shouldn't need to zero out the origins, as they're only checked for - // nonzero shadow. - } - - void visitVAStartInst(VAStartInst &I) override { - if (F.getCallingConv() == CallingConv::Win64) - return; - VAStartInstrumentationList.push_back(&I); - unpoisonVAListTagForInst(I); - } - - void visitVACopyInst(VACopyInst &I) override { - if (F.getCallingConv() == CallingConv::Win64) - return; - unpoisonVAListTagForInst(I); - } - void finalizeInstrumentation() override { assert(!VAArgOverflowSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); @@ -4902,7 +4923,7 @@ struct VarArgAMD64Helper : public VarArgHelper { NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); - Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); + Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 16)), @@ -4919,7 +4940,7 @@ struct VarArgAMD64Helper : public VarArgHelper { if (MS.TrackOrigins) IRB.CreateMemCpy(RegSaveAreaOriginPtr, Alignment, VAArgTLSOriginCopy, Alignment, AMD64FpEndOffset); - Type *OverflowArgAreaPtrTy = Type::getInt64PtrTy(*MS.C); + Type *OverflowArgAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 8)), @@ -4945,18 +4966,14 @@ struct VarArgAMD64Helper : public VarArgHelper { }; /// MIPS64-specific implementation of VarArgHelper. -struct VarArgMIPS64Helper : public VarArgHelper { - Function &F; - MemorySanitizer &MS; - MemorySanitizerVisitor &MSV; +/// NOTE: This is also used for LoongArch64. +struct VarArgMIPS64Helper : public VarArgHelperBase { AllocaInst *VAArgTLSCopy = nullptr; Value *VAArgSize = nullptr; - SmallVector<CallInst *, 16> VAStartInstrumentationList; - VarArgMIPS64Helper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV) {} + : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/8) {} void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override { unsigned VAArgOffset = 0; @@ -4986,42 +5003,6 @@ struct VarArgMIPS64Helper : public VarArgHelper { IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS); } - /// Compute the shadow address for a given va_arg. - Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - unsigned ArgOffset, unsigned ArgSize) { - // Make sure we don't overflow __msan_va_arg_tls. - if (ArgOffset + ArgSize > kParamTLSSize) - return nullptr; - Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), - "_msarg"); - } - - void visitVAStartInst(VAStartInst &I) override { - IRBuilder<> IRB(&I); - VAStartInstrumentationList.push_back(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 8, Alignment, false); - } - - void visitVACopyInst(VACopyInst &I) override { - IRBuilder<> IRB(&I); - VAStartInstrumentationList.push_back(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 8, Alignment, false); - } - void finalizeInstrumentation() override { assert(!VAArgSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); @@ -5051,7 +5032,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { CallInst *OrigInst = VAStartInstrumentationList[i]; NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); - Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); + Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), PointerType::get(RegSaveAreaPtrTy, 0)); @@ -5069,7 +5050,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { }; /// AArch64-specific implementation of VarArgHelper. -struct VarArgAArch64Helper : public VarArgHelper { +struct VarArgAArch64Helper : public VarArgHelperBase { static const unsigned kAArch64GrArgSize = 64; static const unsigned kAArch64VrArgSize = 128; @@ -5081,28 +5062,36 @@ struct VarArgAArch64Helper : public VarArgHelper { AArch64VrBegOffset + kAArch64VrArgSize; static const unsigned AArch64VAEndOffset = AArch64VrEndOffset; - Function &F; - MemorySanitizer &MS; - MemorySanitizerVisitor &MSV; AllocaInst *VAArgTLSCopy = nullptr; Value *VAArgOverflowSize = nullptr; - SmallVector<CallInst *, 16> VAStartInstrumentationList; - enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; VarArgAArch64Helper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV) {} + : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/32) {} - ArgKind classifyArgument(Value *arg) { - Type *T = arg->getType(); - if (T->isFPOrFPVectorTy()) - return AK_FloatingPoint; - if ((T->isIntegerTy() && T->getPrimitiveSizeInBits() <= 64) || - (T->isPointerTy())) - return AK_GeneralPurpose; - return AK_Memory; + // A very rough approximation of aarch64 argument classification rules. + std::pair<ArgKind, uint64_t> classifyArgument(Type *T) { + if (T->isIntOrPtrTy() && T->getPrimitiveSizeInBits() <= 64) + return {AK_GeneralPurpose, 1}; + if (T->isFloatingPointTy() && T->getPrimitiveSizeInBits() <= 128) + return {AK_FloatingPoint, 1}; + + if (T->isArrayTy()) { + auto R = classifyArgument(T->getArrayElementType()); + R.second *= T->getScalarType()->getArrayNumElements(); + return R; + } + + if (const FixedVectorType *FV = dyn_cast<FixedVectorType>(T)) { + auto R = classifyArgument(FV->getScalarType()); + R.second *= FV->getNumElements(); + return R; + } + + LLVM_DEBUG(errs() << "Unknown vararg type: " << *T << "\n"); + return {AK_Memory, 0}; } // The instrumentation stores the argument shadow in a non ABI-specific @@ -5110,7 +5099,7 @@ struct VarArgAArch64Helper : public VarArgHelper { // like x86_64 case, lowers the va_args in the frontend and this pass only // sees the low level code that deals with va_list internals). // The first seven GR registers are saved in the first 56 bytes of the - // va_arg tls arra, followers by the first 8 FP/SIMD registers, and then + // va_arg tls arra, followed by the first 8 FP/SIMD registers, and then // the remaining arguments. // Using constant offset within the va_arg TLS array allows fast copy // in the finalize instrumentation. @@ -5122,20 +5111,22 @@ struct VarArgAArch64Helper : public VarArgHelper { const DataLayout &DL = F.getParent()->getDataLayout(); for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); - ArgKind AK = classifyArgument(A); - if (AK == AK_GeneralPurpose && GrOffset >= AArch64GrEndOffset) + auto [AK, RegNum] = classifyArgument(A->getType()); + if (AK == AK_GeneralPurpose && + (GrOffset + RegNum * 8) > AArch64GrEndOffset) AK = AK_Memory; - if (AK == AK_FloatingPoint && VrOffset >= AArch64VrEndOffset) + if (AK == AK_FloatingPoint && + (VrOffset + RegNum * 16) > AArch64VrEndOffset) AK = AK_Memory; Value *Base; switch (AK) { case AK_GeneralPurpose: - Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset, 8); - GrOffset += 8; + Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset); + GrOffset += 8 * RegNum; break; case AK_FloatingPoint: - Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset, 8); - VrOffset += 16; + Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset); + VrOffset += 16 * RegNum; break; case AK_Memory: // Don't count fixed arguments in the overflow area - va_start will @@ -5143,17 +5134,21 @@ struct VarArgAArch64Helper : public VarArgHelper { if (IsFixed) continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, - alignTo(ArgSize, 8)); - OverflowOffset += alignTo(ArgSize, 8); + uint64_t AlignedSize = alignTo(ArgSize, 8); + unsigned BaseOffset = OverflowOffset; + Base = getShadowPtrForVAArgument(A->getType(), IRB, BaseOffset); + OverflowOffset += AlignedSize; + if (OverflowOffset > kParamTLSSize) { + // We have no space to copy shadow there. + CleanUnusedTLS(IRB, Base, BaseOffset); + continue; + } break; } // Count Gp/Vr fixed arguments to their respective offsets, but don't // bother to actually store a shadow. if (IsFixed) continue; - if (!Base) - continue; IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } Constant *OverflowSize = @@ -5161,48 +5156,12 @@ struct VarArgAArch64Helper : public VarArgHelper { IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS); } - /// Compute the shadow address for a given va_arg. - Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - unsigned ArgOffset, unsigned ArgSize) { - // Make sure we don't overflow __msan_va_arg_tls. - if (ArgOffset + ArgSize > kParamTLSSize) - return nullptr; - Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), - "_msarg"); - } - - void visitVAStartInst(VAStartInst &I) override { - IRBuilder<> IRB(&I); - VAStartInstrumentationList.push_back(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 32, Alignment, false); - } - - void visitVACopyInst(VACopyInst &I) override { - IRBuilder<> IRB(&I); - VAStartInstrumentationList.push_back(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 32, Alignment, false); - } - // Retrieve a va_list field of 'void*' size. Value *getVAField64(IRBuilder<> &IRB, Value *VAListTag, int offset) { Value *SaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, offset)), - Type::getInt64PtrTy(*MS.C)); + PointerType::get(*MS.C, 0)); return IRB.CreateLoad(Type::getInt64Ty(*MS.C), SaveAreaPtrPtr); } @@ -5211,7 +5170,7 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *SaveAreaPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, offset)), - Type::getInt32PtrTy(*MS.C)); + PointerType::get(*MS.C, 0)); Value *SaveArea32 = IRB.CreateLoad(IRB.getInt32Ty(), SaveAreaPtr); return IRB.CreateSExt(SaveArea32, MS.IntptrTy); } @@ -5262,21 +5221,25 @@ struct VarArgAArch64Helper : public VarArgHelper { // we need to adjust the offset for both GR and VR fields based on // the __{gr,vr}_offs value (since they are stores based on incoming // named arguments). + Type *RegSaveAreaPtrTy = IRB.getPtrTy(); // Read the stack pointer from the va_list. - Value *StackSaveAreaPtr = getVAField64(IRB, VAListTag, 0); + Value *StackSaveAreaPtr = + IRB.CreateIntToPtr(getVAField64(IRB, VAListTag, 0), RegSaveAreaPtrTy); // Read both the __gr_top and __gr_off and add them up. Value *GrTopSaveAreaPtr = getVAField64(IRB, VAListTag, 8); Value *GrOffSaveArea = getVAField32(IRB, VAListTag, 24); - Value *GrRegSaveAreaPtr = IRB.CreateAdd(GrTopSaveAreaPtr, GrOffSaveArea); + Value *GrRegSaveAreaPtr = IRB.CreateIntToPtr( + IRB.CreateAdd(GrTopSaveAreaPtr, GrOffSaveArea), RegSaveAreaPtrTy); // Read both the __vr_top and __vr_off and add them up. Value *VrTopSaveAreaPtr = getVAField64(IRB, VAListTag, 16); Value *VrOffSaveArea = getVAField32(IRB, VAListTag, 28); - Value *VrRegSaveAreaPtr = IRB.CreateAdd(VrTopSaveAreaPtr, VrOffSaveArea); + Value *VrRegSaveAreaPtr = IRB.CreateIntToPtr( + IRB.CreateAdd(VrTopSaveAreaPtr, VrOffSaveArea), RegSaveAreaPtrTy); // It does not know how many named arguments is being used and, on the // callsite all the arguments were saved. Since __gr_off is defined as @@ -5332,18 +5295,13 @@ struct VarArgAArch64Helper : public VarArgHelper { }; /// PowerPC64-specific implementation of VarArgHelper. -struct VarArgPowerPC64Helper : public VarArgHelper { - Function &F; - MemorySanitizer &MS; - MemorySanitizerVisitor &MSV; +struct VarArgPowerPC64Helper : public VarArgHelperBase { AllocaInst *VAArgTLSCopy = nullptr; Value *VAArgSize = nullptr; - SmallVector<CallInst *, 16> VAStartInstrumentationList; - VarArgPowerPC64Helper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV) {} + : VarArgHelperBase(F, MS, MSV, /*VAListTagSize=*/8) {} void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override { // For PowerPC, we need to deal with alignment of stack arguments - @@ -5431,43 +5389,6 @@ struct VarArgPowerPC64Helper : public VarArgHelper { IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS); } - /// Compute the shadow address for a given va_arg. - Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - unsigned ArgOffset, unsigned ArgSize) { - // Make sure we don't overflow __msan_va_arg_tls. - if (ArgOffset + ArgSize > kParamTLSSize) - return nullptr; - Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), - "_msarg"); - } - - void visitVAStartInst(VAStartInst &I) override { - IRBuilder<> IRB(&I); - VAStartInstrumentationList.push_back(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 8, Alignment, false); - } - - void visitVACopyInst(VACopyInst &I) override { - IRBuilder<> IRB(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); - // Unpoison the whole __va_list_tag. - // FIXME: magic ABI constants. - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */ 8, Alignment, false); - } - void finalizeInstrumentation() override { assert(!VAArgSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); @@ -5498,7 +5419,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { CallInst *OrigInst = VAStartInstrumentationList[i]; NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); - Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); + Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), PointerType::get(RegSaveAreaPtrTy, 0)); @@ -5516,7 +5437,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { }; /// SystemZ-specific implementation of VarArgHelper. -struct VarArgSystemZHelper : public VarArgHelper { +struct VarArgSystemZHelper : public VarArgHelperBase { static const unsigned SystemZGpOffset = 16; static const unsigned SystemZGpEndOffset = 56; static const unsigned SystemZFpOffset = 128; @@ -5528,16 +5449,11 @@ struct VarArgSystemZHelper : public VarArgHelper { static const unsigned SystemZOverflowArgAreaPtrOffset = 16; static const unsigned SystemZRegSaveAreaPtrOffset = 24; - Function &F; - MemorySanitizer &MS; - MemorySanitizerVisitor &MSV; bool IsSoftFloatABI; AllocaInst *VAArgTLSCopy = nullptr; AllocaInst *VAArgTLSOriginCopy = nullptr; Value *VAArgOverflowSize = nullptr; - SmallVector<CallInst *, 16> VAStartInstrumentationList; - enum class ArgKind { GeneralPurpose, FloatingPoint, @@ -5550,7 +5466,7 @@ struct VarArgSystemZHelper : public VarArgHelper { VarArgSystemZHelper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV), + : VarArgHelperBase(F, MS, MSV, SystemZVAListTagSize), IsSoftFloatABI(F.getFnAttribute("use-soft-float").getValueAsBool()) {} ArgKind classifyArgument(Type *T) { @@ -5711,39 +5627,8 @@ struct VarArgSystemZHelper : public VarArgHelper { IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS); } - Value *getShadowAddrForVAArgument(IRBuilder<> &IRB, unsigned ArgOffset) { - Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); - return IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - } - - Value *getOriginPtrForVAArgument(IRBuilder<> &IRB, int ArgOffset) { - Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0), - "_msarg_va_o"); - } - - void unpoisonVAListTagForInst(IntrinsicInst &I) { - IRBuilder<> IRB(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr, *OriginPtr; - const Align Alignment = Align(8); - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment, - /*isStore*/ true); - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - SystemZVAListTagSize, Alignment, false); - } - - void visitVAStartInst(VAStartInst &I) override { - VAStartInstrumentationList.push_back(&I); - unpoisonVAListTagForInst(I); - } - - void visitVACopyInst(VACopyInst &I) override { unpoisonVAListTagForInst(I); } - void copyRegSaveArea(IRBuilder<> &IRB, Value *VAListTag) { - Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); + Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd( IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), @@ -5767,8 +5652,10 @@ struct VarArgSystemZHelper : public VarArgHelper { Alignment, RegSaveAreaSize); } + // FIXME: This implementation limits OverflowOffset to kParamTLSSize, so we + // don't know real overflow size and can't clear shadow beyond kParamTLSSize. void copyOverflowArea(IRBuilder<> &IRB, Value *VAListTag) { - Type *OverflowArgAreaPtrTy = Type::getInt64PtrTy(*MS.C); + Type *OverflowArgAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd( IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), @@ -5836,6 +5723,10 @@ struct VarArgSystemZHelper : public VarArgHelper { } }; +// Loongarch64 is not a MIPS, but the current vargs calling convention matches +// the MIPS. +using VarArgLoongArch64Helper = VarArgMIPS64Helper; + /// A no-op implementation of VarArgHelper. struct VarArgNoOpHelper : public VarArgHelper { VarArgNoOpHelper(Function &F, MemorySanitizer &MS, @@ -5868,6 +5759,8 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, return new VarArgPowerPC64Helper(Func, Msan, Visitor); else if (TargetTriple.getArch() == Triple::systemz) return new VarArgSystemZHelper(Func, Msan, Visitor); + else if (TargetTriple.isLoongArch64()) + return new VarArgLoongArch64Helper(Func, Msan, Visitor); else return new VarArgNoOpHelper(Func, Msan, Visitor); } diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 3c8f25d73c62..4a5a0b25bebb 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -327,7 +327,6 @@ extern cl::opt<PGOViewCountsType> PGOViewCounts; // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= extern cl::opt<std::string> ViewBlockFreqFuncName; -extern cl::opt<bool> DebugInfoCorrelate; } // namespace llvm static cl::opt<bool> @@ -525,6 +524,7 @@ public: std::vector<std::vector<VPCandidateInfo>> ValueSites; SelectInstVisitor SIVisitor; std::string FuncName; + std::string DeprecatedFuncName; GlobalVariable *FuncNameVar; // CFG hash value for this function. @@ -582,21 +582,22 @@ public: if (!IsCS) { NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); NumOfPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size(); - NumOfPGOBB += MST.BBInfos.size(); + NumOfPGOBB += MST.bbInfoSize(); ValueSites[IPVK_IndirectCallTarget] = VPC.get(IPVK_IndirectCallTarget); } else { NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); NumOfCSPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size(); - NumOfCSPGOBB += MST.BBInfos.size(); + NumOfCSPGOBB += MST.bbInfoSize(); } - FuncName = getPGOFuncName(F); + FuncName = getIRPGOFuncName(F); + DeprecatedFuncName = getPGOFuncName(F); computeCFGHash(); if (!ComdatMembers.empty()) renameComdatFunction(); LLVM_DEBUG(dumpInfo("after CFGMST")); - for (auto &E : MST.AllEdges) { + for (const auto &E : MST.allEdges()) { if (E->Removed) continue; IsCS ? NumOfCSPGOEdge++ : NumOfPGOEdge++; @@ -639,7 +640,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 | //(uint64_t)ValueSites[IPVK_MemOPSize].size() << 40 | - (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); + (uint64_t)MST.numEdges() << 32 | JC.getCRC(); } else { // The higher 32 bits. auto updateJCH = [&JCH](uint64_t Num) { @@ -653,7 +654,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { if (BCI) { updateJCH(BCI->getInstrumentedBlocksHash()); } else { - updateJCH((uint64_t)MST.AllEdges.size()); + updateJCH((uint64_t)MST.numEdges()); } // Hash format for context sensitive profile. Reserve 4 bits for other @@ -668,7 +669,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { LLVM_DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n" << " CRC = " << JC.getCRC() << ", Selects = " << SIVisitor.getNumOfSelectInsts() - << ", Edges = " << MST.AllEdges.size() << ", ICSites = " + << ", Edges = " << MST.numEdges() << ", ICSites = " << ValueSites[IPVK_IndirectCallTarget].size()); if (!PGOOldCFGHashing) { LLVM_DEBUG(dbgs() << ", Memops = " << ValueSites[IPVK_MemOPSize].size() @@ -756,8 +757,8 @@ void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs( // Use a worklist as we will update the vector during the iteration. std::vector<Edge *> EdgeList; - EdgeList.reserve(MST.AllEdges.size()); - for (auto &E : MST.AllEdges) + EdgeList.reserve(MST.numEdges()); + for (const auto &E : MST.allEdges()) EdgeList.push_back(E.get()); for (auto &E : EdgeList) { @@ -874,8 +875,7 @@ static void instrumentOneFunc( F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry, PGOBlockCoverage); - Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); - auto Name = ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy); + auto Name = FuncInfo.FuncNameVar; auto CFGHash = ConstantInt::get(Type::getInt64Ty(M->getContext()), FuncInfo.FunctionHash); if (PGOFunctionEntryCoverage) { @@ -964,9 +964,8 @@ static void instrumentOneFunc( populateEHOperandBundle(Cand, BlockColors, OpBundles); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), - {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), - Builder.getInt64(FuncInfo.FunctionHash), ToProfile, - Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)}, + {FuncInfo.FuncNameVar, Builder.getInt64(FuncInfo.FunctionHash), + ToProfile, Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)}, OpBundles); } } // IPVK_First <= Kind <= IPVK_Last @@ -1164,12 +1163,12 @@ private: } // end anonymous namespace /// Set up InEdges/OutEdges for all BBs in the MST. -static void -setupBBInfoEdges(FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> &FuncInfo) { +static void setupBBInfoEdges( + const FuncPGOInstrumentation<PGOUseEdge, PGOUseBBInfo> &FuncInfo) { // This is not required when there is block coverage inference. if (FuncInfo.BCI) return; - for (auto &E : FuncInfo.MST.AllEdges) { + for (const auto &E : FuncInfo.MST.allEdges()) { if (E->Removed) continue; const BasicBlock *SrcBB = E->SrcBB; @@ -1225,7 +1224,7 @@ bool PGOUseFunc::setInstrumentedCounts( // Set the profile count the Instrumented edges. There are BBs that not in // MST but not instrumented. Need to set the edge count value so that we can // populate the profile counts later. - for (auto &E : FuncInfo.MST.AllEdges) { + for (const auto &E : FuncInfo.MST.allEdges()) { if (E->Removed || E->InMST) continue; const BasicBlock *SrcBB = E->SrcBB; @@ -1336,7 +1335,8 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, auto &Ctx = M->getContext(); uint64_t MismatchedFuncSum = 0; Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord( - FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum); + FuncInfo.FuncName, FuncInfo.FunctionHash, FuncInfo.DeprecatedFuncName, + &MismatchedFuncSum); if (Error E = Result.takeError()) { handleInstrProfError(std::move(E), MismatchedFuncSum); return false; @@ -1381,7 +1381,8 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) { uint64_t MismatchedFuncSum = 0; Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord( - FuncInfo.FuncName, FuncInfo.FunctionHash, &MismatchedFuncSum); + FuncInfo.FuncName, FuncInfo.FunctionHash, FuncInfo.DeprecatedFuncName, + &MismatchedFuncSum); if (auto Err = Result.takeError()) { handleInstrProfError(std::move(Err), MismatchedFuncSum); return; @@ -1436,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) { // If A is uncovered, set weight=1. // This setup will allow BFI to give nonzero profile counts to only covered // blocks. - SmallVector<unsigned, 4> Weights; + SmallVector<uint32_t, 4> Weights; for (auto *Succ : successors(&BB)) Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0); if (Weights.size() >= 2) - BB.getTerminator()->setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights(Weights)); + llvm::setBranchWeights(*BB.getTerminator(), Weights); } unsigned NumCorruptCoverage = 0; @@ -1647,12 +1647,10 @@ void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { Module *M = F.getParent(); IRBuilder<> Builder(&SI); Type *Int64Ty = Builder.getInt64Ty(); - Type *I8PtrTy = Builder.getInt8PtrTy(); auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), - {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), - Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs), + {FuncNameVar, Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); ++(*CurCtrIdx); } @@ -1757,17 +1755,10 @@ static void collectComdatMembers( ComdatMembers.insert(std::make_pair(C, &GA)); } -// Don't perform PGO instrumeatnion / profile-use. -static bool skipPGO(const Function &F) { +// Return true if we should not find instrumentation data for this function +static bool skipPGOUse(const Function &F) { if (F.isDeclaration()) return true; - if (F.hasFnAttribute(llvm::Attribute::NoProfile)) - return true; - if (F.hasFnAttribute(llvm::Attribute::SkipProfile)) - return true; - if (F.getInstructionCount() < PGOFunctionSizeThreshold) - return true; - // If there are too many critical edges, PGO might cause // compiler time problem. Skip PGO if the number of // critical edges execeed the threshold. @@ -1785,7 +1776,19 @@ static bool skipPGO(const Function &F) { << " exceed the threshold. Skip PGO.\n"); return true; } + return false; +} +// Return true if we should not instrument this function +static bool skipPGOGen(const Function &F) { + if (skipPGOUse(F)) + return true; + if (F.hasFnAttribute(llvm::Attribute::NoProfile)) + return true; + if (F.hasFnAttribute(llvm::Attribute::SkipProfile)) + return true; + if (F.getInstructionCount() < PGOFunctionSizeThreshold) + return true; return false; } @@ -1801,7 +1804,7 @@ static bool InstrumentAllFunctions( collectComdatMembers(M, ComdatMembers); for (auto &F : M) { - if (skipPGO(F)) + if (skipPGOGen(F)) continue; auto &TLI = LookupTLI(F); auto *BPI = LookupBPI(F); @@ -2028,7 +2031,7 @@ static bool annotateAllFunctions( InstrumentFuncEntry = PGOInstrumentEntry; bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage(); for (auto &F : M) { - if (skipPGO(F)) + if (skipPGOUse(F)) continue; auto &TLI = LookupTLI(F); auto *BPI = LookupBPI(F); @@ -2201,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) { void llvm::setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { - MDBuilder MDB(M->getContext()); assert(MaxCount > 0 && "Bad max count"); uint64_t Scale = calculateCountScale(MaxCount); SmallVector<unsigned, 4> Weights; @@ -2215,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI, misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false); - TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + setBranchWeights(*TI, Weights); if (EmitBranchProbability) { std::string BrCondStr = getBranchCondString(TI); if (BrCondStr.empty()) diff --git a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index 2906fe190984..fd0f69eca96e 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -378,7 +378,7 @@ bool MemOPSizeOpt::perform(MemOp MO) { assert(It != DefaultBB->end()); BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It), DT); MergeBB->setName("MemOP.Merge"); - BFI.setBlockFreq(MergeBB, OrigBBFreq.getFrequency()); + BFI.setBlockFreq(MergeBB, OrigBBFreq); DefaultBB->setName("MemOP.Default"); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp index d83a3a991c89..230bb8b0a5dc 100644 --- a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp +++ b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp @@ -198,17 +198,16 @@ bool SanitizerBinaryMetadata::run() { // metadata features. // - auto *Int8PtrTy = IRB.getInt8PtrTy(); - auto *Int8PtrPtrTy = PointerType::getUnqual(Int8PtrTy); + auto *PtrTy = IRB.getPtrTy(); auto *Int32Ty = IRB.getInt32Ty(); - const std::array<Type *, 3> InitTypes = {Int32Ty, Int8PtrPtrTy, Int8PtrPtrTy}; + const std::array<Type *, 3> InitTypes = {Int32Ty, PtrTy, PtrTy}; auto *Version = ConstantInt::get(Int32Ty, getVersion()); for (const MetadataInfo *MI : MIS) { const std::array<Value *, InitTypes.size()> InitArgs = { Version, - getSectionMarker(getSectionStart(MI->SectionSuffix), Int8PtrTy), - getSectionMarker(getSectionEnd(MI->SectionSuffix), Int8PtrTy), + getSectionMarker(getSectionStart(MI->SectionSuffix), PtrTy), + getSectionMarker(getSectionEnd(MI->SectionSuffix), PtrTy), }; // We declare the _add and _del functions as weak, and only call them if // there is a valid symbol linked. This allows building binaries with @@ -306,11 +305,11 @@ bool isUARSafeCall(CallInst *CI) { // It's safe to both pass pointers to local variables to them // and to tail-call them. return F && (F->isIntrinsic() || F->doesNotReturn() || - F->getName().startswith("__asan_") || - F->getName().startswith("__hwsan_") || - F->getName().startswith("__ubsan_") || - F->getName().startswith("__msan_") || - F->getName().startswith("__tsan_")); + F->getName().starts_with("__asan_") || + F->getName().starts_with("__hwsan_") || + F->getName().starts_with("__ubsan_") || + F->getName().starts_with("__msan_") || + F->getName().starts_with("__tsan_")); } bool hasUseAfterReturnUnsafeUses(Value &V) { @@ -368,11 +367,11 @@ bool SanitizerBinaryMetadata::pretendAtomicAccess(const Value *Addr) { const auto OF = Triple(Mod.getTargetTriple()).getObjectFormat(); const auto ProfSec = getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false); - if (GV->getSection().endswith(ProfSec)) + if (GV->getSection().ends_with(ProfSec)) return true; } - if (GV->getName().startswith("__llvm_gcov") || - GV->getName().startswith("__llvm_gcda")) + if (GV->getName().starts_with("__llvm_gcov") || + GV->getName().starts_with("__llvm_gcda")) return true; return false; diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index f22918141f6e..906687663519 100644 --- a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -261,9 +261,7 @@ private: FunctionCallee SanCovTraceGepFunction; FunctionCallee SanCovTraceSwitchFunction; GlobalVariable *SanCovLowestStack; - Type *Int128PtrTy, *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, - *Int32PtrTy, *Int16PtrTy, *Int16Ty, *Int8Ty, *Int8PtrTy, *Int1Ty, - *Int1PtrTy; + Type *PtrTy, *IntptrTy, *Int64Ty, *Int32Ty, *Int16Ty, *Int8Ty, *Int1Ty; Module *CurModule; std::string CurModuleUniqueId; Triple TargetTriple; @@ -331,11 +329,10 @@ ModuleSanitizerCoverage::CreateSecStartEnd(Module &M, const char *Section, // Account for the fact that on windows-msvc __start_* symbols actually // point to a uint64_t before the start of the array. - auto SecStartI8Ptr = IRB.CreatePointerCast(SecStart, Int8PtrTy); + auto SecStartI8Ptr = IRB.CreatePointerCast(SecStart, PtrTy); auto GEP = IRB.CreateGEP(Int8Ty, SecStartI8Ptr, ConstantInt::get(IntptrTy, sizeof(uint64_t))); - return std::make_pair(IRB.CreatePointerCast(GEP, PointerType::getUnqual(Ty)), - SecEnd); + return std::make_pair(GEP, SecEnd); } Function *ModuleSanitizerCoverage::CreateInitCallsForSections( @@ -345,7 +342,6 @@ Function *ModuleSanitizerCoverage::CreateInitCallsForSections( auto SecStart = SecStartEnd.first; auto SecEnd = SecStartEnd.second; Function *CtorFunc; - Type *PtrTy = PointerType::getUnqual(Ty); std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( M, CtorName, InitFunctionName, {PtrTy, PtrTy}, {SecStart, SecEnd}); assert(CtorFunc->getName() == CtorName); @@ -391,15 +387,9 @@ bool ModuleSanitizerCoverage::instrumentModule( FunctionPCsArray = nullptr; FunctionCFsArray = nullptr; IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits()); - IntptrPtrTy = PointerType::getUnqual(IntptrTy); + PtrTy = PointerType::getUnqual(*C); Type *VoidTy = Type::getVoidTy(*C); IRBuilder<> IRB(*C); - Int128PtrTy = PointerType::getUnqual(IRB.getInt128Ty()); - Int64PtrTy = PointerType::getUnqual(IRB.getInt64Ty()); - Int16PtrTy = PointerType::getUnqual(IRB.getInt16Ty()); - Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); - Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty()); - Int1PtrTy = PointerType::getUnqual(IRB.getInt1Ty()); Int64Ty = IRB.getInt64Ty(); Int32Ty = IRB.getInt32Ty(); Int16Ty = IRB.getInt16Ty(); @@ -438,26 +428,26 @@ bool ModuleSanitizerCoverage::instrumentModule( M.getOrInsertFunction(SanCovTraceConstCmp8, VoidTy, Int64Ty, Int64Ty); // Loads. - SanCovLoadFunction[0] = M.getOrInsertFunction(SanCovLoad1, VoidTy, Int8PtrTy); + SanCovLoadFunction[0] = M.getOrInsertFunction(SanCovLoad1, VoidTy, PtrTy); SanCovLoadFunction[1] = - M.getOrInsertFunction(SanCovLoad2, VoidTy, Int16PtrTy); + M.getOrInsertFunction(SanCovLoad2, VoidTy, PtrTy); SanCovLoadFunction[2] = - M.getOrInsertFunction(SanCovLoad4, VoidTy, Int32PtrTy); + M.getOrInsertFunction(SanCovLoad4, VoidTy, PtrTy); SanCovLoadFunction[3] = - M.getOrInsertFunction(SanCovLoad8, VoidTy, Int64PtrTy); + M.getOrInsertFunction(SanCovLoad8, VoidTy, PtrTy); SanCovLoadFunction[4] = - M.getOrInsertFunction(SanCovLoad16, VoidTy, Int128PtrTy); + M.getOrInsertFunction(SanCovLoad16, VoidTy, PtrTy); // Stores. SanCovStoreFunction[0] = - M.getOrInsertFunction(SanCovStore1, VoidTy, Int8PtrTy); + M.getOrInsertFunction(SanCovStore1, VoidTy, PtrTy); SanCovStoreFunction[1] = - M.getOrInsertFunction(SanCovStore2, VoidTy, Int16PtrTy); + M.getOrInsertFunction(SanCovStore2, VoidTy, PtrTy); SanCovStoreFunction[2] = - M.getOrInsertFunction(SanCovStore4, VoidTy, Int32PtrTy); + M.getOrInsertFunction(SanCovStore4, VoidTy, PtrTy); SanCovStoreFunction[3] = - M.getOrInsertFunction(SanCovStore8, VoidTy, Int64PtrTy); + M.getOrInsertFunction(SanCovStore8, VoidTy, PtrTy); SanCovStoreFunction[4] = - M.getOrInsertFunction(SanCovStore16, VoidTy, Int128PtrTy); + M.getOrInsertFunction(SanCovStore16, VoidTy, PtrTy); { AttributeList AL; @@ -470,7 +460,7 @@ bool ModuleSanitizerCoverage::instrumentModule( SanCovTraceGepFunction = M.getOrInsertFunction(SanCovTraceGep, VoidTy, IntptrTy); SanCovTraceSwitchFunction = - M.getOrInsertFunction(SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy); + M.getOrInsertFunction(SanCovTraceSwitchName, VoidTy, Int64Ty, PtrTy); Constant *SanCovLowestStackConstant = M.getOrInsertGlobal(SanCovLowestStackName, IntptrTy); @@ -487,7 +477,7 @@ bool ModuleSanitizerCoverage::instrumentModule( SanCovTracePC = M.getOrInsertFunction(SanCovTracePCName, VoidTy); SanCovTracePCGuard = - M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, Int32PtrTy); + M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, PtrTy); for (auto &F : M) instrumentFunction(F, DTCallback, PDTCallback); @@ -510,7 +500,7 @@ bool ModuleSanitizerCoverage::instrumentModule( if (Ctor && Options.PCTable) { auto SecStartEnd = CreateSecStartEnd(M, SanCovPCsSectionName, IntptrTy); FunctionCallee InitFunction = declareSanitizerInitFunction( - M, SanCovPCsInitName, {IntptrPtrTy, IntptrPtrTy}); + M, SanCovPCsInitName, {PtrTy, PtrTy}); IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second}); } @@ -518,7 +508,7 @@ bool ModuleSanitizerCoverage::instrumentModule( if (Ctor && Options.CollectControlFlow) { auto SecStartEnd = CreateSecStartEnd(M, SanCovCFsSectionName, IntptrTy); FunctionCallee InitFunction = declareSanitizerInitFunction( - M, SanCovCFsInitName, {IntptrPtrTy, IntptrPtrTy}); + M, SanCovCFsInitName, {PtrTy, PtrTy}); IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second}); } @@ -616,7 +606,7 @@ void ModuleSanitizerCoverage::instrumentFunction( return; if (F.getName().find(".module_ctor") != std::string::npos) return; // Should not instrument sanitizer init functions. - if (F.getName().startswith("__sanitizer_")) + if (F.getName().starts_with("__sanitizer_")) return; // Don't instrument __sanitizer_* callbacks. // Don't touch available_externally functions, their actual body is elewhere. if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) @@ -744,19 +734,19 @@ ModuleSanitizerCoverage::CreatePCArray(Function &F, IRBuilder<> IRB(&*F.getEntryBlock().getFirstInsertionPt()); for (size_t i = 0; i < N; i++) { if (&F.getEntryBlock() == AllBlocks[i]) { - PCs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy)); + PCs.push_back((Constant *)IRB.CreatePointerCast(&F, PtrTy)); PCs.push_back((Constant *)IRB.CreateIntToPtr( - ConstantInt::get(IntptrTy, 1), IntptrPtrTy)); + ConstantInt::get(IntptrTy, 1), PtrTy)); } else { PCs.push_back((Constant *)IRB.CreatePointerCast( - BlockAddress::get(AllBlocks[i]), IntptrPtrTy)); - PCs.push_back(Constant::getNullValue(IntptrPtrTy)); + BlockAddress::get(AllBlocks[i]), PtrTy)); + PCs.push_back(Constant::getNullValue(PtrTy)); } } - auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, IntptrPtrTy, + auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, PtrTy, SanCovPCsSectionName); PCArray->setInitializer( - ConstantArray::get(ArrayType::get(IntptrPtrTy, N * 2), PCs)); + ConstantArray::get(ArrayType::get(PtrTy, N * 2), PCs)); PCArray->setConstant(true); return PCArray; @@ -833,10 +823,9 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch( Int64Ty->getScalarSizeInBits()) Cond = IRB.CreateIntCast(Cond, Int64Ty, false); for (auto It : SI->cases()) { - Constant *C = It.getCaseValue(); - if (C->getType()->getScalarSizeInBits() < - Int64Ty->getScalarSizeInBits()) - C = ConstantExpr::getCast(CastInst::ZExt, It.getCaseValue(), Int64Ty); + ConstantInt *C = It.getCaseValue(); + if (C->getType()->getScalarSizeInBits() < 64) + C = ConstantInt::get(C->getContext(), C->getValue().zext(64)); Initializers.push_back(C); } llvm::sort(drop_begin(Initializers, 2), @@ -850,7 +839,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch( ConstantArray::get(ArrayOfInt64Ty, Initializers), "__sancov_gen_cov_switch_values"); IRB.CreateCall(SanCovTraceSwitchFunction, - {Cond, IRB.CreatePointerCast(GV, Int64PtrTy)}); + {Cond, IRB.CreatePointerCast(GV, PtrTy)}); } } } @@ -895,16 +884,13 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( : TypeSize == 128 ? 4 : -1; }; - Type *PointerType[5] = {Int8PtrTy, Int16PtrTy, Int32PtrTy, Int64PtrTy, - Int128PtrTy}; for (auto *LI : Loads) { InstrumentationIRBuilder IRB(LI); auto Ptr = LI->getPointerOperand(); int Idx = CallbackIdx(LI->getType()); if (Idx < 0) continue; - IRB.CreateCall(SanCovLoadFunction[Idx], - IRB.CreatePointerCast(Ptr, PointerType[Idx])); + IRB.CreateCall(SanCovLoadFunction[Idx], Ptr); } for (auto *SI : Stores) { InstrumentationIRBuilder IRB(SI); @@ -912,8 +898,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( int Idx = CallbackIdx(SI->getValueOperand()->getType()); if (Idx < 0) continue; - IRB.CreateCall(SanCovStoreFunction[Idx], - IRB.CreatePointerCast(Ptr, PointerType[Idx])); + IRB.CreateCall(SanCovStoreFunction[Idx], Ptr); } } @@ -978,7 +963,7 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, auto GuardPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePointerCast(FunctionGuardArray, IntptrTy), ConstantInt::get(IntptrTy, Idx * 4)), - Int32PtrTy); + PtrTy); IRB.CreateCall(SanCovTracePCGuard, GuardPtr)->setCannotMerge(); } if (Options.Inline8bitCounters) { @@ -1008,7 +993,7 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, Module *M = F.getParent(); Function *GetFrameAddr = Intrinsic::getDeclaration( M, Intrinsic::frameaddress, - IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); + IRB.getPtrTy(M->getDataLayout().getAllocaAddrSpace())); auto FrameAddrPtr = IRB.CreateCall(GetFrameAddr, {Constant::getNullValue(Int32Ty)}); auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy); @@ -1059,40 +1044,40 @@ void ModuleSanitizerCoverage::createFunctionControlFlow(Function &F) { for (auto &BB : F) { // blockaddress can not be used on function's entry block. if (&BB == &F.getEntryBlock()) - CFs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy)); + CFs.push_back((Constant *)IRB.CreatePointerCast(&F, PtrTy)); else CFs.push_back((Constant *)IRB.CreatePointerCast(BlockAddress::get(&BB), - IntptrPtrTy)); + PtrTy)); for (auto SuccBB : successors(&BB)) { assert(SuccBB != &F.getEntryBlock()); CFs.push_back((Constant *)IRB.CreatePointerCast(BlockAddress::get(SuccBB), - IntptrPtrTy)); + PtrTy)); } - CFs.push_back((Constant *)Constant::getNullValue(IntptrPtrTy)); + CFs.push_back((Constant *)Constant::getNullValue(PtrTy)); for (auto &Inst : BB) { if (CallBase *CB = dyn_cast<CallBase>(&Inst)) { if (CB->isIndirectCall()) { // TODO(navidem): handle indirect calls, for now mark its existence. CFs.push_back((Constant *)IRB.CreateIntToPtr( - ConstantInt::get(IntptrTy, -1), IntptrPtrTy)); + ConstantInt::get(IntptrTy, -1), PtrTy)); } else { auto CalledF = CB->getCalledFunction(); if (CalledF && !CalledF->isIntrinsic()) CFs.push_back( - (Constant *)IRB.CreatePointerCast(CalledF, IntptrPtrTy)); + (Constant *)IRB.CreatePointerCast(CalledF, PtrTy)); } } } - CFs.push_back((Constant *)Constant::getNullValue(IntptrPtrTy)); + CFs.push_back((Constant *)Constant::getNullValue(PtrTy)); } FunctionCFsArray = CreateFunctionLocalArrayInSection( - CFs.size(), F, IntptrPtrTy, SanCovCFsSectionName); + CFs.size(), F, PtrTy, SanCovCFsSectionName); FunctionCFsArray->setInitializer( - ConstantArray::get(ArrayType::get(IntptrPtrTy, CFs.size()), CFs)); + ConstantArray::get(ArrayType::get(PtrTy, CFs.size()), CFs)); FunctionCFsArray->setConstant(true); } diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index ce35eefb63fa..8ee0bca7e354 100644 --- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -205,7 +205,7 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) { Attr = Attr.addFnAttribute(Ctx, Attribute::NoUnwind); // Initialize the callbacks. TsanFuncEntry = M.getOrInsertFunction("__tsan_func_entry", Attr, - IRB.getVoidTy(), IRB.getInt8PtrTy()); + IRB.getVoidTy(), IRB.getPtrTy()); TsanFuncExit = M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy()); TsanIgnoreBegin = M.getOrInsertFunction("__tsan_ignore_thread_begin", Attr, @@ -220,49 +220,49 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) { std::string BitSizeStr = utostr(BitSize); SmallString<32> ReadName("__tsan_read" + ByteSizeStr); TsanRead[i] = M.getOrInsertFunction(ReadName, Attr, IRB.getVoidTy(), - IRB.getInt8PtrTy()); + IRB.getPtrTy()); SmallString<32> WriteName("__tsan_write" + ByteSizeStr); TsanWrite[i] = M.getOrInsertFunction(WriteName, Attr, IRB.getVoidTy(), - IRB.getInt8PtrTy()); + IRB.getPtrTy()); SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr); TsanUnalignedRead[i] = M.getOrInsertFunction( - UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr); TsanUnalignedWrite[i] = M.getOrInsertFunction( - UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> VolatileReadName("__tsan_volatile_read" + ByteSizeStr); TsanVolatileRead[i] = M.getOrInsertFunction( - VolatileReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + VolatileReadName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> VolatileWriteName("__tsan_volatile_write" + ByteSizeStr); TsanVolatileWrite[i] = M.getOrInsertFunction( - VolatileWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + VolatileWriteName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> UnalignedVolatileReadName("__tsan_unaligned_volatile_read" + ByteSizeStr); TsanUnalignedVolatileRead[i] = M.getOrInsertFunction( - UnalignedVolatileReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + UnalignedVolatileReadName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> UnalignedVolatileWriteName( "__tsan_unaligned_volatile_write" + ByteSizeStr); TsanUnalignedVolatileWrite[i] = M.getOrInsertFunction( - UnalignedVolatileWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + UnalignedVolatileWriteName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> CompoundRWName("__tsan_read_write" + ByteSizeStr); TsanCompoundRW[i] = M.getOrInsertFunction( - CompoundRWName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + CompoundRWName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); SmallString<64> UnalignedCompoundRWName("__tsan_unaligned_read_write" + ByteSizeStr); TsanUnalignedCompoundRW[i] = M.getOrInsertFunction( - UnalignedCompoundRWName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); + UnalignedCompoundRWName, Attr, IRB.getVoidTy(), IRB.getPtrTy()); Type *Ty = Type::getIntNTy(Ctx, BitSize); - Type *PtrTy = Ty->getPointerTo(); + Type *PtrTy = PointerType::get(Ctx, 0); SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load"); TsanAtomicLoad[i] = M.getOrInsertFunction(AtomicLoadName, @@ -318,9 +318,9 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) { } TsanVptrUpdate = M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy()); + IRB.getPtrTy(), IRB.getPtrTy()); TsanVptrLoad = M.getOrInsertFunction("__tsan_vptr_read", Attr, - IRB.getVoidTy(), IRB.getInt8PtrTy()); + IRB.getVoidTy(), IRB.getPtrTy()); TsanAtomicThreadFence = M.getOrInsertFunction( "__tsan_atomic_thread_fence", TLI.getAttrList(&Ctx, {0}, /*Signed=*/true, /*Ret=*/false, Attr), @@ -332,15 +332,15 @@ void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) { IRB.getVoidTy(), OrdTy); MemmoveFn = - M.getOrInsertFunction("__tsan_memmove", Attr, IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + M.getOrInsertFunction("__tsan_memmove", Attr, IRB.getPtrTy(), + IRB.getPtrTy(), IRB.getPtrTy(), IntptrTy); MemcpyFn = - M.getOrInsertFunction("__tsan_memcpy", Attr, IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + M.getOrInsertFunction("__tsan_memcpy", Attr, IRB.getPtrTy(), + IRB.getPtrTy(), IRB.getPtrTy(), IntptrTy); MemsetFn = M.getOrInsertFunction( "__tsan_memset", TLI.getAttrList(&Ctx, {1}, /*Signed=*/true, /*Ret=*/false, Attr), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); + IRB.getPtrTy(), IRB.getPtrTy(), IRB.getInt32Ty(), IntptrTy); } static bool isVtableAccess(Instruction *I) { @@ -360,15 +360,10 @@ static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) { StringRef SectionName = GV->getSection(); // Check if the global is in the PGO counters section. auto OF = Triple(M->getTargetTriple()).getObjectFormat(); - if (SectionName.endswith( + if (SectionName.ends_with( getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false))) return false; } - - // Check if the global is private gcov data. - if (GV->getName().startswith("__llvm_gcov") || - GV->getName().startswith("__llvm_gcda")) - return false; } // Do not instrument accesses from different address spaces; we cannot deal @@ -522,6 +517,9 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, // Traverse all instructions, collect loads/stores/returns, check for calls. for (auto &BB : F) { for (auto &Inst : BB) { + // Skip instructions inserted by another instrumentation. + if (Inst.hasMetadata(LLVMContext::MD_nosanitize)) + continue; if (isTsanAtomic(&Inst)) AtomicAccesses.push_back(&Inst); else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) @@ -613,17 +611,14 @@ bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II, StoredValue = IRB.CreateExtractElement( StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0)); if (StoredValue->getType()->isIntegerTy()) - StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy()); + StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getPtrTy()); // Call TsanVptrUpdate. - IRB.CreateCall(TsanVptrUpdate, - {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), - IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())}); + IRB.CreateCall(TsanVptrUpdate, {Addr, StoredValue}); NumInstrumentedVtableWrites++; return true; } if (!IsWrite && isVtableAccess(II.Inst)) { - IRB.CreateCall(TsanVptrLoad, - IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy())); + IRB.CreateCall(TsanVptrLoad, Addr); NumInstrumentedVtableReads++; return true; } @@ -655,7 +650,7 @@ bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II, else OnAccessFunc = IsWrite ? TsanUnalignedWrite[Idx] : TsanUnalignedRead[Idx]; } - IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy())); + IRB.CreateCall(OnAccessFunc, Addr); if (IsCompoundRW || IsWrite) NumInstrumentedWrites++; if (IsCompoundRW || !IsWrite) @@ -691,17 +686,19 @@ static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) { bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) { InstrumentationIRBuilder IRB(I); if (MemSetInst *M = dyn_cast<MemSetInst>(I)) { + Value *Cast1 = IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false); + Value *Cast2 = IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false); IRB.CreateCall( MemsetFn, - {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false), - IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false)}); + {M->getArgOperand(0), + Cast1, + Cast2}); I->eraseFromParent(); } else if (MemTransferInst *M = dyn_cast<MemTransferInst>(I)) { IRB.CreateCall( isa<MemCpyInst>(M) ? MemcpyFn : MemmoveFn, - {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(M->getArgOperand(1), IRB.getInt8PtrTy()), + {M->getArgOperand(0), + M->getArgOperand(1), IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false)}); I->eraseFromParent(); } @@ -724,11 +721,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { int Idx = getMemoryAccessFuncIndex(OrigTy, Addr, DL); if (Idx < 0) return false; - const unsigned ByteSize = 1U << Idx; - const unsigned BitSize = ByteSize * 8; - Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); - Type *PtrTy = Ty->getPointerTo(); - Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), + Value *Args[] = {Addr, createOrdering(&IRB, LI->getOrdering())}; Value *C = IRB.CreateCall(TsanAtomicLoad[Idx], Args); Value *Cast = IRB.CreateBitOrPointerCast(C, OrigTy); @@ -742,8 +735,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { const unsigned ByteSize = 1U << Idx; const unsigned BitSize = ByteSize * 8; Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); - Type *PtrTy = Ty->getPointerTo(); - Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), + Value *Args[] = {Addr, IRB.CreateBitOrPointerCast(SI->getValueOperand(), Ty), createOrdering(&IRB, SI->getOrdering())}; CallInst *C = CallInst::Create(TsanAtomicStore[Idx], Args); @@ -760,8 +752,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { const unsigned ByteSize = 1U << Idx; const unsigned BitSize = ByteSize * 8; Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); - Type *PtrTy = Ty->getPointerTo(); - Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), + Value *Args[] = {Addr, IRB.CreateIntCast(RMWI->getValOperand(), Ty, false), createOrdering(&IRB, RMWI->getOrdering())}; CallInst *C = CallInst::Create(F, Args); @@ -775,12 +766,11 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { const unsigned ByteSize = 1U << Idx; const unsigned BitSize = ByteSize * 8; Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); - Type *PtrTy = Ty->getPointerTo(); Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty); Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty); - Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), + Value *Args[] = {Addr, CmpOperand, NewOperand, createOrdering(&IRB, CASI->getSuccessOrdering()), diff --git a/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h b/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h index dd6a1c3f9795..7732eeb4b9c8 100644 --- a/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h +++ b/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.h @@ -22,7 +22,6 @@ #ifndef LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H #define LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/ObjCARCInstKind.h" namespace llvm { diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index adf86526ebf1..b51e4d46bffe 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -933,7 +933,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, if (IsNullOrUndef(CI->getArgOperand(0))) { Changed = true; new StoreInst(ConstantInt::getTrue(CI->getContext()), - PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI); + PoisonValue::get(PointerType::getUnqual(CI->getContext())), + CI); Value *NewValue = PoisonValue::get(CI->getType()); LLVM_DEBUG( dbgs() << "A null pointer-to-weak-pointer is undefined behavior." @@ -952,7 +953,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, IsNullOrUndef(CI->getArgOperand(1))) { Changed = true; new StoreInst(ConstantInt::getTrue(CI->getContext()), - PoisonValue::get(Type::getInt1PtrTy(CI->getContext())), CI); + PoisonValue::get(PointerType::getUnqual(CI->getContext())), + CI); Value *NewValue = PoisonValue::get(CI->getType()); LLVM_DEBUG( diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp index 9f15772f2fa1..e563ecfb1622 100644 --- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp +++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp @@ -19,7 +19,7 @@ using namespace llvm::objcarc; static StringRef getName(Value *V) { StringRef Name = V->getName(); - if (Name.startswith("\1")) + if (Name.starts_with("\1")) return Name.substr(1); return Name; } diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp index 24354211341f..9af275a9f4e2 100644 --- a/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -544,6 +544,16 @@ ADCEChanged AggressiveDeadCodeElimination::removeDeadInstructions() { // value of the function, and may therefore be deleted safely. // NOTE: We reuse the Worklist vector here for memory efficiency. for (Instruction &I : llvm::reverse(instructions(F))) { + // With "RemoveDIs" debug-info stored in DPValue objects, debug-info + // attached to this instruction, and drop any for scopes that aren't alive, + // like the rest of this loop does. Extending support to assignment tracking + // is future work. + for (DPValue &DPV : make_early_inc_range(I.getDbgValueRange())) { + if (AliveScopes.count(DPV.getDebugLoc()->getScope())) + continue; + I.dropOneDbgValue(&DPV); + } + // Check if the instruction is alive. if (isLive(&I)) continue; diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index b259c76fc3a5..f3422a705dca 100644 --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -83,11 +83,7 @@ static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, const SCEV *OffSCEV, Value *Ptr, ScalarEvolution *SE) { const SCEV *PtrSCEV = SE->getSCEV(Ptr); - // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes - // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV - // may disagree. Trunc/extend so they agree. - PtrSCEV = SE->getTruncateOrZeroExtend( - PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType())); + const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV); if (isa<SCEVCouldNotCompute>(DiffSCEV)) return Align(1); @@ -179,6 +175,9 @@ bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, // Added to suppress a crash because consumer doesn't expect non-constant // alignments in the assume bundle. TODO: Consider generalizing caller. return false; + if (!cast<SCEVConstant>(AlignSCEV)->getAPInt().isPowerOf2()) + // Only power of two alignments are supported. + return false; if (AlignOB.Inputs.size() == 3) OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get()); else @@ -264,11 +263,17 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, // Now that we've updated that use of the pointer, look for other uses of // the pointer to update. Visited.insert(J); - for (User *UJ : J->users()) { - Instruction *K = cast<Instruction>(UJ); - if (!Visited.count(K)) - WorkList.push_back(K); - } + if (isa<GetElementPtrInst>(J) || isa<PHINode>(J)) + for (auto &U : J->uses()) { + if (U->getType()->isPointerTy()) { + Instruction *K = cast<Instruction>(U.getUser()); + StoreInst *SI = dyn_cast<StoreInst>(K); + if (SI && SI->getPointerOperandIndex() != U.getOperandNo()) + continue; + if (!Visited.count(K)) + WorkList.push_back(K); + } + } } return true; diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index aeb7c5d461f0..47f663fa0cf0 100644 --- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -62,10 +62,8 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -374,10 +372,10 @@ static void splitCallSite(CallBase &CB, return; } - auto *OriginalBegin = &*TailBB->begin(); + BasicBlock::iterator OriginalBegin = TailBB->begin(); // Replace users of the original call with a PHI mering call-sites split. if (CallPN) { - CallPN->insertBefore(OriginalBegin); + CallPN->insertBefore(*TailBB, OriginalBegin); CB.replaceAllUsesWith(CallPN); } @@ -389,6 +387,7 @@ static void splitCallSite(CallBase &CB, // do not introduce unnecessary PHI nodes for def-use chains from the call // instruction to the beginning of the block. auto I = CB.getReverseIterator(); + Instruction *OriginalBeginInst = &*OriginalBegin; while (I != TailBB->rend()) { Instruction *CurrentI = &*I++; if (!CurrentI->use_empty()) { @@ -401,12 +400,13 @@ static void splitCallSite(CallBase &CB, for (auto &Mapping : ValueToValueMaps) NewPN->addIncoming(Mapping[CurrentI], cast<Instruction>(Mapping[CurrentI])->getParent()); - NewPN->insertBefore(&*TailBB->begin()); + NewPN->insertBefore(*TailBB, TailBB->begin()); CurrentI->replaceAllUsesWith(NewPN); } + CurrentI->dropDbgValues(); CurrentI->eraseFromParent(); // We are done once we handled the first original instruction in TailBB. - if (CurrentI == OriginalBegin) + if (CurrentI == OriginalBeginInst) break; } } diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 611e64bd0976..3e5d979f11cc 100644 --- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -761,11 +761,9 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, if (Adj->Offset) { if (Adj->Ty) { // Constant being rebased is a ConstantExpr. - PointerType *Int8PtrTy = Type::getInt8PtrTy( - *Ctx, cast<PointerType>(Adj->Ty)->getAddressSpace()); - Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", Adj->MatInsertPt); Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset, "mat_gep", Adj->MatInsertPt); + // Hide it behind a bitcast. Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", Adj->MatInsertPt); } else // Constant being rebased is a ConstantInt. diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 15628d32280d..a6fbddca5cba 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -18,13 +18,17 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstraintSystem.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Verifier.h" @@ -83,32 +87,69 @@ static Instruction *getContextInstForUse(Use &U) { } namespace { +/// Struct to express a condition of the form %Op0 Pred %Op1. +struct ConditionTy { + CmpInst::Predicate Pred; + Value *Op0; + Value *Op1; + + ConditionTy() + : Pred(CmpInst::BAD_ICMP_PREDICATE), Op0(nullptr), Op1(nullptr) {} + ConditionTy(CmpInst::Predicate Pred, Value *Op0, Value *Op1) + : Pred(Pred), Op0(Op0), Op1(Op1) {} +}; + /// Represents either -/// * a condition that holds on entry to a block (=conditional fact) +/// * a condition that holds on entry to a block (=condition fact) /// * an assume (=assume fact) /// * a use of a compare instruction to simplify. /// It also tracks the Dominator DFS in and out numbers for each entry. struct FactOrCheck { + enum class EntryTy { + ConditionFact, /// A condition that holds on entry to a block. + InstFact, /// A fact that holds after Inst executed (e.g. an assume or + /// min/mix intrinsic. + InstCheck, /// An instruction to simplify (e.g. an overflow math + /// intrinsics). + UseCheck /// An use of a compare instruction to simplify. + }; + union { Instruction *Inst; Use *U; + ConditionTy Cond; }; + + /// A pre-condition that must hold for the current fact to be added to the + /// system. + ConditionTy DoesHold; + unsigned NumIn; unsigned NumOut; - bool HasInst; - bool Not; + EntryTy Ty; - FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool Not) + FactOrCheck(EntryTy Ty, DomTreeNode *DTN, Instruction *Inst) : Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), - HasInst(true), Not(Not) {} + Ty(Ty) {} FactOrCheck(DomTreeNode *DTN, Use *U) - : U(U), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), - HasInst(false), Not(false) {} + : U(U), DoesHold(CmpInst::BAD_ICMP_PREDICATE, nullptr, nullptr), + NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), + Ty(EntryTy::UseCheck) {} + + FactOrCheck(DomTreeNode *DTN, CmpInst::Predicate Pred, Value *Op0, Value *Op1, + ConditionTy Precond = ConditionTy()) + : Cond(Pred, Op0, Op1), DoesHold(Precond), NumIn(DTN->getDFSNumIn()), + NumOut(DTN->getDFSNumOut()), Ty(EntryTy::ConditionFact) {} - static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst, - bool Not = false) { - return FactOrCheck(DTN, Inst, Not); + static FactOrCheck getConditionFact(DomTreeNode *DTN, CmpInst::Predicate Pred, + Value *Op0, Value *Op1, + ConditionTy Precond = ConditionTy()) { + return FactOrCheck(DTN, Pred, Op0, Op1, Precond); + } + + static FactOrCheck getInstFact(DomTreeNode *DTN, Instruction *Inst) { + return FactOrCheck(EntryTy::InstFact, DTN, Inst); } static FactOrCheck getCheck(DomTreeNode *DTN, Use *U) { @@ -116,39 +157,47 @@ struct FactOrCheck { } static FactOrCheck getCheck(DomTreeNode *DTN, CallInst *CI) { - return FactOrCheck(DTN, CI, false); + return FactOrCheck(EntryTy::InstCheck, DTN, CI); } bool isCheck() const { - return !HasInst || - match(Inst, m_Intrinsic<Intrinsic::ssub_with_overflow>()); + return Ty == EntryTy::InstCheck || Ty == EntryTy::UseCheck; } Instruction *getContextInst() const { - if (HasInst) - return Inst; - return getContextInstForUse(*U); + if (Ty == EntryTy::UseCheck) + return getContextInstForUse(*U); + return Inst; } + Instruction *getInstructionToSimplify() const { assert(isCheck()); - if (HasInst) + if (Ty == EntryTy::InstCheck) return Inst; // The use may have been simplified to a constant already. return dyn_cast<Instruction>(*U); } - bool isConditionFact() const { return !isCheck() && isa<CmpInst>(Inst); } + + bool isConditionFact() const { return Ty == EntryTy::ConditionFact; } }; /// Keep state required to build worklist. struct State { DominatorTree &DT; + LoopInfo &LI; + ScalarEvolution &SE; SmallVector<FactOrCheck, 64> WorkList; - State(DominatorTree &DT) : DT(DT) {} + State(DominatorTree &DT, LoopInfo &LI, ScalarEvolution &SE) + : DT(DT), LI(LI), SE(SE) {} /// Process block \p BB and add known facts to work-list. void addInfoFor(BasicBlock &BB); + /// Try to add facts for loop inductions (AddRecs) in EQ/NE compares + /// controlling the loop header. + void addInfoForInductions(BasicBlock &BB); + /// Returns true if we can add a known condition from BB to its successor /// block Succ. bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const { @@ -172,19 +221,9 @@ struct StackEntry { ValuesToRelease(ValuesToRelease) {} }; -/// Struct to express a pre-condition of the form %Op0 Pred %Op1. -struct PreconditionTy { - CmpInst::Predicate Pred; - Value *Op0; - Value *Op1; - - PreconditionTy(CmpInst::Predicate Pred, Value *Op0, Value *Op1) - : Pred(Pred), Op0(Op0), Op1(Op1) {} -}; - struct ConstraintTy { SmallVector<int64_t, 8> Coefficients; - SmallVector<PreconditionTy, 2> Preconditions; + SmallVector<ConditionTy, 2> Preconditions; SmallVector<SmallVector<int64_t, 8>> ExtraInfo; @@ -327,10 +366,57 @@ struct Decomposition { } }; +// Variable and constant offsets for a chain of GEPs, with base pointer BasePtr. +struct OffsetResult { + Value *BasePtr; + APInt ConstantOffset; + MapVector<Value *, APInt> VariableOffsets; + bool AllInbounds; + + OffsetResult() : BasePtr(nullptr), ConstantOffset(0, uint64_t(0)) {} + + OffsetResult(GEPOperator &GEP, const DataLayout &DL) + : BasePtr(GEP.getPointerOperand()), AllInbounds(GEP.isInBounds()) { + ConstantOffset = APInt(DL.getIndexTypeSizeInBits(BasePtr->getType()), 0); + } +}; } // namespace +// Try to collect variable and constant offsets for \p GEP, partly traversing +// nested GEPs. Returns an OffsetResult with nullptr as BasePtr of collecting +// the offset fails. +static OffsetResult collectOffsets(GEPOperator &GEP, const DataLayout &DL) { + OffsetResult Result(GEP, DL); + unsigned BitWidth = Result.ConstantOffset.getBitWidth(); + if (!GEP.collectOffset(DL, BitWidth, Result.VariableOffsets, + Result.ConstantOffset)) + return {}; + + // If we have a nested GEP, check if we can combine the constant offset of the + // inner GEP with the outer GEP. + if (auto *InnerGEP = dyn_cast<GetElementPtrInst>(Result.BasePtr)) { + MapVector<Value *, APInt> VariableOffsets2; + APInt ConstantOffset2(BitWidth, 0); + bool CanCollectInner = InnerGEP->collectOffset( + DL, BitWidth, VariableOffsets2, ConstantOffset2); + // TODO: Support cases with more than 1 variable offset. + if (!CanCollectInner || Result.VariableOffsets.size() > 1 || + VariableOffsets2.size() > 1 || + (Result.VariableOffsets.size() >= 1 && VariableOffsets2.size() >= 1)) { + // More than 1 variable index, use outer result. + return Result; + } + Result.BasePtr = InnerGEP->getPointerOperand(); + Result.ConstantOffset += ConstantOffset2; + if (Result.VariableOffsets.size() == 0 && VariableOffsets2.size() == 1) + Result.VariableOffsets = VariableOffsets2; + Result.AllInbounds &= InnerGEP->isInBounds(); + } + return Result; +} + static Decomposition decompose(Value *V, - SmallVectorImpl<PreconditionTy> &Preconditions, + SmallVectorImpl<ConditionTy> &Preconditions, bool IsSigned, const DataLayout &DL); static bool canUseSExt(ConstantInt *CI) { @@ -338,51 +424,22 @@ static bool canUseSExt(ConstantInt *CI) { return Val.sgt(MinSignedConstraintValue) && Val.slt(MaxConstraintValue); } -static Decomposition -decomposeGEP(GEPOperator &GEP, SmallVectorImpl<PreconditionTy> &Preconditions, - bool IsSigned, const DataLayout &DL) { +static Decomposition decomposeGEP(GEPOperator &GEP, + SmallVectorImpl<ConditionTy> &Preconditions, + bool IsSigned, const DataLayout &DL) { // Do not reason about pointers where the index size is larger than 64 bits, // as the coefficients used to encode constraints are 64 bit integers. if (DL.getIndexTypeSizeInBits(GEP.getPointerOperand()->getType()) > 64) return &GEP; - if (!GEP.isInBounds()) - return &GEP; - assert(!IsSigned && "The logic below only supports decomposition for " "unsinged predicates at the moment."); - Type *PtrTy = GEP.getType()->getScalarType(); - unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy); - MapVector<Value *, APInt> VariableOffsets; - APInt ConstantOffset(BitWidth, 0); - if (!GEP.collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) + const auto &[BasePtr, ConstantOffset, VariableOffsets, AllInbounds] = + collectOffsets(GEP, DL); + if (!BasePtr || !AllInbounds) return &GEP; - // Handle the (gep (gep ....), C) case by incrementing the constant - // coefficient of the inner GEP, if C is a constant. - auto *InnerGEP = dyn_cast<GEPOperator>(GEP.getPointerOperand()); - if (VariableOffsets.empty() && InnerGEP && InnerGEP->getNumOperands() == 2) { - auto Result = decompose(InnerGEP, Preconditions, IsSigned, DL); - Result.add(ConstantOffset.getSExtValue()); - - if (ConstantOffset.isNegative()) { - unsigned Scale = DL.getTypeAllocSize(InnerGEP->getResultElementType()); - int64_t ConstantOffsetI = ConstantOffset.getSExtValue(); - if (ConstantOffsetI % Scale != 0) - return &GEP; - // Add pre-condition ensuring the GEP is increasing monotonically and - // can be de-composed. - // Both sides are normalized by being divided by Scale. - Preconditions.emplace_back( - CmpInst::ICMP_SGE, InnerGEP->getOperand(1), - ConstantInt::get(InnerGEP->getOperand(1)->getType(), - -1 * (ConstantOffsetI / Scale))); - } - return Result; - } - - Decomposition Result(ConstantOffset.getSExtValue(), - DecompEntry(1, GEP.getPointerOperand())); + Decomposition Result(ConstantOffset.getSExtValue(), DecompEntry(1, BasePtr)); for (auto [Index, Scale] : VariableOffsets) { auto IdxResult = decompose(Index, Preconditions, IsSigned, DL); IdxResult.mul(Scale.getSExtValue()); @@ -401,7 +458,7 @@ decomposeGEP(GEPOperator &GEP, SmallVectorImpl<PreconditionTy> &Preconditions, // Variable } where Coefficient * Variable. The sum of the constant offset and // pairs equals \p V. static Decomposition decompose(Value *V, - SmallVectorImpl<PreconditionTy> &Preconditions, + SmallVectorImpl<ConditionTy> &Preconditions, bool IsSigned, const DataLayout &DL) { auto MergeResults = [&Preconditions, IsSigned, &DL](Value *A, Value *B, @@ -412,6 +469,22 @@ static Decomposition decompose(Value *V, return ResA; }; + Type *Ty = V->getType()->getScalarType(); + if (Ty->isPointerTy() && !IsSigned) { + if (auto *GEP = dyn_cast<GEPOperator>(V)) + return decomposeGEP(*GEP, Preconditions, IsSigned, DL); + if (isa<ConstantPointerNull>(V)) + return int64_t(0); + + return V; + } + + // Don't handle integers > 64 bit. Our coefficients are 64-bit large, so + // coefficient add/mul may wrap, while the operation in the full bit width + // would not. + if (!Ty->isIntegerTy() || Ty->getIntegerBitWidth() > 64) + return V; + // Decompose \p V used with a signed predicate. if (IsSigned) { if (auto *CI = dyn_cast<ConstantInt>(V)) { @@ -424,7 +497,7 @@ static Decomposition decompose(Value *V, return MergeResults(Op0, Op1, IsSigned); ConstantInt *CI; - if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI)))) { + if (match(V, m_NSWMul(m_Value(Op0), m_ConstantInt(CI))) && canUseSExt(CI)) { auto Result = decompose(Op0, Preconditions, IsSigned, DL); Result.mul(CI->getSExtValue()); return Result; @@ -439,9 +512,6 @@ static Decomposition decompose(Value *V, return int64_t(CI->getZExtValue()); } - if (auto *GEP = dyn_cast<GEPOperator>(V)) - return decomposeGEP(*GEP, Preconditions, IsSigned, DL); - Value *Op0; bool IsKnownNonNegative = false; if (match(V, m_ZExt(m_Value(Op0)))) { @@ -474,10 +544,8 @@ static Decomposition decompose(Value *V, } // Decompose or as an add if there are no common bits between the operands. - if (match(V, m_Or(m_Value(Op0), m_ConstantInt(CI))) && - haveNoCommonBitsSet(Op0, CI, DL)) { + if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) return MergeResults(Op0, CI, IsSigned); - } if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) { if (CI->getSExtValue() < 0 || CI->getSExtValue() >= 64) @@ -544,7 +612,7 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, Pred != CmpInst::ICMP_SLE && Pred != CmpInst::ICMP_SLT) return {}; - SmallVector<PreconditionTy, 4> Preconditions; + SmallVector<ConditionTy, 4> Preconditions; bool IsSigned = CmpInst::isSigned(Pred); auto &Value2Index = getValue2Index(IsSigned); auto ADec = decompose(Op0->stripPointerCastsSameRepresentation(), @@ -637,6 +705,17 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred, Value *Op0, Value *Op1) const { + Constant *NullC = Constant::getNullValue(Op0->getType()); + // Handle trivially true compares directly to avoid adding V UGE 0 constraints + // for all variables in the unsigned system. + if ((Pred == CmpInst::ICMP_ULE && Op0 == NullC) || + (Pred == CmpInst::ICMP_UGE && Op1 == NullC)) { + auto &Value2Index = getValue2Index(false); + // Return constraint that's trivially true. + return ConstraintTy(SmallVector<int64_t, 8>(Value2Index.size(), 0), false, + false, false); + } + // If both operands are known to be non-negative, change signed predicates to // unsigned ones. This increases the reasoning effectiveness in combination // with the signed <-> unsigned transfer logic. @@ -654,7 +733,7 @@ ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred, bool ConstraintTy::isValid(const ConstraintInfo &Info) const { return Coefficients.size() > 0 && - all_of(Preconditions, [&Info](const PreconditionTy &C) { + all_of(Preconditions, [&Info](const ConditionTy &C) { return Info.doesHold(C.Pred, C.Op0, C.Op1); }); } @@ -713,6 +792,10 @@ bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, void ConstraintInfo::transferToOtherSystem( CmpInst::Predicate Pred, Value *A, Value *B, unsigned NumIn, unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack) { + auto IsKnownNonNegative = [this](Value *V) { + return doesHold(CmpInst::ICMP_SGE, V, ConstantInt::get(V->getType(), 0)) || + isKnownNonNegative(V, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1); + }; // Check if we can combine facts from the signed and unsigned systems to // derive additional facts. if (!A->getType()->isIntegerTy()) @@ -724,30 +807,41 @@ void ConstraintInfo::transferToOtherSystem( default: break; case CmpInst::ICMP_ULT: - // If B is a signed positive constant, A >=s 0 and A <s B. - if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { + case CmpInst::ICMP_ULE: + // If B is a signed positive constant, then A >=s 0 and A <s (or <=s) B. + if (IsKnownNonNegative(B)) { addFact(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0), NumIn, NumOut, DFSInStack); - addFact(CmpInst::ICMP_SLT, A, B, NumIn, NumOut, DFSInStack); + addFact(CmpInst::getSignedPredicate(Pred), A, B, NumIn, NumOut, + DFSInStack); + } + break; + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_UGT: + // If A is a signed positive constant, then B >=s 0 and A >s (or >=s) B. + if (IsKnownNonNegative(A)) { + addFact(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0), NumIn, + NumOut, DFSInStack); + addFact(CmpInst::getSignedPredicate(Pred), A, B, NumIn, NumOut, + DFSInStack); } break; case CmpInst::ICMP_SLT: - if (doesHold(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0))) + if (IsKnownNonNegative(A)) addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack); break; case CmpInst::ICMP_SGT: { if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1))) addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn, NumOut, DFSInStack); - if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) + if (IsKnownNonNegative(B)) addFact(CmpInst::ICMP_UGT, A, B, NumIn, NumOut, DFSInStack); break; } case CmpInst::ICMP_SGE: - if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { + if (IsKnownNonNegative(B)) addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack); - } break; } } @@ -762,7 +856,138 @@ static void dumpConstraint(ArrayRef<int64_t> C, } #endif +void State::addInfoForInductions(BasicBlock &BB) { + auto *L = LI.getLoopFor(&BB); + if (!L || L->getHeader() != &BB) + return; + + Value *A; + Value *B; + CmpInst::Predicate Pred; + + if (!match(BB.getTerminator(), + m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value()))) + return; + PHINode *PN = dyn_cast<PHINode>(A); + if (!PN) { + Pred = CmpInst::getSwappedPredicate(Pred); + std::swap(A, B); + PN = dyn_cast<PHINode>(A); + } + + if (!PN || PN->getParent() != &BB || PN->getNumIncomingValues() != 2 || + !SE.isSCEVable(PN->getType())) + return; + + BasicBlock *InLoopSucc = nullptr; + if (Pred == CmpInst::ICMP_NE) + InLoopSucc = cast<BranchInst>(BB.getTerminator())->getSuccessor(0); + else if (Pred == CmpInst::ICMP_EQ) + InLoopSucc = cast<BranchInst>(BB.getTerminator())->getSuccessor(1); + else + return; + + if (!L->contains(InLoopSucc) || !L->isLoopExiting(&BB) || InLoopSucc == &BB) + return; + + auto *AR = dyn_cast_or_null<SCEVAddRecExpr>(SE.getSCEV(PN)); + BasicBlock *LoopPred = L->getLoopPredecessor(); + if (!AR || AR->getLoop() != L || !LoopPred) + return; + + const SCEV *StartSCEV = AR->getStart(); + Value *StartValue = nullptr; + if (auto *C = dyn_cast<SCEVConstant>(StartSCEV)) { + StartValue = C->getValue(); + } else { + StartValue = PN->getIncomingValueForBlock(LoopPred); + assert(SE.getSCEV(StartValue) == StartSCEV && "inconsistent start value"); + } + + DomTreeNode *DTN = DT.getNode(InLoopSucc); + auto Inc = SE.getMonotonicPredicateType(AR, CmpInst::ICMP_UGT); + bool MonotonicallyIncreasing = + Inc && *Inc == ScalarEvolution::MonotonicallyIncreasing; + if (MonotonicallyIncreasing) { + // SCEV guarantees that AR does not wrap, so PN >= StartValue can be added + // unconditionally. + WorkList.push_back( + FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_UGE, PN, StartValue)); + } + + APInt StepOffset; + if (auto *C = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) + StepOffset = C->getAPInt(); + else + return; + + // Make sure the bound B is loop-invariant. + if (!L->isLoopInvariant(B)) + return; + + // Handle negative steps. + if (StepOffset.isNegative()) { + // TODO: Extend to allow steps > -1. + if (!(-StepOffset).isOne()) + return; + + // AR may wrap. + // Add StartValue >= PN conditional on B <= StartValue which guarantees that + // the loop exits before wrapping with a step of -1. + WorkList.push_back(FactOrCheck::getConditionFact( + DTN, CmpInst::ICMP_UGE, StartValue, PN, + ConditionTy(CmpInst::ICMP_ULE, B, StartValue))); + // Add PN > B conditional on B <= StartValue which guarantees that the loop + // exits when reaching B with a step of -1. + WorkList.push_back(FactOrCheck::getConditionFact( + DTN, CmpInst::ICMP_UGT, PN, B, + ConditionTy(CmpInst::ICMP_ULE, B, StartValue))); + return; + } + + // Make sure AR either steps by 1 or that the value we compare against is a + // GEP based on the same start value and all offsets are a multiple of the + // step size, to guarantee that the induction will reach the value. + if (StepOffset.isZero() || StepOffset.isNegative()) + return; + + if (!StepOffset.isOne()) { + auto *UpperGEP = dyn_cast<GetElementPtrInst>(B); + if (!UpperGEP || UpperGEP->getPointerOperand() != StartValue || + !UpperGEP->isInBounds()) + return; + + MapVector<Value *, APInt> UpperVariableOffsets; + APInt UpperConstantOffset(StepOffset.getBitWidth(), 0); + const DataLayout &DL = BB.getModule()->getDataLayout(); + if (!UpperGEP->collectOffset(DL, StepOffset.getBitWidth(), + UpperVariableOffsets, UpperConstantOffset)) + return; + // All variable offsets and the constant offset have to be a multiple of the + // step. + if (!UpperConstantOffset.urem(StepOffset).isZero() || + any_of(UpperVariableOffsets, [&StepOffset](const auto &P) { + return !P.second.urem(StepOffset).isZero(); + })) + return; + } + + // AR may wrap. Add PN >= StartValue conditional on StartValue <= B which + // guarantees that the loop exits before wrapping in combination with the + // restrictions on B and the step above. + if (!MonotonicallyIncreasing) { + WorkList.push_back(FactOrCheck::getConditionFact( + DTN, CmpInst::ICMP_UGE, PN, StartValue, + ConditionTy(CmpInst::ICMP_ULE, StartValue, B))); + } + WorkList.push_back(FactOrCheck::getConditionFact( + DTN, CmpInst::ICMP_ULT, PN, B, + ConditionTy(CmpInst::ICMP_ULE, StartValue, B))); +} + void State::addInfoFor(BasicBlock &BB) { + addInfoForInductions(BB); + // True as long as long as the current instruction is guaranteed to execute. bool GuaranteedToExecute = true; // Queue conditions and assumes. @@ -785,27 +1010,40 @@ void State::addInfoFor(BasicBlock &BB) { } if (isa<MinMaxIntrinsic>(&I)) { - WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I)); + WorkList.push_back(FactOrCheck::getInstFact(DT.getNode(&BB), &I)); continue; } - Value *Cond; + Value *A, *B; + CmpInst::Predicate Pred; // For now, just handle assumes with a single compare as condition. - if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) && - isa<ICmpInst>(Cond)) { + if (match(&I, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_Value(A), m_Value(B))))) { if (GuaranteedToExecute) { // The assume is guaranteed to execute when BB is entered, hence Cond // holds on entry to BB. - WorkList.emplace_back(FactOrCheck::getFact(DT.getNode(I.getParent()), - cast<Instruction>(Cond))); + WorkList.emplace_back(FactOrCheck::getConditionFact( + DT.getNode(I.getParent()), Pred, A, B)); } else { WorkList.emplace_back( - FactOrCheck::getFact(DT.getNode(I.getParent()), &I)); + FactOrCheck::getInstFact(DT.getNode(I.getParent()), &I)); } } GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I); } + if (auto *Switch = dyn_cast<SwitchInst>(BB.getTerminator())) { + for (auto &Case : Switch->cases()) { + BasicBlock *Succ = Case.getCaseSuccessor(); + Value *V = Case.getCaseValue(); + if (!canAddSuccessor(BB, Succ)) + continue; + WorkList.emplace_back(FactOrCheck::getConditionFact( + DT.getNode(Succ), CmpInst::ICMP_EQ, Switch->getCondition(), V)); + } + return; + } + auto *Br = dyn_cast<BranchInst>(BB.getTerminator()); if (!Br || !Br->isConditional()) return; @@ -837,8 +1075,11 @@ void State::addInfoFor(BasicBlock &BB) { while (!CondWorkList.empty()) { Value *Cur = CondWorkList.pop_back_val(); if (auto *Cmp = dyn_cast<ICmpInst>(Cur)) { - WorkList.emplace_back( - FactOrCheck::getFact(DT.getNode(Successor), Cmp, IsOr)); + WorkList.emplace_back(FactOrCheck::getConditionFact( + DT.getNode(Successor), + IsOr ? CmpInst::getInversePredicate(Cmp->getPredicate()) + : Cmp->getPredicate(), + Cmp->getOperand(0), Cmp->getOperand(1))); continue; } if (IsOr && match(Cur, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) { @@ -860,11 +1101,14 @@ void State::addInfoFor(BasicBlock &BB) { if (!CmpI) return; if (canAddSuccessor(BB, Br->getSuccessor(0))) - WorkList.emplace_back( - FactOrCheck::getFact(DT.getNode(Br->getSuccessor(0)), CmpI)); + WorkList.emplace_back(FactOrCheck::getConditionFact( + DT.getNode(Br->getSuccessor(0)), CmpI->getPredicate(), + CmpI->getOperand(0), CmpI->getOperand(1))); if (canAddSuccessor(BB, Br->getSuccessor(1))) - WorkList.emplace_back( - FactOrCheck::getFact(DT.getNode(Br->getSuccessor(1)), CmpI, true)); + WorkList.emplace_back(FactOrCheck::getConditionFact( + DT.getNode(Br->getSuccessor(1)), + CmpInst::getInversePredicate(CmpI->getPredicate()), CmpI->getOperand(0), + CmpI->getOperand(1))); } namespace { @@ -1069,7 +1313,8 @@ static std::optional<bool> checkCondition(CmpInst *Cmp, ConstraintInfo &Info, static bool checkAndReplaceCondition( CmpInst *Cmp, ConstraintInfo &Info, unsigned NumIn, unsigned NumOut, Instruction *ContextInst, Module *ReproducerModule, - ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT) { + ArrayRef<ReproducerEntry> ReproducerCondStack, DominatorTree &DT, + SmallVectorImpl<Instruction *> &ToRemove) { auto ReplaceCmpWithConstant = [&](CmpInst *Cmp, bool IsTrue) { generateReproducer(Cmp, ReproducerModule, ReproducerCondStack, Info, DT); Constant *ConstantC = ConstantInt::getBool( @@ -1090,6 +1335,8 @@ static bool checkAndReplaceCondition( return !II || II->getIntrinsicID() != Intrinsic::assume; }); NumCondsRemoved++; + if (Cmp->use_empty()) + ToRemove.push_back(Cmp); return true; }; @@ -1120,6 +1367,7 @@ static bool checkAndSecondOpImpliedByFirst( FactOrCheck &CB, ConstraintInfo &Info, Module *ReproducerModule, SmallVectorImpl<ReproducerEntry> &ReproducerCondStack, SmallVectorImpl<StackEntry> &DFSInStack) { + CmpInst::Predicate Pred; Value *A, *B; Instruction *And = CB.getContextInst(); @@ -1263,7 +1511,8 @@ tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info, return Changed; } -static bool eliminateConstraints(Function &F, DominatorTree &DT, +static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, + ScalarEvolution &SE, OptimizationRemarkEmitter &ORE) { bool Changed = false; DT.updateDFSNumbers(); @@ -1271,7 +1520,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, for (Value &Arg : F.args()) FunctionArgs.push_back(&Arg); ConstraintInfo Info(F.getParent()->getDataLayout(), FunctionArgs); - State S(DT); + State S(DT, LI, SE); std::unique_ptr<Module> ReproducerModule( DumpReproducers ? new Module(F.getName(), F.getContext()) : nullptr); @@ -1293,8 +1542,9 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, // transfer logic. stable_sort(S.WorkList, [](const FactOrCheck &A, const FactOrCheck &B) { auto HasNoConstOp = [](const FactOrCheck &B) { - return !isa<ConstantInt>(B.Inst->getOperand(0)) && - !isa<ConstantInt>(B.Inst->getOperand(1)); + Value *V0 = B.isConditionFact() ? B.Cond.Op0 : B.Inst->getOperand(0); + Value *V1 = B.isConditionFact() ? B.Cond.Op1 : B.Inst->getOperand(1); + return !isa<ConstantInt>(V0) && !isa<ConstantInt>(V1); }; // If both entries have the same In numbers, conditional facts come first. // Otherwise use the relative order in the basic block. @@ -1355,7 +1605,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, } else if (auto *Cmp = dyn_cast<ICmpInst>(Inst)) { bool Simplified = checkAndReplaceCondition( Cmp, Info, CB.NumIn, CB.NumOut, CB.getContextInst(), - ReproducerModule.get(), ReproducerCondStack, S.DT); + ReproducerModule.get(), ReproducerCondStack, S.DT, ToRemove); if (!Simplified && match(CB.getContextInst(), m_LogicalAnd(m_Value(), m_Specific(Inst)))) { Simplified = @@ -1367,8 +1617,11 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, continue; } - LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n"); auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) { + LLVM_DEBUG(dbgs() << "fact to add to the system: " + << CmpInst::getPredicateName(Pred) << " "; + A->printAsOperand(dbgs()); dbgs() << ", "; + B->printAsOperand(dbgs(), false); dbgs() << "\n"); if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) { LLVM_DEBUG( dbgs() @@ -1394,23 +1647,30 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, }; ICmpInst::Predicate Pred; - if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) { - Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); - AddFact(Pred, MinMax, MinMax->getLHS()); - AddFact(Pred, MinMax, MinMax->getRHS()); - continue; + if (!CB.isConditionFact()) { + if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) { + Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); + AddFact(Pred, MinMax, MinMax->getLHS()); + AddFact(Pred, MinMax, MinMax->getRHS()); + continue; + } } - Value *A, *B; - Value *Cmp = CB.Inst; - match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp))); - if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) { - // Use the inverse predicate if required. - if (CB.Not) - Pred = CmpInst::getInversePredicate(Pred); - - AddFact(Pred, A, B); + Value *A = nullptr, *B = nullptr; + if (CB.isConditionFact()) { + Pred = CB.Cond.Pred; + A = CB.Cond.Op0; + B = CB.Cond.Op1; + if (CB.DoesHold.Pred != CmpInst::BAD_ICMP_PREDICATE && + !Info.doesHold(CB.DoesHold.Pred, CB.DoesHold.Op0, CB.DoesHold.Op1)) + continue; + } else { + bool Matched = match(CB.Inst, m_Intrinsic<Intrinsic::assume>( + m_ICmp(Pred, m_Value(A), m_Value(B)))); + (void)Matched; + assert(Matched && "Must have an assume intrinsic with a icmp operand"); } + AddFact(Pred, A, B); } if (ReproducerModule && !ReproducerModule->functions().empty()) { @@ -1440,12 +1700,16 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, PreservedAnalyses ConstraintEliminationPass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - if (!eliminateConstraints(F, DT, ORE)) + if (!eliminateConstraints(F, DT, LI, SE, ORE)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + PA.preserve<ScalarEvolutionAnalysis>(); PA.preserveSet<CFGAnalyses>(); return PA; } diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 48b27a1ea0a2..a5cf875ef354 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -55,7 +55,6 @@ static cl::opt<bool> CanonicalizeICmpPredicatesToUnsigned( STATISTIC(NumPhis, "Number of phis propagated"); STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value"); STATISTIC(NumSelects, "Number of selects propagated"); -STATISTIC(NumMemAccess, "Number of memory access targets propagated"); STATISTIC(NumCmps, "Number of comparisons propagated"); STATISTIC(NumReturns, "Number of return values propagated"); STATISTIC(NumDeadCases, "Number of switch cases removed"); @@ -93,6 +92,7 @@ STATISTIC(NumNonNull, "Number of function pointer arguments marked non-null"); STATISTIC(NumMinMax, "Number of llvm.[us]{min,max} intrinsics removed"); STATISTIC(NumUDivURemsNarrowedExpanded, "Number of bound udiv's/urem's expanded"); +STATISTIC(NumZExt, "Number of non-negative deductions"); static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { if (S->getType()->isVectorTy() || isa<Constant>(S->getCondition())) @@ -263,23 +263,6 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT, return Changed; } -static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) { - Value *Pointer = nullptr; - if (LoadInst *L = dyn_cast<LoadInst>(I)) - Pointer = L->getPointerOperand(); - else - Pointer = cast<StoreInst>(I)->getPointerOperand(); - - if (isa<Constant>(Pointer)) return false; - - Constant *C = LVI->getConstant(Pointer, I); - if (!C) return false; - - ++NumMemAccess; - I->replaceUsesOfWith(Pointer, C); - return true; -} - static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { if (!CanonicalizeICmpPredicatesToUnsigned) return false; @@ -294,8 +277,9 @@ static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { ICmpInst::Predicate UnsignedPred = ConstantRange::getEquivalentPredWithFlippedSignedness( - Cmp->getPredicate(), LVI->getConstantRange(Cmp->getOperand(0), Cmp), - LVI->getConstantRange(Cmp->getOperand(1), Cmp)); + Cmp->getPredicate(), + LVI->getConstantRangeAtUse(Cmp->getOperandUse(0)), + LVI->getConstantRangeAtUse(Cmp->getOperandUse(1))); if (UnsignedPred == ICmpInst::Predicate::BAD_ICMP_PREDICATE) return false; @@ -470,17 +454,17 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI); // because it is negation-invariant. static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) { Value *X = II->getArgOperand(0); - bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne(); - Type *Ty = X->getType(); - Constant *IntMin = - ConstantInt::get(Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits())); - LazyValueInfo::Tristate Result; + if (!Ty->isIntegerTy()) + return false; + + bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne(); + APInt IntMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + ConstantRange Range = LVI->getConstantRangeAtUse( + II->getOperandUse(0), /*UndefAllowed*/ IsIntMinPoison); // Is X in [0, IntMin]? NOTE: INT_MIN is fine! - Result = LVI->getPredicateAt(CmpInst::Predicate::ICMP_ULE, X, IntMin, II, - /*UseBlockValue=*/true); - if (Result == LazyValueInfo::True) { + if (Range.icmp(CmpInst::ICMP_ULE, IntMin)) { ++NumAbs; II->replaceAllUsesWith(X); II->eraseFromParent(); @@ -488,40 +472,30 @@ static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) { } // Is X in [IntMin, 0]? NOTE: INT_MIN is fine! - Constant *Zero = ConstantInt::getNullValue(Ty); - Result = LVI->getPredicateAt(CmpInst::Predicate::ICMP_SLE, X, Zero, II, - /*UseBlockValue=*/true); - assert(Result != LazyValueInfo::False && "Should have been handled already."); - - if (Result == LazyValueInfo::Unknown) { - // Argument's range crosses zero. - bool Changed = false; - if (!IsIntMinPoison) { - // Can we at least tell that the argument is never INT_MIN? - Result = LVI->getPredicateAt(CmpInst::Predicate::ICMP_NE, X, IntMin, II, - /*UseBlockValue=*/true); - if (Result == LazyValueInfo::True) { - ++NumNSW; - ++NumSubNSW; - II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); - Changed = true; - } - } - return Changed; - } + if (Range.getSignedMax().isNonPositive()) { + IRBuilder<> B(II); + Value *NegX = B.CreateNeg(X, II->getName(), /*HasNUW=*/false, + /*HasNSW=*/IsIntMinPoison); + ++NumAbs; + II->replaceAllUsesWith(NegX); + II->eraseFromParent(); - IRBuilder<> B(II); - Value *NegX = B.CreateNeg(X, II->getName(), /*HasNUW=*/false, - /*HasNSW=*/IsIntMinPoison); - ++NumAbs; - II->replaceAllUsesWith(NegX); - II->eraseFromParent(); + // See if we can infer some no-wrap flags. + if (auto *BO = dyn_cast<BinaryOperator>(NegX)) + processBinOp(BO, LVI); - // See if we can infer some no-wrap flags. - if (auto *BO = dyn_cast<BinaryOperator>(NegX)) - processBinOp(BO, LVI); + return true; + } - return true; + // Argument's range crosses zero. + // Can we at least tell that the argument is never INT_MIN? + if (!IsIntMinPoison && !Range.contains(IntMin)) { + ++NumNSW; + ++NumSubNSW; + II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); + return true; + } + return false; } // See if this min/max intrinsic always picks it's one specific operand. @@ -783,7 +757,7 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, // NOTE: this transformation introduces two uses of X, // but it may be undef so we must freeze it first. Value *FrozenX = X; - if (!isGuaranteedNotToBeUndefOrPoison(X)) + if (!isGuaranteedNotToBeUndef(X)) FrozenX = B.CreateFreeze(X, X->getName() + ".frozen"); auto *AdjX = B.CreateNUWSub(FrozenX, Y, Instr->getName() + ".urem"); auto *Cmp = @@ -919,6 +893,14 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, assert(SDI->getOpcode() == Instruction::SDiv); assert(!SDI->getType()->isVectorTy()); + // Check whether the division folds to a constant. + ConstantRange DivCR = LCR.sdiv(RCR); + if (const APInt *Elem = DivCR.getSingleElement()) { + SDI->replaceAllUsesWith(ConstantInt::get(SDI->getType(), *Elem)); + SDI->eraseFromParent(); + return true; + } + struct Operand { Value *V; Domain D; @@ -1026,12 +1008,31 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { auto *ZExt = CastInst::CreateZExtOrBitCast(Base, SDI->getType(), "", SDI); ZExt->takeName(SDI); ZExt->setDebugLoc(SDI->getDebugLoc()); + ZExt->setNonNeg(); SDI->replaceAllUsesWith(ZExt); SDI->eraseFromParent(); return true; } +static bool processZExt(ZExtInst *ZExt, LazyValueInfo *LVI) { + if (ZExt->getType()->isVectorTy()) + return false; + + if (ZExt->hasNonNeg()) + return false; + + const Use &Base = ZExt->getOperandUse(0); + if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false) + .isAllNonNegative()) + return false; + + ++NumZExt; + ZExt->setNonNeg(); + + return true; +} + static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { using OBO = OverflowingBinaryOperator; @@ -1140,10 +1141,6 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, case Instruction::FCmp: BBChanged |= processCmp(cast<CmpInst>(&II), LVI); break; - case Instruction::Load: - case Instruction::Store: - BBChanged |= processMemAccess(&II, LVI); - break; case Instruction::Call: case Instruction::Invoke: BBChanged |= processCallSite(cast<CallBase>(II), LVI); @@ -1162,6 +1159,9 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, case Instruction::SExt: BBChanged |= processSExt(cast<SExtInst>(&II), LVI); break; + case Instruction::ZExt: + BBChanged |= processZExt(cast<ZExtInst>(&II), LVI); + break; case Instruction::Add: case Instruction::Sub: case Instruction::Mul: diff --git a/llvm/lib/Transforms/Scalar/DCE.cpp b/llvm/lib/Transforms/Scalar/DCE.cpp index d309799d95f0..2ad46130dc94 100644 --- a/llvm/lib/Transforms/Scalar/DCE.cpp +++ b/llvm/lib/Transforms/Scalar/DCE.cpp @@ -36,39 +36,6 @@ STATISTIC(DCEEliminated, "Number of insts removed"); DEBUG_COUNTER(DCECounter, "dce-transform", "Controls which instructions are eliminated"); -//===--------------------------------------------------------------------===// -// RedundantDbgInstElimination pass implementation -// - -namespace { -struct RedundantDbgInstElimination : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - RedundantDbgInstElimination() : FunctionPass(ID) { - initializeRedundantDbgInstEliminationPass(*PassRegistry::getPassRegistry()); - } - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - bool Changed = false; - for (auto &BB : F) - Changed |= RemoveRedundantDbgInstrs(&BB); - return Changed; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - } -}; -} - -char RedundantDbgInstElimination::ID = 0; -INITIALIZE_PASS(RedundantDbgInstElimination, "redundant-dbg-inst-elim", - "Redundant Dbg Instruction Elimination", false, false) - -Pass *llvm::createRedundantDbgInstEliminationPass() { - return new RedundantDbgInstElimination(); -} - PreservedAnalyses RedundantDbgInstEliminationPass::run(Function &F, FunctionAnalysisManager &AM) { bool Changed = false; diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index f2efe60bdf88..edfeb36f3422 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -100,10 +100,10 @@ static cl::opt<unsigned> MaxPathLength( cl::desc("Max number of blocks searched to find a threading path"), cl::Hidden, cl::init(20)); -static cl::opt<unsigned> MaxNumPaths( - "dfa-max-num-paths", - cl::desc("Max number of paths enumerated around a switch"), - cl::Hidden, cl::init(200)); +static cl::opt<unsigned> + MaxNumPaths("dfa-max-num-paths", + cl::desc("Max number of paths enumerated around a switch"), + cl::Hidden, cl::init(200)); static cl::opt<unsigned> CostThreshold("dfa-cost-threshold", @@ -249,16 +249,20 @@ void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold, FT = FalseBlock; // Update the phi node of SI. - SIUse->removeIncomingValue(StartBlock, /* DeletePHIIfEmpty = */ false); SIUse->addIncoming(SI->getTrueValue(), TrueBlock); SIUse->addIncoming(SI->getFalseValue(), FalseBlock); // Update any other PHI nodes in EndBlock. for (PHINode &Phi : EndBlock->phis()) { if (&Phi != SIUse) { - Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), TrueBlock); - Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), FalseBlock); + Value *OrigValue = Phi.getIncomingValueForBlock(StartBlock); + Phi.addIncoming(OrigValue, TrueBlock); + Phi.addIncoming(OrigValue, FalseBlock); } + + // Remove incoming place of original StartBlock, which comes in a indirect + // way (through TrueBlock and FalseBlock) now. + Phi.removeIncomingValue(StartBlock, /* DeletePHIIfEmpty = */ false); } } else { BasicBlock *NewBlock = nullptr; @@ -297,6 +301,7 @@ void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold, {DominatorTree::Insert, StartBlock, FT}}); // The select is now dead. + assert(SI->use_empty() && "Select must be dead now"); SI->eraseFromParent(); } @@ -466,8 +471,9 @@ private: if (!SITerm || !SITerm->isUnconditional()) return false; - if (isa<PHINode>(SIUse) && - SIBB->getSingleSuccessor() != cast<Instruction>(SIUse)->getParent()) + // Only fold the select coming from directly where it is defined. + PHINode *PHIUser = dyn_cast<PHINode>(SIUse); + if (PHIUser && PHIUser->getIncomingBlock(*SI->use_begin()) != SIBB) return false; // If select will not be sunk during unfolding, and it is in the same basic @@ -728,6 +734,10 @@ private: CodeMetrics Metrics; SwitchInst *Switch = SwitchPaths->getSwitchInst(); + // Don't thread switch without multiple successors. + if (Switch->getNumSuccessors() <= 1) + return false; + // Note that DuplicateBlockMap is not being used as intended here. It is // just being used to ensure (BB, State) pairs are only counted once. DuplicateBlockMap DuplicateMap; @@ -805,6 +815,8 @@ private: // using binary search, hence the LogBase2(). unsigned CondBranches = APInt(32, Switch->getNumSuccessors()).ceilLogBase2(); + assert(CondBranches > 0 && + "The threaded switch must have multiple branches"); DuplicationCost = Metrics.NumInsts / CondBranches; } else { // Compared with jump tables, the DFA optimizer removes an indirect branch diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index d3fbe49439a8..dd0a290252da 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -38,9 +38,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -205,16 +203,17 @@ static bool isShortenableAtTheBeginning(Instruction *I) { return isa<AnyMemSetInst>(I); } -static uint64_t getPointerSize(const Value *V, const DataLayout &DL, - const TargetLibraryInfo &TLI, - const Function *F) { +static std::optional<TypeSize> getPointerSize(const Value *V, + const DataLayout &DL, + const TargetLibraryInfo &TLI, + const Function *F) { uint64_t Size; ObjectSizeOpts Opts; Opts.NullIsUnknownSize = NullPointerIsDefined(F); if (getObjectSize(V, Size, DL, &TLI, Opts)) - return Size; - return MemoryLocation::UnknownSize; + return TypeSize::getFixed(Size); + return std::nullopt; } namespace { @@ -629,20 +628,11 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, Value *OrigDest = DeadIntrinsic->getRawDest(); if (!IsOverwriteEnd) { - Type *Int8PtrTy = - Type::getInt8PtrTy(DeadIntrinsic->getContext(), - OrigDest->getType()->getPointerAddressSpace()); - Value *Dest = OrigDest; - if (OrigDest->getType() != Int8PtrTy) - Dest = CastInst::CreatePointerCast(OrigDest, Int8PtrTy, "", DeadI); Value *Indices[1] = { ConstantInt::get(DeadWriteLength->getType(), ToRemoveSize)}; Instruction *NewDestGEP = GetElementPtrInst::CreateInBounds( - Type::getInt8Ty(DeadIntrinsic->getContext()), Dest, Indices, "", DeadI); + Type::getInt8Ty(DeadIntrinsic->getContext()), OrigDest, Indices, "", DeadI); NewDestGEP->setDebugLoc(DeadIntrinsic->getDebugLoc()); - if (NewDestGEP->getType() != OrigDest->getType()) - NewDestGEP = CastInst::CreatePointerCast(NewDestGEP, OrigDest->getType(), - "", DeadI); DeadIntrinsic->setDest(NewDestGEP); } @@ -850,9 +840,6 @@ struct DSEState { // Post-order numbers for each basic block. Used to figure out if memory // accesses are executed before another access. DenseMap<BasicBlock *, unsigned> PostOrderNumbers; - // Values that are only used with assumes. Used to refine pointer escape - // analysis. - SmallPtrSet<const Value *, 32> EphValues; /// Keep track of instructions (partly) overlapping with killing MemoryDefs per /// basic block. @@ -872,10 +859,10 @@ struct DSEState { DSEState &operator=(const DSEState &) = delete; DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, - PostDominatorTree &PDT, AssumptionCache &AC, - const TargetLibraryInfo &TLI, const LoopInfo &LI) - : F(F), AA(AA), EI(DT, LI, EphValues), BatchAA(AA, &EI), MSSA(MSSA), - DT(DT), PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) { + PostDominatorTree &PDT, const TargetLibraryInfo &TLI, + const LoopInfo &LI) + : F(F), AA(AA), EI(DT, &LI), BatchAA(AA, &EI), MSSA(MSSA), DT(DT), + PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) { // Collect blocks with throwing instructions not modeled in MemorySSA and // alloc-like objects. unsigned PO = 0; @@ -905,8 +892,6 @@ struct DSEState { AnyUnreachableExit = any_of(PDT.roots(), [](const BasicBlock *E) { return isa<UnreachableInst>(E->getTerminator()); }); - - CodeMetrics::collectEphemeralValues(&F, &AC, EphValues); } LocationSize strengthenLocationSize(const Instruction *I, @@ -958,10 +943,11 @@ struct DSEState { // Check whether the killing store overwrites the whole object, in which // case the size/offset of the dead store does not matter. - if (DeadUndObj == KillingUndObj && KillingLocSize.isPrecise()) { - uint64_t KillingUndObjSize = getPointerSize(KillingUndObj, DL, TLI, &F); - if (KillingUndObjSize != MemoryLocation::UnknownSize && - KillingUndObjSize == KillingLocSize.getValue()) + if (DeadUndObj == KillingUndObj && KillingLocSize.isPrecise() && + isIdentifiedObject(KillingUndObj)) { + std::optional<TypeSize> KillingUndObjSize = + getPointerSize(KillingUndObj, DL, TLI, &F); + if (KillingUndObjSize && *KillingUndObjSize == KillingLocSize.getValue()) return OW_Complete; } @@ -984,9 +970,15 @@ struct DSEState { return isMaskedStoreOverwrite(KillingI, DeadI, BatchAA); } - const uint64_t KillingSize = KillingLocSize.getValue(); - const uint64_t DeadSize = DeadLoc.Size.getValue(); + const TypeSize KillingSize = KillingLocSize.getValue(); + const TypeSize DeadSize = DeadLoc.Size.getValue(); + // Bail on doing Size comparison which depends on AA for now + // TODO: Remove AnyScalable once Alias Analysis deal with scalable vectors + const bool AnyScalable = + DeadSize.isScalable() || KillingLocSize.isScalable(); + if (AnyScalable) + return OW_Unknown; // Query the alias information AliasResult AAR = BatchAA.alias(KillingLoc, DeadLoc); @@ -1076,7 +1068,7 @@ struct DSEState { if (!isInvisibleToCallerOnUnwind(V)) { I.first->second = false; } else if (isNoAliasCall(V)) { - I.first->second = !PointerMayBeCaptured(V, true, false, EphValues); + I.first->second = !PointerMayBeCaptured(V, true, false); } } return I.first->second; @@ -1095,7 +1087,7 @@ struct DSEState { // with the killing MemoryDef. But we refrain from doing so for now to // limit compile-time and this does not cause any changes to the number // of stores removed on a large test set in practice. - I.first->second = PointerMayBeCaptured(V, false, true, EphValues); + I.first->second = PointerMayBeCaptured(V, false, true); return !I.first->second; } @@ -1861,6 +1853,10 @@ struct DSEState { if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || Func != LibFunc_malloc) return false; + // Gracefully handle malloc with unexpected memory attributes. + auto *MallocDef = dyn_cast_or_null<MemoryDef>(MSSA.getMemoryAccess(Malloc)); + if (!MallocDef) + return false; auto shouldCreateCalloc = [](CallInst *Malloc, CallInst *Memset) { // Check for br(icmp ptr, null), truebb, falsebb) pattern at the end @@ -1894,11 +1890,9 @@ struct DSEState { if (!Calloc) return false; MemorySSAUpdater Updater(&MSSA); - auto *LastDef = - cast<MemoryDef>(Updater.getMemorySSA()->getMemoryAccess(Malloc)); auto *NewAccess = - Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), LastDef, - LastDef); + Updater.createMemoryAccessAfter(cast<Instruction>(Calloc), nullptr, + MallocDef); auto *NewAccessMD = cast<MemoryDef>(NewAccess); Updater.insertDef(NewAccessMD, /*RenameUses=*/true); Updater.removeMemoryAccess(Malloc); @@ -2064,12 +2058,11 @@ struct DSEState { static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, - AssumptionCache &AC, const TargetLibraryInfo &TLI, const LoopInfo &LI) { bool MadeChange = false; - DSEState State(F, AA, MSSA, DT, PDT, AC, TLI, LI); + DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); // For each store: for (unsigned I = 0; I < State.MemDefs.size(); I++) { MemoryDef *KillingDef = State.MemDefs[I]; @@ -2250,10 +2243,9 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); LoopInfo &LI = AM.getResult<LoopAnalysis>(F); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, AC, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 67e8e82e408f..f736d429cb63 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -67,6 +67,7 @@ STATISTIC(NumCSE, "Number of instructions CSE'd"); STATISTIC(NumCSECVP, "Number of compare instructions CVP'd"); STATISTIC(NumCSELoad, "Number of load instructions CSE'd"); STATISTIC(NumCSECall, "Number of call instructions CSE'd"); +STATISTIC(NumCSEGEP, "Number of GEP instructions CSE'd"); STATISTIC(NumDSE, "Number of trivial dead stores removed"); DEBUG_COUNTER(CSECounter, "early-cse", @@ -143,11 +144,11 @@ struct SimpleValue { !CI->getFunction()->isPresplitCoroutine(); } return isa<CastInst>(Inst) || isa<UnaryOperator>(Inst) || - isa<BinaryOperator>(Inst) || isa<GetElementPtrInst>(Inst) || - isa<CmpInst>(Inst) || isa<SelectInst>(Inst) || - isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || - isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) || - isa<InsertValueInst>(Inst) || isa<FreezeInst>(Inst); + isa<BinaryOperator>(Inst) || isa<CmpInst>(Inst) || + isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) || + isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) || + isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst) || + isa<FreezeInst>(Inst); } }; @@ -307,21 +308,20 @@ static unsigned getHashValueImpl(SimpleValue Val) { IVI->getOperand(1), hash_combine_range(IVI->idx_begin(), IVI->idx_end())); - assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) || - isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || - isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst) || - isa<FreezeInst>(Inst)) && + assert((isa<CallInst>(Inst) || isa<ExtractElementInst>(Inst) || + isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) || + isa<UnaryOperator>(Inst) || isa<FreezeInst>(Inst)) && "Invalid/unknown instruction"); // Handle intrinsics with commutative operands. - // TODO: Extend this to handle intrinsics with >2 operands where the 1st - // 2 operands are commutative. auto *II = dyn_cast<IntrinsicInst>(Inst); - if (II && II->isCommutative() && II->arg_size() == 2) { + if (II && II->isCommutative() && II->arg_size() >= 2) { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); if (LHS > RHS) std::swap(LHS, RHS); - return hash_combine(II->getOpcode(), LHS, RHS); + return hash_combine( + II->getOpcode(), LHS, RHS, + hash_combine_range(II->value_op_begin() + 2, II->value_op_end())); } // gc.relocate is 'special' call: its second and third operands are @@ -396,13 +396,14 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate(); } - // TODO: Extend this for >2 args by matching the trailing N-2 args. auto *LII = dyn_cast<IntrinsicInst>(LHSI); auto *RII = dyn_cast<IntrinsicInst>(RHSI); if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() && - LII->isCommutative() && LII->arg_size() == 2) { + LII->isCommutative() && LII->arg_size() >= 2) { return LII->getArgOperand(0) == RII->getArgOperand(1) && - LII->getArgOperand(1) == RII->getArgOperand(0); + LII->getArgOperand(1) == RII->getArgOperand(0) && + std::equal(LII->arg_begin() + 2, LII->arg_end(), + RII->arg_begin() + 2, RII->arg_end()); } // See comment above in `getHashValue()`. @@ -548,12 +549,82 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) { // currently executing, so conservatively return false if they are in // different basic blocks. if (LHSI->isConvergent() && LHSI->getParent() != RHSI->getParent()) - return false; + return false; return LHSI->isIdenticalTo(RHSI); } //===----------------------------------------------------------------------===// +// GEPValue +//===----------------------------------------------------------------------===// + +namespace { + +struct GEPValue { + Instruction *Inst; + std::optional<int64_t> ConstantOffset; + + GEPValue(Instruction *I) : Inst(I) { + assert((isSentinel() || canHandle(I)) && "Inst can't be handled!"); + } + + GEPValue(Instruction *I, std::optional<int64_t> ConstantOffset) + : Inst(I), ConstantOffset(ConstantOffset) { + assert((isSentinel() || canHandle(I)) && "Inst can't be handled!"); + } + + bool isSentinel() const { + return Inst == DenseMapInfo<Instruction *>::getEmptyKey() || + Inst == DenseMapInfo<Instruction *>::getTombstoneKey(); + } + + static bool canHandle(Instruction *Inst) { + return isa<GetElementPtrInst>(Inst); + } +}; + +} // namespace + +namespace llvm { + +template <> struct DenseMapInfo<GEPValue> { + static inline GEPValue getEmptyKey() { + return DenseMapInfo<Instruction *>::getEmptyKey(); + } + + static inline GEPValue getTombstoneKey() { + return DenseMapInfo<Instruction *>::getTombstoneKey(); + } + + static unsigned getHashValue(const GEPValue &Val); + static bool isEqual(const GEPValue &LHS, const GEPValue &RHS); +}; + +} // end namespace llvm + +unsigned DenseMapInfo<GEPValue>::getHashValue(const GEPValue &Val) { + auto *GEP = cast<GetElementPtrInst>(Val.Inst); + if (Val.ConstantOffset.has_value()) + return hash_combine(GEP->getOpcode(), GEP->getPointerOperand(), + Val.ConstantOffset.value()); + return hash_combine( + GEP->getOpcode(), + hash_combine_range(GEP->value_op_begin(), GEP->value_op_end())); +} + +bool DenseMapInfo<GEPValue>::isEqual(const GEPValue &LHS, const GEPValue &RHS) { + if (LHS.isSentinel() || RHS.isSentinel()) + return LHS.Inst == RHS.Inst; + auto *LGEP = cast<GetElementPtrInst>(LHS.Inst); + auto *RGEP = cast<GetElementPtrInst>(RHS.Inst); + if (LGEP->getPointerOperand() != RGEP->getPointerOperand()) + return false; + if (LHS.ConstantOffset.has_value() && RHS.ConstantOffset.has_value()) + return LHS.ConstantOffset.value() == RHS.ConstantOffset.value(); + return LGEP->isIdenticalToWhenDefined(RGEP); +} + +//===----------------------------------------------------------------------===// // EarlyCSE implementation //===----------------------------------------------------------------------===// @@ -647,6 +718,13 @@ public: ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>>; CallHTType AvailableCalls; + using GEPMapAllocatorTy = + RecyclingAllocator<BumpPtrAllocator, + ScopedHashTableVal<GEPValue, Value *>>; + using GEPHTType = ScopedHashTable<GEPValue, Value *, DenseMapInfo<GEPValue>, + GEPMapAllocatorTy>; + GEPHTType AvailableGEPs; + /// This is the current generation of the memory value. unsigned CurrentGeneration = 0; @@ -667,9 +745,11 @@ private: class NodeScope { public: NodeScope(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, - InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls) - : Scope(AvailableValues), LoadScope(AvailableLoads), - InvariantScope(AvailableInvariants), CallScope(AvailableCalls) {} + InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls, + GEPHTType &AvailableGEPs) + : Scope(AvailableValues), LoadScope(AvailableLoads), + InvariantScope(AvailableInvariants), CallScope(AvailableCalls), + GEPScope(AvailableGEPs) {} NodeScope(const NodeScope &) = delete; NodeScope &operator=(const NodeScope &) = delete; @@ -678,6 +758,7 @@ private: LoadHTType::ScopeTy LoadScope; InvariantHTType::ScopeTy InvariantScope; CallHTType::ScopeTy CallScope; + GEPHTType::ScopeTy GEPScope; }; // Contains all the needed information to create a stack for doing a depth @@ -688,13 +769,13 @@ private: public: StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls, - unsigned cg, DomTreeNode *n, DomTreeNode::const_iterator child, + GEPHTType &AvailableGEPs, unsigned cg, DomTreeNode *n, + DomTreeNode::const_iterator child, DomTreeNode::const_iterator end) : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child), EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableInvariants, - AvailableCalls) - {} + AvailableCalls, AvailableGEPs) {} StackNode(const StackNode &) = delete; StackNode &operator=(const StackNode &) = delete; @@ -1214,6 +1295,20 @@ Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst, return Result; } +static void combineIRFlags(Instruction &From, Value *To) { + if (auto *I = dyn_cast<Instruction>(To)) { + // If I being poison triggers UB, there is no need to drop those + // flags. Otherwise, only retain flags present on both I and Inst. + // TODO: Currently some fast-math flags are not treated as + // poison-generating even though they should. Until this is fixed, + // always retain flags present on both I and Inst for floating point + // instructions. + if (isa<FPMathOperator>(I) || + (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I))) + I->andIRFlags(&From); + } +} + bool EarlyCSE::overridingStores(const ParseMemoryInst &Earlier, const ParseMemoryInst &Later) { // Can we remove Earlier store because of Later store? @@ -1424,7 +1519,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If this is a simple instruction that we can value number, process it. if (SimpleValue::canHandle(&Inst)) { - if (auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&Inst)) { + if ([[maybe_unused]] auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&Inst)) { assert(CI->getExceptionBehavior() != fp::ebStrict && "Unexpected ebStrict from SimpleValue::canHandle()"); assert((!CI->getRoundingMode() || @@ -1439,16 +1534,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } - if (auto *I = dyn_cast<Instruction>(V)) { - // If I being poison triggers UB, there is no need to drop those - // flags. Otherwise, only retain flags present on both I and Inst. - // TODO: Currently some fast-math flags are not treated as - // poison-generating even though they should. Until this is fixed, - // always retain flags present on both I and Inst for floating point - // instructions. - if (isa<FPMathOperator>(I) || (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I))) - I->andIRFlags(&Inst); - } + combineIRFlags(Inst, V); Inst.replaceAllUsesWith(V); salvageKnowledge(&Inst, &AC); removeMSSA(Inst); @@ -1561,6 +1647,31 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } + // Compare GEP instructions based on offset. + if (GEPValue::canHandle(&Inst)) { + auto *GEP = cast<GetElementPtrInst>(&Inst); + APInt Offset = APInt(SQ.DL.getIndexTypeSizeInBits(GEP->getType()), 0); + GEPValue GEPVal(GEP, GEP->accumulateConstantOffset(SQ.DL, Offset) + ? Offset.trySExtValue() + : std::nullopt); + if (Value *V = AvailableGEPs.lookup(GEPVal)) { + LLVM_DEBUG(dbgs() << "EarlyCSE CSE GEP: " << Inst << " to: " << *V + << '\n'); + combineIRFlags(Inst, V); + Inst.replaceAllUsesWith(V); + salvageKnowledge(&Inst, &AC); + removeMSSA(Inst); + Inst.eraseFromParent(); + Changed = true; + ++NumCSEGEP; + continue; + } + + // Otherwise, just remember that we have this GEP. + AvailableGEPs.insert(GEPVal, &Inst); + continue; + } + // A release fence requires that all stores complete before it, but does // not prevent the reordering of following loads 'before' the fence. As a // result, we don't need to consider it as writing to memory and don't need @@ -1675,7 +1786,7 @@ bool EarlyCSE::run() { // Process the root node. nodesToProcess.push_back(new StackNode( AvailableValues, AvailableLoads, AvailableInvariants, AvailableCalls, - CurrentGeneration, DT.getRootNode(), + AvailableGEPs, CurrentGeneration, DT.getRootNode(), DT.getRootNode()->begin(), DT.getRootNode()->end())); assert(!CurrentGeneration && "Create a new EarlyCSE instance to rerun it."); @@ -1698,10 +1809,10 @@ bool EarlyCSE::run() { } else if (NodeToProcess->childIter() != NodeToProcess->end()) { // Push the next child onto the stack. DomTreeNode *child = NodeToProcess->nextChild(); - nodesToProcess.push_back( - new StackNode(AvailableValues, AvailableLoads, AvailableInvariants, - AvailableCalls, NodeToProcess->childGeneration(), - child, child->begin(), child->end())); + nodesToProcess.push_back(new StackNode( + AvailableValues, AvailableLoads, AvailableInvariants, AvailableCalls, + AvailableGEPs, NodeToProcess->childGeneration(), child, + child->begin(), child->end())); } else { // It has been processed, and there are no more children to process, // so delete it and pop it off the stack. diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 03e8a2507b45..5e58af0edc15 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -760,7 +760,7 @@ PreservedAnalyses GVNPass::run(Function &F, FunctionAnalysisManager &AM) { auto &AA = AM.getResult<AAManager>(F); auto *MemDep = isMemDepEnabled() ? &AM.getResult<MemoryDependenceAnalysis>(F) : nullptr; - auto *LI = AM.getCachedResult<LoopAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); auto *MSSA = AM.getCachedResult<MemorySSAAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); bool Changed = runImpl(F, AC, DT, TLI, AA, MemDep, LI, &ORE, @@ -772,8 +772,7 @@ PreservedAnalyses GVNPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<TargetLibraryAnalysis>(); if (MSSA) PA.preserve<MemorySSAAnalysis>(); - if (LI) - PA.preserve<LoopAnalysis>(); + PA.preserve<LoopAnalysis>(); return PA; } @@ -946,9 +945,14 @@ static void replaceValuesPerBlockEntry( SmallVectorImpl<AvailableValueInBlock> &ValuesPerBlock, Value *OldValue, Value *NewValue) { for (AvailableValueInBlock &V : ValuesPerBlock) { - if ((V.AV.isSimpleValue() && V.AV.getSimpleValue() == OldValue) || - (V.AV.isCoercedLoadValue() && V.AV.getCoercedLoadValue() == OldValue)) - V = AvailableValueInBlock::get(V.BB, NewValue); + if (V.AV.Val == OldValue) + V.AV.Val = NewValue; + if (V.AV.isSelectValue()) { + if (V.AV.V1 == OldValue) + V.AV.V1 = NewValue; + if (V.AV.V2 == OldValue) + V.AV.V2 = NewValue; + } } } @@ -1147,13 +1151,11 @@ static Value *findDominatingValue(const MemoryLocation &Loc, Type *LoadTy, BasicBlock *FromBB = From->getParent(); BatchAAResults BatchAA(*AA); for (BasicBlock *BB = FromBB; BB; BB = BB->getSinglePredecessor()) - for (auto I = BB == FromBB ? From->getReverseIterator() : BB->rbegin(), - E = BB->rend(); - I != E; ++I) { + for (auto *Inst = BB == FromBB ? From : BB->getTerminator(); + Inst != nullptr; Inst = Inst->getPrevNonDebugInstruction()) { // Stop the search if limit is reached. if (++NumVisitedInsts > MaxNumVisitedInsts) return nullptr; - Instruction *Inst = &*I; if (isModSet(BatchAA.getModRefInfo(Inst, Loc))) return nullptr; if (auto *LI = dyn_cast<LoadInst>(Inst)) @@ -1368,7 +1370,7 @@ LoadInst *GVNPass::findLoadToHoistIntoPred(BasicBlock *Pred, BasicBlock *LoadBB, LoadInst *Load) { // For simplicity we handle a Pred has 2 successors only. auto *Term = Pred->getTerminator(); - if (Term->getNumSuccessors() != 2 || Term->isExceptionalTerminator()) + if (Term->getNumSuccessors() != 2 || Term->isSpecialTerminator()) return nullptr; auto *SuccBB = Term->getSuccessor(0); if (SuccBB == LoadBB) @@ -1416,16 +1418,8 @@ void GVNPass::eliminatePartiallyRedundantLoad( Load->getSyncScopeID(), UnavailableBlock->getTerminator()); NewLoad->setDebugLoc(Load->getDebugLoc()); if (MSSAU) { - auto *MSSA = MSSAU->getMemorySSA(); - // Get the defining access of the original load or use the load if it is a - // MemoryDef (e.g. because it is volatile). The inserted loads are - // guaranteed to load from the same definition. - auto *LoadAcc = MSSA->getMemoryAccess(Load); - auto *DefiningAcc = - isa<MemoryDef>(LoadAcc) ? LoadAcc : LoadAcc->getDefiningAccess(); auto *NewAccess = MSSAU->createMemoryAccessInBB( - NewLoad, DefiningAcc, NewLoad->getParent(), - MemorySSA::BeforeTerminator); + NewLoad, nullptr, NewLoad->getParent(), MemorySSA::BeforeTerminator); if (auto *NewDef = dyn_cast<MemoryDef>(NewAccess)) MSSAU->insertDef(NewDef, /*RenameUses=*/true); else @@ -1444,8 +1438,7 @@ void GVNPass::eliminatePartiallyRedundantLoad( if (auto *RangeMD = Load->getMetadata(LLVMContext::MD_range)) NewLoad->setMetadata(LLVMContext::MD_range, RangeMD); if (auto *AccessMD = Load->getMetadata(LLVMContext::MD_access_group)) - if (LI && - LI->getLoopFor(Load->getParent()) == LI->getLoopFor(UnavailableBlock)) + if (LI->getLoopFor(Load->getParent()) == LI->getLoopFor(UnavailableBlock)) NewLoad->setMetadata(LLVMContext::MD_access_group, AccessMD); // We do not propagate the old load's debug location, because the new @@ -1482,6 +1475,7 @@ void GVNPass::eliminatePartiallyRedundantLoad( // Perform PHI construction. Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this); // ConstructSSAForLoadSet is responsible for combining metadata. + ICF->removeUsersOf(Load); Load->replaceAllUsesWith(V); if (isa<PHINode>(V)) V->takeName(Load); @@ -1752,9 +1746,6 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, bool GVNPass::performLoopLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, UnavailBlkVect &UnavailableBlocks) { - if (!LI) - return false; - const Loop *L = LI->getLoopFor(Load->getParent()); // TODO: Generalize to other loop blocks that dominate the latch. if (!L || L->getHeader() != Load->getParent()) @@ -1901,6 +1892,7 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) { // Perform PHI construction. Value *V = ConstructSSAForLoadSet(Load, ValuesPerBlock, *this); // ConstructSSAForLoadSet is responsible for combining metadata. + ICF->removeUsersOf(Load); Load->replaceAllUsesWith(V); if (isa<PHINode>(V)) @@ -1922,7 +1914,7 @@ bool GVNPass::processNonLocalLoad(LoadInst *Load) { // Step 4: Eliminate partial redundancy. if (!isPREEnabled() || !isLoadPREEnabled()) return Changed; - if (!isLoadInLoopPREEnabled() && LI && LI->getLoopFor(Load->getParent())) + if (!isLoadInLoopPREEnabled() && LI->getLoopFor(Load->getParent())) return Changed; if (performLoopLoadPRE(Load, ValuesPerBlock, UnavailableBlocks) || @@ -1998,12 +1990,12 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { if (Cond->isZero()) { Type *Int8Ty = Type::getInt8Ty(V->getContext()); + Type *PtrTy = PointerType::get(V->getContext(), 0); // Insert a new store to null instruction before the load to indicate that // this code is not reachable. FIXME: We could insert unreachable // instruction directly because we can modify the CFG. auto *NewS = new StoreInst(PoisonValue::get(Int8Ty), - Constant::getNullValue(Int8Ty->getPointerTo()), - IntrinsicI); + Constant::getNullValue(PtrTy), IntrinsicI); if (MSSAU) { const MemoryUseOrDef *FirstNonDom = nullptr; const auto *AL = @@ -2023,14 +2015,12 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { } } - // This added store is to null, so it will never executed and we can - // just use the LiveOnEntry def as defining access. auto *NewDef = FirstNonDom ? MSSAU->createMemoryAccessBefore( - NewS, MSSAU->getMemorySSA()->getLiveOnEntryDef(), + NewS, nullptr, const_cast<MemoryUseOrDef *>(FirstNonDom)) : MSSAU->createMemoryAccessInBB( - NewS, MSSAU->getMemorySSA()->getLiveOnEntryDef(), + NewS, nullptr, NewS->getParent(), MemorySSA::BeforeTerminator); MSSAU->insertDef(cast<MemoryDef>(NewDef), /*RenameUses=*/false); @@ -2177,6 +2167,7 @@ bool GVNPass::processLoad(LoadInst *L) { Value *AvailableValue = AV->MaterializeAdjustedValue(L, L, *this); // MaterializeAdjustedValue is responsible for combining metadata. + ICF->removeUsersOf(L); L->replaceAllUsesWith(AvailableValue); markInstructionForDeletion(L); if (MSSAU) @@ -2695,7 +2686,7 @@ bool GVNPass::processInstruction(Instruction *I) { /// runOnFunction - This is the main transformation entry point for a function. bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, const TargetLibraryInfo &RunTLI, AAResults &RunAA, - MemoryDependenceResults *RunMD, LoopInfo *LI, + MemoryDependenceResults *RunMD, LoopInfo &LI, OptimizationRemarkEmitter *RunORE, MemorySSA *MSSA) { AC = &RunAC; DT = &RunDT; @@ -2705,7 +2696,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, MD = RunMD; ImplicitControlFlowTracking ImplicitCFT; ICF = &ImplicitCFT; - this->LI = LI; + this->LI = &LI; VN.setMemDep(MD); ORE = RunORE; InvalidBlockRPONumbers = true; @@ -2719,7 +2710,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, // Merge unconditional branches, allowing PRE to catch more // optimization opportunities. for (BasicBlock &BB : llvm::make_early_inc_range(F)) { - bool removedBlock = MergeBlockIntoPredecessor(&BB, &DTU, LI, MSSAU, MD); + bool removedBlock = MergeBlockIntoPredecessor(&BB, &DTU, &LI, MSSAU, MD); if (removedBlock) ++NumGVNBlocks; @@ -2778,7 +2769,12 @@ bool GVNPass::processBlock(BasicBlock *BB) { // use our normal hash approach for phis. Instead, simply look for // obvious duplicates. The first pass of GVN will tend to create // identical phis, and the second or later passes can eliminate them. - ChangedFunction |= EliminateDuplicatePHINodes(BB); + SmallPtrSet<PHINode *, 8> PHINodesToRemove; + ChangedFunction |= EliminateDuplicatePHINodes(BB, PHINodesToRemove); + for (PHINode *PN : PHINodesToRemove) { + VN.erase(PN); + removeInstruction(PN); + } for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { @@ -2997,9 +2993,9 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { ++NumGVNPRE; // Create a PHI to make the value available in this block. - PHINode *Phi = - PHINode::Create(CurInst->getType(), predMap.size(), - CurInst->getName() + ".pre-phi", &CurrentBlock->front()); + PHINode *Phi = PHINode::Create(CurInst->getType(), predMap.size(), + CurInst->getName() + ".pre-phi"); + Phi->insertBefore(CurrentBlock->begin()); for (unsigned i = 0, e = predMap.size(); i != e; ++i) { if (Value *V = predMap[i].first) { // If we use an existing value in this phi, we have to patch the original @@ -3290,8 +3286,6 @@ public: if (skipFunction(F)) return false; - auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); - auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); return Impl.runImpl( F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), @@ -3301,7 +3295,7 @@ public: Impl.isMemDepEnabled() ? &getAnalysis<MemoryDependenceWrapperPass>().getMemDep() : nullptr, - LIWP ? &LIWP->getLoopInfo() : nullptr, + getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), MSSAWP ? &MSSAWP->getMSSA() : nullptr); } diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index 26a6978656e6..2b38831139a5 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -850,8 +850,9 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, // Create a new PHI in the successor block and populate it. auto *Op = I0->getOperand(O); assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!"); - auto *PN = PHINode::Create(Op->getType(), Insts.size(), - Op->getName() + ".sink", &BBEnd->front()); + auto *PN = + PHINode::Create(Op->getType(), Insts.size(), Op->getName() + ".sink"); + PN->insertBefore(BBEnd->begin()); for (auto *I : Insts) PN->addIncoming(I->getOperand(O), I->getParent()); NewOperands.push_back(PN); diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp index 62b40a23e38c..3bbf6642a90c 100644 --- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -45,16 +45,14 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" @@ -123,12 +121,12 @@ static void eliminateGuard(Instruction *GuardInst, MemorySSAUpdater *MSSAU) { /// condition should stay invariant. Otherwise there can be a miscompile, like /// the one described at https://github.com/llvm/llvm-project/issues/60234. The /// safest way to do it is to expand the new condition at WC's block. -static Instruction *findInsertionPointForWideCondition(Instruction *Guard) { - Value *Condition, *WC; - BasicBlock *IfTrue, *IfFalse; - if (parseWidenableBranch(Guard, Condition, WC, IfTrue, IfFalse)) +static Instruction *findInsertionPointForWideCondition(Instruction *WCOrGuard) { + if (isGuard(WCOrGuard)) + return WCOrGuard; + if (auto WC = extractWidenableCondition(WCOrGuard)) return cast<Instruction>(WC); - return Guard; + return nullptr; } class GuardWideningImpl { @@ -157,8 +155,8 @@ class GuardWideningImpl { /// maps BasicBlocks to the set of guards seen in that block. bool eliminateInstrViaWidening( Instruction *Instr, const df_iterator<DomTreeNode *> &DFSI, - const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> & - GuardsPerBlock, bool InvertCondition = false); + const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> + &GuardsPerBlock); /// Used to keep track of which widening potential is more effective. enum WideningScore { @@ -181,11 +179,12 @@ class GuardWideningImpl { static StringRef scoreTypeToString(WideningScore WS); /// Compute the score for widening the condition in \p DominatedInstr - /// into \p DominatingGuard. If \p InvertCond is set, then we widen the - /// inverted condition of the dominating guard. + /// into \p WideningPoint. WideningScore computeWideningScore(Instruction *DominatedInstr, - Instruction *DominatingGuard, - bool InvertCond); + Instruction *ToWiden, + Instruction *WideningPoint, + SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden); /// Helper to check if \p V can be hoisted to \p InsertPos. bool canBeHoistedTo(const Value *V, const Instruction *InsertPos) const { @@ -196,19 +195,36 @@ class GuardWideningImpl { bool canBeHoistedTo(const Value *V, const Instruction *InsertPos, SmallPtrSetImpl<const Instruction *> &Visited) const; + bool canBeHoistedTo(const SmallVectorImpl<Value *> &Checks, + const Instruction *InsertPos) const { + return all_of(Checks, + [&](const Value *V) { return canBeHoistedTo(V, InsertPos); }); + } /// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c /// canBeHoistedTo returned true. void makeAvailableAt(Value *V, Instruction *InsertPos) const; + void makeAvailableAt(const SmallVectorImpl<Value *> &Checks, + Instruction *InsertPos) const { + for (Value *V : Checks) + makeAvailableAt(V, InsertPos); + } + /// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try - /// to generate an expression computing the logical AND of \p Cond0 and (\p - /// Cond1 XOR \p InvertCondition). - /// Return true if the expression computing the AND is only as - /// expensive as computing one of the two. If \p InsertPt is true then - /// actually generate the resulting expression, make it available at \p - /// InsertPt and return it in \p Result (else no change to the IR is made). - bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, - Value *&Result, bool InvertCondition); + /// to generate an expression computing the logical AND of \p ChecksToHoist + /// and \p ChecksToWiden. Return true if the expression computing the AND is + /// only as expensive as computing one of the set of expressions. If \p + /// InsertPt is true then actually generate the resulting expression, make it + /// available at \p InsertPt and return it in \p Result (else no change to the + /// IR is made). + std::optional<Value *> mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden, + Instruction *InsertPt); + + /// Generate the logical AND of \p ChecksToHoist and \p OldCondition and make + /// it available at InsertPt + Value *hoistChecks(SmallVectorImpl<Value *> &ChecksToHoist, + Value *OldCondition, Instruction *InsertPt); /// Adds freeze to Orig and push it as far as possible very aggressively. /// Also replaces all uses of frozen instruction with frozen version. @@ -253,16 +269,19 @@ class GuardWideningImpl { } }; - /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and + /// Parse \p ToParse into a conjunction (logical-and) of range checks; and /// append them to \p Checks. Returns true on success, may clobber \c Checks /// on failure. - bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) { - SmallPtrSet<const Value *, 8> Visited; - return parseRangeChecks(CheckCond, Checks, Visited); + bool parseRangeChecks(SmallVectorImpl<Value *> &ToParse, + SmallVectorImpl<RangeCheck> &Checks) { + for (auto CheckCond : ToParse) { + if (!parseRangeChecks(CheckCond, Checks)) + return false; + } + return true; } - bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks, - SmallPtrSetImpl<const Value *> &Visited); + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks); /// Combine the checks in \p Checks into a smaller set of checks and append /// them into \p CombinedChecks. Return true on success (i.e. all of checks @@ -271,23 +290,24 @@ class GuardWideningImpl { bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks, SmallVectorImpl<RangeCheck> &CombinedChecks) const; - /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of - /// computing only one of the two expressions? - bool isWideningCondProfitable(Value *Cond0, Value *Cond1, bool InvertCond) { - Value *ResultUnused; - return widenCondCommon(Cond0, Cond1, /*InsertPt=*/nullptr, ResultUnused, - InvertCond); + /// Can we compute the logical AND of \p ChecksToHoist and \p ChecksToWiden + /// for the price of computing only one of the set of expressions? + bool isWideningCondProfitable(SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden) { + return mergeChecks(ChecksToHoist, ChecksToWiden, /*InsertPt=*/nullptr) + .has_value(); } - /// If \p InvertCondition is false, Widen \p ToWiden to fail if - /// \p NewCondition is false, otherwise make it fail if \p NewCondition is - /// true (in addition to whatever it is already checking). - void widenGuard(Instruction *ToWiden, Value *NewCondition, - bool InvertCondition) { - Value *Result; + /// Widen \p ChecksToWiden to fail if any of \p ChecksToHoist is false + void widenGuard(SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden, + Instruction *ToWiden) { Instruction *InsertPt = findInsertionPointForWideCondition(ToWiden); - widenCondCommon(getCondition(ToWiden), NewCondition, InsertPt, Result, - InvertCondition); + auto MergedCheck = mergeChecks(ChecksToHoist, ChecksToWiden, InsertPt); + Value *Result = MergedCheck ? *MergedCheck + : hoistChecks(ChecksToHoist, + getCondition(ToWiden), InsertPt); + if (isGuardAsWidenableBranch(ToWiden)) { setWidenableBranchCond(cast<BranchInst>(ToWiden), Result); return; @@ -353,12 +373,15 @@ bool GuardWideningImpl::run() { bool GuardWideningImpl::eliminateInstrViaWidening( Instruction *Instr, const df_iterator<DomTreeNode *> &DFSI, - const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> & - GuardsInBlock, bool InvertCondition) { + const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> + &GuardsInBlock) { + SmallVector<Value *> ChecksToHoist; + parseWidenableGuard(Instr, ChecksToHoist); // Ignore trivial true or false conditions. These instructions will be // trivially eliminated by any cleanup pass. Do not erase them because other // guards can possibly be widened into them. - if (isa<ConstantInt>(getCondition(Instr))) + if (ChecksToHoist.empty() || + (ChecksToHoist.size() == 1 && isa<ConstantInt>(ChecksToHoist.front()))) return false; Instruction *BestSoFar = nullptr; @@ -394,10 +417,15 @@ bool GuardWideningImpl::eliminateInstrViaWidening( assert((i == (e - 1)) == (Instr->getParent() == CurBB) && "Bad DFS?"); for (auto *Candidate : make_range(I, E)) { - auto Score = computeWideningScore(Instr, Candidate, InvertCondition); - LLVM_DEBUG(dbgs() << "Score between " << *getCondition(Instr) - << " and " << *getCondition(Candidate) << " is " - << scoreTypeToString(Score) << "\n"); + auto *WideningPoint = findInsertionPointForWideCondition(Candidate); + if (!WideningPoint) + continue; + SmallVector<Value *> CandidateChecks; + parseWidenableGuard(Candidate, CandidateChecks); + auto Score = computeWideningScore(Instr, Candidate, WideningPoint, + ChecksToHoist, CandidateChecks); + LLVM_DEBUG(dbgs() << "Score between " << *Instr << " and " << *Candidate + << " is " << scoreTypeToString(Score) << "\n"); if (Score > BestScoreSoFar) { BestScoreSoFar = Score; BestSoFar = Candidate; @@ -416,22 +444,22 @@ bool GuardWideningImpl::eliminateInstrViaWidening( LLVM_DEBUG(dbgs() << "Widening " << *Instr << " into " << *BestSoFar << " with score " << scoreTypeToString(BestScoreSoFar) << "\n"); - widenGuard(BestSoFar, getCondition(Instr), InvertCondition); - auto NewGuardCondition = InvertCondition - ? ConstantInt::getFalse(Instr->getContext()) - : ConstantInt::getTrue(Instr->getContext()); + SmallVector<Value *> ChecksToWiden; + parseWidenableGuard(BestSoFar, ChecksToWiden); + widenGuard(ChecksToHoist, ChecksToWiden, BestSoFar); + auto NewGuardCondition = ConstantInt::getTrue(Instr->getContext()); setCondition(Instr, NewGuardCondition); EliminatedGuardsAndBranches.push_back(Instr); WidenedGuards.insert(BestSoFar); return true; } -GuardWideningImpl::WideningScore -GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, - Instruction *DominatingGuard, - bool InvertCond) { +GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( + Instruction *DominatedInstr, Instruction *ToWiden, + Instruction *WideningPoint, SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden) { Loop *DominatedInstrLoop = LI.getLoopFor(DominatedInstr->getParent()); - Loop *DominatingGuardLoop = LI.getLoopFor(DominatingGuard->getParent()); + Loop *DominatingGuardLoop = LI.getLoopFor(WideningPoint->getParent()); bool HoistingOutOfLoop = false; if (DominatingGuardLoop != DominatedInstrLoop) { @@ -444,10 +472,12 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, HoistingOutOfLoop = true; } - auto *WideningPoint = findInsertionPointForWideCondition(DominatingGuard); - if (!canBeHoistedTo(getCondition(DominatedInstr), WideningPoint)) + if (!canBeHoistedTo(ChecksToHoist, WideningPoint)) return WS_IllegalOrNegative; - if (!canBeHoistedTo(getCondition(DominatingGuard), WideningPoint)) + // Further in the GuardWideningImpl::hoistChecks the entire condition might be + // widened, not the parsed list of checks. So we need to check the possibility + // of that condition hoisting. + if (!canBeHoistedTo(getCondition(ToWiden), WideningPoint)) return WS_IllegalOrNegative; // If the guard was conditional executed, it may never be reached @@ -458,8 +488,7 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, // here. TODO: evaluate cost model for spurious deopt // NOTE: As written, this also lets us hoist right over another guard which // is essentially just another spelling for control flow. - if (isWideningCondProfitable(getCondition(DominatedInstr), - getCondition(DominatingGuard), InvertCond)) + if (isWideningCondProfitable(ChecksToHoist, ChecksToWiden)) return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive; if (HoistingOutOfLoop) @@ -495,7 +524,7 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, // control flow (guards, calls which throw, etc...). That choice appears // arbitrary (we assume that implicit control flow exits are all rare). auto MaybeHoistingToHotterBlock = [&]() { - const auto *DominatingBlock = DominatingGuard->getParent(); + const auto *DominatingBlock = WideningPoint->getParent(); const auto *DominatedBlock = DominatedInstr->getParent(); // Descend as low as we can, always taking the likely successor. @@ -521,7 +550,8 @@ GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, if (!DT.dominates(DominatingBlock, DominatedBlock)) return true; // TODO: diamond, triangle cases - if (!PDT) return true; + if (!PDT) + return true; return !PDT->dominates(DominatedBlock, DominatingBlock); }; @@ -566,35 +596,47 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { } // Return Instruction before which we can insert freeze for the value V as close -// to def as possible. If there is no place to add freeze, return nullptr. -static Instruction *getFreezeInsertPt(Value *V, const DominatorTree &DT) { +// to def as possible. If there is no place to add freeze, return empty. +static std::optional<BasicBlock::iterator> +getFreezeInsertPt(Value *V, const DominatorTree &DT) { auto *I = dyn_cast<Instruction>(V); if (!I) - return &*DT.getRoot()->getFirstNonPHIOrDbgOrAlloca(); + return DT.getRoot()->getFirstNonPHIOrDbgOrAlloca()->getIterator(); - auto *Res = I->getInsertionPointAfterDef(); + std::optional<BasicBlock::iterator> Res = I->getInsertionPointAfterDef(); // If there is no place to add freeze - return nullptr. - if (!Res || !DT.dominates(I, Res)) - return nullptr; + if (!Res || !DT.dominates(I, &**Res)) + return std::nullopt; + + Instruction *ResInst = &**Res; // If there is a User dominated by original I, then it should be dominated // by Freeze instruction as well. if (any_of(I->users(), [&](User *U) { Instruction *User = cast<Instruction>(U); - return Res != User && DT.dominates(I, User) && !DT.dominates(Res, User); + return ResInst != User && DT.dominates(I, User) && + !DT.dominates(ResInst, User); })) - return nullptr; + return std::nullopt; return Res; } Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { if (isGuaranteedNotToBePoison(Orig, nullptr, InsertPt, &DT)) return Orig; - Instruction *InsertPtAtDef = getFreezeInsertPt(Orig, DT); - if (!InsertPtAtDef) - return new FreezeInst(Orig, "gw.freeze", InsertPt); - if (isa<Constant>(Orig) || isa<GlobalValue>(Orig)) - return new FreezeInst(Orig, "gw.freeze", InsertPtAtDef); + std::optional<BasicBlock::iterator> InsertPtAtDef = + getFreezeInsertPt(Orig, DT); + if (!InsertPtAtDef) { + FreezeInst *FI = new FreezeInst(Orig, "gw.freeze"); + FI->insertBefore(InsertPt); + return FI; + } + if (isa<Constant>(Orig) || isa<GlobalValue>(Orig)) { + BasicBlock::iterator InsertPt = *InsertPtAtDef; + FreezeInst *FI = new FreezeInst(Orig, "gw.freeze"); + FI->insertBefore(*InsertPt->getParent(), InsertPt); + return FI; + } SmallSet<Value *, 16> Visited; SmallVector<Value *, 16> Worklist; @@ -613,8 +655,10 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { if (Visited.insert(Def).second) { if (isGuaranteedNotToBePoison(Def, nullptr, InsertPt, &DT)) return true; - CacheOfFreezes[Def] = new FreezeInst(Def, Def->getName() + ".gw.fr", - getFreezeInsertPt(Def, DT)); + BasicBlock::iterator InsertPt = *getFreezeInsertPt(Def, DT); + FreezeInst *FI = new FreezeInst(Def, Def->getName() + ".gw.fr"); + FI->insertBefore(*InsertPt->getParent(), InsertPt); + CacheOfFreezes[Def] = FI; } if (CacheOfFreezes.count(Def)) @@ -655,8 +699,9 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { Value *Result = Orig; for (Value *V : NeedFreeze) { - auto *FreezeInsertPt = getFreezeInsertPt(V, DT); - FreezeInst *FI = new FreezeInst(V, V->getName() + ".gw.fr", FreezeInsertPt); + BasicBlock::iterator FreezeInsertPt = *getFreezeInsertPt(V, DT); + FreezeInst *FI = new FreezeInst(V, V->getName() + ".gw.fr"); + FI->insertBefore(*FreezeInsertPt->getParent(), FreezeInsertPt); ++FreezeAdded; if (V == Orig) Result = FI; @@ -667,20 +712,25 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { return Result; } -bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, - Instruction *InsertPt, Value *&Result, - bool InvertCondition) { +std::optional<Value *> +GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden, + Instruction *InsertPt) { using namespace llvm::PatternMatch; + Value *Result = nullptr; { // L >u C0 && L >u C1 -> L >u max(C0, C1) ConstantInt *RHS0, *RHS1; Value *LHS; ICmpInst::Predicate Pred0, Pred1; - if (match(Cond0, m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) && - match(Cond1, m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) { - if (InvertCondition) - Pred1 = ICmpInst::getInversePredicate(Pred1); + // TODO: Support searching for pairs to merge from both whole lists of + // ChecksToHoist and ChecksToWiden. + if (ChecksToWiden.size() == 1 && ChecksToHoist.size() == 1 && + match(ChecksToWiden.front(), + m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) && + match(ChecksToHoist.front(), + m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) { ConstantRange CR0 = ConstantRange::makeExactICmpRegion(Pred0, RHS0->getValue()); @@ -697,12 +747,12 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, if (Intersect->getEquivalentICmp(Pred, NewRHSAP)) { if (InsertPt) { ConstantInt *NewRHS = - ConstantInt::get(Cond0->getContext(), NewRHSAP); + ConstantInt::get(InsertPt->getContext(), NewRHSAP); assert(canBeHoistedTo(LHS, InsertPt) && "must be"); makeAvailableAt(LHS, InsertPt); Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); } - return true; + return Result; } } } @@ -710,12 +760,10 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, { SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks; - // TODO: Support InvertCondition case? - if (!InvertCondition && - parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && + if (parseRangeChecks(ChecksToWiden, Checks) && + parseRangeChecks(ChecksToHoist, Checks) && combineRangeChecks(Checks, CombinedChecks)) { if (InsertPt) { - Result = nullptr; for (auto &RC : CombinedChecks) { makeAvailableAt(RC.getCheckInst(), InsertPt); if (Result) @@ -728,40 +776,32 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, Result->setName("wide.chk"); Result = freezeAndPush(Result, InsertPt); } - return true; + return Result; } } + // We were not able to compute ChecksToHoist AND ChecksToWiden for the price + // of one. + return std::nullopt; +} - // Base case -- just logical-and the two conditions together. - - if (InsertPt) { - makeAvailableAt(Cond0, InsertPt); - makeAvailableAt(Cond1, InsertPt); - if (InvertCondition) - Cond1 = BinaryOperator::CreateNot(Cond1, "inverted", InsertPt); - Cond1 = freezeAndPush(Cond1, InsertPt); - Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt); - } - - // We were not able to compute Cond0 AND Cond1 for the price of one. - return false; +Value *GuardWideningImpl::hoistChecks(SmallVectorImpl<Value *> &ChecksToHoist, + Value *OldCondition, + Instruction *InsertPt) { + assert(!ChecksToHoist.empty()); + IRBuilder<> Builder(InsertPt); + makeAvailableAt(ChecksToHoist, InsertPt); + makeAvailableAt(OldCondition, InsertPt); + Value *Result = Builder.CreateAnd(ChecksToHoist); + Result = freezeAndPush(Result, InsertPt); + Result = Builder.CreateAnd(OldCondition, Result); + Result->setName("wide.chk"); + return Result; } bool GuardWideningImpl::parseRangeChecks( - Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, - SmallPtrSetImpl<const Value *> &Visited) { - if (!Visited.insert(CheckCond).second) - return true; - + Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks) { using namespace llvm::PatternMatch; - { - Value *AndLHS, *AndRHS; - if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS)))) - return parseRangeChecks(AndLHS, Checks) && - parseRangeChecks(AndRHS, Checks); - } - auto *IC = dyn_cast<ICmpInst>(CheckCond); if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() || (IC->getPredicate() != ICmpInst::ICMP_ULT && @@ -934,6 +974,15 @@ StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { PreservedAnalyses GuardWideningPass::run(Function &F, FunctionAnalysisManager &AM) { + // Avoid requesting analyses if there are no guards or widenable conditions. + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty(); + auto *WCDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_widenable_condition)); + bool HasWidenableConditions = WCDecl && !WCDecl->use_empty(); + if (!HasIntrinsicGuards && !HasWidenableConditions) + return PreservedAnalyses::all(); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); @@ -976,109 +1025,3 @@ PreservedAnalyses GuardWideningPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { -struct GuardWideningLegacyPass : public FunctionPass { - static char ID; - - GuardWideningLegacyPass() : FunctionPass(ID) { - initializeGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - std::unique_ptr<MemorySSAUpdater> MSSAU; - if (MSSAWP) - MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); - return GuardWideningImpl(DT, &PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, - DT.getRootNode(), - [](BasicBlock *) { return true; }) - .run(); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } -}; - -/// Same as above, but restricted to a single loop at a time. Can be -/// scheduled with other loop passes w/o breaking out of LPM -struct LoopGuardWideningLegacyPass : public LoopPass { - static char ID; - - LoopGuardWideningLegacyPass() : LoopPass(ID) { - initializeLoopGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - *L->getHeader()->getParent()); - auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); - auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; - auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - std::unique_ptr<MemorySSAUpdater> MSSAU; - if (MSSAWP) - MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); - - BasicBlock *RootBB = L->getLoopPredecessor(); - if (!RootBB) - RootBB = L->getHeader(); - auto BlockFilter = [&](BasicBlock *BB) { - return BB == RootBB || L->contains(BB); - }; - return GuardWideningImpl(DT, PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, - DT.getNode(RootBB), BlockFilter) - .run(); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - getLoopAnalysisUsage(AU); - AU.addPreserved<PostDominatorTreeWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } -}; -} - -char GuardWideningLegacyPass::ID = 0; -char LoopGuardWideningLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(GuardWideningLegacyPass, "guard-widening", "Widen guards", - false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(GuardWideningLegacyPass, "guard-widening", "Widen guards", - false, false) - -INITIALIZE_PASS_BEGIN(LoopGuardWideningLegacyPass, "loop-guard-widening", - "Widen guards (within a single loop, as a loop pass)", - false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(LoopGuardWideningLegacyPass, "loop-guard-widening", - "Widen guards (within a single loop, as a loop pass)", - false, false) - -FunctionPass *llvm::createGuardWideningPass() { - return new GuardWideningLegacyPass(); -} - -Pass *llvm::createLoopGuardWideningPass() { - return new LoopGuardWideningLegacyPass(); -} diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 40475d9563b2..41c4d6236173 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1997,20 +1997,12 @@ bool IndVarSimplify::run(Loop *L) { TTI, PreHeader->getTerminator())) continue; - // Check preconditions for proper SCEVExpander operation. SCEV does not - // express SCEVExpander's dependencies, such as LoopSimplify. Instead - // any pass that uses the SCEVExpander must do it. This does not work - // well for loop passes because SCEVExpander makes assumptions about - // all loops, while LoopPassManager only forces the current loop to be - // simplified. - // - // FIXME: SCEV expansion has no way to bail out, so the caller must - // explicitly check any assumptions made by SCEV. Brittle. - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ExitCount); - if (!AR || AR->getLoop()->getLoopPreheader()) - Changed |= linearFunctionTestReplace(L, ExitingBB, - ExitCount, IndVar, - Rewriter); + if (!Rewriter.isSafeToExpand(ExitCount)) + continue; + + Changed |= linearFunctionTestReplace(L, ExitingBB, + ExitCount, IndVar, + Rewriter); } } // Clear the rewriter cache, because values that are in the rewriter's cache diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index b52589baeee7..5f82af1ca46d 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -81,6 +81,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopConstrainer.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" @@ -91,7 +92,6 @@ #include <limits> #include <optional> #include <utility> -#include <vector> using namespace llvm; using namespace llvm::PatternMatch; @@ -129,8 +129,6 @@ static cl::opt<bool> PrintScaledBoundaryRangeChecks("irce-print-scaled-boundary-range-checks", cl::Hidden, cl::init(false)); -static const char *ClonedLoopTag = "irce.loop.clone"; - #define DEBUG_TYPE "irce" namespace { @@ -241,8 +239,6 @@ public: SmallVectorImpl<InductiveRangeCheck> &Checks, bool &Changed); }; -struct LoopStructure; - class InductiveRangeCheckElimination { ScalarEvolution &SE; BranchProbabilityInfo *BPI; @@ -554,649 +550,6 @@ void InductiveRangeCheck::extractRangeChecksFromBranch( Checks, Visited); } -// Add metadata to the loop L to disable loop optimizations. Callers need to -// confirm that optimizing loop L is not beneficial. -static void DisableAllLoopOptsOnLoop(Loop &L) { - // We do not care about any existing loopID related metadata for L, since we - // are setting all loop metadata to false. - LLVMContext &Context = L.getHeader()->getContext(); - // Reserve first location for self reference to the LoopID metadata node. - MDNode *Dummy = MDNode::get(Context, {}); - MDNode *DisableUnroll = MDNode::get( - Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); - Metadata *FalseVal = - ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); - MDNode *DisableVectorize = MDNode::get( - Context, - {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); - MDNode *DisableLICMVersioning = MDNode::get( - Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); - MDNode *DisableDistribution= MDNode::get( - Context, - {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); - MDNode *NewLoopID = - MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, - DisableLICMVersioning, DisableDistribution}); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - L.setLoopID(NewLoopID); -} - -namespace { - -// Keeps track of the structure of a loop. This is similar to llvm::Loop, -// except that it is more lightweight and can track the state of a loop through -// changing and potentially invalid IR. This structure also formalizes the -// kinds of loops we can deal with -- ones that have a single latch that is also -// an exiting block *and* have a canonical induction variable. -struct LoopStructure { - const char *Tag = ""; - - BasicBlock *Header = nullptr; - BasicBlock *Latch = nullptr; - - // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th - // successor is `LatchExit', the exit block of the loop. - BranchInst *LatchBr = nullptr; - BasicBlock *LatchExit = nullptr; - unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max(); - - // The loop represented by this instance of LoopStructure is semantically - // equivalent to: - // - // intN_ty inc = IndVarIncreasing ? 1 : -1; - // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; - // - // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase) - // ... body ... - - Value *IndVarBase = nullptr; - Value *IndVarStart = nullptr; - Value *IndVarStep = nullptr; - Value *LoopExitAt = nullptr; - bool IndVarIncreasing = false; - bool IsSignedPredicate = true; - - LoopStructure() = default; - - template <typename M> LoopStructure map(M Map) const { - LoopStructure Result; - Result.Tag = Tag; - Result.Header = cast<BasicBlock>(Map(Header)); - Result.Latch = cast<BasicBlock>(Map(Latch)); - Result.LatchBr = cast<BranchInst>(Map(LatchBr)); - Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); - Result.LatchBrExitIdx = LatchBrExitIdx; - Result.IndVarBase = Map(IndVarBase); - Result.IndVarStart = Map(IndVarStart); - Result.IndVarStep = Map(IndVarStep); - Result.LoopExitAt = Map(LoopExitAt); - Result.IndVarIncreasing = IndVarIncreasing; - Result.IsSignedPredicate = IsSignedPredicate; - return Result; - } - - static std::optional<LoopStructure> parseLoopStructure(ScalarEvolution &, - Loop &, const char *&); -}; - -/// This class is used to constrain loops to run within a given iteration space. -/// The algorithm this class implements is given a Loop and a range [Begin, -/// End). The algorithm then tries to break out a "main loop" out of the loop -/// it is given in a way that the "main loop" runs with the induction variable -/// in a subset of [Begin, End). The algorithm emits appropriate pre and post -/// loops to run any remaining iterations. The pre loop runs any iterations in -/// which the induction variable is < Begin, and the post loop runs any -/// iterations in which the induction variable is >= End. -class LoopConstrainer { - // The representation of a clone of the original loop we started out with. - struct ClonedLoop { - // The cloned blocks - std::vector<BasicBlock *> Blocks; - - // `Map` maps values in the clonee into values in the cloned version - ValueToValueMapTy Map; - - // An instance of `LoopStructure` for the cloned loop - LoopStructure Structure; - }; - - // Result of rewriting the range of a loop. See changeIterationSpaceEnd for - // more details on what these fields mean. - struct RewrittenRangeInfo { - BasicBlock *PseudoExit = nullptr; - BasicBlock *ExitSelector = nullptr; - std::vector<PHINode *> PHIValuesAtPseudoExit; - PHINode *IndVarEnd = nullptr; - - RewrittenRangeInfo() = default; - }; - - // Calculated subranges we restrict the iteration space of the main loop to. - // See the implementation of `calculateSubRanges' for more details on how - // these fields are computed. `LowLimit` is std::nullopt if there is no - // restriction on low end of the restricted iteration space of the main loop. - // `HighLimit` is std::nullopt if there is no restriction on high end of the - // restricted iteration space of the main loop. - - struct SubRanges { - std::optional<const SCEV *> LowLimit; - std::optional<const SCEV *> HighLimit; - }; - - // Compute a safe set of limits for the main loop to run in -- effectively the - // intersection of `Range' and the iteration space of the original loop. - // Return std::nullopt if unable to compute the set of subranges. - std::optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; - - // Clone `OriginalLoop' and return the result in CLResult. The IR after - // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- - // the PHI nodes say that there is an incoming edge from `OriginalPreheader` - // but there is no such edge. - void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; - - // Create the appropriate loop structure needed to describe a cloned copy of - // `Original`. The clone is described by `VM`. - Loop *createClonedLoopStructure(Loop *Original, Loop *Parent, - ValueToValueMapTy &VM, bool IsSubloop); - - // Rewrite the iteration space of the loop denoted by (LS, Preheader). The - // iteration space of the rewritten loop ends at ExitLoopAt. The start of the - // iteration space is not changed. `ExitLoopAt' is assumed to be slt - // `OriginalHeaderCount'. - // - // If there are iterations left to execute, control is made to jump to - // `ContinuationBlock', otherwise they take the normal loop exit. The - // returned `RewrittenRangeInfo' object is populated as follows: - // - // .PseudoExit is a basic block that unconditionally branches to - // `ContinuationBlock'. - // - // .ExitSelector is a basic block that decides, on exit from the loop, - // whether to branch to the "true" exit or to `PseudoExit'. - // - // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value - // for each PHINode in the loop header on taking the pseudo exit. - // - // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate - // preheader because it is made to branch to the loop header only - // conditionally. - RewrittenRangeInfo - changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, - Value *ExitLoopAt, - BasicBlock *ContinuationBlock) const; - - // The loop denoted by `LS' has `OldPreheader' as its preheader. This - // function creates a new preheader for `LS' and returns it. - BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, - const char *Tag) const; - - // `ContinuationBlockAndPreheader' was the continuation block for some call to - // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. - // This function rewrites the PHI nodes in `LS.Header' to start with the - // correct value. - void rewriteIncomingValuesForPHIs( - LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader, - const LoopConstrainer::RewrittenRangeInfo &RRI) const; - - // Even though we do not preserve any passes at this time, we at least need to - // keep the parent loop structure consistent. The `LPPassManager' seems to - // verify this after running a loop pass. This function adds the list of - // blocks denoted by BBs to this loops parent loop if required. - void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs); - - // Some global state. - Function &F; - LLVMContext &Ctx; - ScalarEvolution &SE; - DominatorTree &DT; - LoopInfo &LI; - function_ref<void(Loop *, bool)> LPMAddNewLoop; - - // Information about the original loop we started out with. - Loop &OriginalLoop; - - const IntegerType *ExitCountTy = nullptr; - BasicBlock *OriginalPreheader = nullptr; - - // The preheader of the main loop. This may or may not be different from - // `OriginalPreheader'. - BasicBlock *MainLoopPreheader = nullptr; - - // The range we need to run the main loop in. - InductiveRangeCheck::Range Range; - - // The structure of the main loop (see comment at the beginning of this class - // for a definition) - LoopStructure MainLoopStructure; - -public: - LoopConstrainer(Loop &L, LoopInfo &LI, - function_ref<void(Loop *, bool)> LPMAddNewLoop, - const LoopStructure &LS, ScalarEvolution &SE, - DominatorTree &DT, InductiveRangeCheck::Range R) - : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), - SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), - Range(R), MainLoopStructure(LS) {} - - // Entry point for the algorithm. Returns true on success. - bool run(); -}; - -} // end anonymous namespace - -/// Given a loop with an deccreasing induction variable, is it possible to -/// safely calculate the bounds of a new loop using the given Predicate. -static bool isSafeDecreasingBound(const SCEV *Start, - const SCEV *BoundSCEV, const SCEV *Step, - ICmpInst::Predicate Pred, - unsigned LatchBrExitIdx, - Loop *L, ScalarEvolution &SE) { - if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && - Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) - return false; - - if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) - return false; - - assert(SE.isKnownNegative(Step) && "expecting negative step"); - - LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n"); - LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); - LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); - LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); - LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n"); - LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); - - bool IsSigned = ICmpInst::isSigned(Pred); - // The predicate that we need to check that the induction variable lies - // within bounds. - ICmpInst::Predicate BoundPred = - IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; - - if (LatchBrExitIdx == 1) - return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); - - assert(LatchBrExitIdx == 0 && - "LatchBrExitIdx should be either 0 or 1"); - - const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); - unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); - APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) : - APInt::getMinValue(BitWidth); - const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); - - const SCEV *MinusOne = - SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); - - return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && - SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); - -} - -/// Given a loop with an increasing induction variable, is it possible to -/// safely calculate the bounds of a new loop using the given Predicate. -static bool isSafeIncreasingBound(const SCEV *Start, - const SCEV *BoundSCEV, const SCEV *Step, - ICmpInst::Predicate Pred, - unsigned LatchBrExitIdx, - Loop *L, ScalarEvolution &SE) { - if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && - Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) - return false; - - if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) - return false; - - LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n"); - LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); - LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); - LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); - LLVM_DEBUG(dbgs() << "irce: Pred: " << Pred << "\n"); - LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); - - bool IsSigned = ICmpInst::isSigned(Pred); - // The predicate that we need to check that the induction variable lies - // within bounds. - ICmpInst::Predicate BoundPred = - IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; - - if (LatchBrExitIdx == 1) - return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); - - assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); - - const SCEV *StepMinusOne = - SE.getMinusSCEV(Step, SE.getOne(Step->getType())); - unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); - APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) : - APInt::getMaxValue(BitWidth); - const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); - - return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, - SE.getAddExpr(BoundSCEV, Step)) && - SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); -} - -/// Returns estimate for max latch taken count of the loop of the narrowest -/// available type. If the latch block has such estimate, it is returned. -/// Otherwise, we use max exit count of whole loop (that is potentially of wider -/// type than latch check itself), which is still better than no estimate. -static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, - const Loop &L) { - const SCEV *FromBlock = - SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); - if (isa<SCEVCouldNotCompute>(FromBlock)) - return SE.getSymbolicMaxBackedgeTakenCount(&L); - return FromBlock; -} - -std::optional<LoopStructure> -LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, - const char *&FailureReason) { - if (!L.isLoopSimplifyForm()) { - FailureReason = "loop not in LoopSimplify form"; - return std::nullopt; - } - - BasicBlock *Latch = L.getLoopLatch(); - assert(Latch && "Simplified loops only have one latch!"); - - if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { - FailureReason = "loop has already been cloned"; - return std::nullopt; - } - - if (!L.isLoopExiting(Latch)) { - FailureReason = "no loop latch"; - return std::nullopt; - } - - BasicBlock *Header = L.getHeader(); - BasicBlock *Preheader = L.getLoopPreheader(); - if (!Preheader) { - FailureReason = "no preheader"; - return std::nullopt; - } - - BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); - if (!LatchBr || LatchBr->isUnconditional()) { - FailureReason = "latch terminator not conditional branch"; - return std::nullopt; - } - - unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; - - ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); - if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { - FailureReason = "latch terminator branch not conditional on integral icmp"; - return std::nullopt; - } - - const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); - if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) { - FailureReason = "could not compute latch count"; - return std::nullopt; - } - assert(SE.getLoopDisposition(MaxBETakenCount, &L) == - ScalarEvolution::LoopInvariant && - "loop variant exit count doesn't make sense!"); - - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *LeftValue = ICI->getOperand(0); - const SCEV *LeftSCEV = SE.getSCEV(LeftValue); - IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); - - Value *RightValue = ICI->getOperand(1); - const SCEV *RightSCEV = SE.getSCEV(RightValue); - - // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. - if (!isa<SCEVAddRecExpr>(LeftSCEV)) { - if (isa<SCEVAddRecExpr>(RightSCEV)) { - std::swap(LeftSCEV, RightSCEV); - std::swap(LeftValue, RightValue); - Pred = ICmpInst::getSwappedPredicate(Pred); - } else { - FailureReason = "no add recurrences in the icmp"; - return std::nullopt; - } - } - - auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { - if (AR->getNoWrapFlags(SCEV::FlagNSW)) - return true; - - IntegerType *Ty = cast<IntegerType>(AR->getType()); - IntegerType *WideTy = - IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); - - const SCEVAddRecExpr *ExtendAfterOp = - dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); - if (ExtendAfterOp) { - const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); - const SCEV *ExtendedStep = - SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); - - bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && - ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; - - if (NoSignedWrap) - return true; - } - - // We may have proved this when computing the sign extension above. - return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; - }; - - // `ICI` is interpreted as taking the backedge if the *next* value of the - // induction variable satisfies some constraint. - - const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); - if (IndVarBase->getLoop() != &L) { - FailureReason = "LHS in cmp is not an AddRec for this loop"; - return std::nullopt; - } - if (!IndVarBase->isAffine()) { - FailureReason = "LHS in icmp not induction variable"; - return std::nullopt; - } - const SCEV* StepRec = IndVarBase->getStepRecurrence(SE); - if (!isa<SCEVConstant>(StepRec)) { - FailureReason = "LHS in icmp not induction variable"; - return std::nullopt; - } - ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); - - if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { - FailureReason = "LHS in icmp needs nsw for equality predicates"; - return std::nullopt; - } - - assert(!StepCI->isZero() && "Zero step?"); - bool IsIncreasing = !StepCI->isNegative(); - bool IsSignedPredicate; - const SCEV *StartNext = IndVarBase->getStart(); - const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); - const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); - const SCEV *Step = SE.getSCEV(StepCI); - - const SCEV *FixedRightSCEV = nullptr; - - // If RightValue resides within loop (but still being loop invariant), - // regenerate it as preheader. - if (auto *I = dyn_cast<Instruction>(RightValue)) - if (L.contains(I->getParent())) - FixedRightSCEV = RightSCEV; - - if (IsIncreasing) { - bool DecreasedRightValueByOne = false; - if (StepCI->isOne()) { - // Try to turn eq/ne predicates to those we can work with. - if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) - // while (++i != len) { while (++i < len) { - // ... ---> ... - // } } - // If both parts are known non-negative, it is profitable to use - // unsigned comparison in increasing loop. This allows us to make the - // comparison check against "RightSCEV + 1" more optimistic. - if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && - isKnownNonNegativeInLoop(RightSCEV, &L, SE)) - Pred = ICmpInst::ICMP_ULT; - else - Pred = ICmpInst::ICMP_SLT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { - // while (true) { while (true) { - // if (++i == len) ---> if (++i > len - 1) - // break; break; - // ... ... - // } } - if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && - cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) { - Pred = ICmpInst::ICMP_UGT; - RightSCEV = SE.getMinusSCEV(RightSCEV, - SE.getOne(RightSCEV->getType())); - DecreasedRightValueByOne = true; - } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) { - Pred = ICmpInst::ICMP_SGT; - RightSCEV = SE.getMinusSCEV(RightSCEV, - SE.getOne(RightSCEV->getType())); - DecreasedRightValueByOne = true; - } - } - } - - bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); - bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); - bool FoundExpectedPred = - (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); - - if (!FoundExpectedPred) { - FailureReason = "expected icmp slt semantically, found something else"; - return std::nullopt; - } - - IsSignedPredicate = ICmpInst::isSigned(Pred); - if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { - FailureReason = "unsigned latch conditions are explicitly prohibited"; - return std::nullopt; - } - - if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, - LatchBrExitIdx, &L, SE)) { - FailureReason = "Unsafe loop bounds"; - return std::nullopt; - } - if (LatchBrExitIdx == 0) { - // We need to increase the right value unless we have already decreased - // it virtually when we replaced EQ with SGT. - if (!DecreasedRightValueByOne) - FixedRightSCEV = - SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - } else { - assert(!DecreasedRightValueByOne && - "Right value can be decreased only for LatchBrExitIdx == 0!"); - } - } else { - bool IncreasedRightValueByOne = false; - if (StepCI->isMinusOne()) { - // Try to turn eq/ne predicates to those we can work with. - if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) - // while (--i != len) { while (--i > len) { - // ... ---> ... - // } } - // We intentionally don't turn the predicate into UGT even if we know - // that both operands are non-negative, because it will only pessimize - // our check against "RightSCEV - 1". - Pred = ICmpInst::ICMP_SGT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { - // while (true) { while (true) { - // if (--i == len) ---> if (--i < len + 1) - // break; break; - // ... ... - // } } - if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && - cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { - Pred = ICmpInst::ICMP_ULT; - RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - IncreasedRightValueByOne = true; - } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { - Pred = ICmpInst::ICMP_SLT; - RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - IncreasedRightValueByOne = true; - } - } - } - - bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); - bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); - - bool FoundExpectedPred = - (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); - - if (!FoundExpectedPred) { - FailureReason = "expected icmp sgt semantically, found something else"; - return std::nullopt; - } - - IsSignedPredicate = - Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; - - if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { - FailureReason = "unsigned latch conditions are explicitly prohibited"; - return std::nullopt; - } - - if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, - LatchBrExitIdx, &L, SE)) { - FailureReason = "Unsafe bounds"; - return std::nullopt; - } - - if (LatchBrExitIdx == 0) { - // We need to decrease the right value unless we have already increased - // it virtually when we replaced EQ with SLT. - if (!IncreasedRightValueByOne) - FixedRightSCEV = - SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); - } else { - assert(!IncreasedRightValueByOne && - "Right value can be increased only for LatchBrExitIdx == 0!"); - } - } - BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); - - assert(!L.contains(LatchExit) && "expected an exit block!"); - const DataLayout &DL = Preheader->getModule()->getDataLayout(); - SCEVExpander Expander(SE, DL, "irce"); - Instruction *Ins = Preheader->getTerminator(); - - if (FixedRightSCEV) - RightValue = - Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); - - Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); - IndVarStartV->setName("indvar.start"); - - LoopStructure Result; - - Result.Tag = "main"; - Result.Header = Header; - Result.Latch = Latch; - Result.LatchBr = LatchBr; - Result.LatchExit = LatchExit; - Result.LatchBrExitIdx = LatchBrExitIdx; - Result.IndVarStart = IndVarStartV; - Result.IndVarStep = StepCI; - Result.IndVarBase = LeftValue; - Result.IndVarIncreasing = IsIncreasing; - Result.LoopExitAt = RightValue; - Result.IsSignedPredicate = IsSignedPredicate; - - FailureReason = nullptr; - - return Result; -} - /// If the type of \p S matches with \p Ty, return \p S. Otherwise, return /// signed or unsigned extension of \p S to type \p Ty. static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, @@ -1204,17 +557,23 @@ static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty); } -std::optional<LoopConstrainer::SubRanges> -LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { +// Compute a safe set of limits for the main loop to run in -- effectively the +// intersection of `Range' and the iteration space of the original loop. +// Return std::nullopt if unable to compute the set of subranges. +static std::optional<LoopConstrainer::SubRanges> +calculateSubRanges(ScalarEvolution &SE, const Loop &L, + InductiveRangeCheck::Range &Range, + const LoopStructure &MainLoopStructure) { auto *RTy = cast<IntegerType>(Range.getType()); // We only support wide range checks and narrow latches. - if (!AllowNarrowLatchCondition && RTy != ExitCountTy) + if (!AllowNarrowLatchCondition && RTy != MainLoopStructure.ExitCountTy) return std::nullopt; - if (RTy->getBitWidth() < ExitCountTy->getBitWidth()) + if (RTy->getBitWidth() < MainLoopStructure.ExitCountTy->getBitWidth()) return std::nullopt; LoopConstrainer::SubRanges Result; + bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; // I think we can be more aggressive here and make this nuw / nsw if the // addition that feeds into the icmp for the latch's terminating branch is nuw // / nsw. In any case, a wrapping 2's complement addition is safe. @@ -1245,7 +604,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { // `End`, decrementing by one every time. // // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the - // induction variable is decreasing we know that that the smallest value + // induction variable is decreasing we know that the smallest value // the loop body is actually executed with is `INT_SMIN` == `Smallest`. // // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In @@ -1258,7 +617,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { GreatestSeen = Start; } - auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { + auto Clamp = [&SE, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { return IsSignedPredicate ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)) : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S)); @@ -1283,464 +642,6 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { return Result; } -void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, - const char *Tag) const { - for (BasicBlock *BB : OriginalLoop.getBlocks()) { - BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); - Result.Blocks.push_back(Clone); - Result.Map[BB] = Clone; - } - - auto GetClonedValue = [&Result](Value *V) { - assert(V && "null values not in domain!"); - auto It = Result.Map.find(V); - if (It == Result.Map.end()) - return V; - return static_cast<Value *>(It->second); - }; - - auto *ClonedLatch = - cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); - ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, - MDNode::get(Ctx, {})); - - Result.Structure = MainLoopStructure.map(GetClonedValue); - Result.Structure.Tag = Tag; - - for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { - BasicBlock *ClonedBB = Result.Blocks[i]; - BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; - - assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); - - for (Instruction &I : *ClonedBB) - RemapInstruction(&I, Result.Map, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - - // Exit blocks will now have one more predecessor and their PHI nodes need - // to be edited to reflect that. No phi nodes need to be introduced because - // the loop is in LCSSA. - - for (auto *SBB : successors(OriginalBB)) { - if (OriginalLoop.contains(SBB)) - continue; // not an exit block - - for (PHINode &PN : SBB->phis()) { - Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); - PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); - SE.forgetValue(&PN); - } - } - } -} - -LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( - const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, - BasicBlock *ContinuationBlock) const { - // We start with a loop with a single latch: - // - // +--------------------+ - // | | - // | preheader | - // | | - // +--------+-----------+ - // | ----------------\ - // | / | - // +--------v----v------+ | - // | | | - // | header | | - // | | | - // +--------------------+ | - // | - // ..... | - // | - // +--------------------+ | - // | | | - // | latch >----------/ - // | | - // +-------v------------+ - // | - // | - // | +--------------------+ - // | | | - // +---> original exit | - // | | - // +--------------------+ - // - // We change the control flow to look like - // - // - // +--------------------+ - // | | - // | preheader >-------------------------+ - // | | | - // +--------v-----------+ | - // | /-------------+ | - // | / | | - // +--------v--v--------+ | | - // | | | | - // | header | | +--------+ | - // | | | | | | - // +--------------------+ | | +-----v-----v-----------+ - // | | | | - // | | | .pseudo.exit | - // | | | | - // | | +-----------v-----------+ - // | | | - // ..... | | | - // | | +--------v-------------+ - // +--------------------+ | | | | - // | | | | | ContinuationBlock | - // | latch >------+ | | | - // | | | +----------------------+ - // +---------v----------+ | - // | | - // | | - // | +---------------^-----+ - // | | | - // +-----> .exit.selector | - // | | - // +----------v----------+ - // | - // +--------------------+ | - // | | | - // | original exit <----+ - // | | - // +--------------------+ - - RewrittenRangeInfo RRI; - - BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); - RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", - &F, BBInsertLocation); - RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, - BBInsertLocation); - - BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); - bool Increasing = LS.IndVarIncreasing; - bool IsSignedPredicate = LS.IsSignedPredicate; - - IRBuilder<> B(PreheaderJump); - auto *RangeTy = Range.getBegin()->getType(); - auto NoopOrExt = [&](Value *V) { - if (V->getType() == RangeTy) - return V; - return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) - : B.CreateZExt(V, RangeTy, "wide." + V->getName()); - }; - - // EnterLoopCond - is it okay to start executing this `LS'? - Value *EnterLoopCond = nullptr; - auto Pred = - Increasing - ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) - : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); - Value *IndVarStart = NoopOrExt(LS.IndVarStart); - EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); - - B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); - PreheaderJump->eraseFromParent(); - - LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); - B.SetInsertPoint(LS.LatchBr); - Value *IndVarBase = NoopOrExt(LS.IndVarBase); - Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); - - Value *CondForBranch = LS.LatchBrExitIdx == 1 - ? TakeBackedgeLoopCond - : B.CreateNot(TakeBackedgeLoopCond); - - LS.LatchBr->setCondition(CondForBranch); - - B.SetInsertPoint(RRI.ExitSelector); - - // IterationsLeft - are there any more iterations left, given the original - // upper bound on the induction variable? If not, we branch to the "real" - // exit. - Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); - Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); - B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); - - BranchInst *BranchToContinuation = - BranchInst::Create(ContinuationBlock, RRI.PseudoExit); - - // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of - // each of the PHI nodes in the loop header. This feeds into the initial - // value of the same PHI nodes if/when we continue execution. - for (PHINode &PN : LS.Header->phis()) { - PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", - BranchToContinuation); - - NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); - NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), - RRI.ExitSelector); - RRI.PHIValuesAtPseudoExit.push_back(NewPHI); - } - - RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", - BranchToContinuation); - RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); - RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); - - // The latch exit now has a branch from `RRI.ExitSelector' instead of - // `LS.Latch'. The PHI nodes need to be updated to reflect that. - LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); - - return RRI; -} - -void LoopConstrainer::rewriteIncomingValuesForPHIs( - LoopStructure &LS, BasicBlock *ContinuationBlock, - const LoopConstrainer::RewrittenRangeInfo &RRI) const { - unsigned PHIIndex = 0; - for (PHINode &PN : LS.Header->phis()) - PN.setIncomingValueForBlock(ContinuationBlock, - RRI.PHIValuesAtPseudoExit[PHIIndex++]); - - LS.IndVarStart = RRI.IndVarEnd; -} - -BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, - BasicBlock *OldPreheader, - const char *Tag) const { - BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); - BranchInst::Create(LS.Header, Preheader); - - LS.Header->replacePhiUsesWith(OldPreheader, Preheader); - - return Preheader; -} - -void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { - Loop *ParentLoop = OriginalLoop.getParentLoop(); - if (!ParentLoop) - return; - - for (BasicBlock *BB : BBs) - ParentLoop->addBasicBlockToLoop(BB, LI); -} - -Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, - ValueToValueMapTy &VM, - bool IsSubloop) { - Loop &New = *LI.AllocateLoop(); - if (Parent) - Parent->addChildLoop(&New); - else - LI.addTopLevelLoop(&New); - LPMAddNewLoop(&New, IsSubloop); - - // Add all of the blocks in Original to the new loop. - for (auto *BB : Original->blocks()) - if (LI.getLoopFor(BB) == Original) - New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); - - // Add all of the subloops to the new loop. - for (Loop *SubLoop : *Original) - createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); - - return &New; -} - -bool LoopConstrainer::run() { - BasicBlock *Preheader = nullptr; - const SCEV *MaxBETakenCount = - getNarrowestLatchMaxTakenCountEstimate(SE, OriginalLoop); - Preheader = OriginalLoop.getLoopPreheader(); - assert(!isa<SCEVCouldNotCompute>(MaxBETakenCount) && Preheader != nullptr && - "preconditions!"); - ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType()); - - OriginalPreheader = Preheader; - MainLoopPreheader = Preheader; - - bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; - std::optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); - if (!MaybeSR) { - LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); - return false; - } - - SubRanges SR = *MaybeSR; - bool Increasing = MainLoopStructure.IndVarIncreasing; - IntegerType *IVTy = - cast<IntegerType>(Range.getBegin()->getType()); - - SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); - Instruction *InsertPt = OriginalPreheader->getTerminator(); - - // It would have been better to make `PreLoop' and `PostLoop' - // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy - // constructor. - ClonedLoop PreLoop, PostLoop; - bool NeedsPreLoop = - Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); - bool NeedsPostLoop = - Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); - - Value *ExitPreLoopAt = nullptr; - Value *ExitMainLoopAt = nullptr; - const SCEVConstant *MinusOneS = - cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); - - if (NeedsPreLoop) { - const SCEV *ExitPreLoopAtSCEV = nullptr; - - if (Increasing) - ExitPreLoopAtSCEV = *SR.LowLimit; - else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, - IsSignedPredicate)) - ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); - else { - LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "preloop exit limit. HighLimit = " - << *(*SR.HighLimit) << "\n"); - return false; - } - - if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { - LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" - << " preloop exit limit " << *ExitPreLoopAtSCEV - << " at block " << InsertPt->getParent()->getName() - << "\n"); - return false; - } - - ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); - ExitPreLoopAt->setName("exit.preloop.at"); - } - - if (NeedsPostLoop) { - const SCEV *ExitMainLoopAtSCEV = nullptr; - - if (Increasing) - ExitMainLoopAtSCEV = *SR.HighLimit; - else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, - IsSignedPredicate)) - ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); - else { - LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "mainloop exit limit. LowLimit = " - << *(*SR.LowLimit) << "\n"); - return false; - } - - if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { - LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" - << " main loop exit limit " << *ExitMainLoopAtSCEV - << " at block " << InsertPt->getParent()->getName() - << "\n"); - return false; - } - - ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); - ExitMainLoopAt->setName("exit.mainloop.at"); - } - - // We clone these ahead of time so that we don't have to deal with changing - // and temporarily invalid IR as we transform the loops. - if (NeedsPreLoop) - cloneLoop(PreLoop, "preloop"); - if (NeedsPostLoop) - cloneLoop(PostLoop, "postloop"); - - RewrittenRangeInfo PreLoopRRI; - - if (NeedsPreLoop) { - Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, - PreLoop.Structure.Header); - - MainLoopPreheader = - createPreheader(MainLoopStructure, Preheader, "mainloop"); - PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, - ExitPreLoopAt, MainLoopPreheader); - rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, - PreLoopRRI); - } - - BasicBlock *PostLoopPreheader = nullptr; - RewrittenRangeInfo PostLoopRRI; - - if (NeedsPostLoop) { - PostLoopPreheader = - createPreheader(PostLoop.Structure, Preheader, "postloop"); - PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, - ExitMainLoopAt, PostLoopPreheader); - rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, - PostLoopRRI); - } - - BasicBlock *NewMainLoopPreheader = - MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; - BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, - PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, - PostLoopRRI.ExitSelector, NewMainLoopPreheader}; - - // Some of the above may be nullptr, filter them out before passing to - // addToParentLoopIfNeeded. - auto NewBlocksEnd = - std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); - - addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); - - DT.recalculate(F); - - // We need to first add all the pre and post loop blocks into the loop - // structures (as part of createClonedLoopStructure), and then update the - // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating - // LI when LoopSimplifyForm is generated. - Loop *PreL = nullptr, *PostL = nullptr; - if (!PreLoop.Blocks.empty()) { - PreL = createClonedLoopStructure(&OriginalLoop, - OriginalLoop.getParentLoop(), PreLoop.Map, - /* IsSubLoop */ false); - } - - if (!PostLoop.Blocks.empty()) { - PostL = - createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), - PostLoop.Map, /* IsSubLoop */ false); - } - - // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. - auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) { - formLCSSARecursively(*L, DT, &LI, &SE); - simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); - // Pre/post loops are slow paths, we do not need to perform any loop - // optimizations on them. - if (!IsOriginalLoop) - DisableAllLoopOptsOnLoop(*L); - }; - if (PreL) - CanonicalizeLoop(PreL, false); - if (PostL) - CanonicalizeLoop(PostL, false); - CanonicalizeLoop(&OriginalLoop, true); - - /// At this point: - /// - We've broken a "main loop" out of the loop in a way that the "main loop" - /// runs with the induction variable in a subset of [Begin, End). - /// - There is no overflow when computing "main loop" exit limit. - /// - Max latch taken count of the loop is limited. - /// It guarantees that induction variable will not overflow iterating in the - /// "main loop". - if (auto BO = dyn_cast<BinaryOperator>(MainLoopStructure.IndVarBase)) - if (IsSignedPredicate) - BO->setHasNoSignedWrap(true); - /// TODO: support unsigned predicate. - /// To add NUW flag we need to prove that both operands of BO are - /// non-negative. E.g: - /// ... - /// %iv.next = add nsw i32 %iv, -1 - /// %cmp = icmp ult i32 %iv.next, %n - /// br i1 %cmp, label %loopexit, label %loop - /// - /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will - /// overflow, therefore NUW flag is not legal here. - - return true; -} - /// Computes and returns a range of values for the induction variable (IndVar) /// in which the range check can be safely elided. If it cannot compute such a /// range, returns std::nullopt. @@ -2108,7 +1009,8 @@ bool InductiveRangeCheckElimination::run( const char *FailureReason = nullptr; std::optional<LoopStructure> MaybeLoopStructure = - LoopStructure::parseLoopStructure(SE, *L, FailureReason); + LoopStructure::parseLoopStructure(SE, *L, AllowUnsignedLatchCondition, + FailureReason); if (!MaybeLoopStructure) { LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason << "\n";); @@ -2147,7 +1049,15 @@ bool InductiveRangeCheckElimination::run( if (!SafeIterRange) return Changed; - LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange); + std::optional<LoopConstrainer::SubRanges> MaybeSR = + calculateSubRanges(SE, *L, *SafeIterRange, LS); + if (!MaybeSR) { + LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); + return false; + } + + LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, + SafeIterRange->getBegin()->getType(), *MaybeSR); if (LC.run()) { Changed = true; diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index c2b5a12fd63f..1bf50d79e533 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -164,9 +164,13 @@ class InferAddressSpaces : public FunctionPass { public: static char ID; - InferAddressSpaces() : - FunctionPass(ID), FlatAddrSpace(UninitializedAddressSpace) {} - InferAddressSpaces(unsigned AS) : FunctionPass(ID), FlatAddrSpace(AS) {} + InferAddressSpaces() + : FunctionPass(ID), FlatAddrSpace(UninitializedAddressSpace) { + initializeInferAddressSpacesPass(*PassRegistry::getPassRegistry()); + } + InferAddressSpaces(unsigned AS) : FunctionPass(ID), FlatAddrSpace(AS) { + initializeInferAddressSpacesPass(*PassRegistry::getPassRegistry()); + } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); @@ -221,8 +225,8 @@ class InferAddressSpacesImpl { Value *V, PostorderStackTy &PostorderStack, DenseSet<Value *> &Visited) const; - bool rewriteIntrinsicOperands(IntrinsicInst *II, - Value *OldV, Value *NewV) const; + bool rewriteIntrinsicOperands(IntrinsicInst *II, Value *OldV, + Value *NewV) const; void collectRewritableIntrinsicOperands(IntrinsicInst *II, PostorderStackTy &PostorderStack, DenseSet<Value *> &Visited) const; @@ -473,7 +477,7 @@ void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack( } // Returns all flat address expressions in function F. The elements are ordered -// ordered in postorder. +// in postorder. std::vector<WeakTrackingVH> InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { // This function implements a non-recursive postorder traversal of a partial @@ -483,8 +487,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { DenseSet<Value *> Visited; auto PushPtrOperand = [&](Value *Ptr) { - appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack, - Visited); + appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack, Visited); }; // Look at operations that may be interesting accelerate by moving to a known @@ -519,8 +522,11 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { PushPtrOperand(ASC->getPointerOperand()); } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) { if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI)) - PushPtrOperand( - cast<Operator>(I2P->getOperand(0))->getOperand(0)); + PushPtrOperand(cast<Operator>(I2P->getOperand(0))->getOperand(0)); + } else if (auto *RI = dyn_cast<ReturnInst>(&I)) { + if (auto *RV = RI->getReturnValue(); + RV && RV->getType()->isPtrOrPtrVectorTy()) + PushPtrOperand(RV); } } @@ -923,12 +929,14 @@ bool InferAddressSpacesImpl::updateAddressSpace( Value *Src1 = Op.getOperand(2); auto I = InferredAddrSpace.find(Src0); - unsigned Src0AS = (I != InferredAddrSpace.end()) ? - I->second : Src0->getType()->getPointerAddressSpace(); + unsigned Src0AS = (I != InferredAddrSpace.end()) + ? I->second + : Src0->getType()->getPointerAddressSpace(); auto J = InferredAddrSpace.find(Src1); - unsigned Src1AS = (J != InferredAddrSpace.end()) ? - J->second : Src1->getType()->getPointerAddressSpace(); + unsigned Src1AS = (J != InferredAddrSpace.end()) + ? J->second + : Src1->getType()->getPointerAddressSpace(); auto *C0 = dyn_cast<Constant>(Src0); auto *C1 = dyn_cast<Constant>(Src1); @@ -1097,7 +1105,8 @@ bool InferAddressSpacesImpl::isSafeToCastConstAddrSpace(Constant *C, // If we already have a constant addrspacecast, it should be safe to cast it // off. if (Op->getOpcode() == Instruction::AddrSpaceCast) - return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS); + return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), + NewAS); if (Op->getOpcode() == Instruction::IntToPtr && Op->getType()->getPointerAddressSpace() == FlatAddrSpace) @@ -1128,7 +1137,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( // construction. ValueToValueMapTy ValueWithNewAddrSpace; SmallVector<const Use *, 32> PoisonUsesToFix; - for (Value* V : Postorder) { + for (Value *V : Postorder) { unsigned NewAddrSpace = InferredAddrSpace.lookup(V); // In some degenerate cases (e.g. invalid IR in unreachable code), we may @@ -1161,6 +1170,8 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( } SmallVector<Instruction *, 16> DeadInstructions; + ValueToValueMapTy VMap; + ValueMapper VMapper(VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // Replaces the uses of the old address expressions with the new ones. for (const WeakTrackingVH &WVH : Postorder) { @@ -1174,18 +1185,41 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( << *NewV << '\n'); if (Constant *C = dyn_cast<Constant>(V)) { - Constant *Replace = ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), - C->getType()); + Constant *Replace = + ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), C->getType()); if (C != Replace) { LLVM_DEBUG(dbgs() << "Inserting replacement const cast: " << Replace << ": " << *Replace << '\n'); - C->replaceAllUsesWith(Replace); + SmallVector<User *, 16> WorkList; + for (User *U : make_early_inc_range(C->users())) { + if (auto *I = dyn_cast<Instruction>(U)) { + if (I->getFunction() == F) + I->replaceUsesOfWith(C, Replace); + } else { + WorkList.append(U->user_begin(), U->user_end()); + } + } + if (!WorkList.empty()) { + VMap[C] = Replace; + DenseSet<User *> Visited{WorkList.begin(), WorkList.end()}; + while (!WorkList.empty()) { + User *U = WorkList.pop_back_val(); + if (auto *I = dyn_cast<Instruction>(U)) { + if (I->getFunction() == F) + VMapper.remapInstruction(*I); + continue; + } + for (User *U2 : U->users()) + if (Visited.insert(U2).second) + WorkList.push_back(U2); + } + } V = Replace; } } Value::use_iterator I, E, Next; - for (I = V->use_begin(), E = V->use_end(); I != E; ) { + for (I = V->use_begin(), E = V->use_end(); I != E;) { Use &U = *I; // Some users may see the same pointer operand in multiple operands. Skip @@ -1205,6 +1239,11 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( // Skip if the current user is the new value itself. if (CurUser == NewV) continue; + + if (auto *CurUserI = dyn_cast<Instruction>(CurUser); + CurUserI && CurUserI->getFunction() != F) + continue; + // Handle more complex cases like intrinsic that need to be remangled. if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) { if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV)) @@ -1241,8 +1280,8 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) { if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { Cmp->setOperand(SrcIdx, NewV); - Cmp->setOperand(OtherIdx, - ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType())); + Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast( + KOtherSrc, NewV->getType())); continue; } } diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp new file mode 100644 index 000000000000..b75b8d486fbb --- /dev/null +++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp @@ -0,0 +1,91 @@ +//===- InferAlignment.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Infer alignment for load, stores and other memory operations based on +// trailing zero known bits information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/InferAlignment.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Instructions.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; + +static bool tryToImproveAlign( + const DataLayout &DL, Instruction *I, + function_ref<Align(Value *PtrOp, Align OldAlign, Align PrefAlign)> Fn) { + if (auto *LI = dyn_cast<LoadInst>(I)) { + Value *PtrOp = LI->getPointerOperand(); + Align OldAlign = LI->getAlign(); + Align NewAlign = Fn(PtrOp, OldAlign, DL.getPrefTypeAlign(LI->getType())); + if (NewAlign > OldAlign) { + LI->setAlignment(NewAlign); + return true; + } + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + Value *PtrOp = SI->getPointerOperand(); + Value *ValOp = SI->getValueOperand(); + Align OldAlign = SI->getAlign(); + Align NewAlign = Fn(PtrOp, OldAlign, DL.getPrefTypeAlign(ValOp->getType())); + if (NewAlign > OldAlign) { + SI->setAlignment(NewAlign); + return true; + } + } + // TODO: Also handle memory intrinsics. + return false; +} + +bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) { + const DataLayout &DL = F.getParent()->getDataLayout(); + bool Changed = false; + + // Enforce preferred type alignment if possible. We do this as a separate + // pass first, because it may improve the alignments we infer below. + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + Changed |= tryToImproveAlign( + DL, &I, [&](Value *PtrOp, Align OldAlign, Align PrefAlign) { + if (PrefAlign > OldAlign) + return std::max(OldAlign, + tryEnforceAlignment(PtrOp, PrefAlign, DL)); + return OldAlign; + }); + } + } + + // Compute alignment from known bits. + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + Changed |= tryToImproveAlign( + DL, &I, [&](Value *PtrOp, Align OldAlign, Align PrefAlign) { + KnownBits Known = computeKnownBits(PtrOp, DL, 0, &AC, &I, &DT); + unsigned TrailZ = std::min(Known.countMinTrailingZeros(), + +Value::MaxAlignmentExponent); + return Align(1ull << std::min(Known.getBitWidth() - 1, TrailZ)); + }); + } + } + + return Changed; +} + +PreservedAnalyses InferAlignmentPass::run(Function &F, + FunctionAnalysisManager &AM) { + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + inferAlignment(F, AC, DT); + // Changes to alignment shouldn't invalidated analyses. + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 24390f1b54f6..8603c5cf9c02 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -102,11 +102,6 @@ static cl::opt<unsigned> PhiDuplicateThreshold( cl::desc("Max PHIs in BB to duplicate for jump threading"), cl::init(76), cl::Hidden); -static cl::opt<bool> PrintLVIAfterJumpThreading( - "print-lvi-after-jump-threading", - cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), - cl::Hidden); - static cl::opt<bool> ThreadAcrossLoopHeaders( "jump-threading-across-loop-headers", cl::desc("Allow JumpThreading to thread across loop headers, for testing"), @@ -228,17 +223,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { if (BP >= BranchProbability(50, 100)) continue; - SmallVector<uint32_t, 2> Weights; + uint32_t Weights[2]; if (PredBr->getSuccessor(0) == PredOutEdge.second) { - Weights.push_back(BP.getNumerator()); - Weights.push_back(BP.getCompl().getNumerator()); + Weights[0] = BP.getNumerator(); + Weights[1] = BP.getCompl().getNumerator(); } else { - Weights.push_back(BP.getCompl().getNumerator()); - Weights.push_back(BP.getNumerator()); + Weights[0] = BP.getCompl().getNumerator(); + Weights[1] = BP.getNumerator(); } - PredBr->setMetadata(LLVMContext::MD_prof, - MDBuilder(PredBr->getParent()->getContext()) - .createBranchWeights(Weights)); + setBranchWeights(*PredBr, Weights); } } @@ -259,11 +252,6 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, &DT, nullptr, DomTreeUpdater::UpdateStrategy::Lazy), std::nullopt, std::nullopt); - if (PrintLVIAfterJumpThreading) { - dbgs() << "LVI for function '" << F.getName() << "':\n"; - LVI.printLVI(F, getDomTreeUpdater()->getDomTree(), dbgs()); - } - if (!Changed) return PreservedAnalyses::all(); @@ -412,6 +400,10 @@ static bool replaceFoldableUses(Instruction *Cond, Value *ToVal, if (Cond->getParent() == KnownAtEndOfBB) Changed |= replaceNonLocalUsesWith(Cond, ToVal); for (Instruction &I : reverse(*KnownAtEndOfBB)) { + // Replace any debug-info record users of Cond with ToVal. + for (DPValue &DPV : I.getDbgValueRange()) + DPV.replaceVariableLocationOp(Cond, ToVal, true); + // Reached the Cond whose uses we are trying to replace, so there are no // more uses. if (&I == Cond) @@ -568,6 +560,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( Value *V, BasicBlock *BB, PredValueInfo &Result, ConstantPreference Preference, DenseSet<Value *> &RecursionSet, Instruction *CxtI) { + const DataLayout &DL = BB->getModule()->getDataLayout(); + // This method walks up use-def chains recursively. Because of this, we could // get into an infinite loop going around loops in the use-def chain. To // prevent this, keep track of what (value, block) pairs we've already visited @@ -635,16 +629,19 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( // Handle Cast instructions. if (CastInst *CI = dyn_cast<CastInst>(I)) { Value *Source = CI->getOperand(0); - computeValueKnownInPredecessorsImpl(Source, BB, Result, Preference, + PredValueInfoTy Vals; + computeValueKnownInPredecessorsImpl(Source, BB, Vals, Preference, RecursionSet, CxtI); - if (Result.empty()) + if (Vals.empty()) return false; // Convert the known values. - for (auto &R : Result) - R.first = ConstantExpr::getCast(CI->getOpcode(), R.first, CI->getType()); + for (auto &Val : Vals) + if (Constant *Folded = ConstantFoldCastOperand(CI->getOpcode(), Val.first, + CI->getType(), DL)) + Result.emplace_back(Folded, Val.second); - return true; + return !Result.empty(); } if (FreezeInst *FI = dyn_cast<FreezeInst>(I)) { @@ -726,7 +723,6 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( if (Preference != WantInteger) return false; if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { - const DataLayout &DL = BO->getModule()->getDataLayout(); PredValueInfoTy LHSVals; computeValueKnownInPredecessorsImpl(BO->getOperand(0), BB, LHSVals, WantInteger, RecursionSet, CxtI); @@ -757,7 +753,10 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( PHINode *PN = dyn_cast<PHINode>(CmpLHS); if (!PN) PN = dyn_cast<PHINode>(CmpRHS); - if (PN && PN->getParent() == BB) { + // Do not perform phi translation across a loop header phi, because this + // may result in comparison of values from two different loop iterations. + // FIXME: This check is broken if LoopHeaders is not populated. + if (PN && PN->getParent() == BB && !LoopHeaders.contains(BB)) { const DataLayout &DL = PN->getModule()->getDataLayout(); // We can do this simplification if any comparisons fold to true or false. // See if any do. @@ -1269,6 +1268,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { if (IsLoadCSE) { LoadInst *NLoadI = cast<LoadInst>(AvailableVal); combineMetadataForCSE(NLoadI, LoadI, false); + LVI->forgetValue(NLoadI); }; // If the returned value is the load itself, replace with poison. This can @@ -1432,8 +1432,8 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { // Create a PHI node at the start of the block for the PRE'd load value. pred_iterator PB = pred_begin(LoadBB), PE = pred_end(LoadBB); - PHINode *PN = PHINode::Create(LoadI->getType(), std::distance(PB, PE), "", - &LoadBB->front()); + PHINode *PN = PHINode::Create(LoadI->getType(), std::distance(PB, PE), ""); + PN->insertBefore(LoadBB->begin()); PN->takeName(LoadI); PN->setDebugLoc(LoadI->getDebugLoc()); @@ -1461,6 +1461,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { for (LoadInst *PredLoadI : CSELoads) { combineMetadataForCSE(PredLoadI, LoadI, true); + LVI->forgetValue(PredLoadI); } LoadI->replaceAllUsesWith(PN); @@ -1899,7 +1900,7 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) { return false; const Instruction *TI = SinglePred->getTerminator(); - if (TI->isExceptionalTerminator() || TI->getNumSuccessors() != 1 || + if (TI->isSpecialTerminator() || TI->getNumSuccessors() != 1 || SinglePred == BB || hasAddressTakenAndUsed(BB)) return false; @@ -1954,6 +1955,7 @@ void JumpThreadingPass::updateSSA( SSAUpdater SSAUpdate; SmallVector<Use *, 16> UsesToRename; SmallVector<DbgValueInst *, 4> DbgValues; + SmallVector<DPValue *, 4> DPValues; for (Instruction &I : *BB) { // Scan all uses of this instruction to see if it is used outside of its @@ -1970,15 +1972,16 @@ void JumpThreadingPass::updateSSA( } // Find debug values outside of the block - findDbgValues(DbgValues, &I); - DbgValues.erase(remove_if(DbgValues, - [&](const DbgValueInst *DbgVal) { - return DbgVal->getParent() == BB; - }), - DbgValues.end()); + findDbgValues(DbgValues, &I, &DPValues); + llvm::erase_if(DbgValues, [&](const DbgValueInst *DbgVal) { + return DbgVal->getParent() == BB; + }); + llvm::erase_if(DPValues, [&](const DPValue *DPVal) { + return DPVal->getParent() == BB; + }); // If there are no uses outside the block, we're done with this instruction. - if (UsesToRename.empty() && DbgValues.empty()) + if (UsesToRename.empty() && DbgValues.empty() && DPValues.empty()) continue; LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); @@ -1991,9 +1994,11 @@ void JumpThreadingPass::updateSSA( while (!UsesToRename.empty()) SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); - if (!DbgValues.empty()) { + if (!DbgValues.empty() || !DPValues.empty()) { SSAUpdate.UpdateDebugValues(&I, DbgValues); + SSAUpdate.UpdateDebugValues(&I, DPValues); DbgValues.clear(); + DPValues.clear(); } LLVM_DEBUG(dbgs() << "\n"); @@ -2036,6 +2041,26 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, return true; }; + // Duplicate implementation of the above dbg.value code, using DPValues + // instead. + auto RetargetDPValueIfPossible = [&](DPValue *DPV) { + SmallSet<std::pair<Value *, Value *>, 16> OperandsToRemap; + for (auto *Op : DPV->location_ops()) { + Instruction *OpInst = dyn_cast<Instruction>(Op); + if (!OpInst) + continue; + + auto I = ValueMapping.find(OpInst); + if (I != ValueMapping.end()) + OperandsToRemap.insert({OpInst, I->second}); + } + + for (auto &[OldOp, MappedOp] : OperandsToRemap) + DPV->replaceVariableLocationOp(OldOp, MappedOp); + }; + + BasicBlock *RangeBB = BI->getParent(); + // Clone the phi nodes of the source basic block into NewBB. The resulting // phi nodes are trivial since NewBB only has one predecessor, but SSAUpdater // might need to rewrite the operand of the cloned phi. @@ -2054,6 +2079,12 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, identifyNoAliasScopesToClone(BI, BE, NoAliasScopes); cloneNoAliasScopes(NoAliasScopes, ClonedScopes, "thread", Context); + auto CloneAndRemapDbgInfo = [&](Instruction *NewInst, Instruction *From) { + auto DPVRange = NewInst->cloneDebugInfoFrom(From); + for (DPValue &DPV : DPVRange) + RetargetDPValueIfPossible(&DPV); + }; + // Clone the non-phi instructions of the source basic block into NewBB, // keeping track of the mapping and using it to remap operands in the cloned // instructions. @@ -2064,6 +2095,8 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, ValueMapping[&*BI] = New; adaptNoAliasScopes(New, ClonedScopes, Context); + CloneAndRemapDbgInfo(New, &*BI); + if (RetargetDbgValueIfPossible(New)) continue; @@ -2076,6 +2109,17 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, } } + // There may be DPValues on the terminator, clone directly from marker + // to marker as there isn't an instruction there. + if (BE != RangeBB->end() && BE->hasDbgValues()) { + // Dump them at the end. + DPMarker *Marker = RangeBB->getMarker(BE); + DPMarker *EndMarker = NewBB->createMarker(NewBB->end()); + auto DPVRange = EndMarker->cloneDebugInfoFrom(Marker, std::nullopt); + for (DPValue &DPV : DPVRange) + RetargetDPValueIfPossible(&DPV); + } + return ValueMapping; } @@ -2245,7 +2289,7 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB, assert(BPI && "It's expected BPI to exist along with BFI"); auto NewBBFreq = BFI->getBlockFreq(PredPredBB) * BPI->getEdgeProbability(PredPredBB, PredBB); - BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + BFI->setBlockFreq(NewBB, NewBBFreq); } // We are going to have to map operands from the original BB block to the new @@ -2371,7 +2415,7 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB, assert(BPI && "It's expected BPI to exist along with BFI"); auto NewBBFreq = BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB); - BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + BFI->setBlockFreq(NewBB, NewBBFreq); } // Copy all the instructions from BB to NewBB except the terminator. @@ -2456,7 +2500,7 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, NewBBFreq += FreqMap.lookup(Pred); } if (BFI) // Apply the summed frequency to NewBB. - BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + BFI->setBlockFreq(NewBB, NewBBFreq); } DTU->applyUpdatesPermissive(Updates); @@ -2496,7 +2540,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, auto NewBBFreq = BFI->getBlockFreq(NewBB); auto BB2SuccBBFreq = BBOrigFreq * BPI->getEdgeProbability(BB, SuccBB); auto BBNewFreq = BBOrigFreq - NewBBFreq; - BFI->setBlockFreq(BB, BBNewFreq.getFrequency()); + BFI->setBlockFreq(BB, BBNewFreq); // Collect updated outgoing edges' frequencies from BB and use them to update // edge probabilities. @@ -2567,9 +2611,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, Weights.push_back(Prob.getNumerator()); auto TI = BB->getTerminator(); - TI->setMetadata( - LLVMContext::MD_prof, - MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights)); + setBranchWeights(*TI, Weights); } } @@ -2663,6 +2705,9 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( if (!New->mayHaveSideEffects()) { New->eraseFromParent(); New = nullptr; + // Clone debug-info on the elided instruction to the destination + // position. + OldPredBranch->cloneDebugInfoFrom(&*BI, std::nullopt, true); } } else { ValueMapping[&*BI] = New; @@ -2670,6 +2715,8 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( if (New) { // Otherwise, insert the new instruction into the block. New->setName(BI->getName()); + // Clone across any debug-info attached to the old instruction. + New->cloneDebugInfoFrom(&*BI); // Update Dominance from simplified New instruction operands. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) if (BasicBlock *SuccBB = dyn_cast<BasicBlock>(New->getOperand(i))) @@ -2754,7 +2801,7 @@ void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB, BranchProbability PredToNewBBProb = BranchProbability::getBranchProbability( TrueWeight, TrueWeight + FalseWeight); auto NewBBFreq = BFI->getBlockFreq(Pred) * PredToNewBBProb; - BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + BFI->setBlockFreq(NewBB, NewBBFreq); } // The select is now dead. @@ -2924,7 +2971,9 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) { Value *Cond = SI->getCondition(); if (!isGuaranteedNotToBeUndefOrPoison(Cond, nullptr, SI)) Cond = new FreezeInst(Cond, "cond.fr", SI); - Instruction *Term = SplitBlockAndInsertIfThen(Cond, SI, false); + MDNode *BranchWeights = getBranchWeightMDNode(*SI); + Instruction *Term = + SplitBlockAndInsertIfThen(Cond, SI, false, BranchWeights); BasicBlock *SplitBB = SI->getParent(); BasicBlock *NewBB = Term->getParent(); PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); @@ -3059,8 +3108,8 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard, if (!isa<PHINode>(&*BI)) ToRemove.push_back(&*BI); - Instruction *InsertionPoint = &*BB->getFirstInsertionPt(); - assert(InsertionPoint && "Empty block?"); + BasicBlock::iterator InsertionPoint = BB->getFirstInsertionPt(); + assert(InsertionPoint != BB->end() && "Empty block?"); // Substitute with Phis & remove. for (auto *Inst : reverse(ToRemove)) { if (!Inst->use_empty()) { @@ -3070,6 +3119,7 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard, NewPN->insertBefore(InsertionPoint); Inst->replaceAllUsesWith(NewPN); } + Inst->dropDbgValues(); Inst->eraseFromParent(); } return true; diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index f8fab03f151d..d0afe09ce41d 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -108,6 +108,8 @@ STATISTIC(NumGEPsHoisted, "Number of geps reassociated and hoisted out of the loop"); STATISTIC(NumAddSubHoisted, "Number of add/subtract expressions reassociated " "and hoisted out of the loop"); +STATISTIC(NumFPAssociationsHoisted, "Number of invariant FP expressions " + "reassociated and hoisted out of the loop"); /// Memory promotion is enabled by default. static cl::opt<bool> @@ -127,6 +129,12 @@ static cl::opt<uint32_t> MaxNumUsesTraversed( cl::desc("Max num uses visited for identifying load " "invariance in loop using invariant start (default = 8)")); +static cl::opt<unsigned> FPAssociationUpperLimit( + "licm-max-num-fp-reassociations", cl::init(5U), cl::Hidden, + cl::desc( + "Set upper limit for the number of transformations performed " + "during a single round of hoisting the reassociated expressions.")); + // Experimental option to allow imprecision in LICM in pathological cases, in // exchange for faster compile. This is to be removed if MemorySSA starts to // address the same issue. LICM calls MemorySSAWalker's @@ -473,12 +481,12 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, }); if (!HasCatchSwitch) { - SmallVector<Instruction *, 8> InsertPts; + SmallVector<BasicBlock::iterator, 8> InsertPts; SmallVector<MemoryAccess *, 8> MSSAInsertPts; InsertPts.reserve(ExitBlocks.size()); MSSAInsertPts.reserve(ExitBlocks.size()); for (BasicBlock *ExitBlock : ExitBlocks) { - InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); + InsertPts.push_back(ExitBlock->getFirstInsertionPt()); MSSAInsertPts.push_back(nullptr); } @@ -985,7 +993,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // loop invariant). If so make them unconditional by moving them to their // immediate dominator. We iterate through the instructions in reverse order // which ensures that when we rehoist an instruction we rehoist its operands, - // and also keep track of where in the block we are rehoisting to to make sure + // and also keep track of where in the block we are rehoisting to make sure // that we rehoist instructions before the instructions that use them. Instruction *HoistPoint = nullptr; if (ControlFlowHoisting) { @@ -1031,7 +1039,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // invariant.start has no uses. static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, Loop *CurLoop) { - Value *Addr = LI->getOperand(0); + Value *Addr = LI->getPointerOperand(); const DataLayout &DL = LI->getModule()->getDataLayout(); const TypeSize LocSizeInBits = DL.getTypeSizeInBits(LI->getType()); @@ -1047,20 +1055,6 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, if (LocSizeInBits.isScalable()) return false; - // if the type is i8 addrspace(x)*, we know this is the type of - // llvm.invariant.start operand - auto *PtrInt8Ty = PointerType::get(Type::getInt8Ty(LI->getContext()), - LI->getPointerAddressSpace()); - unsigned BitcastsVisited = 0; - // Look through bitcasts until we reach the i8* type (this is invariant.start - // operand type). - while (Addr->getType() != PtrInt8Ty) { - auto *BC = dyn_cast<BitCastInst>(Addr); - // Avoid traversing high number of bitcast uses. - if (++BitcastsVisited > MaxNumUsesTraversed || !BC) - return false; - Addr = BC->getOperand(0); - } // If we've ended up at a global/constant, bail. We shouldn't be looking at // uselists for non-local Values in a loop pass. if (isa<Constant>(Addr)) @@ -1480,8 +1474,9 @@ static Instruction *cloneInstructionInExitBlock( if (LI->wouldBeOutOfLoopUseRequiringLCSSA(Op.get(), PN.getParent())) { auto *OInst = cast<Instruction>(Op.get()); PHINode *OpPN = - PHINode::Create(OInst->getType(), PN.getNumIncomingValues(), - OInst->getName() + ".lcssa", &ExitBlock.front()); + PHINode::Create(OInst->getType(), PN.getNumIncomingValues(), + OInst->getName() + ".lcssa"); + OpPN->insertBefore(ExitBlock.begin()); for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) OpPN->addIncoming(OInst, PN.getIncomingBlock(i)); Op = OpPN; @@ -1799,7 +1794,7 @@ namespace { class LoopPromoter : public LoadAndStorePromoter { Value *SomePtr; // Designated pointer to store to. SmallVectorImpl<BasicBlock *> &LoopExitBlocks; - SmallVectorImpl<Instruction *> &LoopInsertPts; + SmallVectorImpl<BasicBlock::iterator> &LoopInsertPts; SmallVectorImpl<MemoryAccess *> &MSSAInsertPts; PredIteratorCache &PredCache; MemorySSAUpdater &MSSAU; @@ -1823,7 +1818,8 @@ class LoopPromoter : public LoadAndStorePromoter { // We need to create an LCSSA PHI node for the incoming value and // store that. PHINode *PN = PHINode::Create(I->getType(), PredCache.size(BB), - I->getName() + ".lcssa", &BB->front()); + I->getName() + ".lcssa"); + PN->insertBefore(BB->begin()); for (BasicBlock *Pred : PredCache.get(BB)) PN->addIncoming(I, Pred); return PN; @@ -1832,7 +1828,7 @@ class LoopPromoter : public LoadAndStorePromoter { public: LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S, SmallVectorImpl<BasicBlock *> &LEB, - SmallVectorImpl<Instruction *> &LIP, + SmallVectorImpl<BasicBlock::iterator> &LIP, SmallVectorImpl<MemoryAccess *> &MSSAIP, PredIteratorCache &PIC, MemorySSAUpdater &MSSAU, LoopInfo &li, DebugLoc dl, Align Alignment, bool UnorderedAtomic, const AAMDNodes &AATags, @@ -1855,7 +1851,7 @@ public: Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); LiveInValue = maybeInsertLCSSAPHI(LiveInValue, ExitBlock); Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock); - Instruction *InsertPos = LoopInsertPts[i]; + BasicBlock::iterator InsertPos = LoopInsertPts[i]; StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); if (UnorderedAtomic) NewSI->setOrdering(AtomicOrdering::Unordered); @@ -1934,23 +1930,6 @@ bool isNotVisibleOnUnwindInLoop(const Value *Object, const Loop *L, isNotCapturedBeforeOrInLoop(Object, L, DT); } -// We don't consider globals as writable: While the physical memory is writable, -// we may not have provenance to perform the write. -bool isWritableObject(const Value *Object) { - // TODO: Alloca might not be writable after its lifetime ends. - // See https://github.com/llvm/llvm-project/issues/51838. - if (isa<AllocaInst>(Object)) - return true; - - // TODO: Also handle sret. - if (auto *A = dyn_cast<Argument>(Object)) - return A->hasByValAttr(); - - // TODO: Noalias has nothing to do with writability, this should check for - // an allocator function. - return isNoAliasCall(Object); -} - bool isThreadLocalObject(const Value *Object, const Loop *L, DominatorTree *DT, TargetTransformInfo *TTI) { // The object must be function-local to start with, and then not captured @@ -1970,7 +1949,7 @@ bool isThreadLocalObject(const Value *Object, const Loop *L, DominatorTree *DT, bool llvm::promoteLoopAccessesToScalars( const SmallSetVector<Value *, 8> &PointerMustAliases, SmallVectorImpl<BasicBlock *> &ExitBlocks, - SmallVectorImpl<Instruction *> &InsertPts, + SmallVectorImpl<BasicBlock::iterator> &InsertPts, SmallVectorImpl<MemoryAccess *> &MSSAInsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, AssumptionCache *AC, const TargetLibraryInfo *TLI, TargetTransformInfo *TTI, Loop *CurLoop, @@ -2192,7 +2171,10 @@ bool llvm::promoteLoopAccessesToScalars( // violating the memory model. if (StoreSafety == StoreSafetyUnknown) { Value *Object = getUnderlyingObject(SomePtr); - if (isWritableObject(Object) && + bool ExplicitlyDereferenceableOnly; + if (isWritableObject(Object, ExplicitlyDereferenceableOnly) && + (!ExplicitlyDereferenceableOnly || + isDereferenceablePointer(SomePtr, AccessTy, MDL)) && isThreadLocalObject(Object, CurLoop, DT, TTI)) StoreSafety = StoreSafe; } @@ -2511,7 +2493,7 @@ static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, // handle both offsets being non-negative. const DataLayout &DL = GEP->getModule()->getDataLayout(); auto NonNegative = [&](Value *V) { - return isKnownNonNegative(V, DL, 0, AC, GEP, DT); + return isKnownNonNegative(V, SimplifyQuery(DL, DT, AC, GEP)); }; bool IsInBounds = Src->isInBounds() && GEP->isInBounds() && all_of(Src->indices(), NonNegative) && @@ -2561,8 +2543,9 @@ static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS, // we want to avoid this. auto &DL = L.getHeader()->getModule()->getDataLayout(); bool ProvedNoOverflowAfterReassociate = - computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp, - DT) == llvm::OverflowResult::NeverOverflows; + computeOverflowForSignedSub(InvariantRHS, InvariantOp, + SimplifyQuery(DL, DT, AC, &ICmp)) == + llvm::OverflowResult::NeverOverflows; if (!ProvedNoOverflowAfterReassociate) return false; auto *Preheader = L.getLoopPreheader(); @@ -2612,15 +2595,16 @@ static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS, // we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that // "C1 - C2" does not overflow. auto &DL = L.getHeader()->getModule()->getDataLayout(); + SimplifyQuery SQ(DL, DT, AC, &ICmp); if (VariantSubtracted) { // C1 - LV < C2 --> LV > C1 - C2 - if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp, - DT) != llvm::OverflowResult::NeverOverflows) + if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, SQ) != + llvm::OverflowResult::NeverOverflows) return false; } else { // LV - C1 < C2 --> LV < C1 + C2 - if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp, - DT) != llvm::OverflowResult::NeverOverflows) + if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, SQ) != + llvm::OverflowResult::NeverOverflows) return false; } auto *Preheader = L.getLoopPreheader(); @@ -2674,6 +2658,72 @@ static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, return false; } +/// Try to reassociate expressions like ((A1 * B1) + (A2 * B2) + ...) * C where +/// A1, A2, ... and C are loop invariants into expressions like +/// ((A1 * C * B1) + (A2 * C * B2) + ...) and hoist the (A1 * C), (A2 * C), ... +/// invariant expressions. This functions returns true only if any hoisting has +/// actually occured. +static bool hoistFPAssociation(Instruction &I, Loop &L, + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, AssumptionCache *AC, + DominatorTree *DT) { + using namespace PatternMatch; + Value *VariantOp = nullptr, *InvariantOp = nullptr; + + if (!match(&I, m_FMul(m_Value(VariantOp), m_Value(InvariantOp))) || + !I.hasAllowReassoc() || !I.hasNoSignedZeros()) + return false; + if (L.isLoopInvariant(VariantOp)) + std::swap(VariantOp, InvariantOp); + if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp)) + return false; + Value *Factor = InvariantOp; + + // First, we need to make sure we should do the transformation. + SmallVector<Use *> Changes; + SmallVector<BinaryOperator *> Worklist; + if (BinaryOperator *VariantBinOp = dyn_cast<BinaryOperator>(VariantOp)) + Worklist.push_back(VariantBinOp); + while (!Worklist.empty()) { + BinaryOperator *BO = Worklist.pop_back_val(); + if (!BO->hasOneUse() || !BO->hasAllowReassoc() || !BO->hasNoSignedZeros()) + return false; + BinaryOperator *Op0, *Op1; + if (match(BO, m_FAdd(m_BinOp(Op0), m_BinOp(Op1)))) { + Worklist.push_back(Op0); + Worklist.push_back(Op1); + continue; + } + if (BO->getOpcode() != Instruction::FMul || L.isLoopInvariant(BO)) + return false; + Use &U0 = BO->getOperandUse(0); + Use &U1 = BO->getOperandUse(1); + if (L.isLoopInvariant(U0)) + Changes.push_back(&U0); + else if (L.isLoopInvariant(U1)) + Changes.push_back(&U1); + else + return false; + if (Changes.size() > FPAssociationUpperLimit) + return false; + } + if (Changes.empty()) + return false; + + // We know we should do it so let's do the transformation. + auto *Preheader = L.getLoopPreheader(); + assert(Preheader && "Loop is not in simplify form?"); + IRBuilder<> Builder(Preheader->getTerminator()); + for (auto *U : Changes) { + assert(L.isLoopInvariant(U->get())); + Instruction *Ins = cast<Instruction>(U->getUser()); + U->set(Builder.CreateFMulFMF(U->get(), Factor, Ins, "factor.op.fmul")); + } + I.replaceAllUsesWith(VariantOp); + eraseInstruction(I, SafetyInfo, MSSAU); + return true; +} + static bool hoistArithmetics(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU, AssumptionCache *AC, @@ -2701,6 +2751,12 @@ static bool hoistArithmetics(Instruction &I, Loop &L, return true; } + if (hoistFPAssociation(I, L, SafetyInfo, MSSAU, AC, DT)) { + ++NumHoisted; + ++NumFPAssociationsHoisted; + return true; + } + return false; } diff --git a/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp index 9ae55b9018da..3d3f22d686e3 100644 --- a/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp +++ b/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp @@ -20,7 +20,8 @@ PreservedAnalyses LoopAccessInfoPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { auto &LAIs = AM.getResult<LoopAccessAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); - OS << "Loop access info in function '" << F.getName() << "':\n"; + OS << "Printing analysis 'Loop Access Analysis' for function '" << F.getName() + << "':\n"; SmallPriorityWorklist<Loop *, 4> Worklist; appendLoopsToWorklist(LI, Worklist); diff --git a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp index 2b9800f11912..9a27a08c86eb 100644 --- a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp +++ b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp @@ -430,7 +430,7 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); // Update phi node in exit block of post-loop. - Builder.SetInsertPoint(&PostLoopPreHeader->front()); + Builder.SetInsertPoint(PostLoopPreHeader, PostLoopPreHeader->begin()); for (PHINode &PN : PostLoop->getExitBlock()->phis()) { for (auto i : seq<int>(0, PN.getNumOperands())) { // Check incoming block is pre-loop's exiting block. diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 7c2770979a90..cc1f56014eee 100644 --- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -399,7 +399,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { continue; unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace(); - Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); + Type *I8Ptr = PointerType::get(BB->getContext(), PtrAddrSpace); Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); IRBuilder<> Builder(P.InsertPt); diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 27196e46ca56..626888c74bad 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -104,9 +104,9 @@ static cl::opt<unsigned> DistributeSCEVCheckThreshold( static cl::opt<unsigned> PragmaDistributeSCEVCheckThreshold( "loop-distribute-scev-check-threshold-with-pragma", cl::init(128), cl::Hidden, - cl::desc( - "The maximum number of SCEV checks allowed for Loop " - "Distribution for loop marked with #pragma loop distribute(enable)")); + cl::desc("The maximum number of SCEV checks allowed for Loop " + "Distribution for loop marked with #pragma clang loop " + "distribute(enable)")); static cl::opt<bool> EnableLoopDistribute( "enable-loop-distribute", cl::Hidden, diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index edc8a4956dd1..b1add3c42976 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -641,8 +641,9 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, // Check if the multiply could not overflow due to known ranges of the // input values. OverflowResult OR = computeOverflowForUnsignedMul( - FI.InnerTripCount, FI.OuterTripCount, DL, AC, - FI.OuterLoop->getLoopPreheader()->getTerminator(), DT); + FI.InnerTripCount, FI.OuterTripCount, + SimplifyQuery(DL, DT, AC, + FI.OuterLoop->getLoopPreheader()->getTerminator())); if (OR != OverflowResult::MayOverflow) return OR; diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index d35b562be0aa..e0b224d5ef73 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -1411,7 +1411,7 @@ private: } // Walk through all uses in FC1. For each use, find the reaching def. If the - // def is located in FC0 then it is is not safe to fuse. + // def is located in FC0 then it is not safe to fuse. for (BasicBlock *BB : FC1.L->blocks()) for (Instruction &I : *BB) for (auto &Op : I.operands()) @@ -1473,12 +1473,13 @@ private: for (Instruction *I : HoistInsts) { assert(I->getParent() == FC1.Preheader); - I->moveBefore(FC0.Preheader->getTerminator()); + I->moveBefore(*FC0.Preheader, + FC0.Preheader->getTerminator()->getIterator()); } // insert instructions in reverse order to maintain dominance relationship for (Instruction *I : reverse(SinkInsts)) { assert(I->getParent() == FC1.Preheader); - I->moveBefore(&*FC1.ExitBlock->getFirstInsertionPt()); + I->moveBefore(*FC1.ExitBlock, FC1.ExitBlock->getFirstInsertionPt()); } } @@ -1491,7 +1492,7 @@ private: /// 2. The successors of the guard have the same flow into/around the loop. /// If the compare instructions are identical, then the first successor of the /// guard must go to the same place (either the preheader of the loop or the - /// NonLoopBlock). In other words, the the first successor of both loops must + /// NonLoopBlock). In other words, the first successor of both loops must /// both go into the loop (i.e., the preheader) or go around the loop (i.e., /// the NonLoopBlock). The same must be true for the second successor. bool haveIdenticalGuards(const FusionCandidate &FC0, @@ -1624,7 +1625,7 @@ private: // first, or undef otherwise. This is sound as exiting the first implies the // second will exit too, __without__ taking the back-edge. [Their // trip-counts are equal after all. - // KB: Would this sequence be simpler to just just make FC0.ExitingBlock go + // KB: Would this sequence be simpler to just make FC0.ExitingBlock go // to FC1.Header? I think this is basically what the three sequences are // trying to accomplish; however, doing this directly in the CFG may mean // the DT/PDT becomes invalid @@ -1671,7 +1672,7 @@ private: // exiting the first and jumping to the header of the second does not break // the SSA property of the phis originally in the first loop. See also the // comment above. - Instruction *L1HeaderIP = &FC1.Header->front(); + BasicBlock::iterator L1HeaderIP = FC1.Header->begin(); for (PHINode *LCPHI : OriginalFC0PHIs) { int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch); assert(L1LatchBBIdx >= 0 && @@ -1679,8 +1680,9 @@ private: Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx); - PHINode *L1HeaderPHI = PHINode::Create( - LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP); + PHINode *L1HeaderPHI = + PHINode::Create(LCV->getType(), 2, LCPHI->getName() + ".afterFC0"); + L1HeaderPHI->insertBefore(L1HeaderIP); L1HeaderPHI->addIncoming(LCV, FC0.Latch); L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()), FC0.ExitingBlock); @@ -1953,7 +1955,7 @@ private: // exiting the first and jumping to the header of the second does not break // the SSA property of the phis originally in the first loop. See also the // comment above. - Instruction *L1HeaderIP = &FC1.Header->front(); + BasicBlock::iterator L1HeaderIP = FC1.Header->begin(); for (PHINode *LCPHI : OriginalFC0PHIs) { int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch); assert(L1LatchBBIdx >= 0 && @@ -1961,8 +1963,9 @@ private: Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx); - PHINode *L1HeaderPHI = PHINode::Create( - LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP); + PHINode *L1HeaderPHI = + PHINode::Create(LCV->getType(), 2, LCPHI->getName() + ".afterFC0"); + L1HeaderPHI->insertBefore(L1HeaderIP); L1HeaderPHI->addIncoming(LCV, FC0.Latch); L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()), FC0.ExitingBlock); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 8572a442e784..3721564890dd 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -24,12 +24,6 @@ // memcmp, strlen, etc. // Future floating point idioms to recognize in -ffast-math mode: // fpowi -// Future integer operation idioms to recognize: -// ctpop -// -// Beware that isel's default lowering for ctpop is highly inefficient for -// i64 and larger types when i64 is legal and the value has few bits set. It -// would be good to enhance isel to emit a loop for ctpop in this case. // // This could recognize common matrix multiplies and dot product idioms and // replace them with calls to BLAS (if linked in??). @@ -948,9 +942,13 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // to be exactly the size of the memset, which is (BECount+1)*StoreSize const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount); const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV); - if (BECst && ConstSize) - AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * - ConstSize->getValue()->getZExtValue()); + if (BECst && ConstSize) { + std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue(); + std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue(); + // FIXME: Should this check for overflow? + if (BEInt && SizeInt) + AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt); + } // TODO: For this to be really effective, we have to dive into the pointer // operand in the store. Store to &A[i] of 100 will always return may alias @@ -1023,7 +1021,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( SCEVExpander Expander(*SE, *DL, "loop-idiom"); SCEVExpanderCleaner ExpCleaner(Expander); - Type *DestInt8PtrTy = Builder.getInt8PtrTy(DestAS); + Type *DestInt8PtrTy = Builder.getPtrTy(DestAS); Type *IntIdxTy = DL->getIndexType(DestPtr->getType()); bool Changed = false; @@ -1107,7 +1105,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( PatternValue, ".memset_pattern"); GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these. GV->setAlignment(Align(16)); - Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); + Value *PatternPtr = GV; NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); // Set the TBAA info if present. @@ -1284,7 +1282,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // feeds the stores. Check for an alias by generating the base address and // checking everything. Value *StoreBasePtr = Expander.expandCodeFor( - StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); + StrStart, Builder.getPtrTy(StrAS), Preheader->getTerminator()); // From here on out, conservatively report to the pass manager that we've // changed the IR, even if we later clean up these added instructions. There @@ -1336,8 +1334,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // For a memcpy, we have to make sure that the input array is not being // mutated by the loop. - Value *LoadBasePtr = Expander.expandCodeFor( - LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); + Value *LoadBasePtr = Expander.expandCodeFor(LdStart, Builder.getPtrTy(LdAS), + Preheader->getTerminator()); // If the store is a memcpy instruction, we must check if it will write to // the load memory locations. So remove it from the ignored stores. @@ -2026,7 +2024,8 @@ void LoopIdiomRecognize::transformLoopToCountable( auto *LbBr = cast<BranchInst>(Body->getTerminator()); ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); - PHINode *TcPhi = PHINode::Create(CountTy, 2, "tcphi", &Body->front()); + PHINode *TcPhi = PHINode::Create(CountTy, 2, "tcphi"); + TcPhi->insertBefore(Body->begin()); Builder.SetInsertPoint(LbCond); Instruction *TcDec = cast<Instruction>(Builder.CreateSub( @@ -2132,7 +2131,8 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); Type *Ty = TripCnt->getType(); - PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi", &Body->front()); + PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi"); + TcPhi->insertBefore(Body->begin()); Builder.SetInsertPoint(LbCond); Instruction *TcDec = cast<Instruction>( @@ -2411,7 +2411,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { // it's use count. Instruction *InsertPt = nullptr; if (auto *BitPosI = dyn_cast<Instruction>(BitPos)) - InsertPt = BitPosI->getInsertionPointAfterDef(); + InsertPt = &**BitPosI->getInsertionPointAfterDef(); else InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca(); if (!InsertPt) @@ -2493,7 +2493,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { // Step 4: Rewrite the loop into a countable form, with canonical IV. // The new canonical induction variable. - Builder.SetInsertPoint(&LoopHeaderBB->front()); + Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->begin()); auto *IV = Builder.CreatePHI(Ty, 2, CurLoop->getName() + ".iv"); // The induction itself. @@ -2817,11 +2817,11 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() { // Step 3: Rewrite the loop into a countable form, with canonical IV. // The new canonical induction variable. - Builder.SetInsertPoint(&LoopHeaderBB->front()); + Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->begin()); auto *CIV = Builder.CreatePHI(Ty, 2, CurLoop->getName() + ".iv"); // The induction itself. - Builder.SetInsertPoint(LoopHeaderBB->getFirstNonPHI()); + Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->getFirstNonPHIIt()); auto *CIVNext = Builder.CreateAdd(CIV, ConstantInt::get(Ty, 1), CIV->getName() + ".next", /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2); diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index c9798a80978d..cfe069d00bce 100644 --- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -29,8 +29,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" @@ -172,46 +170,6 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, return Changed; } -namespace { - -class LoopInstSimplifyLegacyPass : public LoopPass { -public: - static char ID; // Pass ID, replacement for typeid - - LoopInstSimplifyLegacyPass() : LoopPass(ID) { - initializeLoopInstSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - *L->getHeader()->getParent()); - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( - *L->getHeader()->getParent()); - MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MemorySSAUpdater MSSAU(MSSA); - - return simplifyLoopInst(*L, DT, LI, AC, TLI, &MSSAU); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.setPreservesCFG(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { @@ -231,18 +189,3 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -char LoopInstSimplifyLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopInstSimplifyLegacyPass, "loop-instsimplify", - "Simplify instructions in loops", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(LoopInstSimplifyLegacyPass, "loop-instsimplify", - "Simplify instructions in loops", false, false) - -Pass *llvm::createLoopInstSimplifyPass() { - return new LoopInstSimplifyLegacyPass(); -} diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 91286ebcea33..277f530ee25f 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1374,7 +1374,7 @@ bool LoopInterchangeTransform::transform() { for (Instruction &I : make_early_inc_range(make_range(InnerLoopPreHeader->begin(), std::prev(InnerLoopPreHeader->end())))) - I.moveBefore(OuterLoopHeader->getTerminator()); + I.moveBeforePreserving(OuterLoopHeader->getTerminator()); } Transformed |= adjustLoopLinks(); diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 179ccde8d035..5ec387300aac 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -195,7 +195,8 @@ public: Instruction *Source = Dep.getSource(LAI); Instruction *Destination = Dep.getDestination(LAI); - if (Dep.Type == MemoryDepChecker::Dependence::Unknown) { + if (Dep.Type == MemoryDepChecker::Dependence::Unknown || + Dep.Type == MemoryDepChecker::Dependence::IndirectUnsafe) { if (isa<LoadInst>(Source)) LoadsWithUnknownDepedence.insert(Source); if (isa<LoadInst>(Destination)) @@ -443,8 +444,8 @@ public: Cand.Load->getType(), InitialPtr, "load_initial", /* isVolatile */ false, Cand.Load->getAlign(), PH->getTerminator()); - PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", - &L->getHeader()->front()); + PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded"); + PHI->insertBefore(L->getHeader()->begin()); PHI->addIncoming(Initial, PH); Type *LoadType = Initial->getType(); diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index 2c8a3351281b..a4f2dbf9a582 100644 --- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -269,11 +269,12 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, PI.pushBeforeNonSkippedPassCallback([&LAR, &LI](StringRef PassID, Any IR) { if (isSpecialPass(PassID, {"PassManager"})) return; - assert(any_cast<const Loop *>(&IR) || any_cast<const LoopNest *>(&IR)); - const Loop **LPtr = any_cast<const Loop *>(&IR); + assert(llvm::any_cast<const Loop *>(&IR) || + llvm::any_cast<const LoopNest *>(&IR)); + const Loop **LPtr = llvm::any_cast<const Loop *>(&IR); const Loop *L = LPtr ? *LPtr : nullptr; if (!L) - L = &any_cast<const LoopNest *>(IR)->getOutermostLoop(); + L = &llvm::any_cast<const LoopNest *>(IR)->getOutermostLoop(); assert(L && "Loop should be valid for printing"); // Verify the loop structure and LCSSA form before visiting the loop. @@ -312,7 +313,8 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, if (LAR.MSSA && !PassPA.getChecker<MemorySSAAnalysis>().preserved()) report_fatal_error("Loop pass manager using MemorySSA contains a pass " - "that does not preserve MemorySSA"); + "that does not preserve MemorySSA", + /*gen_crash_diag*/ false); #ifndef NDEBUG // LoopAnalysisResults should always be valid. diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 12852ae5c460..027dbb9c0f71 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -282,7 +282,7 @@ class LoopPredication { Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); /// Same as above, *except* that this uses the SCEV definition of invariant /// which is that an expression *can be made* invariant via SCEVExpander. - /// Thus, this version is only suitable for finding an insert point to be be + /// Thus, this version is only suitable for finding an insert point to be /// passed to SCEVExpander! Instruction *findInsertPt(const SCEVExpander &Expander, Instruction *User, ArrayRef<const SCEV *> Ops); @@ -307,8 +307,9 @@ class LoopPredication { widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, Instruction *Guard); - unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition, - SCEVExpander &Expander, Instruction *Guard); + void widenChecks(SmallVectorImpl<Value *> &Checks, + SmallVectorImpl<Value *> &WidenedChecks, + SCEVExpander &Expander, Instruction *Guard); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); // If the loop always exits through another block in the loop, we should not @@ -326,49 +327,8 @@ public: bool runOnLoop(Loop *L); }; -class LoopPredicationLegacyPass : public LoopPass { -public: - static char ID; - LoopPredicationLegacyPass() : LoopPass(ID) { - initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<BranchProbabilityInfoWrapperPass>(); - getLoopAnalysisUsage(AU); - AU.addPreserved<MemorySSAWrapperPass>(); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - std::unique_ptr<MemorySSAUpdater> MSSAU; - if (MSSAWP) - MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); - auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr); - return LP.runOnLoop(L); - } -}; - -char LoopPredicationLegacyPass::ID = 0; } // end namespace -INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", - "Loop predication", false, false) -INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", - "Loop predication", false, false) - -Pass *llvm::createLoopPredicationPass() { - return new LoopPredicationLegacyPass(); -} - PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { @@ -754,58 +714,15 @@ LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, } } -unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks, - Value *Condition, - SCEVExpander &Expander, - Instruction *Guard) { - unsigned NumWidened = 0; - // The guard condition is expected to be in form of: - // cond1 && cond2 && cond3 ... - // Iterate over subconditions looking for icmp conditions which can be - // widened across loop iterations. Widening these conditions remember the - // resulting list of subconditions in Checks vector. - SmallVector<Value *, 4> Worklist(1, Condition); - SmallPtrSet<Value *, 4> Visited; - Visited.insert(Condition); - Value *WideableCond = nullptr; - do { - Value *Condition = Worklist.pop_back_val(); - Value *LHS, *RHS; - using namespace llvm::PatternMatch; - if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { - if (Visited.insert(LHS).second) - Worklist.push_back(LHS); - if (Visited.insert(RHS).second) - Worklist.push_back(RHS); - continue; - } - - if (match(Condition, - m_Intrinsic<Intrinsic::experimental_widenable_condition>())) { - // Pick any, we don't care which - WideableCond = Condition; - continue; - } - - if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { - if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, - Guard)) { - Checks.push_back(*NewRangeCheck); - NumWidened++; - continue; +void LoopPredication::widenChecks(SmallVectorImpl<Value *> &Checks, + SmallVectorImpl<Value *> &WidenedChecks, + SCEVExpander &Expander, Instruction *Guard) { + for (auto &Check : Checks) + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Check)) + if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) { + WidenedChecks.push_back(Check); + Check = *NewRangeCheck; } - } - - // Save the condition as is if we can't widen it - Checks.push_back(Condition); - } while (!Worklist.empty()); - // At the moment, our matching logic for wideable conditions implicitly - // assumes we preserve the form: (br (and Cond, WC())). FIXME - // Note that if there were multiple calls to wideable condition in the - // traversal, we only need to keep one, and which one is arbitrary. - if (WideableCond) - Checks.push_back(WideableCond); - return NumWidened; } bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, @@ -815,12 +732,13 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, TotalConsidered++; SmallVector<Value *, 4> Checks; - unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander, - Guard); - if (NumWidened == 0) + SmallVector<Value *> WidenedChecks; + parseWidenableGuard(Guard, Checks); + widenChecks(Checks, WidenedChecks, Expander, Guard); + if (WidenedChecks.empty()) return false; - TotalWidened += NumWidened; + TotalWidened += WidenedChecks.size(); // Emit the new guard condition IRBuilder<> Builder(findInsertPt(Guard, Checks)); @@ -833,7 +751,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, } RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); - LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n"); return true; } @@ -843,20 +761,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions( LLVM_DEBUG(dbgs() << "Processing guard:\n"); LLVM_DEBUG(BI->dump()); - Value *Cond, *WC; - BasicBlock *IfTrueBB, *IfFalseBB; - bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB); - assert(Parsed && "Must be able to parse widenable branch"); - (void)Parsed; - TotalConsidered++; SmallVector<Value *, 4> Checks; - unsigned NumWidened = collectChecks(Checks, BI->getCondition(), - Expander, BI); - if (NumWidened == 0) + SmallVector<Value *> WidenedChecks; + parseWidenableGuard(BI, Checks); + // At the moment, our matching logic for wideable conditions implicitly + // assumes we preserve the form: (br (and Cond, WC())). FIXME + auto WC = extractWidenableCondition(BI); + Checks.push_back(WC); + widenChecks(Checks, WidenedChecks, Expander, BI); + if (WidenedChecks.empty()) return false; - TotalWidened += NumWidened; + TotalWidened += WidenedChecks.size(); // Emit the new guard condition IRBuilder<> Builder(findInsertPt(BI, Checks)); @@ -864,17 +781,18 @@ bool LoopPredication::widenWidenableBranchGuardConditions( auto *OldCond = BI->getCondition(); BI->setCondition(AllChecks); if (InsertAssumesOfPredicatedGuardsConditions) { + BasicBlock *IfTrueBB = BI->getSuccessor(0); Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt()); // If this block has other predecessors, we might not be able to use Cond. // In this case, create a Phi where every other input is `true` and input // from guard block is Cond. - Value *AssumeCond = Cond; + Value *AssumeCond = Builder.CreateAnd(WidenedChecks); if (!IfTrueBB->getUniquePredecessor()) { auto *GuardBB = BI->getParent(); - auto *PN = Builder.CreatePHI(Cond->getType(), pred_size(IfTrueBB), + auto *PN = Builder.CreatePHI(AssumeCond->getType(), pred_size(IfTrueBB), "assume.cond"); for (auto *Pred : predecessors(IfTrueBB)) - PN->addIncoming(Pred == GuardBB ? Cond : Builder.getTrue(), Pred); + PN->addIncoming(Pred == GuardBB ? AssumeCond : Builder.getTrue(), Pred); AssumeCond = PN; } Builder.CreateAssumption(AssumeCond); @@ -883,7 +801,7 @@ bool LoopPredication::widenWidenableBranchGuardConditions( assert(isGuardAsWidenableBranch(BI) && "Stopped being a guard after transform?"); - LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + LLVM_DEBUG(dbgs() << "Widened checks = " << WidenedChecks.size() << "\n"); return true; } @@ -1008,6 +926,9 @@ bool LoopPredication::isLoopProfitableToPredicate() { Numerator += Weight; Denominator += Weight; } + // If all weights are zero act as if there was no profile data + if (Denominator == 0) + return BranchProbability::getBranchProbability(1, NumSucc); return BranchProbability::getBranchProbability(Numerator, Denominator); } else { assert(LatchBlock != ExitingBlock && @@ -1070,13 +991,9 @@ static BranchInst *FindWidenableTerminatorAboveLoop(Loop *L, LoopInfo &LI) { } while (true); if (BasicBlock *Pred = BB->getSinglePredecessor()) { - auto *Term = Pred->getTerminator(); - - Value *Cond, *WC; - BasicBlock *IfTrueBB, *IfFalseBB; - if (parseWidenableBranch(Term, Cond, WC, IfTrueBB, IfFalseBB) && - IfTrueBB == BB) - return cast<BranchInst>(Term); + if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator())) + if (BI->getSuccessor(0) == BB && isWidenableBranch(BI)) + return BI; } return nullptr; } @@ -1164,13 +1081,13 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { if (!BI) continue; - Use *Cond, *WC; - BasicBlock *IfTrueBB, *IfFalseBB; - if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) && - L->contains(IfTrueBB)) { - WC->set(ConstantInt::getTrue(IfTrueBB->getContext())); - ChangedLoop = true; - } + if (auto WC = extractWidenableCondition(BI)) + if (L->contains(BI->getSuccessor(0))) { + assert(WC->hasOneUse() && "Not appropriate widenable branch!"); + WC->user_back()->replaceUsesOfWith( + WC, ConstantInt::getTrue(BI->getContext())); + ChangedLoop = true; + } } if (ChangedLoop) SE->forgetLoop(L); diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 8d59fdff9236..028a487ecdbc 100644 --- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -20,13 +20,11 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -734,52 +732,3 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { -class LoopSimplifyCFGLegacyPass : public LoopPass { -public: - static char ID; // Pass ID, replacement for typeid - LoopSimplifyCFGLegacyPass() : LoopPass(ID) { - initializeLoopSimplifyCFGLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - std::optional<MemorySSAUpdater> MSSAU; - if (MSSAA) - MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); - if (MSSAA && VerifyMemorySSA) - MSSAU->getMemorySSA()->verifyMemorySSA(); - bool DeleteCurrentLoop = false; - bool Changed = simplifyLoopCFG(*L, DT, LI, SE, MSSAU ? &*MSSAU : nullptr, - DeleteCurrentLoop); - if (DeleteCurrentLoop) - LPM.markLoopAsDeleted(*L); - return Changed; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addPreserved<MemorySSAWrapperPass>(); - AU.addPreserved<DependenceAnalysisWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; -} // end namespace - -char LoopSimplifyCFGLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", - "Simplify loop CFG", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", - "Simplify loop CFG", false, false) - -Pass *llvm::createLoopSimplifyCFGPass() { - return new LoopSimplifyCFGLegacyPass(); -} diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp index 597c159682c5..6eedf95e7575 100644 --- a/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -36,13 +36,11 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" @@ -79,7 +77,7 @@ static cl::opt<unsigned> MaxNumberOfUseBBsForSinking( /// AdjustedFreq(BBs) = 99 / SinkFrequencyPercentThreshold% static BlockFrequency adjustedSumFreq(SmallPtrSetImpl<BasicBlock *> &BBs, BlockFrequencyInfo &BFI) { - BlockFrequency T = 0; + BlockFrequency T(0); for (BasicBlock *B : BBs) T += BFI.getBlockFreq(B); if (BBs.size() > 1) @@ -222,9 +220,11 @@ static bool sinkInstruction( // order. No need to stable sort as the block numbers are a total ordering. SmallVector<BasicBlock *, 2> SortedBBsToSinkInto; llvm::append_range(SortedBBsToSinkInto, BBsToSinkInto); - llvm::sort(SortedBBsToSinkInto, [&](BasicBlock *A, BasicBlock *B) { - return LoopBlockNumber.find(A)->second < LoopBlockNumber.find(B)->second; - }); + if (SortedBBsToSinkInto.size() > 1) { + llvm::sort(SortedBBsToSinkInto, [&](BasicBlock *A, BasicBlock *B) { + return LoopBlockNumber.find(A)->second < LoopBlockNumber.find(B)->second; + }); + } BasicBlock *MoveBB = *SortedBBsToSinkInto.begin(); // FIXME: Optimize the efficiency for cloned value replacement. The current @@ -388,58 +388,3 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { return PA; } - -namespace { -struct LegacyLoopSinkPass : public LoopPass { - static char ID; - LegacyLoopSinkPass() : LoopPass(ID) { - initializeLegacyLoopSinkPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) - return false; - - // Enable LoopSink only when runtime profile is available. - // With static profile, the sinking decision may be sub-optimal. - if (!Preheader->getParent()->hasProfileData()) - return false; - - AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - bool Changed = sinkLoopInvariantInstructions( - *L, AA, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), - getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(), - MSSA, SE ? &SE->getSE() : nullptr); - - if (VerifyMemorySSA) - MSSA.verifyMemorySSA(); - - return Changed; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - getLoopAnalysisUsage(AU); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } -}; -} - -char LegacyLoopSinkPass::ID = 0; -INITIALIZE_PASS_BEGIN(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, - false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, false) - -Pass *llvm::createLoopSinkPass() { return new LegacyLoopSinkPass(); } diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index a4369b83e732..39607464dd00 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -67,6 +67,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/IVUsers.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" @@ -188,8 +189,8 @@ static cl::opt<unsigned> SetupCostDepthLimit( "lsr-setupcost-depth-limit", cl::Hidden, cl::init(7), cl::desc("The limit on recursion depth for LSRs setup cost")); -static cl::opt<bool> AllowTerminatingConditionFoldingAfterLSR( - "lsr-term-fold", cl::Hidden, cl::init(false), +static cl::opt<cl::boolOrDefault> AllowTerminatingConditionFoldingAfterLSR( + "lsr-term-fold", cl::Hidden, cl::desc("Attempt to replace primary IV with other IV.")); static cl::opt<bool> AllowDropSolutionIfLessProfitable( @@ -943,12 +944,6 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI, } } - // All pointers have the same requirements, so canonicalize them to an - // arbitrary pointer type to minimize variation. - if (PointerType *PTy = dyn_cast<PointerType>(AccessTy.MemTy)) - AccessTy.MemTy = PointerType::get(IntegerType::get(PTy->getContext(), 1), - PTy->getAddressSpace()); - return AccessTy; } @@ -2794,18 +2789,6 @@ static Value *getWideOperand(Value *Oper) { return Oper; } -/// Return true if we allow an IV chain to include both types. -static bool isCompatibleIVType(Value *LVal, Value *RVal) { - Type *LType = LVal->getType(); - Type *RType = RVal->getType(); - return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy() && - // Different address spaces means (possibly) - // different types of the pointer implementation, - // e.g. i16 vs i32 so disallow that. - (LType->getPointerAddressSpace() == - RType->getPointerAddressSpace())); -} - /// Return an approximation of this SCEV expression's "base", or NULL for any /// constant. Returning the expression itself is conservative. Returning a /// deeper subexpression is more precise and valid as long as it isn't less @@ -2985,7 +2968,7 @@ void LSRInstance::ChainInstruction(Instruction *UserInst, Instruction *IVOper, continue; Value *PrevIV = getWideOperand(Chain.Incs.back().IVOperand); - if (!isCompatibleIVType(PrevIV, NextIV)) + if (PrevIV->getType() != NextIV->getType()) continue; // A phi node terminates a chain. @@ -3279,7 +3262,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // do this if we also found a wide value for the head of the chain. if (isa<PHINode>(Chain.tailUserInst())) { for (PHINode &Phi : L->getHeader()->phis()) { - if (!isCompatibleIVType(&Phi, IVSrc)) + if (Phi.getType() != IVSrc->getType()) continue; Instruction *PostIncV = dyn_cast<Instruction>( Phi.getIncomingValueForBlock(L->getLoopLatch())); @@ -3488,6 +3471,11 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { SmallVector<const SCEV *, 8> Worklist(RegUses.begin(), RegUses.end()); SmallPtrSet<const SCEV *, 32> Visited; + // Don't collect outside uses if we are favoring postinc - the instructions in + // the loop are more important than the ones outside of it. + if (AMK == TTI::AMK_PostIndexed) + return; + while (!Worklist.empty()) { const SCEV *S = Worklist.pop_back_val(); @@ -5559,10 +5547,12 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, "a scale at the same time!"); Constant *C = ConstantInt::getSigned(SE.getEffectiveSCEVType(OpTy), -(uint64_t)Offset); - if (C->getType() != OpTy) - C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, - OpTy, false), - C, OpTy); + if (C->getType() != OpTy) { + C = ConstantFoldCastOperand( + CastInst::getCastOpcode(C, false, OpTy, false), C, OpTy, + CI->getModule()->getDataLayout()); + assert(C && "Cast of ConstantInt should have folded"); + } CI->setOperand(1, C); } @@ -5610,7 +5600,8 @@ void LSRInstance::RewriteForPHI( .setKeepOneInputPHIs()); } else { SmallVector<BasicBlock*, 2> NewBBs; - SplitLandingPadPredecessors(Parent, BB, "", "", NewBBs, &DT, &LI); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + SplitLandingPadPredecessors(Parent, BB, "", "", NewBBs, &DTU, &LI); NewBB = NewBBs[0]; } // If NewBB==NULL, then SplitCriticalEdge refused to split because all @@ -6949,7 +6940,19 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, } } - if (AllowTerminatingConditionFoldingAfterLSR) { + const bool EnableFormTerm = [&] { + switch (AllowTerminatingConditionFoldingAfterLSR) { + case cl::BOU_TRUE: + return true; + case cl::BOU_FALSE: + return false; + case cl::BOU_UNSET: + return TTI.shouldFoldTerminatingConditionAfterLSR(); + } + llvm_unreachable("Unhandled cl::boolOrDefault enum"); + }(); + + if (EnableFormTerm) { if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) { auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 9c6e4ebf62a9..7b4c54370e48 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -111,7 +111,7 @@ static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) { if (!S) continue; - if (S->getString().startswith(Prefix)) + if (S->getString().starts_with(Prefix)) return true; } } @@ -153,9 +153,11 @@ static bool computeUnrollAndJamCount( LoopInfo *LI, AssumptionCache *AC, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned OuterTripCount, - unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount, - unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP, + unsigned OuterTripMultiple, const UnrollCostEstimator &OuterUCE, + unsigned InnerTripCount, unsigned InnerLoopSize, + TargetTransformInfo::UnrollingPreferences &UP, TargetTransformInfo::PeelingPreferences &PP) { + unsigned OuterLoopSize = OuterUCE.getRolledLoopSize(); // First up use computeUnrollCount from the loop unroller to get a count // for unrolling the outer loop, plus any loops requiring explicit // unrolling we leave to the unroller. This uses UP.Threshold / @@ -165,7 +167,7 @@ static bool computeUnrollAndJamCount( bool UseUpperBound = false; bool ExplicitUnroll = computeUnrollCount( L, TTI, DT, LI, AC, SE, EphValues, ORE, OuterTripCount, MaxTripCount, - /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, PP, + /*MaxOrZero*/ false, OuterTripMultiple, OuterUCE, UP, PP, UseUpperBound); if (ExplicitUnroll || UseUpperBound) { // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it @@ -318,39 +320,28 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, } // Approximate the loop size and collect useful info - unsigned NumInlineCandidates; - bool NotDuplicatable; - bool Convergent; SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, &AC, EphValues); Loop *SubLoop = L->getSubLoops()[0]; - InstructionCost InnerLoopSizeIC = - ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable, - Convergent, TTI, EphValues, UP.BEInsns); - InstructionCost OuterLoopSizeIC = - ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent, - TTI, EphValues, UP.BEInsns); - LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterLoopSizeIC << "\n"); - LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSizeIC << "\n"); + UnrollCostEstimator InnerUCE(SubLoop, TTI, EphValues, UP.BEInsns); + UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns); - if (!InnerLoopSizeIC.isValid() || !OuterLoopSizeIC.isValid()) { + if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) { LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions" - << " with invalid cost.\n"); + << " which cannot be duplicated or have invalid cost.\n"); return LoopUnrollResult::Unmodified; } - unsigned InnerLoopSize = *InnerLoopSizeIC.getValue(); - unsigned OuterLoopSize = *OuterLoopSizeIC.getValue(); - if (NotDuplicatable) { - LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable " - "instructions.\n"); - return LoopUnrollResult::Unmodified; - } - if (NumInlineCandidates != 0) { + unsigned InnerLoopSize = InnerUCE.getRolledLoopSize(); + LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterUCE.getRolledLoopSize() + << "\n"); + LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSize << "\n"); + + if (InnerUCE.NumInlineCandidates != 0 || OuterUCE.NumInlineCandidates != 0) { LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); return LoopUnrollResult::Unmodified; } - if (Convergent) { + if (InnerUCE.Convergent || OuterUCE.Convergent) { LLVM_DEBUG( dbgs() << " Not unrolling loop with convergent instructions.\n"); return LoopUnrollResult::Unmodified; @@ -379,7 +370,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // Decide if, and by how much, to unroll bool IsCountSetExplicitly = computeUnrollAndJamCount( L, SubLoop, TTI, DT, LI, &AC, SE, EphValues, &ORE, OuterTripCount, - OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP, PP); + OuterTripMultiple, OuterUCE, InnerTripCount, InnerLoopSize, UP, PP); if (UP.Count <= 1) return LoopUnrollResult::Unmodified; // Unroll factor (Count) must be less or equal to TripCount. diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 335b489d3cb2..f14541a1a037 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -662,19 +662,16 @@ static std::optional<EstimatedUnrollCost> analyzeLoopUnrollCost( unsigned(*RolledDynamicCost.getValue())}}; } -/// ApproximateLoopSize - Approximate the size of the loop. -InstructionCost llvm::ApproximateLoopSize( - const Loop *L, unsigned &NumCalls, bool &NotDuplicatable, bool &Convergent, - const TargetTransformInfo &TTI, +UnrollCostEstimator::UnrollCostEstimator( + const Loop *L, const TargetTransformInfo &TTI, const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) { CodeMetrics Metrics; for (BasicBlock *BB : L->blocks()) Metrics.analyzeBasicBlock(BB, TTI, EphValues); - NumCalls = Metrics.NumInlineCandidates; + NumInlineCandidates = Metrics.NumInlineCandidates; NotDuplicatable = Metrics.notDuplicatable; Convergent = Metrics.convergent; - - InstructionCost LoopSize = Metrics.NumInsts; + LoopSize = Metrics.NumInsts; // Don't allow an estimate of size zero. This would allows unrolling of loops // with huge iteration counts, which is a compile time problem even if it's @@ -685,8 +682,17 @@ InstructionCost llvm::ApproximateLoopSize( if (LoopSize.isValid() && LoopSize < BEInsns + 1) // This is an open coded max() on InstructionCost LoopSize = BEInsns + 1; +} - return LoopSize; +uint64_t UnrollCostEstimator::getUnrolledLoopSize( + const TargetTransformInfo::UnrollingPreferences &UP, + unsigned CountOverwrite) const { + unsigned LS = *LoopSize.getValue(); + assert(LS >= UP.BEInsns && "LoopSize should not be less than BEInsns!"); + if (CountOverwrite) + return static_cast<uint64_t>(LS - UP.BEInsns) * CountOverwrite + UP.BEInsns; + else + return static_cast<uint64_t>(LS - UP.BEInsns) * UP.Count + UP.BEInsns; } // Returns the loop hint metadata node with the given name (for example, @@ -746,36 +752,10 @@ static unsigned getFullUnrollBoostingFactor(const EstimatedUnrollCost &Cost, return MaxPercentThresholdBoost; } -// Produce an estimate of the unrolled cost of the specified loop. This -// is used to a) produce a cost estimate for partial unrolling and b) to -// cheaply estimate cost for full unrolling when we don't want to symbolically -// evaluate all iterations. -class UnrollCostEstimator { - const unsigned LoopSize; - -public: - UnrollCostEstimator(Loop &L, unsigned LoopSize) : LoopSize(LoopSize) {} - - // Returns loop size estimation for unrolled loop, given the unrolling - // configuration specified by UP. - uint64_t - getUnrolledLoopSize(const TargetTransformInfo::UnrollingPreferences &UP, - const unsigned CountOverwrite = 0) const { - assert(LoopSize >= UP.BEInsns && - "LoopSize should not be less than BEInsns!"); - if (CountOverwrite) - return static_cast<uint64_t>(LoopSize - UP.BEInsns) * CountOverwrite + - UP.BEInsns; - else - return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + - UP.BEInsns; - } -}; - static std::optional<unsigned> shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, const unsigned TripMultiple, const unsigned TripCount, - const UnrollCostEstimator UCE, + unsigned MaxTripCount, const UnrollCostEstimator UCE, const TargetTransformInfo::UnrollingPreferences &UP) { // Using unroll pragma @@ -796,6 +776,10 @@ shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, if (PInfo.PragmaFullUnroll && TripCount != 0) return TripCount; + if (PInfo.PragmaEnableUnroll && !TripCount && MaxTripCount && + MaxTripCount <= UnrollMaxUpperBound) + return MaxTripCount; + // if didn't return until here, should continue to other priorties return std::nullopt; } @@ -888,14 +872,14 @@ shouldPartialUnroll(const unsigned LoopSize, const unsigned TripCount, // refactored into it own function. bool llvm::computeUnrollCount( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, - AssumptionCache *AC, - ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, + AssumptionCache *AC, ScalarEvolution &SE, + const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned TripCount, unsigned MaxTripCount, - bool MaxOrZero, unsigned TripMultiple, unsigned LoopSize, + bool MaxOrZero, unsigned TripMultiple, const UnrollCostEstimator &UCE, TargetTransformInfo::UnrollingPreferences &UP, TargetTransformInfo::PeelingPreferences &PP, bool &UseUpperBound) { - UnrollCostEstimator UCE(*L, LoopSize); + unsigned LoopSize = UCE.getRolledLoopSize(); const bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; const bool PragmaFullUnroll = hasUnrollFullPragma(L); @@ -922,7 +906,7 @@ bool llvm::computeUnrollCount( // 1st priority is unroll count set by "unroll-count" option. // 2nd priority is unroll count set by pragma. if (auto UnrollFactor = shouldPragmaUnroll(L, PInfo, TripMultiple, TripCount, - UCE, UP)) { + MaxTripCount, UCE, UP)) { UP.Count = *UnrollFactor; if (UserUnrollCount || (PragmaCount > 0)) { @@ -1177,9 +1161,6 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, return LoopUnrollResult::Unmodified; bool OptForSize = L->getHeader()->getParent()->hasOptSize(); - unsigned NumInlineCandidates; - bool NotDuplicatable; - bool Convergent; TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( L, SE, TTI, BFI, PSI, ORE, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, @@ -1196,30 +1177,22 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, &AC, EphValues); - InstructionCost LoopSizeIC = - ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent, - TTI, EphValues, UP.BEInsns); - LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSizeIC << "\n"); - - if (!LoopSizeIC.isValid()) { + UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns); + if (!UCE.canUnroll()) { LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions" - << " with invalid cost.\n"); + << " which cannot be duplicated or have invalid cost.\n"); return LoopUnrollResult::Unmodified; } - unsigned LoopSize = *LoopSizeIC.getValue(); - if (NotDuplicatable) { - LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" - << " instructions.\n"); - return LoopUnrollResult::Unmodified; - } + unsigned LoopSize = UCE.getRolledLoopSize(); + LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); // When optimizing for size, use LoopSize + 1 as threshold (we use < Threshold // later), to (fully) unroll loops, if it does not increase code size. if (OptForSize) UP.Threshold = std::max(UP.Threshold, LoopSize + 1); - if (NumInlineCandidates != 0) { + if (UCE.NumInlineCandidates != 0) { LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); return LoopUnrollResult::Unmodified; } @@ -1261,7 +1234,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, // Assuming n is the same on all threads, any kind of unrolling is // safe. But currently llvm's notion of convergence isn't powerful // enough to express this. - if (Convergent) + if (UCE.Convergent) UP.AllowRemainder = false; // Try to find the trip count upper bound if we cannot find the exact trip @@ -1277,8 +1250,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, // fully unroll the loop. bool UseUpperBound = false; bool IsCountSetExplicitly = computeUnrollCount( - L, TTI, DT, LI, &AC, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero, - TripMultiple, LoopSize, UP, PP, UseUpperBound); + L, TTI, DT, LI, &AC, SE, EphValues, &ORE, TripCount, MaxTripCount, + MaxOrZero, TripMultiple, UCE, UP, PP, UseUpperBound); if (!UP.Count) return LoopUnrollResult::Unmodified; diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 454aa56be531..6f87e4d91d2c 100644 --- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -13,7 +13,6 @@ #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -21,10 +20,8 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/MisExpect.h" #include <cmath> @@ -105,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) { misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true); SI.setCondition(ArgValue); - - SI.setMetadata(LLVMContext::MD_prof, - MDBuilder(CI->getContext()).createBranchWeights(Weights)); - + setBranchWeights(SI, Weights); return true; } @@ -416,29 +410,3 @@ PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F, return PreservedAnalyses::all(); } - -namespace { -/// Legacy pass for lowering expect intrinsics out of the IR. -/// -/// When this pass is run over a function it uses expect intrinsics which feed -/// branches and switches to provide branch weight metadata for those -/// terminators. It then removes the expect intrinsics from the IR so the rest -/// of the optimizer can ignore them. -class LowerExpectIntrinsic : public FunctionPass { -public: - static char ID; - LowerExpectIntrinsic() : FunctionPass(ID) { - initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); } -}; -} // namespace - -char LowerExpectIntrinsic::ID = 0; -INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", - "Lower 'expect' Intrinsics", false, false) - -FunctionPass *llvm::createLowerExpectIntrinsicPass() { - return new LowerExpectIntrinsic(); -} diff --git a/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp index 8dc037b10cc8..a59ecdda1746 100644 --- a/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -20,25 +20,10 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/GuardUtils.h" using namespace llvm; -namespace { -struct LowerGuardIntrinsicLegacyPass : public FunctionPass { - static char ID; - LowerGuardIntrinsicLegacyPass() : FunctionPass(ID) { - initializeLowerGuardIntrinsicLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; -}; -} - static bool lowerGuardIntrinsic(Function &F) { // Check if we can cheaply rule out the possibility of not having any work to // do. @@ -71,19 +56,6 @@ static bool lowerGuardIntrinsic(Function &F) { return true; } -bool LowerGuardIntrinsicLegacyPass::runOnFunction(Function &F) { - return lowerGuardIntrinsic(F); -} - -char LowerGuardIntrinsicLegacyPass::ID = 0; -INITIALIZE_PASS(LowerGuardIntrinsicLegacyPass, "lower-guard-intrinsic", - "Lower the guard intrinsic to normal control flow", false, - false) - -Pass *llvm::createLowerGuardIntrinsicPass() { - return new LowerGuardIntrinsicLegacyPass(); -} - PreservedAnalyses LowerGuardIntrinsicPass::run(Function &F, FunctionAnalysisManager &AM) { if (lowerGuardIntrinsic(F)) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index f46ea6a20afa..72b9db1e73d7 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" @@ -36,12 +37,9 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/MatrixUtils.h" @@ -180,7 +178,6 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, assert((!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && "Stride must be >= the number of elements in the result vector."); - unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); // Compute the start of the vector with index VecIdx as VecIdx * Stride. Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); @@ -192,11 +189,7 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, else VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); - // Cast elementwise vector start pointer to a pointer to a vector - // (EltType x NumElements)*. - auto *VecType = FixedVectorType::get(EltType, NumElements); - Type *VecPtrType = PointerType::get(VecType, AS); - return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast"); + return VecStart; } /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -1063,13 +1056,6 @@ public: return Changed; } - /// Turns \p BasePtr into an elementwise pointer to \p EltType. - Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { - unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); - Type *EltPtrType = PointerType::get(EltType, AS); - return Builder.CreatePointerCast(BasePtr, EltPtrType); - } - /// Replace intrinsic calls bool VisitCallInst(CallInst *Inst) { if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) @@ -1121,7 +1107,7 @@ public: auto *VType = cast<VectorType>(Ty); Type *EltTy = VType->getElementType(); Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); - Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); + Value *EltPtr = Ptr; MatrixTy Result; for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { Value *GEP = computeVectorAddr( @@ -1147,17 +1133,11 @@ public: Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); - unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); - Value *EltPtr = - Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); - Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); - Type *TilePtrTy = PointerType::get(TileTy, AS); - Value *TilePtr = - Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - return loadMatrix(TileTy, TilePtr, Align, + return loadMatrix(TileTy, TileStart, Align, Builder.getInt64(MatrixShape.getStride()), IsVolatile, ResultShape, Builder); } @@ -1193,17 +1173,11 @@ public: Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); - unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); - Value *EltPtr = - Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); - Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * StoreVal.getNumColumns()); - Type *TilePtrTy = PointerType::get(TileTy, AS); - Value *TilePtr = - Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - storeMatrix(TileTy, StoreVal, TilePtr, MAlign, + storeMatrix(TileTy, StoreVal, TileStart, MAlign, Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); } @@ -1213,7 +1187,7 @@ public: MaybeAlign MAlign, Value *Stride, bool IsVolatile, IRBuilder<> &Builder) { auto VType = cast<VectorType>(Ty); - Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); + Value *EltPtr = Ptr; for (auto Vec : enumerate(StoreVal.vectors())) { Value *GEP = computeVectorAddr( EltPtr, @@ -2180,7 +2154,7 @@ public: /// Returns true if \p V is a matrix value in the given subprogram. bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } - /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to + /// If \p V is a matrix value, print its shape as NumRows x NumColumns to /// \p SS. void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { auto M = Inst2Matrix.find(V); @@ -2201,7 +2175,7 @@ public: write("<no called fn>"); else { StringRef Name = CI->getCalledFunction()->getName(); - if (!Name.startswith("llvm.matrix")) { + if (!Name.starts_with("llvm.matrix")) { write(Name); return; } diff --git a/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp b/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp index e2de322933bc..3c977b816a05 100644 --- a/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp +++ b/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp @@ -19,24 +19,10 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; -namespace { -struct LowerWidenableConditionLegacyPass : public FunctionPass { - static char ID; - LowerWidenableConditionLegacyPass() : FunctionPass(ID) { - initializeLowerWidenableConditionLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; -}; -} - static bool lowerWidenableCondition(Function &F) { // Check if we can cheaply rule out the possibility of not having any work to // do. @@ -65,19 +51,6 @@ static bool lowerWidenableCondition(Function &F) { return true; } -bool LowerWidenableConditionLegacyPass::runOnFunction(Function &F) { - return lowerWidenableCondition(F); -} - -char LowerWidenableConditionLegacyPass::ID = 0; -INITIALIZE_PASS(LowerWidenableConditionLegacyPass, "lower-widenable-condition", - "Lower the widenable condition to default true value", false, - false) - -Pass *llvm::createLowerWidenableConditionPass() { - return new LowerWidenableConditionLegacyPass(); -} - PreservedAnalyses LowerWidenableConditionPass::run(Function &F, FunctionAnalysisManager &AM) { if (lowerWidenableCondition(F)) diff --git a/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp b/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp index a3f09a5a33c3..78e474f925b5 100644 --- a/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp +++ b/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp @@ -42,17 +42,6 @@ using namespace llvm; -namespace { -struct MakeGuardsExplicitLegacyPass : public FunctionPass { - static char ID; - MakeGuardsExplicitLegacyPass() : FunctionPass(ID) { - initializeMakeGuardsExplicitLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; -}; -} - static void turnToExplicitForm(CallInst *Guard, Function *DeoptIntrinsic) { // Replace the guard with an explicit branch (just like in GuardWidening). BasicBlock *OriginalBB = Guard->getParent(); @@ -89,15 +78,6 @@ static bool explicifyGuards(Function &F) { return true; } -bool MakeGuardsExplicitLegacyPass::runOnFunction(Function &F) { - return explicifyGuards(F); -} - -char MakeGuardsExplicitLegacyPass::ID = 0; -INITIALIZE_PASS(MakeGuardsExplicitLegacyPass, "make-guards-explicit", - "Lower the guard intrinsic to explicit control flow form", - false, false) - PreservedAnalyses MakeGuardsExplicitPass::run(Function &F, FunctionAnalysisManager &) { if (explicifyGuards(F)) diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 68642a01b37c..0e55249d63a8 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -19,12 +19,15 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" @@ -66,9 +69,9 @@ static cl::opt<bool> EnableMemCpyOptWithoutLibcalls( STATISTIC(NumMemCpyInstr, "Number of memcpy instructions deleted"); STATISTIC(NumMemSetInfer, "Number of memsets inferred"); -STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy"); -STATISTIC(NumCpyToSet, "Number of memcpys converted to memset"); -STATISTIC(NumCallSlot, "Number of call slot optimizations performed"); +STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy"); +STATISTIC(NumCpyToSet, "Number of memcpys converted to memset"); +STATISTIC(NumCallSlot, "Number of call slot optimizations performed"); STATISTIC(NumStackMove, "Number of stack-move optimizations performed"); namespace { @@ -333,6 +336,17 @@ static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA, return !MSSA->dominates(Clobber, Start); } +// Update AA metadata +static void combineAAMetadata(Instruction *ReplInst, Instruction *I) { + // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be + // handled here, but combineMetadata doesn't support them yet + unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_invariant_group, + LLVMContext::MD_access_group}; + combineMetadata(ReplInst, I, KnownIDs, true); +} + /// When scanning forward over instructions, we look for some other patterns to /// fold away. In particular, this looks for stores to neighboring locations of /// memory. If it sees enough consecutive ones, it attempts to merge them @@ -357,21 +371,13 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // Keeps track of the last memory use or def before the insertion point for // the new memset. The new MemoryDef for the inserted memsets will be inserted - // after MemInsertPoint. It points to either LastMemDef or to the last user - // before the insertion point of the memset, if there are any such users. + // after MemInsertPoint. MemoryUseOrDef *MemInsertPoint = nullptr; - // Keeps track of the last MemoryDef between StartInst and the insertion point - // for the new memset. This will become the defining access of the inserted - // memsets. - MemoryDef *LastMemDef = nullptr; for (++BI; !BI->isTerminator(); ++BI) { auto *CurrentAcc = cast_or_null<MemoryUseOrDef>( MSSAU->getMemorySSA()->getMemoryAccess(&*BI)); - if (CurrentAcc) { + if (CurrentAcc) MemInsertPoint = CurrentAcc; - if (auto *CurrentDef = dyn_cast<MemoryDef>(CurrentAcc)) - LastMemDef = CurrentDef; - } // Calls that only access inaccessible memory do not block merging // accessible stores. @@ -475,16 +481,13 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, if (!Range.TheStores.empty()) AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); - assert(LastMemDef && MemInsertPoint && - "Both LastMemDef and MemInsertPoint need to be set"); auto *NewDef = cast<MemoryDef>(MemInsertPoint->getMemoryInst() == &*BI ? MSSAU->createMemoryAccessBefore( - AMemSet, LastMemDef, MemInsertPoint) + AMemSet, nullptr, MemInsertPoint) : MSSAU->createMemoryAccessAfter( - AMemSet, LastMemDef, MemInsertPoint)); + AMemSet, nullptr, MemInsertPoint)); MSSAU->insertDef(NewDef, /*RenameUses=*/true); - LastMemDef = NewDef; MemInsertPoint = NewDef; // Zap all the stores. @@ -693,7 +696,7 @@ bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI, auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + auto *NewAccess = MSSAU->createMemoryAccessAfter(M, nullptr, LastDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(SI); @@ -814,7 +817,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // store, so we do not need to rename uses. auto *StoreDef = cast<MemoryDef>(MSSA->getMemoryAccess(SI)); auto *NewAccess = MSSAU->createMemoryAccessBefore( - M, StoreDef->getDefiningAccess(), StoreDef); + M, nullptr, StoreDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/false); eraseInstruction(SI); @@ -922,10 +925,12 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, return false; } - // Check that accessing the first srcSize bytes of dest will not cause a - // trap. Otherwise the transform is invalid since it might cause a trap - // to occur earlier than it otherwise would. - if (!isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpySize), + // Check that storing to the first srcSize bytes of dest will not cause a + // trap or data race. + bool ExplicitlyDereferenceableOnly; + if (!isWritableObject(getUnderlyingObject(cpyDest), + ExplicitlyDereferenceableOnly) || + !isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpySize), DL, C, AC, DT)) { LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer not dereferenceable\n"); return false; @@ -1040,12 +1045,13 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, // Since we're changing the parameter to the callsite, we need to make sure // that what would be the new parameter dominates the callsite. + bool NeedMoveGEP = false; if (!DT->dominates(cpyDest, C)) { // Support moving a constant index GEP before the call. auto *GEP = dyn_cast<GetElementPtrInst>(cpyDest); if (GEP && GEP->hasAllConstantIndices() && DT->dominates(GEP->getPointerOperand(), C)) - GEP->moveBefore(C); + NeedMoveGEP = true; else return false; } @@ -1064,29 +1070,19 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, // We can't create address space casts here because we don't know if they're // safe for the target. - if (cpySrc->getType()->getPointerAddressSpace() != - cpyDest->getType()->getPointerAddressSpace()) + if (cpySrc->getType() != cpyDest->getType()) return false; for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI) if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc && - cpySrc->getType()->getPointerAddressSpace() != - C->getArgOperand(ArgI)->getType()->getPointerAddressSpace()) + cpySrc->getType() != C->getArgOperand(ArgI)->getType()) return false; // All the checks have passed, so do the transformation. bool changedArgument = false; for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI) if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc) { - Value *Dest = cpySrc->getType() == cpyDest->getType() ? cpyDest - : CastInst::CreatePointerCast(cpyDest, cpySrc->getType(), - cpyDest->getName(), C); changedArgument = true; - if (C->getArgOperand(ArgI)->getType() == Dest->getType()) - C->setArgOperand(ArgI, Dest); - else - C->setArgOperand(ArgI, CastInst::CreatePointerCast( - Dest, C->getArgOperand(ArgI)->getType(), - Dest->getName(), C)); + C->setArgOperand(ArgI, cpyDest); } if (!changedArgument) @@ -1098,22 +1094,20 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, cast<AllocaInst>(cpyDest)->setAlignment(srcAlign); } + if (NeedMoveGEP) { + auto *GEP = dyn_cast<GetElementPtrInst>(cpyDest); + GEP->moveBefore(C); + } + if (SkippedLifetimeStart) { SkippedLifetimeStart->moveBefore(C); MSSAU->moveBefore(MSSA->getMemoryAccess(SkippedLifetimeStart), MSSA->getMemoryAccess(C)); } - // Update AA metadata - // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be - // handled here, but combineMetadata doesn't support them yet - unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, - LLVMContext::MD_invariant_group, - LLVMContext::MD_access_group}; - combineMetadata(C, cpyLoad, KnownIDs, true); + combineAAMetadata(C, cpyLoad); if (cpyLoad != cpyStore) - combineMetadata(C, cpyStore, KnownIDs, true); + combineAAMetadata(C, cpyStore); ++NumCallSlot; return true; @@ -1203,7 +1197,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M))); auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, nullptr, LastDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); // Remove the instruction we're replacing. @@ -1300,12 +1294,8 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); Value *MemsetLen = Builder.CreateSelect( Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); - unsigned DestAS = Dest->getType()->getPointerAddressSpace(); Instruction *NewMemSet = Builder.CreateMemSet( - Builder.CreateGEP( - Builder.getInt8Ty(), - Builder.CreatePointerCast(Dest, Builder.getInt8PtrTy(DestAS)), - SrcSize), + Builder.CreateGEP(Builder.getInt8Ty(), Dest, SrcSize), MemSet->getOperand(1), MemsetLen, Alignment); assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) && @@ -1315,7 +1305,7 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); auto *NewAccess = MSSAU->createMemoryAccessBefore( - NewMemSet, LastDef->getDefiningAccess(), LastDef); + NewMemSet, nullptr, LastDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(MemSet); @@ -1420,7 +1410,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, CopySize, MemCpy->getDestAlign()); auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, nullptr, LastDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); return true; @@ -1440,7 +1430,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, // allocas that aren't captured. bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, AllocaInst *DestAlloca, - AllocaInst *SrcAlloca, uint64_t Size, + AllocaInst *SrcAlloca, TypeSize Size, BatchAAResults &BAA) { LLVM_DEBUG(dbgs() << "Stack Move: Attempting to optimize:\n" << *Store << "\n"); @@ -1451,35 +1441,30 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, return false; } - // 1. Check that copy is full. Calculate the static size of the allocas to be - // merged, bail out if we can't. + // Check that copy is full with static size. const DataLayout &DL = DestAlloca->getModule()->getDataLayout(); std::optional<TypeSize> SrcSize = SrcAlloca->getAllocationSize(DL); - if (!SrcSize || SrcSize->isScalable() || Size != SrcSize->getFixedValue()) { + if (!SrcSize || Size != *SrcSize) { LLVM_DEBUG(dbgs() << "Stack Move: Source alloca size mismatch\n"); return false; } std::optional<TypeSize> DestSize = DestAlloca->getAllocationSize(DL); - if (!DestSize || DestSize->isScalable() || - Size != DestSize->getFixedValue()) { + if (!DestSize || Size != *DestSize) { LLVM_DEBUG(dbgs() << "Stack Move: Destination alloca size mismatch\n"); return false; } - // 2-1. Check that src and dest are static allocas, which are not affected by - // stacksave/stackrestore. - if (!SrcAlloca->isStaticAlloca() || !DestAlloca->isStaticAlloca() || - SrcAlloca->getParent() != Load->getParent() || - SrcAlloca->getParent() != Store->getParent()) + if (!SrcAlloca->isStaticAlloca() || !DestAlloca->isStaticAlloca()) return false; - // 2-2. Check that src and dest are never captured, unescaped allocas. Also - // collect lifetime markers first/last users in order to shrink wrap the - // lifetimes, and instructions with noalias metadata to remove them. + // Check that src and dest are never captured, unescaped allocas. Also + // find the nearest common dominator and postdominator for all users in + // order to shrink wrap the lifetimes, and instructions with noalias metadata + // to remove them. SmallVector<Instruction *, 4> LifetimeMarkers; - Instruction *FirstUser = nullptr, *LastUser = nullptr; SmallSet<Instruction *, 4> NoAliasInstrs; + bool SrcNotDom = false; // Recursively track the user and check whether modified alias exist. auto IsDereferenceableOrNull = [](Value *V, const DataLayout &DL) -> bool { @@ -1499,6 +1484,12 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, Instruction *I = Worklist.back(); Worklist.pop_back(); for (const Use &U : I->uses()) { + auto *UI = cast<Instruction>(U.getUser()); + // If any use that isn't dominated by SrcAlloca exists, we move src + // alloca to the entry before the transformation. + if (!DT->dominates(SrcAlloca, UI)) + SrcNotDom = true; + if (Visited.size() >= MaxUsesToExplore) { LLVM_DEBUG( dbgs() @@ -1512,22 +1503,15 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, return false; case UseCaptureKind::PASSTHROUGH: // Instructions cannot have non-instruction users. - Worklist.push_back(cast<Instruction>(U.getUser())); + Worklist.push_back(UI); continue; case UseCaptureKind::NO_CAPTURE: { - auto *UI = cast<Instruction>(U.getUser()); - if (DestAlloca->getParent() != UI->getParent()) - return false; - if (!FirstUser || UI->comesBefore(FirstUser)) - FirstUser = UI; - if (!LastUser || LastUser->comesBefore(UI)) - LastUser = UI; if (UI->isLifetimeStartOrEnd()) { // We note the locations of these intrinsic calls so that we can // delete them later if the optimization succeeds, this is safe // since both llvm.lifetime.start and llvm.lifetime.end intrinsics - // conceptually fill all the bytes of the alloca with an undefined - // value. + // practically fill all the bytes of the alloca with an undefined + // value, although conceptually marked as alive/dead. int64_t Size = cast<ConstantInt>(UI->getOperand(0))->getSExtValue(); if (Size < 0 || Size == DestSize) { LifetimeMarkers.push_back(UI); @@ -1545,37 +1529,64 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, return true; }; - // 3. Check that dest has no Mod/Ref, except full size lifetime intrinsics, - // from the alloca to the Store. + // Check that dest has no Mod/Ref, from the alloca to the Store, except full + // size lifetime intrinsics. And collect modref inst for the reachability + // check. ModRefInfo DestModRef = ModRefInfo::NoModRef; MemoryLocation DestLoc(DestAlloca, LocationSize::precise(Size)); + SmallVector<BasicBlock *, 8> ReachabilityWorklist; auto DestModRefCallback = [&](Instruction *UI) -> bool { // We don't care about the store itself. if (UI == Store) return true; ModRefInfo Res = BAA.getModRefInfo(UI, DestLoc); - // FIXME: For multi-BB cases, we need to see reachability from it to - // store. - // Bailout if Dest may have any ModRef before Store. - if (UI->comesBefore(Store) && isModOrRefSet(Res)) - return false; - DestModRef |= BAA.getModRefInfo(UI, DestLoc); + DestModRef |= Res; + if (isModOrRefSet(Res)) { + // Instructions reachability checks. + // FIXME: adding the Instruction version isPotentiallyReachableFromMany on + // lib/Analysis/CFG.cpp (currently only for BasicBlocks) might be helpful. + if (UI->getParent() == Store->getParent()) { + // The same block case is special because it's the only time we're + // looking within a single block to see which instruction comes first. + // Once we start looking at multiple blocks, the first instruction of + // the block is reachable, so we only need to determine reachability + // between whole blocks. + BasicBlock *BB = UI->getParent(); + // If A comes before B, then B is definitively reachable from A. + if (UI->comesBefore(Store)) + return false; + + // If the user's parent block is entry, no predecessor exists. + if (BB->isEntryBlock()) + return true; + + // Otherwise, continue doing the normal per-BB CFG walk. + ReachabilityWorklist.append(succ_begin(BB), succ_end(BB)); + } else { + ReachabilityWorklist.push_back(UI->getParent()); + } + } return true; }; if (!CaptureTrackingWithModRef(DestAlloca, DestModRefCallback)) return false; + // Bailout if Dest may have any ModRef before Store. + if (!ReachabilityWorklist.empty() && + isPotentiallyReachableFromMany(ReachabilityWorklist, Store->getParent(), + nullptr, DT, nullptr)) + return false; - // 3. Check that, from after the Load to the end of the BB, - // 3-1. if the dest has any Mod, src has no Ref, and - // 3-2. if the dest has any Ref, src has no Mod except full-sized lifetimes. + // Check that, from after the Load to the end of the BB, + // - if the dest has any Mod, src has no Ref, and + // - if the dest has any Ref, src has no Mod except full-sized lifetimes. MemoryLocation SrcLoc(SrcAlloca, LocationSize::precise(Size)); auto SrcModRefCallback = [&](Instruction *UI) -> bool { - // Any ModRef before Load doesn't matter, also Load and Store can be - // ignored. - if (UI->comesBefore(Load) || UI == Load || UI == Store) + // Any ModRef post-dominated by Load doesn't matter, also Load and Store + // themselves can be ignored. + if (PDT->dominates(Load, UI) || UI == Load || UI == Store) return true; ModRefInfo Res = BAA.getModRefInfo(UI, SrcLoc); if ((isModSet(DestModRef) && isRefSet(Res)) || @@ -1588,7 +1599,12 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, if (!CaptureTrackingWithModRef(SrcAlloca, SrcModRefCallback)) return false; - // We can do the transformation. First, align the allocas appropriately. + // We can do the transformation. First, move the SrcAlloca to the start of the + // BB. + if (SrcNotDom) + SrcAlloca->moveBefore(*SrcAlloca->getParent(), + SrcAlloca->getParent()->getFirstInsertionPt()); + // Align the allocas appropriately. SrcAlloca->setAlignment( std::max(SrcAlloca->getAlign(), DestAlloca->getAlign())); @@ -1599,28 +1615,10 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, // Drop metadata on the source alloca. SrcAlloca->dropUnknownNonDebugMetadata(); - // Do "shrink wrap" the lifetimes, if the original lifetime intrinsics exists. + // TODO: Reconstruct merged lifetime markers. + // Remove all other lifetime markers. if the original lifetime intrinsics + // exists. if (!LifetimeMarkers.empty()) { - LLVMContext &C = SrcAlloca->getContext(); - IRBuilder<> Builder(C); - - ConstantInt *AllocaSize = ConstantInt::get(Type::getInt64Ty(C), Size); - // Create a new lifetime start marker before the first user of src or alloca - // users. - Builder.SetInsertPoint(FirstUser->getParent(), FirstUser->getIterator()); - Builder.CreateLifetimeStart(SrcAlloca, AllocaSize); - - // Create a new lifetime end marker after the last user of src or alloca - // users. - // FIXME: If the last user is the terminator for the bb, we can insert - // lifetime.end marker to the immidiate post-dominator, but currently do - // nothing. - if (!LastUser->isTerminator()) { - Builder.SetInsertPoint(LastUser->getParent(), ++LastUser->getIterator()); - Builder.CreateLifetimeEnd(SrcAlloca, AllocaSize); - } - - // Remove all other lifetime markers. for (Instruction *I : LifetimeMarkers) eraseInstruction(I); } @@ -1637,6 +1635,16 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, return true; } +static bool isZeroSize(Value *Size) { + if (auto *I = dyn_cast<Instruction>(Size)) + if (auto *Res = simplifyInstruction(I, I->getModule()->getDataLayout())) + Size = Res; + // Treat undef/poison size like zero. + if (auto *C = dyn_cast<Constant>(Size)) + return isa<UndefValue>(C) || C->isNullValue(); + return false; +} + /// Perform simplification of memcpy's. If we have memcpy A /// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite /// B to be a memcpy from X to Z (or potentially a memmove, depending on @@ -1653,6 +1661,19 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { return true; } + // If the size is zero, remove the memcpy. This also prevents infinite loops + // in processMemSetMemCpyDependence, which is a no-op for zero-length memcpys. + if (isZeroSize(M->getLength())) { + ++BBI; + eraseInstruction(M); + return true; + } + + MemoryUseOrDef *MA = MSSA->getMemoryAccess(M); + if (!MA) + // Degenerate case: memcpy marked as not accessing memory. + return false; + // If copying from a constant, try to turn the memcpy into a memset. if (auto *GV = dyn_cast<GlobalVariable>(M->getSource())) if (GV->isConstant() && GV->hasDefinitiveInitializer()) @@ -1661,10 +1682,9 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { IRBuilder<> Builder(M); Instruction *NewM = Builder.CreateMemSet( M->getRawDest(), ByteVal, M->getLength(), M->getDestAlign(), false); - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *LastDef = cast<MemoryDef>(MA); auto *NewAccess = - MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->createMemoryAccessAfter(NewM, nullptr, LastDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(M); @@ -1673,7 +1693,6 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { } BatchAAResults BAA(*AA); - MemoryUseOrDef *MA = MSSA->getMemoryAccess(M); // FIXME: Not using getClobberingMemoryAccess() here due to PR54682. MemoryAccess *AnyClobber = MA->getDefiningAccess(); MemoryLocation DestLoc = MemoryLocation::getForDest(M); @@ -1751,8 +1770,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { ConstantInt *Len = dyn_cast<ConstantInt>(M->getLength()); if (Len == nullptr) return false; - if (performStackMoveOptzn(M, M, DestAlloca, SrcAlloca, Len->getZExtValue(), - BAA)) { + if (performStackMoveOptzn(M, M, DestAlloca, SrcAlloca, + TypeSize::getFixed(Len->getZExtValue()), BAA)) { // Avoid invalidating the iterator. BBI = M->getNextNonDebugInstruction()->getIterator(); eraseInstruction(M); @@ -1831,9 +1850,8 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { DT) < *ByValAlign) return false; - // The address space of the memcpy source must match the byval argument - if (MDep->getSource()->getType()->getPointerAddressSpace() != - ByValArg->getType()->getPointerAddressSpace()) + // The type of the memcpy source must match the byval argument + if (MDep->getSource()->getType() != ByValArg->getType()) return false; // Verify that the copied-from memory doesn't change in between the memcpy and @@ -1851,6 +1869,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { << " " << CB << "\n"); // Otherwise we're good! Update the byval argument. + combineAAMetadata(&CB, MDep); CB.setArgOperand(ArgNo, MDep->getSource()); ++NumMemCpyInstr; return true; @@ -1907,9 +1926,8 @@ bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) { if (!MDep || MDep->isVolatile() || AI != MDep->getDest()) return false; - // The address space of the memcpy source must match the immut argument - if (MDep->getSource()->getType()->getPointerAddressSpace() != - ImmutArg->getType()->getPointerAddressSpace()) + // The type of the memcpy source must match the immut argument + if (MDep->getSource()->getType() != ImmutArg->getType()) return false; // 2-1. The length of the memcpy must be equal to the size of the alloca. @@ -1946,6 +1964,7 @@ bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) { << " " << CB << "\n"); // Otherwise we're good! Update the immut argument. + combineAAMetadata(&CB, MDep); CB.setArgOperand(ArgNo, MDep->getSource()); ++NumMemCpyInstr; return true; @@ -2004,9 +2023,10 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { auto *AA = &AM.getResult<AAManager>(F); auto *AC = &AM.getResult<AssumptionAnalysis>(F); auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); + auto *PDT = &AM.getResult<PostDominatorTreeAnalysis>(F); auto *MSSA = &AM.getResult<MemorySSAAnalysis>(F); - bool MadeChange = runImpl(F, &TLI, AA, AC, DT, &MSSA->getMSSA()); + bool MadeChange = runImpl(F, &TLI, AA, AC, DT, PDT, &MSSA->getMSSA()); if (!MadeChange) return PreservedAnalyses::all(); @@ -2018,12 +2038,14 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { bool MemCpyOptPass::runImpl(Function &F, TargetLibraryInfo *TLI_, AliasAnalysis *AA_, AssumptionCache *AC_, - DominatorTree *DT_, MemorySSA *MSSA_) { + DominatorTree *DT_, PostDominatorTree *PDT_, + MemorySSA *MSSA_) { bool MadeChange = false; TLI = TLI_; AA = AA_; AC = AC_; DT = DT_; + PDT = PDT_; MSSA = MSSA_; MemorySSAUpdater MSSAU_(MSSA_); MSSAU = &MSSAU_; diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index 311a6435ba7c..1e0906717549 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -275,7 +275,7 @@ void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const { // Do the actual spliting. for (Instruction *Inst : reverse(OtherInsts)) - Inst->moveBefore(*NewParent, NewParent->begin()); + Inst->moveBeforePreserving(*NewParent, NewParent->begin()); } bool BCECmpBlock::canSplit(AliasAnalysis &AA) const { diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6c5453831ade..d65054a6ff9d 100644 --- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -80,7 +80,6 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" @@ -217,8 +216,8 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, if (Opd1 == Opd2) return nullptr; - auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink", - &BB->front()); + auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink"); + NewPN->insertBefore(BB->begin()); NewPN->applyMergedLocation(S0->getDebugLoc(), S1->getDebugLoc()); NewPN->addIncoming(Opd1, S0->getParent()); NewPN->addIncoming(Opd2, S1->getParent()); @@ -269,7 +268,7 @@ void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0, // Create the new store to be inserted at the join point. StoreInst *SNew = cast<StoreInst>(S0->clone()); - SNew->insertBefore(&*InsertPt); + SNew->insertBefore(InsertPt); // New PHI operand? Use it. if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) SNew->setOperand(0, NewPN); @@ -378,52 +377,6 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) { return Changed; } -namespace { -class MergedLoadStoreMotionLegacyPass : public FunctionPass { - const bool SplitFooterBB; -public: - static char ID; // Pass identification, replacement for typeid - MergedLoadStoreMotionLegacyPass(bool SplitFooterBB = false) - : FunctionPass(ID), SplitFooterBB(SplitFooterBB) { - initializeMergedLoadStoreMotionLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - /// - /// Run the transformation for each function - /// - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - MergedLoadStoreMotion Impl(SplitFooterBB); - return Impl.run(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); - } - -private: - void getAnalysisUsage(AnalysisUsage &AU) const override { - if (!SplitFooterBB) - AU.setPreservesCFG(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; - -char MergedLoadStoreMotionLegacyPass::ID = 0; -} // anonymous namespace - -/// -/// createMergedLoadStoreMotionPass - The public interface to this file. -/// -FunctionPass *llvm::createMergedLoadStoreMotionPass(bool SplitFooterBB) { - return new MergedLoadStoreMotionLegacyPass(SplitFooterBB); -} - -INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion", - "MergedLoadStoreMotion", false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", - "MergedLoadStoreMotion", false, false) - PreservedAnalyses MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { MergedLoadStoreMotion Impl(Options.SplitFooterBB); diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index 9c3e9a2fd018..7fe1a222021e 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -359,12 +359,13 @@ bool NaryReassociatePass::requiresSignExtension(Value *Index, GetElementPtrInst * NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, Type *IndexedType) { + SimplifyQuery SQ(*DL, DT, AC, GEP); Value *IndexToSplit = GEP->getOperand(I + 1); if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) { IndexToSplit = SExt->getOperand(0); } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) { // zext can be treated as sext if the source is non-negative. - if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT)) + if (isKnownNonNegative(ZExt->getOperand(0), SQ)) IndexToSplit = ZExt->getOperand(0); } @@ -373,8 +374,7 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, // nsw, we cannot split the add because // sext(LHS + RHS) != sext(LHS) + sext(RHS). if (requiresSignExtension(IndexToSplit, GEP) && - computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) != - OverflowResult::NeverOverflows) + computeOverflowForSignedAdd(AO, SQ) != OverflowResult::NeverOverflows) return nullptr; Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); @@ -402,7 +402,7 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, IndexExprs.push_back(SE->getSCEV(Index)); // Replace the I-th index with LHS. IndexExprs[I] = SE->getSCEV(LHS); - if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) && + if (isKnownNonNegative(LHS, SimplifyQuery(*DL, DT, AC, GEP)) && DL->getTypeSizeInBits(LHS->getType()).getFixedValue() < DL->getTypeSizeInBits(GEP->getOperand(I)->getType()) .getFixedValue()) { diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 1af40e2c4e62..19ac9526b5f8 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -774,7 +774,7 @@ private: // Symbolic evaluation. ExprResult checkExprResults(Expression *, Instruction *, Value *) const; - ExprResult performSymbolicEvaluation(Value *, + ExprResult performSymbolicEvaluation(Instruction *, SmallPtrSetImpl<Value *> &) const; const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *, Instruction *, @@ -1904,7 +1904,7 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { LastPredInfo = PI; // In phi of ops cases, we may have predicate info that we are evaluating // in a different context. - if (!DT->dominates(PBranch->To, getBlockForValue(I))) + if (!DT->dominates(PBranch->To, I->getParent())) continue; // TODO: Along the false edge, we may know more things too, like // icmp of @@ -1961,95 +1961,88 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { return createExpression(I); } -// Substitute and symbolize the value before value numbering. +// Substitute and symbolize the instruction before value numbering. NewGVN::ExprResult -NewGVN::performSymbolicEvaluation(Value *V, +NewGVN::performSymbolicEvaluation(Instruction *I, SmallPtrSetImpl<Value *> &Visited) const { const Expression *E = nullptr; - if (auto *C = dyn_cast<Constant>(V)) - E = createConstantExpression(C); - else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { - E = createVariableExpression(V); - } else { - // TODO: memory intrinsics. - // TODO: Some day, we should do the forward propagation and reassociation - // parts of the algorithm. - auto *I = cast<Instruction>(V); - switch (I->getOpcode()) { - case Instruction::ExtractValue: - case Instruction::InsertValue: - E = performSymbolicAggrValueEvaluation(I); - break; - case Instruction::PHI: { - SmallVector<ValPair, 3> Ops; - auto *PN = cast<PHINode>(I); - for (unsigned i = 0; i < PN->getNumOperands(); ++i) - Ops.push_back({PN->getIncomingValue(i), PN->getIncomingBlock(i)}); - // Sort to ensure the invariant createPHIExpression requires is met. - sortPHIOps(Ops); - E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I)); - } break; - case Instruction::Call: - return performSymbolicCallEvaluation(I); - break; - case Instruction::Store: - E = performSymbolicStoreEvaluation(I); - break; - case Instruction::Load: - E = performSymbolicLoadEvaluation(I); - break; - case Instruction::BitCast: - case Instruction::AddrSpaceCast: - case Instruction::Freeze: - return createExpression(I); - break; - case Instruction::ICmp: - case Instruction::FCmp: - return performSymbolicCmpEvaluation(I); - break; - case Instruction::FNeg: - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - case Instruction::Trunc: - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::UIToFP: - case Instruction::SIToFP: - case Instruction::FPTrunc: - case Instruction::FPExt: - case Instruction::PtrToInt: - case Instruction::IntToPtr: - case Instruction::Select: - case Instruction::ExtractElement: - case Instruction::InsertElement: - case Instruction::GetElementPtr: - return createExpression(I); - break; - case Instruction::ShuffleVector: - // FIXME: Add support for shufflevector to createExpression. - return ExprResult::none(); - default: - return ExprResult::none(); - } + // TODO: memory intrinsics. + // TODO: Some day, we should do the forward propagation and reassociation + // parts of the algorithm. + switch (I->getOpcode()) { + case Instruction::ExtractValue: + case Instruction::InsertValue: + E = performSymbolicAggrValueEvaluation(I); + break; + case Instruction::PHI: { + SmallVector<ValPair, 3> Ops; + auto *PN = cast<PHINode>(I); + for (unsigned i = 0; i < PN->getNumOperands(); ++i) + Ops.push_back({PN->getIncomingValue(i), PN->getIncomingBlock(i)}); + // Sort to ensure the invariant createPHIExpression requires is met. + sortPHIOps(Ops); + E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I)); + } break; + case Instruction::Call: + return performSymbolicCallEvaluation(I); + break; + case Instruction::Store: + E = performSymbolicStoreEvaluation(I); + break; + case Instruction::Load: + E = performSymbolicLoadEvaluation(I); + break; + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::Freeze: + return createExpression(I); + break; + case Instruction::ICmp: + case Instruction::FCmp: + return performSymbolicCmpEvaluation(I); + break; + case Instruction::FNeg: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::Select: + case Instruction::ExtractElement: + case Instruction::InsertElement: + case Instruction::GetElementPtr: + return createExpression(I); + break; + case Instruction::ShuffleVector: + // FIXME: Add support for shufflevector to createExpression. + return ExprResult::none(); + default: + return ExprResult::none(); } return ExprResult::some(E); } @@ -2772,6 +2765,9 @@ NewGVN::makePossiblePHIOfOps(Instruction *I, // Clone the instruction, create an expression from it that is // translated back into the predecessor, and see if we have a leader. Instruction *ValueOp = I->clone(); + // Emit the temporal instruction in the predecessor basic block where the + // corresponding value is defined. + ValueOp->insertBefore(PredBB->getTerminator()); if (MemAccess) TempToMemory.insert({ValueOp, MemAccess}); bool SafeForPHIOfOps = true; @@ -2801,7 +2797,7 @@ NewGVN::makePossiblePHIOfOps(Instruction *I, FoundVal = !SafeForPHIOfOps ? nullptr : findLeaderForInst(ValueOp, Visited, MemAccess, I, PredBB); - ValueOp->deleteValue(); + ValueOp->eraseFromParent(); if (!FoundVal) { // We failed to find a leader for the current ValueOp, but this might // change in case of the translated operands change. @@ -3542,7 +3538,7 @@ struct NewGVN::ValueDFS { // the second. We only want it to be less than if the DFS orders are equal. // // Each LLVM instruction only produces one value, and thus the lowest-level - // differentiator that really matters for the stack (and what we use as as a + // differentiator that really matters for the stack (and what we use as a // replacement) is the local dfs number. // Everything else in the structure is instruction level, and only affects // the order in which we will replace operands of a given instruction. @@ -4034,9 +4030,18 @@ bool NewGVN::eliminateInstructions(Function &F) { // because stores are put in terms of the stored value, we skip // stored values here. If the stored value is really dead, it will // still be marked for deletion when we process it in its own class. - if (!EliminationStack.empty() && Def != EliminationStack.back() && - isa<Instruction>(Def) && !FromStore) - markInstructionForDeletion(cast<Instruction>(Def)); + auto *DefI = dyn_cast<Instruction>(Def); + if (!EliminationStack.empty() && DefI && !FromStore) { + Value *DominatingLeader = EliminationStack.back(); + if (DominatingLeader != Def) { + // Even if the instruction is removed, we still need to update + // flags/metadata due to downstreams users of the leader. + if (!match(DefI, m_Intrinsic<Intrinsic::ssa_copy>())) + patchReplacementInstruction(DefI, DominatingLeader); + + markInstructionForDeletion(DefI); + } + } continue; } // At this point, we know it is a Use we are trying to possibly @@ -4095,9 +4100,12 @@ bool NewGVN::eliminateInstructions(Function &F) { // For copy instructions, we use their operand as a leader, // which means we remove a user of the copy and it may become dead. if (isSSACopy) { - unsigned &IIUseCount = UseCounts[II]; - if (--IIUseCount == 0) - ProbablyDead.insert(II); + auto It = UseCounts.find(II); + if (It != UseCounts.end()) { + unsigned &IIUseCount = It->second; + if (--IIUseCount == 0) + ProbablyDead.insert(II); + } } ++LeaderUseCount; AnythingReplaced = true; diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 40c84e249523..818c7b40d489 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -466,7 +466,8 @@ using RepeatedValue = std::pair<Value*, APInt>; /// type and thus make the expression bigger. static bool LinearizeExprTree(Instruction *I, SmallVectorImpl<RepeatedValue> &Ops, - ReassociatePass::OrderedSet &ToRedo) { + ReassociatePass::OrderedSet &ToRedo, + bool &HasNUW) { assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) && "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); @@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I, std::pair<Instruction*, APInt> P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. + if (isa<OverflowingBinaryOperator>(I)) + HasNUW &= I->hasNoUnsignedWrap(); + for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); APInt Weight = P.second; // Number of paths to this operand. @@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I, /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. void ReassociatePass::RewriteExprTree(BinaryOperator *I, - SmallVectorImpl<ValueEntry> &Ops) { + SmallVectorImpl<ValueEntry> &Ops, + bool HasNUW) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -814,14 +819,20 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, if (ExpressionChangedStart) { bool ClearFlags = true; do { - // Preserve FastMathFlags. + // Preserve flags. if (ClearFlags) { if (isa<FPMathOperator>(I)) { FastMathFlags Flags = I->getFastMathFlags(); ExpressionChangedStart->clearSubclassOptionalData(); ExpressionChangedStart->setFastMathFlags(Flags); - } else + } else { ExpressionChangedStart->clearSubclassOptionalData(); + // Note that it doesn't hold for mul if one of the operands is zero. + // TODO: We can preserve NUW flag if we prove that all mul operands + // are non-zero. + if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add) + ExpressionChangedStart->setHasNoUnsignedWrap(); + } } if (ExpressionChangedStart == ExpressionChangedEnd) @@ -921,16 +932,20 @@ static Value *NegateValue(Value *V, Instruction *BI, TheNeg->getParent()->getParent() != BI->getParent()->getParent()) continue; - Instruction *InsertPt; + BasicBlock::iterator InsertPt; if (Instruction *InstInput = dyn_cast<Instruction>(V)) { - InsertPt = InstInput->getInsertionPointAfterDef(); - if (!InsertPt) + auto InsertPtOpt = InstInput->getInsertionPointAfterDef(); + if (!InsertPtOpt) continue; + InsertPt = *InsertPtOpt; } else { - InsertPt = &*TheNeg->getFunction()->getEntryBlock().begin(); + InsertPt = TheNeg->getFunction() + ->getEntryBlock() + .getFirstNonPHIOrDbg() + ->getIterator(); } - TheNeg->moveBefore(InsertPt); + TheNeg->moveBefore(*InsertPt->getParent(), InsertPt); if (TheNeg->getOpcode() == Instruction::Sub) { TheNeg->setHasNoUnsignedWrap(false); TheNeg->setHasNoSignedWrap(false); @@ -1171,7 +1186,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { return nullptr; SmallVector<RepeatedValue, 8> Tree; - MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts); + bool HasNUW = true; + MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW); SmallVector<ValueEntry, 8> Factors; Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -1213,7 +1229,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { if (!FoundFactor) { // Make sure to restore the operands to the expression tree. - RewriteExprTree(BO, Factors); + RewriteExprTree(BO, Factors, HasNUW); return nullptr; } @@ -1225,7 +1241,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { RedoInsts.insert(BO); V = Factors[0].Op; } else { - RewriteExprTree(BO, Factors); + RewriteExprTree(BO, Factors, HasNUW); V = BO; } @@ -2252,9 +2268,10 @@ void ReassociatePass::OptimizeInst(Instruction *I) { // with no common bits set, convert it to X+Y. if (I->getOpcode() == Instruction::Or && shouldConvertOrWithNoCommonBitsToAdd(I) && !isLoadCombineCandidate(I) && - haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), - I->getModule()->getDataLayout(), /*AC=*/nullptr, I, - /*DT=*/nullptr)) { + (cast<PossiblyDisjointInst>(I)->isDisjoint() || + haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), + SimplifyQuery(I->getModule()->getDataLayout(), + /*DT=*/nullptr, /*AC=*/nullptr, I)))) { Instruction *NI = convertOrWithNoCommonBitsToAdd(I); RedoInsts.insert(I); MadeChange = true; @@ -2349,7 +2366,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector<RepeatedValue, 8> Tree; - MadeChange |= LinearizeExprTree(I, Tree, RedoInsts); + bool HasNUW = true; + MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW); SmallVector<ValueEntry, 8> Ops; Ops.reserve(Tree.size()); for (const RepeatedValue &E : Tree) @@ -2542,7 +2560,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { dbgs() << '\n'); // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops); + RewriteExprTree(I, Ops, HasNUW); } void @@ -2550,7 +2568,7 @@ ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) { // Make a "pairmap" of how often each operand pair occurs. for (BasicBlock *BI : RPOT) { for (Instruction &I : *BI) { - if (!I.isAssociative()) + if (!I.isAssociative() || !I.isBinaryOp()) continue; // Ignore nodes that aren't at the root of trees. diff --git a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp index db7a1f24660c..6c2b3e9bd4a7 100644 --- a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -25,8 +25,6 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -107,36 +105,3 @@ PreservedAnalyses RegToMemPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<LoopAnalysis>(); return PA; } - -namespace { -struct RegToMemLegacy : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - RegToMemLegacy() : FunctionPass(ID) { - initializeRegToMemLegacyPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredID(BreakCriticalEdgesID); - AU.addPreservedID(BreakCriticalEdgesID); - } - - bool runOnFunction(Function &F) override { - if (F.isDeclaration() || skipFunction(F)) - return false; - return runPass(F); - } -}; -} // namespace - -char RegToMemLegacy::ID = 0; -INITIALIZE_PASS_BEGIN(RegToMemLegacy, "reg2mem", - "Demote all values to stack slots", false, false) -INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) -INITIALIZE_PASS_END(RegToMemLegacy, "reg2mem", - "Demote all values to stack slots", false, false) - -// createDemoteRegisterToMemory - Provide an entry point to create this pass. -char &llvm::DemoteRegisterToMemoryID = RegToMemLegacy::ID; -FunctionPass *llvm::createDemoteRegisterToMemoryPass() { - return new RegToMemLegacy(); -} diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 908bda5709a0..40b4ea92e1ff 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -54,15 +55,12 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" @@ -995,7 +993,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, NewState.meet(OpState); }); - BDVState OldState = States[BDV]; + BDVState OldState = Pair.second; if (OldState != NewState) { Progress = true; States[BDV] = NewState; @@ -1014,8 +1012,44 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, } #endif - // Handle all instructions that have a vector BDV, but the instruction itself - // is of scalar type. + // Even though we have identified a concrete base (or a conflict) for all live + // pointers at this point, there are cases where the base is of an + // incompatible type compared to the original instruction. We conservatively + // mark those as conflicts to ensure that corresponding BDVs will be generated + // in the next steps. + + // this is a rather explicit check for all cases where we should mark the + // state as a conflict to force the latter stages of the algorithm to emit + // the BDVs. + // TODO: in many cases the instructions emited for the conflicting states + // will be identical to the I itself (if the I's operate on their BDVs + // themselves). We should expoit this, but can't do it here since it would + // break the invariant about the BDVs not being known to be a base. + // TODO: the code also does not handle constants at all - the algorithm relies + // on all constants having the same BDV and therefore constant-only insns + // will never be in conflict, but this check is ignored here. If the + // constant conflicts will be to BDVs themselves, they will be identical + // instructions and will get optimized away (as in the above TODO) + auto MarkConflict = [&](Instruction *I, Value *BaseValue) { + // II and EE mixes vector & scalar so is always a conflict + if (isa<InsertElementInst>(I) || isa<ExtractElementInst>(I)) + return true; + // Shuffle vector is always a conflict as it creates new vector from + // existing ones. + if (isa<ShuffleVectorInst>(I)) + return true; + // Any instructions where the computed base type differs from the + // instruction type. An example is where an extract instruction is used by a + // select. Here the select's BDV is a vector (because of extract's BDV), + // while the select itself is a scalar type. Note that the IE and EE + // instruction check is not fully subsumed by the vector<->scalar check at + // the end, this is due to the BDV algorithm being ignorant of BDV types at + // this junction. + if (!areBothVectorOrScalar(BaseValue, I)) + return true; + return false; + }; + for (auto Pair : States) { Instruction *I = cast<Instruction>(Pair.first); BDVState State = Pair.second; @@ -1028,30 +1062,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); - if (!State.isBase() || !isa<VectorType>(BaseValue->getType())) + // since we only mark vec-scalar insns as conflicts in the pass, our work is + // done if the instruction already conflicts + if (State.isConflict()) continue; - // extractelement instructions are a bit special in that we may need to - // insert an extract even when we know an exact base for the instruction. - // The problem is that we need to convert from a vector base to a scalar - // base for the particular indice we're interested in. - if (isa<ExtractElementInst>(I)) { - auto *EE = cast<ExtractElementInst>(I); - // TODO: In many cases, the new instruction is just EE itself. We should - // exploit this, but can't do it here since it would break the invariant - // about the BDV not being known to be a base. - auto *BaseInst = ExtractElementInst::Create( - State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); - BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); - States[I] = BDVState(I, BDVState::Base, BaseInst); - setKnownBase(BaseInst, /* IsKnownBase */true, KnownBases); - } else if (!isa<VectorType>(I->getType())) { - // We need to handle cases that have a vector base but the instruction is - // a scalar type (these could be phis or selects or any instruction that - // are of scalar type, but the base can be a vector type). We - // conservatively set this as conflict. Setting the base value for these - // conflicts is handled in the next loop which traverses States. + + if (MarkConflict(I, BaseValue)) States[I] = BDVState(I, BDVState::Conflict); - } } #ifndef NDEBUG @@ -1234,6 +1251,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, VerifyStates(); #endif + // get the data layout to compare the sizes of base/derived pointer values + [[maybe_unused]] auto &DL = + cast<llvm::Instruction>(Def)->getModule()->getDataLayout(); // Cache all of our results so we can cheaply reuse them // NOTE: This is actually two caches: one of the base defining value // relation and one of the base pointer relation! FIXME @@ -1241,6 +1261,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, auto *BDV = Pair.first; Value *Base = Pair.second.getBaseValue(); assert(BDV && Base); + // Whenever we have a derived ptr(s), their base + // ptr(s) must be of the same size, not necessarily the same type + assert(DL.getTypeAllocSize(BDV->getType()) == + DL.getTypeAllocSize(Base->getType()) && + "Derived and base values should have same size"); // Only values that do not have known bases or those that have differing // type (scalar versus vector) from a possible known base should be in the // lattice. @@ -1425,14 +1450,15 @@ static constexpr Attribute::AttrKind FnAttrsToStrip[] = {Attribute::Memory, Attribute::NoSync, Attribute::NoFree}; // Create new attribute set containing only attributes which can be transferred -// from original call to the safepoint. -static AttributeList legalizeCallAttributes(LLVMContext &Ctx, - AttributeList OrigAL, +// from the original call to the safepoint. +static AttributeList legalizeCallAttributes(CallBase *Call, bool IsMemIntrinsic, AttributeList StatepointAL) { + AttributeList OrigAL = Call->getAttributes(); if (OrigAL.isEmpty()) return StatepointAL; // Remove the readonly, readnone, and statepoint function attributes. + LLVMContext &Ctx = Call->getContext(); AttrBuilder FnAttrs(Ctx, OrigAL.getFnAttrs()); for (auto Attr : FnAttrsToStrip) FnAttrs.removeAttribute(Attr); @@ -1442,8 +1468,24 @@ static AttributeList legalizeCallAttributes(LLVMContext &Ctx, FnAttrs.removeAttribute(A); } - // Just skip parameter and return attributes for now - return StatepointAL.addFnAttributes(Ctx, FnAttrs); + StatepointAL = StatepointAL.addFnAttributes(Ctx, FnAttrs); + + // The memory intrinsics do not have a 1:1 correspondence of the original + // call arguments to the produced statepoint. Do not transfer the argument + // attributes to avoid putting them on incorrect arguments. + if (IsMemIntrinsic) + return StatepointAL; + + // Attach the argument attributes from the original call at the corresponding + // arguments in the statepoint. Note that any argument attributes that are + // invalid after lowering are stripped in stripNonValidDataFromBody. + for (unsigned I : llvm::seq(Call->arg_size())) + StatepointAL = StatepointAL.addParamAttributes( + Ctx, GCStatepointInst::CallArgsBeginPos + I, + AttrBuilder(Ctx, OrigAL.getParamAttrs(I))); + + // Return attributes are later attached to the gc.result intrinsic. + return StatepointAL; } /// Helper function to place all gc relocates necessary for the given @@ -1480,7 +1522,7 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, auto getGCRelocateDecl = [&](Type *Ty) { assert(isHandledGCPointerType(Ty, GC)); auto AS = Ty->getScalarType()->getPointerAddressSpace(); - Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS); + Type *NewTy = PointerType::get(M->getContext(), AS); if (auto *VT = dyn_cast<VectorType>(Ty)) NewTy = FixedVectorType::get(NewTy, cast<FixedVectorType>(VT)->getNumElements()); @@ -1633,6 +1675,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // with a return value, we lower then as never returning calls to // __llvm_deoptimize that are followed by unreachable to get better codegen. bool IsDeoptimize = false; + bool IsMemIntrinsic = false; StatepointDirectives SD = parseStatepointDirectivesFromAttrs(Call->getAttributes()); @@ -1673,6 +1716,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ IsDeoptimize = true; } else if (IID == Intrinsic::memcpy_element_unordered_atomic || IID == Intrinsic::memmove_element_unordered_atomic) { + IsMemIntrinsic = true; + // Unordered atomic memcpy and memmove intrinsics which are not explicitly // marked as "gc-leaf-function" should be lowered in a GC parseable way. // Specifically, these calls should be lowered to the @@ -1788,12 +1833,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ SPCall->setTailCallKind(CI->getTailCallKind()); SPCall->setCallingConv(CI->getCallingConv()); - // Currently we will fail on parameter attributes and on certain - // function attributes. In case if we can handle this set of attributes - - // set up function attrs directly on statepoint and return attrs later for + // Set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - SPCall->setAttributes(legalizeCallAttributes( - CI->getContext(), CI->getAttributes(), SPCall->getAttributes())); + SPCall->setAttributes( + legalizeCallAttributes(CI, IsMemIntrinsic, SPCall->getAttributes())); Token = cast<GCStatepointInst>(SPCall); @@ -1815,12 +1858,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ SPInvoke->setCallingConv(II->getCallingConv()); - // Currently we will fail on parameter attributes and on certain - // function attributes. In case if we can handle this set of attributes - - // set up function attrs directly on statepoint and return attrs later for + // Set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - SPInvoke->setAttributes(legalizeCallAttributes( - II->getContext(), II->getAttributes(), SPInvoke->getAttributes())); + SPInvoke->setAttributes( + legalizeCallAttributes(II, IsMemIntrinsic, SPInvoke->getAttributes())); Token = cast<GCStatepointInst>(SPInvoke); @@ -1830,7 +1871,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ UnwindBlock->getUniquePredecessor() && "can't safely insert in this block!"); - Builder.SetInsertPoint(&*UnwindBlock->getFirstInsertionPt()); + Builder.SetInsertPoint(UnwindBlock, UnwindBlock->getFirstInsertionPt()); Builder.SetCurrentDebugLocation(II->getDebugLoc()); // Attach exceptional gc relocates to the landingpad. @@ -1845,7 +1886,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ NormalDest->getUniquePredecessor() && "can't safely insert in this block!"); - Builder.SetInsertPoint(&*NormalDest->getFirstInsertionPt()); + Builder.SetInsertPoint(NormalDest, NormalDest->getFirstInsertionPt()); // gc relocates will be generated later as if it were regular call // statepoint diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp index fcdc503c54a4..69679b608f8d 100644 --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -17,10 +17,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/SCCP.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -51,7 +48,6 @@ #include "llvm/Transforms/Utils/SCCPSolver.h" #include <cassert> #include <utility> -#include <vector> using namespace llvm; diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 983a75e1d708..f578762d2b49 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -70,6 +71,7 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" @@ -91,10 +93,10 @@ #include <string> #include <tuple> #include <utility> +#include <variant> #include <vector> using namespace llvm; -using namespace llvm::sroa; #define DEBUG_TYPE "sroa" @@ -123,6 +125,138 @@ static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false), cl::Hidden); namespace { +class AllocaSliceRewriter; +class AllocaSlices; +class Partition; + +class SelectHandSpeculativity { + unsigned char Storage = 0; // None are speculatable by default. + using TrueVal = Bitfield::Element<bool, 0, 1>; // Low 0'th bit. + using FalseVal = Bitfield::Element<bool, 1, 1>; // Low 1'th bit. +public: + SelectHandSpeculativity() = default; + SelectHandSpeculativity &setAsSpeculatable(bool isTrueVal); + bool isSpeculatable(bool isTrueVal) const; + bool areAllSpeculatable() const; + bool areAnySpeculatable() const; + bool areNoneSpeculatable() const; + // For interop as int half of PointerIntPair. + explicit operator intptr_t() const { return static_cast<intptr_t>(Storage); } + explicit SelectHandSpeculativity(intptr_t Storage_) : Storage(Storage_) {} +}; +static_assert(sizeof(SelectHandSpeculativity) == sizeof(unsigned char)); + +using PossiblySpeculatableLoad = + PointerIntPair<LoadInst *, 2, SelectHandSpeculativity>; +using UnspeculatableStore = StoreInst *; +using RewriteableMemOp = + std::variant<PossiblySpeculatableLoad, UnspeculatableStore>; +using RewriteableMemOps = SmallVector<RewriteableMemOp, 2>; + +/// An optimization pass providing Scalar Replacement of Aggregates. +/// +/// This pass takes allocations which can be completely analyzed (that is, they +/// don't escape) and tries to turn them into scalar SSA values. There are +/// a few steps to this process. +/// +/// 1) It takes allocations of aggregates and analyzes the ways in which they +/// are used to try to split them into smaller allocations, ideally of +/// a single scalar data type. It will split up memcpy and memset accesses +/// as necessary and try to isolate individual scalar accesses. +/// 2) It will transform accesses into forms which are suitable for SSA value +/// promotion. This can be replacing a memset with a scalar store of an +/// integer value, or it can involve speculating operations on a PHI or +/// select to be a PHI or select of the results. +/// 3) Finally, this will try to detect a pattern of accesses which map cleanly +/// onto insert and extract operations on a vector value, and convert them to +/// this form. By doing so, it will enable promotion of vector aggregates to +/// SSA vector values. +class SROA { + LLVMContext *const C; + DomTreeUpdater *const DTU; + AssumptionCache *const AC; + const bool PreserveCFG; + + /// Worklist of alloca instructions to simplify. + /// + /// Each alloca in the function is added to this. Each new alloca formed gets + /// added to it as well to recursively simplify unless that alloca can be + /// directly promoted. Finally, each time we rewrite a use of an alloca other + /// the one being actively rewritten, we add it back onto the list if not + /// already present to ensure it is re-visited. + SmallSetVector<AllocaInst *, 16> Worklist; + + /// A collection of instructions to delete. + /// We try to batch deletions to simplify code and make things a bit more + /// efficient. We also make sure there is no dangling pointers. + SmallVector<WeakVH, 8> DeadInsts; + + /// Post-promotion worklist. + /// + /// Sometimes we discover an alloca which has a high probability of becoming + /// viable for SROA after a round of promotion takes place. In those cases, + /// the alloca is enqueued here for re-processing. + /// + /// Note that we have to be very careful to clear allocas out of this list in + /// the event they are deleted. + SmallSetVector<AllocaInst *, 16> PostPromotionWorklist; + + /// A collection of alloca instructions we can directly promote. + std::vector<AllocaInst *> PromotableAllocas; + + /// A worklist of PHIs to speculate prior to promoting allocas. + /// + /// All of these PHIs have been checked for the safety of speculation and by + /// being speculated will allow promoting allocas currently in the promotable + /// queue. + SmallSetVector<PHINode *, 8> SpeculatablePHIs; + + /// A worklist of select instructions to rewrite prior to promoting + /// allocas. + SmallMapVector<SelectInst *, RewriteableMemOps, 8> SelectsToRewrite; + + /// Select instructions that use an alloca and are subsequently loaded can be + /// rewritten to load both input pointers and then select between the result, + /// allowing the load of the alloca to be promoted. + /// From this: + /// %P2 = select i1 %cond, ptr %Alloca, ptr %Other + /// %V = load <type>, ptr %P2 + /// to: + /// %V1 = load <type>, ptr %Alloca -> will be mem2reg'd + /// %V2 = load <type>, ptr %Other + /// %V = select i1 %cond, <type> %V1, <type> %V2 + /// + /// We can do this to a select if its only uses are loads + /// and if either the operand to the select can be loaded unconditionally, + /// or if we are allowed to perform CFG modifications. + /// If found an intervening bitcast with a single use of the load, + /// allow the promotion. + static std::optional<RewriteableMemOps> + isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG); + +public: + SROA(LLVMContext *C, DomTreeUpdater *DTU, AssumptionCache *AC, + SROAOptions PreserveCFG_) + : C(C), DTU(DTU), AC(AC), + PreserveCFG(PreserveCFG_ == SROAOptions::PreserveCFG) {} + + /// Main run method used by both the SROAPass and by the legacy pass. + std::pair<bool /*Changed*/, bool /*CFGChanged*/> runSROA(Function &F); + +private: + friend class AllocaSliceRewriter; + + bool presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS); + AllocaInst *rewritePartition(AllocaInst &AI, AllocaSlices &AS, Partition &P); + bool splitAlloca(AllocaInst &AI, AllocaSlices &AS); + std::pair<bool /*Changed*/, bool /*CFGChanged*/> runOnAlloca(AllocaInst &AI); + void clobberUse(Use &U); + bool deleteDeadInstructions(SmallPtrSetImpl<AllocaInst *> &DeletedAllocas); + bool promoteAllocas(Function &F); +}; + +} // end anonymous namespace + /// Calculate the fragment of a variable to use when slicing a store /// based on the slice dimensions, existing fragment, and base storage /// fragment. @@ -131,7 +265,9 @@ namespace { /// UseNoFrag - The new slice already covers the whole variable. /// Skip - The new alloca slice doesn't include this variable. /// FIXME: Can we use calculateFragmentIntersect instead? +namespace { enum FragCalcResult { UseFrag, UseNoFrag, Skip }; +} static FragCalcResult calculateFragment(DILocalVariable *Variable, uint64_t NewStorageSliceOffsetInBits, @@ -330,6 +466,8 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, } } +namespace { + /// A custom IRBuilder inserter which prefixes all names, but only in /// Assert builds. class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter { @@ -422,8 +560,6 @@ public: bool operator!=(const Slice &RHS) const { return !operator==(RHS); } }; -} // end anonymous namespace - /// Representation of the alloca slices. /// /// This class represents the slices of an alloca which are formed by its @@ -431,7 +567,7 @@ public: /// for the slices used and we reflect that in this structure. The uses are /// stored, sorted by increasing beginning offset and with unsplittable slices /// starting at a particular offset before splittable slices. -class llvm::sroa::AllocaSlices { +class AllocaSlices { public: /// Construct the slices of a particular alloca. AllocaSlices(const DataLayout &DL, AllocaInst &AI); @@ -563,7 +699,7 @@ private: /// /// Objects of this type are produced by traversing the alloca's slices, but /// are only ephemeral and not persistent. -class llvm::sroa::Partition { +class Partition { private: friend class AllocaSlices; friend class AllocaSlices::partition_iterator; @@ -628,6 +764,8 @@ public: ArrayRef<Slice *> splitSliceTails() const { return SplitTails; } }; +} // end anonymous namespace + /// An iterator over partitions of the alloca's slices. /// /// This iterator implements the core algorithm for partitioning the alloca's @@ -1144,6 +1282,7 @@ private: } if (II.isLaunderOrStripInvariantGroup()) { + insertUse(II, Offset, AllocSize, true); enqueueUsers(II); return; } @@ -1169,16 +1308,24 @@ private: std::tie(UsedI, I) = Uses.pop_back_val(); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - Size = - std::max(Size, DL.getTypeStoreSize(LI->getType()).getFixedValue()); + TypeSize LoadSize = DL.getTypeStoreSize(LI->getType()); + if (LoadSize.isScalable()) { + PI.setAborted(LI); + return nullptr; + } + Size = std::max(Size, LoadSize.getFixedValue()); continue; } if (StoreInst *SI = dyn_cast<StoreInst>(I)) { Value *Op = SI->getOperand(0); if (Op == UsedI) return SI; - Size = - std::max(Size, DL.getTypeStoreSize(Op->getType()).getFixedValue()); + TypeSize StoreSize = DL.getTypeStoreSize(Op->getType()); + if (StoreSize.isScalable()) { + PI.setAborted(SI); + return nullptr; + } + Size = std::max(Size, StoreSize.getFixedValue()); continue; } @@ -1525,38 +1672,37 @@ static void speculatePHINodeLoads(IRBuilderTy &IRB, PHINode &PN) { PN.eraseFromParent(); } -sroa::SelectHandSpeculativity & -sroa::SelectHandSpeculativity::setAsSpeculatable(bool isTrueVal) { +SelectHandSpeculativity & +SelectHandSpeculativity::setAsSpeculatable(bool isTrueVal) { if (isTrueVal) - Bitfield::set<sroa::SelectHandSpeculativity::TrueVal>(Storage, true); + Bitfield::set<SelectHandSpeculativity::TrueVal>(Storage, true); else - Bitfield::set<sroa::SelectHandSpeculativity::FalseVal>(Storage, true); + Bitfield::set<SelectHandSpeculativity::FalseVal>(Storage, true); return *this; } -bool sroa::SelectHandSpeculativity::isSpeculatable(bool isTrueVal) const { - return isTrueVal - ? Bitfield::get<sroa::SelectHandSpeculativity::TrueVal>(Storage) - : Bitfield::get<sroa::SelectHandSpeculativity::FalseVal>(Storage); +bool SelectHandSpeculativity::isSpeculatable(bool isTrueVal) const { + return isTrueVal ? Bitfield::get<SelectHandSpeculativity::TrueVal>(Storage) + : Bitfield::get<SelectHandSpeculativity::FalseVal>(Storage); } -bool sroa::SelectHandSpeculativity::areAllSpeculatable() const { +bool SelectHandSpeculativity::areAllSpeculatable() const { return isSpeculatable(/*isTrueVal=*/true) && isSpeculatable(/*isTrueVal=*/false); } -bool sroa::SelectHandSpeculativity::areAnySpeculatable() const { +bool SelectHandSpeculativity::areAnySpeculatable() const { return isSpeculatable(/*isTrueVal=*/true) || isSpeculatable(/*isTrueVal=*/false); } -bool sroa::SelectHandSpeculativity::areNoneSpeculatable() const { +bool SelectHandSpeculativity::areNoneSpeculatable() const { return !areAnySpeculatable(); } -static sroa::SelectHandSpeculativity +static SelectHandSpeculativity isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) { assert(LI.isSimple() && "Only for simple loads"); - sroa::SelectHandSpeculativity Spec; + SelectHandSpeculativity Spec; const DataLayout &DL = SI.getModule()->getDataLayout(); for (Value *Value : {SI.getTrueValue(), SI.getFalseValue()}) @@ -1569,8 +1715,8 @@ isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) { return Spec; } -std::optional<sroa::RewriteableMemOps> -SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { +std::optional<RewriteableMemOps> +SROA::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { RewriteableMemOps Ops; for (User *U : SI.users()) { @@ -1604,7 +1750,7 @@ SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { continue; } - sroa::SelectHandSpeculativity Spec = + SelectHandSpeculativity Spec = isSafeLoadOfSelectToSpeculate(*LI, SI, PreserveCFG); if (PreserveCFG && !Spec.areAllSpeculatable()) return {}; // Give up on this `select`. @@ -1655,7 +1801,7 @@ static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI, template <typename T> static void rewriteMemOpOfSelect(SelectInst &SI, T &I, - sroa::SelectHandSpeculativity Spec, + SelectHandSpeculativity Spec, DomTreeUpdater &DTU) { assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Only for load and store!"); LLVM_DEBUG(dbgs() << " original mem op: " << I << "\n"); @@ -1711,7 +1857,7 @@ static void rewriteMemOpOfSelect(SelectInst &SI, T &I, } static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I, - sroa::SelectHandSpeculativity Spec, + SelectHandSpeculativity Spec, DomTreeUpdater &DTU) { if (auto *LI = dyn_cast<LoadInst>(&I)) rewriteMemOpOfSelect(SelInst, *LI, Spec, DTU); @@ -1722,13 +1868,13 @@ static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I, } static bool rewriteSelectInstMemOps(SelectInst &SI, - const sroa::RewriteableMemOps &Ops, + const RewriteableMemOps &Ops, IRBuilderTy &IRB, DomTreeUpdater *DTU) { bool CFGChanged = false; LLVM_DEBUG(dbgs() << " original select: " << SI << "\n"); for (const RewriteableMemOp &Op : Ops) { - sroa::SelectHandSpeculativity Spec; + SelectHandSpeculativity Spec; Instruction *I; if (auto *const *US = std::get_if<UnspeculatableStore>(&Op)) { I = *US; @@ -2421,14 +2567,15 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, return V; } +namespace { + /// Visitor to rewrite instructions using p particular slice of an alloca /// to use a new alloca. /// /// Also implements the rewriting to vector-based accesses when the partition /// passes the isVectorPromotionViable predicate. Most of the rewriting logic /// lives here. -class llvm::sroa::AllocaSliceRewriter - : public InstVisitor<AllocaSliceRewriter, bool> { +class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> { // Befriend the base class so it can delegate to private visit methods. friend class InstVisitor<AllocaSliceRewriter, bool>; @@ -2436,7 +2583,7 @@ class llvm::sroa::AllocaSliceRewriter const DataLayout &DL; AllocaSlices &AS; - SROAPass &Pass; + SROA &Pass; AllocaInst &OldAI, &NewAI; const uint64_t NewAllocaBeginOffset, NewAllocaEndOffset; Type *NewAllocaTy; @@ -2489,12 +2636,12 @@ class llvm::sroa::AllocaSliceRewriter if (!IsVolatile || AddrSpace == NewAI.getType()->getPointerAddressSpace()) return &NewAI; - Type *AccessTy = NewAI.getAllocatedType()->getPointerTo(AddrSpace); + Type *AccessTy = IRB.getPtrTy(AddrSpace); return IRB.CreateAddrSpaceCast(&NewAI, AccessTy); } public: - AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROAPass &Pass, + AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROA &Pass, AllocaInst &OldAI, AllocaInst &NewAI, uint64_t NewAllocaBeginOffset, uint64_t NewAllocaEndOffset, bool IsIntegerPromotable, @@ -2697,7 +2844,7 @@ private: NewEndOffset == NewAllocaEndOffset && (canConvertValue(DL, NewAllocaTy, TargetTy) || (IsLoadPastEnd && NewAllocaTy->isIntegerTy() && - TargetTy->isIntegerTy()))) { + TargetTy->isIntegerTy() && !LI.isVolatile()))) { Value *NewPtr = getPtrToNewAI(LI.getPointerAddressSpace(), LI.isVolatile()); LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), NewPtr, @@ -2732,7 +2879,7 @@ private: "endian_shift"); } } else { - Type *LTy = TargetTy->getPointerTo(AS); + Type *LTy = IRB.getPtrTy(AS); LoadInst *NewLI = IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(), LI.isVolatile(), LI.getName()); @@ -2762,9 +2909,9 @@ private: // basis for the new value. This allows us to replace the uses of LI with // the computed value, and then replace the placeholder with LI, leaving // LI only used for this computation. - Value *Placeholder = new LoadInst( - LI.getType(), PoisonValue::get(LI.getType()->getPointerTo(AS)), "", - false, Align(1)); + Value *Placeholder = + new LoadInst(LI.getType(), PoisonValue::get(IRB.getPtrTy(AS)), "", + false, Align(1)); V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset, "insert"); LI.replaceAllUsesWith(V); @@ -2875,26 +3022,10 @@ private: if (IntTy && V->getType()->isIntegerTy()) return rewriteIntegerStore(V, SI, AATags); - const bool IsStorePastEnd = - DL.getTypeStoreSize(V->getType()).getFixedValue() > SliceSize; StoreInst *NewSI; if (NewBeginOffset == NewAllocaBeginOffset && NewEndOffset == NewAllocaEndOffset && - (canConvertValue(DL, V->getType(), NewAllocaTy) || - (IsStorePastEnd && NewAllocaTy->isIntegerTy() && - V->getType()->isIntegerTy()))) { - // If this is an integer store past the end of slice (and thus the bytes - // past that point are irrelevant or this is unreachable), truncate the - // value prior to storing. - if (auto *VITy = dyn_cast<IntegerType>(V->getType())) - if (auto *AITy = dyn_cast<IntegerType>(NewAllocaTy)) - if (VITy->getBitWidth() > AITy->getBitWidth()) { - if (DL.isBigEndian()) - V = IRB.CreateLShr(V, VITy->getBitWidth() - AITy->getBitWidth(), - "endian_shift"); - V = IRB.CreateTrunc(V, AITy, "load.trunc"); - } - + canConvertValue(DL, V->getType(), NewAllocaTy)) { V = convertValue(DL, IRB, V, NewAllocaTy); Value *NewPtr = getPtrToNewAI(SI.getPointerAddressSpace(), SI.isVolatile()); @@ -2903,7 +3034,7 @@ private: IRB.CreateAlignedStore(V, NewPtr, NewAI.getAlign(), SI.isVolatile()); } else { unsigned AS = SI.getPointerAddressSpace(); - Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS)); + Value *NewPtr = getNewAllocaSlicePtr(IRB, IRB.getPtrTy(AS)); NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(), SI.isVolatile()); } @@ -3126,8 +3257,7 @@ private: if (IsDest) { // Update the address component of linked dbg.assigns. for (auto *DAI : at::getAssignmentMarkers(&II)) { - if (any_of(DAI->location_ops(), - [&](Value *V) { return V == II.getDest(); }) || + if (llvm::is_contained(DAI->location_ops(), II.getDest()) || DAI->getAddress() == II.getDest()) DAI->replaceVariableLocationOp(II.getDest(), AdjustedPtr); } @@ -3259,7 +3389,6 @@ private: } else { OtherTy = NewAllocaTy; } - OtherPtrTy = OtherTy->getPointerTo(OtherAS); Value *AdjPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy, OtherPtr->getName() + "."); @@ -3337,7 +3466,8 @@ private: } bool visitIntrinsicInst(IntrinsicInst &II) { - assert((II.isLifetimeStartOrEnd() || II.isDroppable()) && + assert((II.isLifetimeStartOrEnd() || II.isLaunderOrStripInvariantGroup() || + II.isDroppable()) && "Unexpected intrinsic!"); LLVM_DEBUG(dbgs() << " original: " << II << "\n"); @@ -3351,6 +3481,9 @@ private: return true; } + if (II.isLaunderOrStripInvariantGroup()) + return true; + assert(II.getArgOperand(1) == OldPtr); // Lifetime intrinsics are only promotable if they cover the whole alloca. // Therefore, we drop lifetime intrinsics which don't cover the whole @@ -3368,7 +3501,7 @@ private: NewEndOffset - NewBeginOffset); // Lifetime intrinsics always expect an i8* so directly get such a pointer // for the new alloca slice. - Type *PointerTy = IRB.getInt8PtrTy(OldPtr->getType()->getPointerAddressSpace()); + Type *PointerTy = IRB.getPtrTy(OldPtr->getType()->getPointerAddressSpace()); Value *Ptr = getNewAllocaSlicePtr(IRB, PointerTy); Value *New; if (II.getIntrinsicID() == Intrinsic::lifetime_start) @@ -3422,7 +3555,8 @@ private: // dominate the PHI. IRBuilderBase::InsertPointGuard Guard(IRB); if (isa<PHINode>(OldPtr)) - IRB.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt()); + IRB.SetInsertPoint(OldPtr->getParent(), + OldPtr->getParent()->getFirstInsertionPt()); else IRB.SetInsertPoint(OldPtr); IRB.SetCurrentDebugLocation(OldPtr->getDebugLoc()); @@ -3472,8 +3606,6 @@ private: } }; -namespace { - /// Visitor to rewrite aggregate loads and stores as scalar. /// /// This pass aggressively rewrites all aggregate loads and stores on @@ -3811,7 +3943,7 @@ private: SmallVector<Value *, 4> Index(GEPI.indices()); bool IsInBounds = GEPI.isInBounds(); - IRB.SetInsertPoint(GEPI.getParent()->getFirstNonPHI()); + IRB.SetInsertPoint(GEPI.getParent(), GEPI.getParent()->getFirstNonPHIIt()); PHINode *NewPN = IRB.CreatePHI(GEPI.getType(), PHI->getNumIncomingValues(), PHI->getName() + ".sroa.phi"); for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { @@ -4046,7 +4178,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, /// there all along. /// /// \returns true if any changes are made. -bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { +bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { LLVM_DEBUG(dbgs() << "Pre-splitting loads and stores\n"); // Track the loads and stores which are candidates for pre-splitting here, in @@ -4268,7 +4400,7 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { for (;;) { auto *PartTy = Type::getIntNTy(LI->getContext(), PartSize * 8); auto AS = LI->getPointerAddressSpace(); - auto *PartPtrTy = PartTy->getPointerTo(AS); + auto *PartPtrTy = LI->getPointerOperandType(); LoadInst *PLoad = IRB.CreateAlignedLoad( PartTy, getAdjustedPtr(IRB, DL, BasePtr, @@ -4323,8 +4455,7 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { for (int Idx = 0, Size = SplitLoads.size(); Idx < Size; ++Idx) { LoadInst *PLoad = SplitLoads[Idx]; uint64_t PartOffset = Idx == 0 ? 0 : Offsets.Splits[Idx - 1]; - auto *PartPtrTy = - PLoad->getType()->getPointerTo(SI->getPointerAddressSpace()); + auto *PartPtrTy = SI->getPointerOperandType(); auto AS = SI->getPointerAddressSpace(); StoreInst *PStore = IRB.CreateAlignedStore( @@ -4404,8 +4535,8 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { int Idx = 0, Size = Offsets.Splits.size(); for (;;) { auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); - auto *LoadPartPtrTy = PartTy->getPointerTo(LI->getPointerAddressSpace()); - auto *StorePartPtrTy = PartTy->getPointerTo(SI->getPointerAddressSpace()); + auto *LoadPartPtrTy = LI->getPointerOperandType(); + auto *StorePartPtrTy = SI->getPointerOperandType(); // Either lookup a split load or create one. LoadInst *PLoad; @@ -4526,8 +4657,8 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { /// appropriate new offsets. It also evaluates how successful the rewrite was /// at enabling promotion and if it was successful queues the alloca to be /// promoted. -AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, - Partition &P) { +AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, + Partition &P) { // Try to compute a friendly type for this partition of the alloca. This // won't always succeed, in which case we fall back to a legal integer type // or an i8 array of an appropriate size. @@ -4709,7 +4840,7 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, /// Walks the slices of an alloca and form partitions based on them, /// rewriting each of their uses. -bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { +bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { if (AS.begin() == AS.end()) return false; @@ -4900,7 +5031,7 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } /// Clobber a use with poison, deleting the used value if it becomes dead. -void SROAPass::clobberUse(Use &U) { +void SROA::clobberUse(Use &U) { Value *OldV = U; // Replace the use with an poison value. U = PoisonValue::get(OldV->getType()); @@ -4920,7 +5051,7 @@ void SROAPass::clobberUse(Use &U) { /// the slices of the alloca, and then hands it off to be split and /// rewritten as needed. std::pair<bool /*Changed*/, bool /*CFGChanged*/> -SROAPass::runOnAlloca(AllocaInst &AI) { +SROA::runOnAlloca(AllocaInst &AI) { bool Changed = false; bool CFGChanged = false; @@ -5002,7 +5133,7 @@ SROAPass::runOnAlloca(AllocaInst &AI) { /// /// We also record the alloca instructions deleted here so that they aren't /// subsequently handed to mem2reg to promote. -bool SROAPass::deleteDeadInstructions( +bool SROA::deleteDeadInstructions( SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) { bool Changed = false; while (!DeadInsts.empty()) { @@ -5043,7 +5174,7 @@ bool SROAPass::deleteDeadInstructions( /// This attempts to promote whatever allocas have been identified as viable in /// the PromotableAllocas list. If that list is empty, there is nothing to do. /// This function returns whether any promotion occurred. -bool SROAPass::promoteAllocas(Function &F) { +bool SROA::promoteAllocas(Function &F) { if (PromotableAllocas.empty()) return false; @@ -5060,12 +5191,8 @@ bool SROAPass::promoteAllocas(Function &F) { return true; } -PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU, - AssumptionCache &RunAC) { +std::pair<bool /*Changed*/, bool /*CFGChanged*/> SROA::runSROA(Function &F) { LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); - C = &F.getContext(); - DTU = &RunDTU; - AC = &RunAC; const DataLayout &DL = F.getParent()->getDataLayout(); BasicBlock &EntryBB = F.getEntryBlock(); @@ -5116,56 +5243,50 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU, assert((!CFGChanged || !PreserveCFG) && "Should not have modified the CFG when told to preserve it."); - if (!Changed) - return PreservedAnalyses::all(); - - if (isAssignmentTrackingEnabled(*F.getParent())) { + if (Changed && isAssignmentTrackingEnabled(*F.getParent())) { for (auto &BB : F) RemoveRedundantDbgInstrs(&BB); } - PreservedAnalyses PA; - if (!CFGChanged) - PA.preserveSet<CFGAnalyses>(); - PA.preserve<DominatorTreeAnalysis>(); - return PA; -} - -PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, - AssumptionCache &RunAC) { - DomTreeUpdater DTU(RunDT, DomTreeUpdater::UpdateStrategy::Lazy); - return runImpl(F, DTU, RunAC); + return {Changed, CFGChanged}; } PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) { DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); - return runImpl(F, DT, AC); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + auto [Changed, CFGChanged] = + SROA(&F.getContext(), &DTU, &AC, PreserveCFG).runSROA(F); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + if (!CFGChanged) + PA.preserveSet<CFGAnalyses>(); + PA.preserve<DominatorTreeAnalysis>(); + return PA; } void SROAPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<SROAPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << (PreserveCFG ? "<preserve-cfg>" : "<modify-cfg>"); + OS << (PreserveCFG == SROAOptions::PreserveCFG ? "<preserve-cfg>" + : "<modify-cfg>"); } -SROAPass::SROAPass(SROAOptions PreserveCFG_) - : PreserveCFG(PreserveCFG_ == SROAOptions::PreserveCFG) {} +SROAPass::SROAPass(SROAOptions PreserveCFG) : PreserveCFG(PreserveCFG) {} + +namespace { /// A legacy pass for the legacy pass manager that wraps the \c SROA pass. -/// -/// This is in the llvm namespace purely to allow it to be a friend of the \c -/// SROA pass. -class llvm::sroa::SROALegacyPass : public FunctionPass { - /// The SROA implementation. - SROAPass Impl; +class SROALegacyPass : public FunctionPass { + SROAOptions PreserveCFG; public: static char ID; SROALegacyPass(SROAOptions PreserveCFG = SROAOptions::PreserveCFG) - : FunctionPass(ID), Impl(PreserveCFG) { + : FunctionPass(ID), PreserveCFG(PreserveCFG) { initializeSROALegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -5173,10 +5294,13 @@ public: if (skipFunction(F)) return false; - auto PA = Impl.runImpl( - F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F)); - return !PA.areAllPreserved(); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + auto [Changed, _] = + SROA(&F.getContext(), &DTU, &AC, PreserveCFG).runSROA(F); + return Changed; } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -5189,6 +5313,8 @@ public: StringRef getPassName() const override { return "SROA"; } }; +} // end anonymous namespace + char SROALegacyPass::ID = 0; FunctionPass *llvm::createSROAPass(bool PreserveCFG) { diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp index 37b032e4d7c7..4ce6ce93be33 100644 --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -21,41 +21,27 @@ using namespace llvm; void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeConstantHoistingLegacyPassPass(Registry); initializeDCELegacyPassPass(Registry); - initializeScalarizerLegacyPassPass(Registry); - initializeGuardWideningLegacyPassPass(Registry); - initializeLoopGuardWideningLegacyPassPass(Registry); initializeGVNLegacyPassPass(Registry); initializeEarlyCSELegacyPassPass(Registry); initializeEarlyCSEMemSSALegacyPassPass(Registry); - initializeMakeGuardsExplicitLegacyPassPass(Registry); initializeFlattenCFGLegacyPassPass(Registry); initializeInferAddressSpacesPass(Registry); initializeInstSimplifyLegacyPassPass(Registry); initializeLegacyLICMPassPass(Registry); - initializeLegacyLoopSinkPassPass(Registry); initializeLoopDataPrefetchLegacyPassPass(Registry); - initializeLoopInstSimplifyLegacyPassPass(Registry); - initializeLoopPredicationLegacyPassPass(Registry); initializeLoopRotateLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); initializeLoopUnrollPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); initializeLowerConstantIntrinsicsPass(Registry); - initializeLowerExpectIntrinsicPass(Registry); - initializeLowerGuardIntrinsicLegacyPassPass(Registry); - initializeLowerWidenableConditionLegacyPassPass(Registry); initializeMergeICmpsLegacyPassPass(Registry); - initializeMergedLoadStoreMotionLegacyPassPass(Registry); initializeNaryReassociateLegacyPassPass(Registry); initializePartiallyInlineLibCallsLegacyPassPass(Registry); initializeReassociateLegacyPassPass(Registry); - initializeRedundantDbgInstEliminationPass(Registry); - initializeRegToMemLegacyPass(Registry); initializeScalarizeMaskedMemIntrinLegacyPassPass(Registry); initializeSROALegacyPassPass(Registry); initializeCFGSimplifyPassPass(Registry); initializeStructurizeCFGLegacyPassPass(Registry); - initializeSimpleLoopUnswitchLegacyPassPass(Registry); initializeSinkingLegacyPassPass(Registry); initializeTailCallElimPass(Registry); initializeTLSVariableHoistLegacyPassPass(Registry); @@ -63,5 +49,4 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeSpeculativeExecutionLegacyPassPass(Registry); initializeStraightLineStrengthReduceLegacyPassPass(Registry); initializePlaceBackedgeSafepointsLegacyPassPass(Registry); - initializeLoopSimplifyCFGLegacyPassPass(Registry); } diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 86b55dfd304a..3eca9ac7c267 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -36,8 +36,6 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/Local.h" @@ -282,12 +280,10 @@ T getWithDefaultOverride(const cl::opt<T> &ClOption, class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> { public: - ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT, - ScalarizerPassOptions Options) - : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT), - ScalarizeVariableInsertExtract( - getWithDefaultOverride(ClScalarizeVariableInsertExtract, - Options.ScalarizeVariableInsertExtract)), + ScalarizerVisitor(DominatorTree *DT, ScalarizerPassOptions Options) + : DT(DT), ScalarizeVariableInsertExtract(getWithDefaultOverride( + ClScalarizeVariableInsertExtract, + Options.ScalarizeVariableInsertExtract)), ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore, Options.ScalarizeLoadStore)), ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits, @@ -337,8 +333,6 @@ private: SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; - unsigned ParallelLoopAccessMDKind; - DominatorTree *DT; const bool ScalarizeVariableInsertExtract; @@ -346,31 +340,8 @@ private: const unsigned ScalarizeMinBits; }; -class ScalarizerLegacyPass : public FunctionPass { -public: - static char ID; - - ScalarizerLegacyPass() : FunctionPass(ID) { - initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage& AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - } -}; - } // end anonymous namespace -char ScalarizerLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer", - "Scalarize vector operations", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer", - "Scalarize vector operations", false, false) - Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, const VectorSplit &VS, ValueVector *cachePtr) : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) { @@ -443,22 +414,6 @@ Value *Scatterer::operator[](unsigned Frag) { return CV[Frag]; } -bool ScalarizerLegacyPass::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - - Module &M = *F.getParent(); - unsigned ParallelLoopAccessMDKind = - M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, ScalarizerPassOptions()); - return Impl.visit(F); -} - -FunctionPass *llvm::createScalarizerPass() { - return new ScalarizerLegacyPass(); -} - bool ScalarizerVisitor::visit(Function &F) { assert(Gathered.empty() && Scattered.empty()); @@ -558,7 +513,7 @@ bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) { || Tag == LLVMContext::MD_invariant_load || Tag == LLVMContext::MD_alias_scope || Tag == LLVMContext::MD_noalias - || Tag == ParallelLoopAccessMDKind + || Tag == LLVMContext::MD_mem_parallel_loop_access || Tag == LLVMContext::MD_access_group); } @@ -730,7 +685,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { // vector type, which is true for all current intrinsics. for (unsigned I = 0; I != NumArgs; ++I) { Value *OpI = CI.getOperand(I); - if (auto *OpVecTy = dyn_cast<FixedVectorType>(OpI->getType())) { + if ([[maybe_unused]] auto *OpVecTy = + dyn_cast<FixedVectorType>(OpI->getType())) { assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements()); std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType()); if (!OpVS || OpVS->NumPacked != VS->NumPacked) { @@ -1253,11 +1209,8 @@ bool ScalarizerVisitor::finish() { } PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) { - Module &M = *F.getParent(); - unsigned ParallelLoopAccessMDKind = - M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); - ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options); + ScalarizerVisitor Impl(DT, Options); bool Changed = Impl.visit(F); PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 89d0b7c33e0d..b8c9d9d100f1 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -524,7 +524,7 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, // FIXME: this does not appear to be covered by any tests // (with x86/aarch64 backends at least) if (BO->getOpcode() == Instruction::Or && - !haveNoCommonBitsSet(LHS, RHS, DL, nullptr, BO, DT)) + !haveNoCommonBitsSet(LHS, RHS, SimplifyQuery(DL, DT, /*AC*/ nullptr, BO))) return false; // FIXME: We don't currently support constants from the RHS of subs, @@ -661,15 +661,16 @@ Value *ConstantOffsetExtractor::applyExts(Value *V) { // in the reversed order. for (CastInst *I : llvm::reverse(ExtInsts)) { if (Constant *C = dyn_cast<Constant>(Current)) { - // If Current is a constant, apply s/zext using ConstantExpr::getCast. - // ConstantExpr::getCast emits a ConstantInt if C is a ConstantInt. - Current = ConstantExpr::getCast(I->getOpcode(), C, I->getType()); - } else { - Instruction *Ext = I->clone(); - Ext->setOperand(0, Current); - Ext->insertBefore(IP); - Current = Ext; + // Try to constant fold the cast. + Current = ConstantFoldCastOperand(I->getOpcode(), C, I->getType(), DL); + if (Current) + continue; } + + Instruction *Ext = I->clone(); + Ext->setOperand(0, Current); + Ext->insertBefore(IP); + Current = Ext; } return Current; } @@ -830,7 +831,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (GTI.isSequential()) { // Constant offsets of scalable types are not really constant. - if (isa<ScalableVectorType>(GTI.getIndexedType())) + if (GTI.getIndexedType()->isScalableTy()) continue; // Tries to extract a constant offset from this GEP index. @@ -1019,7 +1020,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (GTI.isSequential()) { // Constant offsets of scalable types are not really constant. - if (isa<ScalableVectorType>(GTI.getIndexedType())) + if (GTI.getIndexedType()->isScalableTy()) continue; // Splits this GEP index into a variadic part and a constant offset, and diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index ad7d34b61470..7eb0ba1c2c17 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -24,7 +24,6 @@ #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" @@ -46,8 +45,6 @@ #include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -368,10 +365,11 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, bool FullUnswitch) { assert(&ExitBB != &UnswitchedBB && "Must have different loop exit and unswitched blocks!"); - Instruction *InsertPt = &*UnswitchedBB.begin(); + BasicBlock::iterator InsertPt = UnswitchedBB.begin(); for (PHINode &PN : ExitBB.phis()) { auto *NewPN = PHINode::Create(PN.getType(), /*NumReservedValues*/ 2, - PN.getName() + ".split", InsertPt); + PN.getName() + ".split"); + NewPN->insertBefore(InsertPt); // Walk backwards over the old PHI node's inputs to minimize the cost of // removing each one. We have to do this weird loop manually so that we @@ -609,7 +607,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, UnswitchedBB = LoopExitBB; } else { UnswitchedBB = - SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI, MSSAU); + SplitBlock(LoopExitBB, LoopExitBB->begin(), &DT, &LI, MSSAU, "", false); } if (MSSAU && VerifyMemorySSA) @@ -623,7 +621,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // If fully unswitching, we can use the existing branch instruction. // Splice it into the old PH to gate reaching the new preheader and re-point // its successors. - OldPH->splice(OldPH->end(), BI.getParent(), BI.getIterator()); + BI.moveBefore(*OldPH, OldPH->end()); BI.setCondition(Cond); if (MSSAU) { // Temporarily clone the terminator, to make MSSA update cheaper by @@ -882,7 +880,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH); } else { auto *SplitBB = - SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI, MSSAU); + SplitBlock(DefaultExitBB, DefaultExitBB->begin(), &DT, &LI, MSSAU); rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); @@ -909,7 +907,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, BasicBlock *&SplitExitBB = SplitExitBBMap[ExitBB]; if (!SplitExitBB) { // If this is the first time we see this, do the split and remember it. - SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSAU); + SplitExitBB = SplitBlock(ExitBB, ExitBB->begin(), &DT, &LI, MSSAU); rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); @@ -1210,7 +1208,7 @@ static BasicBlock *buildClonedLoopBlocks( // place to merge the CFG, so split the exit first. This is always safe to // do because there cannot be any non-loop predecessors of a loop exit in // loop simplified form. - auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSAU); + auto *MergeBB = SplitBlock(ExitBB, ExitBB->begin(), &DT, &LI, MSSAU); // Rearrange the names to make it easier to write test cases by having the // exit block carry the suffix rather than the merge block carrying the @@ -1246,8 +1244,8 @@ static BasicBlock *buildClonedLoopBlocks( SE->forgetValue(&I); auto *MergePN = - PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi", - &*MergeBB->getFirstInsertionPt()); + PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi"); + MergePN->insertBefore(MergeBB->getFirstInsertionPt()); I.replaceAllUsesWith(MergePN); MergePN->addIncoming(&I, ExitBB); MergePN->addIncoming(&ClonedI, ClonedExitBB); @@ -1259,8 +1257,11 @@ static BasicBlock *buildClonedLoopBlocks( // everything available. Also, we have inserted new instructions which may // include assume intrinsics, so we update the assumption cache while // processing this. + Module *M = ClonedPH->getParent()->getParent(); for (auto *ClonedBB : NewBlocks) for (Instruction &I : *ClonedBB) { + RemapDPValueRange(M, I.getDbgValueRange(), VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); if (auto *II = dyn_cast<AssumeInst>(&I)) @@ -1684,13 +1685,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, BB->eraseFromParent(); } -static void -deleteDeadBlocksFromLoop(Loop &L, - SmallVectorImpl<BasicBlock *> &ExitBlocks, - DominatorTree &DT, LoopInfo &LI, - MemorySSAUpdater *MSSAU, - ScalarEvolution *SE, - function_ref<void(Loop &, StringRef)> DestroyLoopCB) { +static void deleteDeadBlocksFromLoop(Loop &L, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAU, + ScalarEvolution *SE, + LPMUpdater &LoopUpdater) { // Find all the dead blocks tied to this loop, and remove them from their // successors. SmallSetVector<BasicBlock *, 8> DeadBlockSet; @@ -1740,7 +1740,7 @@ deleteDeadBlocksFromLoop(Loop &L, }) && "If the child loop header is dead all blocks in the child loop must " "be dead as well!"); - DestroyLoopCB(*ChildL, ChildL->getName()); + LoopUpdater.markLoopAsDeleted(*ChildL, ChildL->getName()); if (SE) SE->forgetBlockAndLoopDispositions(); LI.destroy(ChildL); @@ -2084,8 +2084,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, ParentL->removeChildLoop(llvm::find(*ParentL, &L)); else LI.removeLoop(llvm::find(LI, &L)); - // markLoopAsDeleted for L should be triggered by the caller (it is typically - // done by using the UnswitchCB callback). + // markLoopAsDeleted for L should be triggered by the caller (it is + // typically done within postUnswitch). if (SE) SE->forgetBlockAndLoopDispositions(); LI.destroy(&L); @@ -2122,17 +2122,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { } while (!DomWorklist.empty()); } +void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName, + bool CurrentLoopValid, bool PartiallyInvariant, + bool InjectedCondition, ArrayRef<Loop *> NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + if (!NewLoops.empty()) + U.addSiblingLoops(NewLoops); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) { + if (PartiallyInvariant) { + // Mark the new loop as partially unswitched, to avoid unswitching on + // the same condition again. + auto &Context = L.getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.partial.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L.getLoopID(), {"llvm.loop.unswitch.partial"}, + {DisableUnswitchMD}); + L.setLoopID(NewLoopID); + } else if (InjectedCondition) { + // Do the same for injection of invariant conditions. + auto &Context = L.getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.injection.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L.getLoopID(), {"llvm.loop.unswitch.injection"}, + {DisableUnswitchMD}); + L.setLoopID(NewLoopID); + } else + U.revisitCurrentLoop(); + } else + U.markLoopAsDeleted(L, LoopName); +} + static void unswitchNontrivialInvariants( Loop &L, Instruction &TI, ArrayRef<Value *> Invariants, IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, - function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze) { + AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast<BranchInst>(&TI); SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI); + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName(L.getName()); + // We can only unswitch switches, conditional branches with an invariant // condition, or combining invariant conditions with an instruction or // partially invariant instructions. @@ -2295,7 +2334,7 @@ static void unswitchNontrivialInvariants( if (FullUnswitch) { // Splice the terminator from the original loop and rewrite its // successors. - SplitBB->splice(SplitBB->end(), ParentBB, TI.getIterator()); + TI.moveBefore(*SplitBB, SplitBB->end()); // Keep a clone of the terminator for MSSA updates. Instruction *NewTI = TI.clone(); @@ -2445,7 +2484,7 @@ static void unswitchNontrivialInvariants( // Now that our cloned loops have been built, we can update the original loop. // First we delete the dead blocks from it and then we rebuild the loop // structure taking these deletions into account. - deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB); + deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -2581,7 +2620,8 @@ static void unswitchNontrivialInvariants( for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) if (UpdatedL->getParentLoop() == ParentL) SibLoops.push_back(UpdatedL); - UnswitchCB(IsStillLoop, PartiallyInvariant, SibLoops); + postUnswitch(L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant, + InjectedCondition, SibLoops); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -2979,13 +3019,6 @@ static bool shouldTryInjectInvariantCondition( /// the metadata. bool shouldTryInjectBasingOnMetadata(const BranchInst *BI, const BasicBlock *TakenSucc) { - // Skip branches that have already been unswithed this way. After successful - // unswitching of injected condition, we will still have a copy of this loop - // which looks exactly the same as original one. To prevent the 2nd attempt - // of unswitching it in the same pass, mark this branch as "nothing to do - // here". - if (BI->hasMetadata("llvm.invariant.condition.injection.disabled")) - return false; SmallVector<uint32_t> Weights; if (!extractBranchWeights(*BI, Weights)) return false; @@ -3060,7 +3093,6 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L, auto *InjectedCond = ICmpInst::Create(Instruction::ICmp, Pred, LHS, RHS, "injected.cond", Preheader->getTerminator()); - auto *OldCond = TI->getCondition(); BasicBlock *CheckBlock = BasicBlock::Create(Ctx, BB->getName() + ".check", BB->getParent(), InLoopSucc); @@ -3069,12 +3101,9 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L, Builder.CreateCondBr(InjectedCond, InLoopSucc, CheckBlock); Builder.SetInsertPoint(CheckBlock); - auto *NewTerm = Builder.CreateCondBr(OldCond, InLoopSucc, OutOfLoopSucc); - + Builder.CreateCondBr(TI->getCondition(), TI->getSuccessor(0), + TI->getSuccessor(1)); TI->eraseFromParent(); - // Prevent infinite unswitching. - NewTerm->setMetadata("llvm.invariant.condition.injection.disabled", - MDNode::get(BB->getContext(), {})); // Fixup phis. for (auto &I : *InLoopSucc) { @@ -3439,12 +3468,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT, Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT); } -static bool unswitchBestCondition( - Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - AAResults &AA, TargetTransformInfo &TTI, - function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref<void(Loop &, StringRef)> DestroyLoopCB) { +static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, AAResults &AA, + TargetTransformInfo &TTI, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU, + LPMUpdater &LoopUpdater) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates; @@ -3452,9 +3480,10 @@ static bool unswitchBestCondition( Instruction *PartialIVCondBranch = nullptr; collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo, PartialIVCondBranch, L, LI, AA, MSSAU); - collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo, - PartialIVCondBranch, L, DT, LI, AA, - MSSAU); + if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.injection.disable")) + collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo, + PartialIVCondBranch, L, DT, LI, AA, + MSSAU); // If we didn't find any candidates, we're done. if (UnswitchCandidates.empty()) return false; @@ -3475,8 +3504,11 @@ static bool unswitchBestCondition( return false; } - if (Best.hasPendingInjection()) + bool InjectedCondition = false; + if (Best.hasPendingInjection()) { Best = injectPendingInvariantConditions(Best, L, DT, LI, AC, MSSAU); + InjectedCondition = true; + } assert(!Best.hasPendingInjection() && "All injections should have been done by now!"); @@ -3503,8 +3535,8 @@ static bool unswitchBestCondition( LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost << ") terminator: " << *Best.TI << "\n"); unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT, - LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB, - InsertFreeze); + LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze, + InjectedCondition); return true; } @@ -3523,20 +3555,18 @@ static bool unswitchBestCondition( /// true, we will attempt to do non-trivial unswitching as well as trivial /// unswitching. /// -/// The `UnswitchCB` callback provided will be run after unswitching is -/// complete, with the first parameter set to `true` if the provided loop -/// remains a loop, and a list of new sibling loops created. +/// The `postUnswitch` function will be run after unswitching is complete +/// with information on whether or not the provided loop remains a loop and +/// a list of new sibling loops created. /// /// If `SE` is non-null, we will update that analysis based on the unswitching /// done. -static bool -unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - AAResults &AA, TargetTransformInfo &TTI, bool Trivial, - bool NonTrivial, - function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, - function_ref<void(Loop &, StringRef)> DestroyLoopCB) { +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, AAResults &AA, + TargetTransformInfo &TTI, bool Trivial, + bool NonTrivial, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3548,7 +3578,9 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. - UnswitchCB(/*CurrentLoopValid*/ true, false, {}); + postUnswitch(L, LoopUpdater, L.getName(), + /*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false, + /*InjectedCondition*/ false, {}); return true; } @@ -3617,8 +3649,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU, - DestroyLoopCB)) + if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater)) return true; // No other opportunities to unswitch. @@ -3638,41 +3669,6 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); - // Save the current loop name in a variable so that we can report it even - // after it has been deleted. - std::string LoopName = std::string(L.getName()); - - auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, - bool PartiallyInvariant, - ArrayRef<Loop *> NewLoops) { - // If we did a non-trivial unswitch, we have added new (cloned) loops. - if (!NewLoops.empty()) - U.addSiblingLoops(NewLoops); - - // If the current loop remains valid, we should revisit it to catch any - // other unswitch opportunities. Otherwise, we need to mark it as deleted. - if (CurrentLoopValid) { - if (PartiallyInvariant) { - // Mark the new loop as partially unswitched, to avoid unswitching on - // the same condition again. - auto &Context = L.getHeader()->getContext(); - MDNode *DisableUnswitchMD = MDNode::get( - Context, - MDString::get(Context, "llvm.loop.unswitch.partial.disable")); - MDNode *NewLoopID = makePostTransformationMetadata( - Context, L.getLoopID(), {"llvm.loop.unswitch.partial"}, - {DisableUnswitchMD}); - L.setLoopID(NewLoopID); - } else - U.revisitCurrentLoop(); - } else - U.markLoopAsDeleted(L, LoopName); - }; - - auto DestroyLoopCB = [&U](Loop &L, StringRef Name) { - U.markLoopAsDeleted(L, Name); - }; - std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) { MSSAU = MemorySSAUpdater(AR.MSSA); @@ -3680,8 +3676,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, AR.MSSA->verifyMemorySSA(); } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, - UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, - DestroyLoopCB)) + &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U)) return PreservedAnalyses::all(); if (AR.MSSA && VerifyMemorySSA) @@ -3707,104 +3702,3 @@ void SimpleLoopUnswitchPass::printPipeline( OS << (Trivial ? "" : "no-") << "trivial"; OS << '>'; } - -namespace { - -class SimpleLoopUnswitchLegacyPass : public LoopPass { - bool NonTrivial; - -public: - static char ID; // Pass ID, replacement for typeid - - explicit SimpleLoopUnswitchLegacyPass(bool NonTrivial = false) - : LoopPass(ID), NonTrivial(NonTrivial) { - initializeSimpleLoopUnswitchLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - getLoopAnalysisUsage(AU); - } -}; - -} // end anonymous namespace - -bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - - Function &F = *L->getHeader()->getParent(); - - LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L - << "\n"); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MemorySSAUpdater MSSAU(MSSA); - - auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - auto *SE = SEWP ? &SEWP->getSE() : nullptr; - - auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, bool PartiallyInvariant, - ArrayRef<Loop *> NewLoops) { - // If we did a non-trivial unswitch, we have added new (cloned) loops. - for (auto *NewL : NewLoops) - LPM.addLoop(*NewL); - - // If the current loop remains valid, re-add it to the queue. This is - // a little wasteful as we'll finish processing the current loop as well, - // but it is the best we can do in the old PM. - if (CurrentLoopValid) { - // If the current loop has been unswitched using a partially invariant - // condition, we should not re-add the current loop to avoid unswitching - // on the same condition again. - if (!PartiallyInvariant) - LPM.addLoop(*L); - } else - LPM.markLoopAsDeleted(*L); - }; - - auto DestroyLoopCB = [&LPM](Loop &L, StringRef /* Name */) { - LPM.markLoopAsDeleted(L); - }; - - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - bool Changed = - unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, UnswitchCB, SE, - &MSSAU, nullptr, nullptr, DestroyLoopCB); - - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - - // Historically this pass has had issues with the dominator tree so verify it - // in asserts builds. - assert(DT.verify(DominatorTree::VerificationLevel::Fast)); - - return Changed; -} - -char SimpleLoopUnswitchLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", - "Simple unswitch loops", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", - "Simple unswitch loops", false, false) - -Pass *llvm::createSimpleLoopUnswitchLegacyPass(bool NonTrivial) { - return new SimpleLoopUnswitchLegacyPass(NonTrivial); -} diff --git a/llvm/lib/Transforms/Scalar/Sink.cpp b/llvm/lib/Transforms/Scalar/Sink.cpp index 8b99f73b850b..46bcfd6b41ce 100644 --- a/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/llvm/lib/Transforms/Scalar/Sink.cpp @@ -67,9 +67,8 @@ static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, assert(Inst && "Instruction to be sunk is null"); assert(SuccToSinkTo && "Candidate sink target is null"); - // It's never legal to sink an instruction into a block which terminates in an - // EH-pad. - if (SuccToSinkTo->getTerminator()->isExceptionalTerminator()) + // It's never legal to sink an instruction into an EH-pad block. + if (SuccToSinkTo->isEHPad()) return false; // If the block has multiple predecessors, this would introduce computation @@ -131,15 +130,16 @@ static bool SinkInstruction(Instruction *Inst, for (Use &U : Inst->uses()) { Instruction *UseInst = cast<Instruction>(U.getUser()); BasicBlock *UseBlock = UseInst->getParent(); - // Don't worry about dead users. - if (!DT.isReachableFromEntry(UseBlock)) - continue; if (PHINode *PN = dyn_cast<PHINode>(UseInst)) { // PHI nodes use the operand in the predecessor block, not the block with // the PHI. unsigned Num = PHINode::getIncomingValueNumForOperand(U.getOperandNo()); UseBlock = PN->getIncomingBlock(Num); } + // Don't worry about dead users. + if (!DT.isReachableFromEntry(UseBlock)) + continue; + if (SuccToSinkTo) SuccToSinkTo = DT.findNearestCommonDominator(SuccToSinkTo, UseBlock); else diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index e866fe681127..7a5318d4404c 100644 --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -316,7 +316,7 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( auto Current = I; ++I; if (!NotHoisted.count(&*Current)) { - Current->moveBefore(ToBlock.getTerminator()); + Current->moveBeforePreserving(ToBlock.getTerminator()); } } return true; @@ -346,4 +346,14 @@ PreservedAnalyses SpeculativeExecutionPass::run(Function &F, PA.preserveSet<CFGAnalyses>(); return PA; } + +void SpeculativeExecutionPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<SpeculativeExecutionPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << '<'; + if (OnlyIfDivergentTarget) + OS << "only-if-divergent-target"; + OS << '>'; +} } // namespace llvm diff --git a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index fdb41cb415df..543469d62fe7 100644 --- a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -680,7 +680,7 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( if (BumpWithUglyGEP) { // C = (char *)Basis + Bump unsigned AS = Basis.Ins->getType()->getPointerAddressSpace(); - Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS); + Type *CharTy = PointerType::get(Basis.Ins->getContext(), AS); Reduced = Builder.CreateBitCast(Basis.Ins, CharTy); Reduced = Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds); diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index fac5695c7bea..7d96a3478858 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -42,6 +42,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -353,7 +354,6 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { if (SkipUniformRegions) AU.addRequired<UniformityInfoWrapperPass>(); - AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -368,7 +368,6 @@ char StructurizeCFGLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(StructurizeCFGLegacyPass, "structurizecfg", "Structurize the CFG", false, false) INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) INITIALIZE_PASS_END(StructurizeCFGLegacyPass, "structurizecfg", @@ -1173,6 +1172,8 @@ bool StructurizeCFG::run(Region *R, DominatorTree *DT) { this->DT = DT; Func = R->getEntry()->getParent(); + assert(hasOnlySimpleTerminator(*Func) && "Unsupported block terminator."); + ParentRegion = R; orderNodes(); diff --git a/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp b/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp index 4ec7181ad859..58ea5b68d548 100644 --- a/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp +++ b/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp @@ -32,7 +32,6 @@ #include <cassert> #include <cstdint> #include <iterator> -#include <tuple> #include <utility> using namespace llvm; diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 4f1350e4ebb9..c6e8505d5ab4 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -369,8 +369,14 @@ static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { if (!I->isAssociative() || !I->isCommutative()) return false; - assert(I->getNumOperands() == 2 && - "Associative/commutative operations should have 2 args!"); + assert(I->getNumOperands() >= 2 && + "Associative/commutative operations should have at least 2 args!"); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + // Accumulators must have an identity. + if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(), I->getType())) + return false; + } // Exactly one operand should be the result of the call instruction. if ((I->getOperand(0) == CI && I->getOperand(1) == CI) || @@ -518,10 +524,10 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { // block, insert a PHI node for each argument of the function. // For now, we initialize each PHI to only have the real arguments // which are passed in. - Instruction *InsertPos = &HeaderBB->front(); + BasicBlock::iterator InsertPos = HeaderBB->begin(); for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { - PHINode *PN = - PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos); + PHINode *PN = PHINode::Create(I->getType(), 2, I->getName() + ".tr"); + PN->insertBefore(InsertPos); I->replaceAllUsesWith(PN); // Everyone use the PHI node now! PN->addIncoming(&*I, NewEntry); ArgumentPHIs.push_back(PN); @@ -534,8 +540,10 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { Type *RetType = F.getReturnType(); if (!RetType->isVoidTy()) { Type *BoolType = Type::getInt1Ty(F.getContext()); - RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos); - RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos); + RetPN = PHINode::Create(RetType, 2, "ret.tr"); + RetPN->insertBefore(InsertPos); + RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr"); + RetKnownPN->insertBefore(InsertPos); RetPN->addIncoming(PoisonValue::get(RetType), NewEntry); RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry); @@ -555,7 +563,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { // Start by inserting a new PHI node for the accumulator. pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB); AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1, - "accumulator.tr", &HeaderBB->front()); + "accumulator.tr"); + AccPN->insertBefore(HeaderBB->begin()); // Loop over all of the predecessors of the tail recursion block. For the // real entry into the function we seed the PHI with the identity constant for @@ -566,8 +575,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { for (pred_iterator PI = PB; PI != PE; ++PI) { BasicBlock *P = *PI; if (P == &F.getEntryBlock()) { - Constant *Identity = ConstantExpr::getBinOpIdentity( - AccRecInstr->getOpcode(), AccRecInstr->getType()); + Constant *Identity = + ConstantExpr::getIdentity(AccRecInstr, AccRecInstr->getType()); AccPN->addIncoming(Identity, P); } else { AccPN->addIncoming(AccPN, P); @@ -675,6 +684,12 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { for (unsigned I = 0, E = CI->arg_size(); I != E; ++I) { if (CI->isByValArgument(I)) { copyLocalTempOfByValueOperandIntoArguments(CI, I); + // When eliminating a tail call, we modify the values of the arguments. + // Therefore, if the byval parameter has a readonly attribute, we have to + // remove it. It is safe because, from the perspective of a caller, the + // byval parameter is always treated as "readonly," even if the readonly + // attribute is removed. + F.removeParamAttr(I, Attribute::ReadOnly); ArgumentPHIs[I]->addIncoming(F.getArg(I), BB); } else ArgumentPHIs[I]->addIncoming(CI->getArgOperand(I), BB); diff --git a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp index 2195406c144c..6ca737df49b9 100644 --- a/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp +++ b/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp @@ -153,19 +153,17 @@ static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) { static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str, Value *Length, bool isLast) { auto Int64Ty = Builder.getInt64Ty(); - auto CharPtrTy = Builder.getInt8PtrTy(); + auto PtrTy = Builder.getPtrTy(); auto Int32Ty = Builder.getInt32Ty(); auto M = Builder.GetInsertBlock()->getModule(); auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty, - Int64Ty, CharPtrTy, Int64Ty, Int32Ty); + Int64Ty, PtrTy, Int64Ty, Int32Ty); auto IsLastInt32 = Builder.getInt32(isLast); return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32}); } static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg, bool IsLast) { - Arg = Builder.CreateBitCast( - Arg, Builder.getInt8PtrTy(Arg->getType()->getPointerAddressSpace())); auto Length = getStrlenWithNull(Builder, Arg); return callAppendStringN(Builder, Desc, Arg, Length, IsLast); } @@ -299,9 +297,9 @@ static Value *callBufferedPrintfStart( Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); Type *Tys_alloc[1] = {Builder.getInt32Ty()}; - Type *I8Ptr = - Builder.getInt8PtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace()); - FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false); + Type *PtrTy = + Builder.getPtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace()); + FunctionType *FTy_alloc = FunctionType::get(PtrTy, Tys_alloc, false); auto PrintfAllocFn = M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr); diff --git a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp index 7d127400651e..f95d5e23c9c8 100644 --- a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp @@ -63,13 +63,10 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" #include <utility> diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index 45cf98e65a5a..efa8e874b955 100644 --- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -19,7 +19,6 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Utils/Local.h" @@ -587,37 +586,3 @@ PreservedAnalyses AssumeBuilderPass::run(Function &F, PA.preserveSet<CFGAnalyses>(); return PA; } - -namespace { -class AssumeBuilderPassLegacyPass : public FunctionPass { -public: - static char ID; - - AssumeBuilderPassLegacyPass() : FunctionPass(ID) { - initializeAssumeBuilderPassLegacyPassPass(*PassRegistry::getPassRegistry()); - } - bool runOnFunction(Function &F) override { - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - DominatorTreeWrapperPass *DTWP = - getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - for (Instruction &I : instructions(F)) - salvageKnowledge(&I, &AC, DTWP ? &DTWP->getDomTree() : nullptr); - return true; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - - AU.setPreservesAll(); - } -}; -} // namespace - -char AssumeBuilderPassLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(AssumeBuilderPassLegacyPass, "assume-builder", - "Assume Builder", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(AssumeBuilderPassLegacyPass, "assume-builder", - "Assume Builder", false, false) diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index f06ea89cc61d..b700edf8ea6c 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -194,7 +194,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, // Don't break unwinding instructions or terminators with other side-effects. Instruction *PTI = PredBB->getTerminator(); - if (PTI->isExceptionalTerminator() || PTI->mayHaveSideEffects()) + if (PTI->isSpecialTerminator() || PTI->mayHaveSideEffects()) return false; // Can't merge if there are multiple distinct successors. @@ -300,7 +300,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, PredBB->back().eraseFromParent(); // Move terminator instruction. - PredBB->splice(PredBB->end(), BB); + BB->back().moveBeforePreserving(*PredBB, PredBB->end()); // Terminator may be a memory accessing instruction too. if (MSSAU) @@ -382,7 +382,39 @@ bool llvm::MergeBlockSuccessorsIntoGivenBlocks( /// - Check fully overlapping fragments and not only identical fragments. /// - Support dbg.declare. dbg.label, and possibly other meta instructions being /// part of the sequence of consecutive instructions. +static bool DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { + SmallVector<DPValue *, 8> ToBeRemoved; + SmallDenseSet<DebugVariable> VariableSet; + for (auto &I : reverse(*BB)) { + for (DPValue &DPV : reverse(I.getDbgValueRange())) { + DebugVariable Key(DPV.getVariable(), DPV.getExpression(), + DPV.getDebugLoc()->getInlinedAt()); + auto R = VariableSet.insert(Key); + // If the same variable fragment is described more than once it is enough + // to keep the last one (i.e. the first found since we for reverse + // iteration). + // FIXME: add assignment tracking support (see parallel implementation + // below). + if (!R.second) + ToBeRemoved.push_back(&DPV); + continue; + } + // Sequence with consecutive dbg.value instrs ended. Clear the map to + // restart identifying redundant instructions if case we find another + // dbg.value sequence. + VariableSet.clear(); + } + + for (auto &DPV : ToBeRemoved) + DPV->eraseFromParent(); + + return !ToBeRemoved.empty(); +} + static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { + if (BB->IsNewDbgInfoFormat) + return DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BB); + SmallVector<DbgValueInst *, 8> ToBeRemoved; SmallDenseSet<DebugVariable> VariableSet; for (auto &I : reverse(*BB)) { @@ -440,7 +472,38 @@ static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { /// /// Possible improvements: /// - Keep track of non-overlapping fragments. +static bool DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { + SmallVector<DPValue *, 8> ToBeRemoved; + DenseMap<DebugVariable, std::pair<SmallVector<Value *, 4>, DIExpression *>> + VariableMap; + for (auto &I : *BB) { + for (DPValue &DPV : I.getDbgValueRange()) { + DebugVariable Key(DPV.getVariable(), std::nullopt, + DPV.getDebugLoc()->getInlinedAt()); + auto VMI = VariableMap.find(Key); + // Update the map if we found a new value/expression describing the + // variable, or if the variable wasn't mapped already. + SmallVector<Value *, 4> Values(DPV.location_ops()); + if (VMI == VariableMap.end() || VMI->second.first != Values || + VMI->second.second != DPV.getExpression()) { + VariableMap[Key] = {Values, DPV.getExpression()}; + continue; + } + // Found an identical mapping. Remember the instruction for later removal. + ToBeRemoved.push_back(&DPV); + } + } + + for (auto *DPV : ToBeRemoved) + DPV->eraseFromParent(); + + return !ToBeRemoved.empty(); +} + static bool removeRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { + if (BB->IsNewDbgInfoFormat) + return DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BB); + SmallVector<DbgValueInst *, 8> ToBeRemoved; DenseMap<DebugVariable, std::pair<SmallVector<Value *, 4>, DIExpression *>> VariableMap; @@ -852,9 +915,11 @@ void llvm::createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds, continue; // Otherwise a new PHI is needed. Create one and populate it. - PHINode *NewPN = PHINode::Create( - PN.getType(), Preds.size(), "split", - SplitBB->isLandingPad() ? &SplitBB->front() : SplitBB->getTerminator()); + PHINode *NewPN = PHINode::Create(PN.getType(), Preds.size(), "split"); + BasicBlock::iterator InsertPos = + SplitBB->isLandingPad() ? SplitBB->begin() + : SplitBB->getTerminator()->getIterator(); + NewPN->insertBefore(InsertPos); for (BasicBlock *BB : Preds) NewPN->addIncoming(V, BB); @@ -877,7 +942,7 @@ llvm::SplitAllCriticalEdges(Function &F, return NumBroken; } -static BasicBlock *SplitBlockImpl(BasicBlock *Old, Instruction *SplitPt, +static BasicBlock *SplitBlockImpl(BasicBlock *Old, BasicBlock::iterator SplitPt, DomTreeUpdater *DTU, DominatorTree *DT, LoopInfo *LI, MemorySSAUpdater *MSSAU, const Twine &BBName, bool Before) { @@ -887,7 +952,7 @@ static BasicBlock *SplitBlockImpl(BasicBlock *Old, Instruction *SplitPt, DTU ? DTU : (DT ? &LocalDTU : nullptr), LI, MSSAU, BBName); } - BasicBlock::iterator SplitIt = SplitPt->getIterator(); + BasicBlock::iterator SplitIt = SplitPt; while (isa<PHINode>(SplitIt) || SplitIt->isEHPad()) { ++SplitIt; assert(SplitIt != SplitPt->getParent()->end()); @@ -933,14 +998,14 @@ static BasicBlock *SplitBlockImpl(BasicBlock *Old, Instruction *SplitPt, return New; } -BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, +BasicBlock *llvm::SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI, MemorySSAUpdater *MSSAU, const Twine &BBName, bool Before) { return SplitBlockImpl(Old, SplitPt, /*DTU=*/nullptr, DT, LI, MSSAU, BBName, Before); } -BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, +BasicBlock *llvm::SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, const Twine &BBName, bool Before) { @@ -948,12 +1013,12 @@ BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, Before); } -BasicBlock *llvm::splitBlockBefore(BasicBlock *Old, Instruction *SplitPt, +BasicBlock *llvm::splitBlockBefore(BasicBlock *Old, BasicBlock::iterator SplitPt, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, const Twine &BBName) { - BasicBlock::iterator SplitIt = SplitPt->getIterator(); + BasicBlock::iterator SplitIt = SplitPt; while (isa<PHINode>(SplitIt) || SplitIt->isEHPad()) ++SplitIt; std::string Name = BBName.str(); @@ -1137,14 +1202,11 @@ static void UpdatePHINodes(BasicBlock *OrigBB, BasicBlock *NewBB, // If all incoming values for the new PHI would be the same, just don't // make a new PHI. Instead, just remove the incoming values from the old // PHI. - - // NOTE! This loop walks backwards for a reason! First off, this minimizes - // the cost of removal if we end up removing a large number of values, and - // second off, this ensures that the indices for the incoming values - // aren't invalidated when we remove one. - for (int64_t i = PN->getNumIncomingValues() - 1; i >= 0; --i) - if (PredSet.count(PN->getIncomingBlock(i))) - PN->removeIncomingValue(i, false); + PN->removeIncomingValueIf( + [&](unsigned Idx) { + return PredSet.contains(PN->getIncomingBlock(Idx)); + }, + /* DeletePHIIfEmpty */ false); // Add an incoming value to the PHI node in the loop for the preheader // edge. @@ -1394,17 +1456,6 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, ArrayRef<BasicBlock *> Preds, const char *Suffix1, const char *Suffix2, SmallVectorImpl<BasicBlock *> &NewBBs, - DominatorTree *DT, LoopInfo *LI, - MemorySSAUpdater *MSSAU, - bool PreserveLCSSA) { - return SplitLandingPadPredecessorsImpl( - OrigBB, Preds, Suffix1, Suffix2, NewBBs, - /*DTU=*/nullptr, DT, LI, MSSAU, PreserveLCSSA); -} -void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, - ArrayRef<BasicBlock *> Preds, - const char *Suffix1, const char *Suffix2, - SmallVectorImpl<BasicBlock *> &NewBBs, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, bool PreserveLCSSA) { @@ -1472,7 +1523,7 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, } Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, - Instruction *SplitBefore, + BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI, @@ -1485,7 +1536,7 @@ Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, } Instruction *llvm::SplitBlockAndInsertIfElse(Value *Cond, - Instruction *SplitBefore, + BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI, @@ -1497,7 +1548,7 @@ Instruction *llvm::SplitBlockAndInsertIfElse(Value *Cond, return ElseBlock->getTerminator(); } -void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, +void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, BasicBlock::iterator SplitBefore, Instruction **ThenTerm, Instruction **ElseTerm, MDNode *BranchWeights, @@ -1513,7 +1564,7 @@ void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, } void llvm::SplitBlockAndInsertIfThenElse( - Value *Cond, Instruction *SplitBefore, BasicBlock **ThenBlock, + Value *Cond, BasicBlock::iterator SplitBefore, BasicBlock **ThenBlock, BasicBlock **ElseBlock, bool UnreachableThen, bool UnreachableElse, MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI) { assert((ThenBlock || ElseBlock) && @@ -1530,7 +1581,7 @@ void llvm::SplitBlockAndInsertIfThenElse( } LLVMContext &C = Head->getContext(); - BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); + BasicBlock *Tail = Head->splitBasicBlock(SplitBefore); BasicBlock *TrueBlock = Tail; BasicBlock *FalseBlock = Tail; bool ThenToTailEdge = false; @@ -2077,3 +2128,25 @@ void llvm::InvertBranch(BranchInst *PBI, IRBuilderBase &Builder) { PBI->setCondition(NewCond); PBI->swapSuccessors(); } + +bool llvm::hasOnlySimpleTerminator(const Function &F) { + for (auto &BB : F) { + auto *Term = BB.getTerminator(); + if (!(isa<ReturnInst>(Term) || isa<UnreachableInst>(Term) || + isa<BranchInst>(Term))) + return false; + } + return true; +} + +bool llvm::isPresplitCoroSuspendExitEdge(const BasicBlock &Src, + const BasicBlock &Dest) { + assert(Src.getParent() == Dest.getParent()); + if (!Src.getParent()->isPresplitCoroutine()) + return false; + if (auto *SW = dyn_cast<SwitchInst>(Src.getTerminator())) + if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition())) + return Intr->getIntrinsicID() == Intrinsic::coro_suspend && + SW->getDefaultDest() == &Dest; + return false; +} diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index ddb35756030f..5fb796cc3db6 100644 --- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -387,7 +387,7 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, if (ShouldUpdateAnalysis) { // Copy the BFI/BPI from Target to BodyBlock. BPI->setEdgeProbability(BodyBlock, EdgeProbabilities); - BFI->setBlockFreq(BodyBlock, BFI->getBlockFreq(Target).getFrequency()); + BFI->setBlockFreq(BodyBlock, BFI->getBlockFreq(Target)); } // It's possible Target was its own successor through an indirectbr. // In this case, the indirectbr now comes from BodyBlock. @@ -411,10 +411,10 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, BPI->getEdgeProbability(Src, DirectSucc); } if (ShouldUpdateAnalysis) { - BFI->setBlockFreq(DirectSucc, BlockFreqForDirectSucc.getFrequency()); + BFI->setBlockFreq(DirectSucc, BlockFreqForDirectSucc); BlockFrequency NewBlockFreqForTarget = BFI->getBlockFreq(Target) - BlockFreqForDirectSucc; - BFI->setBlockFreq(Target, NewBlockFreqForTarget.getFrequency()); + BFI->setBlockFreq(Target, NewBlockFreqForTarget); } // Ok, now fix up the PHIs. We know the two blocks only have PHIs, and that @@ -449,8 +449,8 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, // Create a PHI in the body block, to merge the direct and indirect // predecessors. - PHINode *MergePHI = - PHINode::Create(IndPHI->getType(), 2, "merge", &*MergeInsert); + PHINode *MergePHI = PHINode::Create(IndPHI->getType(), 2, "merge"); + MergePHI->insertBefore(MergeInsert); MergePHI->addIncoming(NewIndPHI, Target); MergePHI->addIncoming(DirPHI, DirectSucc); diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 5de8ff84de77..12741dc5af5a 100644 --- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -1425,11 +1425,6 @@ StringRef llvm::getFloatFn(const Module *M, const TargetLibraryInfo *TLI, //- Emit LibCalls ------------------------------------------------------------// -Value *llvm::castToCStr(Value *V, IRBuilderBase &B) { - unsigned AS = V->getType()->getPointerAddressSpace(); - return B.CreateBitCast(V, B.getInt8PtrTy(AS), "cstr"); -} - static IntegerType *getIntTy(IRBuilderBase &B, const TargetLibraryInfo *TLI) { return B.getIntNTy(TLI->getIntSize()); } @@ -1461,63 +1456,64 @@ static Value *emitLibCall(LibFunc TheLibFunc, Type *ReturnType, Value *llvm::emitStrLen(Value *Ptr, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { + Type *CharPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_strlen, SizeTTy, - B.getInt8PtrTy(), castToCStr(Ptr, B), B, TLI); + return emitLibCall(LibFunc_strlen, SizeTTy, CharPtrTy, Ptr, B, TLI); } Value *llvm::emitStrDup(Value *Ptr, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall(LibFunc_strdup, B.getInt8PtrTy(), B.getInt8PtrTy(), - castToCStr(Ptr, B), B, TLI); + Type *CharPtrTy = B.getPtrTy(); + return emitLibCall(LibFunc_strdup, CharPtrTy, CharPtrTy, Ptr, B, TLI); } Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); - return emitLibCall(LibFunc_strchr, I8Ptr, {I8Ptr, IntTy}, - {castToCStr(Ptr, B), ConstantInt::get(IntTy, C)}, B, TLI); + return emitLibCall(LibFunc_strchr, CharPtrTy, {CharPtrTy, IntTy}, + {Ptr, ConstantInt::get(IntTy, C)}, B, TLI); } Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { + Type *CharPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall( LibFunc_strncmp, IntTy, - {B.getInt8PtrTy(), B.getInt8PtrTy(), SizeTTy}, - {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); + {CharPtrTy, CharPtrTy, SizeTTy}, + {Ptr1, Ptr2, Len}, B, TLI); } Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = Dst->getType(); - return emitLibCall(LibFunc_strcpy, I8Ptr, {I8Ptr, I8Ptr}, - {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI); + Type *CharPtrTy = Dst->getType(); + return emitLibCall(LibFunc_strcpy, CharPtrTy, {CharPtrTy, CharPtrTy}, + {Dst, Src}, B, TLI); } Value *llvm::emitStpCpy(Value *Dst, Value *Src, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); - return emitLibCall(LibFunc_stpcpy, I8Ptr, {I8Ptr, I8Ptr}, - {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI); + Type *CharPtrTy = B.getPtrTy(); + return emitLibCall(LibFunc_stpcpy, CharPtrTy, {CharPtrTy, CharPtrTy}, + {Dst, Src}, B, TLI); } Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_strncpy, I8Ptr, {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI); + return emitLibCall(LibFunc_strncpy, CharPtrTy, {CharPtrTy, CharPtrTy, SizeTTy}, + {Dst, Src, Len}, B, TLI); } Value *llvm::emitStpNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_stpncpy, I8Ptr, {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI); + return emitLibCall(LibFunc_stpncpy, CharPtrTy, {CharPtrTy, CharPtrTy, SizeTTy}, + {Dst, Src, Len}, B, TLI); } Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, @@ -1530,13 +1526,11 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, AttributeList AS; AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); FunctionCallee MemCpy = getOrInsertLibFunc(M, *TLI, LibFunc_memcpy_chk, - AttributeList::get(M->getContext(), AS), I8Ptr, - I8Ptr, I8Ptr, SizeTTy, SizeTTy); - Dst = castToCStr(Dst, B); - Src = castToCStr(Src, B); + AttributeList::get(M->getContext(), AS), VoidPtrTy, + VoidPtrTy, VoidPtrTy, SizeTTy, SizeTTy); CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize}); if (const Function *F = dyn_cast<Function>(MemCpy.getCallee()->stripPointerCasts())) @@ -1546,140 +1540,141 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, Value *llvm::emitMemPCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_mempcpy, I8Ptr, - {I8Ptr, I8Ptr, SizeTTy}, + return emitLibCall(LibFunc_mempcpy, VoidPtrTy, + {VoidPtrTy, VoidPtrTy, SizeTTy}, {Dst, Src, Len}, B, TLI); } Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_memchr, I8Ptr, - {I8Ptr, IntTy, SizeTTy}, - {castToCStr(Ptr, B), Val, Len}, B, TLI); + return emitLibCall(LibFunc_memchr, VoidPtrTy, + {VoidPtrTy, IntTy, SizeTTy}, + {Ptr, Val, Len}, B, TLI); } Value *llvm::emitMemRChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_memrchr, I8Ptr, - {I8Ptr, IntTy, SizeTTy}, - {castToCStr(Ptr, B), Val, Len}, B, TLI); + return emitLibCall(LibFunc_memrchr, VoidPtrTy, + {VoidPtrTy, IntTy, SizeTTy}, + {Ptr, Val, Len}, B, TLI); } Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall(LibFunc_memcmp, IntTy, - {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); + {VoidPtrTy, VoidPtrTy, SizeTTy}, + {Ptr1, Ptr2, Len}, B, TLI); } Value *llvm::emitBCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall(LibFunc_bcmp, IntTy, - {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); + {VoidPtrTy, VoidPtrTy, SizeTTy}, + {Ptr1, Ptr2, Len}, B, TLI); } Value *llvm::emitMemCCpy(Value *Ptr1, Value *Ptr2, Value *Val, Value *Len, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *VoidPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_memccpy, I8Ptr, - {I8Ptr, I8Ptr, IntTy, SizeTTy}, + return emitLibCall(LibFunc_memccpy, VoidPtrTy, + {VoidPtrTy, VoidPtrTy, IntTy, SizeTTy}, {Ptr1, Ptr2, Val, Len}, B, TLI); } Value *llvm::emitSNPrintf(Value *Dest, Value *Size, Value *Fmt, ArrayRef<Value *> VariadicArgs, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); - SmallVector<Value *, 8> Args{castToCStr(Dest, B), Size, castToCStr(Fmt, B)}; + SmallVector<Value *, 8> Args{Dest, Size, Fmt}; llvm::append_range(Args, VariadicArgs); return emitLibCall(LibFunc_snprintf, IntTy, - {I8Ptr, SizeTTy, I8Ptr}, + {CharPtrTy, SizeTTy, CharPtrTy}, Args, B, TLI, /*IsVaArgs=*/true); } Value *llvm::emitSPrintf(Value *Dest, Value *Fmt, ArrayRef<Value *> VariadicArgs, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); - SmallVector<Value *, 8> Args{castToCStr(Dest, B), castToCStr(Fmt, B)}; + SmallVector<Value *, 8> Args{Dest, Fmt}; llvm::append_range(Args, VariadicArgs); return emitLibCall(LibFunc_sprintf, IntTy, - {I8Ptr, I8Ptr}, Args, B, TLI, + {CharPtrTy, CharPtrTy}, Args, B, TLI, /*IsVaArgs=*/true); } Value *llvm::emitStrCat(Value *Dest, Value *Src, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall(LibFunc_strcat, B.getInt8PtrTy(), - {B.getInt8PtrTy(), B.getInt8PtrTy()}, - {castToCStr(Dest, B), castToCStr(Src, B)}, B, TLI); + Type *CharPtrTy = B.getPtrTy(); + return emitLibCall(LibFunc_strcat, CharPtrTy, + {CharPtrTy, CharPtrTy}, + {Dest, Src}, B, TLI); } Value *llvm::emitStrLCpy(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall(LibFunc_strlcpy, SizeTTy, - {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); + {CharPtrTy, CharPtrTy, SizeTTy}, + {Dest, Src, Size}, B, TLI); } Value *llvm::emitStrLCat(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall(LibFunc_strlcat, SizeTTy, - {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); + {CharPtrTy, CharPtrTy, SizeTTy}, + {Dest, Src, Size}, B, TLI); } Value *llvm::emitStrNCat(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *SizeTTy = getSizeTTy(B, TLI); - return emitLibCall(LibFunc_strncat, I8Ptr, - {I8Ptr, I8Ptr, SizeTTy}, - {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); + return emitLibCall(LibFunc_strncat, CharPtrTy, + {CharPtrTy, CharPtrTy, SizeTTy}, + {Dest, Src, Size}, B, TLI); } Value *llvm::emitVSNPrintf(Value *Dest, Value *Size, Value *Fmt, Value *VAList, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall( LibFunc_vsnprintf, IntTy, - {I8Ptr, SizeTTy, I8Ptr, VAList->getType()}, - {castToCStr(Dest, B), Size, castToCStr(Fmt, B), VAList}, B, TLI); + {CharPtrTy, SizeTTy, CharPtrTy, VAList->getType()}, + {Dest, Size, Fmt, VAList}, B, TLI); } Value *llvm::emitVSPrintf(Value *Dest, Value *Fmt, Value *VAList, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *CharPtrTy = B.getPtrTy(); Type *IntTy = getIntTy(B, TLI); return emitLibCall(LibFunc_vsprintf, IntTy, - {I8Ptr, I8Ptr, VAList->getType()}, - {castToCStr(Dest, B), castToCStr(Fmt, B), VAList}, B, TLI); + {CharPtrTy, CharPtrTy, VAList->getType()}, + {Dest, Fmt, VAList}, B, TLI); } /// Append a suffix to the function name according to the type of 'Op'. @@ -1829,9 +1824,9 @@ Value *llvm::emitPutS(Value *Str, IRBuilderBase &B, Type *IntTy = getIntTy(B, TLI); StringRef PutsName = TLI->getName(LibFunc_puts); FunctionCallee PutS = getOrInsertLibFunc(M, *TLI, LibFunc_puts, IntTy, - B.getInt8PtrTy()); + B.getPtrTy()); inferNonMandatoryLibFuncAttrs(M, PutsName, *TLI); - CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), PutsName); + CallInst *CI = B.CreateCall(PutS, Str, PutsName); if (const Function *F = dyn_cast<Function>(PutS.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1867,10 +1862,10 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilderBase &B, Type *IntTy = getIntTy(B, TLI); StringRef FPutsName = TLI->getName(LibFunc_fputs); FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputs, IntTy, - B.getInt8PtrTy(), File->getType()); + B.getPtrTy(), File->getType()); if (File->getType()->isPointerTy()) inferNonMandatoryLibFuncAttrs(M, FPutsName, *TLI); - CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsName); + CallInst *CI = B.CreateCall(F, {Str, File}, FPutsName); if (const Function *Fn = dyn_cast<Function>(F.getCallee()->stripPointerCasts())) @@ -1887,13 +1882,13 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilderBase &B, Type *SizeTTy = getSizeTTy(B, TLI); StringRef FWriteName = TLI->getName(LibFunc_fwrite); FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fwrite, - SizeTTy, B.getInt8PtrTy(), SizeTTy, + SizeTTy, B.getPtrTy(), SizeTTy, SizeTTy, File->getType()); if (File->getType()->isPointerTy()) inferNonMandatoryLibFuncAttrs(M, FWriteName, *TLI); CallInst *CI = - B.CreateCall(F, {castToCStr(Ptr, B), Size, + B.CreateCall(F, {Ptr, Size, ConstantInt::get(SizeTTy, 1), File}); if (const Function *Fn = @@ -1911,7 +1906,7 @@ Value *llvm::emitMalloc(Value *Num, IRBuilderBase &B, const DataLayout &DL, StringRef MallocName = TLI->getName(LibFunc_malloc); Type *SizeTTy = getSizeTTy(B, TLI); FunctionCallee Malloc = getOrInsertLibFunc(M, *TLI, LibFunc_malloc, - B.getInt8PtrTy(), SizeTTy); + B.getPtrTy(), SizeTTy); inferNonMandatoryLibFuncAttrs(M, MallocName, *TLI); CallInst *CI = B.CreateCall(Malloc, Num, MallocName); @@ -1931,7 +1926,7 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B, StringRef CallocName = TLI.getName(LibFunc_calloc); Type *SizeTTy = getSizeTTy(B, &TLI); FunctionCallee Calloc = getOrInsertLibFunc(M, TLI, LibFunc_calloc, - B.getInt8PtrTy(), SizeTTy, SizeTTy); + B.getPtrTy(), SizeTTy, SizeTTy); inferNonMandatoryLibFuncAttrs(M, CallocName, TLI); CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName); @@ -1950,7 +1945,7 @@ Value *llvm::emitHotColdNew(Value *Num, IRBuilderBase &B, return nullptr; StringRef Name = TLI->getName(NewFunc); - FunctionCallee Func = M->getOrInsertFunction(Name, B.getInt8PtrTy(), + FunctionCallee Func = M->getOrInsertFunction(Name, B.getPtrTy(), Num->getType(), B.getInt8Ty()); inferNonMandatoryLibFuncAttrs(M, Name, *TLI); CallInst *CI = B.CreateCall(Func, {Num, B.getInt8(HotCold)}, Name); @@ -1971,7 +1966,7 @@ Value *llvm::emitHotColdNewNoThrow(Value *Num, Value *NoThrow, IRBuilderBase &B, StringRef Name = TLI->getName(NewFunc); FunctionCallee Func = - M->getOrInsertFunction(Name, B.getInt8PtrTy(), Num->getType(), + M->getOrInsertFunction(Name, B.getPtrTy(), Num->getType(), NoThrow->getType(), B.getInt8Ty()); inferNonMandatoryLibFuncAttrs(M, Name, *TLI); CallInst *CI = B.CreateCall(Func, {Num, NoThrow, B.getInt8(HotCold)}, Name); @@ -1992,7 +1987,7 @@ Value *llvm::emitHotColdNewAligned(Value *Num, Value *Align, IRBuilderBase &B, StringRef Name = TLI->getName(NewFunc); FunctionCallee Func = M->getOrInsertFunction( - Name, B.getInt8PtrTy(), Num->getType(), Align->getType(), B.getInt8Ty()); + Name, B.getPtrTy(), Num->getType(), Align->getType(), B.getInt8Ty()); inferNonMandatoryLibFuncAttrs(M, Name, *TLI); CallInst *CI = B.CreateCall(Func, {Num, Align, B.getInt8(HotCold)}, Name); @@ -2013,7 +2008,7 @@ Value *llvm::emitHotColdNewAlignedNoThrow(Value *Num, Value *Align, StringRef Name = TLI->getName(NewFunc); FunctionCallee Func = M->getOrInsertFunction( - Name, B.getInt8PtrTy(), Num->getType(), Align->getType(), + Name, B.getPtrTy(), Num->getType(), Align->getType(), NoThrow->getType(), B.getInt8Ty()); inferNonMandatoryLibFuncAttrs(M, Name, *TLI); CallInst *CI = diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index b488e3bb0cbd..e42cdab64446 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -111,7 +111,7 @@ static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst, if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty()) return; - Builder.SetInsertPoint(&MergeBlock->front()); + Builder.SetInsertPoint(MergeBlock, MergeBlock->begin()); PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0); SmallVector<User *, 16> UsersToUpdate(OrigInst->users()); for (User *U : UsersToUpdate) diff --git a/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp b/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp index a1ee3df907ec..fb4d82885377 100644 --- a/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp +++ b/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp @@ -30,6 +30,7 @@ #include "llvm/Transforms/Utils/CanonicalizeFreezeInLoops.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopAnalysisManager.h" diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp index d55208602b71..c0f333364fa5 100644 --- a/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -44,6 +44,7 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, ClonedCodeInfo *CodeInfo, DebugInfoFinder *DIFinder) { BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "", F); + NewBB->IsNewDbgInfoFormat = BB->IsNewDbgInfoFormat; if (BB->hasName()) NewBB->setName(BB->getName() + NameSuffix); @@ -58,7 +59,10 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, Instruction *NewInst = I.clone(); if (I.hasName()) NewInst->setName(I.getName() + NameSuffix); - NewInst->insertInto(NewBB, NewBB->end()); + + NewInst->insertBefore(*NewBB, NewBB->end()); + NewInst->cloneDebugInfoFrom(&I); + VMap[&I] = NewInst; // Add instruction map to value. if (isa<CallInst>(I) && !I.isDebugOrPseudoInst()) { @@ -90,6 +94,7 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, const char *NameSuffix, ClonedCodeInfo *CodeInfo, ValueMapTypeRemapper *TypeMapper, ValueMaterializer *Materializer) { + NewFunc->setIsNewDbgInfoFormat(OldFunc->IsNewDbgInfoFormat); assert(NameSuffix && "NameSuffix cannot be null!"); #ifndef NDEBUG @@ -267,9 +272,13 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, BB = cast<BasicBlock>(VMap[&OldFunc->front()])->getIterator(), BE = NewFunc->end(); BB != BE; ++BB) - // Loop over all instructions, fixing each one as we find it... - for (Instruction &II : *BB) + // Loop over all instructions, fixing each one as we find it, and any + // attached debug-info records. + for (Instruction &II : *BB) { RemapInstruction(&II, VMap, RemapFlag, TypeMapper, Materializer); + RemapDPValueRange(II.getModule(), II.getDbgValueRange(), VMap, RemapFlag, + TypeMapper, Materializer); + } // Only update !llvm.dbg.cu for DifferentModule (not CloneModule). In the // same module, the compile unit will already be listed (or not). When @@ -327,6 +336,7 @@ Function *llvm::CloneFunction(Function *F, ValueToValueMapTy &VMap, // Create the new function... Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), F->getName(), F->getParent()); + NewF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat); // Loop over the arguments, copying the names of the mapped arguments over... Function::arg_iterator DestI = NewF->arg_begin(); @@ -472,6 +482,7 @@ void PruningFunctionCloner::CloneBlock( BasicBlock *NewBB; Twine NewName(BB->hasName() ? Twine(BB->getName()) + NameSuffix : ""); BBEntry = NewBB = BasicBlock::Create(BB->getContext(), NewName, NewFunc); + NewBB->IsNewDbgInfoFormat = BB->IsNewDbgInfoFormat; // It is only legal to clone a function if a block address within that // function is never referenced outside of the function. Given that, we @@ -491,6 +502,22 @@ void PruningFunctionCloner::CloneBlock( bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false; bool hasMemProfMetadata = false; + // Keep a cursor pointing at the last place we cloned debug-info records from. + BasicBlock::const_iterator DbgCursor = StartingInst; + auto CloneDbgRecordsToHere = + [NewBB, &DbgCursor](Instruction *NewInst, BasicBlock::const_iterator II) { + if (!NewBB->IsNewDbgInfoFormat) + return; + + // Clone debug-info records onto this instruction. Iterate through any + // source-instructions we've cloned and then subsequently optimised + // away, so that their debug-info doesn't go missing. + for (; DbgCursor != II; ++DbgCursor) + NewInst->cloneDebugInfoFrom(&*DbgCursor, std::nullopt, false); + NewInst->cloneDebugInfoFrom(&*II); + DbgCursor = std::next(II); + }; + // Loop over all instructions, and copy them over, DCE'ing as we go. This // loop doesn't include the terminator. for (BasicBlock::const_iterator II = StartingInst, IE = --BB->end(); II != IE; @@ -540,6 +567,8 @@ void PruningFunctionCloner::CloneBlock( hasMemProfMetadata |= II->hasMetadata(LLVMContext::MD_memprof); } + CloneDbgRecordsToHere(NewInst, II); + if (CodeInfo) { CodeInfo->OrigVMap[&*II] = NewInst; if (auto *CB = dyn_cast<CallBase>(&*II)) @@ -597,6 +626,9 @@ void PruningFunctionCloner::CloneBlock( if (OldTI->hasName()) NewInst->setName(OldTI->getName() + NameSuffix); NewInst->insertInto(NewBB, NewBB->end()); + + CloneDbgRecordsToHere(NewInst, OldTI->getIterator()); + VMap[OldTI] = NewInst; // Add instruction map to value. if (CodeInfo) { @@ -608,6 +640,13 @@ void PruningFunctionCloner::CloneBlock( // Recursively clone any reachable successor blocks. append_range(ToClone, successors(BB->getTerminator())); + } else { + // If we didn't create a new terminator, clone DPValues from the old + // terminator onto the new terminator. + Instruction *NewInst = NewBB->getTerminator(); + assert(NewInst); + + CloneDbgRecordsToHere(NewInst, OldTI->getIterator()); } if (CodeInfo) { @@ -845,12 +884,22 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, TypeMapper, Materializer); } + // Do the same for DPValues, touching all the instructions in the cloned + // range of blocks. + Function::iterator Begin = cast<BasicBlock>(VMap[StartingBB])->getIterator(); + for (BasicBlock &BB : make_range(Begin, NewFunc->end())) { + for (Instruction &I : BB) { + RemapDPValueRange(I.getModule(), I.getDbgValueRange(), VMap, + ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, + TypeMapper, Materializer); + } + } + // Simplify conditional branches and switches with a constant operand. We try // to prune these out when cloning, but if the simplification required // looking through PHI nodes, those are only available after forming the full // basic block. That may leave some here, and we still want to prune the dead // code as early as possible. - Function::iterator Begin = cast<BasicBlock>(VMap[StartingBB])->getIterator(); for (BasicBlock &BB : make_range(Begin, NewFunc->end())) ConstantFoldTerminator(&BB); @@ -939,10 +988,14 @@ void llvm::CloneAndPruneFunctionInto( void llvm::remapInstructionsInBlocks(ArrayRef<BasicBlock *> Blocks, ValueToValueMapTy &VMap) { // Rewrite the code to refer to itself. - for (auto *BB : Blocks) - for (auto &Inst : *BB) + for (auto *BB : Blocks) { + for (auto &Inst : *BB) { + RemapDPValueRange(Inst.getModule(), Inst.getDbgValueRange(), VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); RemapInstruction(&Inst, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + } + } } /// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p @@ -1066,6 +1119,7 @@ BasicBlock *llvm::DuplicateInstructionsInSplitBetween( Instruction *New = BI->clone(); New->setName(BI->getName()); New->insertBefore(NewTerm); + New->cloneDebugInfoFrom(&*BI); ValueMapping[&*BI] = New; // Remap operands to patch up intra-block references. diff --git a/llvm/lib/Transforms/Utils/CloneModule.cpp b/llvm/lib/Transforms/Utils/CloneModule.cpp index 55e051298a9a..00e40fe73d90 100644 --- a/llvm/lib/Transforms/Utils/CloneModule.cpp +++ b/llvm/lib/Transforms/Utils/CloneModule.cpp @@ -34,6 +34,8 @@ static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) { /// copies of global variables and functions, and making their (initializers and /// references, respectively) refer to the right globals. /// +/// Cloning un-materialized modules is not currently supported, so any +/// modules initialized via lazy loading should be materialized before cloning std::unique_ptr<Module> llvm::CloneModule(const Module &M) { // Create the value map that maps things from the old module over to the new // module. @@ -49,6 +51,9 @@ std::unique_ptr<Module> llvm::CloneModule(const Module &M, std::unique_ptr<Module> llvm::CloneModule( const Module &M, ValueToValueMapTy &VMap, function_ref<bool(const GlobalValue *)> ShouldCloneDefinition) { + + assert(M.isMaterialized() && "Module must be materialized before cloning!"); + // First off, we need to create the new module. std::unique_ptr<Module> New = std::make_unique<Module>(M.getModuleIdentifier(), M.getContext()); @@ -56,6 +61,7 @@ std::unique_ptr<Module> llvm::CloneModule( New->setDataLayout(M.getDataLayout()); New->setTargetTriple(M.getTargetTriple()); New->setModuleInlineAsm(M.getModuleInlineAsm()); + New->IsNewDbgInfoFormat = M.IsNewDbgInfoFormat; // Loop over all of the global variables, making corresponding globals in the // new module. Here we add them to the VMap and to the new Module. We diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index c390af351a69..9c1186232e02 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, AssumptionCache *AC, bool AllowVarArgs, bool AllowAlloca, - BasicBlock *AllocationBlock, std::string Suffix) + BasicBlock *AllocationBlock, std::string Suffix, + bool ArgsInZeroAddressSpace) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AC(AC), AllocationBlock(AllocationBlock), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), - Suffix(Suffix) {} + Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, @@ -567,7 +568,7 @@ void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC, for (Instruction *I : LifetimeBitcastUsers) { Module *M = AIFunc->getParent(); LLVMContext &Ctx = M->getContext(); - auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); + auto *Int8PtrTy = PointerType::getUnqual(Ctx); CastInst *CastI = CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I); I->replaceUsesOfWith(I->getOperand(1), CastI); @@ -721,7 +722,8 @@ void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) { // Create a new PHI node in the new region, which has an incoming value // from OldPred of PN. PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion, - PN->getName() + ".ce", &NewBB->front()); + PN->getName() + ".ce"); + NewPN->insertBefore(NewBB->begin()); PN->replaceAllUsesWith(NewPN); NewPN->addIncoming(PN, OldPred); @@ -766,6 +768,7 @@ void CodeExtractor::severSplitPHINodesOfExits( NewBB = BasicBlock::Create(ExitBB->getContext(), ExitBB->getName() + ".split", ExitBB->getParent(), ExitBB); + NewBB->IsNewDbgInfoFormat = ExitBB->IsNewDbgInfoFormat; SmallVector<BasicBlock *, 4> Preds(predecessors(ExitBB)); for (BasicBlock *PredBB : Preds) if (Blocks.count(PredBB)) @@ -775,9 +778,9 @@ void CodeExtractor::severSplitPHINodesOfExits( } // Split this PHI. - PHINode *NewPN = - PHINode::Create(PN.getType(), IncomingVals.size(), - PN.getName() + ".ce", NewBB->getFirstNonPHI()); + PHINode *NewPN = PHINode::Create(PN.getType(), IncomingVals.size(), + PN.getName() + ".ce"); + NewPN->insertBefore(NewBB->getFirstNonPHIIt()); for (unsigned i : IncomingVals) NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i)); for (unsigned i : reverse(IncomingVals)) @@ -865,7 +868,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, StructType *StructTy = nullptr; if (AggregateArgs && !AggParamTy.empty()) { StructTy = StructType::get(M->getContext(), AggParamTy); - ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace())); + ParamTy.push_back(PointerType::get( + StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace())); } LLVM_DEBUG({ @@ -886,6 +890,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, Function *newFunction = Function::Create( funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(), oldFunction->getName() + "." + SuffixToUse, M); + newFunction->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; // Inherit all of the target dependent attributes and white-listed // target independent attributes. @@ -919,6 +924,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::PresplitCoroutine: case Attribute::Memory: case Attribute::NoFPClass: + case Attribute::CoroDestroyOnlyWhenComplete: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: @@ -940,6 +946,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::NoSanitizeBounds: case Attribute::NoSanitizeCoverage: case Attribute::NullPointerIsValid: + case Attribute::OptimizeForDebugging: case Attribute::OptForFuzzing: case Attribute::OptimizeNone: case Attribute::OptimizeForSize: @@ -990,6 +997,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::ImmArg: case Attribute::ByRef: case Attribute::WriteOnly: + case Attribute::Writable: // These are not really attributes. case Attribute::None: case Attribute::EndAttrKinds: @@ -1185,8 +1193,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg", AllocationBlock ? &*AllocationBlock->getFirstInsertionPt() : &codeReplacer->getParent()->front().front()); - params.push_back(Struct); + if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { + auto *StructSpaceCast = new AddrSpaceCastInst( + Struct, PointerType ::get(Context, 0), "structArg.ascast"); + StructSpaceCast->insertAfter(Struct); + params.push_back(StructSpaceCast); + } else { + params.push_back(Struct); + } // Store aggregated inputs in the struct. for (unsigned i = 0, e = StructValues.size(); i != e; ++i) { if (inputs.contains(StructValues[i])) { @@ -1492,10 +1507,14 @@ void CodeExtractor::calculateNewCallTerminatorWeights( static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) { for (Instruction &I : instructions(F)) { SmallVector<DbgVariableIntrinsic *, 4> DbgUsers; - findDbgUsers(DbgUsers, &I); + SmallVector<DPValue *, 4> DPValues; + findDbgUsers(DbgUsers, &I, &DPValues); for (DbgVariableIntrinsic *DVI : DbgUsers) if (DVI->getFunction() != &F) DVI->eraseFromParent(); + for (DPValue *DPV : DPValues) + if (DPV->getFunction() != &F) + DPV->eraseFromParent(); } } @@ -1531,6 +1550,16 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, /*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags); NewFunc.setSubprogram(NewSP); + auto IsInvalidLocation = [&NewFunc](Value *Location) { + // Location is invalid if it isn't a constant or an instruction, or is an + // instruction but isn't in the new function. + if (!Location || + (!isa<Constant>(Location) && !isa<Instruction>(Location))) + return true; + Instruction *LocationInst = dyn_cast<Instruction>(Location); + return LocationInst && LocationInst->getFunction() != &NewFunc; + }; + // Debug intrinsics in the new function need to be updated in one of two // ways: // 1) They need to be deleted, because they describe a value in the old @@ -1539,8 +1568,41 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, // point to a variable in the wrong scope. SmallDenseMap<DINode *, DINode *> RemappedMetadata; SmallVector<Instruction *, 4> DebugIntrinsicsToDelete; + SmallVector<DPValue *, 4> DPVsToDelete; DenseMap<const MDNode *, MDNode *> Cache; + + auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar) { + DINode *&NewVar = RemappedMetadata[OldVar]; + if (!NewVar) { + DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram( + *OldVar->getScope(), *NewSP, Ctx, Cache); + NewVar = DIB.createAutoVariable( + NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(), + OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero, + OldVar->getAlignInBits()); + } + return cast<DILocalVariable>(NewVar); + }; + + auto UpdateDPValuesOnInst = [&](Instruction &I) -> void { + for (auto &DPV : I.getDbgValueRange()) { + // Apply the two updates that dbg.values get: invalid operands, and + // variable metadata fixup. + // FIXME: support dbg.assign form of DPValues. + if (any_of(DPV.location_ops(), IsInvalidLocation)) { + DPVsToDelete.push_back(&DPV); + continue; + } + if (!DPV.getDebugLoc().getInlinedAt()) + DPV.setVariable(GetUpdatedDIVariable(DPV.getVariable())); + DPV.setDebugLoc(DebugLoc::replaceInlinedAtSubprogram(DPV.getDebugLoc(), + *NewSP, Ctx, Cache)); + } + }; + for (Instruction &I : instructions(NewFunc)) { + UpdateDPValuesOnInst(I); + auto *DII = dyn_cast<DbgInfoIntrinsic>(&I); if (!DII) continue; @@ -1562,41 +1624,28 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, continue; } - auto IsInvalidLocation = [&NewFunc](Value *Location) { - // Location is invalid if it isn't a constant or an instruction, or is an - // instruction but isn't in the new function. - if (!Location || - (!isa<Constant>(Location) && !isa<Instruction>(Location))) - return true; - Instruction *LocationInst = dyn_cast<Instruction>(Location); - return LocationInst && LocationInst->getFunction() != &NewFunc; - }; - auto *DVI = cast<DbgVariableIntrinsic>(DII); // If any of the used locations are invalid, delete the intrinsic. if (any_of(DVI->location_ops(), IsInvalidLocation)) { DebugIntrinsicsToDelete.push_back(DVI); continue; } + // DbgAssign intrinsics have an extra Value argument: + if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI); + DAI && IsInvalidLocation(DAI->getAddress())) { + DebugIntrinsicsToDelete.push_back(DVI); + continue; + } // If the variable was in the scope of the old function, i.e. it was not // inlined, point the intrinsic to a fresh variable within the new function. - if (!DVI->getDebugLoc().getInlinedAt()) { - DILocalVariable *OldVar = DVI->getVariable(); - DINode *&NewVar = RemappedMetadata[OldVar]; - if (!NewVar) { - DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram( - *OldVar->getScope(), *NewSP, Ctx, Cache); - NewVar = DIB.createAutoVariable( - NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(), - OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero, - OldVar->getAlignInBits()); - } - DVI->setVariable(cast<DILocalVariable>(NewVar)); - } + if (!DVI->getDebugLoc().getInlinedAt()) + DVI->setVariable(GetUpdatedDIVariable(DVI->getVariable())); } for (auto *DII : DebugIntrinsicsToDelete) DII->eraseFromParent(); + for (auto *DPV : DPVsToDelete) + DPV->getMarker()->MarkedInstr->dropOneDbgValue(DPV); DIB.finalizeSubprogram(NewSP); // Fix up the scope information attached to the line locations in the new @@ -1702,11 +1751,14 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, header); + codeReplacer->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; // The new function needs a root node because other nodes can branch to the // head of the region, but the entry node of a function cannot have preds. BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), "newFuncRoot"); + newFuncRoot->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; + auto *BranchI = BranchInst::Create(header); // If the original function has debug info, we have to add a debug location // to the new branch instruction from the artificial entry block. @@ -1772,11 +1824,11 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, // Update the entry count of the function. if (BFI) { - auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + auto Count = BFI->getProfileCountFromFreq(EntryFreq); if (Count) newFunction->setEntryCount( ProfileCount(*Count, Function::PCT_Real)); // FIXME - BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + BFI->setBlockFreq(codeReplacer, EntryFreq); } CallInst *TheCall = diff --git a/llvm/lib/Transforms/Utils/CodeLayout.cpp b/llvm/lib/Transforms/Utils/CodeLayout.cpp index ac74a1c116cc..95edd27c675d 100644 --- a/llvm/lib/Transforms/Utils/CodeLayout.cpp +++ b/llvm/lib/Transforms/Utils/CodeLayout.cpp @@ -45,8 +45,11 @@ #include "llvm/Support/Debug.h" #include <cmath> +#include <set> using namespace llvm; +using namespace llvm::codelayout; + #define DEBUG_TYPE "code-layout" namespace llvm { @@ -61,8 +64,8 @@ cl::opt<bool> ApplyExtTspWithoutProfile( cl::init(true), cl::Hidden); } // namespace llvm -// Algorithm-specific params. The values are tuned for the best performance -// of large-scale front-end bound binaries. +// Algorithm-specific params for Ext-TSP. The values are tuned for the best +// performance of large-scale front-end bound binaries. static cl::opt<double> ForwardWeightCond( "ext-tsp-forward-weight-cond", cl::ReallyHidden, cl::init(0.1), cl::desc("The weight of conditional forward jumps for ExtTSP value")); @@ -96,10 +99,10 @@ static cl::opt<unsigned> BackwardDistance( cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP")); // The maximum size of a chain created by the algorithm. The size is bounded -// so that the algorithm can efficiently process extremely large instance. +// so that the algorithm can efficiently process extremely large instances. static cl::opt<unsigned> - MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(4096), - cl::desc("The maximum size of a chain to create.")); + MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(512), + cl::desc("The maximum size of a chain to create")); // The maximum size of a chain for splitting. Larger values of the threshold // may yield better quality at the cost of worsen run-time. @@ -107,11 +110,29 @@ static cl::opt<unsigned> ChainSplitThreshold( "ext-tsp-chain-split-threshold", cl::ReallyHidden, cl::init(128), cl::desc("The maximum size of a chain to apply splitting")); -// The option enables splitting (large) chains along in-coming and out-going -// jumps. This typically results in a better quality. -static cl::opt<bool> EnableChainSplitAlongJumps( - "ext-tsp-enable-chain-split-along-jumps", cl::ReallyHidden, cl::init(true), - cl::desc("The maximum size of a chain to apply splitting")); +// The maximum ratio between densities of two chains for merging. +static cl::opt<double> MaxMergeDensityRatio( + "ext-tsp-max-merge-density-ratio", cl::ReallyHidden, cl::init(100), + cl::desc("The maximum ratio between densities of two chains for merging")); + +// Algorithm-specific options for CDSort. +static cl::opt<unsigned> CacheEntries("cdsort-cache-entries", cl::ReallyHidden, + cl::desc("The size of the cache")); + +static cl::opt<unsigned> CacheSize("cdsort-cache-size", cl::ReallyHidden, + cl::desc("The size of a line in the cache")); + +static cl::opt<unsigned> + CDMaxChainSize("cdsort-max-chain-size", cl::ReallyHidden, + cl::desc("The maximum size of a chain to create")); + +static cl::opt<double> DistancePower( + "cdsort-distance-power", cl::ReallyHidden, + cl::desc("The power exponent for the distance-based locality")); + +static cl::opt<double> FrequencyScale( + "cdsort-frequency-scale", cl::ReallyHidden, + cl::desc("The scale factor for the frequency-based locality")); namespace { @@ -199,11 +220,14 @@ struct NodeT { NodeT &operator=(const NodeT &) = delete; NodeT &operator=(NodeT &&) = default; - explicit NodeT(size_t Index, uint64_t Size, uint64_t EC) - : Index(Index), Size(Size), ExecutionCount(EC) {} + explicit NodeT(size_t Index, uint64_t Size, uint64_t Count) + : Index(Index), Size(Size), ExecutionCount(Count) {} bool isEntry() const { return Index == 0; } + // Check if Other is a successor of the node. + bool isSuccessor(const NodeT *Other) const; + // The total execution count of outgoing jumps. uint64_t outCount() const; @@ -267,7 +291,7 @@ struct ChainT { size_t numBlocks() const { return Nodes.size(); } - double density() const { return static_cast<double>(ExecutionCount) / Size; } + double density() const { return ExecutionCount / Size; } bool isEntry() const { return Nodes[0]->Index == 0; } @@ -280,9 +304,9 @@ struct ChainT { } ChainEdge *getEdge(ChainT *Other) const { - for (auto It : Edges) { - if (It.first == Other) - return It.second; + for (const auto &[Chain, ChainEdge] : Edges) { + if (Chain == Other) + return ChainEdge; } return nullptr; } @@ -302,13 +326,13 @@ struct ChainT { Edges.push_back(std::make_pair(Other, Edge)); } - void merge(ChainT *Other, const std::vector<NodeT *> &MergedBlocks) { - Nodes = MergedBlocks; - // Update the chain's data + void merge(ChainT *Other, std::vector<NodeT *> MergedBlocks) { + Nodes = std::move(MergedBlocks); + // Update the chain's data. ExecutionCount += Other->ExecutionCount; Size += Other->Size; Id = Nodes[0]->Index; - // Update the node's data + // Update the node's data. for (size_t Idx = 0; Idx < Nodes.size(); Idx++) { Nodes[Idx]->CurChain = this; Nodes[Idx]->CurIndex = Idx; @@ -328,8 +352,9 @@ struct ChainT { uint64_t Id; // Cached ext-tsp score for the chain. double Score{0}; - // The total execution count of the chain. - uint64_t ExecutionCount{0}; + // The total execution count of the chain. Since the execution count of + // a basic block is uint64_t, using doubles here to avoid overflow. + double ExecutionCount{0}; // The total size of the chain. uint64_t Size{0}; // Nodes of the chain. @@ -340,7 +365,7 @@ struct ChainT { /// An edge in the graph representing jumps between two chains. /// When nodes are merged into chains, the edges are combined too so that -/// there is always at most one edge between a pair of chains +/// there is always at most one edge between a pair of chains. struct ChainEdge { ChainEdge(const ChainEdge &) = delete; ChainEdge(ChainEdge &&) = default; @@ -424,53 +449,57 @@ private: bool CacheValidBackward{false}; }; +bool NodeT::isSuccessor(const NodeT *Other) const { + for (JumpT *Jump : OutJumps) + if (Jump->Target == Other) + return true; + return false; +} + uint64_t NodeT::outCount() const { uint64_t Count = 0; - for (JumpT *Jump : OutJumps) { + for (JumpT *Jump : OutJumps) Count += Jump->ExecutionCount; - } return Count; } uint64_t NodeT::inCount() const { uint64_t Count = 0; - for (JumpT *Jump : InJumps) { + for (JumpT *Jump : InJumps) Count += Jump->ExecutionCount; - } return Count; } void ChainT::mergeEdges(ChainT *Other) { - // Update edges adjacent to chain Other - for (auto EdgeIt : Other->Edges) { - ChainT *DstChain = EdgeIt.first; - ChainEdge *DstEdge = EdgeIt.second; + // Update edges adjacent to chain Other. + for (const auto &[DstChain, DstEdge] : Other->Edges) { ChainT *TargetChain = DstChain == Other ? this : DstChain; ChainEdge *CurEdge = getEdge(TargetChain); if (CurEdge == nullptr) { DstEdge->changeEndpoint(Other, this); this->addEdge(TargetChain, DstEdge); - if (DstChain != this && DstChain != Other) { + if (DstChain != this && DstChain != Other) DstChain->addEdge(this, DstEdge); - } } else { CurEdge->moveJumps(DstEdge); } - // Cleanup leftover edge - if (DstChain != Other) { + // Cleanup leftover edge. + if (DstChain != Other) DstChain->removeEdge(Other); - } } } using NodeIter = std::vector<NodeT *>::const_iterator; +static std::vector<NodeT *> EmptyList; -/// A wrapper around three chains of nodes; it is used to avoid extra -/// instantiation of the vectors. -struct MergedChain { - MergedChain(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(), - NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(), - NodeIter End3 = NodeIter()) +/// A wrapper around three concatenated vectors (chains) of nodes; it is used +/// to avoid extra instantiation of the vectors. +struct MergedNodesT { + MergedNodesT(NodeIter Begin1, NodeIter End1, + NodeIter Begin2 = EmptyList.begin(), + NodeIter End2 = EmptyList.end(), + NodeIter Begin3 = EmptyList.begin(), + NodeIter End3 = EmptyList.end()) : Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3), End3(End3) {} @@ -504,15 +533,35 @@ private: NodeIter End3; }; +/// A wrapper around two concatenated vectors (chains) of jumps. +struct MergedJumpsT { + MergedJumpsT(const std::vector<JumpT *> *Jumps1, + const std::vector<JumpT *> *Jumps2 = nullptr) { + assert(!Jumps1->empty() && "cannot merge empty jump list"); + JumpArray[0] = Jumps1; + JumpArray[1] = Jumps2; + } + + template <typename F> void forEach(const F &Func) const { + for (auto Jumps : JumpArray) + if (Jumps != nullptr) + for (JumpT *Jump : *Jumps) + Func(Jump); + } + +private: + std::array<const std::vector<JumpT *> *, 2> JumpArray{nullptr, nullptr}; +}; + /// Merge two chains of nodes respecting a given 'type' and 'offset'. /// /// If MergeType == 0, then the result is a concatenation of two chains. /// Otherwise, the first chain is cut into two sub-chains at the offset, /// and merged using all possible ways of concatenating three chains. -MergedChain mergeNodes(const std::vector<NodeT *> &X, - const std::vector<NodeT *> &Y, size_t MergeOffset, - MergeTypeT MergeType) { - // Split the first chain, X, into X1 and X2 +MergedNodesT mergeNodes(const std::vector<NodeT *> &X, + const std::vector<NodeT *> &Y, size_t MergeOffset, + MergeTypeT MergeType) { + // Split the first chain, X, into X1 and X2. NodeIter BeginX1 = X.begin(); NodeIter EndX1 = X.begin() + MergeOffset; NodeIter BeginX2 = X.begin() + MergeOffset; @@ -520,18 +569,18 @@ MergedChain mergeNodes(const std::vector<NodeT *> &X, NodeIter BeginY = Y.begin(); NodeIter EndY = Y.end(); - // Construct a new chain from the three existing ones + // Construct a new chain from the three existing ones. switch (MergeType) { case MergeTypeT::X_Y: - return MergedChain(BeginX1, EndX2, BeginY, EndY); + return MergedNodesT(BeginX1, EndX2, BeginY, EndY); case MergeTypeT::Y_X: - return MergedChain(BeginY, EndY, BeginX1, EndX2); + return MergedNodesT(BeginY, EndY, BeginX1, EndX2); case MergeTypeT::X1_Y_X2: - return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2); + return MergedNodesT(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2); case MergeTypeT::Y_X2_X1: - return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1); + return MergedNodesT(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1); case MergeTypeT::X2_X1_Y: - return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY); + return MergedNodesT(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY); } llvm_unreachable("unexpected chain merge type"); } @@ -539,15 +588,14 @@ MergedChain mergeNodes(const std::vector<NodeT *> &X, /// The implementation of the ExtTSP algorithm. class ExtTSPImpl { public: - ExtTSPImpl(const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<EdgeCountT> &EdgeCounts) + ExtTSPImpl(ArrayRef<uint64_t> NodeSizes, ArrayRef<uint64_t> NodeCounts, + ArrayRef<EdgeCount> EdgeCounts) : NumNodes(NodeSizes.size()) { initialize(NodeSizes, NodeCounts, EdgeCounts); } /// Run the algorithm and return an optimized ordering of nodes. - void run(std::vector<uint64_t> &Result) { + std::vector<uint64_t> run() { // Pass 1: Merge nodes with their mutually forced successors mergeForcedPairs(); @@ -558,78 +606,80 @@ public: mergeColdChains(); // Collect nodes from all chains - concatChains(Result); + return concatChains(); } private: /// Initialize the algorithm's data structures. - void initialize(const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<EdgeCountT> &EdgeCounts) { - // Initialize nodes + void initialize(const ArrayRef<uint64_t> &NodeSizes, + const ArrayRef<uint64_t> &NodeCounts, + const ArrayRef<EdgeCount> &EdgeCounts) { + // Initialize nodes. AllNodes.reserve(NumNodes); for (uint64_t Idx = 0; Idx < NumNodes; Idx++) { uint64_t Size = std::max<uint64_t>(NodeSizes[Idx], 1ULL); uint64_t ExecutionCount = NodeCounts[Idx]; - // The execution count of the entry node is set to at least one + // The execution count of the entry node is set to at least one. if (Idx == 0 && ExecutionCount == 0) ExecutionCount = 1; AllNodes.emplace_back(Idx, Size, ExecutionCount); } - // Initialize jumps between nodes + // Initialize jumps between the nodes. SuccNodes.resize(NumNodes); PredNodes.resize(NumNodes); std::vector<uint64_t> OutDegree(NumNodes, 0); AllJumps.reserve(EdgeCounts.size()); - for (auto It : EdgeCounts) { - uint64_t Pred = It.first.first; - uint64_t Succ = It.first.second; - OutDegree[Pred]++; - // Ignore self-edges - if (Pred == Succ) + for (auto Edge : EdgeCounts) { + ++OutDegree[Edge.src]; + // Ignore self-edges. + if (Edge.src == Edge.dst) continue; - SuccNodes[Pred].push_back(Succ); - PredNodes[Succ].push_back(Pred); - uint64_t ExecutionCount = It.second; - if (ExecutionCount > 0) { - NodeT &PredNode = AllNodes[Pred]; - NodeT &SuccNode = AllNodes[Succ]; - AllJumps.emplace_back(&PredNode, &SuccNode, ExecutionCount); + SuccNodes[Edge.src].push_back(Edge.dst); + PredNodes[Edge.dst].push_back(Edge.src); + if (Edge.count > 0) { + NodeT &PredNode = AllNodes[Edge.src]; + NodeT &SuccNode = AllNodes[Edge.dst]; + AllJumps.emplace_back(&PredNode, &SuccNode, Edge.count); SuccNode.InJumps.push_back(&AllJumps.back()); PredNode.OutJumps.push_back(&AllJumps.back()); + // Adjust execution counts. + PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Edge.count); + SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Edge.count); } } for (JumpT &Jump : AllJumps) { - assert(OutDegree[Jump.Source->Index] > 0); + assert(OutDegree[Jump.Source->Index] > 0 && + "incorrectly computed out-degree of the block"); Jump.IsConditional = OutDegree[Jump.Source->Index] > 1; } - // Initialize chains + // Initialize chains. AllChains.reserve(NumNodes); HotChains.reserve(NumNodes); for (NodeT &Node : AllNodes) { + // Create a chain. AllChains.emplace_back(Node.Index, &Node); Node.CurChain = &AllChains.back(); - if (Node.ExecutionCount > 0) { + if (Node.ExecutionCount > 0) HotChains.push_back(&AllChains.back()); - } } - // Initialize chain edges + // Initialize chain edges. AllEdges.reserve(AllJumps.size()); for (NodeT &PredNode : AllNodes) { for (JumpT *Jump : PredNode.OutJumps) { + assert(Jump->ExecutionCount > 0 && "incorrectly initialized jump"); NodeT *SuccNode = Jump->Target; ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain); - // this edge is already present in the graph + // This edge is already present in the graph. if (CurEdge != nullptr) { assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr); CurEdge->appendJump(Jump); continue; } - // this is a new edge + // This is a new edge. AllEdges.emplace_back(Jump); PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back()); SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back()); @@ -642,7 +692,7 @@ private: /// to B are from A. Such nodes should be adjacent in the optimal ordering; /// the method finds and merges such pairs of nodes. void mergeForcedPairs() { - // Find fallthroughs based on edge weights + // Find forced pairs of blocks. for (NodeT &Node : AllNodes) { if (SuccNodes[Node.Index].size() == 1 && PredNodes[SuccNodes[Node.Index][0]].size() == 1 && @@ -669,12 +719,12 @@ private: } if (SuccNode == nullptr) continue; - // Break the cycle + // Break the cycle. AllNodes[Node.ForcedPred->Index].ForcedSucc = nullptr; Node.ForcedPred = nullptr; } - // Merge nodes with their fallthrough successors + // Merge nodes with their fallthrough successors. for (NodeT &Node : AllNodes) { if (Node.ForcedPred == nullptr && Node.ForcedSucc != nullptr) { const NodeT *CurBlock = &Node; @@ -689,33 +739,42 @@ private: /// Merge pairs of chains while improving the ExtTSP objective. void mergeChainPairs() { - /// Deterministically compare pairs of chains + /// Deterministically compare pairs of chains. auto compareChainPairs = [](const ChainT *A1, const ChainT *B1, const ChainT *A2, const ChainT *B2) { - if (A1 != A2) - return A1->Id < A2->Id; - return B1->Id < B2->Id; + return std::make_tuple(A1->Id, B1->Id) < std::make_tuple(A2->Id, B2->Id); }; while (HotChains.size() > 1) { ChainT *BestChainPred = nullptr; ChainT *BestChainSucc = nullptr; MergeGainT BestGain; - // Iterate over all pairs of chains + // Iterate over all pairs of chains. for (ChainT *ChainPred : HotChains) { - // Get candidates for merging with the current chain - for (auto EdgeIt : ChainPred->Edges) { - ChainT *ChainSucc = EdgeIt.first; - ChainEdge *Edge = EdgeIt.second; - // Ignore loop edges - if (ChainPred == ChainSucc) + // Get candidates for merging with the current chain. + for (const auto &[ChainSucc, Edge] : ChainPred->Edges) { + // Ignore loop edges. + if (Edge->isSelfEdge()) continue; - - // Stop early if the combined chain violates the maximum allowed size + // Skip the merge if the combined chain violates the maximum specified + // size. if (ChainPred->numBlocks() + ChainSucc->numBlocks() >= MaxChainSize) continue; + // Don't merge the chains if they have vastly different densities. + // Skip the merge if the ratio between the densities exceeds + // MaxMergeDensityRatio. Smaller values of the option result in fewer + // merges, and hence, more chains. + const double ChainPredDensity = ChainPred->density(); + const double ChainSuccDensity = ChainSucc->density(); + assert(ChainPredDensity > 0.0 && ChainSuccDensity > 0.0 && + "incorrectly computed chain densities"); + auto [MinDensity, MaxDensity] = + std::minmax(ChainPredDensity, ChainSuccDensity); + const double Ratio = MaxDensity / MinDensity; + if (Ratio > MaxMergeDensityRatio) + continue; - // Compute the gain of merging the two chains + // Compute the gain of merging the two chains. MergeGainT CurGain = getBestMergeGain(ChainPred, ChainSucc, Edge); if (CurGain.score() <= EPS) continue; @@ -731,11 +790,11 @@ private: } } - // Stop merging when there is no improvement + // Stop merging when there is no improvement. if (BestGain.score() <= EPS) break; - // Merge the best pair of chains + // Merge the best pair of chains. mergeChains(BestChainPred, BestChainSucc, BestGain.mergeOffset(), BestGain.mergeType()); } @@ -743,7 +802,7 @@ private: /// Merge remaining nodes into chains w/o taking jump counts into /// consideration. This allows to maintain the original node order in the - /// absence of profile data + /// absence of profile data. void mergeColdChains() { for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) { // Iterating in reverse order to make sure original fallthrough jumps are @@ -764,24 +823,22 @@ private: } /// Compute the Ext-TSP score for a given node order and a list of jumps. - double extTSPScore(const MergedChain &MergedBlocks, - const std::vector<JumpT *> &Jumps) const { - if (Jumps.empty()) - return 0.0; + double extTSPScore(const MergedNodesT &Nodes, + const MergedJumpsT &Jumps) const { uint64_t CurAddr = 0; - MergedBlocks.forEach([&](const NodeT *Node) { + Nodes.forEach([&](const NodeT *Node) { Node->EstimatedAddr = CurAddr; CurAddr += Node->Size; }); double Score = 0; - for (JumpT *Jump : Jumps) { + Jumps.forEach([&](const JumpT *Jump) { const NodeT *SrcBlock = Jump->Source; const NodeT *DstBlock = Jump->Target; Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size, DstBlock->EstimatedAddr, Jump->ExecutionCount, Jump->IsConditional); - } + }); return Score; } @@ -793,74 +850,76 @@ private: /// element being the corresponding merging type. MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc, ChainEdge *Edge) const { - if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) { + if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) return Edge->getCachedMergeGain(ChainPred, ChainSucc); - } - // Precompute jumps between ChainPred and ChainSucc - auto Jumps = Edge->jumps(); + assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps"); + // Precompute jumps between ChainPred and ChainSucc. ChainEdge *EdgePP = ChainPred->getEdge(ChainPred); - if (EdgePP != nullptr) { - Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end()); - } - assert(!Jumps.empty() && "trying to merge chains w/o jumps"); + MergedJumpsT Jumps(&Edge->jumps(), EdgePP ? &EdgePP->jumps() : nullptr); - // The object holds the best currently chosen gain of merging the two chains + // This object holds the best chosen gain of merging two chains. MergeGainT Gain = MergeGainT(); /// Given a merge offset and a list of merge types, try to merge two chains - /// and update Gain with a better alternative + /// and update Gain with a better alternative. auto tryChainMerging = [&](size_t Offset, const std::vector<MergeTypeT> &MergeTypes) { - // Skip merging corresponding to concatenation w/o splitting + // Skip merging corresponding to concatenation w/o splitting. if (Offset == 0 || Offset == ChainPred->Nodes.size()) return; - // Skip merging if it breaks Forced successors + // Skip merging if it breaks Forced successors. NodeT *Node = ChainPred->Nodes[Offset - 1]; if (Node->ForcedSucc != nullptr) return; // Apply the merge, compute the corresponding gain, and update the best - // value, if the merge is beneficial + // value, if the merge is beneficial. for (const MergeTypeT &MergeType : MergeTypes) { Gain.updateIfLessThan( computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType)); } }; - // Try to concatenate two chains w/o splitting + // Try to concatenate two chains w/o splitting. Gain.updateIfLessThan( computeMergeGain(ChainPred, ChainSucc, Jumps, 0, MergeTypeT::X_Y)); - if (EnableChainSplitAlongJumps) { - // Attach (a part of) ChainPred before the first node of ChainSucc - for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) { - const NodeT *SrcBlock = Jump->Source; - if (SrcBlock->CurChain != ChainPred) - continue; - size_t Offset = SrcBlock->CurIndex + 1; - tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y}); - } + // Attach (a part of) ChainPred before the first node of ChainSucc. + for (JumpT *Jump : ChainSucc->Nodes.front()->InJumps) { + const NodeT *SrcBlock = Jump->Source; + if (SrcBlock->CurChain != ChainPred) + continue; + size_t Offset = SrcBlock->CurIndex + 1; + tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::X2_X1_Y}); + } - // Attach (a part of) ChainPred after the last node of ChainSucc - for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) { - const NodeT *DstBlock = Jump->Source; - if (DstBlock->CurChain != ChainPred) - continue; - size_t Offset = DstBlock->CurIndex; - tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1}); - } + // Attach (a part of) ChainPred after the last node of ChainSucc. + for (JumpT *Jump : ChainSucc->Nodes.back()->OutJumps) { + const NodeT *DstBlock = Jump->Target; + if (DstBlock->CurChain != ChainPred) + continue; + size_t Offset = DstBlock->CurIndex; + tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1}); } - // Try to break ChainPred in various ways and concatenate with ChainSucc + // Try to break ChainPred in various ways and concatenate with ChainSucc. if (ChainPred->Nodes.size() <= ChainSplitThreshold) { for (size_t Offset = 1; Offset < ChainPred->Nodes.size(); Offset++) { - // Try to split the chain in different ways. In practice, applying - // X2_Y_X1 merging is almost never provides benefits; thus, we exclude - // it from consideration to reduce the search space + // Do not split the chain along a fall-through jump. One of the two + // loops above may still "break" such a jump whenever it results in a + // new fall-through. + const NodeT *BB = ChainPred->Nodes[Offset - 1]; + const NodeT *BB2 = ChainPred->Nodes[Offset]; + if (BB->isSuccessor(BB2)) + continue; + + // In practice, applying X2_Y_X1 merging almost never provides benefits; + // thus, we exclude it from consideration to reduce the search space. tryChainMerging(Offset, {MergeTypeT::X1_Y_X2, MergeTypeT::Y_X2_X1, MergeTypeT::X2_X1_Y}); } } + Edge->setCachedMergeGain(ChainPred, ChainSucc, Gain); return Gain; } @@ -870,19 +929,20 @@ private: /// /// The two chains are not modified in the method. MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc, - const std::vector<JumpT *> &Jumps, - size_t MergeOffset, MergeTypeT MergeType) const { - auto MergedBlocks = + const MergedJumpsT &Jumps, size_t MergeOffset, + MergeTypeT MergeType) const { + MergedNodesT MergedNodes = mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType); - // Do not allow a merge that does not preserve the original entry point + // Do not allow a merge that does not preserve the original entry point. if ((ChainPred->isEntry() || ChainSucc->isEntry()) && - !MergedBlocks.getFirstNode()->isEntry()) + !MergedNodes.getFirstNode()->isEntry()) return MergeGainT(); - // The gain for the new chain - auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->Score; - return MergeGainT(NewGainScore, MergeOffset, MergeType); + // The gain for the new chain. + double NewScore = extTSPScore(MergedNodes, Jumps); + double CurScore = ChainPred->Score; + return MergeGainT(NewScore - CurScore, MergeOffset, MergeType); } /// Merge chain From into chain Into, update the list of active chains, @@ -891,39 +951,398 @@ private: MergeTypeT MergeType) { assert(Into != From && "a chain cannot be merged with itself"); - // Merge the nodes - MergedChain MergedNodes = + // Merge the nodes. + MergedNodesT MergedNodes = mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType); Into->merge(From, MergedNodes.getNodes()); - // Merge the edges + // Merge the edges. Into->mergeEdges(From); From->clear(); - // Update cached ext-tsp score for the new chain + // Update cached ext-tsp score for the new chain. ChainEdge *SelfEdge = Into->getEdge(Into); if (SelfEdge != nullptr) { - MergedNodes = MergedChain(Into->Nodes.begin(), Into->Nodes.end()); - Into->Score = extTSPScore(MergedNodes, SelfEdge->jumps()); + MergedNodes = MergedNodesT(Into->Nodes.begin(), Into->Nodes.end()); + MergedJumpsT MergedJumps(&SelfEdge->jumps()); + Into->Score = extTSPScore(MergedNodes, MergedJumps); } - // Remove the chain from the list of active chains - llvm::erase_value(HotChains, From); + // Remove the chain from the list of active chains. + llvm::erase(HotChains, From); - // Invalidate caches + // Invalidate caches. for (auto EdgeIt : Into->Edges) EdgeIt.second->invalidateCache(); } /// Concatenate all chains into the final order. - void concatChains(std::vector<uint64_t> &Order) { - // Collect chains and calculate density stats for their sorting + std::vector<uint64_t> concatChains() { + // Collect non-empty chains. + std::vector<const ChainT *> SortedChains; + for (ChainT &Chain : AllChains) { + if (!Chain.Nodes.empty()) + SortedChains.push_back(&Chain); + } + + // Sorting chains by density in the decreasing order. + std::sort(SortedChains.begin(), SortedChains.end(), + [&](const ChainT *L, const ChainT *R) { + // Place the entry point at the beginning of the order. + if (L->isEntry() != R->isEntry()) + return L->isEntry(); + + // Compare by density and break ties by chain identifiers. + return std::make_tuple(-L->density(), L->Id) < + std::make_tuple(-R->density(), R->Id); + }); + + // Collect the nodes in the order specified by their chains. + std::vector<uint64_t> Order; + Order.reserve(NumNodes); + for (const ChainT *Chain : SortedChains) + for (NodeT *Node : Chain->Nodes) + Order.push_back(Node->Index); + return Order; + } + +private: + /// The number of nodes in the graph. + const size_t NumNodes; + + /// Successors of each node. + std::vector<std::vector<uint64_t>> SuccNodes; + + /// Predecessors of each node. + std::vector<std::vector<uint64_t>> PredNodes; + + /// All nodes (basic blocks) in the graph. + std::vector<NodeT> AllNodes; + + /// All jumps between the nodes. + std::vector<JumpT> AllJumps; + + /// All chains of nodes. + std::vector<ChainT> AllChains; + + /// All edges between the chains. + std::vector<ChainEdge> AllEdges; + + /// Active chains. The vector gets updated at runtime when chains are merged. + std::vector<ChainT *> HotChains; +}; + +/// The implementation of the Cache-Directed Sort (CDSort) algorithm for +/// ordering functions represented by a call graph. +class CDSortImpl { +public: + CDSortImpl(const CDSortConfig &Config, ArrayRef<uint64_t> NodeSizes, + ArrayRef<uint64_t> NodeCounts, ArrayRef<EdgeCount> EdgeCounts, + ArrayRef<uint64_t> EdgeOffsets) + : Config(Config), NumNodes(NodeSizes.size()) { + initialize(NodeSizes, NodeCounts, EdgeCounts, EdgeOffsets); + } + + /// Run the algorithm and return an ordered set of function clusters. + std::vector<uint64_t> run() { + // Merge pairs of chains while improving the objective. + mergeChainPairs(); + + // Collect nodes from all the chains. + return concatChains(); + } + +private: + /// Initialize the algorithm's data structures. + void initialize(const ArrayRef<uint64_t> &NodeSizes, + const ArrayRef<uint64_t> &NodeCounts, + const ArrayRef<EdgeCount> &EdgeCounts, + const ArrayRef<uint64_t> &EdgeOffsets) { + // Initialize nodes. + AllNodes.reserve(NumNodes); + for (uint64_t Node = 0; Node < NumNodes; Node++) { + uint64_t Size = std::max<uint64_t>(NodeSizes[Node], 1ULL); + uint64_t ExecutionCount = NodeCounts[Node]; + AllNodes.emplace_back(Node, Size, ExecutionCount); + TotalSamples += ExecutionCount; + if (ExecutionCount > 0) + TotalSize += Size; + } + + // Initialize jumps between the nodes. + SuccNodes.resize(NumNodes); + PredNodes.resize(NumNodes); + AllJumps.reserve(EdgeCounts.size()); + for (size_t I = 0; I < EdgeCounts.size(); I++) { + auto [Pred, Succ, Count] = EdgeCounts[I]; + // Ignore recursive calls. + if (Pred == Succ) + continue; + + SuccNodes[Pred].push_back(Succ); + PredNodes[Succ].push_back(Pred); + if (Count > 0) { + NodeT &PredNode = AllNodes[Pred]; + NodeT &SuccNode = AllNodes[Succ]; + AllJumps.emplace_back(&PredNode, &SuccNode, Count); + AllJumps.back().Offset = EdgeOffsets[I]; + SuccNode.InJumps.push_back(&AllJumps.back()); + PredNode.OutJumps.push_back(&AllJumps.back()); + // Adjust execution counts. + PredNode.ExecutionCount = std::max(PredNode.ExecutionCount, Count); + SuccNode.ExecutionCount = std::max(SuccNode.ExecutionCount, Count); + } + } + + // Initialize chains. + AllChains.reserve(NumNodes); + for (NodeT &Node : AllNodes) { + // Adjust execution counts. + Node.ExecutionCount = std::max(Node.ExecutionCount, Node.inCount()); + Node.ExecutionCount = std::max(Node.ExecutionCount, Node.outCount()); + // Create chain. + AllChains.emplace_back(Node.Index, &Node); + Node.CurChain = &AllChains.back(); + } + + // Initialize chain edges. + AllEdges.reserve(AllJumps.size()); + for (NodeT &PredNode : AllNodes) { + for (JumpT *Jump : PredNode.OutJumps) { + NodeT *SuccNode = Jump->Target; + ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain); + // This edge is already present in the graph. + if (CurEdge != nullptr) { + assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr); + CurEdge->appendJump(Jump); + continue; + } + // This is a new edge. + AllEdges.emplace_back(Jump); + PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back()); + SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back()); + } + } + } + + /// Merge pairs of chains while there is an improvement in the objective. + void mergeChainPairs() { + // Create a priority queue containing all edges ordered by the merge gain. + auto GainComparator = [](ChainEdge *L, ChainEdge *R) { + return std::make_tuple(-L->gain(), L->srcChain()->Id, L->dstChain()->Id) < + std::make_tuple(-R->gain(), R->srcChain()->Id, R->dstChain()->Id); + }; + std::set<ChainEdge *, decltype(GainComparator)> Queue(GainComparator); + + // Insert the edges into the queue. + [[maybe_unused]] size_t NumActiveChains = 0; + for (NodeT &Node : AllNodes) { + if (Node.ExecutionCount == 0) + continue; + ++NumActiveChains; + for (const auto &[_, Edge] : Node.CurChain->Edges) { + // Ignore self-edges. + if (Edge->isSelfEdge()) + continue; + // Ignore already processed edges. + if (Edge->gain() != -1.0) + continue; + + // Compute the gain of merging the two chains. + MergeGainT Gain = getBestMergeGain(Edge); + Edge->setMergeGain(Gain); + + if (Edge->gain() > EPS) + Queue.insert(Edge); + } + } + + // Merge the chains while the gain of merging is positive. + while (!Queue.empty()) { + // Extract the best (top) edge for merging. + ChainEdge *BestEdge = *Queue.begin(); + Queue.erase(Queue.begin()); + ChainT *BestSrcChain = BestEdge->srcChain(); + ChainT *BestDstChain = BestEdge->dstChain(); + + // Remove outdated edges from the queue. + for (const auto &[_, ChainEdge] : BestSrcChain->Edges) + Queue.erase(ChainEdge); + for (const auto &[_, ChainEdge] : BestDstChain->Edges) + Queue.erase(ChainEdge); + + // Merge the best pair of chains. + MergeGainT BestGain = BestEdge->getMergeGain(); + mergeChains(BestSrcChain, BestDstChain, BestGain.mergeOffset(), + BestGain.mergeType()); + --NumActiveChains; + + // Insert newly created edges into the queue. + for (const auto &[_, Edge] : BestSrcChain->Edges) { + // Ignore loop edges. + if (Edge->isSelfEdge()) + continue; + if (Edge->srcChain()->numBlocks() + Edge->dstChain()->numBlocks() > + Config.MaxChainSize) + continue; + + // Compute the gain of merging the two chains. + MergeGainT Gain = getBestMergeGain(Edge); + Edge->setMergeGain(Gain); + + if (Edge->gain() > EPS) + Queue.insert(Edge); + } + } + + LLVM_DEBUG(dbgs() << "Cache-directed function sorting reduced the number" + << " of chains from " << NumNodes << " to " + << NumActiveChains << "\n"); + } + + /// Compute the gain of merging two chains. + /// + /// The function considers all possible ways of merging two chains and + /// computes the one having the largest increase in ExtTSP objective. The + /// result is a pair with the first element being the gain and the second + /// element being the corresponding merging type. + MergeGainT getBestMergeGain(ChainEdge *Edge) const { + assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps"); + // Precompute jumps between ChainPred and ChainSucc. + MergedJumpsT Jumps(&Edge->jumps()); + ChainT *SrcChain = Edge->srcChain(); + ChainT *DstChain = Edge->dstChain(); + + // This object holds the best currently chosen gain of merging two chains. + MergeGainT Gain = MergeGainT(); + + /// Given a list of merge types, try to merge two chains and update Gain + /// with a better alternative. + auto tryChainMerging = [&](const std::vector<MergeTypeT> &MergeTypes) { + // Apply the merge, compute the corresponding gain, and update the best + // value, if the merge is beneficial. + for (const MergeTypeT &MergeType : MergeTypes) { + MergeGainT NewGain = + computeMergeGain(SrcChain, DstChain, Jumps, MergeType); + + // When forward and backward gains are the same, prioritize merging that + // preserves the original order of the functions in the binary. + if (std::abs(Gain.score() - NewGain.score()) < EPS) { + if ((MergeType == MergeTypeT::X_Y && SrcChain->Id < DstChain->Id) || + (MergeType == MergeTypeT::Y_X && SrcChain->Id > DstChain->Id)) { + Gain = NewGain; + } + } else if (NewGain.score() > Gain.score() + EPS) { + Gain = NewGain; + } + } + }; + + // Try to concatenate two chains w/o splitting. + tryChainMerging({MergeTypeT::X_Y, MergeTypeT::Y_X}); + + return Gain; + } + + /// Compute the score gain of merging two chains, respecting a given type. + /// + /// The two chains are not modified in the method. + MergeGainT computeMergeGain(ChainT *ChainPred, ChainT *ChainSucc, + const MergedJumpsT &Jumps, + MergeTypeT MergeType) const { + // This doesn't depend on the ordering of the nodes + double FreqGain = freqBasedLocalityGain(ChainPred, ChainSucc); + + // Merge offset is always 0, as the chains are not split. + size_t MergeOffset = 0; + auto MergedBlocks = + mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType); + double DistGain = distBasedLocalityGain(MergedBlocks, Jumps); + + double GainScore = DistGain + Config.FrequencyScale * FreqGain; + // Scale the result to increase the importance of merging short chains. + if (GainScore >= 0.0) + GainScore /= std::min(ChainPred->Size, ChainSucc->Size); + + return MergeGainT(GainScore, MergeOffset, MergeType); + } + + /// Compute the change of the frequency locality after merging the chains. + double freqBasedLocalityGain(ChainT *ChainPred, ChainT *ChainSucc) const { + auto missProbability = [&](double ChainDensity) { + double PageSamples = ChainDensity * Config.CacheSize; + if (PageSamples >= TotalSamples) + return 0.0; + double P = PageSamples / TotalSamples; + return pow(1.0 - P, static_cast<double>(Config.CacheEntries)); + }; + + // Cache misses on the chains before merging. + double CurScore = + ChainPred->ExecutionCount * missProbability(ChainPred->density()) + + ChainSucc->ExecutionCount * missProbability(ChainSucc->density()); + + // Cache misses on the merged chain + double MergedCounts = ChainPred->ExecutionCount + ChainSucc->ExecutionCount; + double MergedSize = ChainPred->Size + ChainSucc->Size; + double MergedDensity = static_cast<double>(MergedCounts) / MergedSize; + double NewScore = MergedCounts * missProbability(MergedDensity); + + return CurScore - NewScore; + } + + /// Compute the distance locality for a jump / call. + double distScore(uint64_t SrcAddr, uint64_t DstAddr, uint64_t Count) const { + uint64_t Dist = SrcAddr <= DstAddr ? DstAddr - SrcAddr : SrcAddr - DstAddr; + double D = Dist == 0 ? 0.1 : static_cast<double>(Dist); + return static_cast<double>(Count) * std::pow(D, -Config.DistancePower); + } + + /// Compute the change of the distance locality after merging the chains. + double distBasedLocalityGain(const MergedNodesT &Nodes, + const MergedJumpsT &Jumps) const { + uint64_t CurAddr = 0; + Nodes.forEach([&](const NodeT *Node) { + Node->EstimatedAddr = CurAddr; + CurAddr += Node->Size; + }); + + double CurScore = 0; + double NewScore = 0; + Jumps.forEach([&](const JumpT *Jump) { + uint64_t SrcAddr = Jump->Source->EstimatedAddr + Jump->Offset; + uint64_t DstAddr = Jump->Target->EstimatedAddr; + NewScore += distScore(SrcAddr, DstAddr, Jump->ExecutionCount); + CurScore += distScore(0, TotalSize, Jump->ExecutionCount); + }); + return NewScore - CurScore; + } + + /// Merge chain From into chain Into, update the list of active chains, + /// adjacency information, and the corresponding cached values. + void mergeChains(ChainT *Into, ChainT *From, size_t MergeOffset, + MergeTypeT MergeType) { + assert(Into != From && "a chain cannot be merged with itself"); + + // Merge the nodes. + MergedNodesT MergedNodes = + mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType); + Into->merge(From, MergedNodes.getNodes()); + + // Merge the edges. + Into->mergeEdges(From); + From->clear(); + } + + /// Concatenate all chains into the final order. + std::vector<uint64_t> concatChains() { + // Collect chains and calculate density stats for their sorting. std::vector<const ChainT *> SortedChains; DenseMap<const ChainT *, double> ChainDensity; for (ChainT &Chain : AllChains) { if (!Chain.Nodes.empty()) { SortedChains.push_back(&Chain); - // Using doubles to avoid overflow of ExecutionCounts + // Using doubles to avoid overflow of ExecutionCounts. double Size = 0; double ExecutionCount = 0; for (NodeT *Node : Chain.Nodes) { @@ -935,30 +1354,29 @@ private: } } - // Sorting chains by density in the decreasing order - std::stable_sort(SortedChains.begin(), SortedChains.end(), - [&](const ChainT *L, const ChainT *R) { - // Make sure the original entry point is at the - // beginning of the order - if (L->isEntry() != R->isEntry()) - return L->isEntry(); - - const double DL = ChainDensity[L]; - const double DR = ChainDensity[R]; - // Compare by density and break ties by chain identifiers - return (DL != DR) ? (DL > DR) : (L->Id < R->Id); - }); + // Sort chains by density in the decreasing order. + std::sort(SortedChains.begin(), SortedChains.end(), + [&](const ChainT *L, const ChainT *R) { + const double DL = ChainDensity[L]; + const double DR = ChainDensity[R]; + // Compare by density and break ties by chain identifiers. + return std::make_tuple(-DL, L->Id) < + std::make_tuple(-DR, R->Id); + }); - // Collect the nodes in the order specified by their chains + // Collect the nodes in the order specified by their chains. + std::vector<uint64_t> Order; Order.reserve(NumNodes); - for (const ChainT *Chain : SortedChains) { - for (NodeT *Node : Chain->Nodes) { + for (const ChainT *Chain : SortedChains) + for (NodeT *Node : Chain->Nodes) Order.push_back(Node->Index); - } - } + return Order; } private: + /// Config for the algorithm. + const CDSortConfig Config; + /// The number of nodes in the graph. const size_t NumNodes; @@ -968,10 +1386,10 @@ private: /// Predecessors of each node. std::vector<std::vector<uint64_t>> PredNodes; - /// All nodes (basic blocks) in the graph. + /// All nodes (functions) in the graph. std::vector<NodeT> AllNodes; - /// All jumps between the nodes. + /// All jumps (function calls) between the nodes. std::vector<JumpT> AllJumps; /// All chains of nodes. @@ -980,65 +1398,95 @@ private: /// All edges between the chains. std::vector<ChainEdge> AllEdges; - /// Active chains. The vector gets updated at runtime when chains are merged. - std::vector<ChainT *> HotChains; + /// The total number of samples in the graph. + uint64_t TotalSamples{0}; + + /// The total size of the nodes in the graph. + uint64_t TotalSize{0}; }; } // end of anonymous namespace std::vector<uint64_t> -llvm::applyExtTspLayout(const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<EdgeCountT> &EdgeCounts) { - // Verify correctness of the input data +codelayout::computeExtTspLayout(ArrayRef<uint64_t> NodeSizes, + ArrayRef<uint64_t> NodeCounts, + ArrayRef<EdgeCount> EdgeCounts) { + // Verify correctness of the input data. assert(NodeCounts.size() == NodeSizes.size() && "Incorrect input"); assert(NodeSizes.size() > 2 && "Incorrect input"); - // Apply the reordering algorithm + // Apply the reordering algorithm. ExtTSPImpl Alg(NodeSizes, NodeCounts, EdgeCounts); - std::vector<uint64_t> Result; - Alg.run(Result); + std::vector<uint64_t> Result = Alg.run(); - // Verify correctness of the output + // Verify correctness of the output. assert(Result.front() == 0 && "Original entry point is not preserved"); assert(Result.size() == NodeSizes.size() && "Incorrect size of layout"); return Result; } -double llvm::calcExtTspScore(const std::vector<uint64_t> &Order, - const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<EdgeCountT> &EdgeCounts) { - // Estimate addresses of the blocks in memory +double codelayout::calcExtTspScore(ArrayRef<uint64_t> Order, + ArrayRef<uint64_t> NodeSizes, + ArrayRef<uint64_t> NodeCounts, + ArrayRef<EdgeCount> EdgeCounts) { + // Estimate addresses of the blocks in memory. std::vector<uint64_t> Addr(NodeSizes.size(), 0); for (size_t Idx = 1; Idx < Order.size(); Idx++) { Addr[Order[Idx]] = Addr[Order[Idx - 1]] + NodeSizes[Order[Idx - 1]]; } std::vector<uint64_t> OutDegree(NodeSizes.size(), 0); - for (auto It : EdgeCounts) { - uint64_t Pred = It.first.first; - OutDegree[Pred]++; - } + for (auto Edge : EdgeCounts) + ++OutDegree[Edge.src]; - // Increase the score for each jump + // Increase the score for each jump. double Score = 0; - for (auto It : EdgeCounts) { - uint64_t Pred = It.first.first; - uint64_t Succ = It.first.second; - uint64_t Count = It.second; - bool IsConditional = OutDegree[Pred] > 1; - Score += ::extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count, - IsConditional); + for (auto Edge : EdgeCounts) { + bool IsConditional = OutDegree[Edge.src] > 1; + Score += ::extTSPScore(Addr[Edge.src], NodeSizes[Edge.src], Addr[Edge.dst], + Edge.count, IsConditional); } return Score; } -double llvm::calcExtTspScore(const std::vector<uint64_t> &NodeSizes, - const std::vector<uint64_t> &NodeCounts, - const std::vector<EdgeCountT> &EdgeCounts) { +double codelayout::calcExtTspScore(ArrayRef<uint64_t> NodeSizes, + ArrayRef<uint64_t> NodeCounts, + ArrayRef<EdgeCount> EdgeCounts) { std::vector<uint64_t> Order(NodeSizes.size()); for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) { Order[Idx] = Idx; } return calcExtTspScore(Order, NodeSizes, NodeCounts, EdgeCounts); } + +std::vector<uint64_t> codelayout::computeCacheDirectedLayout( + const CDSortConfig &Config, ArrayRef<uint64_t> FuncSizes, + ArrayRef<uint64_t> FuncCounts, ArrayRef<EdgeCount> CallCounts, + ArrayRef<uint64_t> CallOffsets) { + // Verify correctness of the input data. + assert(FuncCounts.size() == FuncSizes.size() && "Incorrect input"); + + // Apply the reordering algorithm. + CDSortImpl Alg(Config, FuncSizes, FuncCounts, CallCounts, CallOffsets); + std::vector<uint64_t> Result = Alg.run(); + assert(Result.size() == FuncSizes.size() && "Incorrect size of layout"); + return Result; +} + +std::vector<uint64_t> codelayout::computeCacheDirectedLayout( + ArrayRef<uint64_t> FuncSizes, ArrayRef<uint64_t> FuncCounts, + ArrayRef<EdgeCount> CallCounts, ArrayRef<uint64_t> CallOffsets) { + CDSortConfig Config; + // Populate the config from the command-line options. + if (CacheEntries.getNumOccurrences() > 0) + Config.CacheEntries = CacheEntries; + if (CacheSize.getNumOccurrences() > 0) + Config.CacheSize = CacheSize; + if (CDMaxChainSize.getNumOccurrences() > 0) + Config.MaxChainSize = CDMaxChainSize; + if (DistancePower.getNumOccurrences() > 0) + Config.DistancePower = DistancePower; + if (FrequencyScale.getNumOccurrences() > 0) + Config.FrequencyScale = FrequencyScale; + return computeCacheDirectedLayout(Config, FuncSizes, FuncCounts, CallCounts, + CallOffsets); +} diff --git a/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp b/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp index 4a6719741719..6a2dae5bab68 100644 --- a/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp +++ b/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp @@ -417,7 +417,7 @@ void llvm::moveInstructionsToTheBeginning(BasicBlock &FromBB, BasicBlock &ToBB, Instruction *MovePos = ToBB.getFirstNonPHIOrDbg(); if (isSafeToMoveBefore(I, *MovePos, DT, &PDT, &DI)) - I.moveBefore(MovePos); + I.moveBeforePreserving(MovePos); } } @@ -429,7 +429,7 @@ void llvm::moveInstructionsToTheEnd(BasicBlock &FromBB, BasicBlock &ToBB, while (FromBB.size() > 1) { Instruction &I = FromBB.front(); if (isSafeToMoveBefore(I, *MovePos, DT, &PDT, &DI)) - I.moveBefore(MovePos); + I.moveBeforePreserving(MovePos); } } diff --git a/llvm/lib/Transforms/Utils/CtorUtils.cpp b/llvm/lib/Transforms/Utils/CtorUtils.cpp index e07c92df2265..507729bc5ebc 100644 --- a/llvm/lib/Transforms/Utils/CtorUtils.cpp +++ b/llvm/lib/Transforms/Utils/CtorUtils.cpp @@ -52,12 +52,9 @@ static void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemov NGV->takeName(GCL); // Nuke the old list, replacing any uses with the new one. - if (!GCL->use_empty()) { - Constant *V = NGV; - if (V->getType() != GCL->getType()) - V = ConstantExpr::getBitCast(V, GCL->getType()); - GCL->replaceAllUsesWith(V); - } + if (!GCL->use_empty()) + GCL->replaceAllUsesWith(NGV); + GCL->eraseFromParent(); } diff --git a/llvm/lib/Transforms/Utils/DXILUpgrade.cpp b/llvm/lib/Transforms/Utils/DXILUpgrade.cpp new file mode 100644 index 000000000000..735686ddce38 --- /dev/null +++ b/llvm/lib/Transforms/Utils/DXILUpgrade.cpp @@ -0,0 +1,36 @@ +//===- DXILUpgrade.cpp - Upgrade DXIL metadata to LLVM constructs ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/DXILUpgrade.h" + +using namespace llvm; + +static bool handleValVerMetadata(Module &M) { + NamedMDNode *ValVer = M.getNamedMetadata("dx.valver"); + if (!ValVer) + return false; + + // We don't need the validation version internally, so we drop it. + ValVer->dropAllReferences(); + ValVer->eraseFromParent(); + return true; +} + +PreservedAnalyses DXILUpgradePass::run(Module &M, ModuleAnalysisManager &AM) { + PreservedAnalyses PA; + // We never add, remove, or change functions here. + PA.preserve<FunctionAnalysisManagerModuleProxy>(); + PA.preserveSet<AllAnalysesOn<Function>>(); + + bool Changed = false; + Changed |= handleValVerMetadata(M); + + if (!Changed) + return PreservedAnalyses::all(); + return PA; +} diff --git a/llvm/lib/Transforms/Utils/Debugify.cpp b/llvm/lib/Transforms/Utils/Debugify.cpp index 93cad0888a56..d0cc603426d2 100644 --- a/llvm/lib/Transforms/Utils/Debugify.cpp +++ b/llvm/lib/Transforms/Utils/Debugify.cpp @@ -801,7 +801,15 @@ bool checkDebugifyMetadata(Module &M, /// legacy module pass manager. struct DebugifyModulePass : public ModulePass { bool runOnModule(Module &M) override { - return applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass); + bool NewDebugMode = M.IsNewDbgInfoFormat; + if (NewDebugMode) + M.convertFromNewDbgValues(); + + bool Result = applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass); + + if (NewDebugMode) + M.convertToNewDbgValues(); + return Result; } DebugifyModulePass(enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, @@ -826,7 +834,15 @@ private: /// single function, used with the legacy module pass manager. struct DebugifyFunctionPass : public FunctionPass { bool runOnFunction(Function &F) override { - return applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass); + bool NewDebugMode = F.IsNewDbgInfoFormat; + if (NewDebugMode) + F.convertFromNewDbgValues(); + + bool Result = applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass); + + if (NewDebugMode) + F.convertToNewDbgValues(); + return Result; } DebugifyFunctionPass( @@ -852,13 +868,24 @@ private: /// legacy module pass manager. struct CheckDebugifyModulePass : public ModulePass { bool runOnModule(Module &M) override { + bool NewDebugMode = M.IsNewDbgInfoFormat; + if (NewDebugMode) + M.convertFromNewDbgValues(); + + bool Result; if (Mode == DebugifyMode::SyntheticDebugInfo) - return checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass, + Result = checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass, "CheckModuleDebugify", Strip, StatsMap); - return checkDebugInfoMetadata( + else + Result = checkDebugInfoMetadata( M, M.functions(), *DebugInfoBeforePass, "CheckModuleDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); + + if (NewDebugMode) + M.convertToNewDbgValues(); + + return Result; } CheckDebugifyModulePass( @@ -891,16 +918,26 @@ private: /// with the legacy module pass manager. struct CheckDebugifyFunctionPass : public FunctionPass { bool runOnFunction(Function &F) override { + bool NewDebugMode = F.IsNewDbgInfoFormat; + if (NewDebugMode) + F.convertFromNewDbgValues(); + Module &M = *F.getParent(); auto FuncIt = F.getIterator(); + bool Result; if (Mode == DebugifyMode::SyntheticDebugInfo) - return checkDebugifyMetadata(M, make_range(FuncIt, std::next(FuncIt)), + Result = checkDebugifyMetadata(M, make_range(FuncIt, std::next(FuncIt)), NameOfWrappedPass, "CheckFunctionDebugify", Strip, StatsMap); - return checkDebugInfoMetadata( + else + Result = checkDebugInfoMetadata( M, make_range(FuncIt, std::next(FuncIt)), *DebugInfoBeforePass, "CheckFunctionDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); + + if (NewDebugMode) + F.convertToNewDbgValues(); + return Result; } CheckDebugifyFunctionPass( @@ -972,6 +1009,10 @@ createDebugifyFunctionPass(enum DebugifyMode Mode, } PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { + bool NewDebugMode = M.IsNewDbgInfoFormat; + if (NewDebugMode) + M.convertFromNewDbgValues(); + if (Mode == DebugifyMode::SyntheticDebugInfo) applyDebugifyMetadata(M, M.functions(), "ModuleDebugify: ", /*ApplyToMF*/ nullptr); @@ -979,6 +1020,10 @@ PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { collectDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass, "ModuleDebugify (original debuginfo)", NameOfWrappedPass); + + if (NewDebugMode) + M.convertToNewDbgValues(); + PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); return PA; @@ -1010,6 +1055,10 @@ FunctionPass *createCheckDebugifyFunctionPass( PreservedAnalyses NewPMCheckDebugifyPass::run(Module &M, ModuleAnalysisManager &) { + bool NewDebugMode = M.IsNewDbgInfoFormat; + if (NewDebugMode) + M.convertFromNewDbgValues(); + if (Mode == DebugifyMode::SyntheticDebugInfo) checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass, "CheckModuleDebugify", Strip, StatsMap); @@ -1018,6 +1067,10 @@ PreservedAnalyses NewPMCheckDebugifyPass::run(Module &M, M, M.functions(), *DebugInfoBeforePass, "CheckModuleDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); + + if (NewDebugMode) + M.convertToNewDbgValues(); + return PreservedAnalyses::all(); } @@ -1035,13 +1088,13 @@ void DebugifyEachInstrumentation::registerCallbacks( return; PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); - if (const auto **CF = any_cast<const Function *>(&IR)) { + if (const auto **CF = llvm::any_cast<const Function *>(&IR)) { Function &F = *const_cast<Function *>(*CF); applyDebugify(F, Mode, DebugInfoBeforePass, P); MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent()) .getManager() .invalidate(F, PA); - } else if (const auto **CM = any_cast<const Module *>(&IR)) { + } else if (const auto **CM = llvm::any_cast<const Module *>(&IR)) { Module &M = *const_cast<Module *>(*CM); applyDebugify(M, Mode, DebugInfoBeforePass, P); MAM.invalidate(M, PA); @@ -1053,7 +1106,7 @@ void DebugifyEachInstrumentation::registerCallbacks( return; PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); - if (const auto **CF = any_cast<const Function *>(&IR)) { + if (const auto **CF = llvm::any_cast<const Function *>(&IR)) { auto &F = *const_cast<Function *>(*CF); Module &M = *F.getParent(); auto It = F.getIterator(); @@ -1069,7 +1122,7 @@ void DebugifyEachInstrumentation::registerCallbacks( MAM.getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent()) .getManager() .invalidate(F, PA); - } else if (const auto **CM = any_cast<const Module *>(&IR)) { + } else if (const auto **CM = llvm::any_cast<const Module *>(&IR)) { Module &M = *const_cast<Module *>(*CM); if (Mode == DebugifyMode::SyntheticDebugInfo) checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify", diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp index d424ebbef99d..092f1799755d 100644 --- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -35,7 +35,7 @@ static void insertCall(Function &CurFn, StringRef Func, Triple TargetTriple(M.getTargetTriple()); if (TargetTriple.isOSAIX() && Func == "__mcount") { Type *SizeTy = M.getDataLayout().getIntPtrType(C); - Type *SizePtrTy = SizeTy->getPointerTo(); + Type *SizePtrTy = PointerType::getUnqual(C); GlobalVariable *GV = new GlobalVariable(M, SizeTy, /*isConstant=*/false, GlobalValue::InternalLinkage, ConstantInt::get(SizeTy, 0)); @@ -54,7 +54,7 @@ static void insertCall(Function &CurFn, StringRef Func, } if (Func == "__cyg_profile_func_enter" || Func == "__cyg_profile_func_exit") { - Type *ArgTypes[] = {Type::getInt8PtrTy(C), Type::getInt8PtrTy(C)}; + Type *ArgTypes[] = {PointerType::getUnqual(C), PointerType::getUnqual(C)}; FunctionCallee Fn = M.getOrInsertFunction( Func, FunctionType::get(Type::getVoidTy(C), ArgTypes, false)); @@ -65,9 +65,7 @@ static void insertCall(Function &CurFn, StringRef Func, InsertionPt); RetAddr->setDebugLoc(DL); - Value *Args[] = {ConstantExpr::getBitCast(&CurFn, Type::getInt8PtrTy(C)), - RetAddr}; - + Value *Args[] = {&CurFn, RetAddr}; CallInst *Call = CallInst::Create(Fn, ArrayRef<Value *>(Args), "", InsertionPt); Call->setDebugLoc(DL); diff --git a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp index 88c838685bca..cc00106fcbfe 100644 --- a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -70,7 +70,7 @@ IRBuilder<> *EscapeEnumerator::Next() { // Create a cleanup block. LLVMContext &C = F.getContext(); BasicBlock *CleanupBB = BasicBlock::Create(C, CleanupBBName, &F); - Type *ExnTy = StructType::get(Type::getInt8PtrTy(C), Type::getInt32Ty(C)); + Type *ExnTy = StructType::get(PointerType::getUnqual(C), Type::getInt32Ty(C)); if (!F.hasPersonalityFn()) { FunctionCallee PersFn = getDefaultPersonalityFn(F.getParent()); F.setPersonalityFn(cast<Constant>(PersFn.getCallee())); diff --git a/llvm/lib/Transforms/Utils/FixIrreducible.cpp b/llvm/lib/Transforms/Utils/FixIrreducible.cpp index dda236167363..11e24d0585be 100644 --- a/llvm/lib/Transforms/Utils/FixIrreducible.cpp +++ b/llvm/lib/Transforms/Utils/FixIrreducible.cpp @@ -87,10 +87,8 @@ struct FixIrreducible : public FunctionPass { } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreservedID(LowerSwitchID); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); } @@ -106,7 +104,6 @@ FunctionPass *llvm::createFixIrreduciblePass() { return new FixIrreducible(); } INITIALIZE_PASS_BEGIN(FixIrreducible, "fix-irreducible", "Convert irreducible control-flow into natural loops", false /* Only looks at CFG */, false /* Analysis Pass */) -INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(FixIrreducible, "fix-irreducible", @@ -317,6 +314,8 @@ static bool FixIrreducibleImpl(Function &F, LoopInfo &LI, DominatorTree &DT) { LLVM_DEBUG(dbgs() << "===== Fix irreducible control-flow in function: " << F.getName() << "\n"); + assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator."); + bool Changed = false; SmallVector<Loop *, 8> WorkList; diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 8daeb92130ba..79ca99d1566c 100644 --- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -160,10 +160,23 @@ int FunctionComparator::cmpAttrs(const AttributeList L, int FunctionComparator::cmpMetadata(const Metadata *L, const Metadata *R) const { // TODO: the following routine coerce the metadata contents into constants - // before comparison. + // or MDStrings before comparison. // It ignores any other cases, so that the metadata nodes are considered // equal even though this is not correct. // We should structurally compare the metadata nodes to be perfect here. + + auto *MDStringL = dyn_cast<MDString>(L); + auto *MDStringR = dyn_cast<MDString>(R); + if (MDStringL && MDStringR) { + if (MDStringL == MDStringR) + return 0; + return MDStringL->getString().compare(MDStringR->getString()); + } + if (MDStringR) + return -1; + if (MDStringL) + return 1; + auto *CL = dyn_cast<ConstantAsMetadata>(L); auto *CR = dyn_cast<ConstantAsMetadata>(R); if (CL == CR) @@ -820,6 +833,21 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const { if (ConstR) return -1; + const MetadataAsValue *MetadataValueL = dyn_cast<MetadataAsValue>(L); + const MetadataAsValue *MetadataValueR = dyn_cast<MetadataAsValue>(R); + if (MetadataValueL && MetadataValueR) { + if (MetadataValueL == MetadataValueR) + return 0; + + return cmpMetadata(MetadataValueL->getMetadata(), + MetadataValueR->getMetadata()); + } + + if (MetadataValueL) + return 1; + if (MetadataValueR) + return -1; + const InlineAsm *InlineAsmL = dyn_cast<InlineAsm>(L); const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R); @@ -958,67 +986,3 @@ int FunctionComparator::compare() { } return 0; } - -namespace { - -// Accumulate the hash of a sequence of 64-bit integers. This is similar to a -// hash of a sequence of 64bit ints, but the entire input does not need to be -// available at once. This interface is necessary for functionHash because it -// needs to accumulate the hash as the structure of the function is traversed -// without saving these values to an intermediate buffer. This form of hashing -// is not often needed, as usually the object to hash is just read from a -// buffer. -class HashAccumulator64 { - uint64_t Hash; - -public: - // Initialize to random constant, so the state isn't zero. - HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } - - void add(uint64_t V) { Hash = hashing::detail::hash_16_bytes(Hash, V); } - - // No finishing is required, because the entire hash value is used. - uint64_t getHash() { return Hash; } -}; - -} // end anonymous namespace - -// A function hash is calculated by considering only the number of arguments and -// whether a function is varargs, the order of basic blocks (given by the -// successors of each basic block in depth first order), and the order of -// opcodes of each instruction within each of these basic blocks. This mirrors -// the strategy compare() uses to compare functions by walking the BBs in depth -// first order and comparing each instruction in sequence. Because this hash -// does not look at the operands, it is insensitive to things such as the -// target of calls and the constants used in the function, which makes it useful -// when possibly merging functions which are the same modulo constants and call -// targets. -FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { - HashAccumulator64 H; - H.add(F.isVarArg()); - H.add(F.arg_size()); - - SmallVector<const BasicBlock *, 8> BBs; - SmallPtrSet<const BasicBlock *, 16> VisitedBBs; - - // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(), - // accumulating the hash of the function "structure." (BB and opcode sequence) - BBs.push_back(&F.getEntryBlock()); - VisitedBBs.insert(BBs[0]); - while (!BBs.empty()) { - const BasicBlock *BB = BBs.pop_back_val(); - // This random value acts as a block header, as otherwise the partition of - // opcodes into BBs wouldn't affect the hash, only the order of the opcodes - H.add(45798); - for (const auto &Inst : *BB) { - H.add(Inst.getOpcode()); - } - const Instruction *Term = BB->getTerminator(); - for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { - if (!VisitedBBs.insert(Term->getSuccessor(i)).second) - continue; - BBs.push_back(Term->getSuccessor(i)); - } - } - return H.getHash(); -} diff --git a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp index dab0be3a9fde..0990c750af55 100644 --- a/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp +++ b/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp @@ -91,18 +91,16 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) { Mappings.end()); auto AddVariantDecl = [&](const ElementCount &VF, bool Predicate) { - const std::string TLIName = - std::string(TLI.getVectorizedFunction(ScalarName, VF, Predicate)); - if (!TLIName.empty()) { - std::string MangledName = VFABI::mangleTLIVectorName( - TLIName, ScalarName, CI.arg_size(), VF, Predicate); + const VecDesc *VD = TLI.getVectorMappingInfo(ScalarName, VF, Predicate); + if (VD && !VD->getVectorFnName().empty()) { + std::string MangledName = VD->getVectorFunctionABIVariantString(); if (!OriginalSetOfMappings.count(MangledName)) { Mappings.push_back(MangledName); ++NumCallInjected; } - Function *VariantF = M->getFunction(TLIName); + Function *VariantF = M->getFunction(VD->getVectorFnName()); if (!VariantF) - addVariantDeclaration(CI, VF, Predicate, TLIName); + addVariantDeclaration(CI, VF, Predicate, VD->getVectorFnName()); } }; diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index f7b93fc8fd06..39d5f6e53c1d 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -30,6 +30,7 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -189,20 +190,21 @@ BasicBlock *LandingPadInliningInfo::getInnerResumeDest() { const unsigned PHICapacity = 2; // Create corresponding new PHIs for all the PHIs in the outer landing pad. - Instruction *InsertPoint = &InnerResumeDest->front(); + BasicBlock::iterator InsertPoint = InnerResumeDest->begin(); BasicBlock::iterator I = OuterResumeDest->begin(); for (unsigned i = 0, e = UnwindDestPHIValues.size(); i != e; ++i, ++I) { PHINode *OuterPHI = cast<PHINode>(I); PHINode *InnerPHI = PHINode::Create(OuterPHI->getType(), PHICapacity, - OuterPHI->getName() + ".lpad-body", - InsertPoint); + OuterPHI->getName() + ".lpad-body"); + InnerPHI->insertBefore(InsertPoint); OuterPHI->replaceAllUsesWith(InnerPHI); InnerPHI->addIncoming(OuterPHI, OuterResumeDest); } // Create a PHI for the exception values. - InnerEHValuesPHI = PHINode::Create(CallerLPad->getType(), PHICapacity, - "eh.lpad-body", InsertPoint); + InnerEHValuesPHI = + PHINode::Create(CallerLPad->getType(), PHICapacity, "eh.lpad-body"); + InnerEHValuesPHI->insertBefore(InsertPoint); CallerLPad->replaceAllUsesWith(InnerEHValuesPHI); InnerEHValuesPHI->addIncoming(CallerLPad, OuterResumeDest); @@ -1331,38 +1333,51 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, } } -static bool MayContainThrowingOrExitingCall(Instruction *Begin, - Instruction *End) { +static bool MayContainThrowingOrExitingCallAfterCB(CallBase *Begin, + ReturnInst *End) { assert(Begin->getParent() == End->getParent() && "Expected to be in same basic block!"); + auto BeginIt = Begin->getIterator(); + assert(BeginIt != End->getIterator() && "Non-empty BB has empty iterator"); return !llvm::isGuaranteedToTransferExecutionToSuccessor( - Begin->getIterator(), End->getIterator(), InlinerAttributeWindow + 1); + ++BeginIt, End->getIterator(), InlinerAttributeWindow + 1); } -static AttrBuilder IdentifyValidAttributes(CallBase &CB) { +// Only allow these white listed attributes to be propagated back to the +// callee. This is because other attributes may only be valid on the call +// itself, i.e. attributes such as signext and zeroext. - AttrBuilder AB(CB.getContext(), CB.getAttributes().getRetAttrs()); - if (!AB.hasAttributes()) - return AB; +// Attributes that are always okay to propagate as if they are violated its +// immediate UB. +static AttrBuilder IdentifyValidUBGeneratingAttributes(CallBase &CB) { AttrBuilder Valid(CB.getContext()); - // Only allow these white listed attributes to be propagated back to the - // callee. This is because other attributes may only be valid on the call - // itself, i.e. attributes such as signext and zeroext. - if (auto DerefBytes = AB.getDereferenceableBytes()) + if (auto DerefBytes = CB.getRetDereferenceableBytes()) Valid.addDereferenceableAttr(DerefBytes); - if (auto DerefOrNullBytes = AB.getDereferenceableOrNullBytes()) + if (auto DerefOrNullBytes = CB.getRetDereferenceableOrNullBytes()) Valid.addDereferenceableOrNullAttr(DerefOrNullBytes); - if (AB.contains(Attribute::NoAlias)) + if (CB.hasRetAttr(Attribute::NoAlias)) Valid.addAttribute(Attribute::NoAlias); - if (AB.contains(Attribute::NonNull)) + if (CB.hasRetAttr(Attribute::NoUndef)) + Valid.addAttribute(Attribute::NoUndef); + return Valid; +} + +// Attributes that need additional checks as propagating them may change +// behavior or cause new UB. +static AttrBuilder IdentifyValidPoisonGeneratingAttributes(CallBase &CB) { + AttrBuilder Valid(CB.getContext()); + if (CB.hasRetAttr(Attribute::NonNull)) Valid.addAttribute(Attribute::NonNull); + if (CB.hasRetAttr(Attribute::Alignment)) + Valid.addAlignmentAttr(CB.getRetAlign()); return Valid; } static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { - AttrBuilder Valid = IdentifyValidAttributes(CB); - if (!Valid.hasAttributes()) + AttrBuilder ValidUB = IdentifyValidUBGeneratingAttributes(CB); + AttrBuilder ValidPG = IdentifyValidPoisonGeneratingAttributes(CB); + if (!ValidUB.hasAttributes() && !ValidPG.hasAttributes()) return; auto *CalledFunction = CB.getCalledFunction(); auto &Context = CalledFunction->getContext(); @@ -1397,7 +1412,7 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { // limit the check to both RetVal and RI are in the same basic block and // there are no throwing/exiting instructions between these instructions. if (RI->getParent() != RetVal->getParent() || - MayContainThrowingOrExitingCall(RetVal, RI)) + MayContainThrowingOrExitingCallAfterCB(RetVal, RI)) continue; // Add to the existing attributes of NewRetVal, i.e. the cloned call // instruction. @@ -1406,7 +1421,62 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { // existing attribute value (i.e. attributes such as dereferenceable, // dereferenceable_or_null etc). See AttrBuilder::merge for more details. AttributeList AL = NewRetVal->getAttributes(); - AttributeList NewAL = AL.addRetAttributes(Context, Valid); + if (ValidUB.getDereferenceableBytes() < AL.getRetDereferenceableBytes()) + ValidUB.removeAttribute(Attribute::Dereferenceable); + if (ValidUB.getDereferenceableOrNullBytes() < + AL.getRetDereferenceableOrNullBytes()) + ValidUB.removeAttribute(Attribute::DereferenceableOrNull); + AttributeList NewAL = AL.addRetAttributes(Context, ValidUB); + // Attributes that may generate poison returns are a bit tricky. If we + // propagate them, other uses of the callsite might have their behavior + // change or cause UB (if they have noundef) b.c of the new potential + // poison. + // Take the following three cases: + // + // 1) + // define nonnull ptr @foo() { + // %p = call ptr @bar() + // call void @use(ptr %p) willreturn nounwind + // ret ptr %p + // } + // + // 2) + // define noundef nonnull ptr @foo() { + // %p = call ptr @bar() + // call void @use(ptr %p) willreturn nounwind + // ret ptr %p + // } + // + // 3) + // define nonnull ptr @foo() { + // %p = call noundef ptr @bar() + // ret ptr %p + // } + // + // In case 1, we can't propagate nonnull because poison value in @use may + // change behavior or trigger UB. + // In case 2, we don't need to be concerned about propagating nonnull, as + // any new poison at @use will trigger UB anyways. + // In case 3, we can never propagate nonnull because it may create UB due to + // the noundef on @bar. + if (ValidPG.getAlignment().valueOrOne() < AL.getRetAlignment().valueOrOne()) + ValidPG.removeAttribute(Attribute::Alignment); + if (ValidPG.hasAttributes()) { + // Three checks. + // If the callsite has `noundef`, then a poison due to violating the + // return attribute will create UB anyways so we can always propagate. + // Otherwise, if the return value (callee to be inlined) has `noundef`, we + // can't propagate as a new poison return will cause UB. + // Finally, check if the return value has no uses whose behavior may + // change/may cause UB if we potentially return poison. At the moment this + // is implemented overly conservatively with a single-use check. + // TODO: Update the single-use check to iterate through uses and only bail + // if we have a potentially dangerous use. + + if (CB.hasRetAttr(Attribute::NoUndef) || + (RetVal->hasOneUse() && !RetVal->hasRetAttr(Attribute::NoUndef))) + NewAL = NewAL.addRetAttributes(Context, ValidPG); + } NewRetVal->setAttributes(NewAL); } } @@ -1515,10 +1585,10 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, if (ByValAlignment) Alignment = std::max(Alignment, *ByValAlignment); - Value *NewAlloca = - new AllocaInst(ByValType, DL.getAllocaAddrSpace(), nullptr, Alignment, - Arg->getName(), &*Caller->begin()->begin()); - IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca)); + AllocaInst *NewAlloca = new AllocaInst(ByValType, DL.getAllocaAddrSpace(), + nullptr, Alignment, Arg->getName()); + NewAlloca->insertBefore(Caller->begin()->begin()); + IFI.StaticAllocas.push_back(NewAlloca); // Uses of the argument in the function should use our new alloca // instead. @@ -1538,8 +1608,8 @@ static bool isUsedByLifetimeMarker(Value *V) { // lifetime.start or lifetime.end intrinsics. static bool hasLifetimeMarkers(AllocaInst *AI) { Type *Ty = AI->getType(); - Type *Int8PtrTy = Type::getInt8PtrTy(Ty->getContext(), - Ty->getPointerAddressSpace()); + Type *Int8PtrTy = + PointerType::get(Ty->getContext(), Ty->getPointerAddressSpace()); if (Ty == Int8PtrTy) return isUsedByLifetimeMarker(AI); @@ -1596,48 +1666,71 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, // the call site location instead. bool NoInlineLineTables = Fn->hasFnAttribute("no-inline-line-tables"); - for (; FI != Fn->end(); ++FI) { - for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); - BI != BE; ++BI) { - // Loop metadata needs to be updated so that the start and end locs - // reference inlined-at locations. - auto updateLoopInfoLoc = [&Ctx, &InlinedAtNode, - &IANodes](Metadata *MD) -> Metadata * { - if (auto *Loc = dyn_cast_or_null<DILocation>(MD)) - return inlineDebugLoc(Loc, InlinedAtNode, Ctx, IANodes).get(); - return MD; - }; - updateLoopMetadataDebugLocations(*BI, updateLoopInfoLoc); + // Helper-util for updating the metadata attached to an instruction. + auto UpdateInst = [&](Instruction &I) { + // Loop metadata needs to be updated so that the start and end locs + // reference inlined-at locations. + auto updateLoopInfoLoc = [&Ctx, &InlinedAtNode, + &IANodes](Metadata *MD) -> Metadata * { + if (auto *Loc = dyn_cast_or_null<DILocation>(MD)) + return inlineDebugLoc(Loc, InlinedAtNode, Ctx, IANodes).get(); + return MD; + }; + updateLoopMetadataDebugLocations(I, updateLoopInfoLoc); - if (!NoInlineLineTables) - if (DebugLoc DL = BI->getDebugLoc()) { - DebugLoc IDL = - inlineDebugLoc(DL, InlinedAtNode, BI->getContext(), IANodes); - BI->setDebugLoc(IDL); - continue; - } + if (!NoInlineLineTables) + if (DebugLoc DL = I.getDebugLoc()) { + DebugLoc IDL = + inlineDebugLoc(DL, InlinedAtNode, I.getContext(), IANodes); + I.setDebugLoc(IDL); + return; + } - if (CalleeHasDebugInfo && !NoInlineLineTables) - continue; + if (CalleeHasDebugInfo && !NoInlineLineTables) + return; - // If the inlined instruction has no line number, or if inline info - // is not being generated, make it look as if it originates from the call - // location. This is important for ((__always_inline, __nodebug__)) - // functions which must use caller location for all instructions in their - // function body. + // If the inlined instruction has no line number, or if inline info + // is not being generated, make it look as if it originates from the call + // location. This is important for ((__always_inline, __nodebug__)) + // functions which must use caller location for all instructions in their + // function body. - // Don't update static allocas, as they may get moved later. - if (auto *AI = dyn_cast<AllocaInst>(BI)) - if (allocaWouldBeStaticInEntry(AI)) - continue; + // Don't update static allocas, as they may get moved later. + if (auto *AI = dyn_cast<AllocaInst>(&I)) + if (allocaWouldBeStaticInEntry(AI)) + return; - // Do not force a debug loc for pseudo probes, since they do not need to - // be debuggable, and also they are expected to have a zero/null dwarf - // discriminator at this point which could be violated otherwise. - if (isa<PseudoProbeInst>(BI)) - continue; + // Do not force a debug loc for pseudo probes, since they do not need to + // be debuggable, and also they are expected to have a zero/null dwarf + // discriminator at this point which could be violated otherwise. + if (isa<PseudoProbeInst>(I)) + return; - BI->setDebugLoc(TheCallDL); + I.setDebugLoc(TheCallDL); + }; + + // Helper-util for updating debug-info records attached to instructions. + auto UpdateDPV = [&](DPValue *DPV) { + assert(DPV->getDebugLoc() && "Debug Value must have debug loc"); + if (NoInlineLineTables) { + DPV->setDebugLoc(TheCallDL); + return; + } + DebugLoc DL = DPV->getDebugLoc(); + DebugLoc IDL = + inlineDebugLoc(DL, InlinedAtNode, + DPV->getMarker()->getParent()->getContext(), IANodes); + DPV->setDebugLoc(IDL); + }; + + // Iterate over all instructions, updating metadata and debug-info records. + for (; FI != Fn->end(); ++FI) { + for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; + ++BI) { + UpdateInst(*BI); + for (DPValue &DPV : BI->getDbgValueRange()) { + UpdateDPV(&DPV); + } } // Remove debug info intrinsics if we're not keeping inline info. @@ -1647,11 +1740,12 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, if (isa<DbgInfoIntrinsic>(BI)) { BI = BI->eraseFromParent(); continue; + } else { + BI->dropDbgValues(); } ++BI; } } - } } @@ -1760,12 +1854,12 @@ static void updateCallerBFI(BasicBlock *CallSiteBlock, continue; auto *OrigBB = cast<BasicBlock>(Entry.first); auto *ClonedBB = cast<BasicBlock>(Entry.second); - uint64_t Freq = CalleeBFI->getBlockFreq(OrigBB).getFrequency(); + BlockFrequency Freq = CalleeBFI->getBlockFreq(OrigBB); if (!ClonedBBs.insert(ClonedBB).second) { // Multiple blocks in the callee might get mapped to one cloned block in // the caller since we prune the callee as we clone it. When that happens, // we want to use the maximum among the original blocks' frequencies. - uint64_t NewFreq = CallerBFI->getBlockFreq(ClonedBB).getFrequency(); + BlockFrequency NewFreq = CallerBFI->getBlockFreq(ClonedBB); if (NewFreq > Freq) Freq = NewFreq; } @@ -1773,8 +1867,7 @@ static void updateCallerBFI(BasicBlock *CallSiteBlock, } BasicBlock *EntryClone = cast<BasicBlock>(VMap.lookup(&CalleeEntryBlock)); CallerBFI->setBlockFreqAndScale( - EntryClone, CallerBFI->getBlockFreq(CallSiteBlock).getFrequency(), - ClonedBBs); + EntryClone, CallerBFI->getBlockFreq(CallSiteBlock), ClonedBBs); } /// Update the branch metadata for cloned call instructions. @@ -1882,8 +1975,7 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind, Builder.SetInsertPoint(II); Function *IFn = Intrinsic::getDeclaration(Mod, Intrinsic::objc_release); - Value *BC = Builder.CreateBitCast(RetOpnd, IFn->getArg(0)->getType()); - Builder.CreateCall(IFn, BC, ""); + Builder.CreateCall(IFn, RetOpnd, ""); } II->eraseFromParent(); InsertRetainCall = false; @@ -1918,8 +2010,7 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind, // to objc_retain. Builder.SetInsertPoint(RI); Function *IFn = Intrinsic::getDeclaration(Mod, Intrinsic::objc_retain); - Value *BC = Builder.CreateBitCast(RetOpnd, IFn->getArg(0)->getType()); - Builder.CreateCall(IFn, BC, ""); + Builder.CreateCall(IFn, RetOpnd, ""); } } } @@ -1953,9 +2044,11 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // The inliner does not know how to inline through calls with operand bundles // in general ... + Value *ConvergenceControlToken = nullptr; if (CB.hasOperandBundles()) { for (int i = 0, e = CB.getNumOperandBundles(); i != e; ++i) { - uint32_t Tag = CB.getOperandBundleAt(i).getTagID(); + auto OBUse = CB.getOperandBundleAt(i); + uint32_t Tag = OBUse.getTagID(); // ... but it knows how to inline through "deopt" operand bundles ... if (Tag == LLVMContext::OB_deopt) continue; @@ -1966,11 +2059,37 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, continue; if (Tag == LLVMContext::OB_kcfi) continue; + if (Tag == LLVMContext::OB_convergencectrl) { + ConvergenceControlToken = OBUse.Inputs[0].get(); + continue; + } return InlineResult::failure("unsupported operand bundle"); } } + // FIXME: The check below is redundant and incomplete. According to spec, if a + // convergent call is missing a token, then the caller is using uncontrolled + // convergence. If the callee has an entry intrinsic, then the callee is using + // controlled convergence, and the call cannot be inlined. A proper + // implemenation of this check requires a whole new analysis that identifies + // convergence in every function. For now, we skip that and just do this one + // cursory check. The underlying assumption is that in a compiler flow that + // fully implements convergence control tokens, there is no mixing of + // controlled and uncontrolled convergent operations in the whole program. + if (CB.isConvergent()) { + auto *I = CalledFunc->getEntryBlock().getFirstNonPHI(); + if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) { + if (IntrinsicCall->getIntrinsicID() == + Intrinsic::experimental_convergence_entry) { + if (!ConvergenceControlToken) { + return InlineResult::failure( + "convergent call needs convergencectrl operand"); + } + } + } + } + // If the call to the callee cannot throw, set the 'nounwind' flag on any // calls that we inline. bool MarkNoUnwind = CB.doesNotThrow(); @@ -2260,6 +2379,17 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, IFI.GetAssumptionCache(*Caller).registerAssumption(II); } + if (ConvergenceControlToken) { + auto *I = FirstNewBlock->getFirstNonPHI(); + if (auto *IntrinsicCall = dyn_cast<IntrinsicInst>(I)) { + if (IntrinsicCall->getIntrinsicID() == + Intrinsic::experimental_convergence_entry) { + IntrinsicCall->replaceAllUsesWith(ConvergenceControlToken); + IntrinsicCall->eraseFromParent(); + } + } + } + // If there are any alloca instructions in the block that used to be the entry // block for the callee, move them to the entry block of the caller. First // calculate which instruction they should be inserted before. We insert the @@ -2296,6 +2426,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Transfer all of the allocas over in a block. Using splice means // that the instructions aren't removed from the symbol table, then // reinserted. + I.setTailBit(true); Caller->getEntryBlock().splice(InsertPoint, &*FirstNewBlock, AI->getIterator(), I); } @@ -2400,7 +2531,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // `Caller->isPresplitCoroutine()` would affect AlwaysInliner at O0 only. if ((InsertLifetime || Caller->isPresplitCoroutine()) && !IFI.StaticAllocas.empty()) { - IRBuilder<> builder(&FirstNewBlock->front()); + IRBuilder<> builder(&*FirstNewBlock, FirstNewBlock->begin()); for (unsigned ai = 0, ae = IFI.StaticAllocas.size(); ai != ae; ++ai) { AllocaInst *AI = IFI.StaticAllocas[ai]; // Don't mark swifterror allocas. They can't have bitcast uses. @@ -2454,14 +2585,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // If the inlined code contained dynamic alloca instructions, wrap the inlined // code with llvm.stacksave/llvm.stackrestore intrinsics. if (InlinedFunctionInfo.ContainsDynamicAllocas) { - Module *M = Caller->getParent(); - // Get the two intrinsics we care about. - Function *StackSave = Intrinsic::getDeclaration(M, Intrinsic::stacksave); - Function *StackRestore=Intrinsic::getDeclaration(M,Intrinsic::stackrestore); - // Insert the llvm.stacksave. CallInst *SavedPtr = IRBuilder<>(&*FirstNewBlock, FirstNewBlock->begin()) - .CreateCall(StackSave, {}, "savedstack"); + .CreateStackSave("savedstack"); // Insert a call to llvm.stackrestore before any return instructions in the // inlined function. @@ -2472,7 +2598,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, continue; if (InlinedDeoptimizeCalls && RI->getParent()->getTerminatingDeoptimizeCall()) continue; - IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); + IRBuilder<>(RI).CreateStackRestore(SavedPtr); } } @@ -2574,6 +2700,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, Builder.CreateRetVoid(); else Builder.CreateRet(NewDeoptCall); + // Since the ret type is changed, remove the incompatible attributes. + NewDeoptCall->removeRetAttrs( + AttributeFuncs::typeIncompatible(NewDeoptCall->getType())); } // Leave behind the normal returns so we can merge control flow. @@ -2704,8 +2833,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (IFI.CallerBFI) { // Copy original BB's block frequency to AfterCallBB - IFI.CallerBFI->setBlockFreq( - AfterCallBB, IFI.CallerBFI->getBlockFreq(OrigBB).getFrequency()); + IFI.CallerBFI->setBlockFreq(AfterCallBB, + IFI.CallerBFI->getBlockFreq(OrigBB)); } // Change the branch that used to go to AfterCallBB to branch to the first @@ -2731,8 +2860,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // The PHI node should go at the front of the new basic block to merge all // possible incoming values. if (!CB.use_empty()) { - PHI = PHINode::Create(RTy, Returns.size(), CB.getName(), - &AfterCallBB->front()); + PHI = PHINode::Create(RTy, Returns.size(), CB.getName()); + PHI->insertBefore(AfterCallBB->begin()); // Anything that used the result of the function call should now use the // PHI node as their operand. CB.replaceAllUsesWith(PHI); diff --git a/llvm/lib/Transforms/Utils/LCSSA.cpp b/llvm/lib/Transforms/Utils/LCSSA.cpp index c36b0533580b..5e0c312fe149 100644 --- a/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -160,7 +160,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, if (SSAUpdate.HasValueForBlock(ExitBB)) continue; PHINode *PN = PHINode::Create(I->getType(), PredCache.size(ExitBB), - I->getName() + ".lcssa", &ExitBB->front()); + I->getName() + ".lcssa"); + PN->insertBefore(ExitBB->begin()); if (InsertedPHIs) InsertedPHIs->push_back(PN); // Get the debug location from the original instruction. @@ -241,7 +242,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, } SmallVector<DbgValueInst *, 4> DbgValues; - llvm::findDbgValues(DbgValues, I); + SmallVector<DPValue *, 4> DPValues; + llvm::findDbgValues(DbgValues, I, &DPValues); // Update pre-existing debug value uses that reside outside the loop. for (auto *DVI : DbgValues) { @@ -257,6 +259,21 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, DVI->replaceVariableLocationOp(I, V); } + // RemoveDIs: copy-paste of block above, using non-instruction debug-info + // records. + for (DPValue *DPV : DPValues) { + BasicBlock *UserBB = DPV->getMarker()->getParent(); + if (InstBB == UserBB || L->contains(UserBB)) + continue; + // We currently only handle debug values residing in blocks that were + // traversed while rewriting the uses. If we inserted just a single PHI, + // we will handle all relevant debug values. + Value *V = AddedPHIs.size() == 1 ? AddedPHIs[0] + : SSAUpdate.FindValueForBlock(UserBB); + if (V) + DPV->replaceVariableLocationOp(I, V); + } + // SSAUpdater might have inserted phi-nodes inside other loops. We'll need // to post-process them to keep LCSSA form. for (PHINode *InsertedPN : LocalInsertedPHIs) { diff --git a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index cdcfb5050bff..6220f8509309 100644 --- a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -101,7 +101,7 @@ private: float Val) { Constant *V = ConstantFP::get(BBBuilder.getContext(), APFloat(Val)); if (!Arg->getType()->isFloatTy()) - V = ConstantExpr::getFPExtend(V, Arg->getType()); + V = ConstantFoldCastInstruction(Instruction::FPExt, V, Arg->getType()); if (BBBuilder.GetInsertBlock()->getParent()->hasFnAttribute(Attribute::StrictFP)) BBBuilder.setIsFPConstrained(true); return BBBuilder.CreateFCmp(Cmp, Arg, V); diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index f153ace5d3fc..51f39e0ba0cc 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -69,6 +69,7 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" @@ -86,6 +87,8 @@ using namespace llvm; using namespace llvm::PatternMatch; +extern cl::opt<bool> UseNewDbgInfoFormat; + #define DEBUG_TYPE "local" STATISTIC(NumRemoved, "Number of unreachable basic blocks removed"); @@ -227,9 +230,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Remove weight for this case. std::swap(Weights[Idx + 1], Weights.back()); Weights.pop_back(); - SI->setMetadata(LLVMContext::MD_prof, - MDBuilder(BB->getContext()). - createBranchWeights(Weights)); + setBranchWeights(*SI, Weights); } // Remove this entry. BasicBlock *ParentBB = SI->getParent(); @@ -414,7 +415,7 @@ bool llvm::wouldInstructionBeTriviallyDeadOnUnusedPaths( return wouldInstructionBeTriviallyDead(I, TLI); } -bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, +bool llvm::wouldInstructionBeTriviallyDead(const Instruction *I, const TargetLibraryInfo *TLI) { if (I->isTerminator()) return false; @@ -428,7 +429,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (isa<DbgVariableIntrinsic>(I)) return false; - if (DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) { + if (const DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) { if (DLI->getLabel()) return false; return true; @@ -443,9 +444,16 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (!II) return false; + switch (II->getIntrinsicID()) { + case Intrinsic::experimental_guard: { + // Guards on true are operationally no-ops. In the future we can + // consider more sophisticated tradeoffs for guards considering potential + // for check widening, but for now we keep things simple. + auto *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0)); + return Cond && Cond->isOne(); + } // TODO: These intrinsics are not safe to remove, because this may remove // a well-defined trap. - switch (II->getIntrinsicID()) { case Intrinsic::wasm_trunc_signed: case Intrinsic::wasm_trunc_unsigned: case Intrinsic::ptrauth_auth: @@ -461,7 +469,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, // Special case intrinsics that "may have side effects" but can be deleted // when dead. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { // Safe to delete llvm.stacksave and launder.invariant.group if dead. if (II->getIntrinsicID() == Intrinsic::stacksave || II->getIntrinsicID() == Intrinsic::launder_invariant_group) @@ -484,13 +492,9 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, return false; } - // Assumptions are dead if their condition is trivially true. Guards on - // true are operationally no-ops. In the future we can consider more - // sophisticated tradeoffs for guards considering potential for check - // widening, but for now we keep things simple. - if ((II->getIntrinsicID() == Intrinsic::assume && - isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) || - II->getIntrinsicID() == Intrinsic::experimental_guard) { + // Assumptions are dead if their condition is trivially true. + if (II->getIntrinsicID() == Intrinsic::assume && + isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) { if (ConstantInt *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0))) return !Cond->isZero(); @@ -605,10 +609,13 @@ void llvm::RecursivelyDeleteTriviallyDeadInstructions( bool llvm::replaceDbgUsesWithUndef(Instruction *I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; - findDbgUsers(DbgUsers, I); + SmallVector<DPValue *, 1> DPUsers; + findDbgUsers(DbgUsers, I, &DPUsers); for (auto *DII : DbgUsers) DII->setKillLocation(); - return !DbgUsers.empty(); + for (auto *DPV : DPUsers) + DPV->setKillLocation(); + return !DbgUsers.empty() || !DPUsers.empty(); } /// areAllUsesEqual - Check whether the uses of a value are all the same. @@ -847,17 +854,17 @@ static bool CanMergeValues(Value *First, Value *Second) { /// branch to Succ, into Succ. /// /// Assumption: Succ is the single successor for BB. -static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { +static bool +CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ, + const SmallPtrSetImpl<BasicBlock *> &BBPreds) { assert(*succ_begin(BB) == Succ && "Succ is not successor of BB!"); LLVM_DEBUG(dbgs() << "Looking to fold " << BB->getName() << " into " << Succ->getName() << "\n"); // Shortcut, if there is only a single predecessor it must be BB and merging // is always safe - if (Succ->getSinglePredecessor()) return true; - - // Make a list of the predecessors of BB - SmallPtrSet<BasicBlock*, 16> BBPreds(pred_begin(BB), pred_end(BB)); + if (Succ->getSinglePredecessor()) + return true; // Look at all the phi nodes in Succ, to see if they present a conflict when // merging these blocks @@ -997,6 +1004,35 @@ static void replaceUndefValuesInPhi(PHINode *PN, } } +// Only when they shares a single common predecessor, return true. +// Only handles cases when BB can't be merged while its predecessors can be +// redirected. +static bool +CanRedirectPredsOfEmptyBBToSucc(BasicBlock *BB, BasicBlock *Succ, + const SmallPtrSetImpl<BasicBlock *> &BBPreds, + const SmallPtrSetImpl<BasicBlock *> &SuccPreds, + BasicBlock *&CommonPred) { + + // There must be phis in BB, otherwise BB will be merged into Succ directly + if (BB->phis().empty() || Succ->phis().empty()) + return false; + + // BB must have predecessors not shared that can be redirected to Succ + if (!BB->hasNPredecessorsOrMore(2)) + return false; + + // Get single common predecessors of both BB and Succ + for (BasicBlock *SuccPred : SuccPreds) { + if (BBPreds.count(SuccPred)) { + if (CommonPred) + return false; + CommonPred = SuccPred; + } + } + + return true; +} + /// Replace a value flowing from a block to a phi with /// potentially multiple instances of that value flowing from the /// block's predecessors to the phi. @@ -1004,9 +1040,11 @@ static void replaceUndefValuesInPhi(PHINode *PN, /// \param BB The block with the value flowing into the phi. /// \param BBPreds The predecessors of BB. /// \param PN The phi that we are updating. +/// \param CommonPred The common predecessor of BB and PN's BasicBlock static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB, const PredBlockVector &BBPreds, - PHINode *PN) { + PHINode *PN, + BasicBlock *CommonPred) { Value *OldVal = PN->removeIncomingValue(BB, false); assert(OldVal && "No entry in PHI for Pred BB!"); @@ -1034,26 +1072,39 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB, // will trigger asserts if we try to clean it up now, without also // simplifying the corresponding conditional branch). BasicBlock *PredBB = OldValPN->getIncomingBlock(i); + + if (PredBB == CommonPred) + continue; + Value *PredVal = OldValPN->getIncomingValue(i); - Value *Selected = selectIncomingValueForBlock(PredVal, PredBB, - IncomingValues); + Value *Selected = + selectIncomingValueForBlock(PredVal, PredBB, IncomingValues); // And add a new incoming value for this predecessor for the // newly retargeted branch. PN->addIncoming(Selected, PredBB); } + if (CommonPred) + PN->addIncoming(OldValPN->getIncomingValueForBlock(CommonPred), BB); + } else { for (unsigned i = 0, e = BBPreds.size(); i != e; ++i) { // Update existing incoming values in PN for this // predecessor of BB. BasicBlock *PredBB = BBPreds[i]; - Value *Selected = selectIncomingValueForBlock(OldVal, PredBB, - IncomingValues); + + if (PredBB == CommonPred) + continue; + + Value *Selected = + selectIncomingValueForBlock(OldVal, PredBB, IncomingValues); // And add a new incoming value for this predecessor for the // newly retargeted branch. PN->addIncoming(Selected, PredBB); } + if (CommonPred) + PN->addIncoming(OldVal, BB); } replaceUndefValuesInPhi(PN, IncomingValues); @@ -1064,13 +1115,30 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, assert(BB != &BB->getParent()->getEntryBlock() && "TryToSimplifyUncondBranchFromEmptyBlock called on entry block!"); - // We can't eliminate infinite loops. + // We can't simplify infinite loops. BasicBlock *Succ = cast<BranchInst>(BB->getTerminator())->getSuccessor(0); - if (BB == Succ) return false; + if (BB == Succ) + return false; + + SmallPtrSet<BasicBlock *, 16> BBPreds(pred_begin(BB), pred_end(BB)); + SmallPtrSet<BasicBlock *, 16> SuccPreds(pred_begin(Succ), pred_end(Succ)); - // Check to see if merging these blocks would cause conflicts for any of the - // phi nodes in BB or Succ. If not, we can safely merge. - if (!CanPropagatePredecessorsForPHIs(BB, Succ)) return false; + // The single common predecessor of BB and Succ when BB cannot be killed + BasicBlock *CommonPred = nullptr; + + bool BBKillable = CanPropagatePredecessorsForPHIs(BB, Succ, BBPreds); + + // Even if we can not fold bB into Succ, we may be able to redirect the + // predecessors of BB to Succ. + bool BBPhisMergeable = + BBKillable || + CanRedirectPredsOfEmptyBBToSucc(BB, Succ, BBPreds, SuccPreds, CommonPred); + + if (!BBKillable && !BBPhisMergeable) + return false; + + // Check to see if merging these blocks/phis would cause conflicts for any of + // the phi nodes in BB or Succ. If not, we can safely merge. // Check for cases where Succ has multiple predecessors and a PHI node in BB // has uses which will not disappear when the PHI nodes are merged. It is @@ -1099,6 +1167,11 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, } } + if (BBPhisMergeable && CommonPred) + LLVM_DEBUG(dbgs() << "Found Common Predecessor between: " << BB->getName() + << " and " << Succ->getName() << " : " + << CommonPred->getName() << "\n"); + // 'BB' and 'BB->Pred' are loop latches, bail out to presrve inner loop // metadata. // @@ -1171,25 +1244,37 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, if (PredTI->hasMetadata(LLVMContext::MD_loop)) return false; - LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); + if (BBKillable) + LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); + else if (BBPhisMergeable) + LLVM_DEBUG(dbgs() << "Merge Phis in Trivial BB: \n" << *BB); SmallVector<DominatorTree::UpdateType, 32> Updates; + if (DTU) { // To avoid processing the same predecessor more than once. SmallPtrSet<BasicBlock *, 8> SeenPreds; - // All predecessors of BB will be moved to Succ. - SmallPtrSet<BasicBlock *, 8> PredsOfSucc(pred_begin(Succ), pred_end(Succ)); + // All predecessors of BB (except the common predecessor) will be moved to + // Succ. Updates.reserve(Updates.size() + 2 * pred_size(BB) + 1); - for (auto *PredOfBB : predecessors(BB)) - // This predecessor of BB may already have Succ as a successor. - if (!PredsOfSucc.contains(PredOfBB)) + + for (auto *PredOfBB : predecessors(BB)) { + // Do not modify those common predecessors of BB and Succ + if (!SuccPreds.contains(PredOfBB)) if (SeenPreds.insert(PredOfBB).second) Updates.push_back({DominatorTree::Insert, PredOfBB, Succ}); + } + SeenPreds.clear(); + for (auto *PredOfBB : predecessors(BB)) - if (SeenPreds.insert(PredOfBB).second) + // When BB cannot be killed, do not remove the edge between BB and + // CommonPred. + if (SeenPreds.insert(PredOfBB).second && PredOfBB != CommonPred) Updates.push_back({DominatorTree::Delete, PredOfBB, BB}); - Updates.push_back({DominatorTree::Delete, BB, Succ}); + + if (BBKillable) + Updates.push_back({DominatorTree::Delete, BB, Succ}); } if (isa<PHINode>(Succ->begin())) { @@ -1201,21 +1286,19 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, // Loop over all of the PHI nodes in the successor of BB. for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { PHINode *PN = cast<PHINode>(I); - - redirectValuesFromPredecessorsToPhi(BB, BBPreds, PN); + redirectValuesFromPredecessorsToPhi(BB, BBPreds, PN, CommonPred); } } if (Succ->getSinglePredecessor()) { // BB is the only predecessor of Succ, so Succ will end up with exactly // the same predecessors BB had. - // Copy over any phi, debug or lifetime instruction. BB->getTerminator()->eraseFromParent(); - Succ->splice(Succ->getFirstNonPHI()->getIterator(), BB); + Succ->splice(Succ->getFirstNonPHIIt(), BB); } else { while (PHINode *PN = dyn_cast<PHINode>(&BB->front())) { - // We explicitly check for such uses in CanPropagatePredecessorsForPHIs. + // We explicitly check for such uses for merging phis. assert(PN->use_empty() && "There shouldn't be any uses here!"); PN->eraseFromParent(); } @@ -1228,26 +1311,42 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, for (BasicBlock *Pred : predecessors(BB)) Pred->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopMD); - // Everything that jumped to BB now goes to Succ. - BB->replaceAllUsesWith(Succ); - if (!Succ->hasName()) Succ->takeName(BB); + if (BBKillable) { + // Everything that jumped to BB now goes to Succ. + BB->replaceAllUsesWith(Succ); - // Clear the successor list of BB to match updates applying to DTU later. - if (BB->getTerminator()) - BB->back().eraseFromParent(); - new UnreachableInst(BB->getContext(), BB); - assert(succ_empty(BB) && "The successor list of BB isn't empty before " - "applying corresponding DTU updates."); + if (!Succ->hasName()) + Succ->takeName(BB); + + // Clear the successor list of BB to match updates applying to DTU later. + if (BB->getTerminator()) + BB->back().eraseFromParent(); + + new UnreachableInst(BB->getContext(), BB); + assert(succ_empty(BB) && "The successor list of BB isn't empty before " + "applying corresponding DTU updates."); + } else if (BBPhisMergeable) { + // Everything except CommonPred that jumped to BB now goes to Succ. + BB->replaceUsesWithIf(Succ, [BBPreds, CommonPred](Use &U) -> bool { + if (Instruction *UseInst = dyn_cast<Instruction>(U.getUser())) + return UseInst->getParent() != CommonPred && + BBPreds.contains(UseInst->getParent()); + return false; + }); + } if (DTU) DTU->applyUpdates(Updates); - DeleteDeadBlock(BB, DTU); + if (BBKillable) + DeleteDeadBlock(BB, DTU); return true; } -static bool EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB) { +static bool +EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB, + SmallPtrSetImpl<PHINode *> &ToRemove) { // This implementation doesn't currently consider undef operands // specially. Theoretically, two phis which are identical except for // one having an undef where the other doesn't could be collapsed. @@ -1263,12 +1362,14 @@ static bool EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB) { // Note that we only look in the upper square's triangle, // we already checked that the lower triangle PHI's aren't identical. for (auto J = I; PHINode *DuplicatePN = dyn_cast<PHINode>(J); ++J) { + if (ToRemove.contains(DuplicatePN)) + continue; if (!DuplicatePN->isIdenticalToWhenDefined(PN)) continue; // A duplicate. Replace this PHI with the base PHI. ++NumPHICSEs; DuplicatePN->replaceAllUsesWith(PN); - DuplicatePN->eraseFromParent(); + ToRemove.insert(DuplicatePN); Changed = true; // The RAUW can change PHIs that we already visited. @@ -1279,7 +1380,9 @@ static bool EliminateDuplicatePHINodesNaiveImpl(BasicBlock *BB) { return Changed; } -static bool EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB) { +static bool +EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB, + SmallPtrSetImpl<PHINode *> &ToRemove) { // This implementation doesn't currently consider undef operands // specially. Theoretically, two phis which are identical except for // one having an undef where the other doesn't could be collapsed. @@ -1343,12 +1446,14 @@ static bool EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB) { // Examine each PHI. bool Changed = false; for (auto I = BB->begin(); PHINode *PN = dyn_cast<PHINode>(I++);) { + if (ToRemove.contains(PN)) + continue; auto Inserted = PHISet.insert(PN); if (!Inserted.second) { // A duplicate. Replace this PHI with its duplicate. ++NumPHICSEs; PN->replaceAllUsesWith(*Inserted.first); - PN->eraseFromParent(); + ToRemove.insert(PN); Changed = true; // The RAUW can change PHIs that we already visited. Start over from the @@ -1361,25 +1466,27 @@ static bool EliminateDuplicatePHINodesSetBasedImpl(BasicBlock *BB) { return Changed; } -bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { +bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB, + SmallPtrSetImpl<PHINode *> &ToRemove) { if ( #ifndef NDEBUG !PHICSEDebugHash && #endif hasNItemsOrLess(BB->phis(), PHICSENumPHISmallSize)) - return EliminateDuplicatePHINodesNaiveImpl(BB); - return EliminateDuplicatePHINodesSetBasedImpl(BB); + return EliminateDuplicatePHINodesNaiveImpl(BB, ToRemove); + return EliminateDuplicatePHINodesSetBasedImpl(BB, ToRemove); } -/// If the specified pointer points to an object that we control, try to modify -/// the object's alignment to PrefAlign. Returns a minimum known alignment of -/// the value after the operation, which may be lower than PrefAlign. -/// -/// Increating value alignment isn't often possible though. If alignment is -/// important, a more reliable approach is to simply align all global variables -/// and allocation instructions to their preferred alignment from the beginning. -static Align tryEnforceAlignment(Value *V, Align PrefAlign, - const DataLayout &DL) { +bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { + SmallPtrSet<PHINode *, 8> ToRemove; + bool Changed = EliminateDuplicatePHINodes(BB, ToRemove); + for (PHINode *PN : ToRemove) + PN->eraseFromParent(); + return Changed; +} + +Align llvm::tryEnforceAlignment(Value *V, Align PrefAlign, + const DataLayout &DL) { V = V->stripPointerCasts(); if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) { @@ -1463,12 +1570,18 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, // is removed by LowerDbgDeclare(), we need to make sure that we are // not inserting the same dbg.value intrinsic over and over. SmallVector<DbgValueInst *, 1> DbgValues; - findDbgValues(DbgValues, APN); + SmallVector<DPValue *, 1> DPValues; + findDbgValues(DbgValues, APN, &DPValues); for (auto *DVI : DbgValues) { assert(is_contained(DVI->getValues(), APN)); if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr)) return true; } + for (auto *DPV : DPValues) { + assert(is_contained(DPV->location_ops(), APN)); + if ((DPV->getVariable() == DIVar) && (DPV->getExpression() == DIExpr)) + return true; + } return false; } @@ -1504,6 +1617,67 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { // Could not determine size of variable. Conservatively return false. return false; } +// RemoveDIs: duplicate implementation of the above, using DPValues, the +// replacement for dbg.values. +static bool valueCoversEntireFragment(Type *ValTy, DPValue *DPV) { + const DataLayout &DL = DPV->getModule()->getDataLayout(); + TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy); + if (std::optional<uint64_t> FragmentSize = DPV->getFragmentSizeInBits()) + return TypeSize::isKnownGE(ValueSize, TypeSize::getFixed(*FragmentSize)); + + // We can't always calculate the size of the DI variable (e.g. if it is a + // VLA). Try to use the size of the alloca that the dbg intrinsic describes + // intead. + if (DPV->isAddressOfVariable()) { + // DPV should have exactly 1 location when it is an address. + assert(DPV->getNumVariableLocationOps() == 1 && + "address of variable must have exactly 1 location operand."); + if (auto *AI = + dyn_cast_or_null<AllocaInst>(DPV->getVariableLocationOp(0))) { + if (std::optional<TypeSize> FragmentSize = AI->getAllocationSizeInBits(DL)) { + return TypeSize::isKnownGE(ValueSize, *FragmentSize); + } + } + } + // Could not determine size of variable. Conservatively return false. + return false; +} + +static void insertDbgValueOrDPValue(DIBuilder &Builder, Value *DV, + DILocalVariable *DIVar, + DIExpression *DIExpr, + const DebugLoc &NewLoc, + BasicBlock::iterator Instr) { + if (!UseNewDbgInfoFormat) { + auto *DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, + (Instruction *)nullptr); + DbgVal->insertBefore(Instr); + } else { + // RemoveDIs: if we're using the new debug-info format, allocate a + // DPValue directly instead of a dbg.value intrinsic. + ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); + DPValue *DV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); + Instr->getParent()->insertDPValueBefore(DV, Instr); + } +} + +static void insertDbgValueOrDPValueAfter(DIBuilder &Builder, Value *DV, + DILocalVariable *DIVar, + DIExpression *DIExpr, + const DebugLoc &NewLoc, + BasicBlock::iterator Instr) { + if (!UseNewDbgInfoFormat) { + auto *DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, + (Instruction *)nullptr); + DbgVal->insertAfter(&*Instr); + } else { + // RemoveDIs: if we're using the new debug-info format, allocate a + // DPValue directly instead of a dbg.value intrinsic. + ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); + DPValue *DV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); + Instr->getParent()->insertDPValueAfter(DV, &*Instr); + } +} /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value /// that has an associated llvm.dbg.declare intrinsic. @@ -1533,7 +1707,8 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, DIExpr->isDeref() || (!DIExpr->startsWithDeref() && valueCoversEntireFragment(DV->getType(), DII)); if (CanConvert) { - Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); + insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc, + SI->getIterator()); return; } @@ -1545,7 +1720,19 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // know which part) we insert an dbg.value intrinsic to indicate that we // know nothing about the variable's content. DV = UndefValue::get(DV->getType()); - Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); + insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc, + SI->getIterator()); +} + +// RemoveDIs: duplicate the getDebugValueLoc method using DPValues instead of +// dbg.value intrinsics. +static DebugLoc getDebugValueLocDPV(DPValue *DPV) { + // Original dbg.declare must have a location. + const DebugLoc &DeclareLoc = DPV->getDebugLoc(); + MDNode *Scope = DeclareLoc.getScope(); + DILocation *InlinedAt = DeclareLoc.getInlinedAt(); + // Produce an unknown location with the correct scope / inlinedAt fields. + return DILocation::get(DPV->getContext(), 0, 0, Scope, InlinedAt); } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value @@ -1571,9 +1758,40 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // future if multi-location support is added to the IR, it might be // preferable to keep tracking both the loaded value and the original // address in case the alloca can not be elided. - Instruction *DbgValue = Builder.insertDbgValueIntrinsic( - LI, DIVar, DIExpr, NewLoc, (Instruction *)nullptr); - DbgValue->insertAfter(LI); + insertDbgValueOrDPValueAfter(Builder, LI, DIVar, DIExpr, NewLoc, + LI->getIterator()); +} + +void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, StoreInst *SI, + DIBuilder &Builder) { + assert(DPV->isAddressOfVariable()); + auto *DIVar = DPV->getVariable(); + assert(DIVar && "Missing variable"); + auto *DIExpr = DPV->getExpression(); + Value *DV = SI->getValueOperand(); + + DebugLoc NewLoc = getDebugValueLocDPV(DPV); + + if (!valueCoversEntireFragment(DV->getType(), DPV)) { + // FIXME: If storing to a part of the variable described by the dbg.declare, + // then we want to insert a DPValue.value for the corresponding fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV + << '\n'); + // For now, when there is a store to parts of the variable (but we do not + // know which part) we insert an DPValue record to indicate that we know + // nothing about the variable's content. + DV = UndefValue::get(DV->getType()); + ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); + DPValue *NewDPV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); + SI->getParent()->insertDPValueBefore(NewDPV, SI->getIterator()); + return; + } + + assert(UseNewDbgInfoFormat); + // Create a DPValue directly and insert. + ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); + DPValue *NewDPV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); + SI->getParent()->insertDPValueBefore(NewDPV, SI->getIterator()); } /// Inserts a llvm.dbg.value intrinsic after a phi that has an associated @@ -1604,8 +1822,38 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // The block may be a catchswitch block, which does not have a valid // insertion point. // FIXME: Insert dbg.value markers in the successors when appropriate. - if (InsertionPt != BB->end()) - Builder.insertDbgValueIntrinsic(APN, DIVar, DIExpr, NewLoc, &*InsertionPt); + if (InsertionPt != BB->end()) { + insertDbgValueOrDPValue(Builder, APN, DIVar, DIExpr, NewLoc, InsertionPt); + } +} + +void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, LoadInst *LI, + DIBuilder &Builder) { + auto *DIVar = DPV->getVariable(); + auto *DIExpr = DPV->getExpression(); + assert(DIVar && "Missing variable"); + + if (!valueCoversEntireFragment(LI->getType(), DPV)) { + // FIXME: If only referring to a part of the variable described by the + // dbg.declare, then we want to insert a DPValue for the corresponding + // fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV + << '\n'); + return; + } + + DebugLoc NewLoc = getDebugValueLocDPV(DPV); + + // We are now tracking the loaded value instead of the address. In the + // future if multi-location support is added to the IR, it might be + // preferable to keep tracking both the loaded value and the original + // address in case the alloca can not be elided. + assert(UseNewDbgInfoFormat); + + // Create a DPValue directly and insert. + ValueAsMetadata *LIVAM = ValueAsMetadata::get(LI); + DPValue *DV = new DPValue(LIVAM, DIVar, DIExpr, NewLoc.get()); + LI->getParent()->insertDPValueAfter(DV, LI); } /// Determine whether this alloca is either a VLA or an array. @@ -1618,6 +1866,36 @@ static bool isArray(AllocaInst *AI) { static bool isStructure(AllocaInst *AI) { return AI->getAllocatedType() && AI->getAllocatedType()->isStructTy(); } +void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, PHINode *APN, + DIBuilder &Builder) { + auto *DIVar = DPV->getVariable(); + auto *DIExpr = DPV->getExpression(); + assert(DIVar && "Missing variable"); + + if (PhiHasDebugValue(DIVar, DIExpr, APN)) + return; + + if (!valueCoversEntireFragment(APN->getType(), DPV)) { + // FIXME: If only referring to a part of the variable described by the + // dbg.declare, then we want to insert a DPValue for the corresponding + // fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV + << '\n'); + return; + } + + BasicBlock *BB = APN->getParent(); + auto InsertionPt = BB->getFirstInsertionPt(); + + DebugLoc NewLoc = getDebugValueLocDPV(DPV); + + // The block may be a catchswitch block, which does not have a valid + // insertion point. + // FIXME: Insert DPValue markers in the successors when appropriate. + if (InsertionPt != BB->end()) { + insertDbgValueOrDPValue(Builder, APN, DIVar, DIExpr, NewLoc, InsertionPt); + } +} /// LowerDbgDeclare - Lowers llvm.dbg.declare intrinsics into appropriate set /// of llvm.dbg.value intrinsics. @@ -1674,8 +1952,8 @@ bool llvm::LowerDbgDeclare(Function &F) { DebugLoc NewLoc = getDebugValueLoc(DDI); auto *DerefExpr = DIExpression::append(DDI->getExpression(), dwarf::DW_OP_deref); - DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DerefExpr, - NewLoc, CI); + insertDbgValueOrDPValue(DIB, AI, DDI->getVariable(), DerefExpr, + NewLoc, CI->getIterator()); } } else if (BitCastInst *BI = dyn_cast<BitCastInst>(U)) { if (BI->getType()->isPointerTy()) @@ -1694,6 +1972,69 @@ bool llvm::LowerDbgDeclare(Function &F) { return Changed; } +// RemoveDIs: re-implementation of insertDebugValuesForPHIs, but which pulls the +// debug-info out of the block's DPValues rather than dbg.value intrinsics. +static void insertDPValuesForPHIs(BasicBlock *BB, + SmallVectorImpl<PHINode *> &InsertedPHIs) { + assert(BB && "No BasicBlock to clone DPValue(s) from."); + if (InsertedPHIs.size() == 0) + return; + + // Map existing PHI nodes to their DPValues. + DenseMap<Value *, DPValue *> DbgValueMap; + for (auto &I : *BB) { + for (auto &DPV : I.getDbgValueRange()) { + for (Value *V : DPV.location_ops()) + if (auto *Loc = dyn_cast_or_null<PHINode>(V)) + DbgValueMap.insert({Loc, &DPV}); + } + } + if (DbgValueMap.size() == 0) + return; + + // Map a pair of the destination BB and old DPValue to the new DPValue, + // so that if a DPValue is being rewritten to use more than one of the + // inserted PHIs in the same destination BB, we can update the same DPValue + // with all the new PHIs instead of creating one copy for each. + MapVector<std::pair<BasicBlock *, DPValue *>, DPValue *> NewDbgValueMap; + // Then iterate through the new PHIs and look to see if they use one of the + // previously mapped PHIs. If so, create a new DPValue that will propagate + // the info through the new PHI. If we use more than one new PHI in a single + // destination BB with the same old dbg.value, merge the updates so that we + // get a single new DPValue with all the new PHIs. + for (auto PHI : InsertedPHIs) { + BasicBlock *Parent = PHI->getParent(); + // Avoid inserting a debug-info record into an EH block. + if (Parent->getFirstNonPHI()->isEHPad()) + continue; + for (auto VI : PHI->operand_values()) { + auto V = DbgValueMap.find(VI); + if (V != DbgValueMap.end()) { + DPValue *DbgII = cast<DPValue>(V->second); + auto NewDI = NewDbgValueMap.find({Parent, DbgII}); + if (NewDI == NewDbgValueMap.end()) { + DPValue *NewDbgII = DbgII->clone(); + NewDI = NewDbgValueMap.insert({{Parent, DbgII}, NewDbgII}).first; + } + DPValue *NewDbgII = NewDI->second; + // If PHI contains VI as an operand more than once, we may + // replaced it in NewDbgII; confirm that it is present. + if (is_contained(NewDbgII->location_ops(), VI)) + NewDbgII->replaceVariableLocationOp(VI, PHI); + } + } + } + // Insert the new DPValues into their destination blocks. + for (auto DI : NewDbgValueMap) { + BasicBlock *Parent = DI.first.first; + DPValue *NewDbgII = DI.second; + auto InsertionPt = Parent->getFirstInsertionPt(); + assert(InsertionPt != Parent->end() && "Ill-formed basic block"); + + InsertionPt->DbgMarker->insertDPValue(NewDbgII, true); + } +} + /// Propagate dbg.value intrinsics through the newly inserted PHIs. void llvm::insertDebugValuesForPHIs(BasicBlock *BB, SmallVectorImpl<PHINode *> &InsertedPHIs) { @@ -1701,6 +2042,8 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, if (InsertedPHIs.size() == 0) return; + insertDPValuesForPHIs(BB, InsertedPHIs); + // Map existing PHI nodes to their dbg.values. ValueToValueMapTy DbgValueMap; for (auto &I : *BB) { @@ -1775,44 +2118,60 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, return !DbgDeclares.empty(); } -static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, - DIBuilder &Builder, int Offset) { - const DebugLoc &Loc = DVI->getDebugLoc(); - auto *DIVar = DVI->getVariable(); - auto *DIExpr = DVI->getExpression(); +static void updateOneDbgValueForAlloca(const DebugLoc &Loc, + DILocalVariable *DIVar, + DIExpression *DIExpr, Value *NewAddress, + DbgValueInst *DVI, DPValue *DPV, + DIBuilder &Builder, int Offset) { assert(DIVar && "Missing variable"); - // This is an alloca-based llvm.dbg.value. The first thing it should do with - // the alloca pointer is dereference it. Otherwise we don't know how to handle - // it and give up. + // This is an alloca-based dbg.value/DPValue. The first thing it should do + // with the alloca pointer is dereference it. Otherwise we don't know how to + // handle it and give up. if (!DIExpr || DIExpr->getNumElements() < 1 || DIExpr->getElement(0) != dwarf::DW_OP_deref) return; // Insert the offset before the first deref. - // We could just change the offset argument of dbg.value, but it's unsigned... if (Offset) DIExpr = DIExpression::prepend(DIExpr, 0, Offset); - Builder.insertDbgValueIntrinsic(NewAddress, DIVar, DIExpr, Loc, DVI); - DVI->eraseFromParent(); + if (DVI) { + DVI->setExpression(DIExpr); + DVI->replaceVariableLocationOp(0u, NewAddress); + } else { + assert(DPV); + DPV->setExpression(DIExpr); + DPV->replaceVariableLocationOp(0u, NewAddress); + } } void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, DIBuilder &Builder, int Offset) { - if (auto *L = LocalAsMetadata::getIfExists(AI)) - if (auto *MDV = MetadataAsValue::getIfExists(AI->getContext(), L)) - for (Use &U : llvm::make_early_inc_range(MDV->uses())) - if (auto *DVI = dyn_cast<DbgValueInst>(U.getUser())) - replaceOneDbgValueForAlloca(DVI, NewAllocaAddress, Builder, Offset); + SmallVector<DbgValueInst *, 1> DbgUsers; + SmallVector<DPValue *, 1> DPUsers; + findDbgValues(DbgUsers, AI, &DPUsers); + + // Attempt to replace dbg.values that use this alloca. + for (auto *DVI : DbgUsers) + updateOneDbgValueForAlloca(DVI->getDebugLoc(), DVI->getVariable(), + DVI->getExpression(), NewAllocaAddress, DVI, + nullptr, Builder, Offset); + + // Replace any DPValues that use this alloca. + for (DPValue *DPV : DPUsers) + updateOneDbgValueForAlloca(DPV->getDebugLoc(), DPV->getVariable(), + DPV->getExpression(), NewAllocaAddress, nullptr, + DPV, Builder, Offset); } /// Where possible to salvage debug information for \p I do so. /// If not possible mark undef. void llvm::salvageDebugInfo(Instruction &I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; - findDbgUsers(DbgUsers, &I); - salvageDebugInfoForDbgValues(I, DbgUsers); + SmallVector<DPValue *, 1> DPUsers; + findDbgUsers(DbgUsers, &I, &DPUsers); + salvageDebugInfoForDbgValues(I, DbgUsers, DPUsers); } /// Salvage the address component of \p DAI. @@ -1850,7 +2209,8 @@ static void salvageDbgAssignAddress(DbgAssignIntrinsic *DAI) { } void llvm::salvageDebugInfoForDbgValues( - Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers) { + Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers, + ArrayRef<DPValue *> DPUsers) { // These are arbitrary chosen limits on the maximum number of values and the // maximum size of a debug expression we can salvage up to, used for // performance reasons. @@ -1916,12 +2276,70 @@ void llvm::salvageDebugInfoForDbgValues( LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); Salvaged = true; } + // Duplicate of above block for DPValues. + for (auto *DPV : DPUsers) { + // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they + // are implicitly pointing out the value as a DWARF memory location + // description. + bool StackValue = DPV->getType() == DPValue::LocationType::Value; + auto DPVLocation = DPV->location_ops(); + assert( + is_contained(DPVLocation, &I) && + "DbgVariableIntrinsic must use salvaged instruction as its location"); + SmallVector<Value *, 4> AdditionalValues; + // 'I' may appear more than once in DPV's location ops, and each use of 'I' + // must be updated in the DIExpression and potentially have additional + // values added; thus we call salvageDebugInfoImpl for each 'I' instance in + // DPVLocation. + Value *Op0 = nullptr; + DIExpression *SalvagedExpr = DPV->getExpression(); + auto LocItr = find(DPVLocation, &I); + while (SalvagedExpr && LocItr != DPVLocation.end()) { + SmallVector<uint64_t, 16> Ops; + unsigned LocNo = std::distance(DPVLocation.begin(), LocItr); + uint64_t CurrentLocOps = SalvagedExpr->getNumLocationOperands(); + Op0 = salvageDebugInfoImpl(I, CurrentLocOps, Ops, AdditionalValues); + if (!Op0) + break; + SalvagedExpr = + DIExpression::appendOpsToArg(SalvagedExpr, Ops, LocNo, StackValue); + LocItr = std::find(++LocItr, DPVLocation.end(), &I); + } + // salvageDebugInfoImpl should fail on examining the first element of + // DbgUsers, or none of them. + if (!Op0) + break; + + DPV->replaceVariableLocationOp(&I, Op0); + bool IsValidSalvageExpr = + SalvagedExpr->getNumElements() <= MaxExpressionSize; + if (AdditionalValues.empty() && IsValidSalvageExpr) { + DPV->setExpression(SalvagedExpr); + } else if (DPV->getType() == DPValue::LocationType::Value && + IsValidSalvageExpr && + DPV->getNumVariableLocationOps() + AdditionalValues.size() <= + MaxDebugArgs) { + DPV->addVariableLocationOps(AdditionalValues, SalvagedExpr); + } else { + // Do not salvage using DIArgList for dbg.addr/dbg.declare, as it is + // currently only valid for stack value expressions. + // Also do not salvage if the resulting DIArgList would contain an + // unreasonably large number of values. + Value *Undef = UndefValue::get(I.getOperand(0)->getType()); + DPV->replaceVariableLocationOp(I.getOperand(0), Undef); + } + LLVM_DEBUG(dbgs() << "SALVAGE: " << DPV << '\n'); + Salvaged = true; + } if (Salvaged) return; for (auto *DII : DbgUsers) DII->setKillLocation(); + + for (auto *DPV : DPUsers) + DPV->setKillLocation(); } Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, @@ -2136,16 +2554,20 @@ using DbgValReplacement = std::optional<DIExpression *>; /// changes are made. static bool rewriteDebugUsers( Instruction &From, Value &To, Instruction &DomPoint, DominatorTree &DT, - function_ref<DbgValReplacement(DbgVariableIntrinsic &DII)> RewriteExpr) { + function_ref<DbgValReplacement(DbgVariableIntrinsic &DII)> RewriteExpr, + function_ref<DbgValReplacement(DPValue &DPV)> RewriteDPVExpr) { // Find debug users of From. SmallVector<DbgVariableIntrinsic *, 1> Users; - findDbgUsers(Users, &From); - if (Users.empty()) + SmallVector<DPValue *, 1> DPUsers; + findDbgUsers(Users, &From, &DPUsers); + if (Users.empty() && DPUsers.empty()) return false; // Prevent use-before-def of To. bool Changed = false; + SmallPtrSet<DbgVariableIntrinsic *, 1> UndefOrSalvage; + SmallPtrSet<DPValue *, 1> UndefOrSalvageDPV; if (isa<Instruction>(&To)) { bool DomPointAfterFrom = From.getNextNonDebugInstruction() == &DomPoint; @@ -2163,6 +2585,25 @@ static bool rewriteDebugUsers( UndefOrSalvage.insert(DII); } } + + // DPValue implementation of the above. + for (auto *DPV : DPUsers) { + Instruction *MarkedInstr = DPV->getMarker()->MarkedInstr; + Instruction *NextNonDebug = MarkedInstr; + // The next instruction might still be a dbg.declare, skip over it. + if (isa<DbgVariableIntrinsic>(NextNonDebug)) + NextNonDebug = NextNonDebug->getNextNonDebugInstruction(); + + if (DomPointAfterFrom && NextNonDebug == &DomPoint) { + LLVM_DEBUG(dbgs() << "MOVE: " << *DPV << '\n'); + DPV->removeFromParent(); + // Ensure there's a marker. + DomPoint.getParent()->insertDPValueAfter(DPV, &DomPoint); + Changed = true; + } else if (!DT.dominates(&DomPoint, MarkedInstr)) { + UndefOrSalvageDPV.insert(DPV); + } + } } // Update debug users without use-before-def risk. @@ -2179,8 +2620,21 @@ static bool rewriteDebugUsers( LLVM_DEBUG(dbgs() << "REWRITE: " << *DII << '\n'); Changed = true; } + for (auto *DPV : DPUsers) { + if (UndefOrSalvageDPV.count(DPV)) + continue; - if (!UndefOrSalvage.empty()) { + DbgValReplacement DVR = RewriteDPVExpr(*DPV); + if (!DVR) + continue; + + DPV->replaceVariableLocationOp(&From, &To); + DPV->setExpression(*DVR); + LLVM_DEBUG(dbgs() << "REWRITE: " << DPV << '\n'); + Changed = true; + } + + if (!UndefOrSalvage.empty() || !UndefOrSalvageDPV.empty()) { // Try to salvage the remaining debug users. salvageDebugInfo(From); Changed = true; @@ -2228,12 +2682,15 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, auto Identity = [&](DbgVariableIntrinsic &DII) -> DbgValReplacement { return DII.getExpression(); }; + auto IdentityDPV = [&](DPValue &DPV) -> DbgValReplacement { + return DPV.getExpression(); + }; // Handle no-op conversions. Module &M = *From.getModule(); const DataLayout &DL = M.getDataLayout(); if (isBitCastSemanticsPreserving(DL, FromTy, ToTy)) - return rewriteDebugUsers(From, To, DomPoint, DT, Identity); + return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDPV); // Handle integer-to-integer widening and narrowing. // FIXME: Use DW_OP_convert when it's available everywhere. @@ -2245,7 +2702,7 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, // When the width of the result grows, assume that a debugger will only // access the low `FromBits` bits when inspecting the source variable. if (FromBits < ToBits) - return rewriteDebugUsers(From, To, DomPoint, DT, Identity); + return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDPV); // The width of the result has shrunk. Use sign/zero extension to describe // the source variable's high bits. @@ -2261,7 +2718,22 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, return DIExpression::appendExt(DII.getExpression(), ToBits, FromBits, Signed); }; - return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt); + // RemoveDIs: duplicate implementation working on DPValues rather than on + // dbg.value intrinsics. + auto SignOrZeroExtDPV = [&](DPValue &DPV) -> DbgValReplacement { + DILocalVariable *Var = DPV.getVariable(); + + // Without knowing signedness, sign/zero extension isn't possible. + auto Signedness = Var->getSignedness(); + if (!Signedness) + return std::nullopt; + + bool Signed = *Signedness == DIBasicType::Signedness::Signed; + return DIExpression::appendExt(DPV.getExpression(), ToBits, FromBits, + Signed); + }; + return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt, + SignOrZeroExtDPV); } // TODO: Floating-point conversions, vectors. @@ -2275,12 +2747,17 @@ llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { // Delete the instructions backwards, as it has a reduced likelihood of // having to update as many def-use and use-def chains. Instruction *EndInst = BB->getTerminator(); // Last not to be deleted. + // RemoveDIs: erasing debug-info must be done manually. + EndInst->dropDbgValues(); while (EndInst != &BB->front()) { // Delete the next to last instruction. Instruction *Inst = &*--EndInst->getIterator(); if (!Inst->use_empty() && !Inst->getType()->isTokenTy()) Inst->replaceAllUsesWith(PoisonValue::get(Inst->getType())); if (Inst->isEHPad() || Inst->getType()->isTokenTy()) { + // EHPads can't have DPValues attached to them, but it might be possible + // for things with token type. + Inst->dropDbgValues(); EndInst = Inst; continue; } @@ -2288,6 +2765,8 @@ llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { ++NumDeadDbgInst; else ++NumDeadInst; + // RemoveDIs: erasing debug-info must be done manually. + Inst->dropDbgValues(); Inst->eraseFromParent(); } return {NumDeadInst, NumDeadDbgInst}; @@ -2329,6 +2808,7 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool PreserveLCSSA, Updates.push_back({DominatorTree::Delete, BB, UniqueSuccessor}); DTU->applyUpdates(Updates); } + BB->flushTerminatorDbgValues(); return NumInstrsRemoved; } @@ -2482,9 +2962,9 @@ static bool markAliveBlocks(Function &F, // If we found a call to a no-return function, insert an unreachable // instruction after it. Make sure there isn't *already* one there // though. - if (!isa<UnreachableInst>(CI->getNextNode())) { + if (!isa<UnreachableInst>(CI->getNextNonDebugInstruction())) { // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(CI->getNextNode(), false, DTU); + changeToUnreachable(CI->getNextNonDebugInstruction(), false, DTU); Changed = true; } break; @@ -2896,9 +3376,10 @@ static unsigned replaceDominatedUsesWith(Value *From, Value *To, for (Use &U : llvm::make_early_inc_range(From->uses())) { if (!Dominates(Root, U)) continue; + LLVM_DEBUG(dbgs() << "Replace dominated use of '"; + From->printAsOperand(dbgs()); + dbgs() << "' with " << *To << " in " << *U.getUser() << "\n"); U.set(To); - LLVM_DEBUG(dbgs() << "Replace dominated use of '" << From->getName() - << "' as " << *To << " in " << *U << "\n"); ++Count; } return Count; @@ -3017,9 +3498,12 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, void llvm::dropDebugUsers(Instruction &I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; - findDbgUsers(DbgUsers, &I); + SmallVector<DPValue *, 1> DPUsers; + findDbgUsers(DbgUsers, &I, &DPUsers); for (auto *DII : DbgUsers) DII->eraseFromParent(); + for (auto *DPV : DPUsers) + DPV->eraseFromParent(); } void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, @@ -3051,6 +3535,8 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, I->dropUBImplyingAttrsAndMetadata(); if (I->isUsedByMetadata()) dropDebugUsers(*I); + // RemoveDIs: drop debug-info too as the following code does. + I->dropDbgValues(); if (I->isDebugOrPseudoInst()) { // Remove DbgInfo and pseudo probe Intrinsics. II = I->eraseFromParent(); @@ -3063,6 +3549,41 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, BB->getTerminator()->getIterator()); } +DIExpression *llvm::getExpressionForConstant(DIBuilder &DIB, const Constant &C, + Type &Ty) { + // Create integer constant expression. + auto createIntegerExpression = [&DIB](const Constant &CV) -> DIExpression * { + const APInt &API = cast<ConstantInt>(&CV)->getValue(); + std::optional<int64_t> InitIntOpt = API.trySExtValue(); + return InitIntOpt ? DIB.createConstantValueExpression( + static_cast<uint64_t>(*InitIntOpt)) + : nullptr; + }; + + if (isa<ConstantInt>(C)) + return createIntegerExpression(C); + + if (Ty.isFloatTy() || Ty.isDoubleTy()) { + const APFloat &APF = cast<ConstantFP>(&C)->getValueAPF(); + return DIB.createConstantValueExpression( + APF.bitcastToAPInt().getZExtValue()); + } + + if (!Ty.isPointerTy()) + return nullptr; + + if (isa<ConstantPointerNull>(C)) + return DIB.createConstantValueExpression(0); + + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(&C)) + if (CE->getOpcode() == Instruction::IntToPtr) { + const Value *V = CE->getOperand(0); + if (auto CI = dyn_cast_or_null<ConstantInt>(V)) + return createIntegerExpression(*CI); + } + return nullptr; +} + namespace { /// A potential constituent of a bitreverse or bswap expression. See diff --git a/llvm/lib/Transforms/Utils/LoopConstrainer.cpp b/llvm/lib/Transforms/Utils/LoopConstrainer.cpp new file mode 100644 index 000000000000..ea6d952cfa7d --- /dev/null +++ b/llvm/lib/Transforms/Utils/LoopConstrainer.cpp @@ -0,0 +1,904 @@ +#include "llvm/Transforms/Utils/LoopConstrainer.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" + +using namespace llvm; + +static const char *ClonedLoopTag = "loop_constrainer.loop.clone"; + +#define DEBUG_TYPE "loop-constrainer" + +/// Given a loop with an deccreasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, + const SCEV *Step, ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, Loop *L, + ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + assert(SE.isKnownNegative(Step) && "expecting negative step"); + + LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); + LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1"); + + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) + : APInt::getMinValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); + + const SCEV *MinusOne = + SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); + + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); +} + +/// Given a loop with an increasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, + const SCEV *Step, ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, Loop *L, + ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); + LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); + + const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) + : APInt::getMaxValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); + + return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, + SE.getAddExpr(BoundSCEV, Step)) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); +} + +/// Returns estimate for max latch taken count of the loop of the narrowest +/// available type. If the latch block has such estimate, it is returned. +/// Otherwise, we use max exit count of whole loop (that is potentially of wider +/// type than latch check itself), which is still better than no estimate. +static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, + const Loop &L) { + const SCEV *FromBlock = + SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); + if (isa<SCEVCouldNotCompute>(FromBlock)) + return SE.getSymbolicMaxBackedgeTakenCount(&L); + return FromBlock; +} + +std::optional<LoopStructure> +LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, + bool AllowUnsignedLatchCond, + const char *&FailureReason) { + if (!L.isLoopSimplifyForm()) { + FailureReason = "loop not in LoopSimplify form"; + return std::nullopt; + } + + BasicBlock *Latch = L.getLoopLatch(); + assert(Latch && "Simplified loops only have one latch!"); + + if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { + FailureReason = "loop has already been cloned"; + return std::nullopt; + } + + if (!L.isLoopExiting(Latch)) { + FailureReason = "no loop latch"; + return std::nullopt; + } + + BasicBlock *Header = L.getHeader(); + BasicBlock *Preheader = L.getLoopPreheader(); + if (!Preheader) { + FailureReason = "no preheader"; + return std::nullopt; + } + + BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!LatchBr || LatchBr->isUnconditional()) { + FailureReason = "latch terminator not conditional branch"; + return std::nullopt; + } + + unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; + + ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); + if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { + FailureReason = "latch terminator branch not conditional on integral icmp"; + return std::nullopt; + } + + const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); + if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) { + FailureReason = "could not compute latch count"; + return std::nullopt; + } + assert(SE.getLoopDisposition(MaxBETakenCount, &L) == + ScalarEvolution::LoopInvariant && + "loop variant exit count doesn't make sense!"); + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *LeftValue = ICI->getOperand(0); + const SCEV *LeftSCEV = SE.getSCEV(LeftValue); + IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); + + Value *RightValue = ICI->getOperand(1); + const SCEV *RightSCEV = SE.getSCEV(RightValue); + + // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. + if (!isa<SCEVAddRecExpr>(LeftSCEV)) { + if (isa<SCEVAddRecExpr>(RightSCEV)) { + std::swap(LeftSCEV, RightSCEV); + std::swap(LeftValue, RightValue); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else { + FailureReason = "no add recurrences in the icmp"; + return std::nullopt; + } + } + + auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { + if (AR->getNoWrapFlags(SCEV::FlagNSW)) + return true; + + IntegerType *Ty = cast<IntegerType>(AR->getType()); + IntegerType *WideTy = + IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + + const SCEVAddRecExpr *ExtendAfterOp = + dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); + if (ExtendAfterOp) { + const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); + const SCEV *ExtendedStep = + SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); + + bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && + ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; + + if (NoSignedWrap) + return true; + } + + // We may have proved this when computing the sign extension above. + return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; + }; + + // `ICI` is interpreted as taking the backedge if the *next* value of the + // induction variable satisfies some constraint. + + const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); + if (IndVarBase->getLoop() != &L) { + FailureReason = "LHS in cmp is not an AddRec for this loop"; + return std::nullopt; + } + if (!IndVarBase->isAffine()) { + FailureReason = "LHS in icmp not induction variable"; + return std::nullopt; + } + const SCEV *StepRec = IndVarBase->getStepRecurrence(SE); + if (!isa<SCEVConstant>(StepRec)) { + FailureReason = "LHS in icmp not induction variable"; + return std::nullopt; + } + ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); + + if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { + FailureReason = "LHS in icmp needs nsw for equality predicates"; + return std::nullopt; + } + + assert(!StepCI->isZero() && "Zero step?"); + bool IsIncreasing = !StepCI->isNegative(); + bool IsSignedPredicate; + const SCEV *StartNext = IndVarBase->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); + const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + const SCEV *Step = SE.getSCEV(StepCI); + + const SCEV *FixedRightSCEV = nullptr; + + // If RightValue resides within loop (but still being loop invariant), + // regenerate it as preheader. + if (auto *I = dyn_cast<Instruction>(RightValue)) + if (L.contains(I->getParent())) + FixedRightSCEV = RightSCEV; + + if (IsIncreasing) { + bool DecreasedRightValueByOne = false; + if (StepCI->isOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (++i != len) { while (++i < len) { + // ... ---> ... + // } } + // If both parts are known non-negative, it is profitable to use + // unsigned comparison in increasing loop. This allows us to make the + // comparison check against "RightSCEV + 1" more optimistic. + if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && + isKnownNonNegativeInLoop(RightSCEV, &L, SE)) + Pred = ICmpInst::ICMP_ULT; + else + Pred = ICmpInst::ICMP_SLT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { + // while (true) { while (true) { + // if (++i == len) ---> if (++i > len - 1) + // break; break; + // ... ... + // } } + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) { + Pred = ICmpInst::ICMP_UGT; + RightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) { + Pred = ICmpInst::ICMP_SGT; + RightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } + } + } + + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + bool FoundExpectedPred = + (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp slt semantically, found something else"; + return std::nullopt; + } + + IsSignedPredicate = ICmpInst::isSigned(Pred); + if (!IsSignedPredicate && !AllowUnsignedLatchCond) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return std::nullopt; + } + + if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe loop bounds"; + return std::nullopt; + } + if (LatchBrExitIdx == 0) { + // We need to increase the right value unless we have already decreased + // it virtually when we replaced EQ with SGT. + if (!DecreasedRightValueByOne) + FixedRightSCEV = + SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + } else { + assert(!DecreasedRightValueByOne && + "Right value can be decreased only for LatchBrExitIdx == 0!"); + } + } else { + bool IncreasedRightValueByOne = false; + if (StepCI->isMinusOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (--i != len) { while (--i > len) { + // ... ---> ... + // } } + // We intentionally don't turn the predicate into UGT even if we know + // that both operands are non-negative, because it will only pessimize + // our check against "RightSCEV - 1". + Pred = ICmpInst::ICMP_SGT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { + // while (true) { while (true) { + // if (--i == len) ---> if (--i < len + 1) + // break; break; + // ... ... + // } } + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { + Pred = ICmpInst::ICMP_ULT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } + } + } + + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + + bool FoundExpectedPred = + (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp sgt semantically, found something else"; + return std::nullopt; + } + + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCond) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return std::nullopt; + } + + if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe bounds"; + return std::nullopt; + } + + if (LatchBrExitIdx == 0) { + // We need to decrease the right value unless we have already increased + // it virtually when we replaced EQ with SLT. + if (!IncreasedRightValueByOne) + FixedRightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + } else { + assert(!IncreasedRightValueByOne && + "Right value can be increased only for LatchBrExitIdx == 0!"); + } + } + BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); + + assert(!L.contains(LatchExit) && "expected an exit block!"); + const DataLayout &DL = Preheader->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "loop-constrainer"); + Instruction *Ins = Preheader->getTerminator(); + + if (FixedRightSCEV) + RightValue = + Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); + + Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); + IndVarStartV->setName("indvar.start"); + + LoopStructure Result; + + Result.Tag = "main"; + Result.Header = Header; + Result.Latch = Latch; + Result.LatchBr = LatchBr; + Result.LatchExit = LatchExit; + Result.LatchBrExitIdx = LatchBrExitIdx; + Result.IndVarStart = IndVarStartV; + Result.IndVarStep = StepCI; + Result.IndVarBase = LeftValue; + Result.IndVarIncreasing = IsIncreasing; + Result.LoopExitAt = RightValue; + Result.IsSignedPredicate = IsSignedPredicate; + Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType()); + + FailureReason = nullptr; + + return Result; +} + +// Add metadata to the loop L to disable loop optimizations. Callers need to +// confirm that optimizing loop L is not beneficial. +static void DisableAllLoopOptsOnLoop(Loop &L) { + // We do not care about any existing loopID related metadata for L, since we + // are setting all loop metadata to false. + LLVMContext &Context = L.getHeader()->getContext(); + // Reserve first location for self reference to the LoopID metadata node. + MDNode *Dummy = MDNode::get(Context, {}); + MDNode *DisableUnroll = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); + Metadata *FalseVal = + ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); + MDNode *DisableVectorize = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); + MDNode *DisableLICMVersioning = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); + MDNode *DisableDistribution = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); + MDNode *NewLoopID = + MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, + DisableLICMVersioning, DisableDistribution}); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L.setLoopID(NewLoopID); +} + +LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI, + function_ref<void(Loop *, bool)> LPMAddNewLoop, + const LoopStructure &LS, ScalarEvolution &SE, + DominatorTree &DT, Type *T, SubRanges SR) + : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), + DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T), + MainLoopStructure(LS), SR(SR) {} + +void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, + const char *Tag) const { + for (BasicBlock *BB : OriginalLoop.getBlocks()) { + BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); + Result.Blocks.push_back(Clone); + Result.Map[BB] = Clone; + } + + auto GetClonedValue = [&Result](Value *V) { + assert(V && "null values not in domain!"); + auto It = Result.Map.find(V); + if (It == Result.Map.end()) + return V; + return static_cast<Value *>(It->second); + }; + + auto *ClonedLatch = + cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); + ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, + MDNode::get(Ctx, {})); + + Result.Structure = MainLoopStructure.map(GetClonedValue); + Result.Structure.Tag = Tag; + + for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { + BasicBlock *ClonedBB = Result.Blocks[i]; + BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; + + assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); + + for (Instruction &I : *ClonedBB) + RemapInstruction(&I, Result.Map, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // Exit blocks will now have one more predecessor and their PHI nodes need + // to be edited to reflect that. No phi nodes need to be introduced because + // the loop is in LCSSA. + + for (auto *SBB : successors(OriginalBB)) { + if (OriginalLoop.contains(SBB)) + continue; // not an exit block + + for (PHINode &PN : SBB->phis()) { + Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); + PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); + SE.forgetValue(&PN); + } + } + } +} + +LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( + const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, + BasicBlock *ContinuationBlock) const { + // We start with a loop with a single latch: + // + // +--------------------+ + // | | + // | preheader | + // | | + // +--------+-----------+ + // | ----------------\ + // | / | + // +--------v----v------+ | + // | | | + // | header | | + // | | | + // +--------------------+ | + // | + // ..... | + // | + // +--------------------+ | + // | | | + // | latch >----------/ + // | | + // +-------v------------+ + // | + // | + // | +--------------------+ + // | | | + // +---> original exit | + // | | + // +--------------------+ + // + // We change the control flow to look like + // + // + // +--------------------+ + // | | + // | preheader >-------------------------+ + // | | | + // +--------v-----------+ | + // | /-------------+ | + // | / | | + // +--------v--v--------+ | | + // | | | | + // | header | | +--------+ | + // | | | | | | + // +--------------------+ | | +-----v-----v-----------+ + // | | | | + // | | | .pseudo.exit | + // | | | | + // | | +-----------v-----------+ + // | | | + // ..... | | | + // | | +--------v-------------+ + // +--------------------+ | | | | + // | | | | | ContinuationBlock | + // | latch >------+ | | | + // | | | +----------------------+ + // +---------v----------+ | + // | | + // | | + // | +---------------^-----+ + // | | | + // +-----> .exit.selector | + // | | + // +----------v----------+ + // | + // +--------------------+ | + // | | | + // | original exit <----+ + // | | + // +--------------------+ + + RewrittenRangeInfo RRI; + + BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); + RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", + &F, BBInsertLocation); + RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, + BBInsertLocation); + + BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); + bool Increasing = LS.IndVarIncreasing; + bool IsSignedPredicate = LS.IsSignedPredicate; + + IRBuilder<> B(PreheaderJump); + auto NoopOrExt = [&](Value *V) { + if (V->getType() == RangeTy) + return V; + return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) + : B.CreateZExt(V, RangeTy, "wide." + V->getName()); + }; + + // EnterLoopCond - is it okay to start executing this `LS'? + Value *EnterLoopCond = nullptr; + auto Pred = + Increasing + ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) + : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); + Value *IndVarStart = NoopOrExt(LS.IndVarStart); + EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); + + B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); + PreheaderJump->eraseFromParent(); + + LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); + B.SetInsertPoint(LS.LatchBr); + Value *IndVarBase = NoopOrExt(LS.IndVarBase); + Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); + + Value *CondForBranch = LS.LatchBrExitIdx == 1 + ? TakeBackedgeLoopCond + : B.CreateNot(TakeBackedgeLoopCond); + + LS.LatchBr->setCondition(CondForBranch); + + B.SetInsertPoint(RRI.ExitSelector); + + // IterationsLeft - are there any more iterations left, given the original + // upper bound on the induction variable? If not, we branch to the "real" + // exit. + Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); + Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); + B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); + + BranchInst *BranchToContinuation = + BranchInst::Create(ContinuationBlock, RRI.PseudoExit); + + // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of + // each of the PHI nodes in the loop header. This feeds into the initial + // value of the same PHI nodes if/when we continue execution. + for (PHINode &PN : LS.Header->phis()) { + PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", + BranchToContinuation); + + NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); + NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), + RRI.ExitSelector); + RRI.PHIValuesAtPseudoExit.push_back(NewPHI); + } + + RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", + BranchToContinuation); + RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); + RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); + + // The latch exit now has a branch from `RRI.ExitSelector' instead of + // `LS.Latch'. The PHI nodes need to be updated to reflect that. + LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); + + return RRI; +} + +void LoopConstrainer::rewriteIncomingValuesForPHIs( + LoopStructure &LS, BasicBlock *ContinuationBlock, + const LoopConstrainer::RewrittenRangeInfo &RRI) const { + unsigned PHIIndex = 0; + for (PHINode &PN : LS.Header->phis()) + PN.setIncomingValueForBlock(ContinuationBlock, + RRI.PHIValuesAtPseudoExit[PHIIndex++]); + + LS.IndVarStart = RRI.IndVarEnd; +} + +BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, + BasicBlock *OldPreheader, + const char *Tag) const { + BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); + BranchInst::Create(LS.Header, Preheader); + + LS.Header->replacePhiUsesWith(OldPreheader, Preheader); + + return Preheader; +} + +void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { + Loop *ParentLoop = OriginalLoop.getParentLoop(); + if (!ParentLoop) + return; + + for (BasicBlock *BB : BBs) + ParentLoop->addBasicBlockToLoop(BB, LI); +} + +Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, + ValueToValueMapTy &VM, + bool IsSubloop) { + Loop &New = *LI.AllocateLoop(); + if (Parent) + Parent->addChildLoop(&New); + else + LI.addTopLevelLoop(&New); + LPMAddNewLoop(&New, IsSubloop); + + // Add all of the blocks in Original to the new loop. + for (auto *BB : Original->blocks()) + if (LI.getLoopFor(BB) == Original) + New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); + + // Add all of the subloops to the new loop. + for (Loop *SubLoop : *Original) + createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); + + return &New; +} + +bool LoopConstrainer::run() { + BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); + assert(Preheader != nullptr && "precondition!"); + + OriginalPreheader = Preheader; + MainLoopPreheader = Preheader; + bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; + bool Increasing = MainLoopStructure.IndVarIncreasing; + IntegerType *IVTy = cast<IntegerType>(RangeTy); + + SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer"); + Instruction *InsertPt = OriginalPreheader->getTerminator(); + + // It would have been better to make `PreLoop' and `PostLoop' + // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy + // constructor. + ClonedLoop PreLoop, PostLoop; + bool NeedsPreLoop = + Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); + bool NeedsPostLoop = + Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); + + Value *ExitPreLoopAt = nullptr; + Value *ExitMainLoopAt = nullptr; + const SCEVConstant *MinusOneS = + cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); + + if (NeedsPreLoop) { + const SCEV *ExitPreLoopAtSCEV = nullptr; + + if (Increasing) + ExitPreLoopAtSCEV = *SR.LowLimit; + else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " + << *(*SR.HighLimit) << "\n"); + return false; + } + + if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { + LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); + return false; + } + + ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); + ExitPreLoopAt->setName("exit.preloop.at"); + } + + if (NeedsPostLoop) { + const SCEV *ExitMainLoopAtSCEV = nullptr; + + if (Increasing) + ExitMainLoopAtSCEV = *SR.HighLimit; + else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " + << *(*SR.LowLimit) << "\n"); + return false; + } + + if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { + LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); + return false; + } + + ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); + ExitMainLoopAt->setName("exit.mainloop.at"); + } + + // We clone these ahead of time so that we don't have to deal with changing + // and temporarily invalid IR as we transform the loops. + if (NeedsPreLoop) + cloneLoop(PreLoop, "preloop"); + if (NeedsPostLoop) + cloneLoop(PostLoop, "postloop"); + + RewrittenRangeInfo PreLoopRRI; + + if (NeedsPreLoop) { + Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, + PreLoop.Structure.Header); + + MainLoopPreheader = + createPreheader(MainLoopStructure, Preheader, "mainloop"); + PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, + ExitPreLoopAt, MainLoopPreheader); + rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, + PreLoopRRI); + } + + BasicBlock *PostLoopPreheader = nullptr; + RewrittenRangeInfo PostLoopRRI; + + if (NeedsPostLoop) { + PostLoopPreheader = + createPreheader(PostLoop.Structure, Preheader, "postloop"); + PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, + ExitMainLoopAt, PostLoopPreheader); + rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, + PostLoopRRI); + } + + BasicBlock *NewMainLoopPreheader = + MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; + BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, + PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, + PostLoopRRI.ExitSelector, NewMainLoopPreheader}; + + // Some of the above may be nullptr, filter them out before passing to + // addToParentLoopIfNeeded. + auto NewBlocksEnd = + std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); + + addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); + + DT.recalculate(F); + + // We need to first add all the pre and post loop blocks into the loop + // structures (as part of createClonedLoopStructure), and then update the + // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating + // LI when LoopSimplifyForm is generated. + Loop *PreL = nullptr, *PostL = nullptr; + if (!PreLoop.Blocks.empty()) { + PreL = createClonedLoopStructure(&OriginalLoop, + OriginalLoop.getParentLoop(), PreLoop.Map, + /* IsSubLoop */ false); + } + + if (!PostLoop.Blocks.empty()) { + PostL = + createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), + PostLoop.Map, /* IsSubLoop */ false); + } + + // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. + auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) { + formLCSSARecursively(*L, DT, &LI, &SE); + simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); + // Pre/post loops are slow paths, we do not need to perform any loop + // optimizations on them. + if (!IsOriginalLoop) + DisableAllLoopOptsOnLoop(*L); + }; + if (PreL) + CanonicalizeLoop(PreL, false); + if (PostL) + CanonicalizeLoop(PostL, false); + CanonicalizeLoop(&OriginalLoop, true); + + /// At this point: + /// - We've broken a "main loop" out of the loop in a way that the "main loop" + /// runs with the induction variable in a subset of [Begin, End). + /// - There is no overflow when computing "main loop" exit limit. + /// - Max latch taken count of the loop is limited. + /// It guarantees that induction variable will not overflow iterating in the + /// "main loop". + if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase)) + if (IsSignedPredicate) + cast<BinaryOperator>(MainLoopStructure.IndVarBase) + ->setHasNoSignedWrap(true); + /// TODO: support unsigned predicate. + /// To add NUW flag we need to prove that both operands of BO are + /// non-negative. E.g: + /// ... + /// %iv.next = add nsw i32 %iv, -1 + /// %cmp = icmp ult i32 %iv.next, %n + /// br i1 %cmp, label %loopexit, label %loop + /// + /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will + /// overflow, therefore NUW flag is not legal here. + + return true; +} diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp index d701cf110154..f76fa3bb6c61 100644 --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -351,11 +351,20 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, MaxPeelCount = std::min((unsigned)SC->getAPInt().getLimitedValue() - 1, MaxPeelCount); - auto ComputePeelCount = [&](Value *Condition) -> void { - if (!Condition->getType()->isIntegerTy()) + const unsigned MaxDepth = 4; + std::function<void(Value *, unsigned)> ComputePeelCount = + [&](Value *Condition, unsigned Depth) -> void { + if (!Condition->getType()->isIntegerTy() || Depth >= MaxDepth) return; Value *LeftVal, *RightVal; + if (match(Condition, m_And(m_Value(LeftVal), m_Value(RightVal))) || + match(Condition, m_Or(m_Value(LeftVal), m_Value(RightVal)))) { + ComputePeelCount(LeftVal, Depth + 1); + ComputePeelCount(RightVal, Depth + 1); + return; + } + CmpInst::Predicate Pred; if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal)))) return; @@ -443,7 +452,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, for (BasicBlock *BB : L.blocks()) { for (Instruction &I : *BB) { if (SelectInst *SI = dyn_cast<SelectInst>(&I)) - ComputePeelCount(SI->getCondition()); + ComputePeelCount(SI->getCondition(), 0); } auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); @@ -454,7 +463,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, if (L.getLoopLatch() == BB) continue; - ComputePeelCount(BI->getCondition()); + ComputePeelCount(BI->getCondition(), 0); } return DesiredPeelCount; @@ -624,21 +633,24 @@ struct WeightInfo { /// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to /// go to exit. /// Then, Estimated ExitCount = F / E. -/// For I-th (counting from 0) peeled off iteration we set the the weights for +/// For I-th (counting from 0) peeled off iteration we set the weights for /// the peeled exit as (EC - I, 1). It gives us reasonable distribution, /// The probability to go to exit 1/(EC-I) increases. At the same time /// the estimated exit count in the remainder loop reduces by I. /// To avoid dealing with division rounding we can just multiple both part /// of weights to E and use weight as (F - I * E, E). static void updateBranchWeights(Instruction *Term, WeightInfo &Info) { - MDBuilder MDB(Term->getContext()); - Term->setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights(Info.Weights)); + setBranchWeights(*Term, Info.Weights); for (auto [Idx, SubWeight] : enumerate(Info.SubWeights)) if (SubWeight != 0) - Info.Weights[Idx] = Info.Weights[Idx] > SubWeight - ? Info.Weights[Idx] - SubWeight - : 1; + // Don't set the probability of taking the edge from latch to loop header + // to less than 1:1 ratio (meaning Weight should not be lower than + // SubWeight), as this could significantly reduce the loop's hotness, + // which would be incorrect in the case of underestimating the trip count. + Info.Weights[Idx] = + Info.Weights[Idx] > SubWeight + ? std::max(Info.Weights[Idx] - SubWeight, SubWeight) + : SubWeight; } /// Initialize the weights for all exiting blocks. @@ -685,14 +697,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos, } } -/// Update the weights of original exiting block after peeling off all -/// iterations. -static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) { - MDBuilder MDB(Term->getContext()); - Term->setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights(Info.Weights)); -} - /// Clones the body of the loop L, putting it between \p InsertTop and \p /// InsertBot. /// \param IterNumber The serial number of the iteration currently being @@ -1028,8 +1032,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, PHI->setIncomingValueForBlock(NewPreHeader, NewVal); } - for (const auto &[Term, Info] : Weights) - fixupBranchWeights(Term, Info); + for (const auto &[Term, Info] : Weights) { + setBranchWeights(*Term, Info.Weights); + } // Update Metadata for count of peeled off iterations. unsigned AlreadyPeeled = 0; diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index d81db5647c60..76280ed492b3 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -25,6 +25,8 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -50,6 +52,9 @@ static cl::opt<bool> cl::desc("Allow loop rotation multiple times in order to reach " "a better latch exit")); +// Probability that a rotated loop has zero trip count / is never entered. +static constexpr uint32_t ZeroTripCountWeights[] = {1, 127}; + namespace { /// A simple loop rotation transformation. class LoopRotate { @@ -154,7 +159,8 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug // intrinsics. SmallVector<DbgValueInst *, 1> DbgValues; - llvm::findDbgValues(DbgValues, OrigHeaderVal); + SmallVector<DPValue *, 1> DPValues; + llvm::findDbgValues(DbgValues, OrigHeaderVal, &DPValues); for (auto &DbgValue : DbgValues) { // The original users in the OrigHeader are already using the original // definitions. @@ -175,6 +181,29 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, NewVal = UndefValue::get(OrigHeaderVal->getType()); DbgValue->replaceVariableLocationOp(OrigHeaderVal, NewVal); } + + // RemoveDIs: duplicate implementation for non-instruction debug-info + // storage in DPValues. + for (DPValue *DPV : DPValues) { + // The original users in the OrigHeader are already using the original + // definitions. + BasicBlock *UserBB = DPV->getMarker()->getParent(); + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped and anything else can be handled by + // the SSAUpdater. To avoid adding PHINodes, check if the value is + // available in UserBB, if not substitute undef. + Value *NewVal; + if (UserBB == OrigPreheader) + NewVal = OrigPreHeaderVal; + else if (SSA.HasValueForBlock(UserBB)) + NewVal = SSA.GetValueInMiddleOfBlock(UserBB); + else + NewVal = UndefValue::get(OrigHeaderVal->getType()); + DPV->replaceVariableLocationOp(OrigHeaderVal, NewVal); + } } } @@ -244,6 +273,123 @@ static bool canRotateDeoptimizingLatchExit(Loop *L) { return false; } +static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, + bool HasConditionalPreHeader, + bool SuccsSwapped) { + MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI); + if (WeightMD == nullptr) + return; + + // LoopBI should currently be a clone of PreHeaderBI with the same + // metadata. But we double check to make sure we don't have a degenerate case + // where instsimplify changed the instructions. + if (WeightMD != getBranchWeightMDNode(LoopBI)) + return; + + SmallVector<uint32_t, 2> Weights; + extractFromBranchWeightMD(WeightMD, Weights); + if (Weights.size() != 2) + return; + uint32_t OrigLoopExitWeight = Weights[0]; + uint32_t OrigLoopBackedgeWeight = Weights[1]; + + if (SuccsSwapped) + std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight); + + // Update branch weights. Consider the following edge-counts: + // + // | |-------- | + // V V | V + // Br i1 ... | Br i1 ... + // | | | | | + // x| y| | becomes: | y0| |----- + // V V | | V V | + // Exit Loop | | Loop | + // | | | Br i1 ... | + // ----- | | | | + // x0| x1| y1 | | + // V V ---- + // Exit + // + // The following must hold: + // - x == x0 + x1 # counts to "exit" must stay the same. + // - y0 == x - x0 == x1 # how often loop was entered at all. + // - y1 == y - y0 # How often loop was repeated (after first iter.). + // + // We cannot generally deduce how often we had a zero-trip count loop so we + // have to make a guess for how to distribute x among the new x0 and x1. + + uint32_t ExitWeight0; // aka x0 + uint32_t ExitWeight1; // aka x1 + uint32_t EnterWeight; // aka y0 + uint32_t LoopBackWeight; // aka y1 + if (OrigLoopExitWeight > 0 && OrigLoopBackedgeWeight > 0) { + ExitWeight0 = 0; + if (HasConditionalPreHeader) { + // Here we cannot know how many 0-trip count loops we have, so we guess: + if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) { + // If the loop count is bigger than the exit count then we set + // probabilities as if 0-trip count nearly never happens. + ExitWeight0 = ZeroTripCountWeights[0]; + // Scale up counts if necessary so we can match `ZeroTripCountWeights` + // for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio. + while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) { + // ... but don't overflow. + uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1); + if ((OrigLoopBackedgeWeight & HighBit) != 0 || + (OrigLoopExitWeight & HighBit) != 0) + break; + OrigLoopBackedgeWeight <<= 1; + OrigLoopExitWeight <<= 1; + } + } else { + // If there's a higher exit-count than backedge-count then we set + // probabilities as if there are only 0-trip and 1-trip cases. + ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight; + } + } + ExitWeight1 = OrigLoopExitWeight - ExitWeight0; + EnterWeight = ExitWeight1; + LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; + } else if (OrigLoopExitWeight == 0) { + if (OrigLoopBackedgeWeight == 0) { + // degenerate case... keep everything zero... + ExitWeight0 = 0; + ExitWeight1 = 0; + EnterWeight = 0; + LoopBackWeight = 0; + } else { + // Special case "LoopExitWeight == 0" weights which behaves like an + // endless where we don't want loop-enttry (y0) to be the same as + // loop-exit (x1). + ExitWeight0 = 0; + ExitWeight1 = 0; + EnterWeight = 1; + LoopBackWeight = OrigLoopBackedgeWeight; + } + } else { + // loop is never entered. + assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero"); + ExitWeight0 = 1; + ExitWeight1 = 1; + EnterWeight = 0; + LoopBackWeight = 0; + } + + const uint32_t LoopBIWeights[] = { + SuccsSwapped ? LoopBackWeight : ExitWeight1, + SuccsSwapped ? ExitWeight1 : LoopBackWeight, + }; + setBranchWeights(LoopBI, LoopBIWeights); + if (HasConditionalPreHeader) { + const uint32_t PreHeaderBIWeights[] = { + SuccsSwapped ? EnterWeight : ExitWeight0, + SuccsSwapped ? ExitWeight0 : EnterWeight, + }; + setBranchWeights(PreHeaderBI, PreHeaderBIWeights); + } +} + /// Rotate loop LP. Return true if the loop is rotated. /// /// \param SimplifiedLatch is true if the latch was just folded into the final @@ -363,7 +509,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // loop. Otherwise loop is not suitable for rotation. BasicBlock *Exit = BI->getSuccessor(0); BasicBlock *NewHeader = BI->getSuccessor(1); - if (L->contains(Exit)) + bool BISuccsSwapped = L->contains(Exit); + if (BISuccsSwapped) std::swap(Exit, NewHeader); assert(NewHeader && "Unable to determine new loop header"); assert(L->contains(NewHeader) && !L->contains(Exit) && @@ -394,20 +541,32 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // duplication. using DbgIntrinsicHash = std::pair<std::pair<hash_code, DILocalVariable *>, DIExpression *>; - auto makeHash = [](DbgVariableIntrinsic *D) -> DbgIntrinsicHash { + auto makeHash = [](auto *D) -> DbgIntrinsicHash { auto VarLocOps = D->location_ops(); return {{hash_combine_range(VarLocOps.begin(), VarLocOps.end()), D->getVariable()}, D->getExpression()}; }; + SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; for (Instruction &I : llvm::drop_begin(llvm::reverse(*OrigPreheader))) { - if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I)) + if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I)) { DbgIntrinsics.insert(makeHash(DII)); - else + // Until RemoveDIs supports dbg.declares in DPValue format, we'll need + // to collect DPValues attached to any other debug intrinsics. + for (const DPValue &DPV : DII->getDbgValueRange()) + DbgIntrinsics.insert(makeHash(&DPV)); + } else { break; + } } + // Build DPValue hashes for DPValues attached to the terminator, which isn't + // considered in the loop above. + for (const DPValue &DPV : + OrigPreheader->getTerminator()->getDbgValueRange()) + DbgIntrinsics.insert(makeHash(&DPV)); + // Remember the local noalias scope declarations in the header. After the // rotation, they must be duplicated and the scope must be cloned. This // avoids unwanted interaction across iterations. @@ -416,6 +575,29 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { if (auto *Decl = dyn_cast<NoAliasScopeDeclInst>(&I)) NoAliasDeclInstructions.push_back(Decl); + Module *M = OrigHeader->getModule(); + + // Track the next DPValue to clone. If we have a sequence where an + // instruction is hoisted instead of being cloned: + // DPValue blah + // %foo = add i32 0, 0 + // DPValue xyzzy + // %bar = call i32 @foobar() + // where %foo is hoisted, then the DPValue "blah" will be seen twice, once + // attached to %foo, then when %foo his hoisted it will "fall down" onto the + // function call: + // DPValue blah + // DPValue xyzzy + // %bar = call i32 @foobar() + // causing it to appear attached to the call too. + // + // To avoid this, cloneDebugInfoFrom takes an optional "start cloning from + // here" position to account for this behaviour. We point it at any DPValues + // on the next instruction, here labelled xyzzy, before we hoist %foo. + // Later, we only only clone DPValues from that position (xyzzy) onwards, + // which avoids cloning DPValue "blah" multiple times. + std::optional<DPValue::self_iterator> NextDbgInst = std::nullopt; + while (I != E) { Instruction *Inst = &*I++; @@ -428,7 +610,21 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && !Inst->mayWriteToMemory() && !Inst->isTerminator() && !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { + + if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat) { + auto DbgValueRange = + LoopEntryBranch->cloneDebugInfoFrom(Inst, NextDbgInst); + RemapDPValueRange(M, DbgValueRange, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + // Erase anything we've seen before. + for (DPValue &DPV : make_early_inc_range(DbgValueRange)) + if (DbgIntrinsics.count(makeHash(&DPV))) + DPV.eraseFromParent(); + } + + NextDbgInst = I->getDbgValueRange().begin(); Inst->moveBefore(LoopEntryBranch); + ++NumInstrsHoisted; continue; } @@ -439,6 +635,17 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { ++NumInstrsDuplicated; + if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat) { + auto Range = C->cloneDebugInfoFrom(Inst, NextDbgInst); + RemapDPValueRange(M, Range, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + NextDbgInst = std::nullopt; + // Erase anything we've seen before. + for (DPValue &DPV : make_early_inc_range(Range)) + if (DbgIntrinsics.count(makeHash(&DPV))) + DPV.eraseFromParent(); + } + // Eagerly remap the operands of the instruction. RemapInstruction(C, ValueMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); @@ -553,6 +760,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // OrigPreHeader's old terminator (the original branch into the loop), and // remove the corresponding incoming values from the PHI nodes in OrigHeader. LoopEntryBranch->eraseFromParent(); + OrigPreheader->flushTerminatorDbgValues(); // Update MemorySSA before the rewrite call below changes the 1:1 // instruction:cloned_instruction_or_value mapping. @@ -605,9 +813,14 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // to split as many edges. BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator()); assert(PHBI->isConditional() && "Should be clone of BI condbr!"); - if (!isa<ConstantInt>(PHBI->getCondition()) || - PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) != - NewHeader) { + const Value *Cond = PHBI->getCondition(); + const bool HasConditionalPreHeader = + !isa<ConstantInt>(Cond) || + PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader; + + updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped); + + if (HasConditionalPreHeader) { // The conditional branch can't be folded, handle the general case. // Split edges as necessary to preserve LoopSimplify form. diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 3e604fdf2e11..07e622b1577f 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -429,8 +429,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, PN->setIncomingBlock(0, PN->getIncomingBlock(PreheaderIdx)); } // Nuke all entries except the zero'th. - for (unsigned i = 0, e = PN->getNumIncomingValues()-1; i != e; ++i) - PN->removeIncomingValue(e-i, false); + PN->removeIncomingValueIf([](unsigned Idx) { return Idx != 0; }, + /* DeletePHIIfEmpty */ false); // Finally, add the newly constructed PHI node as the entry for the BEBlock. PN->addIncoming(NewPN, BEBlock); diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 511dd61308f9..ee6f7b35750a 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist_iterator.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -838,7 +837,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, DTUToUse ? nullptr : DT)) { // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); - llvm::erase_value(UnrolledLoopBlocks, Dest); + llvm::erase(UnrolledLoopBlocks, Dest); } } } diff --git a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp index 31b8cd34eb24..3c06a6e47a30 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -19,7 +19,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 1e22eca30d2d..612f69970881 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -56,6 +56,17 @@ static cl::opt<bool> UnrollRuntimeOtherExitPredictable( "unroll-runtime-other-exit-predictable", cl::init(false), cl::Hidden, cl::desc("Assume the non latch exit block to be predictable")); +// Probability that the loop trip count is so small that after the prolog +// we do not enter the unrolled loop at all. +// It is unlikely that the loop trip count is smaller than the unroll factor; +// other than that, the choice of constant is not tuned yet. +static const uint32_t UnrolledLoopHeaderWeights[] = {1, 127}; +// Probability that the loop trip count is so small that we skip the unrolled +// loop completely and immediately enter the epilogue loop. +// It is unlikely that the loop trip count is smaller than the unroll factor; +// other than that, the choice of constant is not tuned yet. +static const uint32_t EpilogHeaderWeights[] = {1, 127}; + /// Connect the unrolling prolog code to the original loop. /// The unrolling prolog code contains code to execute the /// 'extra' iterations if the run-time trip count modulo the @@ -105,8 +116,8 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // PrologLatch. When supporting multiple-exiting block loops, we can have // two or more blocks that have the LatchExit as the target in the // original loop. - PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr", - PrologExit->getFirstNonPHI()); + PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr"); + NewPN->insertBefore(PrologExit->getFirstNonPHIIt()); // Adding a value to the new PHI node from the original loop preheader. // This is the value that skips all the prolog code. if (L->contains(&PN)) { @@ -169,7 +180,14 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, SplitBlockPredecessors(OriginalLoopLatchExit, Preds, ".unr-lcssa", DT, LI, nullptr, PreserveLCSSA); // Add the branch to the exit block (around the unrolled loop) - B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader); + MDNode *BranchWeights = nullptr; + if (hasBranchWeightMD(*Latch->getTerminator())) { + // Assume loop is nearly always entered. + MDBuilder MDB(B.getContext()); + BranchWeights = MDB.createBranchWeights(UnrolledLoopHeaderWeights); + } + B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader, + BranchWeights); InsertPt->eraseFromParent(); if (DT) { auto *NewDom = DT->findNearestCommonDominator(OriginalLoopLatchExit, @@ -194,8 +212,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, BasicBlock *Exit, BasicBlock *PreHeader, BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, DominatorTree *DT, - LoopInfo *LI, bool PreserveLCSSA, - ScalarEvolution &SE) { + LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE, + unsigned Count) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]); @@ -269,8 +287,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, for (PHINode &PN : Succ->phis()) { // Add new PHI nodes to the loop exit block and update epilog // PHIs with the new PHI values. - PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr", - NewExit->getFirstNonPHI()); + PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr"); + NewPN->insertBefore(NewExit->getFirstNonPHIIt()); // Adding a value to the new PHI node from the unrolling loop preheader. NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader); // Adding a value to the new PHI node from the unrolling loop latch. @@ -292,7 +310,13 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, nullptr, PreserveLCSSA); // Add the branch to the exit block (around the unrolling loop) - B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit); + MDNode *BranchWeights = nullptr; + if (hasBranchWeightMD(*Latch->getTerminator())) { + // Assume equal distribution in interval [0, Count). + MDBuilder MDB(B.getContext()); + BranchWeights = MDB.createBranchWeights(1, Count - 1); + } + B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights); InsertPt->eraseFromParent(); if (DT) { auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit); @@ -316,8 +340,9 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, const bool UnrollRemainder, BasicBlock *InsertTop, BasicBlock *InsertBot, BasicBlock *Preheader, - std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, - ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI) { + std::vector<BasicBlock *> &NewBlocks, + LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, + DominatorTree *DT, LoopInfo *LI, unsigned Count) { StringRef suffix = UseEpilogRemainder ? "epil" : "prol"; BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); @@ -363,14 +388,34 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); IRBuilder<> Builder(LatchBR); - PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, - suffix + ".iter", - FirstLoopBB->getFirstNonPHI()); + PHINode *NewIdx = + PHINode::Create(NewIter->getType(), 2, suffix + ".iter"); + NewIdx->insertBefore(FirstLoopBB->getFirstNonPHIIt()); auto *Zero = ConstantInt::get(NewIdx->getType(), 0); auto *One = ConstantInt::get(NewIdx->getType(), 1); - Value *IdxNext = Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next"); + Value *IdxNext = + Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next"); Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp"); - Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot); + MDNode *BranchWeights = nullptr; + if (hasBranchWeightMD(*LatchBR)) { + uint32_t ExitWeight; + uint32_t BackEdgeWeight; + if (Count >= 3) { + // Note: We do not enter this loop for zero-remainders. The check + // is at the end of the loop. We assume equal distribution between + // possible remainders in [1, Count). + ExitWeight = 1; + BackEdgeWeight = (Count - 2) / 2; + } else { + // Unnecessary backedge, should never be taken. The conditional + // jump should be optimized away later. + ExitWeight = 1; + BackEdgeWeight = 0; + } + MDBuilder MDB(Builder.getContext()); + BranchWeights = MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + } + Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights); NewIdx->addIncoming(Zero, InsertTop); NewIdx->addIncoming(IdxNext, NewBB); LatchBR->eraseFromParent(); @@ -464,32 +509,6 @@ static bool canProfitablyUnrollMultiExitLoop( // know of kinds of multiexit loops that would benefit from unrolling. } -// Assign the maximum possible trip count as the back edge weight for the -// remainder loop if the original loop comes with a branch weight. -static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop, - Loop *RemainderLoop, - uint64_t UnrollFactor) { - uint64_t TrueWeight, FalseWeight; - BranchInst *LatchBR = - cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()); - if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight)) - return; - uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() - ? FalseWeight - : TrueWeight; - assert(UnrollFactor > 1); - uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight; - BasicBlock *Header = RemainderLoop->getHeader(); - BasicBlock *Latch = RemainderLoop->getLoopLatch(); - auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator()); - unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1); - MDBuilder MDB(RemainderLatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) - : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); - RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); -} - /// Calculate ModVal = (BECount + 1) % Count on the abstract integer domain /// accounting for the possibility of unsigned overflow in the 2s complement /// domain. Preconditions: @@ -775,7 +794,13 @@ bool llvm::UnrollRuntimeLoopRemainder( BasicBlock *RemainderLoop = UseEpilogRemainder ? NewExit : PrologPreHeader; BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit; // Branch to either remainder (extra iterations) loop or unrolling loop. - B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop); + MDNode *BranchWeights = nullptr; + if (hasBranchWeightMD(*Latch->getTerminator())) { + // Assume loop is nearly always entered. + MDBuilder MDB(B.getContext()); + BranchWeights = MDB.createBranchWeights(EpilogHeaderWeights); + } + B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights); PreHeaderBR->eraseFromParent(); if (DT) { if (UseEpilogRemainder) @@ -804,12 +829,7 @@ bool llvm::UnrollRuntimeLoopRemainder( BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; Loop *remainderLoop = CloneLoopBlocks( L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot, - NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); - - // Assign the maximum possible trip count as the back edge weight for the - // remainder loop if the original loop comes with a branch weight. - if (remainderLoop && !UnrollRemainder) - updateLatchBranchWeightsForRemainderLoop(L, remainderLoop, Count); + NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count); // Insert the cloned blocks into the function. F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end()); @@ -893,9 +913,12 @@ bool llvm::UnrollRuntimeLoopRemainder( // Rewrite the cloned instruction operands to use the values created when the // clone is created. for (BasicBlock *BB : NewBlocks) { + Module *M = BB->getModule(); for (Instruction &I : *BB) { RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + RemapDPValueRange(M, I.getDbgValueRange(), VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } } @@ -903,7 +926,7 @@ bool llvm::UnrollRuntimeLoopRemainder( // Connect the epilog code to the original loop and update the // PHI functions. ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader, - NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE); + NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count); // Update counter in loop for unrolling. // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. @@ -912,8 +935,8 @@ bool llvm::UnrollRuntimeLoopRemainder( IRBuilder<> B2(NewPreHeader->getTerminator()); Value *TestVal = B2.CreateSub(TripCount, ModVal, "unroll_iter"); BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); - PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter", - Header->getFirstNonPHI()); + PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter"); + NewIdx->insertBefore(Header->getFirstNonPHIIt()); B2.SetInsertPoint(LatchBR); auto *Zero = ConstantInt::get(NewIdx->getType(), 0); auto *One = ConstantInt::get(NewIdx->getType(), 1); diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 7d6662c44f07..59485126b280 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -296,7 +296,7 @@ std::optional<MDNode *> llvm::makeFollowupLoopID( StringRef AttrName = cast<MDString>(NameMD)->getString(); // Do not inherit excluded attributes. - return !AttrName.startswith(InheritOptionsExceptPrefix); + return !AttrName.starts_with(InheritOptionsExceptPrefix); }; if (InheritThisAttribute(Op)) @@ -556,12 +556,8 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Removes all incoming values from all other exiting blocks (including // duplicate values from an exiting block). // Nuke all entries except the zero'th entry which is the preheader entry. - // NOTE! We need to remove Incoming Values in the reverse order as done - // below, to keep the indices valid for deletion (removeIncomingValues - // updates getNumIncomingValues and shifts all values down into the - // operand being deleted). - for (unsigned i = 0, e = P.getNumIncomingValues() - 1; i != e; ++i) - P.removeIncomingValue(e - i, false); + P.removeIncomingValueIf([](unsigned Idx) { return Idx != 0; }, + /* DeletePHIIfEmpty */ false); assert((P.getNumIncomingValues() == 1 && P.getIncomingBlock(PredIndex) == Preheader) && @@ -608,6 +604,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Use a map to unique and a vector to guarantee deterministic ordering. llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet; llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst; + llvm::SmallVector<DPValue *, 4> DeadDPValues; if (ExitBlock) { // Given LCSSA form is satisfied, we should not have users of instructions @@ -632,6 +629,24 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, "Unexpected user in reachable block"); U.set(Poison); } + + // RemoveDIs: do the same as below for DPValues. + if (Block->IsNewDbgInfoFormat) { + for (DPValue &DPV : + llvm::make_early_inc_range(I.getDbgValueRange())) { + DebugVariable Key(DPV.getVariable(), DPV.getExpression(), + DPV.getDebugLoc().get()); + if (!DeadDebugSet.insert(Key).second) + continue; + // Unlinks the DPV from it's container, for later insertion. + DPV.removeFromParent(); + DeadDPValues.push_back(&DPV); + } + } + + // For one of each variable encountered, preserve a debug intrinsic (set + // to Poison) and transfer it to the loop exit. This terminates any + // variable locations that were set during the loop. auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I); if (!DVI) continue; @@ -646,12 +661,22 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // be be replaced with undef. Loop invariant values will still be available. // Move dbg.values out the loop so that earlier location ranges are still // terminated and loop invariant assignments are preserved. - Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI(); - assert(InsertDbgValueBefore && + DIBuilder DIB(*ExitBlock->getModule()); + BasicBlock::iterator InsertDbgValueBefore = + ExitBlock->getFirstInsertionPt(); + assert(InsertDbgValueBefore != ExitBlock->end() && "There should be a non-PHI instruction in exit block, else these " "instructions will have no parent."); + for (auto *DVI : DeadDebugInst) - DVI->moveBefore(InsertDbgValueBefore); + DVI->moveBefore(*ExitBlock, InsertDbgValueBefore); + + // Due to the "head" bit in BasicBlock::iterator, we're going to insert + // each DPValue right at the start of the block, wheras dbg.values would be + // repeatedly inserted before the first instruction. To replicate this + // behaviour, do it backwards. + for (DPValue *DPV : llvm::reverse(DeadDPValues)) + ExitBlock->insertDPValueBefore(DPV, InsertDbgValueBefore); } // Remove the block from the reference counting scheme, so that we can @@ -937,8 +962,8 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { } } -Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, - RecurKind RK, Value *Left, Value *Right) { +Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, + RecurKind RK, Value *Left, Value *Right) { if (auto VTy = dyn_cast<VectorType>(Left->getType())) StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal); Value *Cmp = @@ -1028,14 +1053,12 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } -Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, - const TargetTransformInfo *TTI, - Value *Src, - const RecurrenceDescriptor &Desc, - PHINode *OrigPhi) { - assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind( - Desc.getRecurrenceKind()) && - "Unexpected reduction kind"); +Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src, + const RecurrenceDescriptor &Desc, + PHINode *OrigPhi) { + assert( + RecurrenceDescriptor::isAnyOfRecurrenceKind(Desc.getRecurrenceKind()) && + "Unexpected reduction kind"); Value *InitVal = Desc.getRecurrenceStartValue(); Value *NewVal = nullptr; @@ -1068,9 +1091,8 @@ Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select"); } -Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, - const TargetTransformInfo *TTI, - Value *Src, RecurKind RdxKind) { +Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src, + RecurKind RdxKind) { auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); switch (RdxKind) { case RecurKind::Add: @@ -1111,7 +1133,6 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, } Value *llvm::createTargetReduction(IRBuilderBase &B, - const TargetTransformInfo *TTI, const RecurrenceDescriptor &Desc, Value *Src, PHINode *OrigPhi) { // TODO: Support in-order reductions based on the recurrence descriptor. @@ -1121,10 +1142,10 @@ Value *llvm::createTargetReduction(IRBuilderBase &B, B.setFastMathFlags(Desc.getFastMathFlags()); RecurKind RK = Desc.getRecurrenceKind(); - if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) - return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi); + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) + return createAnyOfTargetReduction(B, Src, Desc, OrigPhi); - return createSimpleTargetReduction(B, TTI, Src, RK); + return createSimpleTargetReduction(B, Src, RK); } Value *llvm::createOrderedReduction(IRBuilderBase &B, @@ -1453,7 +1474,7 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, // Note that we must not perform expansions until after // we query *all* the costs, because if we perform temporary expansion // inbetween, one that we might not intend to keep, said expansion - // *may* affect cost calculation of the the next SCEV's we'll query, + // *may* affect cost calculation of the next SCEV's we'll query, // and next SCEV may errneously get smaller cost. // Collect all the candidate PHINodes to be rewritten. @@ -1632,42 +1653,92 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, struct PointerBounds { TrackingVH<Value> Start; TrackingVH<Value> End; + Value *StrideToCheck; }; /// Expand code for the lower and upper bound of the pointer group \p CG /// in \p TheLoop. \return the values for the bounds. static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, Loop *TheLoop, Instruction *Loc, - SCEVExpander &Exp) { + SCEVExpander &Exp, bool HoistRuntimeChecks) { LLVMContext &Ctx = Loc->getContext(); - Type *PtrArithTy = Type::getInt8PtrTy(Ctx, CG->AddressSpace); + Type *PtrArithTy = PointerType::get(Ctx, CG->AddressSpace); Value *Start = nullptr, *End = nullptr; LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); - Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); - End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr; + + // If the Low and High values are themselves loop-variant, then we may want + // to expand the range to include those covered by the outer loop as well. + // There is a trade-off here with the advantage being that creating checks + // using the expanded range permits the runtime memory checks to be hoisted + // out of the outer loop. This reduces the cost of entering the inner loop, + // which can be significant for low trip counts. The disadvantage is that + // there is a chance we may now never enter the vectorized inner loop, + // whereas using a restricted range check could have allowed us to enter at + // least once. This is why the behaviour is not currently the default and is + // controlled by the parameter 'HoistRuntimeChecks'. + if (HoistRuntimeChecks && TheLoop->getParentLoop() && + isa<SCEVAddRecExpr>(High) && isa<SCEVAddRecExpr>(Low)) { + auto *HighAR = cast<SCEVAddRecExpr>(High); + auto *LowAR = cast<SCEVAddRecExpr>(Low); + const Loop *OuterLoop = TheLoop->getParentLoop(); + const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE()); + if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) && + HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) { + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + const SCEV *OuterExitCount = + Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch); + if (!isa<SCEVCouldNotCompute>(OuterExitCount) && + OuterExitCount->getType()->isIntegerTy()) { + const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration( + OuterExitCount, *Exp.getSE()); + if (!isa<SCEVCouldNotCompute>(NewHigh)) { + LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include " + "outer loop in order to permit hoisting\n"); + High = NewHigh; + Low = cast<SCEVAddRecExpr>(Low)->getStart(); + // If there is a possibility that the stride is negative then we have + // to generate extra checks to ensure the stride is positive. + if (!Exp.getSE()->isKnownNonNegative(Recur)) { + Stride = Recur; + LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is " + "positive: " + << *Stride << '\n'); + } + } + } + } + } + + Start = Exp.expandCodeFor(Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(High, PtrArithTy, Loc); if (CG->NeedsFreeze) { IRBuilder<> Builder(Loc); Start = Builder.CreateFreeze(Start, Start->getName() + ".fr"); End = Builder.CreateFreeze(End, End->getName() + ".fr"); } - LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n"); - return {Start, End}; + Value *StrideVal = + Stride ? Exp.expandCodeFor(Stride, Stride->getType(), Loc) : nullptr; + LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n"); + return {Start, End, StrideVal}; } /// Turns a collection of checks into a collection of expanded upper and /// lower bounds for both pointers in the check. static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, - Instruction *Loc, SCEVExpander &Exp) { + Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) { SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; // Here we're relying on the SCEV Expander's cache to only emit code for the // same bounds once. transform(PointerChecks, std::back_inserter(ChecksWithBounds), [&](const RuntimePointerCheck &Check) { - PointerBounds First = expandBounds(Check.first, L, Loc, Exp), - Second = expandBounds(Check.second, L, Loc, Exp); + PointerBounds First = expandBounds(Check.first, L, Loc, Exp, + HoistRuntimeChecks), + Second = expandBounds(Check.second, L, Loc, Exp, + HoistRuntimeChecks); return std::make_pair(First, Second); }); @@ -1677,10 +1748,11 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, Value *llvm::addRuntimeChecks( Instruction *Loc, Loop *TheLoop, const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, - SCEVExpander &Exp) { + SCEVExpander &Exp, bool HoistRuntimeChecks) { // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible - auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp); + auto ExpandedChecks = + expandBounds(PointerChecks, TheLoop, Loc, Exp, HoistRuntimeChecks); LLVMContext &Ctx = Loc->getContext(); IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx, @@ -1693,21 +1765,13 @@ Value *llvm::addRuntimeChecks( const PointerBounds &A = Check.first, &B = Check.second; // Check if two pointers (A and B) conflict where conflict is computed as: // start(A) <= end(B) && start(B) <= end(A) - unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); - unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); - assert((AS0 == B.End->getType()->getPointerAddressSpace()) && - (AS1 == A.End->getType()->getPointerAddressSpace()) && + assert((A.Start->getType()->getPointerAddressSpace() == + B.End->getType()->getPointerAddressSpace()) && + (B.Start->getType()->getPointerAddressSpace() == + A.End->getType()->getPointerAddressSpace()) && "Trying to bounds check pointers with different address spaces"); - Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); - Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); - - Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); - Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); - Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); - Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); - // [A|B].Start points to the first accessed byte under base [A|B]. // [A|B].End points to the last accessed byte, plus one. // There is no conflict when the intervals are disjoint: @@ -1716,9 +1780,21 @@ Value *llvm::addRuntimeChecks( // bound0 = (B.Start < A.End) // bound1 = (A.Start < B.End) // IsConflict = bound0 & bound1 - Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); - Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + Value *Cmp0 = ChkBuilder.CreateICmpULT(A.Start, B.End, "bound0"); + Value *Cmp1 = ChkBuilder.CreateICmpULT(B.Start, A.End, "bound1"); Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + if (A.StrideToCheck) { + Value *IsNegativeStride = ChkBuilder.CreateICmpSLT( + A.StrideToCheck, ConstantInt::get(A.StrideToCheck->getType(), 0), + "stride.check"); + IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride); + } + if (B.StrideToCheck) { + Value *IsNegativeStride = ChkBuilder.CreateICmpSLT( + B.StrideToCheck, ConstantInt::get(B.StrideToCheck->getType(), 0), + "stride.check"); + IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride); + } if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); @@ -1740,23 +1816,31 @@ Value *llvm::addDiffRuntimeChecks( // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; + auto &SE = *Expander.getSE(); + // Map to keep track of created compares, The key is the pair of operands for + // the compare, to allow detecting and re-using redundant compares. + DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares; for (const auto &C : Checks) { Type *Ty = C.SinkStart->getType(); // Compute VF * IC * AccessSize. auto *VFTimesUFTimesSize = ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()), ConstantInt::get(Ty, IC * C.AccessSize)); - Value *Sink = Expander.expandCodeFor(C.SinkStart, Ty, Loc); - Value *Src = Expander.expandCodeFor(C.SrcStart, Ty, Loc); - if (C.NeedsFreeze) { - IRBuilder<> Builder(Loc); - Sink = Builder.CreateFreeze(Sink, Sink->getName() + ".fr"); - Src = Builder.CreateFreeze(Src, Src->getName() + ".fr"); - } - Value *Diff = ChkBuilder.CreateSub(Sink, Src); - Value *IsConflict = - ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check"); + Value *Diff = Expander.expandCodeFor( + SE.getMinusSCEV(C.SinkStart, C.SrcStart), Ty, Loc); + + // Check if the same compare has already been created earlier. In that case, + // there is no need to check it again. + Value *IsConflict = SeenCompares.lookup({Diff, VFTimesUFTimesSize}); + if (IsConflict) + continue; + IsConflict = + ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check"); + SeenCompares.insert({{Diff, VFTimesUFTimesSize}, IsConflict}); + if (C.NeedsFreeze) + IsConflict = + ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr"); if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 78ebe75c121b..548b0f3c55f0 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -145,8 +145,8 @@ void LoopVersioning::addPHINodes( } // If not create it. if (!PN) { - PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver", - &PHIBlock->front()); + PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver"); + PN->insertBefore(PHIBlock->begin()); SmallVector<User*, 8> UsersToUpdate; for (User *U : Inst->users()) if (!VersionedLoop->contains(cast<Instruction>(U)->getParent())) diff --git a/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp b/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp index 195c274ff18e..4908535cba54 100644 --- a/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp +++ b/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp @@ -128,7 +128,7 @@ static bool runImpl(Module &M) { // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d); LLVMContext &C = M.getContext(); - PointerType *VoidStar = Type::getInt8PtrTy(C); + PointerType *VoidStar = PointerType::getUnqual(C); Type *AtExitFuncArgs[] = {VoidStar}; FunctionType *AtExitFuncTy = FunctionType::get(Type::getVoidTy(C), AtExitFuncArgs, @@ -140,6 +140,17 @@ static bool runImpl(Module &M) { {PointerType::get(AtExitFuncTy, 0), VoidStar, VoidStar}, /*isVarArg=*/false)); + // If __cxa_atexit is defined (e.g. in the case of LTO) and arg0 is not + // actually used (i.e. it's dummy/stub function as used in emscripten when + // the program never exits) we can simply return early and clear out + // @llvm.global_dtors. + if (auto F = dyn_cast<Function>(AtExit.getCallee())) { + if (F && F->hasExactDefinition() && F->getArg(0)->getNumUses() == 0) { + GV->eraseFromParent(); + return true; + } + } + // Declare __dso_local. Type *DsoHandleTy = Type::getInt8Ty(C); Constant *DsoHandle = M.getOrInsertGlobal("__dso_handle", DsoHandleTy, [&] { diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 906eb71fc2d9..c75de8687879 100644 --- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -64,17 +64,6 @@ void llvm::createMemCpyLoopKnownSize( IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); - // Cast the Src and Dst pointers to pointers to the loop operand type (if - // needed). - PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS); - PointerType *DstOpType = PointerType::get(LoopOpType, DstAS); - if (SrcAddr->getType() != SrcOpType) { - SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType); - } - if (DstAddr->getType() != DstOpType) { - DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType); - } - Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize)); Align PartSrcAlign(commonAlignment(SrcAlign, LoopOpSize)); @@ -137,13 +126,9 @@ void llvm::createMemCpyLoopKnownSize( uint64_t GepIndex = BytesCopied / OperandSize; assert(GepIndex * OperandSize == BytesCopied && "Division should have no Remainder!"); - // Cast source to operand type and load - PointerType *SrcPtrType = PointerType::get(OpTy, SrcAS); - Value *CastedSrc = SrcAddr->getType() == SrcPtrType - ? SrcAddr - : RBuilder.CreateBitCast(SrcAddr, SrcPtrType); + Value *SrcGEP = RBuilder.CreateInBoundsGEP( - OpTy, CastedSrc, ConstantInt::get(TypeOfCopyLen, GepIndex)); + OpTy, SrcAddr, ConstantInt::get(TypeOfCopyLen, GepIndex)); LoadInst *Load = RBuilder.CreateAlignedLoad(OpTy, SrcGEP, PartSrcAlign, SrcIsVolatile); if (!CanOverlap) { @@ -151,13 +136,8 @@ void llvm::createMemCpyLoopKnownSize( Load->setMetadata(LLVMContext::MD_alias_scope, MDNode::get(Ctx, NewScope)); } - // Cast destination to operand type and store. - PointerType *DstPtrType = PointerType::get(OpTy, DstAS); - Value *CastedDst = DstAddr->getType() == DstPtrType - ? DstAddr - : RBuilder.CreateBitCast(DstAddr, DstPtrType); Value *DstGEP = RBuilder.CreateInBoundsGEP( - OpTy, CastedDst, ConstantInt::get(TypeOfCopyLen, GepIndex)); + OpTy, DstAddr, ConstantInt::get(TypeOfCopyLen, GepIndex)); StoreInst *Store = RBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); if (!CanOverlap) { @@ -206,15 +186,6 @@ void llvm::createMemCpyLoopUnknownSize( IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); - PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS); - PointerType *DstOpType = PointerType::get(LoopOpType, DstAS); - if (SrcAddr->getType() != SrcOpType) { - SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType); - } - if (DstAddr->getType() != DstOpType) { - DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType); - } - // Calculate the loop trip count, and remaining bytes to copy after the loop. Type *CopyLenType = CopyLen->getType(); IntegerType *ILengthType = dyn_cast<IntegerType>(CopyLenType); @@ -305,13 +276,9 @@ void llvm::createMemCpyLoopUnknownSize( ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index"); ResidualIndex->addIncoming(Zero, ResHeaderBB); - Value *SrcAsResLoopOpType = ResBuilder.CreateBitCast( - SrcAddr, PointerType::get(ResLoopOpType, SrcAS)); - Value *DstAsResLoopOpType = ResBuilder.CreateBitCast( - DstAddr, PointerType::get(ResLoopOpType, DstAS)); Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex); - Value *SrcGEP = ResBuilder.CreateInBoundsGEP( - ResLoopOpType, SrcAsResLoopOpType, FullOffset); + Value *SrcGEP = + ResBuilder.CreateInBoundsGEP(ResLoopOpType, SrcAddr, FullOffset); LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP, PartSrcAlign, SrcIsVolatile); if (!CanOverlap) { @@ -319,8 +286,8 @@ void llvm::createMemCpyLoopUnknownSize( Load->setMetadata(LLVMContext::MD_alias_scope, MDNode::get(Ctx, NewScope)); } - Value *DstGEP = ResBuilder.CreateInBoundsGEP( - ResLoopOpType, DstAsResLoopOpType, FullOffset); + Value *DstGEP = + ResBuilder.CreateInBoundsGEP(ResLoopOpType, DstAddr, FullOffset); StoreInst *Store = ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); if (!CanOverlap) { @@ -479,11 +446,6 @@ static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr, IRBuilder<> Builder(OrigBB->getTerminator()); - // Cast pointer to the type of value getting stored - unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); - DstAddr = Builder.CreateBitCast(DstAddr, - PointerType::get(SetValue->getType(), dstAS)); - Builder.CreateCondBr( Builder.CreateICmpEQ(ConstantInt::get(TypeOfCopyLen, 0), CopyLen), NewBB, LoopBB); diff --git a/llvm/lib/Transforms/Utils/MetaRenamer.cpp b/llvm/lib/Transforms/Utils/MetaRenamer.cpp index 44ac65f265f0..fd0112ae529c 100644 --- a/llvm/lib/Transforms/Utils/MetaRenamer.cpp +++ b/llvm/lib/Transforms/Utils/MetaRenamer.cpp @@ -151,7 +151,7 @@ void MetaRename(Module &M, auto IsNameExcluded = [](StringRef &Name, SmallVectorImpl<StringRef> &ExcludedPrefixes) { return any_of(ExcludedPrefixes, - [&Name](auto &Prefix) { return Name.startswith(Prefix); }); + [&Name](auto &Prefix) { return Name.starts_with(Prefix); }); }; // Leave library functions alone because their presence or absence could @@ -159,7 +159,7 @@ void MetaRename(Module &M, auto ExcludeLibFuncs = [&](Function &F) { LibFunc Tmp; StringRef Name = F.getName(); - return Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || + return Name.starts_with("llvm.") || (!Name.empty() && Name[0] == 1) || GetTLI(F).getLibFunc(F, Tmp) || IsNameExcluded(Name, ExcludedFuncPrefixes); }; @@ -177,7 +177,7 @@ void MetaRename(Module &M, // Rename all aliases for (GlobalAlias &GA : M.aliases()) { StringRef Name = GA.getName(); - if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || + if (Name.starts_with("llvm.") || (!Name.empty() && Name[0] == 1) || IsNameExcluded(Name, ExcludedAliasesPrefixes)) continue; @@ -187,7 +187,7 @@ void MetaRename(Module &M, // Rename all global variables for (GlobalVariable &GV : M.globals()) { StringRef Name = GV.getName(); - if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || + if (Name.starts_with("llvm.") || (!Name.empty() && Name[0] == 1) || IsNameExcluded(Name, ExcludedGlobalsPrefixes)) continue; diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 1e243ef74df7..7de0959ca57e 100644 --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -44,17 +44,17 @@ static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F, } GVCtor->eraseFromParent(); } else { - EltTy = StructType::get( - IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()), - IRB.getInt8PtrTy()); + EltTy = StructType::get(IRB.getInt32Ty(), + PointerType::get(FnTy, F->getAddressSpace()), + IRB.getPtrTy()); } // Build a 3 field global_ctor entry. We don't take a comdat key. Constant *CSVals[3]; CSVals[0] = IRB.getInt32(Priority); CSVals[1] = F; - CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getInt8PtrTy()) - : Constant::getNullValue(IRB.getInt8PtrTy()); + CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getPtrTy()) + : Constant::getNullValue(IRB.getPtrTy()); Constant *RuntimeCtorInit = ConstantStruct::get(EltTy, ArrayRef(CSVals, EltTy->getNumElements())); @@ -96,7 +96,7 @@ static void appendToUsedList(Module &M, StringRef Name, ArrayRef<GlobalValue *> if (GV) GV->eraseFromParent(); - Type *ArrayEltTy = llvm::Type::getInt8PtrTy(M.getContext()); + Type *ArrayEltTy = llvm::PointerType::getUnqual(M.getContext()); for (auto *V : Values) Init.insert(ConstantExpr::getPointerBitCastOrAddrSpaceCast(V, ArrayEltTy)); @@ -301,7 +301,7 @@ std::string llvm::getUniqueModuleId(Module *M) { MD5 Md5; bool ExportsSymbols = false; auto AddGlobal = [&](GlobalValue &GV) { - if (GV.isDeclaration() || GV.getName().startswith("llvm.") || + if (GV.isDeclaration() || GV.getName().starts_with("llvm.") || !GV.hasExternalLinkage() || GV.hasComdat()) return; ExportsSymbols = true; @@ -346,7 +346,8 @@ void VFABI::setVectorVariantNames(CallInst *CI, #ifndef NDEBUG for (const std::string &VariantMapping : VariantMappings) { LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n"); - std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M); + std::optional<VFInfo> VI = + VFABI::tryDemangleForVFABI(VariantMapping, CI->getFunctionType()); assert(VI && "Cannot add an invalid VFABI name."); assert(M->getNamedValue(VI->VectorName) && "Cannot add variant to attribute: " diff --git a/llvm/lib/Transforms/Utils/MoveAutoInit.cpp b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp index b0ca0b15c08e..a977ad87b79f 100644 --- a/llvm/lib/Transforms/Utils/MoveAutoInit.cpp +++ b/llvm/lib/Transforms/Utils/MoveAutoInit.cpp @@ -14,7 +14,6 @@ #include "llvm/Transforms/Utils/MoveAutoInit.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringSet.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ValueTracking.h" @@ -50,7 +49,7 @@ static std::optional<MemoryLocation> writeToAlloca(const Instruction &I) { else if (auto *SI = dyn_cast<StoreInst>(&I)) ML = MemoryLocation::get(SI); else - assert(false && "memory location set"); + return std::nullopt; if (isa<AllocaInst>(getUnderlyingObject(ML.Ptr))) return ML; @@ -202,7 +201,7 @@ static bool runMoveAutoInit(Function &F, DominatorTree &DT, MemorySSA &MSSA) { // if two instructions are moved from the same BB to the same BB, we insert // the second one in the front, then the first on top of it. for (auto &Job : reverse(JobList)) { - Job.first->moveBefore(&*Job.second->getFirstInsertionPt()); + Job.first->moveBefore(*Job.second, Job.second->getFirstInsertionPt()); MSSAU.moveToPlace(MSSA.getMemoryAccess(Job.first), Job.first->getParent(), MemorySSA::InsertionPlace::Beginning); } diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp index 1f16ba78bdb0..902977b08d15 100644 --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -23,7 +23,6 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" @@ -33,12 +32,6 @@ using namespace llvm; using namespace PatternMatch; -INITIALIZE_PASS_BEGIN(PredicateInfoPrinterLegacyPass, "print-predicateinfo", - "PredicateInfo Printer", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(PredicateInfoPrinterLegacyPass, "print-predicateinfo", - "PredicateInfo Printer", false, false) static cl::opt<bool> VerifyPredicateInfo( "verify-predicateinfo", cl::init(false), cl::Hidden, cl::desc("Verify PredicateInfo in legacy printer pass.")); @@ -835,20 +828,6 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const { void PredicateInfo::verifyPredicateInfo() const {} -char PredicateInfoPrinterLegacyPass::ID = 0; - -PredicateInfoPrinterLegacyPass::PredicateInfoPrinterLegacyPass() - : FunctionPass(ID) { - initializePredicateInfoPrinterLegacyPassPass( - *PassRegistry::getPassRegistry()); -} - -void PredicateInfoPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequiredTransitive<DominatorTreeWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); -} - // Replace ssa_copy calls created by PredicateInfo with their operand. static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) { for (Instruction &Inst : llvm::make_early_inc_range(instructions(F))) { @@ -862,18 +841,6 @@ static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) { } } -bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) { - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto PredInfo = std::make_unique<PredicateInfo>(F, DT, AC); - PredInfo->print(dbgs()); - if (VerifyPredicateInfo) - PredInfo->verifyPredicateInfo(); - - replaceCreatedSSACopys(*PredInfo, F); - return false; -} - PreservedAnalyses PredicateInfoPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 2e5f40d39912..717b6d301c8c 100644 --- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugProgramInstruction.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -172,6 +173,7 @@ public: struct AllocaInfo { using DbgUserVec = SmallVector<DbgVariableIntrinsic *, 1>; + using DPUserVec = SmallVector<DPValue *, 1>; SmallVector<BasicBlock *, 32> DefiningBlocks; SmallVector<BasicBlock *, 32> UsingBlocks; @@ -182,6 +184,7 @@ struct AllocaInfo { /// Debug users of the alloca - does not include dbg.assign intrinsics. DbgUserVec DbgUsers; + DPUserVec DPUsers; /// Helper to update assignment tracking debug info. AssignmentTrackingInfo AssignmentTracking; @@ -192,6 +195,7 @@ struct AllocaInfo { OnlyBlock = nullptr; OnlyUsedInOneBlock = true; DbgUsers.clear(); + DPUsers.clear(); AssignmentTracking.clear(); } @@ -225,7 +229,7 @@ struct AllocaInfo { } } DbgUserVec AllDbgUsers; - findDbgUsers(AllDbgUsers, AI); + findDbgUsers(AllDbgUsers, AI, &DPUsers); std::copy_if(AllDbgUsers.begin(), AllDbgUsers.end(), std::back_inserter(DbgUsers), [](DbgVariableIntrinsic *DII) { return !isa<DbgAssignIntrinsic>(DII); @@ -329,6 +333,7 @@ struct PromoteMem2Reg { /// describes it, if any, so that we can convert it to a dbg.value /// intrinsic if the alloca gets promoted. SmallVector<AllocaInfo::DbgUserVec, 8> AllocaDbgUsers; + SmallVector<AllocaInfo::DPUserVec, 8> AllocaDPUsers; /// For each alloca, keep an instance of a helper class that gives us an easy /// way to update assignment tracking debug info if the alloca is promoted. @@ -525,14 +530,18 @@ static bool rewriteSingleStoreAlloca( // Record debuginfo for the store and remove the declaration's // debuginfo. - for (DbgVariableIntrinsic *DII : Info.DbgUsers) { - if (DII->isAddressOfVariable()) { - ConvertDebugDeclareToDebugValue(DII, Info.OnlyStore, DIB); - DII->eraseFromParent(); - } else if (DII->getExpression()->startsWithDeref()) { - DII->eraseFromParent(); + auto ConvertDebugInfoForStore = [&](auto &Container) { + for (auto *DbgItem : Container) { + if (DbgItem->isAddressOfVariable()) { + ConvertDebugDeclareToDebugValue(DbgItem, Info.OnlyStore, DIB); + DbgItem->eraseFromParent(); + } else if (DbgItem->getExpression()->startsWithDeref()) { + DbgItem->eraseFromParent(); + } } - } + }; + ConvertDebugInfoForStore(Info.DbgUsers); + ConvertDebugInfoForStore(Info.DPUsers); // Remove dbg.assigns linked to the alloca as these are now redundant. at::deleteAssignmentMarkers(AI); @@ -629,12 +638,18 @@ static bool promoteSingleBlockAlloca( StoreInst *SI = cast<StoreInst>(AI->user_back()); // Update assignment tracking info for the store we're going to delete. Info.AssignmentTracking.updateForDeletedStore(SI, DIB, DbgAssignsToDelete); + // Record debuginfo for the store before removing it. - for (DbgVariableIntrinsic *DII : Info.DbgUsers) { - if (DII->isAddressOfVariable()) { - ConvertDebugDeclareToDebugValue(DII, SI, DIB); + auto DbgUpdateForStore = [&](auto &Container) { + for (auto *DbgItem : Container) { + if (DbgItem->isAddressOfVariable()) { + ConvertDebugDeclareToDebugValue(DbgItem, SI, DIB); + } } - } + }; + DbgUpdateForStore(Info.DbgUsers); + DbgUpdateForStore(Info.DPUsers); + SI->eraseFromParent(); LBI.deleteValue(SI); } @@ -644,9 +659,14 @@ static bool promoteSingleBlockAlloca( AI->eraseFromParent(); // The alloca's debuginfo can be removed as well. - for (DbgVariableIntrinsic *DII : Info.DbgUsers) - if (DII->isAddressOfVariable() || DII->getExpression()->startsWithDeref()) - DII->eraseFromParent(); + auto DbgUpdateForAlloca = [&](auto &Container) { + for (auto *DbgItem : Container) + if (DbgItem->isAddressOfVariable() || + DbgItem->getExpression()->startsWithDeref()) + DbgItem->eraseFromParent(); + }; + DbgUpdateForAlloca(Info.DbgUsers); + DbgUpdateForAlloca(Info.DPUsers); ++NumLocalPromoted; return true; @@ -657,6 +677,7 @@ void PromoteMem2Reg::run() { AllocaDbgUsers.resize(Allocas.size()); AllocaATInfo.resize(Allocas.size()); + AllocaDPUsers.resize(Allocas.size()); AllocaInfo Info; LargeBlockInfo LBI; @@ -720,6 +741,8 @@ void PromoteMem2Reg::run() { AllocaDbgUsers[AllocaNum] = Info.DbgUsers; if (!Info.AssignmentTracking.empty()) AllocaATInfo[AllocaNum] = Info.AssignmentTracking; + if (!Info.DPUsers.empty()) + AllocaDPUsers[AllocaNum] = Info.DPUsers; // Keep the reverse mapping of the 'Allocas' array for the rename pass. AllocaLookup[Allocas[AllocaNum]] = AllocaNum; @@ -795,11 +818,16 @@ void PromoteMem2Reg::run() { } // Remove alloca's dbg.declare intrinsics from the function. - for (auto &DbgUsers : AllocaDbgUsers) { - for (auto *DII : DbgUsers) - if (DII->isAddressOfVariable() || DII->getExpression()->startsWithDeref()) - DII->eraseFromParent(); - } + auto RemoveDbgDeclares = [&](auto &Container) { + for (auto &DbgUsers : Container) { + for (auto *DbgItem : DbgUsers) + if (DbgItem->isAddressOfVariable() || + DbgItem->getExpression()->startsWithDeref()) + DbgItem->eraseFromParent(); + } + }; + RemoveDbgDeclares(AllocaDbgUsers); + RemoveDbgDeclares(AllocaDPUsers); // Loop over all of the PHI nodes and see if there are any that we can get // rid of because they merge all of the same incoming values. This can @@ -981,8 +1009,8 @@ bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo, // Create a PhiNode using the dereferenced type... and add the phi-node to the // BasicBlock. PN = PHINode::Create(Allocas[AllocaNo]->getAllocatedType(), getNumPreds(BB), - Allocas[AllocaNo]->getName() + "." + Twine(Version++), - &BB->front()); + Allocas[AllocaNo]->getName() + "." + Twine(Version++)); + PN->insertBefore(BB->begin()); ++NumPHIInsert; PhiToAllocaMap[PN] = AllocaNo; return true; @@ -1041,9 +1069,13 @@ NextIteration: // The currently active variable for this block is now the PHI. IncomingVals[AllocaNo] = APN; AllocaATInfo[AllocaNo].updateForNewPhi(APN, DIB); - for (DbgVariableIntrinsic *DII : AllocaDbgUsers[AllocaNo]) - if (DII->isAddressOfVariable()) - ConvertDebugDeclareToDebugValue(DII, APN, DIB); + auto ConvertDbgDeclares = [&](auto &Container) { + for (auto *DbgItem : Container) + if (DbgItem->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DbgItem, APN, DIB); + }; + ConvertDbgDeclares(AllocaDbgUsers[AllocaNo]); + ConvertDbgDeclares(AllocaDPUsers[AllocaNo]); // Get the next phi node. ++PNI; @@ -1098,9 +1130,13 @@ NextIteration: IncomingLocs[AllocaNo] = SI->getDebugLoc(); AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB, &DbgAssignsToDelete); - for (DbgVariableIntrinsic *DII : AllocaDbgUsers[ai->second]) - if (DII->isAddressOfVariable()) - ConvertDebugDeclareToDebugValue(DII, SI, DIB); + auto ConvertDbgDeclares = [&](auto &Container) { + for (auto *DbgItem : Container) + if (DbgItem->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DbgItem, SI, DIB); + }; + ConvertDbgDeclares(AllocaDbgUsers[ai->second]); + ConvertDbgDeclares(AllocaDPUsers[ai->second]); SI->eraseFromParent(); } } diff --git a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp index c9ff94dc9744..ea628d7c3d7d 100644 --- a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp +++ b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp @@ -153,17 +153,12 @@ static void convertToRelLookupTable(GlobalVariable &LookupTable) { Builder.SetInsertPoint(Load); Function *LoadRelIntrinsic = llvm::Intrinsic::getDeclaration( &M, Intrinsic::load_relative, {Index->getType()}); - Value *Base = Builder.CreateBitCast(RelLookupTable, Builder.getInt8PtrTy()); // Create a call to load.relative intrinsic that computes the target address // by adding base address (lookup table address) and relative offset. - Value *Result = Builder.CreateCall(LoadRelIntrinsic, {Base, Offset}, + Value *Result = Builder.CreateCall(LoadRelIntrinsic, {RelLookupTable, Offset}, "reltable.intrinsic"); - // Create a bitcast instruction if necessary. - if (Load->getType() != Builder.getInt8PtrTy()) - Result = Builder.CreateBitCast(Result, Load->getType(), "reltable.bitcast"); - // Replace load instruction with the new generated instruction sequence. Load->replaceAllUsesWith(Result); // Remove Load and GEP instructions. diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index de3626a24212..ab95698abc43 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -107,9 +107,7 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) { static bool refineInstruction(SCCPSolver &Solver, const SmallPtrSetImpl<Value *> &InsertedValues, Instruction &Inst) { - if (!isa<OverflowingBinaryOperator>(Inst)) - return false; - + bool Changed = false; auto GetRange = [&Solver, &InsertedValues](Value *Op) { if (auto *Const = dyn_cast<ConstantInt>(Op)) return ConstantRange(Const->getValue()); @@ -120,23 +118,32 @@ static bool refineInstruction(SCCPSolver &Solver, return getConstantRange(Solver.getLatticeValueFor(Op), Op->getType(), /*UndefAllowed=*/false); }; - auto RangeA = GetRange(Inst.getOperand(0)); - auto RangeB = GetRange(Inst.getOperand(1)); - bool Changed = false; - if (!Inst.hasNoUnsignedWrap()) { - auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( - Instruction::BinaryOps(Inst.getOpcode()), RangeB, - OverflowingBinaryOperator::NoUnsignedWrap); - if (NUWRange.contains(RangeA)) { - Inst.setHasNoUnsignedWrap(); - Changed = true; + + if (isa<OverflowingBinaryOperator>(Inst)) { + auto RangeA = GetRange(Inst.getOperand(0)); + auto RangeB = GetRange(Inst.getOperand(1)); + if (!Inst.hasNoUnsignedWrap()) { + auto NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::BinaryOps(Inst.getOpcode()), RangeB, + OverflowingBinaryOperator::NoUnsignedWrap); + if (NUWRange.contains(RangeA)) { + Inst.setHasNoUnsignedWrap(); + Changed = true; + } } - } - if (!Inst.hasNoSignedWrap()) { - auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( - Instruction::BinaryOps(Inst.getOpcode()), RangeB, OverflowingBinaryOperator::NoSignedWrap); - if (NSWRange.contains(RangeA)) { - Inst.setHasNoSignedWrap(); + if (!Inst.hasNoSignedWrap()) { + auto NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::BinaryOps(Inst.getOpcode()), RangeB, + OverflowingBinaryOperator::NoSignedWrap); + if (NSWRange.contains(RangeA)) { + Inst.setHasNoSignedWrap(); + Changed = true; + } + } + } else if (isa<ZExtInst>(Inst) && !Inst.hasNonNeg()) { + auto Range = GetRange(Inst.getOperand(0)); + if (Range.isAllNonNegative()) { + Inst.setNonNeg(); Changed = true; } } @@ -171,6 +178,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, if (InsertedValues.count(Op0) || !isNonNegative(Op0)) return false; NewInst = new ZExtInst(Op0, Inst.getType(), "", &Inst); + NewInst->setNonNeg(); break; } case Instruction::AShr: { @@ -179,6 +187,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, if (InsertedValues.count(Op0) || !isNonNegative(Op0)) return false; NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", &Inst); + NewInst->setIsExact(Inst.isExact()); break; } case Instruction::SDiv: @@ -191,6 +200,8 @@ static bool replaceSignedInst(SCCPSolver &Solver, auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv : Instruction::URem; NewInst = BinaryOperator::Create(NewOpcode, Op0, Op1, "", &Inst); + if (Inst.getOpcode() == Instruction::SDiv) + NewInst->setIsExact(Inst.isExact()); break; } default: @@ -1029,8 +1040,9 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } - // Unwinding instructions successors are always executable. - if (TI.isExceptionalTerminator()) { + // We cannot analyze special terminators, so consider all successors + // executable. + if (TI.isSpecialTerminator()) { Succs.assign(TI.getNumSuccessors(), true); return; } @@ -1098,13 +1110,6 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } - // In case of callbr, we pessimistically assume that all successors are - // feasible. - if (isa<CallBrInst>(&TI)) { - Succs.assign(TI.getNumSuccessors(), true); - return; - } - LLVM_DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n'); llvm_unreachable("SCCP: Don't know how to handle this terminator!"); } @@ -1231,10 +1236,12 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) { // Fold the constant as we build. - Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); - markConstant(&I, C); - } else if (I.getDestTy()->isIntegerTy() && - I.getSrcTy()->isIntOrIntVectorTy()) { + if (Constant *C = + ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL)) + return (void)markConstant(&I, C); + } + + if (I.getDestTy()->isIntegerTy() && I.getSrcTy()->isIntOrIntVectorTy()) { auto &LV = getValueState(&I); ConstantRange OpRange = getConstantRange(OpSt, I.getSrcTy()); @@ -1539,11 +1546,8 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { return (void)markOverdefined(&I); } - Constant *Ptr = Operands[0]; - auto Indices = ArrayRef(Operands.begin() + 1, Operands.end()); - Constant *C = - ConstantExpr::getGetElementPtr(I.getSourceElementType(), Ptr, Indices); - markConstant(&I, C); + if (Constant *C = ConstantFoldInstOperands(&I, Operands, DL)) + markConstant(&I, C); } void SCCPInstVisitor::visitStoreInst(StoreInst &SI) { diff --git a/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/llvm/lib/Transforms/Utils/SSAUpdater.cpp index ebe9cb27f5ab..fc21fb552137 100644 --- a/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -156,8 +156,9 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { } // Ok, we have no way out, insert a new one now. - PHINode *InsertedPHI = PHINode::Create(ProtoType, PredValues.size(), - ProtoName, &BB->front()); + PHINode *InsertedPHI = + PHINode::Create(ProtoType, PredValues.size(), ProtoName); + InsertedPHI->insertBefore(BB->begin()); // Fill in all the predecessors of the PHI. for (const auto &PredValue : PredValues) @@ -198,12 +199,18 @@ void SSAUpdater::RewriteUse(Use &U) { void SSAUpdater::UpdateDebugValues(Instruction *I) { SmallVector<DbgValueInst *, 4> DbgValues; - llvm::findDbgValues(DbgValues, I); + SmallVector<DPValue *, 4> DPValues; + llvm::findDbgValues(DbgValues, I, &DPValues); for (auto &DbgValue : DbgValues) { if (DbgValue->getParent() == I->getParent()) continue; UpdateDebugValue(I, DbgValue); } + for (auto &DPV : DPValues) { + if (DPV->getParent() == I->getParent()) + continue; + UpdateDebugValue(I, DPV); + } } void SSAUpdater::UpdateDebugValues(Instruction *I, @@ -213,16 +220,31 @@ void SSAUpdater::UpdateDebugValues(Instruction *I, } } +void SSAUpdater::UpdateDebugValues(Instruction *I, + SmallVectorImpl<DPValue *> &DPValues) { + for (auto &DPV : DPValues) { + UpdateDebugValue(I, DPV); + } +} + void SSAUpdater::UpdateDebugValue(Instruction *I, DbgValueInst *DbgValue) { BasicBlock *UserBB = DbgValue->getParent(); if (HasValueForBlock(UserBB)) { Value *NewVal = GetValueAtEndOfBlock(UserBB); DbgValue->replaceVariableLocationOp(I, NewVal); - } - else + } else DbgValue->setKillLocation(); } +void SSAUpdater::UpdateDebugValue(Instruction *I, DPValue *DPV) { + BasicBlock *UserBB = DPV->getParent(); + if (HasValueForBlock(UserBB)) { + Value *NewVal = GetValueAtEndOfBlock(UserBB); + DPV->replaceVariableLocationOp(I, NewVal); + } else + DPV->setKillLocation(); +} + void SSAUpdater::RewriteUseAfterInsertions(Use &U) { Instruction *User = cast<Instruction>(U.getUser()); @@ -295,8 +317,9 @@ public: /// Reserve space for the operands but do not fill them in yet. static Value *CreateEmptyPHI(BasicBlock *BB, unsigned NumPreds, SSAUpdater *Updater) { - PHINode *PHI = PHINode::Create(Updater->ProtoType, NumPreds, - Updater->ProtoName, &BB->front()); + PHINode *PHI = + PHINode::Create(Updater->ProtoType, NumPreds, Updater->ProtoName); + PHI->insertBefore(BB->begin()); return PHI; } diff --git a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp index 31d62fbf0618..101b70d8def4 100644 --- a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp +++ b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp @@ -159,7 +159,7 @@ public: /// Get the total flow from a given source node. /// Returns a list of pairs (target node, amount of flow to the target). - const std::vector<std::pair<uint64_t, int64_t>> getFlow(uint64_t Src) const { + std::vector<std::pair<uint64_t, int64_t>> getFlow(uint64_t Src) const { std::vector<std::pair<uint64_t, int64_t>> Flow; for (const auto &Edge : Edges[Src]) { if (Edge.Flow > 0) diff --git a/llvm/lib/Transforms/Utils/SanitizerStats.cpp b/llvm/lib/Transforms/Utils/SanitizerStats.cpp index fd21ee4cc408..b80c5a6f9d68 100644 --- a/llvm/lib/Transforms/Utils/SanitizerStats.cpp +++ b/llvm/lib/Transforms/Utils/SanitizerStats.cpp @@ -21,7 +21,7 @@ using namespace llvm; SanitizerStatReport::SanitizerStatReport(Module *M) : M(M) { - StatTy = ArrayType::get(Type::getInt8PtrTy(M->getContext()), 2); + StatTy = ArrayType::get(PointerType::getUnqual(M->getContext()), 2); EmptyModuleStatsTy = makeModuleStatsTy(); ModuleStatsGV = new GlobalVariable(*M, EmptyModuleStatsTy, false, @@ -33,28 +33,28 @@ ArrayType *SanitizerStatReport::makeModuleStatsArrayTy() { } StructType *SanitizerStatReport::makeModuleStatsTy() { - return StructType::get(M->getContext(), {Type::getInt8PtrTy(M->getContext()), - Type::getInt32Ty(M->getContext()), - makeModuleStatsArrayTy()}); + return StructType::get(M->getContext(), + {PointerType::getUnqual(M->getContext()), + Type::getInt32Ty(M->getContext()), + makeModuleStatsArrayTy()}); } void SanitizerStatReport::create(IRBuilder<> &B, SanitizerStatKind SK) { Function *F = B.GetInsertBlock()->getParent(); Module *M = F->getParent(); - PointerType *Int8PtrTy = B.getInt8PtrTy(); + PointerType *PtrTy = B.getPtrTy(); IntegerType *IntPtrTy = B.getIntPtrTy(M->getDataLayout()); - ArrayType *StatTy = ArrayType::get(Int8PtrTy, 2); + ArrayType *StatTy = ArrayType::get(PtrTy, 2); Inits.push_back(ConstantArray::get( StatTy, - {Constant::getNullValue(Int8PtrTy), + {Constant::getNullValue(PtrTy), ConstantExpr::getIntToPtr( ConstantInt::get(IntPtrTy, uint64_t(SK) << (IntPtrTy->getBitWidth() - kSanitizerStatKindBits)), - Int8PtrTy)})); + PtrTy)})); - FunctionType *StatReportTy = - FunctionType::get(B.getVoidTy(), Int8PtrTy, false); + FunctionType *StatReportTy = FunctionType::get(B.getVoidTy(), PtrTy, false); FunctionCallee StatReport = M->getOrInsertFunction("__sanitizer_stat_report", StatReportTy); @@ -64,7 +64,7 @@ void SanitizerStatReport::create(IRBuilder<> &B, SanitizerStatKind SK) { ConstantInt::get(IntPtrTy, 0), ConstantInt::get(B.getInt32Ty(), 2), ConstantInt::get(IntPtrTy, Inits.size() - 1), }); - B.CreateCall(StatReport, ConstantExpr::getBitCast(InitAddr, Int8PtrTy)); + B.CreateCall(StatReport, InitAddr); } void SanitizerStatReport::finish() { @@ -73,7 +73,7 @@ void SanitizerStatReport::finish() { return; } - PointerType *Int8PtrTy = Type::getInt8PtrTy(M->getContext()); + PointerType *Int8PtrTy = PointerType::getUnqual(M->getContext()); IntegerType *Int32Ty = Type::getInt32Ty(M->getContext()); Type *VoidTy = Type::getVoidTy(M->getContext()); @@ -85,8 +85,7 @@ void SanitizerStatReport::finish() { {Constant::getNullValue(Int8PtrTy), ConstantInt::get(Int32Ty, Inits.size()), ConstantArray::get(makeModuleStatsArrayTy(), Inits)})); - ModuleStatsGV->replaceAllUsesWith( - ConstantExpr::getBitCast(NewModuleStatsGV, ModuleStatsGV->getType())); + ModuleStatsGV->replaceAllUsesWith(NewModuleStatsGV); ModuleStatsGV->eraseFromParent(); // Create a global constructor to register NewModuleStatsGV. @@ -99,7 +98,7 @@ void SanitizerStatReport::finish() { FunctionCallee StatInit = M->getOrInsertFunction("__sanitizer_stat_init", StatInitTy); - B.CreateCall(StatInit, ConstantExpr::getBitCast(NewModuleStatsGV, Int8PtrTy)); + B.CreateCall(StatInit, NewModuleStatsGV); B.CreateRetVoid(); appendToGlobalCtors(*M, F, 0); diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 20844271b943..cd3ac317cd23 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -170,11 +170,10 @@ Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) { if (Op == Instruction::IntToPtr) { auto *PtrTy = cast<PointerType>(Ty); if (DL.isNonIntegralPointerType(PtrTy)) { - auto *Int8PtrTy = Builder.getInt8PtrTy(PtrTy->getAddressSpace()); assert(DL.getTypeAllocSize(Builder.getInt8Ty()) == 1 && "alloc size of i8 must by 1 byte for the GEP to be correct"); return Builder.CreateGEP( - Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "scevgep"); + Builder.getInt8Ty(), Constant::getNullValue(PtrTy), V, "scevgep"); } } // Short-circuit unnecessary bitcasts. @@ -313,11 +312,11 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, /// loop-invariant portions of expressions, after considering what /// can be folded using target addressing modes. /// -Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) { +Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Value *V) { assert(!isa<Instruction>(V) || SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint())); - Value *Idx = expandCodeForImpl(Offset, Ty); + Value *Idx = expand(Offset); // Fold a GEP with constant operands. if (Constant *CLHS = dyn_cast<Constant>(V)) @@ -339,7 +338,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Type *Ty, Value *V) { if (IP->getOpcode() == Instruction::GetElementPtr && IP->getOperand(0) == V && IP->getOperand(1) == Idx && cast<GEPOperator>(&*IP)->getSourceElementType() == - Type::getInt8Ty(Ty->getContext())) + Builder.getInt8Ty()) return &*IP; if (IP == BlockBegin) break; } @@ -457,8 +456,6 @@ public: } Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { - Type *Ty = SE.getEffectiveSCEVType(S->getType()); - // Collect all the add operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal, and // so that pointer operands are inserted first, which the code below relies on @@ -498,20 +495,19 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { X = SE.getSCEV(U->getValue()); NewOps.push_back(X); } - Sum = expandAddToGEP(SE.getAddExpr(NewOps), Ty, Sum); + Sum = expandAddToGEP(SE.getAddExpr(NewOps), Sum); } else if (Op->isNonConstantNegative()) { // Instead of doing a negate and add, just do a subtract. - Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty); - Sum = InsertNoopCastOfTo(Sum, Ty); + Value *W = expand(SE.getNegativeSCEV(Op)); Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); ++I; } else { // A simple add. - Value *W = expandCodeForImpl(Op, Ty); - Sum = InsertNoopCastOfTo(Sum, Ty); + Value *W = expand(Op); // Canonicalize a constant to the RHS. - if (isa<Constant>(Sum)) std::swap(Sum, W); + if (isa<Constant>(Sum)) + std::swap(Sum, W); Sum = InsertBinop(Instruction::Add, Sum, W, S->getNoWrapFlags(), /*IsSafeToHoist*/ true); ++I; @@ -522,7 +518,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { } Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { - Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Type *Ty = S->getType(); // Collect all the mul operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal. @@ -541,7 +537,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { // Expand the calculation of X pow N in the following manner: // Let N = P1 + P2 + ... + PK, where all P are powers of 2. Then: // X pow N = (X pow P1) * (X pow P2) * ... * (X pow PK). - const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops, &Ty]() { + const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops]() { auto E = I; // Calculate how many times the same operand from the same loop is included // into this power. @@ -559,7 +555,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them // that are needed into the result. - Value *P = expandCodeForImpl(I->second, Ty); + Value *P = expand(I->second); Value *Result = nullptr; if (Exponent & 1) Result = P; @@ -584,14 +580,12 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { Prod = ExpandOpBinPowN(); } else if (I->second->isAllOnesValue()) { // Instead of doing a multiply by negative one, just do a negate. - Prod = InsertNoopCastOfTo(Prod, Ty); Prod = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), Prod, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); ++I; } else { // A simple mul. Value *W = ExpandOpBinPowN(); - Prod = InsertNoopCastOfTo(Prod, Ty); // Canonicalize a constant to the RHS. if (isa<Constant>(Prod)) std::swap(Prod, W); const APInt *RHS; @@ -616,18 +610,16 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { } Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { - Type *Ty = SE.getEffectiveSCEVType(S->getType()); - - Value *LHS = expandCodeForImpl(S->getLHS(), Ty); + Value *LHS = expand(S->getLHS()); if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) { const APInt &RHS = SC->getAPInt(); if (RHS.isPowerOf2()) return InsertBinop(Instruction::LShr, LHS, - ConstantInt::get(Ty, RHS.logBase2()), + ConstantInt::get(SC->getType(), RHS.logBase2()), SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); } - Value *RHS = expandCodeForImpl(S->getRHS(), Ty); + Value *RHS = expand(S->getRHS()); return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS())); } @@ -803,12 +795,11 @@ bool SCEVExpander::isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, /// Typically this is the LatchBlock terminator or IVIncInsertPos, but we may /// need to materialize IV increments elsewhere to handle difficult situations. Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, - Type *ExpandTy, Type *IntTy, bool useSubtract) { Value *IncV; // If the PHI is a pointer, use a GEP, otherwise use an add or sub. - if (ExpandTy->isPointerTy()) { - IncV = expandAddToGEP(SE.getSCEV(StepV), IntTy, PN); + if (PN->getType()->isPointerTy()) { + IncV = expandAddToGEP(SE.getSCEV(StepV), PN); } else { IncV = useSubtract ? Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") : @@ -824,12 +815,11 @@ static bool canBeCheaplyTransformed(ScalarEvolution &SE, const SCEVAddRecExpr *Requested, bool &InvertStep) { // We can't transform to match a pointer PHI. - if (Phi->getType()->isPointerTy()) + Type *PhiTy = Phi->getType(); + Type *RequestedTy = Requested->getType(); + if (PhiTy->isPointerTy() || RequestedTy->isPointerTy()) return false; - Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType()); - Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType()); - if (RequestedTy->getIntegerBitWidth() > PhiTy->getIntegerBitWidth()) return false; @@ -886,12 +876,10 @@ static bool IsIncrementNUW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { /// values, and return the PHI. PHINode * SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, - const Loop *L, - Type *ExpandTy, - Type *IntTy, - Type *&TruncTy, + const Loop *L, Type *&TruncTy, bool &InvertStep) { - assert((!IVIncInsertLoop||IVIncInsertPos) && "Uninitialized insert position"); + assert((!IVIncInsertLoop || IVIncInsertPos) && + "Uninitialized insert position"); // Reuse a previously-inserted PHI, if present. BasicBlock *LatchBlock = L->getLoopLatch(); @@ -962,7 +950,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, // later. AddRecPhiMatch = &PN; IncV = TempIncV; - TruncTy = SE.getEffectiveSCEVType(Normalized->getType()); + TruncTy = Normalized->getType(); } } @@ -996,8 +984,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, assert(L->getLoopPreheader() && "Can't expand add recurrences without a loop preheader!"); Value *StartV = - expandCodeForImpl(Normalized->getStart(), ExpandTy, - L->getLoopPreheader()->getTerminator()); + expand(Normalized->getStart(), L->getLoopPreheader()->getTerminator()); // StartV must have been be inserted into L's preheader to dominate the new // phi. @@ -1008,6 +995,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, // Expand code for the step value. Do this before creating the PHI so that PHI // reuse code doesn't see an incomplete PHI. const SCEV *Step = Normalized->getStepRecurrence(SE); + Type *ExpandTy = Normalized->getType(); // If the stride is negative, insert a sub instead of an add for the increment // (unless it's a constant, because subtracts of constants are canonicalized // to adds). @@ -1015,8 +1003,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, if (useSubtract) Step = SE.getNegativeSCEV(Step); // Expand the step somewhere that dominates the loop header. - Value *StepV = expandCodeForImpl( - Step, IntTy, &*L->getHeader()->getFirstInsertionPt()); + Value *StepV = expand(Step, L->getHeader()->getFirstInsertionPt()); // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if // we actually do emit an addition. It does not apply if we emit a @@ -1047,7 +1034,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, Instruction *InsertPos = L == IVIncInsertLoop ? IVIncInsertPos : Pred->getTerminator(); Builder.SetInsertPoint(InsertPos); - Value *IncV = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); + Value *IncV = expandIVInc(PN, StepV, L, useSubtract); if (isa<OverflowingBinaryOperator>(IncV)) { if (IncrementIsNUW) @@ -1070,8 +1057,6 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, } Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { - Type *STy = S->getType(); - Type *IntTy = SE.getEffectiveSCEVType(STy); const Loop *L = S->getLoop(); // Determine a normalized form of this expression, which is the expression @@ -1084,51 +1069,17 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { normalizeForPostIncUse(S, Loops, SE, /*CheckInvertible=*/false)); } - // Strip off any non-loop-dominating component from the addrec start. - const SCEV *Start = Normalized->getStart(); - const SCEV *PostLoopOffset = nullptr; - if (!SE.properlyDominates(Start, L->getHeader())) { - PostLoopOffset = Start; - Start = SE.getConstant(Normalized->getType(), 0); - Normalized = cast<SCEVAddRecExpr>( - SE.getAddRecExpr(Start, Normalized->getStepRecurrence(SE), - Normalized->getLoop(), - Normalized->getNoWrapFlags(SCEV::FlagNW))); - } - - // Strip off any non-loop-dominating component from the addrec step. + [[maybe_unused]] const SCEV *Start = Normalized->getStart(); const SCEV *Step = Normalized->getStepRecurrence(SE); - const SCEV *PostLoopScale = nullptr; - if (!SE.dominates(Step, L->getHeader())) { - PostLoopScale = Step; - Step = SE.getConstant(Normalized->getType(), 1); - if (!Start->isZero()) { - // The normalization below assumes that Start is constant zero, so if - // it isn't re-associate Start to PostLoopOffset. - assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?"); - PostLoopOffset = Start; - Start = SE.getConstant(Normalized->getType(), 0); - } - Normalized = - cast<SCEVAddRecExpr>(SE.getAddRecExpr( - Start, Step, Normalized->getLoop(), - Normalized->getNoWrapFlags(SCEV::FlagNW))); - } - - // Expand the core addrec. If we need post-loop scaling, force it to - // expand to an integer type to avoid the need for additional casting. - Type *ExpandTy = PostLoopScale ? IntTy : STy; - // We can't use a pointer type for the addrec if the pointer type is - // non-integral. - Type *AddRecPHIExpandTy = - DL.isNonIntegralPointerType(STy) ? Normalized->getType() : ExpandTy; + assert(SE.properlyDominates(Start, L->getHeader()) && + "Start does not properly dominate loop header"); + assert(SE.dominates(Step, L->getHeader()) && "Step not dominate loop header"); // In some cases, we decide to reuse an existing phi node but need to truncate // it and/or invert the step. Type *TruncTy = nullptr; bool InvertStep = false; - PHINode *PN = getAddRecExprPHILiterally(Normalized, L, AddRecPHIExpandTy, - IntTy, TruncTy, InvertStep); + PHINode *PN = getAddRecExprPHILiterally(Normalized, L, TruncTy, InvertStep); // Accommodate post-inc mode, if necessary. Value *Result; @@ -1167,59 +1118,29 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { // inserting an extra IV increment. StepV might fold into PostLoopOffset, // but hopefully expandCodeFor handles that. bool useSubtract = - !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); + !S->getType()->isPointerTy() && Step->isNonConstantNegative(); if (useSubtract) Step = SE.getNegativeSCEV(Step); Value *StepV; { // Expand the step somewhere that dominates the loop header. SCEVInsertPointGuard Guard(Builder, this); - StepV = expandCodeForImpl( - Step, IntTy, &*L->getHeader()->getFirstInsertionPt()); + StepV = expand(Step, L->getHeader()->getFirstInsertionPt()); } - Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); + Result = expandIVInc(PN, StepV, L, useSubtract); } } // We have decided to reuse an induction variable of a dominating loop. Apply // truncation and/or inversion of the step. if (TruncTy) { - Type *ResTy = Result->getType(); - // Normalize the result type. - if (ResTy != SE.getEffectiveSCEVType(ResTy)) - Result = InsertNoopCastOfTo(Result, SE.getEffectiveSCEVType(ResTy)); // Truncate the result. if (TruncTy != Result->getType()) Result = Builder.CreateTrunc(Result, TruncTy); // Invert the result. if (InvertStep) - Result = Builder.CreateSub( - expandCodeForImpl(Normalized->getStart(), TruncTy), Result); - } - - // Re-apply any non-loop-dominating scale. - if (PostLoopScale) { - assert(S->isAffine() && "Can't linearly scale non-affine recurrences."); - Result = InsertNoopCastOfTo(Result, IntTy); - Result = Builder.CreateMul(Result, - expandCodeForImpl(PostLoopScale, IntTy)); - } - - // Re-apply any non-loop-dominating offset. - if (PostLoopOffset) { - if (isa<PointerType>(ExpandTy)) { - if (Result->getType()->isIntegerTy()) { - Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy); - Result = expandAddToGEP(SE.getUnknown(Result), IntTy, Base); - } else { - Result = expandAddToGEP(PostLoopOffset, IntTy, Result); - } - } else { - Result = InsertNoopCastOfTo(Result, IntTy); - Result = Builder.CreateAdd( - Result, expandCodeForImpl(PostLoopOffset, IntTy)); - } + Result = Builder.CreateSub(expand(Normalized->getStart()), Result); } return Result; @@ -1260,8 +1181,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { S->getNoWrapFlags(SCEV::FlagNW))); BasicBlock::iterator NewInsertPt = findInsertPointAfter(cast<Instruction>(V), &*Builder.GetInsertPoint()); - V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, - &*NewInsertPt); + V = expand(SE.getTruncateExpr(SE.getUnknown(V), Ty), NewInsertPt); return V; } @@ -1269,7 +1189,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { if (!S->getStart()->isZero()) { if (isa<PointerType>(S->getType())) { Value *StartV = expand(SE.getPointerBase(S)); - return expandAddToGEP(SE.removePointerBase(S), Ty, StartV); + return expandAddToGEP(SE.removePointerBase(S), StartV); } SmallVector<const SCEV *, 4> NewOps(S->operands()); @@ -1292,8 +1212,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { // specified loop. BasicBlock *Header = L->getHeader(); pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); - CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar", - &Header->front()); + CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar"); + CanonicalIV->insertBefore(Header->begin()); rememberInstruction(CanonicalIV); SmallSet<BasicBlock *, 4> PredSeen; @@ -1361,34 +1281,25 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { } Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) { - Value *V = - expandCodeForImpl(S->getOperand(), S->getOperand()->getType()); + Value *V = expand(S->getOperand()); return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt, GetOptimalInsertionPointForCastOf(V)); } Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { - Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *V = expandCodeForImpl( - S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()) - ); - return Builder.CreateTrunc(V, Ty); + Value *V = expand(S->getOperand()); + return Builder.CreateTrunc(V, S->getType()); } Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { - Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *V = expandCodeForImpl( - S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()) - ); - return Builder.CreateZExt(V, Ty); + Value *V = expand(S->getOperand()); + return Builder.CreateZExt(V, S->getType(), "", + SE.isKnownNonNegative(S->getOperand())); } Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { - Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *V = expandCodeForImpl( - S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()) - ); - return Builder.CreateSExt(V, Ty); + Value *V = expand(S->getOperand()); + return Builder.CreateSExt(V, S->getType()); } Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S, @@ -1399,7 +1310,7 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S, if (IsSequential) LHS = Builder.CreateFreeze(LHS); for (int i = S->getNumOperands() - 2; i >= 0; --i) { - Value *RHS = expandCodeForImpl(S->getOperand(i), Ty); + Value *RHS = expand(S->getOperand(i)); if (IsSequential && i != 0) RHS = Builder.CreateFreeze(RHS); Value *Sel; @@ -1440,14 +1351,14 @@ Value *SCEVExpander::visitVScale(const SCEVVScale *S) { return Builder.CreateVScale(ConstantInt::get(S->getType(), 1)); } -Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, - Instruction *IP) { +Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty, + BasicBlock::iterator IP) { setInsertPoint(IP); - Value *V = expandCodeForImpl(SH, Ty); + Value *V = expandCodeFor(SH, Ty); return V; } -Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty) { +Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) { // Expand the code for this SCEV. Value *V = expand(SH); @@ -1459,8 +1370,64 @@ Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty) { return V; } -Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S, - const Instruction *InsertPt) { +static bool +canReuseInstruction(ScalarEvolution &SE, const SCEV *S, Instruction *I, + SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) { + // If the instruction cannot be poison, it's always safe to reuse. + if (programUndefinedIfPoison(I)) + return true; + + // Otherwise, it is possible that I is more poisonous that S. Collect the + // poison-contributors of S, and then check whether I has any additional + // poison-contributors. Poison that is contributed through poison-generating + // flags is handled by dropping those flags instead. + SmallPtrSet<const Value *, 8> PoisonVals; + SE.getPoisonGeneratingValues(PoisonVals, S); + + SmallVector<Value *> Worklist; + SmallPtrSet<Value *, 8> Visited; + Worklist.push_back(I); + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + // Avoid walking large instruction graphs. + if (Visited.size() > 16) + return false; + + // Either the value can't be poison, or the S would also be poison if it + // is. + if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V)) + continue; + + auto *I = dyn_cast<Instruction>(V); + if (!I) + return false; + + // FIXME: Ignore vscale, even though it technically could be poison. Do this + // because SCEV currently assumes it can't be poison. Remove this special + // case once we proper model when vscale can be poison. + if (auto *II = dyn_cast<IntrinsicInst>(I); + II && II->getIntrinsicID() == Intrinsic::vscale) + continue; + + if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false)) + return false; + + // If the instruction can't create poison, we can recurse to its operands. + if (I->hasPoisonGeneratingFlagsOrMetadata()) + DropPoisonGeneratingInsts.push_back(I); + + for (Value *Op : I->operands()) + Worklist.push_back(Op); + } + return true; +} + +Value *SCEVExpander::FindValueInExprValueMap( + const SCEV *S, const Instruction *InsertPt, + SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) { // If the expansion is not in CanonicalMode, and the SCEV contains any // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally. if (!CanonicalMode && SE.containsAddRecurrence(S)) @@ -1470,20 +1437,24 @@ Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S, if (isa<SCEVConstant>(S)) return nullptr; - // Choose a Value from the set which dominates the InsertPt. - // InsertPt should be inside the Value's parent loop so as not to break - // the LCSSA form. for (Value *V : SE.getSCEVValues(S)) { Instruction *EntInst = dyn_cast<Instruction>(V); if (!EntInst) continue; + // Choose a Value from the set which dominates the InsertPt. + // InsertPt should be inside the Value's parent loop so as not to break + // the LCSSA form. assert(EntInst->getFunction() == InsertPt->getFunction()); - if (S->getType() == V->getType() && - SE.DT.dominates(EntInst, InsertPt) && - (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || - SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) + if (S->getType() != V->getType() || !SE.DT.dominates(EntInst, InsertPt) || + !(SE.LI.getLoopFor(EntInst->getParent()) == nullptr || + SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) + continue; + + // Make sure reusing the instruction is poison-safe. + if (canReuseInstruction(SE, S, EntInst, DropPoisonGeneratingInsts)) return V; + DropPoisonGeneratingInsts.clear(); } return nullptr; } @@ -1497,7 +1468,7 @@ Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S, Value *SCEVExpander::expand(const SCEV *S) { // Compute an insertion point for this SCEV object. Hoist the instructions // as far out in the loop nest as possible. - Instruction *InsertPt = &*Builder.GetInsertPoint(); + BasicBlock::iterator InsertPt = Builder.GetInsertPoint(); // We can move insertion point only if there is no div or rem operations // otherwise we are risky to move it over the check for zero denominator. @@ -1521,24 +1492,25 @@ Value *SCEVExpander::expand(const SCEV *S) { L = L->getParentLoop()) { if (SE.isLoopInvariant(S, L)) { if (!L) break; - if (BasicBlock *Preheader = L->getLoopPreheader()) - InsertPt = Preheader->getTerminator(); - else + if (BasicBlock *Preheader = L->getLoopPreheader()) { + InsertPt = Preheader->getTerminator()->getIterator(); + } else { // LSR sets the insertion point for AddRec start/step values to the // block start to simplify value reuse, even though it's an invalid // position. SCEVExpander must correct for this in all cases. - InsertPt = &*L->getHeader()->getFirstInsertionPt(); + InsertPt = L->getHeader()->getFirstInsertionPt(); + } } else { // If the SCEV is computable at this level, insert it into the header // after the PHIs (and after any other instructions that we've inserted // there) so that it is guaranteed to dominate any user inside the loop. if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) - InsertPt = &*L->getHeader()->getFirstInsertionPt(); + InsertPt = L->getHeader()->getFirstInsertionPt(); - while (InsertPt->getIterator() != Builder.GetInsertPoint() && - (isInsertedInstruction(InsertPt) || - isa<DbgInfoIntrinsic>(InsertPt))) { - InsertPt = &*std::next(InsertPt->getIterator()); + while (InsertPt != Builder.GetInsertPoint() && + (isInsertedInstruction(&*InsertPt) || + isa<DbgInfoIntrinsic>(&*InsertPt))) { + InsertPt = std::next(InsertPt); } break; } @@ -1546,26 +1518,40 @@ Value *SCEVExpander::expand(const SCEV *S) { } // Check to see if we already expanded this here. - auto I = InsertedExpressions.find(std::make_pair(S, InsertPt)); + auto I = InsertedExpressions.find(std::make_pair(S, &*InsertPt)); if (I != InsertedExpressions.end()) return I->second; SCEVInsertPointGuard Guard(Builder, this); - Builder.SetInsertPoint(InsertPt); + Builder.SetInsertPoint(InsertPt->getParent(), InsertPt); // Expand the expression into instructions. - Value *V = FindValueInExprValueMap(S, InsertPt); + SmallVector<Instruction *> DropPoisonGeneratingInsts; + Value *V = FindValueInExprValueMap(S, &*InsertPt, DropPoisonGeneratingInsts); if (!V) { V = visit(S); V = fixupLCSSAFormFor(V); } else { - // If we're reusing an existing instruction, we are effectively CSEing two - // copies of the instruction (with potentially different flags). As such, - // we need to drop any poison generating flags unless we can prove that - // said flags must be valid for all new users. - if (auto *I = dyn_cast<Instruction>(V)) - if (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I)) - I->dropPoisonGeneratingFlags(); + for (Instruction *I : DropPoisonGeneratingInsts) { + I->dropPoisonGeneratingFlagsAndMetadata(); + // See if we can re-infer from first principles any of the flags we just + // dropped. + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) + if (auto Flags = SE.getStrengthenedNoWrapFlagsFromBinOp(OBO)) { + auto *BO = cast<BinaryOperator>(I); + BO->setHasNoUnsignedWrap( + ScalarEvolution::maskFlags(*Flags, SCEV::FlagNUW) == SCEV::FlagNUW); + BO->setHasNoSignedWrap( + ScalarEvolution::maskFlags(*Flags, SCEV::FlagNSW) == SCEV::FlagNSW); + } + if (auto *NNI = dyn_cast<PossiblyNonNegInst>(I)) { + auto *Src = NNI->getOperand(0); + if (isImpliedByDomCondition(ICmpInst::ICMP_SGE, Src, + Constant::getNullValue(Src->getType()), I, + DL).value_or(false)) + NNI->setNonNeg(true); + } + } } // Remember the expanded value for this SCEV at this location. // @@ -1573,7 +1559,7 @@ Value *SCEVExpander::expand(const SCEV *S) { // the expression at this insertion point. If the mapped value happened to be // a postinc expansion, it could be reused by a non-postinc user, but only if // its insertion point was already at the head of the loop. - InsertedExpressions[std::make_pair(S, InsertPt)] = V; + InsertedExpressions[std::make_pair(S, &*InsertPt)] = V; return V; } @@ -1710,13 +1696,13 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, << *IsomorphicInc << '\n'); Value *NewInc = OrigInc; if (OrigInc->getType() != IsomorphicInc->getType()) { - Instruction *IP = nullptr; + BasicBlock::iterator IP; if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) - IP = &*PN->getParent()->getFirstInsertionPt(); + IP = PN->getParent()->getFirstInsertionPt(); else - IP = OrigInc->getNextNode(); + IP = OrigInc->getNextNonDebugInstruction()->getIterator(); - IRBuilder<> Builder(IP); + IRBuilder<> Builder(IP->getParent(), IP); Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); NewInc = Builder.CreateTruncOrBitCast( OrigInc, IsomorphicInc->getType(), IVName); @@ -1734,7 +1720,8 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, ++NumElim; Value *NewIV = OrigPhiRef; if (OrigPhiRef->getType() != Phi->getType()) { - IRBuilder<> Builder(&*L->getHeader()->getFirstInsertionPt()); + IRBuilder<> Builder(L->getHeader(), + L->getHeader()->getFirstInsertionPt()); Builder.SetCurrentDebugLocation(Phi->getDebugLoc()); NewIV = Builder.CreateTruncOrBitCast(OrigPhiRef, Phi->getType(), IVName); } @@ -1744,9 +1731,9 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, return NumElim; } -Value *SCEVExpander::getRelatedExistingExpansion(const SCEV *S, - const Instruction *At, - Loop *L) { +bool SCEVExpander::hasRelatedExistingExpansion(const SCEV *S, + const Instruction *At, + Loop *L) { using namespace llvm::PatternMatch; SmallVector<BasicBlock *, 4> ExitingBlocks; @@ -1763,17 +1750,18 @@ Value *SCEVExpander::getRelatedExistingExpansion(const SCEV *S, continue; if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) - return LHS; + return true; if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At)) - return RHS; + return true; } // Use expand's logic which is used for reusing a previous Value in // ExprValueMap. Note that we don't currently model the cost of // needing to drop poison generating flags on the instruction if we // want to reuse it. We effectively assume that has zero cost. - return FindValueInExprValueMap(S, At); + SmallVector<Instruction *> DropPoisonGeneratingInsts; + return FindValueInExprValueMap(S, At, DropPoisonGeneratingInsts) != nullptr; } template<typename T> static InstructionCost costAndCollectOperands( @@ -1951,7 +1939,7 @@ bool SCEVExpander::isHighCostExpansionHelper( // If we can find an existing value for this scev available at the point "At" // then consider the expression cheap. - if (getRelatedExistingExpansion(S, &At, L)) + if (hasRelatedExistingExpansion(S, &At, L)) return false; // Consider the expression to be free. TargetTransformInfo::TargetCostKind CostKind = @@ -1993,7 +1981,7 @@ bool SCEVExpander::isHighCostExpansionHelper( // At the beginning of this function we already tried to find existing // value for plain 'S'. Now try to lookup 'S + 1' since it is common // pattern involving division. This is just a simple search heuristic. - if (getRelatedExistingExpansion( + if (hasRelatedExistingExpansion( SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), &At, L)) return false; // Consider it to be free. @@ -2045,10 +2033,8 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred, Instruction *IP) { - Value *Expr0 = - expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP); - Value *Expr1 = - expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP); + Value *Expr0 = expand(Pred->getLHS(), IP); + Value *Expr1 = expand(Pred->getRHS(), IP); Builder.SetInsertPoint(IP); auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate()); @@ -2080,17 +2066,15 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, // Step >= 0, Start + |Step| * Backedge > Start // and |Step| * Backedge doesn't unsigned overflow. - IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); Builder.SetInsertPoint(Loc); - Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc); + Value *TripCountVal = expand(ExitCount, Loc); IntegerType *Ty = IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy)); - Value *StepValue = expandCodeForImpl(Step, Ty, Loc); - Value *NegStepValue = - expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc); - Value *StartValue = expandCodeForImpl(Start, ARTy, Loc); + Value *StepValue = expand(Step, Loc); + Value *NegStepValue = expand(SE.getNegativeSCEV(Step), Loc); + Value *StartValue = expand(Start, Loc); ConstantInt *Zero = ConstantInt::get(Loc->getContext(), APInt::getZero(DstBits)); @@ -2136,9 +2120,7 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, bool NeedPosCheck = !SE.isKnownNegative(Step); bool NeedNegCheck = !SE.isKnownPositive(Step); - if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARTy)) { - StartValue = InsertNoopCastOfTo( - StartValue, Builder.getInt8PtrTy(ARPtrTy->getAddressSpace())); + if (isa<PointerType>(ARTy)) { Value *NegMulV = Builder.CreateNeg(MulV); if (NeedPosCheck) Add = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, MulV); @@ -2171,7 +2153,7 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, // If the backedge taken count type is larger than the AR type, // check that we don't drop any bits by truncating it. If we are // dropping bits, then we have overflow (unless the step is zero). - if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { + if (SrcBits > DstBits) { auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); auto *BackedgeCheck = Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, @@ -2244,7 +2226,7 @@ Value *SCEVExpander::fixupLCSSAFormFor(Value *V) { // instruction. Type *ToTy; if (DefI->getType()->isIntegerTy()) - ToTy = DefI->getType()->getPointerTo(); + ToTy = PointerType::get(DefI->getContext(), 0); else ToTy = Type::getInt32Ty(DefI->getContext()); Instruction *User = @@ -2306,12 +2288,6 @@ struct SCEVFindUnsafe { } } if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { - const SCEV *Step = AR->getStepRecurrence(SE); - if (!AR->isAffine() && !SE.dominates(Step, AR->getLoop()->getHeader())) { - IsUnsafe = true; - return false; - } - // For non-affine addrecs or in non-canonical mode we need a preheader // to insert into. if (!AR->getLoop()->getLoopPreheader() && diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index d3a9a41aef15..c09cf9c2325c 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -271,7 +271,10 @@ class SimplifyCFGOpt { bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI, IRBuilder<> &Builder); - bool HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly); + bool hoistCommonCodeFromSuccessors(BasicBlock *BB, bool EqTermsOnly); + bool hoistSuccIdenticalTerminatorToSwitchOrIf( + Instruction *TI, Instruction *I1, + SmallVectorImpl<Instruction *> &OtherSuccTIs); bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB); bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond, BasicBlock *TrueBB, BasicBlock *FalseBB, @@ -499,7 +502,7 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout &DL) { return CI; else return cast<ConstantInt>( - ConstantExpr::getIntegerCast(CI, PtrTy, /*isSigned=*/false)); + ConstantFoldIntegerCast(CI, PtrTy, /*isSigned=*/false, DL)); } return nullptr; } @@ -819,7 +822,7 @@ BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases( static void EliminateBlockCases(BasicBlock *BB, std::vector<ValueEqualityComparisonCase> &Cases) { - llvm::erase_value(Cases, BB); + llvm::erase(Cases, BB); } /// Return true if there are any keys in C1 that exist in C2 as well. @@ -1098,12 +1101,13 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( // Note that there may be multiple predecessor blocks, so we cannot move // bonus instructions to a predecessor block. for (Instruction &BonusInst : *BB) { - if (isa<DbgInfoIntrinsic>(BonusInst) || BonusInst.isTerminator()) + if (BonusInst.isTerminator()) continue; Instruction *NewBonusInst = BonusInst.clone(); - if (PTI->getDebugLoc() != NewBonusInst->getDebugLoc()) { + if (!isa<DbgInfoIntrinsic>(BonusInst) && + PTI->getDebugLoc() != NewBonusInst->getDebugLoc()) { // Unless the instruction has the same !dbg location as the original // branch, drop it. When we fold the bonus instructions we want to make // sure we reset their debug locations in order to avoid stepping on @@ -1113,7 +1117,6 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( RemapInstruction(NewBonusInst, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - VMap[&BonusInst] = NewBonusInst; // If we speculated an instruction, we need to drop any metadata that may // result in undefined behavior, as the metadata might have been valid @@ -1123,8 +1126,16 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( NewBonusInst->dropUBImplyingAttrsAndMetadata(); NewBonusInst->insertInto(PredBlock, PTI->getIterator()); + auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst); + RemapDPValueRange(NewBonusInst->getModule(), Range, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + if (isa<DbgInfoIntrinsic>(BonusInst)) + continue; + NewBonusInst->takeName(&BonusInst); BonusInst.setName(NewBonusInst->getName() + ".old"); + VMap[&BonusInst] = NewBonusInst; // Update (liveout) uses of bonus instructions, // now that the bonus instruction has been cloned into predecessor. @@ -1303,7 +1314,7 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding( } for (const std::pair<BasicBlock *, int /*Num*/> &NewSuccessor : NewSuccessors) { - for (auto I : seq(0, NewSuccessor.second)) { + for (auto I : seq(NewSuccessor.second)) { (void)I; AddPredecessorToBlock(NewSuccessor.first, Pred, BB); } @@ -1408,8 +1419,9 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(Instruction *TI, } // If we would need to insert a select that uses the value of this invoke -// (comments in HoistThenElseCodeToIf explain why we would need to do this), we -// can't hoist the invoke, as there is nowhere to put the select in this case. +// (comments in hoistSuccIdenticalTerminatorToSwitchOrIf explain why we would +// need to do this), we can't hoist the invoke, as there is nowhere to put the +// select in this case. static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, Instruction *I1, Instruction *I2) { for (BasicBlock *Succ : successors(BB1)) { @@ -1424,9 +1436,9 @@ static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, return true; } -// Get interesting characteristics of instructions that `HoistThenElseCodeToIf` -// didn't hoist. They restrict what kind of instructions can be reordered -// across. +// Get interesting characteristics of instructions that +// `hoistCommonCodeFromSuccessors` didn't hoist. They restrict what kind of +// instructions can be reordered across. enum SkipFlags { SkipReadMem = 1, SkipSideEffect = 2, @@ -1484,7 +1496,7 @@ static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) { static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValueMayBeModified = false); -/// Helper function for HoistThenElseCodeToIf. Return true if identical +/// Helper function for hoistCommonCodeFromSuccessors. Return true if identical /// instructions \p I1 and \p I2 can and should be hoisted. static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2, const TargetTransformInfo &TTI) { @@ -1515,62 +1527,51 @@ static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2, return true; } -/// Given a conditional branch that goes to BB1 and BB2, hoist any common code -/// in the two blocks up into the branch block. The caller of this function -/// guarantees that BI's block dominates BB1 and BB2. If EqTermsOnly is given, -/// only perform hoisting in case both blocks only contain a terminator. In that -/// case, only the original BI will be replaced and selects for PHIs are added. -bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly) { +/// Hoist any common code in the successor blocks up into the block. This +/// function guarantees that BB dominates all successors. If EqTermsOnly is +/// given, only perform hoisting in case both blocks only contain a terminator. +/// In that case, only the original BI will be replaced and selects for PHIs are +/// added. +bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB, + bool EqTermsOnly) { // This does very trivial matching, with limited scanning, to find identical - // instructions in the two blocks. In particular, we don't want to get into - // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As + // instructions in the two blocks. In particular, we don't want to get into + // O(N1*N2*...) situations here where Ni are the sizes of these successors. As // such, we currently just scan for obviously identical instructions in an // identical order, possibly separated by the same number of non-identical // instructions. - BasicBlock *BB1 = BI->getSuccessor(0); // The true destination. - BasicBlock *BB2 = BI->getSuccessor(1); // The false destination + unsigned int SuccSize = succ_size(BB); + if (SuccSize < 2) + return false; // If either of the blocks has it's address taken, then we can't do this fold, // because the code we'd hoist would no longer run when we jump into the block // by it's address. - if (BB1->hasAddressTaken() || BB2->hasAddressTaken()) - return false; + for (auto *Succ : successors(BB)) + if (Succ->hasAddressTaken() || !Succ->getSinglePredecessor()) + return false; - BasicBlock::iterator BB1_Itr = BB1->begin(); - BasicBlock::iterator BB2_Itr = BB2->begin(); + auto *TI = BB->getTerminator(); - Instruction *I1 = &*BB1_Itr++, *I2 = &*BB2_Itr++; - // Skip debug info if it is not identical. - DbgInfoIntrinsic *DBI1 = dyn_cast<DbgInfoIntrinsic>(I1); - DbgInfoIntrinsic *DBI2 = dyn_cast<DbgInfoIntrinsic>(I2); - if (!DBI1 || !DBI2 || !DBI1->isIdenticalToWhenDefined(DBI2)) { - while (isa<DbgInfoIntrinsic>(I1)) - I1 = &*BB1_Itr++; - while (isa<DbgInfoIntrinsic>(I2)) - I2 = &*BB2_Itr++; + // The second of pair is a SkipFlags bitmask. + using SuccIterPair = std::pair<BasicBlock::iterator, unsigned>; + SmallVector<SuccIterPair, 8> SuccIterPairs; + for (auto *Succ : successors(BB)) { + BasicBlock::iterator SuccItr = Succ->begin(); + if (isa<PHINode>(*SuccItr)) + return false; + SuccIterPairs.push_back(SuccIterPair(SuccItr, 0)); } - if (isa<PHINode>(I1)) - return false; - - BasicBlock *BIParent = BI->getParent(); - - bool Changed = false; - - auto _ = make_scope_exit([&]() { - if (Changed) - ++NumHoistCommonCode; - }); // Check if only hoisting terminators is allowed. This does not add new // instructions to the hoist location. if (EqTermsOnly) { // Skip any debug intrinsics, as they are free to hoist. - auto *I1NonDbg = &*skipDebugIntrinsics(I1->getIterator()); - auto *I2NonDbg = &*skipDebugIntrinsics(I2->getIterator()); - if (!I1NonDbg->isIdenticalToWhenDefined(I2NonDbg)) - return false; - if (!I1NonDbg->isTerminator()) - return false; + for (auto &SuccIter : make_first_range(SuccIterPairs)) { + auto *INonDbg = &*skipDebugIntrinsics(SuccIter); + if (!INonDbg->isTerminator()) + return false; + } // Now we know that we only need to hoist debug intrinsics and the // terminator. Let the loop below handle those 2 cases. } @@ -1579,153 +1580,235 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, bool EqTermsOnly) { // many instructions we skip, serving as a compilation time control as well as // preventing excessive increase of life ranges. unsigned NumSkipped = 0; + // If we find an unreachable instruction at the beginning of a basic block, we + // can still hoist instructions from the rest of the basic blocks. + if (SuccIterPairs.size() > 2) { + erase_if(SuccIterPairs, + [](const auto &Pair) { return isa<UnreachableInst>(Pair.first); }); + if (SuccIterPairs.size() < 2) + return false; + } - // Record any skipped instuctions that may read memory, write memory or have - // side effects, or have implicit control flow. - unsigned SkipFlagsBB1 = 0; - unsigned SkipFlagsBB2 = 0; + bool Changed = false; for (;;) { + auto *SuccIterPairBegin = SuccIterPairs.begin(); + auto &BB1ItrPair = *SuccIterPairBegin++; + auto OtherSuccIterPairRange = + iterator_range(SuccIterPairBegin, SuccIterPairs.end()); + auto OtherSuccIterRange = make_first_range(OtherSuccIterPairRange); + + Instruction *I1 = &*BB1ItrPair.first; + auto *BB1 = I1->getParent(); + + // Skip debug info if it is not identical. + bool AllDbgInstsAreIdentical = all_of(OtherSuccIterRange, [I1](auto &Iter) { + Instruction *I2 = &*Iter; + return I1->isIdenticalToWhenDefined(I2); + }); + if (!AllDbgInstsAreIdentical) { + while (isa<DbgInfoIntrinsic>(I1)) + I1 = &*++BB1ItrPair.first; + for (auto &SuccIter : OtherSuccIterRange) { + Instruction *I2 = &*SuccIter; + while (isa<DbgInfoIntrinsic>(I2)) + I2 = &*++SuccIter; + } + } + + bool AllInstsAreIdentical = true; + bool HasTerminator = I1->isTerminator(); + for (auto &SuccIter : OtherSuccIterRange) { + Instruction *I2 = &*SuccIter; + HasTerminator |= I2->isTerminator(); + if (AllInstsAreIdentical && !I1->isIdenticalToWhenDefined(I2)) + AllInstsAreIdentical = false; + } + // If we are hoisting the terminator instruction, don't move one (making a // broken BB), instead clone it, and remove BI. - if (I1->isTerminator() || I2->isTerminator()) { + if (HasTerminator) { + // Even if BB, which contains only one unreachable instruction, is ignored + // at the beginning of the loop, we can hoist the terminator instruction. // If any instructions remain in the block, we cannot hoist terminators. - if (NumSkipped || !I1->isIdenticalToWhenDefined(I2)) + if (NumSkipped || !AllInstsAreIdentical) return Changed; - goto HoistTerminator; + SmallVector<Instruction *, 8> Insts; + for (auto &SuccIter : OtherSuccIterRange) + Insts.push_back(&*SuccIter); + return hoistSuccIdenticalTerminatorToSwitchOrIf(TI, I1, Insts) || Changed; } - if (I1->isIdenticalToWhenDefined(I2) && - // Even if the instructions are identical, it may not be safe to hoist - // them if we have skipped over instructions with side effects or their - // operands weren't hoisted. - isSafeToHoistInstr(I1, SkipFlagsBB1) && - isSafeToHoistInstr(I2, SkipFlagsBB2) && - shouldHoistCommonInstructions(I1, I2, TTI)) { - if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) { - assert(isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2)); + if (AllInstsAreIdentical) { + unsigned SkipFlagsBB1 = BB1ItrPair.second; + AllInstsAreIdentical = + isSafeToHoistInstr(I1, SkipFlagsBB1) && + all_of(OtherSuccIterPairRange, [=](const auto &Pair) { + Instruction *I2 = &*Pair.first; + unsigned SkipFlagsBB2 = Pair.second; + // Even if the instructions are identical, it may not + // be safe to hoist them if we have skipped over + // instructions with side effects or their operands + // weren't hoisted. + return isSafeToHoistInstr(I2, SkipFlagsBB2) && + shouldHoistCommonInstructions(I1, I2, TTI); + }); + } + + if (AllInstsAreIdentical) { + BB1ItrPair.first++; + if (isa<DbgInfoIntrinsic>(I1)) { // The debug location is an integral part of a debug info intrinsic // and can't be separated from it or replaced. Instead of attempting // to merge locations, simply hoist both copies of the intrinsic. - BIParent->splice(BI->getIterator(), BB1, I1->getIterator()); - BIParent->splice(BI->getIterator(), BB2, I2->getIterator()); + I1->moveBeforePreserving(TI); + for (auto &SuccIter : OtherSuccIterRange) { + auto *I2 = &*SuccIter++; + assert(isa<DbgInfoIntrinsic>(I2)); + I2->moveBeforePreserving(TI); + } } else { // For a normal instruction, we just move one to right before the // branch, then replace all uses of the other with the first. Finally, // we remove the now redundant second instruction. - BIParent->splice(BI->getIterator(), BB1, I1->getIterator()); - if (!I2->use_empty()) - I2->replaceAllUsesWith(I1); - I1->andIRFlags(I2); - combineMetadataForCSE(I1, I2, true); - - // I1 and I2 are being combined into a single instruction. Its debug - // location is the merged locations of the original instructions. - I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); - - I2->eraseFromParent(); + I1->moveBeforePreserving(TI); + BB->splice(TI->getIterator(), BB1, I1->getIterator()); + for (auto &SuccIter : OtherSuccIterRange) { + Instruction *I2 = &*SuccIter++; + assert(I2 != I1); + if (!I2->use_empty()) + I2->replaceAllUsesWith(I1); + I1->andIRFlags(I2); + combineMetadataForCSE(I1, I2, true); + // I1 and I2 are being combined into a single instruction. Its debug + // location is the merged locations of the original instructions. + I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); + I2->eraseFromParent(); + } } + if (!Changed) + NumHoistCommonCode += SuccIterPairs.size(); Changed = true; - ++NumHoistCommonInstrs; + NumHoistCommonInstrs += SuccIterPairs.size(); } else { if (NumSkipped >= HoistCommonSkipLimit) return Changed; // We are about to skip over a pair of non-identical instructions. Record // if any have characteristics that would prevent reordering instructions // across them. - SkipFlagsBB1 |= skippedInstrFlags(I1); - SkipFlagsBB2 |= skippedInstrFlags(I2); + for (auto &SuccIterPair : SuccIterPairs) { + Instruction *I = &*SuccIterPair.first++; + SuccIterPair.second |= skippedInstrFlags(I); + } ++NumSkipped; } - - I1 = &*BB1_Itr++; - I2 = &*BB2_Itr++; - // Skip debug info if it is not identical. - DbgInfoIntrinsic *DBI1 = dyn_cast<DbgInfoIntrinsic>(I1); - DbgInfoIntrinsic *DBI2 = dyn_cast<DbgInfoIntrinsic>(I2); - if (!DBI1 || !DBI2 || !DBI1->isIdenticalToWhenDefined(DBI2)) { - while (isa<DbgInfoIntrinsic>(I1)) - I1 = &*BB1_Itr++; - while (isa<DbgInfoIntrinsic>(I2)) - I2 = &*BB2_Itr++; - } } +} - return Changed; +bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( + Instruction *TI, Instruction *I1, + SmallVectorImpl<Instruction *> &OtherSuccTIs) { -HoistTerminator: - // It may not be possible to hoist an invoke. + auto *BI = dyn_cast<BranchInst>(TI); + + bool Changed = false; + BasicBlock *TIParent = TI->getParent(); + BasicBlock *BB1 = I1->getParent(); + + // Use only for an if statement. + auto *I2 = *OtherSuccTIs.begin(); + auto *BB2 = I2->getParent(); + if (BI) { + assert(OtherSuccTIs.size() == 1); + assert(BI->getSuccessor(0) == I1->getParent()); + assert(BI->getSuccessor(1) == I2->getParent()); + } + + // In the case of an if statement, we try to hoist an invoke. // FIXME: Can we define a safety predicate for CallBr? - if (isa<InvokeInst>(I1) && !isSafeToHoistInvoke(BB1, BB2, I1, I2)) - return Changed; + // FIXME: Test case llvm/test/Transforms/SimplifyCFG/2009-06-15-InvokeCrash.ll + // removed in 4c923b3b3fd0ac1edebf0603265ca3ba51724937 commit? + if (isa<InvokeInst>(I1) && (!BI || !isSafeToHoistInvoke(BB1, BB2, I1, I2))) + return false; // TODO: callbr hoisting currently disabled pending further study. if (isa<CallBrInst>(I1)) - return Changed; + return false; for (BasicBlock *Succ : successors(BB1)) { for (PHINode &PN : Succ->phis()) { Value *BB1V = PN.getIncomingValueForBlock(BB1); - Value *BB2V = PN.getIncomingValueForBlock(BB2); - if (BB1V == BB2V) - continue; + for (Instruction *OtherSuccTI : OtherSuccTIs) { + Value *BB2V = PN.getIncomingValueForBlock(OtherSuccTI->getParent()); + if (BB1V == BB2V) + continue; - // Check for passingValueIsAlwaysUndefined here because we would rather - // eliminate undefined control flow then converting it to a select. - if (passingValueIsAlwaysUndefined(BB1V, &PN) || - passingValueIsAlwaysUndefined(BB2V, &PN)) - return Changed; + // In the case of an if statement, check for + // passingValueIsAlwaysUndefined here because we would rather eliminate + // undefined control flow then converting it to a select. + if (!BI || passingValueIsAlwaysUndefined(BB1V, &PN) || + passingValueIsAlwaysUndefined(BB2V, &PN)) + return false; + } } } // Okay, it is safe to hoist the terminator. Instruction *NT = I1->clone(); - NT->insertInto(BIParent, BI->getIterator()); + NT->insertInto(TIParent, TI->getIterator()); if (!NT->getType()->isVoidTy()) { I1->replaceAllUsesWith(NT); - I2->replaceAllUsesWith(NT); + for (Instruction *OtherSuccTI : OtherSuccTIs) + OtherSuccTI->replaceAllUsesWith(NT); NT->takeName(I1); } Changed = true; - ++NumHoistCommonInstrs; + NumHoistCommonInstrs += OtherSuccTIs.size() + 1; // Ensure terminator gets a debug location, even an unknown one, in case // it involves inlinable calls. - NT->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); + SmallVector<DILocation *, 4> Locs; + Locs.push_back(I1->getDebugLoc()); + for (auto *OtherSuccTI : OtherSuccTIs) + Locs.push_back(OtherSuccTI->getDebugLoc()); + NT->setDebugLoc(DILocation::getMergedLocations(Locs)); // PHIs created below will adopt NT's merged DebugLoc. IRBuilder<NoFolder> Builder(NT); - // Hoisting one of the terminators from our successor is a great thing. - // Unfortunately, the successors of the if/else blocks may have PHI nodes in - // them. If they do, all PHI entries for BB1/BB2 must agree for all PHI - // nodes, so we insert select instruction to compute the final result. - std::map<std::pair<Value *, Value *>, SelectInst *> InsertedSelects; - for (BasicBlock *Succ : successors(BB1)) { - for (PHINode &PN : Succ->phis()) { - Value *BB1V = PN.getIncomingValueForBlock(BB1); - Value *BB2V = PN.getIncomingValueForBlock(BB2); - if (BB1V == BB2V) - continue; + // In the case of an if statement, hoisting one of the terminators from our + // successor is a great thing. Unfortunately, the successors of the if/else + // blocks may have PHI nodes in them. If they do, all PHI entries for BB1/BB2 + // must agree for all PHI nodes, so we insert select instruction to compute + // the final result. + if (BI) { + std::map<std::pair<Value *, Value *>, SelectInst *> InsertedSelects; + for (BasicBlock *Succ : successors(BB1)) { + for (PHINode &PN : Succ->phis()) { + Value *BB1V = PN.getIncomingValueForBlock(BB1); + Value *BB2V = PN.getIncomingValueForBlock(BB2); + if (BB1V == BB2V) + continue; - // These values do not agree. Insert a select instruction before NT - // that determines the right value. - SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)]; - if (!SI) { - // Propagate fast-math-flags from phi node to its replacement select. - IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); - if (isa<FPMathOperator>(PN)) - Builder.setFastMathFlags(PN.getFastMathFlags()); + // These values do not agree. Insert a select instruction before NT + // that determines the right value. + SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)]; + if (!SI) { + // Propagate fast-math-flags from phi node to its replacement select. + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + if (isa<FPMathOperator>(PN)) + Builder.setFastMathFlags(PN.getFastMathFlags()); - SI = cast<SelectInst>( - Builder.CreateSelect(BI->getCondition(), BB1V, BB2V, - BB1V->getName() + "." + BB2V->getName(), BI)); - } + SI = cast<SelectInst>(Builder.CreateSelect( + BI->getCondition(), BB1V, BB2V, + BB1V->getName() + "." + BB2V->getName(), BI)); + } - // Make the PHI node use the select for all incoming values for BB1/BB2 - for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) - if (PN.getIncomingBlock(i) == BB1 || PN.getIncomingBlock(i) == BB2) - PN.setIncomingValue(i, SI); + // Make the PHI node use the select for all incoming values for BB1/BB2 + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) + if (PN.getIncomingBlock(i) == BB1 || PN.getIncomingBlock(i) == BB2) + PN.setIncomingValue(i, SI); + } } } @@ -1733,16 +1816,16 @@ HoistTerminator: // Update any PHI nodes in our new successors. for (BasicBlock *Succ : successors(BB1)) { - AddPredecessorToBlock(Succ, BIParent, BB1); + AddPredecessorToBlock(Succ, TIParent, BB1); if (DTU) - Updates.push_back({DominatorTree::Insert, BIParent, Succ}); + Updates.push_back({DominatorTree::Insert, TIParent, Succ}); } if (DTU) - for (BasicBlock *Succ : successors(BI)) - Updates.push_back({DominatorTree::Delete, BIParent, Succ}); + for (BasicBlock *Succ : successors(TI)) + Updates.push_back({DominatorTree::Delete, TIParent, Succ}); - EraseTerminatorAndDCECond(BI); + EraseTerminatorAndDCECond(TI); if (DTU) DTU->applyUpdates(Updates); return Changed; @@ -1808,10 +1891,19 @@ static bool canSinkInstructions( } const Instruction *I0 = Insts.front(); - for (auto *I : Insts) + for (auto *I : Insts) { if (!I->isSameOperationAs(I0)) return false; + // swifterror pointers can only be used by a load or store; sinking a load + // or store would require introducing a select for the pointer operand, + // which isn't allowed for swifterror pointers. + if (isa<StoreInst>(I) && I->getOperand(1)->isSwiftError()) + return false; + if (isa<LoadInst>(I) && I->getOperand(0)->isSwiftError()) + return false; + } + // All instructions in Insts are known to be the same opcode. If they have a // use, check that the only user is a PHI or in the same block as the // instruction, because if a user is in the same block as an instruction we're @@ -1952,8 +2044,9 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { // Create a new PHI in the successor block and populate it. auto *Op = I0->getOperand(O); assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!"); - auto *PN = PHINode::Create(Op->getType(), Insts.size(), - Op->getName() + ".sink", &BBEnd->front()); + auto *PN = + PHINode::Create(Op->getType(), Insts.size(), Op->getName() + ".sink"); + PN->insertBefore(BBEnd->begin()); for (auto *I : Insts) PN->addIncoming(I->getOperand(O), I->getParent()); NewOperands.push_back(PN); @@ -1963,7 +2056,8 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { // and move it to the start of the successor block. for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) I0->getOperandUse(O).set(NewOperands[O]); - I0->moveBefore(&*BBEnd->getFirstInsertionPt()); + + I0->moveBefore(*BBEnd, BBEnd->getFirstInsertionPt()); // Update metadata and IR flags, and merge debug locations. for (auto *I : Insts) @@ -2765,8 +2859,8 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB, Value *OrigV = PN.getIncomingValueForBlock(BB); Value *ThenV = PN.getIncomingValueForBlock(ThenBB); - // FIXME: Try to remove some of the duplication with HoistThenElseCodeToIf. - // Skip PHIs which are trivial. + // FIXME: Try to remove some of the duplication with + // hoistCommonCodeFromSuccessors. Skip PHIs which are trivial. if (ThenV == OrigV) continue; @@ -3009,7 +3103,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, // store %merge, %x.dest, !DIAssignID !2 // dbg.assign %merge, "x", ..., !2 for (auto *DAI : at::getAssignmentMarkers(SpeculatedStore)) { - if (any_of(DAI->location_ops(), [&](Value *V) { return V == OrigV; })) + if (llvm::is_contained(DAI->location_ops(), OrigV)) DAI->replaceVariableLocationOp(OrigV, S); } } @@ -3036,6 +3130,11 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, } // Hoist the instructions. + // In "RemoveDIs" non-instr debug-info mode, drop DPValues attached to these + // instructions, in the same way that dbg.value intrinsics are dropped at the + // end of this block. + for (auto &It : make_range(ThenBB->begin(), ThenBB->end())) + It.dropDbgValues(); BB->splice(BI->getIterator(), ThenBB, ThenBB->begin(), std::prev(ThenBB->end())); @@ -3207,6 +3306,10 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, BasicBlock::iterator InsertPt = EdgeBB->getFirstInsertionPt(); DenseMap<Value *, Value *> TranslateMap; // Track translated values. TranslateMap[Cond] = CB; + + // RemoveDIs: track instructions that we optimise away while folding, so + // that we can copy DPValues from them later. + BasicBlock::iterator SrcDbgCursor = BB->begin(); for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { if (PHINode *PN = dyn_cast<PHINode>(BBI)) { TranslateMap[PN] = PN->getIncomingValueForBlock(EdgeBB); @@ -3241,6 +3344,15 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, TranslateMap[&*BBI] = N; } if (N) { + // Copy all debug-info attached to instructions from the last we + // successfully clone, up to this instruction (they might have been + // folded away). + for (; SrcDbgCursor != BBI; ++SrcDbgCursor) + N->cloneDebugInfoFrom(&*SrcDbgCursor); + SrcDbgCursor = std::next(BBI); + // Clone debug-info on this instruction too. + N->cloneDebugInfoFrom(&*BBI); + // Register the new instruction with the assumption cache if necessary. if (auto *Assume = dyn_cast<AssumeInst>(N)) if (AC) @@ -3248,6 +3360,10 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, } } + for (; &*SrcDbgCursor != BI; ++SrcDbgCursor) + InsertPt->cloneDebugInfoFrom(&*SrcDbgCursor); + InsertPt->cloneDebugInfoFrom(BI); + BB->removePredecessor(EdgeBB); BranchInst *EdgeBI = cast<BranchInst>(EdgeBB->getTerminator()); EdgeBI->setSuccessor(0, RealDest); @@ -3652,22 +3768,22 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, ValueToValueMapTy VMap; // maps original values to cloned values CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(BB, PredBlock, VMap); + Module *M = BB->getModule(); + + if (PredBlock->IsNewDbgInfoFormat) { + PredBlock->getTerminator()->cloneDebugInfoFrom(BB->getTerminator()); + for (DPValue &DPV : PredBlock->getTerminator()->getDbgValueRange()) { + RemapDPValue(M, &DPV, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + } + } + // Now that the Cond was cloned into the predecessor basic block, // or/and the two conditions together. Value *BICond = VMap[BI->getCondition()]; PBI->setCondition( createLogicalOp(Builder, Opc, PBI->getCondition(), BICond, "or.cond")); - // Copy any debug value intrinsics into the end of PredBlock. - for (Instruction &I : *BB) { - if (isa<DbgInfoIntrinsic>(I)) { - Instruction *NewI = I.clone(); - RemapInstruction(NewI, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - NewI->insertBefore(PBI); - } - } - ++NumFoldBranchToCommonDest; return true; } @@ -3867,7 +3983,8 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != BB)) return V; - PHI = PHINode::Create(V->getType(), 2, "simplifycfg.merge", &Succ->front()); + PHI = PHINode::Create(V->getType(), 2, "simplifycfg.merge"); + PHI->insertBefore(Succ->begin()); PHI->addIncoming(V, BB); for (BasicBlock *PredBB : predecessors(Succ)) if (PredBB != BB) @@ -3991,7 +4108,9 @@ static bool mergeConditionalStoreToAddress( Value *QPHI = ensureValueAvailableInSuccessor(QStore->getValueOperand(), QStore->getParent(), PPHI); - IRBuilder<> QB(&*PostBB->getFirstInsertionPt()); + BasicBlock::iterator PostBBFirst = PostBB->getFirstInsertionPt(); + IRBuilder<> QB(PostBB, PostBBFirst); + QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc()); Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond); Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond); @@ -4002,9 +4121,11 @@ static bool mergeConditionalStoreToAddress( QPred = QB.CreateNot(QPred); Value *CombinedPred = QB.CreateOr(PPred, QPred); - auto *T = SplitBlockAndInsertIfThen(CombinedPred, &*QB.GetInsertPoint(), + BasicBlock::iterator InsertPt = QB.GetInsertPoint(); + auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt, /*Unreachable=*/false, /*BranchWeights=*/nullptr, DTU); + QB.SetInsertPoint(T); StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address)); SI->setAAMetadata(PStore->getAAMetadata().merge(QStore->getAAMetadata())); @@ -4140,10 +4261,10 @@ static bool tryWidenCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // 2) We can sink side effecting instructions into BI's fallthrough // successor provided they doesn't contribute to computation of // BI's condition. - Value *CondWB, *WC; - BasicBlock *IfTrueBB, *IfFalseBB; - if (!parseWidenableBranch(PBI, CondWB, WC, IfTrueBB, IfFalseBB) || - IfTrueBB != BI->getParent() || !BI->getParent()->getSinglePredecessor()) + BasicBlock *IfTrueBB = PBI->getSuccessor(0); + BasicBlock *IfFalseBB = PBI->getSuccessor(1); + if (!isWidenableBranch(PBI) || IfTrueBB != BI->getParent() || + !BI->getParent()->getSinglePredecessor()) return false; if (!IfFalseBB->phis().empty()) return false; // TODO @@ -4256,6 +4377,21 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (PBI->getSuccessor(PBIOp) == BB) return false; + // If predecessor's branch probability to BB is too low don't merge branches. + SmallVector<uint32_t, 2> PredWeights; + if (!PBI->getMetadata(LLVMContext::MD_unpredictable) && + extractBranchWeights(*PBI, PredWeights) && + (static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]) != 0) { + + BranchProbability CommonDestProb = BranchProbability::getBranchProbability( + PredWeights[PBIOp], + static_cast<uint64_t>(PredWeights[0]) + PredWeights[1]); + + BranchProbability Likely = TTI.getPredictableBranchThreshold(); + if (CommonDestProb >= Likely) + return false; + } + // Do not perform this transformation if it would require // insertion of a large number of select instructions. For targets // without predication/cmovs, this is a big pessimization. @@ -5088,6 +5224,15 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { bool Changed = false; + // Ensure that any debug-info records that used to occur after the Unreachable + // are moved to in front of it -- otherwise they'll "dangle" at the end of + // the block. + BB->flushTerminatorDbgValues(); + + // Debug-info records on the unreachable inst itself should be deleted, as + // below we delete everything past the final executable instruction. + UI->dropDbgValues(); + // If there are any instructions immediately before the unreachable that can // be removed, do so. while (UI->getIterator() != BB->begin()) { @@ -5104,6 +5249,10 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { // block will be the unwind edges of Invoke/CatchSwitch/CleanupReturn, // and we can therefore guarantee this block will be erased. + // If we're deleting this, we're deleting any subsequent dbg.values, so + // delete DPValue records of variable information. + BBI->dropDbgValues(); + // Delete this instruction (any uses are guaranteed to be dead) BBI->replaceAllUsesWith(PoisonValue::get(BBI->getType())); BBI->eraseFromParent(); @@ -5667,7 +5816,7 @@ getCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, for (Instruction &I : CaseDest->instructionsWithoutDebug(false)) { if (I.isTerminator()) { // If the terminator is a simple branch, continue to the next block. - if (I.getNumSuccessors() != 1 || I.isExceptionalTerminator()) + if (I.getNumSuccessors() != 1 || I.isSpecialTerminator()) return false; Pred = CaseDest; CaseDest = I.getSuccessor(0); @@ -5890,8 +6039,8 @@ static void removeSwitchAfterSelectFold(SwitchInst *SI, PHINode *PHI, // Remove the switch. - while (PHI->getBasicBlockIndex(SelectBB) >= 0) - PHI->removeIncomingValue(SelectBB); + PHI->removeIncomingValueIf( + [&](unsigned Idx) { return PHI->getIncomingBlock(Idx) == SelectBB; }); PHI->addIncoming(SelectValue, SelectBB); SmallPtrSet<BasicBlock *, 4> RemovedSuccessors; @@ -6051,8 +6200,9 @@ SwitchLookupTable::SwitchLookupTable( bool LinearMappingPossible = true; APInt PrevVal; APInt DistToPrev; - // When linear map is monotonic, we can attach nsw. - bool Wrapped = false; + // When linear map is monotonic and signed overflow doesn't happen on + // maximum index, we can attach nsw on Add and Mul. + bool NonMonotonic = false; assert(TableSize >= 2 && "Should be a SingleValue table."); // Check if there is the same distance between two consecutive values. for (uint64_t I = 0; I < TableSize; ++I) { @@ -6072,7 +6222,7 @@ SwitchLookupTable::SwitchLookupTable( LinearMappingPossible = false; break; } - Wrapped |= + NonMonotonic |= Dist.isStrictlyPositive() ? Val.sle(PrevVal) : Val.sgt(PrevVal); } PrevVal = Val; @@ -6080,7 +6230,10 @@ SwitchLookupTable::SwitchLookupTable( if (LinearMappingPossible) { LinearOffset = cast<ConstantInt>(TableContents[0]); LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev); - LinearMapValWrapped = Wrapped; + bool MayWrap = false; + APInt M = LinearMultiplier->getValue(); + (void)M.smul_ov(APInt(M.getBitWidth(), TableSize - 1), MayWrap); + LinearMapValWrapped = NonMonotonic || MayWrap; Kind = LinearMapKind; ++NumLinearMaps; return; @@ -6503,9 +6656,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // If the default destination is unreachable, or if the lookup table covers // all values of the conditional variable, branch directly to the lookup table // BB. Otherwise, check that the condition is within the case range. - const bool DefaultIsReachable = + bool DefaultIsReachable = !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); - const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize); // Create the BB that does the lookups. Module &Mod = *CommonDest->getParent()->getParent(); @@ -6536,6 +6688,28 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, BranchInst *RangeCheckBranch = nullptr; + // Grow the table to cover all possible index values to avoid the range check. + // It will use the default result to fill in the table hole later, so make + // sure it exist. + if (UseSwitchConditionAsTableIndex && HasDefaultResults) { + ConstantRange CR = computeConstantRange(TableIndex, /* ForSigned */ false); + // Grow the table shouldn't have any size impact by checking + // WouldFitInRegister. + // TODO: Consider growing the table also when it doesn't fit in a register + // if no optsize is specified. + const uint64_t UpperBound = CR.getUpper().getLimitedValue(); + if (!CR.isUpperWrapped() && all_of(ResultTypes, [&](const auto &KV) { + return SwitchLookupTable::WouldFitInRegister( + DL, UpperBound, KV.second /* ResultType */); + })) { + // The default branch is unreachable after we enlarge the lookup table. + // Adjust DefaultIsReachable to reuse code path. + TableSize = UpperBound; + DefaultIsReachable = false; + } + } + + const bool GeneratingCoveredLookupTable = (MaxTableSize == TableSize); if (!DefaultIsReachable || GeneratingCoveredLookupTable) { Builder.CreateBr(LookupBB); if (DTU) @@ -6697,9 +6871,6 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, // This transform can be done speculatively because it is so cheap - it // results in a single rotate operation being inserted. - // FIXME: It's possible that optimizing a switch on powers of two might also - // be beneficial - flag values are often powers of two and we could use a CLZ - // as the key function. // countTrailingZeros(0) returns 64. As Values is guaranteed to have more than // one element and LLVM disallows duplicate cases, Shift is guaranteed to be @@ -6744,6 +6915,80 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Tries to transform switch of powers of two to reduce switch range. +/// For example, switch like: +/// switch (C) { case 1: case 2: case 64: case 128: } +/// will be transformed to: +/// switch (count_trailing_zeros(C)) { case 0: case 1: case 6: case 7: } +/// +/// This transformation allows better lowering and could allow transforming into +/// a lookup table. +static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout &DL, + const TargetTransformInfo &TTI) { + Value *Condition = SI->getCondition(); + LLVMContext &Context = SI->getContext(); + auto *CondTy = cast<IntegerType>(Condition->getType()); + + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + + const auto CttzIntrinsicCost = TTI.getIntrinsicInstrCost( + IntrinsicCostAttributes(Intrinsic::cttz, CondTy, + {Condition, ConstantInt::getTrue(Context)}), + TTI::TCK_SizeAndLatency); + + if (CttzIntrinsicCost > TTI::TCC_Basic) + // Inserting intrinsic is too expensive. + return false; + + // Only bother with this optimization if there are more than 3 switch cases. + // SDAG will only bother creating jump tables for 4 or more cases. + if (SI->getNumCases() < 4) + return false; + + // We perform this optimization only for switches with + // unreachable default case. + // This assumtion will save us from checking if `Condition` is a power of two. + if (!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg())) + return false; + + // Check that switch cases are powers of two. + SmallVector<uint64_t, 4> Values; + for (const auto &Case : SI->cases()) { + uint64_t CaseValue = Case.getCaseValue()->getValue().getZExtValue(); + if (llvm::has_single_bit(CaseValue)) + Values.push_back(CaseValue); + else + return false; + } + + // isSwichDense requires case values to be sorted. + llvm::sort(Values); + if (!isSwitchDense(Values.size(), llvm::countr_zero(Values.back()) - + llvm::countr_zero(Values.front()) + 1)) + // Transform is unable to generate dense switch. + return false; + + Builder.SetInsertPoint(SI); + + // Replace each case with its trailing zeros number. + for (auto &Case : SI->cases()) { + auto *OrigValue = Case.getCaseValue(); + Case.setValue(ConstantInt::get(OrigValue->getType(), + OrigValue->getValue().countr_zero())); + } + + // Replace condition with its trailing zeros number. + auto *ConditionTrailingZeros = Builder.CreateIntrinsic( + Intrinsic::cttz, {CondTy}, {Condition, ConstantInt::getTrue(Context)}); + + SI->setCondition(ConditionTrailingZeros); + + return true; +} + bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { BasicBlock *BB = SI->getParent(); @@ -6791,9 +7036,16 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { SwitchToLookupTable(SI, Builder, DTU, DL, TTI)) return requestResimplify(); + if (simplifySwitchOfPowersOfTwo(SI, Builder, DL, TTI)) + return requestResimplify(); + if (ReduceSwitchRange(SI, Builder, DL, TTI)) return requestResimplify(); + if (HoistCommon && + hoistCommonCodeFromSuccessors(SI->getParent(), !Options.HoistCommonInsts)) + return requestResimplify(); + return false; } @@ -6978,7 +7230,8 @@ bool SimplifyCFGOpt::simplifyUncondBranch(BranchInst *BI, // branches to us and our successor, fold the comparison into the // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. - if (FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI, + if (Options.SpeculateBlocks && + FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI, Options.BonusInstThreshold)) return requestResimplify(); return false; @@ -7048,7 +7301,8 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. - if (FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI, + if (Options.SpeculateBlocks && + FoldBranchToCommonDest(BI, DTU, /*MSSAU=*/nullptr, &TTI, Options.BonusInstThreshold)) return requestResimplify(); @@ -7058,7 +7312,8 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // can hoist it up to the branching block. if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { - if (HoistCommon && HoistThenElseCodeToIf(BI, !Options.HoistCommonInsts)) + if (HoistCommon && hoistCommonCodeFromSuccessors( + BI->getParent(), !Options.HoistCommonInsts)) return requestResimplify(); } else { // If Successor #1 has multiple preds, we may be able to conditionally diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index a28916bc9baf..722ed03db3de 100644 --- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -539,7 +539,8 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) { for (auto *ICI : ICmpUsers) { bool IsSwapped = L->isLoopInvariant(ICI->getOperand(0)); auto *Op1 = IsSwapped ? ICI->getOperand(0) : ICI->getOperand(1); - Instruction *Ext = nullptr; + IRBuilder<> Builder(ICI); + Value *Ext = nullptr; // For signed/unsigned predicate, replace the old comparison with comparison // of immediate IV against sext/zext of the invariant argument. If we can // use either sext or zext (i.e. we are dealing with equality predicate), @@ -550,18 +551,18 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) { if (IsSwapped) Pred = ICmpInst::getSwappedPredicate(Pred); if (CanUseZExt(ICI)) { assert(DoesZExtCollapse && "Unprofitable zext?"); - Ext = new ZExtInst(Op1, IVTy, "zext", ICI); + Ext = Builder.CreateZExt(Op1, IVTy, "zext"); Pred = ICmpInst::getUnsignedPredicate(Pred); } else { assert(DoesSExtCollapse && "Unprofitable sext?"); - Ext = new SExtInst(Op1, IVTy, "sext", ICI); + Ext = Builder.CreateSExt(Op1, IVTy, "sext"); assert(Pred == ICmpInst::getSignedPredicate(Pred) && "Must be signed!"); } bool Changed; L->makeLoopInvariant(Ext, Changed); (void)Changed; - ICmpInst *NewICI = new ICmpInst(ICI, Pred, IV, Ext); - ICI->replaceAllUsesWith(NewICI); + auto *NewCmp = Builder.CreateICmp(Pred, IV, Ext); + ICI->replaceAllUsesWith(NewCmp); DeadInsts.emplace_back(ICI); } @@ -659,12 +660,12 @@ bool SimplifyIndvar::replaceFloatIVWithIntegerIV(Instruction *UseInst) { Instruction *IVOperand = cast<Instruction>(UseInst->getOperand(0)); // Get the symbolic expression for this instruction. const SCEV *IV = SE->getSCEV(IVOperand); - unsigned MaskBits; + int MaskBits; if (UseInst->getOpcode() == CastInst::SIToFP) - MaskBits = SE->getSignedRange(IV).getMinSignedBits(); + MaskBits = (int)SE->getSignedRange(IV).getMinSignedBits(); else - MaskBits = SE->getUnsignedRange(IV).getActiveBits(); - unsigned DestNumSigBits = UseInst->getType()->getFPMantissaWidth(); + MaskBits = (int)SE->getUnsignedRange(IV).getActiveBits(); + int DestNumSigBits = UseInst->getType()->getFPMantissaWidth(); if (MaskBits <= DestNumSigBits) { for (User *U : UseInst->users()) { // Match for fptosi/fptoui of sitofp and with same type. @@ -908,8 +909,9 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { if (replaceIVUserWithLoopInvariant(UseInst)) continue; - // Go further for the bitcast ''prtoint ptr to i64' - if (isa<PtrToIntInst>(UseInst)) + // Go further for the bitcast 'prtoint ptr to i64' or if the cast is done + // by truncation + if ((isa<PtrToIntInst>(UseInst)) || (isa<TruncInst>(UseInst))) for (Use &U : UseInst->uses()) { Instruction *User = cast<Instruction>(U.getUser()); if (replaceIVUserWithLoopInvariant(User)) @@ -1373,16 +1375,32 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) { DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0; assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU"); - const SCEV *ExtendOperExpr = nullptr; const OverflowingBinaryOperator *OBO = cast<OverflowingBinaryOperator>(DU.NarrowUse); ExtendKind ExtKind = getExtendKind(DU.NarrowDef); - if (ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) - ExtendOperExpr = SE->getSignExtendExpr( - SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); - else if (ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap()) - ExtendOperExpr = SE->getZeroExtendExpr( - SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); + if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) && + !(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) { + ExtKind = ExtendKind::Unknown; + + // For a non-negative NarrowDef, we can choose either type of + // extension. We want to use the current extend kind if legal + // (see above), and we only hit this code if we need to check + // the opposite case. + if (DU.NeverNegative) { + if (OBO->hasNoSignedWrap()) { + ExtKind = ExtendKind::Sign; + } else if (OBO->hasNoUnsignedWrap()) { + ExtKind = ExtendKind::Zero; + } + } + } + + const SCEV *ExtendOperExpr = + SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)); + if (ExtKind == ExtendKind::Sign) + ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType); + else if (ExtKind == ExtendKind::Zero) + ExtendOperExpr = SE->getZeroExtendExpr(ExtendOperExpr, WideType); else return {nullptr, ExtendKind::Unknown}; @@ -1493,10 +1511,6 @@ bool WidenIV::widenLoopCompare(WidenIV::NarrowIVDefUse DU) { assert(CastWidth <= IVWidth && "Unexpected width while widening compare."); // Widen the compare instruction. - auto *InsertPt = getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI); - if (!InsertPt) - return false; - IRBuilder<> Builder(InsertPt); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); // Widen the other operand of the compare, if necessary. @@ -1673,7 +1687,8 @@ bool WidenIV::widenWithVariantUse(WidenIV::NarrowIVDefUse DU) { assert(LoopExitingBlock && L->contains(LoopExitingBlock) && "Not a LCSSA Phi?"); WidePN->addIncoming(WideBO, LoopExitingBlock); - Builder.SetInsertPoint(&*User->getParent()->getFirstInsertionPt()); + Builder.SetInsertPoint(User->getParent(), + User->getParent()->getFirstInsertionPt()); auto *TruncPN = Builder.CreateTrunc(WidePN, User->getType()); User->replaceAllUsesWith(TruncPN); DeadInsts.emplace_back(User); @@ -1726,7 +1741,8 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide", UsePhi); WidePhi->addIncoming(DU.WideDef, UsePhi->getIncomingBlock(0)); - IRBuilder<> Builder(&*WidePhi->getParent()->getFirstInsertionPt()); + BasicBlock *WidePhiBB = WidePhi->getParent(); + IRBuilder<> Builder(WidePhiBB, WidePhiBB->getFirstInsertionPt()); Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType()); UsePhi->replaceAllUsesWith(Trunc); DeadInsts.emplace_back(UsePhi); @@ -1786,65 +1802,70 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri return nullptr; } - // Does this user itself evaluate to a recurrence after widening? - WidenedRecTy WideAddRec = getExtendedOperandRecurrence(DU); - if (!WideAddRec.first) - WideAddRec = getWideRecurrence(DU); - - assert((WideAddRec.first == nullptr) == - (WideAddRec.second == ExtendKind::Unknown)); - if (!WideAddRec.first) { - // If use is a loop condition, try to promote the condition instead of - // truncating the IV first. - if (widenLoopCompare(DU)) + auto tryAddRecExpansion = [&]() -> Instruction* { + // Does this user itself evaluate to a recurrence after widening? + WidenedRecTy WideAddRec = getExtendedOperandRecurrence(DU); + if (!WideAddRec.first) + WideAddRec = getWideRecurrence(DU); + assert((WideAddRec.first == nullptr) == + (WideAddRec.second == ExtendKind::Unknown)); + if (!WideAddRec.first) return nullptr; - // We are here about to generate a truncate instruction that may hurt - // performance because the scalar evolution expression computed earlier - // in WideAddRec.first does not indicate a polynomial induction expression. - // In that case, look at the operands of the use instruction to determine - // if we can still widen the use instead of truncating its operand. - if (widenWithVariantUse(DU)) + // Reuse the IV increment that SCEVExpander created as long as it dominates + // NarrowUse. + Instruction *WideUse = nullptr; + if (WideAddRec.first == WideIncExpr && + Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) + WideUse = WideInc; + else { + WideUse = cloneIVUser(DU, WideAddRec.first); + if (!WideUse) + return nullptr; + } + // Evaluation of WideAddRec ensured that the narrow expression could be + // extended outside the loop without overflow. This suggests that the wide use + // evaluates to the same expression as the extended narrow use, but doesn't + // absolutely guarantee it. Hence the following failsafe check. In rare cases + // where it fails, we simply throw away the newly created wide use. + if (WideAddRec.first != SE->getSCEV(WideUse)) { + LLVM_DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse << ": " + << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first + << "\n"); + DeadInsts.emplace_back(WideUse); return nullptr; + }; - // This user does not evaluate to a recurrence after widening, so don't - // follow it. Instead insert a Trunc to kill off the original use, - // eventually isolating the original narrow IV so it can be removed. - truncateIVUse(DU, DT, LI); - return nullptr; - } + // if we reached this point then we are going to replace + // DU.NarrowUse with WideUse. Reattach DbgValue then. + replaceAllDbgUsesWith(*DU.NarrowUse, *WideUse, *WideUse, *DT); - // Reuse the IV increment that SCEVExpander created as long as it dominates - // NarrowUse. - Instruction *WideUse = nullptr; - if (WideAddRec.first == WideIncExpr && - Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) - WideUse = WideInc; - else { - WideUse = cloneIVUser(DU, WideAddRec.first); - if (!WideUse) - return nullptr; - } - // Evaluation of WideAddRec ensured that the narrow expression could be - // extended outside the loop without overflow. This suggests that the wide use - // evaluates to the same expression as the extended narrow use, but doesn't - // absolutely guarantee it. Hence the following failsafe check. In rare cases - // where it fails, we simply throw away the newly created wide use. - if (WideAddRec.first != SE->getSCEV(WideUse)) { - LLVM_DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse << ": " - << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first - << "\n"); - DeadInsts.emplace_back(WideUse); + ExtendKindMap[DU.NarrowUse] = WideAddRec.second; + // Returning WideUse pushes it on the worklist. + return WideUse; + }; + + if (auto *I = tryAddRecExpansion()) + return I; + + // If use is a loop condition, try to promote the condition instead of + // truncating the IV first. + if (widenLoopCompare(DU)) return nullptr; - } - // if we reached this point then we are going to replace - // DU.NarrowUse with WideUse. Reattach DbgValue then. - replaceAllDbgUsesWith(*DU.NarrowUse, *WideUse, *WideUse, *DT); + // We are here about to generate a truncate instruction that may hurt + // performance because the scalar evolution expression computed earlier + // in WideAddRec.first does not indicate a polynomial induction expression. + // In that case, look at the operands of the use instruction to determine + // if we can still widen the use instead of truncating its operand. + if (widenWithVariantUse(DU)) + return nullptr; - ExtendKindMap[DU.NarrowUse] = WideAddRec.second; - // Returning WideUse pushes it on the worklist. - return WideUse; + // This user does not evaluate to a recurrence after widening, so don't + // follow it. Instead insert a Trunc to kill off the original use, + // eventually isolating the original narrow IV so it can be removed. + truncateIVUse(DU, DT, LI); + return nullptr; } /// Add eligible users of NarrowDef to NarrowIVUsers. @@ -1944,13 +1965,15 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { // SCEVExpander. Henceforth, we produce 1-to-1 narrow to wide uses. if (BasicBlock *LatchBlock = L->getLoopLatch()) { WideInc = - cast<Instruction>(WidePhi->getIncomingValueForBlock(LatchBlock)); - WideIncExpr = SE->getSCEV(WideInc); - // Propagate the debug location associated with the original loop increment - // to the new (widened) increment. - auto *OrigInc = - cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock)); - WideInc->setDebugLoc(OrigInc->getDebugLoc()); + dyn_cast<Instruction>(WidePhi->getIncomingValueForBlock(LatchBlock)); + if (WideInc) { + WideIncExpr = SE->getSCEV(WideInc); + // Propagate the debug location associated with the original loop + // increment to the new (widened) increment. + auto *OrigInc = + cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock)); + WideInc->setDebugLoc(OrigInc->getDebugLoc()); + } } LLVM_DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n"); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 5b0951252c07..760a626c8b6f 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -227,9 +227,21 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr, return ConstantInt::get(RetTy, Result); } +static bool isOnlyUsedInComparisonWithZero(Value *V) { + for (User *U : V->users()) { + if (ICmpInst *IC = dyn_cast<ICmpInst>(U)) + if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) + if (C->isNullValue()) + continue; + // Unknown instruction. + return false; + } + return true; +} + static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len, const DataLayout &DL) { - if (!isOnlyUsedInZeroComparison(CI)) + if (!isOnlyUsedInComparisonWithZero(CI)) return false; if (!isDereferenceableAndAlignedPointer(Str, Align(1), APInt(64, Len), DL)) @@ -1136,7 +1148,7 @@ Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilderBase &B) { Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) { // fold strstr(x, x) -> x. if (CI->getArgOperand(0) == CI->getArgOperand(1)) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); + return CI->getArgOperand(0); // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 if (isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { @@ -1164,7 +1176,7 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) { // fold strstr(x, "") -> x. if (HasStr2 && ToFindStr.empty()) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); + return CI->getArgOperand(0); // If both strings are known, constant fold it. if (HasStr1 && HasStr2) { @@ -1174,16 +1186,13 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) { return Constant::getNullValue(CI->getType()); // strstr("abcd", "bc") -> gep((char*)"abcd", 1) - Value *Result = castToCStr(CI->getArgOperand(0), B); - Result = - B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), Result, Offset, "strstr"); - return B.CreateBitCast(Result, CI->getType()); + return B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), CI->getArgOperand(0), + Offset, "strstr"); } // fold strstr(x, "y") -> strchr(x, 'y'). if (HasStr2 && ToFindStr.size() == 1) { - Value *StrChr = emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI); - return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + return emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI); } annotateNonNullNoUndefBasedOnAccess(CI, {0, 1}); @@ -1380,7 +1389,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { if (isOnlyUsedInEqualityComparison(CI, SrcStr)) // S is dereferenceable so it's safe to load from it and fold // memchr(S, C, N) == S to N && *S == C for any C and N. - // TODO: This is safe even even for nonconstant S. + // TODO: This is safe even for nonconstant S. return memChrToCharCompare(CI, Size, B, DL); // From now on we need a constant length and constant array. @@ -1522,12 +1531,10 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS if (Len == 1) { - Value *LHSV = - B.CreateZExt(B.CreateLoad(B.getInt8Ty(), castToCStr(LHS, B), "lhsc"), - CI->getType(), "lhsv"); - Value *RHSV = - B.CreateZExt(B.CreateLoad(B.getInt8Ty(), castToCStr(RHS, B), "rhsc"), - CI->getType(), "rhsv"); + Value *LHSV = B.CreateZExt(B.CreateLoad(B.getInt8Ty(), LHS, "lhsc"), + CI->getType(), "lhsv"); + Value *RHSV = B.CreateZExt(B.CreateLoad(B.getInt8Ty(), RHS, "rhsc"), + CI->getType(), "rhsv"); return B.CreateSub(LHSV, RHSV, "chardiff"); } @@ -1833,7 +1840,7 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B, StringRef CallerName = CI->getFunction()->getName(); if (!CallerName.empty() && CallerName.back() == 'f' && CallerName.size() == (CalleeName.size() + 1) && - CallerName.startswith(CalleeName)) + CallerName.starts_with(CalleeName)) return nullptr; } @@ -2368,8 +2375,8 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { FMF.setNoSignedZeros(); B.setFastMathFlags(FMF); - Intrinsic::ID IID = Callee->getName().startswith("fmin") ? Intrinsic::minnum - : Intrinsic::maxnum; + Intrinsic::ID IID = Callee->getName().starts_with("fmin") ? Intrinsic::minnum + : Intrinsic::maxnum; Function *F = Intrinsic::getDeclaration(CI->getModule(), IID, CI->getType()); return copyFlags( *CI, B.CreateCall(F, {CI->getArgOperand(0), CI->getArgOperand(1)})); @@ -3066,7 +3073,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); - Value *Ptr = castToCStr(Dest, B); + Value *Ptr = Dest; B.CreateStore(V, Ptr); Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); @@ -3093,9 +3100,6 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, return ConstantInt::get(CI->getType(), SrcLen - 1); } else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) { // sprintf(dest, "%s", str) -> stpcpy(dest, str) - dest - // Handle mismatched pointer types (goes away with typeless pointers?). - V = B.CreatePointerCast(V, B.getInt8PtrTy()); - Dest = B.CreatePointerCast(Dest, B.getInt8PtrTy()); Value *PtrDiff = B.CreatePtrDiff(B.getInt8Ty(), V, Dest); return B.CreateIntCast(PtrDiff, CI->getType(), false); } @@ -3261,7 +3265,7 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, if (!CI->getArgOperand(3)->getType()->isIntegerTy()) return nullptr; Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); - Value *Ptr = castToCStr(DstArg, B); + Value *Ptr = DstArg; B.CreateStore(V, Ptr); Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); @@ -3397,8 +3401,7 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilderBase &B) { // If this is writing one byte, turn it into fputc. // This optimisation is only valid, if the return value is unused. if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(B.getInt8Ty(), - castToCStr(CI->getArgOperand(0), B), "char"); + Value *Char = B.CreateLoad(B.getInt8Ty(), CI->getArgOperand(0), "char"); Type *IntTy = B.getIntNTy(TLI->getIntSize()); Value *Cast = B.CreateIntCast(Char, IntTy, /*isSigned*/ true, "chari"); Value *NewCI = emitFPutC(Cast, CI->getArgOperand(3), B, TLI); diff --git a/llvm/lib/Transforms/Utils/StripGCRelocates.cpp b/llvm/lib/Transforms/Utils/StripGCRelocates.cpp index 0ff88e8b4612..6094f36a77f4 100644 --- a/llvm/lib/Transforms/Utils/StripGCRelocates.cpp +++ b/llvm/lib/Transforms/Utils/StripGCRelocates.cpp @@ -18,8 +18,6 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Statepoint.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" using namespace llvm; @@ -66,21 +64,3 @@ PreservedAnalyses StripGCRelocates::run(Function &F, PA.preserveSet<CFGAnalyses>(); return PA; } - -namespace { -struct StripGCRelocatesLegacy : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - StripGCRelocatesLegacy() : FunctionPass(ID) { - initializeStripGCRelocatesLegacyPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &Info) const override {} - - bool runOnFunction(Function &F) override { return ::stripGCRelocates(F); } -}; -char StripGCRelocatesLegacy::ID = 0; -} // namespace - -INITIALIZE_PASS(StripGCRelocatesLegacy, "strip-gc-relocates", - "Strip gc.relocates inserted through RewriteStatepointsForGC", - true, false) diff --git a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index c3ae43e567b0..8b4f34209e85 100644 --- a/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -68,8 +68,6 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" diff --git a/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp index 2b706858cbed..d5468909dd4e 100644 --- a/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp +++ b/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -16,33 +16,9 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils.h" using namespace llvm; -char UnifyFunctionExitNodesLegacyPass::ID = 0; - -UnifyFunctionExitNodesLegacyPass::UnifyFunctionExitNodesLegacyPass() - : FunctionPass(ID) { - initializeUnifyFunctionExitNodesLegacyPassPass( - *PassRegistry::getPassRegistry()); -} - -INITIALIZE_PASS(UnifyFunctionExitNodesLegacyPass, "mergereturn", - "Unify function exit nodes", false, false) - -Pass *llvm::createUnifyFunctionExitNodesPass() { - return new UnifyFunctionExitNodesLegacyPass(); -} - -void UnifyFunctionExitNodesLegacyPass::getAnalysisUsage( - AnalysisUsage &AU) const { - // We preserve the non-critical-edgeness property - AU.addPreservedID(BreakCriticalEdgesID); - // This is a cluster of orthogonal Transforms - AU.addPreservedID(LowerSwitchID); -} - namespace { bool unifyUnreachableBlocks(Function &F) { @@ -110,16 +86,6 @@ bool unifyReturnBlocks(Function &F) { } } // namespace -// Unify all exit nodes of the CFG by creating a new BasicBlock, and converting -// all returns to unconditional branches to this new basic block. Also, unify -// all unreachable blocks. -bool UnifyFunctionExitNodesLegacyPass::runOnFunction(Function &F) { - bool Changed = false; - Changed |= unifyUnreachableBlocks(F); - Changed |= unifyReturnBlocks(F); - return Changed; -} - PreservedAnalyses UnifyFunctionExitNodesPass::run(Function &F, FunctionAnalysisManager &AM) { bool Changed = false; diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 8c781f59ff5a..2f37f7f972cb 100644 --- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -44,10 +44,8 @@ struct UnifyLoopExitsLegacyPass : public FunctionPass { } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredID(LowerSwitchID); AU.addRequired<LoopInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreservedID(LowerSwitchID); AU.addPreserved<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); } @@ -65,7 +63,6 @@ FunctionPass *llvm::createUnifyLoopExitsPass() { INITIALIZE_PASS_BEGIN(UnifyLoopExitsLegacyPass, "unify-loop-exits", "Fixup each natural loop to have a single exit block", false /* Only looks at CFG */, false /* Analysis Pass */) -INITIALIZE_PASS_DEPENDENCY(LowerSwitchLegacyPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(UnifyLoopExitsLegacyPass, "unify-loop-exits", @@ -234,6 +231,8 @@ bool UnifyLoopExitsLegacyPass::runOnFunction(Function &F) { auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator."); + return runImpl(LI, DT); } diff --git a/llvm/lib/Transforms/Utils/Utils.cpp b/llvm/lib/Transforms/Utils/Utils.cpp index 91c743f17764..51e1e824dd26 100644 --- a/llvm/lib/Transforms/Utils/Utils.cpp +++ b/llvm/lib/Transforms/Utils/Utils.cpp @@ -21,7 +21,6 @@ using namespace llvm; /// initializeTransformUtils - Initialize all passes in the TransformUtils /// library. void llvm::initializeTransformUtils(PassRegistry &Registry) { - initializeAssumeBuilderPassLegacyPassPass(Registry); initializeBreakCriticalEdgesPass(Registry); initializeCanonicalizeFreezeInLoopsPass(Registry); initializeLCSSAWrapperPassPass(Registry); @@ -30,9 +29,6 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) { initializeLowerInvokeLegacyPassPass(Registry); initializeLowerSwitchLegacyPassPass(Registry); initializePromoteLegacyPassPass(Registry); - initializeUnifyFunctionExitNodesLegacyPassPass(Registry); - initializeStripGCRelocatesLegacyPass(Registry); - initializePredicateInfoPrinterLegacyPassPass(Registry); initializeFixIrreduciblePass(Registry); initializeUnifyLoopExitsLegacyPassPass(Registry); } diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 3446e31cc2ef..71d0f09e4771 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" @@ -145,6 +146,7 @@ public: Value *mapValue(const Value *V); void remapInstruction(Instruction *I); void remapFunction(Function &F); + void remapDPValue(DPValue &DPV); Constant *mapConstant(const Constant *C) { return cast_or_null<Constant>(mapValue(C)); @@ -535,6 +537,39 @@ Value *Mapper::mapValue(const Value *V) { return getVM()[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); } +void Mapper::remapDPValue(DPValue &V) { + // Remap variables and DILocations. + auto *MappedVar = mapMetadata(V.getVariable()); + auto *MappedDILoc = mapMetadata(V.getDebugLoc()); + V.setVariable(cast<DILocalVariable>(MappedVar)); + V.setDebugLoc(DebugLoc(cast<DILocation>(MappedDILoc))); + + // Find Value operands and remap those. + SmallVector<Value *, 4> Vals, NewVals; + for (Value *Val : V.location_ops()) + Vals.push_back(Val); + for (Value *Val : Vals) + NewVals.push_back(mapValue(Val)); + + // If there are no changes to the Value operands, finished. + if (Vals == NewVals) + return; + + bool IgnoreMissingLocals = Flags & RF_IgnoreMissingLocals; + + // Otherwise, do some replacement. + if (!IgnoreMissingLocals && + llvm::any_of(NewVals, [&](Value *V) { return V == nullptr; })) { + V.setKillLocation(); + } else { + // Either we have all non-empty NewVals, or we're permitted to ignore + // missing locals. + for (unsigned int I = 0; I < Vals.size(); ++I) + if (NewVals[I]) + V.replaceVariableLocationOp(I, NewVals[I]); + } +} + Value *Mapper::mapBlockAddress(const BlockAddress &BA) { Function *F = cast<Function>(mapValue(BA.getFunction())); @@ -1179,6 +1214,17 @@ void ValueMapper::remapInstruction(Instruction &I) { FlushingMapper(pImpl)->remapInstruction(&I); } +void ValueMapper::remapDPValue(Module *M, DPValue &V) { + FlushingMapper(pImpl)->remapDPValue(V); +} + +void ValueMapper::remapDPValueRange( + Module *M, iterator_range<DPValue::self_iterator> Range) { + for (DPValue &DPV : Range) { + remapDPValue(M, DPV); + } +} + void ValueMapper::remapFunction(Function &F) { FlushingMapper(pImpl)->remapFunction(F); } diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 260d7889906b..c0dbd52acbab 100644 --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -103,7 +103,6 @@ #include "llvm/Support/ModRef.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Vectorize.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -900,9 +899,9 @@ bool Vectorizer::vectorizeChain(Chain &C) { // Chain is in offset order, so C[0] is the instr with the lowest offset, // i.e. the root of the vector. - Value *Bitcast = Builder.CreateBitCast( - getLoadStorePointerOperand(C[0].Inst), VecTy->getPointerTo(AS)); - VecInst = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment); + VecInst = Builder.CreateAlignedLoad(VecTy, + getLoadStorePointerOperand(C[0].Inst), + Alignment); unsigned VecIdx = 0; for (const ChainElem &E : C) { @@ -976,8 +975,7 @@ bool Vectorizer::vectorizeChain(Chain &C) { // i.e. the root of the vector. VecInst = Builder.CreateAlignedStore( Vec, - Builder.CreateBitCast(getLoadStorePointerOperand(C[0].Inst), - VecTy->getPointerTo(AS)), + getLoadStorePointerOperand(C[0].Inst), Alignment); } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index f923f0be6621..37a356c43e29 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -289,7 +289,7 @@ void LoopVectorizeHints::getHintsFromMetadata() { } void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { - if (!Name.startswith(Prefix())) + if (!Name.starts_with(Prefix())) return; Name = Name.substr(Prefix().size(), StringRef::npos); @@ -943,6 +943,11 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } } + // If we found a vectorized variant of a function, note that so LV can + // make better decisions about maximum VF. + if (CI && !VFDatabase::getMappings(*CI).empty()) + VecCallVariantsFound = true; + // Check that the instruction return type is vectorizable. // Also, we can't vectorize extractelement instructions. if ((!VectorType::isValidElementType(I.getType()) && @@ -1242,13 +1247,12 @@ bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const { bool LoopVectorizationLegality::blockCanBePredicated( BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs, - SmallPtrSetImpl<const Instruction *> &MaskedOp, - SmallPtrSetImpl<Instruction *> &ConditionalAssumes) const { + SmallPtrSetImpl<const Instruction *> &MaskedOp) const { for (Instruction &I : *BB) { // We can predicate blocks with calls to assume, as long as we drop them in // case we flatten the CFG via predication. if (match(&I, m_Intrinsic<Intrinsic::assume>())) { - ConditionalAssumes.insert(&I); + MaskedOp.insert(&I); continue; } @@ -1345,16 +1349,13 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { } // We must be able to predicate all blocks that need to be predicated. - if (blockNeedsPredication(BB)) { - if (!blockCanBePredicated(BB, SafePointers, MaskedOp, - ConditionalAssumes)) { - reportVectorizationFailure( - "Control flow cannot be substituted for a select", - "control flow cannot be substituted for a select", - "NoCFGForSelect", ORE, TheLoop, - BB->getTerminator()); - return false; - } + if (blockNeedsPredication(BB) && + !blockCanBePredicated(BB, SafePointers, MaskedOp)) { + reportVectorizationFailure( + "Control flow cannot be substituted for a select", + "control flow cannot be substituted for a select", "NoCFGForSelect", + ORE, TheLoop, BB->getTerminator()); + return false; } } @@ -1554,14 +1555,14 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() { // The list of pointers that we can safely read and write to remains empty. SmallPtrSet<Value *, 8> SafePointers; + // Collect masked ops in temporary set first to avoid partially populating + // MaskedOp if a block cannot be predicated. SmallPtrSet<const Instruction *, 8> TmpMaskedOp; - SmallPtrSet<Instruction *, 8> TmpConditionalAssumes; // Check and mark all blocks for predication, including those that ordinarily // do not need predication such as the header block. for (BasicBlock *BB : TheLoop->blocks()) { - if (!blockCanBePredicated(BB, SafePointers, TmpMaskedOp, - TmpConditionalAssumes)) { + if (!blockCanBePredicated(BB, SafePointers, TmpMaskedOp)) { LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking as requested.\n"); return false; } @@ -1570,9 +1571,6 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() { LLVM_DEBUG(dbgs() << "LV: can fold tail by masking.\n"); MaskedOp.insert(TmpMaskedOp.begin(), TmpMaskedOp.end()); - ConditionalAssumes.insert(TmpConditionalAssumes.begin(), - TmpConditionalAssumes.end()); - return true; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 13357cb06c55..577ce8000de2 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -31,6 +31,7 @@ namespace llvm { class LoopInfo; +class DominatorTree; class LoopVectorizationLegality; class LoopVectorizationCostModel; class PredicatedScalarEvolution; @@ -45,13 +46,17 @@ class VPBuilder { VPBasicBlock *BB = nullptr; VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator(); + /// Insert \p VPI in BB at InsertPt if BB is set. + VPInstruction *tryInsertInstruction(VPInstruction *VPI) { + if (BB) + BB->insert(VPI, InsertPt); + return VPI; + } + VPInstruction *createInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL, const Twine &Name = "") { - VPInstruction *Instr = new VPInstruction(Opcode, Operands, DL, Name); - if (BB) - BB->insert(Instr, InsertPt); - return Instr; + return tryInsertInstruction(new VPInstruction(Opcode, Operands, DL, Name)); } VPInstruction *createInstruction(unsigned Opcode, @@ -62,6 +67,7 @@ class VPBuilder { public: VPBuilder() = default; + VPBuilder(VPBasicBlock *InsertBB) { setInsertPoint(InsertBB); } /// Clear the insertion point: created instructions will not be inserted into /// a block. @@ -116,10 +122,11 @@ public: InsertPt = IP; } - /// Insert and return the specified instruction. - VPInstruction *insert(VPInstruction *I) const { - BB->insert(I, InsertPt); - return I; + /// This specifies that created instructions should be inserted at the + /// specified point. + void setInsertPoint(VPRecipeBase *IP) { + BB = IP->getParent(); + InsertPt = IP->getIterator(); } /// Create an N-ary operation with \p Opcode, \p Operands and set \p Inst as @@ -138,6 +145,13 @@ public: return createInstruction(Opcode, Operands, DL, Name); } + VPInstruction *createOverflowingOp(unsigned Opcode, + std::initializer_list<VPValue *> Operands, + VPRecipeWithIRFlags::WrapFlagsTy WrapFlags, + DebugLoc DL, const Twine &Name = "") { + return tryInsertInstruction( + new VPInstruction(Opcode, Operands, WrapFlags, DL, Name)); + } VPValue *createNot(VPValue *Operand, DebugLoc DL, const Twine &Name = "") { return createInstruction(VPInstruction::Not, {Operand}, DL, Name); } @@ -158,6 +172,12 @@ public: Name); } + /// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A + /// and \p B. + /// TODO: add createFCmp when needed. + VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B, + DebugLoc DL = {}, const Twine &Name = ""); + //===--------------------------------------------------------------------===// // RAII helpers. //===--------------------------------------------------------------------===// @@ -268,6 +288,9 @@ class LoopVectorizationPlanner { /// Loop Info analysis. LoopInfo *LI; + /// The dominator tree. + DominatorTree *DT; + /// Target Library Info. const TargetLibraryInfo *TLI; @@ -298,16 +321,14 @@ class LoopVectorizationPlanner { VPBuilder Builder; public: - LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI, - const TargetTransformInfo &TTI, - LoopVectorizationLegality *Legal, - LoopVectorizationCostModel &CM, - InterleavedAccessInfo &IAI, - PredicatedScalarEvolution &PSE, - const LoopVectorizeHints &Hints, - OptimizationRemarkEmitter *ORE) - : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM), IAI(IAI), - PSE(PSE), Hints(Hints), ORE(ORE) {} + LoopVectorizationPlanner( + Loop *L, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, + const TargetTransformInfo &TTI, LoopVectorizationLegality *Legal, + LoopVectorizationCostModel &CM, InterleavedAccessInfo &IAI, + PredicatedScalarEvolution &PSE, const LoopVectorizeHints &Hints, + OptimizationRemarkEmitter *ORE) + : OrigLoop(L), LI(LI), DT(DT), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM), + IAI(IAI), PSE(PSE), Hints(Hints), ORE(ORE) {} /// Plan how to best vectorize, return the best VF and its cost, or /// std::nullopt if vectorization and interleaving should be avoided up front. @@ -333,7 +354,7 @@ public: executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan, InnerLoopVectorizer &LB, DominatorTree *DT, bool IsEpilogueVectorization, - DenseMap<const SCEV *, Value *> *ExpandedSCEVs = nullptr); + const DenseMap<const SCEV *, Value *> *ExpandedSCEVs = nullptr); #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printPlans(raw_ostream &O); @@ -377,8 +398,7 @@ private: /// returned VPlan is valid for. If no VPlan can be built for the input range, /// set the largest included VF to the maximum VF for which no plan could be /// built. - std::optional<VPlanPtr> tryToBuildVPlanWithVPRecipes( - VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions); + VPlanPtr tryToBuildVPlanWithVPRecipes(VFRange &Range); /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, /// according to the information gathered by Legal when it checked if it is diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index d7e40e8ef978..f82e161fb846 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -27,7 +27,7 @@ // // There is a development effort going on to migrate loop vectorizer to the // VPlan infrastructure and to introduce outer loop vectorization support (see -// docs/Proposal/VectorizationPlan.rst and +// docs/VectorizationPlan.rst and // http://lists.llvm.org/pipermail/llvm-dev/2017-December/119523.html). For this // purpose, we temporarily introduced the VPlan-native vectorization path: an // alternative vectorization path that is natively implemented on top of the @@ -57,6 +57,7 @@ #include "LoopVectorizationPlanner.h" #include "VPRecipeBuilder.h" #include "VPlan.h" +#include "VPlanAnalysis.h" #include "VPlanHCFGBuilder.h" #include "VPlanTransforms.h" #include "llvm/ADT/APInt.h" @@ -111,10 +112,12 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -390,6 +393,21 @@ static cl::opt<cl::boolOrDefault> ForceSafeDivisor( cl::desc( "Override cost based safe divisor widening for div/rem instructions")); +static cl::opt<bool> UseWiderVFIfCallVariantsPresent( + "vectorizer-maximize-bandwidth-for-vector-calls", cl::init(true), + cl::Hidden, + cl::desc("Try wider VFs if they enable the use of vector variants")); + +// Likelyhood of bypassing the vectorized loop because assumptions about SCEV +// variables not overflowing do not hold. See `emitSCEVChecks`. +static constexpr uint32_t SCEVCheckBypassWeights[] = {1, 127}; +// Likelyhood of bypassing the vectorized loop because pointers overlap. See +// `emitMemRuntimeChecks`. +static constexpr uint32_t MemCheckBypassWeights[] = {1, 127}; +// Likelyhood of bypassing the vectorized loop because there are zero trips left +// after prolog. See `emitIterationCountCheck`. +static constexpr uint32_t MinItersBypassWeights[] = {1, 127}; + /// A helper function that returns true if the given type is irregular. The /// type is irregular if its allocated size doesn't equal the store size of an /// element of the corresponding vector type. @@ -408,13 +426,6 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL) { /// we always assume predicated blocks have a 50% chance of executing. static unsigned getReciprocalPredBlockProb() { return 2; } -/// A helper function that returns an integer or floating-point constant with -/// value C. -static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { - return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C) - : ConstantFP::get(Ty, C); -} - /// Returns "best known" trip count for the specified loop \p L as defined by /// the following procedure: /// 1) Returns exact trip count if it is known. @@ -556,10 +567,6 @@ public: const VPIteration &Instance, VPTransformState &State); - /// Construct the vector value of a scalarized value \p V one lane at a time. - void packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance, - VPTransformState &State); - /// Try to vectorize interleaved access group \p Group with the base address /// given in \p Addr, optionally masking the vector operations if \p /// BlockInMask is non-null. Use \p State to translate given VPValues to IR @@ -634,10 +641,6 @@ protected: /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); - /// Shrinks vector element sizes to the smallest bitwidth they can be legally - /// represented as. - void truncateToMinimalBitwidths(VPTransformState &State); - /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock); @@ -943,21 +946,21 @@ protected: /// Look for a meaningful debug location on the instruction or it's /// operands. -static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { +static DebugLoc getDebugLocFromInstOrOperands(Instruction *I) { if (!I) - return I; + return DebugLoc(); DebugLoc Empty; if (I->getDebugLoc() != Empty) - return I; + return I->getDebugLoc(); for (Use &Op : I->operands()) { if (Instruction *OpInst = dyn_cast<Instruction>(Op)) if (OpInst->getDebugLoc() != Empty) - return OpInst; + return OpInst->getDebugLoc(); } - return I; + return I->getDebugLoc(); } /// Write a \p DebugMsg about vectorization to the debug output stream. If \p I @@ -1021,14 +1024,6 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE, return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop); } -static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy, - ElementCount VF) { - assert(FTy->isFloatingPointTy() && "Expected floating point type!"); - Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits()); - Value *RuntimeVF = getRuntimeVF(B, IntTy, VF); - return B.CreateUIToFP(RuntimeVF, FTy); -} - void reportVectorizationFailure(const StringRef DebugMsg, const StringRef OREMsg, const StringRef ORETag, OptimizationRemarkEmitter *ORE, Loop *TheLoop, @@ -1050,6 +1045,23 @@ void reportVectorizationInfo(const StringRef Msg, const StringRef ORETag, << Msg); } +/// Report successful vectorization of the loop. In case an outer loop is +/// vectorized, prepend "outer" to the vectorization remark. +static void reportVectorization(OptimizationRemarkEmitter *ORE, Loop *TheLoop, + VectorizationFactor VF, unsigned IC) { + LLVM_DEBUG(debugVectorizationMessage( + "Vectorizing: ", TheLoop->isInnermost() ? "innermost loop" : "outer loop", + nullptr)); + StringRef LoopType = TheLoop->isInnermost() ? "" : "outer "; + ORE->emit([&]() { + return OptimizationRemark(LV_NAME, "Vectorized", TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "vectorized " << LoopType << "loop (vectorization width: " + << ore::NV("VectorizationFactor", VF.Width) + << ", interleaved count: " << ore::NV("InterleaveCount", IC) << ")"; + }); +} + } // end namespace llvm #ifndef NDEBUG @@ -1104,7 +1116,8 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) { RecWithFlags->dropPoisonGeneratingFlags(); } else { - Instruction *Instr = CurRec->getUnderlyingInstr(); + Instruction *Instr = dyn_cast_or_null<Instruction>( + CurRec->getVPSingleValue()->getUnderlyingValue()); (void)Instr; assert((!Instr || !Instr->hasPoisonGeneratingFlags()) && "found instruction with poison generating flags not covered by " @@ -1247,6 +1260,13 @@ public: /// avoid redundant calculations. void setCostBasedWideningDecision(ElementCount VF); + /// A call may be vectorized in different ways depending on whether we have + /// vectorized variants available and whether the target supports masking. + /// This function analyzes all calls in the function at the supplied VF, + /// makes a decision based on the costs of available options, and stores that + /// decision in a map for use in planning and plan execution. + void setVectorizedCallDecision(ElementCount VF); + /// A struct that represents some properties of the register usage /// of a loop. struct RegisterUsage { @@ -1270,7 +1290,7 @@ public: void collectElementTypesForWidening(); /// Split reductions into those that happen in the loop, and those that happen - /// outside. In loop reductions are collected into InLoopReductionChains. + /// outside. In loop reductions are collected into InLoopReductions. void collectInLoopReductions(); /// Returns true if we should use strict in-order reductions for the given @@ -1358,7 +1378,9 @@ public: CM_Widen_Reverse, // For consecutive accesses with stride -1. CM_Interleave, CM_GatherScatter, - CM_Scalarize + CM_Scalarize, + CM_VectorCall, + CM_IntrinsicCall }; /// Save vectorization decision \p W and \p Cost taken by the cost model for @@ -1414,6 +1436,29 @@ public: return WideningDecisions[InstOnVF].second; } + struct CallWideningDecision { + InstWidening Kind; + Function *Variant; + Intrinsic::ID IID; + std::optional<unsigned> MaskPos; + InstructionCost Cost; + }; + + void setCallWideningDecision(CallInst *CI, ElementCount VF, InstWidening Kind, + Function *Variant, Intrinsic::ID IID, + std::optional<unsigned> MaskPos, + InstructionCost Cost) { + assert(!VF.isScalar() && "Expected vector VF"); + CallWideningDecisions[std::make_pair(CI, VF)] = {Kind, Variant, IID, + MaskPos, Cost}; + } + + CallWideningDecision getCallWideningDecision(CallInst *CI, + ElementCount VF) const { + assert(!VF.isScalar() && "Expected vector VF"); + return CallWideningDecisions.at(std::make_pair(CI, VF)); + } + /// Return True if instruction \p I is an optimizable truncate whose operand /// is an induction variable. Such a truncate will be removed by adding a new /// induction variable with the destination type. @@ -1447,11 +1492,15 @@ public: /// Collect Uniform and Scalar values for the given \p VF. /// The sets depend on CM decision for Load/Store instructions /// that may be vectorized as interleave, gather-scatter or scalarized. + /// Also make a decision on what to do about call instructions in the loop + /// at that VF -- scalarize, call a known vector routine, or call a + /// vector intrinsic. void collectUniformsAndScalars(ElementCount VF) { // Do the analysis once. if (VF.isScalar() || Uniforms.contains(VF)) return; setCostBasedWideningDecision(VF); + setVectorizedCallDecision(VF); collectLoopUniforms(VF); collectLoopScalars(VF); } @@ -1606,20 +1655,9 @@ public: return foldTailByMasking() || Legal->blockNeedsPredication(BB); } - /// A SmallMapVector to store the InLoop reduction op chains, mapping phi - /// nodes to the chain of instructions representing the reductions. Uses a - /// MapVector to ensure deterministic iteration order. - using ReductionChainMap = - SmallMapVector<PHINode *, SmallVector<Instruction *, 4>, 4>; - - /// Return the chain of instructions representing an inloop reduction. - const ReductionChainMap &getInLoopReductionChains() const { - return InLoopReductionChains; - } - /// Returns true if the Phi is part of an inloop reduction. bool isInLoopReduction(PHINode *Phi) const { - return InLoopReductionChains.count(Phi); + return InLoopReductions.contains(Phi); } /// Estimate cost of an intrinsic call instruction CI if it were vectorized @@ -1629,16 +1667,13 @@ public: /// Estimate cost of a call instruction CI if it were vectorized with factor /// VF. Return the cost of the instruction, including scalarization overhead - /// if it's needed. The flag NeedToScalarize shows if the call needs to be - /// scalarized - - /// i.e. either vector version isn't available, or is too expensive. - InstructionCost getVectorCallCost(CallInst *CI, ElementCount VF, - Function **Variant, - bool *NeedsMask = nullptr) const; + /// if it's needed. + InstructionCost getVectorCallCost(CallInst *CI, ElementCount VF) const; /// Invalidates decisions already taken by the cost model. void invalidateCostModelingDecisions() { WideningDecisions.clear(); + CallWideningDecisions.clear(); Uniforms.clear(); Scalars.clear(); } @@ -1675,14 +1710,14 @@ private: /// elements is a power-of-2 larger than zero. If scalable vectorization is /// disabled or unsupported, then the scalable part will be equal to /// ElementCount::getScalable(0). - FixedScalableVFPair computeFeasibleMaxVF(unsigned ConstTripCount, + FixedScalableVFPair computeFeasibleMaxVF(unsigned MaxTripCount, ElementCount UserVF, bool FoldTailByMasking); /// \return the maximized element count based on the targets vector /// registers and the loop trip-count, but limited to a maximum safe VF. /// This is a helper function of computeFeasibleMaxVF. - ElementCount getMaximizedVFForTarget(unsigned ConstTripCount, + ElementCount getMaximizedVFForTarget(unsigned MaxTripCount, unsigned SmallestType, unsigned WidestType, ElementCount MaxSafeVF, @@ -1705,7 +1740,7 @@ private: /// part of that pattern. std::optional<InstructionCost> getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy, - TTI::TargetCostKind CostKind); + TTI::TargetCostKind CostKind) const; /// Calculate vectorization cost of memory instruction \p I. InstructionCost getMemoryInstructionCost(Instruction *I, ElementCount VF); @@ -1783,15 +1818,12 @@ private: /// scalarized. DenseMap<ElementCount, SmallPtrSet<Instruction *, 4>> ForcedScalars; - /// PHINodes of the reductions that should be expanded in-loop along with - /// their associated chains of reduction operations, in program order from top - /// (PHI) to bottom - ReductionChainMap InLoopReductionChains; + /// PHINodes of the reductions that should be expanded in-loop. + SmallPtrSet<PHINode *, 4> InLoopReductions; /// A Map of inloop reduction operations and their immediate chain operand. /// FIXME: This can be removed once reductions can be costed correctly in - /// vplan. This was added to allow quick lookup to the inloop operations, - /// without having to loop through InLoopReductionChains. + /// VPlan. This was added to allow quick lookup of the inloop operations. DenseMap<Instruction *, Instruction *> InLoopReductionImmediateChains; /// Returns the expected difference in cost from scalarizing the expression @@ -1830,6 +1862,11 @@ private: DecisionList WideningDecisions; + using CallDecisionList = + DenseMap<std::pair<CallInst *, ElementCount>, CallWideningDecision>; + + CallDecisionList CallWideningDecisions; + /// Returns true if \p V is expected to be vectorized and it needs to be /// extracted. bool needsExtract(Value *V, ElementCount VF) const { @@ -1933,12 +1970,14 @@ class GeneratedRTChecks { SCEVExpander MemCheckExp; bool CostTooHigh = false; + const bool AddBranchWeights; public: GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI, - TargetTransformInfo *TTI, const DataLayout &DL) + TargetTransformInfo *TTI, const DataLayout &DL, + bool AddBranchWeights) : DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"), - MemCheckExp(SE, DL, "scev.check") {} + MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {} /// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can /// accurately estimate the cost of the runtime checks. The blocks are @@ -1990,9 +2029,9 @@ public: }, IC); } else { - MemRuntimeCheckCond = - addRuntimeChecks(MemCheckBlock->getTerminator(), L, - RtPtrChecking.getChecks(), MemCheckExp); + MemRuntimeCheckCond = addRuntimeChecks( + MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(), + MemCheckExp, VectorizerParams::HoistRuntimeChecks); } assert(MemRuntimeCheckCond && "no RT checks generated although RtPtrChecking " @@ -2131,8 +2170,10 @@ public: DT->addNewBlock(SCEVCheckBlock, Pred); DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock); - ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), - BranchInst::Create(Bypass, LoopVectorPreHeader, Cond)); + BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond); + if (AddBranchWeights) + setBranchWeights(BI, SCEVCheckBypassWeights); + ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI); return SCEVCheckBlock; } @@ -2156,9 +2197,12 @@ public: if (auto *PL = LI->getLoopFor(LoopVectorPreHeader)) PL->addBasicBlockToLoop(MemCheckBlock, *LI); - ReplaceInstWithInst( - MemCheckBlock->getTerminator(), - BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond)); + BranchInst &BI = + *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond); + if (AddBranchWeights) { + setBranchWeights(BI, MemCheckBypassWeights); + } + ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI); MemCheckBlock->getTerminator()->setDebugLoc( Pred->getTerminator()->getDebugLoc()); @@ -2252,157 +2296,17 @@ static void collectSupportedLoops(Loop &L, LoopInfo *LI, // LoopVectorizationCostModel and LoopVectorizationPlanner. //===----------------------------------------------------------------------===// -/// This function adds -/// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...) -/// to each vector element of Val. The sequence starts at StartIndex. -/// \p Opcode is relevant for FP induction variable. -static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step, - Instruction::BinaryOps BinOp, ElementCount VF, - IRBuilderBase &Builder) { - assert(VF.isVector() && "only vector VFs are supported"); - - // Create and check the types. - auto *ValVTy = cast<VectorType>(Val->getType()); - ElementCount VLen = ValVTy->getElementCount(); - - Type *STy = Val->getType()->getScalarType(); - assert((STy->isIntegerTy() || STy->isFloatingPointTy()) && - "Induction Step must be an integer or FP"); - assert(Step->getType() == STy && "Step has wrong type"); - - SmallVector<Constant *, 8> Indices; - - // Create a vector of consecutive numbers from zero to VF. - VectorType *InitVecValVTy = ValVTy; - if (STy->isFloatingPointTy()) { - Type *InitVecValSTy = - IntegerType::get(STy->getContext(), STy->getScalarSizeInBits()); - InitVecValVTy = VectorType::get(InitVecValSTy, VLen); - } - Value *InitVec = Builder.CreateStepVector(InitVecValVTy); - - // Splat the StartIdx - Value *StartIdxSplat = Builder.CreateVectorSplat(VLen, StartIdx); - - if (STy->isIntegerTy()) { - InitVec = Builder.CreateAdd(InitVec, StartIdxSplat); - Step = Builder.CreateVectorSplat(VLen, Step); - assert(Step->getType() == Val->getType() && "Invalid step vec"); - // FIXME: The newly created binary instructions should contain nsw/nuw - // flags, which can be found from the original scalar operations. - Step = Builder.CreateMul(InitVec, Step); - return Builder.CreateAdd(Val, Step, "induction"); - } - - // Floating point induction. - assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && - "Binary Opcode should be specified for FP induction"); - InitVec = Builder.CreateUIToFP(InitVec, ValVTy); - InitVec = Builder.CreateFAdd(InitVec, StartIdxSplat); - - Step = Builder.CreateVectorSplat(VLen, Step); - Value *MulOp = Builder.CreateFMul(InitVec, Step); - return Builder.CreateBinOp(BinOp, Val, MulOp, "induction"); -} - -/// Compute scalar induction steps. \p ScalarIV is the scalar induction -/// variable on which to base the steps, \p Step is the size of the step. -static void buildScalarSteps(Value *ScalarIV, Value *Step, - const InductionDescriptor &ID, VPValue *Def, - VPTransformState &State) { - IRBuilderBase &Builder = State.Builder; - - // Ensure step has the same type as that of scalar IV. - Type *ScalarIVTy = ScalarIV->getType()->getScalarType(); - if (ScalarIVTy != Step->getType()) { - // TODO: Also use VPDerivedIVRecipe when only the step needs truncating, to - // avoid separate truncate here. - assert(Step->getType()->isIntegerTy() && - "Truncation requires an integer step"); - Step = State.Builder.CreateTrunc(Step, ScalarIVTy); - } - - // We build scalar steps for both integer and floating-point induction - // variables. Here, we determine the kind of arithmetic we will perform. - Instruction::BinaryOps AddOp; - Instruction::BinaryOps MulOp; - if (ScalarIVTy->isIntegerTy()) { - AddOp = Instruction::Add; - MulOp = Instruction::Mul; - } else { - AddOp = ID.getInductionOpcode(); - MulOp = Instruction::FMul; - } - - // Determine the number of scalars we need to generate for each unroll - // iteration. - bool FirstLaneOnly = vputils::onlyFirstLaneUsed(Def); - // Compute the scalar steps and save the results in State. - Type *IntStepTy = IntegerType::get(ScalarIVTy->getContext(), - ScalarIVTy->getScalarSizeInBits()); - Type *VecIVTy = nullptr; - Value *UnitStepVec = nullptr, *SplatStep = nullptr, *SplatIV = nullptr; - if (!FirstLaneOnly && State.VF.isScalable()) { - VecIVTy = VectorType::get(ScalarIVTy, State.VF); - UnitStepVec = - Builder.CreateStepVector(VectorType::get(IntStepTy, State.VF)); - SplatStep = Builder.CreateVectorSplat(State.VF, Step); - SplatIV = Builder.CreateVectorSplat(State.VF, ScalarIV); - } - - unsigned StartPart = 0; - unsigned EndPart = State.UF; - unsigned StartLane = 0; - unsigned EndLane = FirstLaneOnly ? 1 : State.VF.getKnownMinValue(); - if (State.Instance) { - StartPart = State.Instance->Part; - EndPart = StartPart + 1; - StartLane = State.Instance->Lane.getKnownLane(); - EndLane = StartLane + 1; - } - for (unsigned Part = StartPart; Part < EndPart; ++Part) { - Value *StartIdx0 = createStepForVF(Builder, IntStepTy, State.VF, Part); - - if (!FirstLaneOnly && State.VF.isScalable()) { - auto *SplatStartIdx = Builder.CreateVectorSplat(State.VF, StartIdx0); - auto *InitVec = Builder.CreateAdd(SplatStartIdx, UnitStepVec); - if (ScalarIVTy->isFloatingPointTy()) - InitVec = Builder.CreateSIToFP(InitVec, VecIVTy); - auto *Mul = Builder.CreateBinOp(MulOp, InitVec, SplatStep); - auto *Add = Builder.CreateBinOp(AddOp, SplatIV, Mul); - State.set(Def, Add, Part); - // It's useful to record the lane values too for the known minimum number - // of elements so we do those below. This improves the code quality when - // trying to extract the first element, for example. - } - - if (ScalarIVTy->isFloatingPointTy()) - StartIdx0 = Builder.CreateSIToFP(StartIdx0, ScalarIVTy); - - for (unsigned Lane = StartLane; Lane < EndLane; ++Lane) { - Value *StartIdx = Builder.CreateBinOp( - AddOp, StartIdx0, getSignedIntOrFpConstant(ScalarIVTy, Lane)); - // The step returned by `createStepForVF` is a runtime-evaluated value - // when VF is scalable. Otherwise, it should be folded into a Constant. - assert((State.VF.isScalable() || isa<Constant>(StartIdx)) && - "Expected StartIdx to be folded to a constant when VF is not " - "scalable"); - auto *Mul = Builder.CreateBinOp(MulOp, StartIdx, Step); - auto *Add = Builder.CreateBinOp(AddOp, ScalarIV, Mul); - State.set(Def, Add, VPIteration(Part, Lane)); - } - } -} - /// Compute the transformed value of Index at offset StartValue using step /// StepValue. /// For integer induction, returns StartValue + Index * StepValue. /// For pointer induction, returns StartValue[Index * StepValue]. /// FIXME: The newly created binary instructions should contain nsw/nuw /// flags, which can be found from the original scalar operations. -static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, - Value *StartValue, Value *Step, - const InductionDescriptor &ID) { +static Value * +emitTransformedIndex(IRBuilderBase &B, Value *Index, Value *StartValue, + Value *Step, + InductionDescriptor::InductionKind InductionKind, + const BinaryOperator *InductionBinOp) { Type *StepTy = Step->getType(); Value *CastedIndex = StepTy->isIntegerTy() ? B.CreateSExtOrTrunc(Index, StepTy) @@ -2446,7 +2350,7 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, return B.CreateMul(X, Y); }; - switch (ID.getKind()) { + switch (InductionKind) { case InductionDescriptor::IK_IntInduction: { assert(!isa<VectorType>(Index->getType()) && "Vector indices not supported for integer inductions yet"); @@ -2464,7 +2368,6 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, assert(!isa<VectorType>(Index->getType()) && "Vector indices not supported for FP inductions yet"); assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value"); - auto InductionBinOp = ID.getInductionBinOp(); assert(InductionBinOp && (InductionBinOp->getOpcode() == Instruction::FAdd || InductionBinOp->getOpcode() == Instruction::FSub) && @@ -2524,17 +2427,6 @@ static bool isIndvarOverflowCheckKnownFalse( return false; } -void InnerLoopVectorizer::packScalarIntoVectorValue(VPValue *Def, - const VPIteration &Instance, - VPTransformState &State) { - Value *ScalarInst = State.get(Def, Instance); - Value *VectorValue = State.get(Def, Instance.Part); - VectorValue = Builder.CreateInsertElement( - VectorValue, ScalarInst, - Instance.Lane.getAsRuntimeExpr(State.Builder, VF)); - State.set(Def, VectorValue, Instance.Part); -} - // Return whether we allow using masked interleave-groups (for dealing with // strided loads/stores that reside in predicated blocks, or for dealing // with gaps). @@ -2612,7 +2504,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( for (unsigned Part = 0; Part < UF; Part++) { Value *AddrPart = State.get(Addr, VPIteration(Part, 0)); - State.setDebugLocFromInst(AddrPart); + if (auto *I = dyn_cast<Instruction>(AddrPart)) + State.setDebugLocFrom(I->getDebugLoc()); // Notice current instruction could be any index. Need to adjust the address // to the member of index 0. @@ -2630,14 +2523,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( if (auto *gep = dyn_cast<GetElementPtrInst>(AddrPart->stripPointerCasts())) InBounds = gep->isInBounds(); AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Idx, "", InBounds); - - // Cast to the vector pointer type. - unsigned AddressSpace = AddrPart->getType()->getPointerAddressSpace(); - Type *PtrTy = VecTy->getPointerTo(AddressSpace); - AddrParts.push_back(Builder.CreateBitCast(AddrPart, PtrTy)); + AddrParts.push_back(AddrPart); } - State.setDebugLocFromInst(Instr); + State.setDebugLocFrom(Instr->getDebugLoc()); Value *PoisonVec = PoisonValue::get(VecTy); auto CreateGroupMask = [this, &BlockInMask, &State, &InterleaveFactor]( @@ -2835,13 +2724,20 @@ void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, bool IsVoidRetTy = Instr->getType()->isVoidTy(); Instruction *Cloned = Instr->clone(); - if (!IsVoidRetTy) + if (!IsVoidRetTy) { Cloned->setName(Instr->getName() + ".cloned"); +#if !defined(NDEBUG) + // Verify that VPlan type inference results agree with the type of the + // generated values. + assert(State.TypeAnalysis.inferScalarType(RepRecipe) == Cloned->getType() && + "inferred type and type from generated instructions do not match"); +#endif + } RepRecipe->setFlags(Cloned); - if (Instr->getDebugLoc()) - State.setDebugLocFromInst(Instr); + if (auto DL = Instr->getDebugLoc()) + State.setDebugLocFrom(DL); // Replace the operands of the cloned instructions with their scalar // equivalents in the new loop. @@ -3019,9 +2915,11 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { // dominator of the exit blocks. DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock); - ReplaceInstWithInst( - TCCheckBlock->getTerminator(), - BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters)); + BranchInst &BI = + *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters); + if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) + setBranchWeights(BI, MinItersBypassWeights); + ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI); LoopBypassBlocks.push_back(TCCheckBlock); } @@ -3151,15 +3049,17 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue( if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp())) B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags()); - EndValue = - emitTransformedIndex(B, VectorTripCount, II.getStartValue(), Step, II); + EndValue = emitTransformedIndex(B, VectorTripCount, II.getStartValue(), + Step, II.getKind(), II.getInductionBinOp()); EndValue->setName("ind.end"); // Compute the end value for the additional bypass (if applicable). if (AdditionalBypass.first) { - B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt())); - EndValueFromAdditionalBypass = emitTransformedIndex( - B, AdditionalBypass.second, II.getStartValue(), Step, II); + B.SetInsertPoint(AdditionalBypass.first, + AdditionalBypass.first->getFirstInsertionPt()); + EndValueFromAdditionalBypass = + emitTransformedIndex(B, AdditionalBypass.second, II.getStartValue(), + Step, II.getKind(), II.getInductionBinOp()); EndValueFromAdditionalBypass->setName("ind.end"); } } @@ -3240,16 +3140,25 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() { // 3) Otherwise, construct a runtime check. if (!Cost->requiresScalarEpilogue(VF.isVector()) && !Cost->foldTailByMasking()) { - Instruction *CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, - Count, VectorTripCount, "cmp.n", - LoopMiddleBlock->getTerminator()); - // Here we use the same DebugLoc as the scalar loop latch terminator instead // of the corresponding compare because they may have ended up with // different line numbers and we want to avoid awkward line stepping while // debugging. Eg. if the compare has got a line number inside the loop. - CmpN->setDebugLoc(ScalarLatchTerm->getDebugLoc()); - cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN); + // TODO: At the moment, CreateICmpEQ will simplify conditions with constant + // operands. Perform simplification directly on VPlan once the branch is + // modeled there. + IRBuilder<> B(LoopMiddleBlock->getTerminator()); + B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc()); + Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n"); + BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator()); + BI.setCondition(CmpN); + if (hasBranchWeightMD(*ScalarLatchTerm)) { + // Assume that `Count % VectorTripCount` is equally distributed. + unsigned TripCount = UF * VF.getKnownMinValue(); + assert(TripCount > 0 && "trip count should not be zero"); + const uint32_t Weights[] = {1, TripCount - 1}; + setBranchWeights(BI, Weights); + } } #ifdef EXPENSIVE_CHECKS @@ -3373,7 +3282,8 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue() : State.get(StepVPV, {0, 0}); Value *Escape = - emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step, II); + emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step, + II.getKind(), II.getInductionBinOp()); Escape->setName("ind.escape"); MissingVals[UI] = Escape; } @@ -3445,76 +3355,33 @@ static void cse(BasicBlock *BB) { } } -InstructionCost LoopVectorizationCostModel::getVectorCallCost( - CallInst *CI, ElementCount VF, Function **Variant, bool *NeedsMask) const { - Function *F = CI->getCalledFunction(); - Type *ScalarRetTy = CI->getType(); - SmallVector<Type *, 4> Tys, ScalarTys; - bool MaskRequired = Legal->isMaskRequired(CI); - for (auto &ArgOp : CI->args()) - ScalarTys.push_back(ArgOp->getType()); +InstructionCost +LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, + ElementCount VF) const { + // We only need to calculate a cost if the VF is scalar; for actual vectors + // we should already have a pre-calculated cost at each VF. + if (!VF.isScalar()) + return CallWideningDecisions.at(std::make_pair(CI, VF)).Cost; - // Estimate cost of scalarized vector call. The source operands are assumed - // to be vectors, so we need to extract individual elements from there, - // execute VF scalar calls, and then gather the result into the vector return - // value. TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - InstructionCost ScalarCallCost = - TTI.getCallInstrCost(F, ScalarRetTy, ScalarTys, CostKind); - if (VF.isScalar()) - return ScalarCallCost; - - // Compute corresponding vector type for return value and arguments. - Type *RetTy = ToVectorTy(ScalarRetTy, VF); - for (Type *ScalarTy : ScalarTys) - Tys.push_back(ToVectorTy(ScalarTy, VF)); - - // Compute costs of unpacking argument values for the scalar calls and - // packing the return values to a vector. - InstructionCost ScalarizationCost = - getScalarizationOverhead(CI, VF, CostKind); - - InstructionCost Cost = - ScalarCallCost * VF.getKnownMinValue() + ScalarizationCost; + Type *RetTy = CI->getType(); + if (RecurrenceDescriptor::isFMulAddIntrinsic(CI)) + if (auto RedCost = getReductionPatternCost(CI, VF, RetTy, CostKind)) + return *RedCost; - // If we can't emit a vector call for this function, then the currently found - // cost is the cost we need to return. - InstructionCost MaskCost = 0; - VFShape Shape = VFShape::get(*CI, VF, MaskRequired); - if (NeedsMask) - *NeedsMask = MaskRequired; - Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); - // If we want an unmasked vector function but can't find one matching the VF, - // maybe we can find vector function that does use a mask and synthesize - // an all-true mask. - if (!VecFunc && !MaskRequired) { - Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true); - VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); - // If we found one, add in the cost of creating a mask - if (VecFunc) { - if (NeedsMask) - *NeedsMask = true; - MaskCost = TTI.getShuffleCost( - TargetTransformInfo::SK_Broadcast, - VectorType::get( - IntegerType::getInt1Ty(VecFunc->getFunctionType()->getContext()), - VF)); - } - } + SmallVector<Type *, 4> Tys; + for (auto &ArgOp : CI->args()) + Tys.push_back(ArgOp->getType()); - // We don't support masked function calls yet, but we can scalarize a - // masked call with branches (unless VF is scalable). - if (!TLI || CI->isNoBuiltin() || !VecFunc) - return VF.isScalable() ? InstructionCost::getInvalid() : Cost; + InstructionCost ScalarCallCost = + TTI.getCallInstrCost(CI->getCalledFunction(), RetTy, Tys, CostKind); - // If the corresponding vector cost is cheaper, return its cost. - InstructionCost VectorCallCost = - TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost; - if (VectorCallCost < Cost) { - *Variant = VecFunc; - Cost = VectorCallCost; + // If this is an intrinsic we may have a lower cost for it. + if (getVectorIntrinsicIDForCall(CI, TLI)) { + InstructionCost IntrinsicCost = getVectorIntrinsicCost(CI, VF); + return std::min(ScalarCallCost, IntrinsicCost); } - return Cost; + return ScalarCallCost; } static Type *MaybeVectorizeType(Type *Elt, ElementCount VF) { @@ -3558,146 +3425,8 @@ static Type *largestIntegerVectorType(Type *T1, Type *T2) { return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; } -void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) { - // For every instruction `I` in MinBWs, truncate the operands, create a - // truncated version of `I` and reextend its result. InstCombine runs - // later and will remove any ext/trunc pairs. - SmallPtrSet<Value *, 4> Erased; - for (const auto &KV : Cost->getMinimalBitwidths()) { - // If the value wasn't vectorized, we must maintain the original scalar - // type. The absence of the value from State indicates that it - // wasn't vectorized. - // FIXME: Should not rely on getVPValue at this point. - VPValue *Def = State.Plan->getVPValue(KV.first, true); - if (!State.hasAnyVectorValue(Def)) - continue; - for (unsigned Part = 0; Part < UF; ++Part) { - Value *I = State.get(Def, Part); - if (Erased.count(I) || I->use_empty() || !isa<Instruction>(I)) - continue; - Type *OriginalTy = I->getType(); - Type *ScalarTruncatedTy = - IntegerType::get(OriginalTy->getContext(), KV.second); - auto *TruncatedTy = VectorType::get( - ScalarTruncatedTy, cast<VectorType>(OriginalTy)->getElementCount()); - if (TruncatedTy == OriginalTy) - continue; - - IRBuilder<> B(cast<Instruction>(I)); - auto ShrinkOperand = [&](Value *V) -> Value * { - if (auto *ZI = dyn_cast<ZExtInst>(V)) - if (ZI->getSrcTy() == TruncatedTy) - return ZI->getOperand(0); - return B.CreateZExtOrTrunc(V, TruncatedTy); - }; - - // The actual instruction modification depends on the instruction type, - // unfortunately. - Value *NewI = nullptr; - if (auto *BO = dyn_cast<BinaryOperator>(I)) { - NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)), - ShrinkOperand(BO->getOperand(1))); - - // Any wrapping introduced by shrinking this operation shouldn't be - // considered undefined behavior. So, we can't unconditionally copy - // arithmetic wrapping flags to NewI. - cast<BinaryOperator>(NewI)->copyIRFlags(I, /*IncludeWrapFlags=*/false); - } else if (auto *CI = dyn_cast<ICmpInst>(I)) { - NewI = - B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)), - ShrinkOperand(CI->getOperand(1))); - } else if (auto *SI = dyn_cast<SelectInst>(I)) { - NewI = B.CreateSelect(SI->getCondition(), - ShrinkOperand(SI->getTrueValue()), - ShrinkOperand(SI->getFalseValue())); - } else if (auto *CI = dyn_cast<CastInst>(I)) { - switch (CI->getOpcode()) { - default: - llvm_unreachable("Unhandled cast!"); - case Instruction::Trunc: - NewI = ShrinkOperand(CI->getOperand(0)); - break; - case Instruction::SExt: - NewI = B.CreateSExtOrTrunc( - CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, TruncatedTy)); - break; - case Instruction::ZExt: - NewI = B.CreateZExtOrTrunc( - CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, TruncatedTy)); - break; - } - } else if (auto *SI = dyn_cast<ShuffleVectorInst>(I)) { - auto Elements0 = - cast<VectorType>(SI->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0)); - auto Elements1 = - cast<VectorType>(SI->getOperand(1)->getType())->getElementCount(); - auto *O1 = B.CreateZExtOrTrunc( - SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1)); - - NewI = B.CreateShuffleVector(O0, O1, SI->getShuffleMask()); - } else if (isa<LoadInst>(I) || isa<PHINode>(I)) { - // Don't do anything with the operands, just extend the result. - continue; - } else if (auto *IE = dyn_cast<InsertElementInst>(I)) { - auto Elements = - cast<VectorType>(IE->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); - auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy); - NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2)); - } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) { - auto Elements = - cast<VectorType>(EE->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); - NewI = B.CreateExtractElement(O0, EE->getOperand(2)); - } else { - // If we don't know what to do, be conservative and don't do anything. - continue; - } - - // Lastly, extend the result. - NewI->takeName(cast<Instruction>(I)); - Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); - I->replaceAllUsesWith(Res); - cast<Instruction>(I)->eraseFromParent(); - Erased.insert(I); - State.reset(Def, Res, Part); - } - } - - // We'll have created a bunch of ZExts that are now parentless. Clean up. - for (const auto &KV : Cost->getMinimalBitwidths()) { - // If the value wasn't vectorized, we must maintain the original scalar - // type. The absence of the value from State indicates that it - // wasn't vectorized. - // FIXME: Should not rely on getVPValue at this point. - VPValue *Def = State.Plan->getVPValue(KV.first, true); - if (!State.hasAnyVectorValue(Def)) - continue; - for (unsigned Part = 0; Part < UF; ++Part) { - Value *I = State.get(Def, Part); - ZExtInst *Inst = dyn_cast<ZExtInst>(I); - if (Inst && Inst->use_empty()) { - Value *NewI = Inst->getOperand(0); - Inst->eraseFromParent(); - State.reset(Def, NewI, Part); - } - } - } -} - void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, VPlan &Plan) { - // Insert truncates and extends for any truncated instructions as hints to - // InstCombine. - if (VF.isVector()) - truncateToMinimalBitwidths(State); - // Fix widened non-induction PHIs by setting up the PHI operands. if (EnableVPlanNativePath) fixNonInductionPHIs(Plan, State); @@ -3710,6 +3439,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, // Forget the original basic block. PSE.getSE()->forgetLoop(OrigLoop); + PSE.getSE()->forgetBlockAndLoopDispositions(); // After vectorization, the exit blocks of the original loop will have // additional predecessors. Invalidate SCEVs for the exit phis in case SE @@ -3718,7 +3448,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, OrigLoop->getExitBlocks(ExitBlocks); for (BasicBlock *Exit : ExitBlocks) for (PHINode &PN : Exit->phis()) - PSE.getSE()->forgetValue(&PN); + PSE.getSE()->forgetLcssaPhiWithNewPredecessor(OrigLoop, &PN); VPBasicBlock *LatchVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock(); Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]); @@ -3744,7 +3474,8 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, // Fix LCSSA phis not already fixed earlier. Extracts may need to be generated // in the exit block, so update the builder. - State.Builder.SetInsertPoint(State.CFG.ExitBB->getFirstNonPHI()); + State.Builder.SetInsertPoint(State.CFG.ExitBB, + State.CFG.ExitBB->getFirstNonPHIIt()); for (const auto &KV : Plan.getLiveOuts()) KV.second->fixPhi(Plan, State); @@ -3781,10 +3512,14 @@ void InnerLoopVectorizer::fixCrossIterationPHIs(VPTransformState &State) { // the incoming edges. VPBasicBlock *Header = State.Plan->getVectorLoopRegion()->getEntryBasicBlock(); + for (VPRecipeBase &R : Header->phis()) { if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) fixReduction(ReductionPhi, State); - else if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) + } + + for (VPRecipeBase &R : Header->phis()) { + if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) fixFixedOrderRecurrence(FOR, State); } } @@ -3895,7 +3630,7 @@ void InnerLoopVectorizer::fixFixedOrderRecurrence( } // Fix the initial value of the original recurrence in the scalar loop. - Builder.SetInsertPoint(&*LoopScalarPreHeader->begin()); + Builder.SetInsertPoint(LoopScalarPreHeader, LoopScalarPreHeader->begin()); PHINode *Phi = cast<PHINode>(PhiR->getUnderlyingValue()); auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init"); auto *ScalarInit = PhiR->getStartValue()->getLiveInIRValue(); @@ -3919,90 +3654,56 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, RecurKind RK = RdxDesc.getRecurrenceKind(); TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); - State.setDebugLocFromInst(ReductionStartValue); + if (auto *I = dyn_cast<Instruction>(&*ReductionStartValue)) + State.setDebugLocFrom(I->getDebugLoc()); VPValue *LoopExitInstDef = PhiR->getBackedgeValue(); - // This is the vector-clone of the value that leaves the loop. - Type *VecTy = State.get(LoopExitInstDef, 0)->getType(); // Before each round, move the insertion point right between // the PHIs and the values we are going to write. // This allows us to write both PHINodes and the extractelement // instructions. - Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); + Builder.SetInsertPoint(LoopMiddleBlock, + LoopMiddleBlock->getFirstInsertionPt()); - State.setDebugLocFromInst(LoopExitInst); + State.setDebugLocFrom(LoopExitInst->getDebugLoc()); Type *PhiTy = OrigPhi->getType(); - - VPBasicBlock *LatchVPBB = - PhiR->getParent()->getEnclosingLoopRegion()->getExitingBasicBlock(); - BasicBlock *VectorLoopLatch = State.CFG.VPBB2IRBB[LatchVPBB]; // If tail is folded by masking, the vector value to leave the loop should be // a Select choosing between the vectorized LoopExitInst and vectorized Phi, // instead of the former. For an inloop reduction the reduction will already // be predicated, and does not need to be handled here. if (Cost->foldTailByMasking() && !PhiR->isInLoop()) { - for (unsigned Part = 0; Part < UF; ++Part) { - Value *VecLoopExitInst = State.get(LoopExitInstDef, Part); - SelectInst *Sel = nullptr; - for (User *U : VecLoopExitInst->users()) { - if (isa<SelectInst>(U)) { - assert(!Sel && "Reduction exit feeding two selects"); - Sel = cast<SelectInst>(U); - } else - assert(isa<PHINode>(U) && "Reduction exit must feed Phi's or select"); - } - assert(Sel && "Reduction exit feeds no select"); - State.reset(LoopExitInstDef, Sel, Part); - - if (isa<FPMathOperator>(Sel)) - Sel->setFastMathFlags(RdxDesc.getFastMathFlags()); - - // If the target can create a predicated operator for the reduction at no - // extra cost in the loop (for example a predicated vadd), it can be - // cheaper for the select to remain in the loop than be sunk out of it, - // and so use the select value for the phi instead of the old - // LoopExitValue. - if (PreferPredicatedReductionSelect || - TTI->preferPredicatedReductionSelect( - RdxDesc.getOpcode(), PhiTy, - TargetTransformInfo::ReductionFlags())) { - auto *VecRdxPhi = - cast<PHINode>(State.get(PhiR, Part)); - VecRdxPhi->setIncomingValueForBlock(VectorLoopLatch, Sel); + VPValue *Def = nullptr; + for (VPUser *U : LoopExitInstDef->users()) { + auto *S = dyn_cast<VPInstruction>(U); + if (S && S->getOpcode() == Instruction::Select) { + Def = S; + break; } } + if (Def) + LoopExitInstDef = Def; } + VectorParts RdxParts(UF); + for (unsigned Part = 0; Part < UF; ++Part) + RdxParts[Part] = State.get(LoopExitInstDef, Part); + // If the vector reduction can be performed in a smaller type, we truncate // then extend the loop exit value to enable InstCombine to evaluate the // entire expression in the smaller type. if (VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) { - assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!"); + Builder.SetInsertPoint(LoopMiddleBlock, + LoopMiddleBlock->getFirstInsertionPt()); Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); - Builder.SetInsertPoint(VectorLoopLatch->getTerminator()); - VectorParts RdxParts(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - RdxParts[Part] = State.get(LoopExitInstDef, Part); - Value *Trunc = Builder.CreateTrunc(RdxParts[Part], RdxVecTy); - Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) - : Builder.CreateZExt(Trunc, VecTy); - for (User *U : llvm::make_early_inc_range(RdxParts[Part]->users())) - if (U != Trunc) { - U->replaceUsesOfWith(RdxParts[Part], Extnd); - RdxParts[Part] = Extnd; - } - } - Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); for (unsigned Part = 0; Part < UF; ++Part) { RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy); - State.reset(LoopExitInstDef, RdxParts[Part], Part); } } // Reduce all of the unrolled parts into a single vector. - Value *ReducedPartRdx = State.get(LoopExitInstDef, 0); + Value *ReducedPartRdx = RdxParts[0]; unsigned Op = RecurrenceDescriptor::getOpcode(RK); // The middle block terminator has already been assigned a DebugLoc here (the @@ -4012,21 +3713,21 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // conditional branch, and (c) other passes may add new predecessors which // terminate on this line. This is the easiest way to ensure we don't // accidentally cause an extra step back into the loop while debugging. - State.setDebugLocFromInst(LoopMiddleBlock->getTerminator()); + State.setDebugLocFrom(LoopMiddleBlock->getTerminator()->getDebugLoc()); if (PhiR->isOrdered()) - ReducedPartRdx = State.get(LoopExitInstDef, UF - 1); + ReducedPartRdx = RdxParts[UF - 1]; else { // Floating-point operations should have some FMF to enable the reduction. IRBuilderBase::FastMathFlagGuard FMFG(Builder); Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); for (unsigned Part = 1; Part < UF; ++Part) { - Value *RdxPart = State.get(LoopExitInstDef, Part); - if (Op != Instruction::ICmp && Op != Instruction::FCmp) { + Value *RdxPart = RdxParts[Part]; + if (Op != Instruction::ICmp && Op != Instruction::FCmp) ReducedPartRdx = Builder.CreateBinOp( (Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx"); - } else if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) - ReducedPartRdx = createSelectCmpOp(Builder, ReductionStartValue, RK, - ReducedPartRdx, RdxPart); + else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) + ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK, + ReducedPartRdx, RdxPart); else ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart); } @@ -4036,7 +3737,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // target reduction in the loop using a Reduction recipe. if (VF.isVector() && !PhiR->isInLoop()) { ReducedPartRdx = - createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, OrigPhi); + createTargetReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi); // If the reduction can be performed in a smaller type, we need to extend // the reduction to the wider type before we branch to the original loop. if (PhiTy != RdxDesc.getRecurrenceType()) @@ -4073,7 +3774,8 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // inside the loop, create the final store here. if (StoreInst *SI = RdxDesc.IntermediateStore) { StoreInst *NewSI = - Builder.CreateStore(ReducedPartRdx, SI->getPointerOperand()); + Builder.CreateAlignedStore(ReducedPartRdx, SI->getPointerOperand(), + SI->getAlign()); propagateMetadata(NewSI, SI); // If the reduction value is used in other places, @@ -4402,7 +4104,10 @@ bool LoopVectorizationCostModel::isScalarWithPredication( default: return true; case Instruction::Call: - return !VFDatabase::hasMaskedVariant(*(cast<CallInst>(I)), VF); + if (VF.isScalar()) + return true; + return CallWideningDecisions.at(std::make_pair(cast<CallInst>(I), VF)) + .Kind == CM_Scalarize; case Instruction::Load: case Instruction::Store: { auto *Ptr = getLoadStorePointerOperand(I); @@ -4954,7 +4659,7 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { } FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF( - unsigned ConstTripCount, ElementCount UserVF, bool FoldTailByMasking) { + unsigned MaxTripCount, ElementCount UserVF, bool FoldTailByMasking) { MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); unsigned SmallestType, WidestType; std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); @@ -5042,12 +4747,12 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF( FixedScalableVFPair Result(ElementCount::getFixed(1), ElementCount::getScalable(0)); if (auto MaxVF = - getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType, + getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType, MaxSafeFixedVF, FoldTailByMasking)) Result.FixedVF = MaxVF; if (auto MaxVF = - getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType, + getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType, MaxSafeScalableVF, FoldTailByMasking)) if (MaxVF.isScalable()) { Result.ScalableVF = MaxVF; @@ -5071,6 +4776,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { } unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + unsigned MaxTC = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop); LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); if (TC == 1) { reportVectorizationFailure("Single iteration (non) loop", @@ -5081,7 +4787,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { switch (ScalarEpilogueStatus) { case CM_ScalarEpilogueAllowed: - return computeFeasibleMaxVF(TC, UserVF, false); + return computeFeasibleMaxVF(MaxTC, UserVF, false); case CM_ScalarEpilogueNotAllowedUsePredicate: [[fallthrough]]; case CM_ScalarEpilogueNotNeededUsePredicate: @@ -5119,7 +4825,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking: vectorize with a " "scalar epilogue instead.\n"); ScalarEpilogueStatus = CM_ScalarEpilogueAllowed; - return computeFeasibleMaxVF(TC, UserVF, false); + return computeFeasibleMaxVF(MaxTC, UserVF, false); } return FixedScalableVFPair::getNone(); } @@ -5136,7 +4842,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { InterleaveInfo.invalidateGroupsRequiringScalarEpilogue(); } - FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(TC, UserVF, true); + FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(MaxTC, UserVF, true); // Avoid tail folding if the trip count is known to be a multiple of any VF // we choose. @@ -5212,7 +4918,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { } ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( - unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType, + unsigned MaxTripCount, unsigned SmallestType, unsigned WidestType, ElementCount MaxSafeVF, bool FoldTailByMasking) { bool ComputeScalableMaxVF = MaxSafeVF.isScalable(); const TypeSize WidestRegister = TTI.getRegisterBitWidth( @@ -5251,31 +4957,35 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( } // When a scalar epilogue is required, at least one iteration of the scalar - // loop has to execute. Adjust ConstTripCount accordingly to avoid picking a + // loop has to execute. Adjust MaxTripCount accordingly to avoid picking a // max VF that results in a dead vector loop. - if (ConstTripCount > 0 && requiresScalarEpilogue(true)) - ConstTripCount -= 1; + if (MaxTripCount > 0 && requiresScalarEpilogue(true)) + MaxTripCount -= 1; - if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC && - (!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) { - // If loop trip count (TC) is known at compile time there is no point in - // choosing VF greater than TC (as done in the loop below). Select maximum - // power of two which doesn't exceed TC. - // If MaxVectorElementCount is scalable, we only fall back on a fixed VF - // when the TC is less than or equal to the known number of lanes. - auto ClampedConstTripCount = llvm::bit_floor(ConstTripCount); + if (MaxTripCount && MaxTripCount <= WidestRegisterMinEC && + (!FoldTailByMasking || isPowerOf2_32(MaxTripCount))) { + // If upper bound loop trip count (TC) is known at compile time there is no + // point in choosing VF greater than TC (as done in the loop below). Select + // maximum power of two which doesn't exceed TC. If MaxVectorElementCount is + // scalable, we only fall back on a fixed VF when the TC is less than or + // equal to the known number of lanes. + auto ClampedUpperTripCount = llvm::bit_floor(MaxTripCount); LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not " "exceeding the constant trip count: " - << ClampedConstTripCount << "\n"); - return ElementCount::getFixed(ClampedConstTripCount); + << ClampedUpperTripCount << "\n"); + return ElementCount::get( + ClampedUpperTripCount, + FoldTailByMasking ? MaxVectorElementCount.isScalable() : false); } TargetTransformInfo::RegisterKind RegKind = ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector : TargetTransformInfo::RGK_FixedWidthVector; ElementCount MaxVF = MaxVectorElementCount; - if (MaximizeBandwidth || (MaximizeBandwidth.getNumOccurrences() == 0 && - TTI.shouldMaximizeVectorBandwidth(RegKind))) { + if (MaximizeBandwidth || + (MaximizeBandwidth.getNumOccurrences() == 0 && + (TTI.shouldMaximizeVectorBandwidth(RegKind) || + (UseWiderVFIfCallVariantsPresent && Legal->hasVectorCallVariants())))) { auto MaxVectorElementCountMaxBW = ElementCount::get( llvm::bit_floor(WidestRegister.getKnownMinValue() / SmallestType), ComputeScalableMaxVF); @@ -5947,7 +5657,7 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, HasReductions && any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool { const RecurrenceDescriptor &RdxDesc = Reduction.second; - return RecurrenceDescriptor::isSelectCmpRecurrenceKind( + return RecurrenceDescriptor::isAnyOfRecurrenceKind( RdxDesc.getRecurrenceKind()); }); if (HasSelectCmpReductions) { @@ -6115,6 +5825,8 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { if (ValuesToIgnore.count(I)) continue; + collectInLoopReductions(); + // For each VF find the maximum usage of registers. for (unsigned j = 0, e = VFs.size(); j < e; ++j) { // Count the number of registers used, per register class, given all open @@ -6634,10 +6346,11 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, std::optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( - Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) { + Instruction *I, ElementCount VF, Type *Ty, + TTI::TargetCostKind CostKind) const { using namespace llvm::PatternMatch; // Early exit for no inloop reductions - if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty)) + if (InLoopReductions.empty() || VF.isScalar() || !isa<VectorType>(Ty)) return std::nullopt; auto *VectorTy = cast<VectorType>(Ty); @@ -6672,10 +6385,10 @@ LoopVectorizationCostModel::getReductionPatternCost( // Find the reduction this chain is a part of and calculate the basic cost of // the reduction on its own. - Instruction *LastChain = InLoopReductionImmediateChains[RetI]; + Instruction *LastChain = InLoopReductionImmediateChains.at(RetI); Instruction *ReductionPhi = LastChain; while (!isa<PHINode>(ReductionPhi)) - ReductionPhi = InLoopReductionImmediateChains[ReductionPhi]; + ReductionPhi = InLoopReductionImmediateChains.at(ReductionPhi); const RecurrenceDescriptor &RdxDesc = Legal->getReductionVars().find(cast<PHINode>(ReductionPhi))->second; @@ -7093,6 +6806,168 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { } } +void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) { + assert(!VF.isScalar() && + "Trying to set a vectorization decision for a scalar VF"); + + for (BasicBlock *BB : TheLoop->blocks()) { + // For each instruction in the old loop. + for (Instruction &I : *BB) { + CallInst *CI = dyn_cast<CallInst>(&I); + + if (!CI) + continue; + + InstructionCost ScalarCost = InstructionCost::getInvalid(); + InstructionCost VectorCost = InstructionCost::getInvalid(); + InstructionCost IntrinsicCost = InstructionCost::getInvalid(); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + Function *ScalarFunc = CI->getCalledFunction(); + Type *ScalarRetTy = CI->getType(); + SmallVector<Type *, 4> Tys, ScalarTys; + bool MaskRequired = Legal->isMaskRequired(CI); + for (auto &ArgOp : CI->args()) + ScalarTys.push_back(ArgOp->getType()); + + // Compute corresponding vector type for return value and arguments. + Type *RetTy = ToVectorTy(ScalarRetTy, VF); + for (Type *ScalarTy : ScalarTys) + Tys.push_back(ToVectorTy(ScalarTy, VF)); + + // An in-loop reduction using an fmuladd intrinsic is a special case; + // we don't want the normal cost for that intrinsic. + if (RecurrenceDescriptor::isFMulAddIntrinsic(CI)) + if (auto RedCost = getReductionPatternCost(CI, VF, RetTy, CostKind)) { + setCallWideningDecision(CI, VF, CM_IntrinsicCall, nullptr, + getVectorIntrinsicIDForCall(CI, TLI), + std::nullopt, *RedCost); + continue; + } + + // Estimate cost of scalarized vector call. The source operands are + // assumed to be vectors, so we need to extract individual elements from + // there, execute VF scalar calls, and then gather the result into the + // vector return value. + InstructionCost ScalarCallCost = + TTI.getCallInstrCost(ScalarFunc, ScalarRetTy, ScalarTys, CostKind); + + // Compute costs of unpacking argument values for the scalar calls and + // packing the return values to a vector. + InstructionCost ScalarizationCost = + getScalarizationOverhead(CI, VF, CostKind); + + ScalarCost = ScalarCallCost * VF.getKnownMinValue() + ScalarizationCost; + + // Find the cost of vectorizing the call, if we can find a suitable + // vector variant of the function. + bool UsesMask = false; + VFInfo FuncInfo; + Function *VecFunc = nullptr; + // Search through any available variants for one we can use at this VF. + for (VFInfo &Info : VFDatabase::getMappings(*CI)) { + // Must match requested VF. + if (Info.Shape.VF != VF) + continue; + + // Must take a mask argument if one is required + if (MaskRequired && !Info.isMasked()) + continue; + + // Check that all parameter kinds are supported + bool ParamsOk = true; + for (VFParameter Param : Info.Shape.Parameters) { + switch (Param.ParamKind) { + case VFParamKind::Vector: + break; + case VFParamKind::OMP_Uniform: { + Value *ScalarParam = CI->getArgOperand(Param.ParamPos); + // Make sure the scalar parameter in the loop is invariant. + if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(ScalarParam), + TheLoop)) + ParamsOk = false; + break; + } + case VFParamKind::OMP_Linear: { + Value *ScalarParam = CI->getArgOperand(Param.ParamPos); + // Find the stride for the scalar parameter in this loop and see if + // it matches the stride for the variant. + // TODO: do we need to figure out the cost of an extract to get the + // first lane? Or do we hope that it will be folded away? + ScalarEvolution *SE = PSE.getSE(); + const auto *SAR = + dyn_cast<SCEVAddRecExpr>(SE->getSCEV(ScalarParam)); + + if (!SAR || SAR->getLoop() != TheLoop) { + ParamsOk = false; + break; + } + + const SCEVConstant *Step = + dyn_cast<SCEVConstant>(SAR->getStepRecurrence(*SE)); + + if (!Step || + Step->getAPInt().getSExtValue() != Param.LinearStepOrPos) + ParamsOk = false; + + break; + } + case VFParamKind::GlobalPredicate: + UsesMask = true; + break; + default: + ParamsOk = false; + break; + } + } + + if (!ParamsOk) + continue; + + // Found a suitable candidate, stop here. + VecFunc = CI->getModule()->getFunction(Info.VectorName); + FuncInfo = Info; + break; + } + + // Add in the cost of synthesizing a mask if one wasn't required. + InstructionCost MaskCost = 0; + if (VecFunc && UsesMask && !MaskRequired) + MaskCost = TTI.getShuffleCost( + TargetTransformInfo::SK_Broadcast, + VectorType::get(IntegerType::getInt1Ty( + VecFunc->getFunctionType()->getContext()), + VF)); + + if (TLI && VecFunc && !CI->isNoBuiltin()) + VectorCost = + TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost; + + // Find the cost of an intrinsic; some targets may have instructions that + // perform the operation without needing an actual call. + Intrinsic::ID IID = getVectorIntrinsicIDForCall(CI, TLI); + if (IID != Intrinsic::not_intrinsic) + IntrinsicCost = getVectorIntrinsicCost(CI, VF); + + InstructionCost Cost = ScalarCost; + InstWidening Decision = CM_Scalarize; + + if (VectorCost <= Cost) { + Cost = VectorCost; + Decision = CM_VectorCall; + } + + if (IntrinsicCost <= Cost) { + Cost = IntrinsicCost; + Decision = CM_IntrinsicCall; + } + + setCallWideningDecision(CI, VF, Decision, VecFunc, IID, + FuncInfo.getParamIndexForOptionalMask(), Cost); + } + } +} + InstructionCost LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, Type *&VectorTy) { @@ -7122,7 +6997,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, // With the exception of GEPs and PHIs, after scalarization there should // only be one copy of the instruction generated in the loop. This is // because the VF is either 1, or any instructions that need scalarizing - // have already been dealt with by the the time we get here. As a result, + // have already been dealt with by the time we get here. As a result, // it means we don't have to multiply the instruction cost by VF. assert(I->getOpcode() == Instruction::GetElementPtr || I->getOpcode() == Instruction::PHI || @@ -7350,6 +7225,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, return TTI::CastContextHint::Reversed; case LoopVectorizationCostModel::CM_Unknown: llvm_unreachable("Instr did not go through cost modelling?"); + case LoopVectorizationCostModel::CM_VectorCall: + case LoopVectorizationCostModel::CM_IntrinsicCall: + llvm_unreachable_internal("Instr has invalid widening decision"); } llvm_unreachable("Unhandled case!"); @@ -7407,19 +7285,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, return TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } - case Instruction::Call: { - if (RecurrenceDescriptor::isFMulAddIntrinsic(I)) - if (auto RedCost = getReductionPatternCost(I, VF, VectorTy, CostKind)) - return *RedCost; - Function *Variant; - CallInst *CI = cast<CallInst>(I); - InstructionCost CallCost = getVectorCallCost(CI, VF, &Variant); - if (getVectorIntrinsicIDForCall(CI, TLI)) { - InstructionCost IntrinsicCost = getVectorIntrinsicCost(CI, VF); - return std::min(CallCost, IntrinsicCost); - } - return CallCost; - } + case Instruction::Call: + return getVectorCallCost(cast<CallInst>(I), VF); case Instruction::ExtractValue: return TTI.getInstructionCost(I, TTI::TCK_RecipThroughput); case Instruction::Alloca: @@ -7487,8 +7354,9 @@ void LoopVectorizationCostModel::collectInLoopReductions() { SmallVector<Instruction *, 4> ReductionOperations = RdxDesc.getReductionOpChain(Phi, TheLoop); bool InLoop = !ReductionOperations.empty(); + if (InLoop) { - InLoopReductionChains[Phi] = ReductionOperations; + InLoopReductions.insert(Phi); // Add the elements to InLoopReductionImmediateChains for cost modelling. Instruction *LastChain = Phi; for (auto *I : ReductionOperations) { @@ -7501,21 +7369,38 @@ void LoopVectorizationCostModel::collectInLoopReductions() { } } +VPValue *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B, + DebugLoc DL, const Twine &Name) { + assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE && + Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate"); + return tryInsertInstruction( + new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name)); +} + +// This function will select a scalable VF if the target supports scalable +// vectors and a fixed one otherwise. // TODO: we could return a pair of values that specify the max VF and // min VF, to be used in `buildVPlans(MinVF, MaxVF)` instead of // `buildVPlans(VF, VF)`. We cannot do it because VPLAN at the moment // doesn't have a cost model that can choose which plan to execute if // more than one is generated. -static unsigned determineVPlanVF(const unsigned WidestVectorRegBits, - LoopVectorizationCostModel &CM) { +static ElementCount determineVPlanVF(const TargetTransformInfo &TTI, + LoopVectorizationCostModel &CM) { unsigned WidestType; std::tie(std::ignore, WidestType) = CM.getSmallestAndWidestTypes(); - return WidestVectorRegBits / WidestType; + + TargetTransformInfo::RegisterKind RegKind = + TTI.enableScalableVectorization() + ? TargetTransformInfo::RGK_ScalableVector + : TargetTransformInfo::RGK_FixedWidthVector; + + TypeSize RegSize = TTI.getRegisterBitWidth(RegKind); + unsigned N = RegSize.getKnownMinValue() / WidestType; + return ElementCount::get(N, RegSize.isScalable()); } VectorizationFactor LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { - assert(!UserVF.isScalable() && "scalable vectors not yet supported"); ElementCount VF = UserVF; // Outer loop handling: They may require CFG and instruction level // transformations before even evaluating whether vectorization is profitable. @@ -7525,10 +7410,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { // If the user doesn't provide a vectorization factor, determine a // reasonable one. if (UserVF.isZero()) { - VF = ElementCount::getFixed(determineVPlanVF( - TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedValue(), - CM)); + VF = determineVPlanVF(TTI, CM); LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n"); // Make sure we have a VF > 1 for stress testing. @@ -7537,6 +7419,17 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { << "overriding computed VF.\n"); VF = ElementCount::getFixed(4); } + } else if (UserVF.isScalable() && !TTI.supportsScalableVectors() && + !ForceTargetSupportsScalableVectors) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing. Scalable VF requested, but " + << "not supported by the target.\n"); + reportVectorizationFailure( + "Scalable vectorization requested but not supported by the target", + "the scalable user-specified vectorization width for outer-loop " + "vectorization cannot be used because the target does not support " + "scalable vectors.", + "ScalableVFUnfeasible", ORE, OrigLoop); + return VectorizationFactor::Disabled(); } assert(EnableVPlanNativePath && "VPlan-native path is not enabled."); assert(isPowerOf2_32(VF.getKnownMinValue()) && @@ -7590,9 +7483,9 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { "VF needs to be a power of two"); // Collect the instructions (and their associated costs) that will be more // profitable to scalarize. + CM.collectInLoopReductions(); if (CM.selectUserVectorizationFactor(UserVF)) { LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); - CM.collectInLoopReductions(); buildVPlansWithVPRecipes(UserVF, UserVF); if (!hasPlanWithVF(UserVF)) { LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << UserVF @@ -7616,6 +7509,7 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { ElementCount::isKnownLE(VF, MaxFactors.ScalableVF); VF *= 2) VFCandidates.insert(VF); + CM.collectInLoopReductions(); for (const auto &VF : VFCandidates) { // Collect Uniform and Scalar instructions after vectorization with VF. CM.collectUniformsAndScalars(VF); @@ -7626,7 +7520,6 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { CM.collectInstsToScalarize(VF); } - CM.collectInLoopReductions(); buildVPlansWithVPRecipes(ElementCount::getFixed(1), MaxFactors.FixedVF); buildVPlansWithVPRecipes(ElementCount::getScalable(1), MaxFactors.ScalableVF); @@ -7671,7 +7564,7 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) { if (MD) { const auto *S = dyn_cast<MDString>(MD->getOperand(0)); IsUnrollMetadata = - S && S->getString().startswith("llvm.loop.unroll.disable"); + S && S->getString().starts_with("llvm.loop.unroll.disable"); } MDs.push_back(LoopID->getOperand(i)); } @@ -7695,7 +7588,7 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) { SCEV2ValueTy LoopVectorizationPlanner::executePlan( ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan, InnerLoopVectorizer &ILV, DominatorTree *DT, bool IsEpilogueVectorization, - DenseMap<const SCEV *, Value *> *ExpandedSCEVs) { + const DenseMap<const SCEV *, Value *> *ExpandedSCEVs) { assert(BestVPlan.hasVF(BestVF) && "Trying to execute plan with unsupported VF"); assert(BestVPlan.hasUF(BestUF) && @@ -7711,7 +7604,8 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan( VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE); // Perform the actual loop transformation. - VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan}; + VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan, + OrigLoop->getHeader()->getContext()); // 0. Generate SCEV-dependent code into the preheader, including TripCount, // before making any changes to the CFG. @@ -7764,9 +7658,9 @@ SCEV2ValueTy LoopVectorizationPlanner::executePlan( //===------------------------------------------------===// // 2. Copy and widen instructions from the old loop into the new loop. - BestVPlan.prepareToExecute( - ILV.getTripCount(), ILV.getOrCreateVectorTripCount(nullptr), - CanonicalIVStartValue, State, IsEpilogueVectorization); + BestVPlan.prepareToExecute(ILV.getTripCount(), + ILV.getOrCreateVectorTripCount(nullptr), + CanonicalIVStartValue, State); BestVPlan.execute(&State); @@ -7930,9 +7824,11 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, EPI.TripCount = Count; } - ReplaceInstWithInst( - TCCheckBlock->getTerminator(), - BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters)); + BranchInst &BI = + *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters); + if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) + setBranchWeights(BI, MinItersBypassWeights); + ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI); return TCCheckBlock; } @@ -8030,8 +7926,8 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton( // Generate a resume induction for the vector epilogue and put it in the // vector epilogue preheader Type *IdxTy = Legal->getWidestInductionType(); - PHINode *EPResumeVal = PHINode::Create(IdxTy, 2, "vec.epilog.resume.val", - LoopVectorPreHeader->getFirstNonPHI()); + PHINode *EPResumeVal = PHINode::Create(IdxTy, 2, "vec.epilog.resume.val"); + EPResumeVal->insertBefore(LoopVectorPreHeader->getFirstNonPHIIt()); EPResumeVal->addIncoming(EPI.VectorTripCount, VecEpilogueIterationCountCheck); EPResumeVal->addIncoming(ConstantInt::get(IdxTy, 0), EPI.MainLoopIterationCountCheck); @@ -8076,9 +7972,22 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( EPI.EpilogueVF, EPI.EpilogueUF), "min.epilog.iters.check"); - ReplaceInstWithInst( - Insert->getTerminator(), - BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters)); + BranchInst &BI = + *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters); + if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) { + unsigned MainLoopStep = UF * VF.getKnownMinValue(); + unsigned EpilogueLoopStep = + EPI.EpilogueUF * EPI.EpilogueVF.getKnownMinValue(); + // We assume the remaining `Count` is equally distributed in + // [0, MainLoopStep) + // So the probability for `Count < EpilogueLoopStep` should be + // min(MainLoopStep, EpilogueLoopStep) / MainLoopStep + unsigned EstimatedSkipCount = std::min(MainLoopStep, EpilogueLoopStep); + const uint32_t Weights[] = {EstimatedSkipCount, + MainLoopStep - EstimatedSkipCount}; + setBranchWeights(BI, Weights); + } + ReplaceInstWithInst(Insert->getTerminator(), &BI); LoopBypassBlocks.push_back(Insert); return Insert; @@ -8172,6 +8081,33 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, return EdgeMaskCache[Edge] = EdgeMask; } +void VPRecipeBuilder::createHeaderMask(VPlan &Plan) { + BasicBlock *Header = OrigLoop->getHeader(); + + // When not folding the tail, use nullptr to model all-true mask. + if (!CM.foldTailByMasking()) { + BlockMaskCache[Header] = nullptr; + return; + } + + // Introduce the early-exit compare IV <= BTC to form header block mask. + // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by + // constructing the desired canonical IV in the header block as its first + // non-phi instructions. + + VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); + auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi(); + auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV()); + HeaderVPBB->insert(IV, NewInsertionPoint); + + VPBuilder::InsertPointGuard Guard(Builder); + Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint); + VPValue *BlockMask = nullptr; + VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); + BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC); + BlockMaskCache[Header] = BlockMask; +} + VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) { assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); @@ -8180,45 +8116,12 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) { if (BCEntryIt != BlockMaskCache.end()) return BCEntryIt->second; + assert(OrigLoop->getHeader() != BB && + "Loop header must have cached block mask"); + // All-one mask is modelled as no-mask following the convention for masked // load/store/gather/scatter. Initialize BlockMask to no-mask. VPValue *BlockMask = nullptr; - - if (OrigLoop->getHeader() == BB) { - if (!CM.blockNeedsPredicationForAnyReason(BB)) - return BlockMaskCache[BB] = BlockMask; // Loop incoming mask is all-one. - - assert(CM.foldTailByMasking() && "must fold the tail"); - - // If we're using the active lane mask for control flow, then we get the - // mask from the active lane mask PHI that is cached in the VPlan. - TailFoldingStyle TFStyle = CM.getTailFoldingStyle(); - if (useActiveLaneMaskForControlFlow(TFStyle)) - return BlockMaskCache[BB] = Plan.getActiveLaneMaskPhi(); - - // Introduce the early-exit compare IV <= BTC to form header block mask. - // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by - // constructing the desired canonical IV in the header block as its first - // non-phi instructions. - - VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); - auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi(); - auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV()); - HeaderVPBB->insert(IV, HeaderVPBB->getFirstNonPhi()); - - VPBuilder::InsertPointGuard Guard(Builder); - Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint); - if (useActiveLaneMask(TFStyle)) { - VPValue *TC = Plan.getTripCount(); - BlockMask = Builder.createNaryOp(VPInstruction::ActiveLaneMask, {IV, TC}, - nullptr, "active.lane.mask"); - } else { - VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); - BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC}); - } - return BlockMaskCache[BB] = BlockMask; - } - // This is the block mask. We OR all incoming edges. for (auto *Predecessor : predecessors(BB)) { VPValue *EdgeMask = createEdgeMask(Predecessor, BB, Plan); @@ -8424,22 +8327,15 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, bool ShouldUseVectorIntrinsic = ID && LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) -> bool { - Function *Variant; - // Is it beneficial to perform intrinsic call compared to lib - // call? - InstructionCost CallCost = - CM.getVectorCallCost(CI, VF, &Variant); - InstructionCost IntrinsicCost = - CM.getVectorIntrinsicCost(CI, VF); - return IntrinsicCost <= CallCost; + return CM.getCallWideningDecision(CI, VF).Kind == + LoopVectorizationCostModel::CM_IntrinsicCall; }, Range); if (ShouldUseVectorIntrinsic) return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), ID); Function *Variant = nullptr; - ElementCount VariantVF; - bool NeedsMask = false; + std::optional<unsigned> MaskPos; // Is better to call a vectorized version of the function than to to scalarize // the call? auto ShouldUseVectorCall = LoopVectorizationPlanner::getDecisionAndClampRange( @@ -8458,16 +8354,19 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, // finds a valid variant. if (Variant) return false; - CM.getVectorCallCost(CI, VF, &Variant, &NeedsMask); - // If we found a valid vector variant at this VF, then store the VF - // in case we need to generate a mask. - if (Variant) - VariantVF = VF; - return Variant != nullptr; + LoopVectorizationCostModel::CallWideningDecision Decision = + CM.getCallWideningDecision(CI, VF); + if (Decision.Kind == LoopVectorizationCostModel::CM_VectorCall) { + Variant = Decision.Variant; + MaskPos = Decision.MaskPos; + return true; + } + + return false; }, Range); if (ShouldUseVectorCall) { - if (NeedsMask) { + if (MaskPos.has_value()) { // We have 2 cases that would require a mask: // 1) The block needs to be predicated, either due to a conditional // in the scalar loop or use of an active lane mask with @@ -8482,17 +8381,7 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, Mask = Plan->getVPValueOrAddLiveIn(ConstantInt::getTrue( IntegerType::getInt1Ty(Variant->getFunctionType()->getContext()))); - VFShape Shape = VFShape::get(*CI, VariantVF, /*HasGlobalPred=*/true); - unsigned MaskPos = 0; - - for (const VFInfo &Info : VFDatabase::getMappings(*CI)) - if (Info.Shape == Shape) { - assert(Info.isMasked() && "Vector function info shape mismatch"); - MaskPos = Info.getParamIndexForOptionalMask().value(); - break; - } - - Ops.insert(Ops.begin() + MaskPos, Mask); + Ops.insert(Ops.begin() + *MaskPos, Mask); } return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), @@ -8713,8 +8602,8 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, } if (auto *CI = dyn_cast<CastInst>(Instr)) { - return toVPRecipeResult( - new VPWidenCastRecipe(CI->getOpcode(), Operands[0], CI->getType(), CI)); + return toVPRecipeResult(new VPWidenCastRecipe(CI->getOpcode(), Operands[0], + CI->getType(), *CI)); } return toVPRecipeResult(tryToWiden(Instr, Operands, VPBB, Plan)); @@ -8724,27 +8613,26 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, ElementCount MaxVF) { assert(OrigLoop->isInnermost() && "Inner loop expected."); - // Add assume instructions we need to drop to DeadInstructions, to prevent - // them from being added to the VPlan. - // TODO: We only need to drop assumes in blocks that get flattend. If the - // control flow is preserved, we should keep them. - SmallPtrSet<Instruction *, 4> DeadInstructions; - auto &ConditionalAssumes = Legal->getConditionalAssumes(); - DeadInstructions.insert(ConditionalAssumes.begin(), ConditionalAssumes.end()); - auto MaxVFTimes2 = MaxVF * 2; for (ElementCount VF = MinVF; ElementCount::isKnownLT(VF, MaxVFTimes2);) { VFRange SubRange = {VF, MaxVFTimes2}; - if (auto Plan = tryToBuildVPlanWithVPRecipes(SubRange, DeadInstructions)) - VPlans.push_back(std::move(*Plan)); + if (auto Plan = tryToBuildVPlanWithVPRecipes(SubRange)) { + // Now optimize the initial VPlan. + if (!Plan->hasVF(ElementCount::getFixed(1))) + VPlanTransforms::truncateToMinimalBitwidths( + *Plan, CM.getMinimalBitwidths(), PSE.getSE()->getContext()); + VPlanTransforms::optimize(*Plan, *PSE.getSE()); + assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); + VPlans.push_back(std::move(Plan)); + } VF = SubRange.End; } } // Add the necessary canonical IV and branch recipes required to control the // loop. -static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, - TailFoldingStyle Style) { +static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW, + DebugLoc DL) { Value *StartIdx = ConstantInt::get(IdxTy, 0); auto *StartV = Plan.getVPValueOrAddLiveIn(StartIdx); @@ -8756,102 +8644,24 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, // Add a CanonicalIVIncrement{NUW} VPInstruction to increment the scalar // IV by VF * UF. - bool HasNUW = Style == TailFoldingStyle::None; auto *CanonicalIVIncrement = - new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementNUW - : VPInstruction::CanonicalIVIncrement, - {CanonicalIVPHI}, DL, "index.next"); + new VPInstruction(Instruction::Add, {CanonicalIVPHI, &Plan.getVFxUF()}, + {HasNUW, false}, DL, "index.next"); CanonicalIVPHI->addOperand(CanonicalIVIncrement); VPBasicBlock *EB = TopRegion->getExitingBasicBlock(); - if (useActiveLaneMaskForControlFlow(Style)) { - // Create the active lane mask instruction in the vplan preheader. - VPBasicBlock *VecPreheader = - cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSinglePredecessor()); - - // We can't use StartV directly in the ActiveLaneMask VPInstruction, since - // we have to take unrolling into account. Each part needs to start at - // Part * VF - auto *CanonicalIVIncrementParts = - new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW - : VPInstruction::CanonicalIVIncrementForPart, - {StartV}, DL, "index.part.next"); - VecPreheader->appendRecipe(CanonicalIVIncrementParts); - - // Create the ActiveLaneMask instruction using the correct start values. - VPValue *TC = Plan.getTripCount(); - - VPValue *TripCount, *IncrementValue; - if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) { - // When avoiding a runtime check, the active.lane.mask inside the loop - // uses a modified trip count and the induction variable increment is - // done after the active.lane.mask intrinsic is called. - auto *TCMinusVF = - new VPInstruction(VPInstruction::CalculateTripCountMinusVF, {TC}, DL); - VecPreheader->appendRecipe(TCMinusVF); - IncrementValue = CanonicalIVPHI; - TripCount = TCMinusVF; - } else { - // When the loop is guarded by a runtime overflow check for the loop - // induction variable increment by VF, we can increment the value before - // the get.active.lane mask and use the unmodified tripcount. - EB->appendRecipe(CanonicalIVIncrement); - IncrementValue = CanonicalIVIncrement; - TripCount = TC; - } - - auto *EntryALM = new VPInstruction(VPInstruction::ActiveLaneMask, - {CanonicalIVIncrementParts, TC}, DL, - "active.lane.mask.entry"); - VecPreheader->appendRecipe(EntryALM); - - // Now create the ActiveLaneMaskPhi recipe in the main loop using the - // preheader ActiveLaneMask instruction. - auto *LaneMaskPhi = new VPActiveLaneMaskPHIRecipe(EntryALM, DebugLoc()); - Header->insert(LaneMaskPhi, Header->getFirstNonPhi()); - - // Create the active lane mask for the next iteration of the loop. - CanonicalIVIncrementParts = - new VPInstruction(HasNUW ? VPInstruction::CanonicalIVIncrementForPartNUW - : VPInstruction::CanonicalIVIncrementForPart, - {IncrementValue}, DL); - EB->appendRecipe(CanonicalIVIncrementParts); - - auto *ALM = new VPInstruction(VPInstruction::ActiveLaneMask, - {CanonicalIVIncrementParts, TripCount}, DL, - "active.lane.mask.next"); - EB->appendRecipe(ALM); - LaneMaskPhi->addOperand(ALM); - - if (Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) { - // Do the increment of the canonical IV after the active.lane.mask, because - // that value is still based off %CanonicalIVPHI - EB->appendRecipe(CanonicalIVIncrement); - } - - // We have to invert the mask here because a true condition means jumping - // to the exit block. - auto *NotMask = new VPInstruction(VPInstruction::Not, ALM, DL); - EB->appendRecipe(NotMask); + EB->appendRecipe(CanonicalIVIncrement); - VPInstruction *BranchBack = - new VPInstruction(VPInstruction::BranchOnCond, {NotMask}, DL); - EB->appendRecipe(BranchBack); - } else { - EB->appendRecipe(CanonicalIVIncrement); - - // Add the BranchOnCount VPInstruction to the latch. - VPInstruction *BranchBack = new VPInstruction( - VPInstruction::BranchOnCount, - {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL); - EB->appendRecipe(BranchBack); - } + // Add the BranchOnCount VPInstruction to the latch. + VPInstruction *BranchBack = + new VPInstruction(VPInstruction::BranchOnCount, + {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL); + EB->appendRecipe(BranchBack); } // Add exit values to \p Plan. VPLiveOuts are added for each LCSSA phi in the // original exit block. -static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, - VPBasicBlock *MiddleVPBB, Loop *OrigLoop, +static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, Loop *OrigLoop, VPlan &Plan) { BasicBlock *ExitBB = OrigLoop->getUniqueExitBlock(); BasicBlock *ExitingBB = OrigLoop->getExitingBlock(); @@ -8868,8 +8678,8 @@ static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, } } -std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( - VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions) { +VPlanPtr +LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { SmallPtrSet<const InterleaveGroup<Instruction> *, 1> InterleaveGroups; @@ -8880,24 +8690,6 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // process after constructing the initial VPlan. // --------------------------------------------------------------------------- - for (const auto &Reduction : CM.getInLoopReductionChains()) { - PHINode *Phi = Reduction.first; - RecurKind Kind = - Legal->getReductionVars().find(Phi)->second.getRecurrenceKind(); - const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second; - - RecipeBuilder.recordRecipeOf(Phi); - for (const auto &R : ReductionOperations) { - RecipeBuilder.recordRecipeOf(R); - // For min/max reductions, where we have a pair of icmp/select, we also - // need to record the ICmp recipe, so it can be removed later. - assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && - "Only min/max recurrences allowed for inloop reductions"); - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) - RecipeBuilder.recordRecipeOf(cast<Instruction>(R->getOperand(0))); - } - } - // For each interleave group which is relevant for this (possibly trimmed) // Range, add it to the set of groups to be later applied to the VPlan and add // placeholders for its members' Recipes which we'll be replacing with a @@ -8938,23 +8730,27 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body"); VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch"); VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB); - auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop"); - VPBlockUtils::insertBlockAfter(TopRegion, Plan->getEntry()); - VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block"); - VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion); + Plan->getVectorLoopRegion()->setEntry(HeaderVPBB); + Plan->getVectorLoopRegion()->setExiting(LatchVPBB); // Don't use getDecisionAndClampRange here, because we don't know the UF // so this function is better to be conservative, rather than to split // it up into different VPlans. + // TODO: Consider using getDecisionAndClampRange here to split up VPlans. bool IVUpdateMayOverflow = false; for (ElementCount VF : Range) IVUpdateMayOverflow |= !isIndvarOverflowCheckKnownFalse(&CM, VF); - Instruction *DLInst = - getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); - addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), - DLInst ? DLInst->getDebugLoc() : DebugLoc(), - CM.getTailFoldingStyle(IVUpdateMayOverflow)); + DebugLoc DL = getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); + TailFoldingStyle Style = CM.getTailFoldingStyle(IVUpdateMayOverflow); + // When not folding the tail, we know that the induction increment will not + // overflow. + bool HasNUW = Style == TailFoldingStyle::None; + addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL); + + // Proactively create header mask. Masks for other blocks are created on + // demand. + RecipeBuilder.createHeaderMask(*Plan); // Scan the body of the loop in a topological order to visit each basic block // after having visited its predecessor basic blocks. @@ -8971,14 +8767,8 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // Introduce each ingredient into VPlan. // TODO: Model and preserve debug intrinsics in VPlan. - for (Instruction &I : BB->instructionsWithoutDebug(false)) { + for (Instruction &I : drop_end(BB->instructionsWithoutDebug(false))) { Instruction *Instr = &I; - - // First filter out irrelevant instructions, to ensure no recipes are - // built for them. - if (isa<BranchInst>(Instr) || DeadInstructions.count(Instr)) - continue; - SmallVector<VPValue *, 4> Operands; auto *Phi = dyn_cast<PHINode>(Instr); if (Phi && Phi->getParent() == OrigLoop->getHeader()) { @@ -9018,11 +8808,18 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( } RecipeBuilder.setRecipe(Instr, Recipe); - if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) && - HeaderVPBB->getFirstNonPhi() != VPBB->end()) { - // Move VPWidenIntOrFpInductionRecipes for optimized truncates to the - // phi section of HeaderVPBB. - assert(isa<TruncInst>(Instr)); + if (isa<VPHeaderPHIRecipe>(Recipe)) { + // VPHeaderPHIRecipes must be kept in the phi section of HeaderVPBB. In + // the following cases, VPHeaderPHIRecipes may be created after non-phi + // recipes and need to be moved to the phi section of HeaderVPBB: + // * tail-folding (non-phi recipes computing the header mask are + // introduced earlier than regular header phi recipes, and should appear + // after them) + // * Optimizing truncates to VPWidenIntOrFpInductionRecipe. + + assert((HeaderVPBB->getFirstNonPhi() == VPBB->end() || + CM.foldTailByMasking() || isa<TruncInst>(Instr)) && + "unexpected recipe needs moving"); Recipe->insertBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi()); } else VPBB->appendRecipe(Recipe); @@ -9040,7 +8837,7 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // and there is nothing to fix from vector loop; phis should have incoming // from scalar loop only. } else - addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan); + addUsersInExitBlock(HeaderVPBB, OrigLoop, *Plan); assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) && !Plan->getVectorLoopRegion()->getEntryBasicBlock()->empty() && @@ -9054,8 +8851,7 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // --------------------------------------------------------------------------- // Adjust the recipes for any inloop reductions. - adjustRecipesForReductions(cast<VPBasicBlock>(TopRegion->getExiting()), Plan, - RecipeBuilder, Range.Start); + adjustRecipesForReductions(LatchVPBB, Plan, RecipeBuilder, Range.Start); // Interleave memory: for each Interleave Group we marked earlier as relevant // for this VPlan, replace the Recipes widening its memory instructions with a @@ -9116,21 +8912,18 @@ std::optional<VPlanPtr> LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // Sink users of fixed-order recurrence past the recipe defining the previous // value and introduce FirstOrderRecurrenceSplice VPInstructions. if (!VPlanTransforms::adjustFixedOrderRecurrences(*Plan, Builder)) - return std::nullopt; - - VPlanTransforms::removeRedundantCanonicalIVs(*Plan); - VPlanTransforms::removeRedundantInductionCasts(*Plan); - - VPlanTransforms::optimizeInductions(*Plan, *PSE.getSE()); - VPlanTransforms::removeDeadRecipes(*Plan); - - VPlanTransforms::createAndOptimizeReplicateRegions(*Plan); - - VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan); - VPlanTransforms::mergeBlocksIntoPredecessors(*Plan); + return nullptr; - assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); - return std::make_optional(std::move(Plan)); + if (useActiveLaneMask(Style)) { + // TODO: Move checks to VPlanTransforms::addActiveLaneMask once + // TailFoldingStyle is visible there. + bool ForControlFlow = useActiveLaneMaskForControlFlow(Style); + bool WithoutRuntimeCheck = + Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck; + VPlanTransforms::addActiveLaneMask(*Plan, ForControlFlow, + WithoutRuntimeCheck); + } + return Plan; } VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { @@ -9164,8 +8957,11 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { Plan->getVectorLoopRegion()->getExitingBasicBlock()->getTerminator(); Term->eraseFromParent(); - addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DebugLoc(), - CM.getTailFoldingStyle()); + // Tail folding is not supported for outer loops, so the induction increment + // is guaranteed to not wrap. + bool HasNUW = true; + addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, + DebugLoc()); return Plan; } @@ -9177,105 +8973,211 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { void LoopVectorizationPlanner::adjustRecipesForReductions( VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) { - for (const auto &Reduction : CM.getInLoopReductionChains()) { - PHINode *Phi = Reduction.first; - const RecurrenceDescriptor &RdxDesc = - Legal->getReductionVars().find(Phi)->second; - const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second; + VPBasicBlock *Header = Plan->getVectorLoopRegion()->getEntryBasicBlock(); + // Gather all VPReductionPHIRecipe and sort them so that Intermediate stores + // sank outside of the loop would keep the same order as they had in the + // original loop. + SmallVector<VPReductionPHIRecipe *> ReductionPHIList; + for (VPRecipeBase &R : Header->phis()) { + if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) + ReductionPHIList.emplace_back(ReductionPhi); + } + bool HasIntermediateStore = false; + stable_sort(ReductionPHIList, + [this, &HasIntermediateStore](const VPReductionPHIRecipe *R1, + const VPReductionPHIRecipe *R2) { + auto *IS1 = R1->getRecurrenceDescriptor().IntermediateStore; + auto *IS2 = R2->getRecurrenceDescriptor().IntermediateStore; + HasIntermediateStore |= IS1 || IS2; + + // If neither of the recipes has an intermediate store, keep the + // order the same. + if (!IS1 && !IS2) + return false; + + // If only one of the recipes has an intermediate store, then + // move it towards the beginning of the list. + if (IS1 && !IS2) + return true; + + if (!IS1 && IS2) + return false; - if (MinVF.isScalar() && !CM.useOrderedReductions(RdxDesc)) + // If both recipes have an intermediate store, then the recipe + // with the later store should be processed earlier. So it + // should go to the beginning of the list. + return DT->dominates(IS2, IS1); + }); + + if (HasIntermediateStore && ReductionPHIList.size() > 1) + for (VPRecipeBase *R : ReductionPHIList) + R->moveBefore(*Header, Header->getFirstNonPhi()); + + SmallVector<VPReductionPHIRecipe *> InLoopReductionPhis; + for (VPRecipeBase &R : Header->phis()) { + auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); + if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered())) continue; + InLoopReductionPhis.push_back(PhiR); + } - // ReductionOperations are orders top-down from the phi's use to the - // LoopExitValue. We keep a track of the previous item (the Chain) to tell - // which of the two operands will remain scalar and which will be reduced. - // For minmax the chain will be the select instructions. - Instruction *Chain = Phi; - for (Instruction *R : ReductionOperations) { - VPRecipeBase *WidenRecipe = RecipeBuilder.getRecipe(R); - RecurKind Kind = RdxDesc.getRecurrenceKind(); + for (VPReductionPHIRecipe *PhiR : InLoopReductionPhis) { + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + RecurKind Kind = RdxDesc.getRecurrenceKind(); + assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) && + "AnyOf reductions are not allowed for in-loop reductions"); - VPValue *ChainOp = Plan->getVPValue(Chain); - unsigned FirstOpId; - assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && - "Only min/max recurrences allowed for inloop reductions"); + // Collect the chain of "link" recipes for the reduction starting at PhiR. + SetVector<VPRecipeBase *> Worklist; + Worklist.insert(PhiR); + for (unsigned I = 0; I != Worklist.size(); ++I) { + VPRecipeBase *Cur = Worklist[I]; + for (VPUser *U : Cur->getVPSingleValue()->users()) { + auto *UserRecipe = dyn_cast<VPRecipeBase>(U); + if (!UserRecipe) + continue; + assert(UserRecipe->getNumDefinedValues() == 1 && + "recipes must define exactly one result value"); + Worklist.insert(UserRecipe); + } + } + + // Visit operation "Links" along the reduction chain top-down starting from + // the phi until LoopExitValue. We keep track of the previous item + // (PreviousLink) to tell which of the two operands of a Link will remain + // scalar and which will be reduced. For minmax by select(cmp), Link will be + // the select instructions. + VPRecipeBase *PreviousLink = PhiR; // Aka Worklist[0]. + for (VPRecipeBase *CurrentLink : Worklist.getArrayRef().drop_front()) { + VPValue *PreviousLinkV = PreviousLink->getVPSingleValue(); + + Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr(); + + // Index of the first operand which holds a non-mask vector operand. + unsigned IndexOfFirstOperand; // Recognize a call to the llvm.fmuladd intrinsic. bool IsFMulAdd = (Kind == RecurKind::FMulAdd); - assert((!IsFMulAdd || RecurrenceDescriptor::isFMulAddIntrinsic(R)) && - "Expected instruction to be a call to the llvm.fmuladd intrinsic"); - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { - assert(isa<VPWidenSelectRecipe>(WidenRecipe) && - "Expected to replace a VPWidenSelectSC"); - FirstOpId = 1; + VPValue *VecOp; + VPBasicBlock *LinkVPBB = CurrentLink->getParent(); + if (IsFMulAdd) { + assert( + RecurrenceDescriptor::isFMulAddIntrinsic(CurrentLinkI) && + "Expected instruction to be a call to the llvm.fmuladd intrinsic"); + assert(((MinVF.isScalar() && isa<VPReplicateRecipe>(CurrentLink)) || + isa<VPWidenCallRecipe>(CurrentLink)) && + CurrentLink->getOperand(2) == PreviousLinkV && + "expected a call where the previous link is the added operand"); + + // If the instruction is a call to the llvm.fmuladd intrinsic then we + // need to create an fmul recipe (multiplying the first two operands of + // the fmuladd together) to use as the vector operand for the fadd + // reduction. + VPInstruction *FMulRecipe = new VPInstruction( + Instruction::FMul, + {CurrentLink->getOperand(0), CurrentLink->getOperand(1)}, + CurrentLinkI->getFastMathFlags()); + LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator()); + VecOp = FMulRecipe; } else { - assert((MinVF.isScalar() || isa<VPWidenRecipe>(WidenRecipe) || - (IsFMulAdd && isa<VPWidenCallRecipe>(WidenRecipe))) && - "Expected to replace a VPWidenSC"); - FirstOpId = 0; + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { + if (isa<VPWidenRecipe>(CurrentLink)) { + assert(isa<CmpInst>(CurrentLinkI) && + "need to have the compare of the select"); + continue; + } + assert(isa<VPWidenSelectRecipe>(CurrentLink) && + "must be a select recipe"); + IndexOfFirstOperand = 1; + } else { + assert((MinVF.isScalar() || isa<VPWidenRecipe>(CurrentLink)) && + "Expected to replace a VPWidenSC"); + IndexOfFirstOperand = 0; + } + // Note that for non-commutable operands (cmp-selects), the semantics of + // the cmp-select are captured in the recurrence kind. + unsigned VecOpId = + CurrentLink->getOperand(IndexOfFirstOperand) == PreviousLinkV + ? IndexOfFirstOperand + 1 + : IndexOfFirstOperand; + VecOp = CurrentLink->getOperand(VecOpId); + assert(VecOp != PreviousLinkV && + CurrentLink->getOperand(CurrentLink->getNumOperands() - 1 - + (VecOpId - IndexOfFirstOperand)) == + PreviousLinkV && + "PreviousLinkV must be the operand other than VecOp"); } - unsigned VecOpId = - R->getOperand(FirstOpId) == Chain ? FirstOpId + 1 : FirstOpId; - VPValue *VecOp = Plan->getVPValue(R->getOperand(VecOpId)); + BasicBlock *BB = CurrentLinkI->getParent(); VPValue *CondOp = nullptr; - if (CM.blockNeedsPredicationForAnyReason(R->getParent())) { + if (CM.blockNeedsPredicationForAnyReason(BB)) { VPBuilder::InsertPointGuard Guard(Builder); - Builder.setInsertPoint(WidenRecipe->getParent(), - WidenRecipe->getIterator()); - CondOp = RecipeBuilder.createBlockInMask(R->getParent(), *Plan); + Builder.setInsertPoint(CurrentLink); + CondOp = RecipeBuilder.createBlockInMask(BB, *Plan); } - if (IsFMulAdd) { - // If the instruction is a call to the llvm.fmuladd intrinsic then we - // need to create an fmul recipe to use as the vector operand for the - // fadd reduction. - VPInstruction *FMulRecipe = new VPInstruction( - Instruction::FMul, {VecOp, Plan->getVPValue(R->getOperand(1))}); - FMulRecipe->setFastMathFlags(R->getFastMathFlags()); - WidenRecipe->getParent()->insert(FMulRecipe, - WidenRecipe->getIterator()); - VecOp = FMulRecipe; - } - VPReductionRecipe *RedRecipe = - new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, &TTI); - WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); - Plan->removeVPValueFor(R); - Plan->addVPValue(R, RedRecipe); + VPReductionRecipe *RedRecipe = new VPReductionRecipe( + RdxDesc, CurrentLinkI, PreviousLinkV, VecOp, CondOp); // Append the recipe to the end of the VPBasicBlock because we need to // ensure that it comes after all of it's inputs, including CondOp. - WidenRecipe->getParent()->appendRecipe(RedRecipe); - WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); - WidenRecipe->eraseFromParent(); - - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { - VPRecipeBase *CompareRecipe = - RecipeBuilder.getRecipe(cast<Instruction>(R->getOperand(0))); - assert(isa<VPWidenRecipe>(CompareRecipe) && - "Expected to replace a VPWidenSC"); - assert(cast<VPWidenRecipe>(CompareRecipe)->getNumUsers() == 0 && - "Expected no remaining users"); - CompareRecipe->eraseFromParent(); - } - Chain = R; + // Note that this transformation may leave over dead recipes (including + // CurrentLink), which will be cleaned by a later VPlan transform. + LinkVPBB->appendRecipe(RedRecipe); + CurrentLink->getVPSingleValue()->replaceAllUsesWith(RedRecipe); + PreviousLink = RedRecipe; } } - - // If tail is folded by masking, introduce selects between the phi - // and the live-out instruction of each reduction, at the beginning of the - // dedicated latch block. - if (CM.foldTailByMasking()) { - Builder.setInsertPoint(LatchVPBB, LatchVPBB->begin()); + Builder.setInsertPoint(&*LatchVPBB->begin()); for (VPRecipeBase &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { - VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); - if (!PhiR || PhiR->isInLoop()) - continue; + VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); + if (!PhiR || PhiR->isInLoop()) + continue; + + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + auto *Result = PhiR->getBackedgeValue()->getDefiningRecipe(); + // If tail is folded by masking, introduce selects between the phi + // and the live-out instruction of each reduction, at the beginning of the + // dedicated latch block. + if (CM.foldTailByMasking()) { VPValue *Cond = RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), *Plan); VPValue *Red = PhiR->getBackedgeValue(); assert(Red->getDefiningRecipe()->getParent() != LatchVPBB && "reduction recipe must be defined before latch"); - Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR}); + FastMathFlags FMFs = RdxDesc.getFastMathFlags(); + Type *PhiTy = PhiR->getOperand(0)->getLiveInIRValue()->getType(); + Result = + PhiTy->isFloatingPointTy() + ? new VPInstruction(Instruction::Select, {Cond, Red, PhiR}, FMFs) + : new VPInstruction(Instruction::Select, {Cond, Red, PhiR}); + Result->insertBefore(&*Builder.getInsertPoint()); + Red->replaceUsesWithIf( + Result->getVPSingleValue(), + [](VPUser &U, unsigned) { return isa<VPLiveOut>(&U); }); + if (PreferPredicatedReductionSelect || + TTI.preferPredicatedReductionSelect( + PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy, + TargetTransformInfo::ReductionFlags())) + PhiR->setOperand(1, Result->getVPSingleValue()); + } + // If the vector reduction can be performed in a smaller type, we truncate + // then extend the loop exit value to enable InstCombine to evaluate the + // entire expression in the smaller type. + Type *PhiTy = PhiR->getStartValue()->getLiveInIRValue()->getType(); + if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) { + assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!"); + Type *RdxTy = RdxDesc.getRecurrenceType(); + auto *Trunc = new VPWidenCastRecipe(Instruction::Trunc, + Result->getVPSingleValue(), RdxTy); + auto *Extnd = + RdxDesc.isSigned() + ? new VPWidenCastRecipe(Instruction::SExt, Trunc, PhiTy) + : new VPWidenCastRecipe(Instruction::ZExt, Trunc, PhiTy); + + Trunc->insertAfter(Result); + Extnd->insertAfter(Trunc); + Result->getVPSingleValue()->replaceAllUsesWith(Extnd); + Trunc->setOperand(0, Result->getVPSingleValue()); } } @@ -9313,107 +9215,6 @@ void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent, } #endif -void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) { - assert(!State.Instance && "Int or FP induction being replicated."); - - Value *Start = getStartValue()->getLiveInIRValue(); - const InductionDescriptor &ID = getInductionDescriptor(); - TruncInst *Trunc = getTruncInst(); - IRBuilderBase &Builder = State.Builder; - assert(IV->getType() == ID.getStartValue()->getType() && "Types must match"); - assert(State.VF.isVector() && "must have vector VF"); - - // The value from the original loop to which we are mapping the new induction - // variable. - Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; - - // Fast-math-flags propagate from the original induction instruction. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - if (ID.getInductionBinOp() && isa<FPMathOperator>(ID.getInductionBinOp())) - Builder.setFastMathFlags(ID.getInductionBinOp()->getFastMathFlags()); - - // Now do the actual transformations, and start with fetching the step value. - Value *Step = State.get(getStepValue(), VPIteration(0, 0)); - - assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) && - "Expected either an induction phi-node or a truncate of it!"); - - // Construct the initial value of the vector IV in the vector loop preheader - auto CurrIP = Builder.saveIP(); - BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); - Builder.SetInsertPoint(VectorPH->getTerminator()); - if (isa<TruncInst>(EntryVal)) { - assert(Start->getType()->isIntegerTy() && - "Truncation requires an integer type"); - auto *TruncType = cast<IntegerType>(EntryVal->getType()); - Step = Builder.CreateTrunc(Step, TruncType); - Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); - } - - Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0); - Value *SplatStart = Builder.CreateVectorSplat(State.VF, Start); - Value *SteppedStart = getStepVector( - SplatStart, Zero, Step, ID.getInductionOpcode(), State.VF, State.Builder); - - // We create vector phi nodes for both integer and floating-point induction - // variables. Here, we determine the kind of arithmetic we will perform. - Instruction::BinaryOps AddOp; - Instruction::BinaryOps MulOp; - if (Step->getType()->isIntegerTy()) { - AddOp = Instruction::Add; - MulOp = Instruction::Mul; - } else { - AddOp = ID.getInductionOpcode(); - MulOp = Instruction::FMul; - } - - // Multiply the vectorization factor by the step using integer or - // floating-point arithmetic as appropriate. - Type *StepType = Step->getType(); - Value *RuntimeVF; - if (Step->getType()->isFloatingPointTy()) - RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, State.VF); - else - RuntimeVF = getRuntimeVF(Builder, StepType, State.VF); - Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF); - - // Create a vector splat to use in the induction update. - // - // FIXME: If the step is non-constant, we create the vector splat with - // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't - // handle a constant vector splat. - Value *SplatVF = isa<Constant>(Mul) - ? ConstantVector::getSplat(State.VF, cast<Constant>(Mul)) - : Builder.CreateVectorSplat(State.VF, Mul); - Builder.restoreIP(CurrIP); - - // We may need to add the step a number of times, depending on the unroll - // factor. The last of those goes into the PHI. - PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", - &*State.CFG.PrevBB->getFirstInsertionPt()); - VecInd->setDebugLoc(EntryVal->getDebugLoc()); - Instruction *LastInduction = VecInd; - for (unsigned Part = 0; Part < State.UF; ++Part) { - State.set(this, LastInduction, Part); - - if (isa<TruncInst>(EntryVal)) - State.addMetadata(LastInduction, EntryVal); - - LastInduction = cast<Instruction>( - Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add")); - LastInduction->setDebugLoc(EntryVal->getDebugLoc()); - } - - LastInduction->setName("vec.ind.next"); - VecInd->addIncoming(SteppedStart, VectorPH); - // Add induction update using an incorrect block temporarily. The phi node - // will be fixed after VPlan execution. Note that at this point the latch - // block cannot be used, as it does not exist yet. - // TODO: Model increment value in VPlan, by turning the recipe into a - // multi-def and a subclass of VPHeaderPHIRecipe. - VecInd->addIncoming(LastInduction, VectorPH); -} - void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { assert(IndDesc.getKind() == InductionDescriptor::IK_PtrInduction && "Not a pointer induction according to InductionDescriptor!"); @@ -9446,7 +9247,8 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { Value *Step = State.get(getOperand(1), VPIteration(Part, Lane)); Value *SclrGep = emitTransformedIndex( - State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, IndDesc); + State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, + IndDesc.getKind(), IndDesc.getInductionBinOp()); SclrGep->setName("next.gep"); State.set(this, SclrGep, VPIteration(Part, Lane)); } @@ -9513,41 +9315,26 @@ void VPDerivedIVRecipe::execute(VPTransformState &State) { // Fast-math-flags propagate from the original induction instruction. IRBuilder<>::FastMathFlagGuard FMFG(State.Builder); - if (IndDesc.getInductionBinOp() && - isa<FPMathOperator>(IndDesc.getInductionBinOp())) - State.Builder.setFastMathFlags( - IndDesc.getInductionBinOp()->getFastMathFlags()); + if (FPBinOp) + State.Builder.setFastMathFlags(FPBinOp->getFastMathFlags()); Value *Step = State.get(getStepValue(), VPIteration(0, 0)); Value *CanonicalIV = State.get(getCanonicalIV(), VPIteration(0, 0)); - Value *DerivedIV = - emitTransformedIndex(State.Builder, CanonicalIV, - getStartValue()->getLiveInIRValue(), Step, IndDesc); + Value *DerivedIV = emitTransformedIndex( + State.Builder, CanonicalIV, getStartValue()->getLiveInIRValue(), Step, + Kind, cast_if_present<BinaryOperator>(FPBinOp)); DerivedIV->setName("offset.idx"); - if (ResultTy != DerivedIV->getType()) { - assert(Step->getType()->isIntegerTy() && + if (TruncResultTy) { + assert(TruncResultTy != DerivedIV->getType() && + Step->getType()->isIntegerTy() && "Truncation requires an integer step"); - DerivedIV = State.Builder.CreateTrunc(DerivedIV, ResultTy); + DerivedIV = State.Builder.CreateTrunc(DerivedIV, TruncResultTy); } assert(DerivedIV != CanonicalIV && "IV didn't need transforming?"); State.set(this, DerivedIV, VPIteration(0, 0)); } -void VPScalarIVStepsRecipe::execute(VPTransformState &State) { - // Fast-math-flags propagate from the original induction instruction. - IRBuilder<>::FastMathFlagGuard FMFG(State.Builder); - if (IndDesc.getInductionBinOp() && - isa<FPMathOperator>(IndDesc.getInductionBinOp())) - State.Builder.setFastMathFlags( - IndDesc.getInductionBinOp()->getFastMathFlags()); - - Value *BaseIV = State.get(getOperand(0), VPIteration(0, 0)); - Value *Step = State.get(getStepValue(), VPIteration(0, 0)); - - buildScalarSteps(BaseIV, Step, IndDesc, this, State); -} - void VPInterleaveRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Interleave group being replicated."); State.ILV->vectorizeInterleaveGroup(IG, definedValues(), State, getAddr(), @@ -9558,48 +9345,51 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { void VPReductionRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Reduction being replicated."); Value *PrevInChain = State.get(getChainOp(), 0); - RecurKind Kind = RdxDesc->getRecurrenceKind(); - bool IsOrdered = State.ILV->useOrderedReductions(*RdxDesc); + RecurKind Kind = RdxDesc.getRecurrenceKind(); + bool IsOrdered = State.ILV->useOrderedReductions(RdxDesc); // Propagate the fast-math flags carried by the underlying instruction. IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(RdxDesc->getFastMathFlags()); + State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *NewVecOp = State.get(getVecOp(), Part); if (VPValue *Cond = getCondOp()) { - Value *NewCond = State.get(Cond, Part); - VectorType *VecTy = cast<VectorType>(NewVecOp->getType()); - Value *Iden = RdxDesc->getRecurrenceIdentity( - Kind, VecTy->getElementType(), RdxDesc->getFastMathFlags()); - Value *IdenVec = - State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden); - Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, IdenVec); + Value *NewCond = State.VF.isVector() ? State.get(Cond, Part) + : State.get(Cond, {Part, 0}); + VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType()); + Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType(); + Value *Iden = RdxDesc.getRecurrenceIdentity(Kind, ElementTy, + RdxDesc.getFastMathFlags()); + if (State.VF.isVector()) { + Iden = + State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden); + } + + Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, Iden); NewVecOp = Select; } Value *NewRed; Value *NextInChain; if (IsOrdered) { if (State.VF.isVector()) - NewRed = createOrderedReduction(State.Builder, *RdxDesc, NewVecOp, + NewRed = createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain); else NewRed = State.Builder.CreateBinOp( - (Instruction::BinaryOps)RdxDesc->getOpcode(Kind), PrevInChain, + (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), PrevInChain, NewVecOp); PrevInChain = NewRed; } else { PrevInChain = State.get(getChainOp(), Part); - NewRed = createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp); + NewRed = createTargetReduction(State.Builder, RdxDesc, NewVecOp); } if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { - NextInChain = - createMinMaxOp(State.Builder, RdxDesc->getRecurrenceKind(), - NewRed, PrevInChain); + NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(), + NewRed, PrevInChain); } else if (IsOrdered) NextInChain = NewRed; else NextInChain = State.Builder.CreateBinOp( - (Instruction::BinaryOps)RdxDesc->getOpcode(Kind), NewRed, - PrevInChain); + (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), NewRed, PrevInChain); State.set(this, NextInChain, Part); } } @@ -9618,7 +9408,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) { VectorType::get(UI->getType(), State.VF)); State.set(this, Poison, State.Instance->Part); } - State.ILV->packScalarIntoVectorValue(this, *State.Instance, State); + State.packScalarIntoVectorValue(this, *State.Instance); } return; } @@ -9684,9 +9474,16 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { auto &Builder = State.Builder; InnerLoopVectorizer::VectorParts BlockInMaskParts(State.UF); bool isMaskRequired = getMask(); - if (isMaskRequired) - for (unsigned Part = 0; Part < State.UF; ++Part) - BlockInMaskParts[Part] = State.get(getMask(), Part); + if (isMaskRequired) { + // Mask reversal is only neede for non-all-one (null) masks, as reverse of a + // null all-one mask is a null mask. + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *Mask = State.get(getMask(), Part); + if (isReverse()) + Mask = Builder.CreateVectorReverse(Mask, "reverse"); + BlockInMaskParts[Part] = Mask; + } + } const auto CreateVecPtr = [&](unsigned Part, Value *Ptr) -> Value * { // Calculate the pointer for the specific unroll-part. @@ -9697,7 +9494,8 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); Type *IndexTy = State.VF.isScalable() && (isReverse() || Part > 0) - ? DL.getIndexType(ScalarDataTy->getPointerTo()) + ? DL.getIndexType(PointerType::getUnqual( + ScalarDataTy->getContext())) : Builder.getInt32Ty(); bool InBounds = false; if (auto *gep = dyn_cast<GetElementPtrInst>(Ptr->stripPointerCasts())) @@ -9717,21 +9515,17 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, NumElt, "", InBounds); PartPtr = Builder.CreateGEP(ScalarDataTy, PartPtr, LastLane, "", InBounds); - if (isMaskRequired) // Reverse of a null all-one mask is a null mask. - BlockInMaskParts[Part] = - Builder.CreateVectorReverse(BlockInMaskParts[Part], "reverse"); } else { Value *Increment = createStepForVF(Builder, IndexTy, State.VF, Part); PartPtr = Builder.CreateGEP(ScalarDataTy, Ptr, Increment, "", InBounds); } - unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); - return Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); + return PartPtr; }; // Handle Stores: if (SI) { - State.setDebugLocFromInst(SI); + State.setDebugLocFrom(SI->getDebugLoc()); for (unsigned Part = 0; Part < State.UF; ++Part) { Instruction *NewSI = nullptr; @@ -9764,7 +9558,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { // Handle loads. assert(LI && "Must have a load instruction"); - State.setDebugLocFromInst(LI); + State.setDebugLocFrom(LI->getDebugLoc()); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *NewLI; if (CreateGatherScatter) { @@ -9843,95 +9637,6 @@ static ScalarEpilogueLowering getScalarEpilogueLowering( return CM_ScalarEpilogueAllowed; } -Value *VPTransformState::get(VPValue *Def, unsigned Part) { - // If Values have been set for this Def return the one relevant for \p Part. - if (hasVectorValue(Def, Part)) - return Data.PerPartOutput[Def][Part]; - - auto GetBroadcastInstrs = [this, Def](Value *V) { - bool SafeToHoist = Def->isDefinedOutsideVectorRegions(); - if (VF.isScalar()) - return V; - // Place the code for broadcasting invariant variables in the new preheader. - IRBuilder<>::InsertPointGuard Guard(Builder); - if (SafeToHoist) { - BasicBlock *LoopVectorPreHeader = CFG.VPBB2IRBB[cast<VPBasicBlock>( - Plan->getVectorLoopRegion()->getSinglePredecessor())]; - if (LoopVectorPreHeader) - Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); - } - - // Place the code for broadcasting invariant variables in the new preheader. - // Broadcast the scalar into all locations in the vector. - Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast"); - - return Shuf; - }; - - if (!hasScalarValue(Def, {Part, 0})) { - Value *IRV = Def->getLiveInIRValue(); - Value *B = GetBroadcastInstrs(IRV); - set(Def, B, Part); - return B; - } - - Value *ScalarValue = get(Def, {Part, 0}); - // If we aren't vectorizing, we can just copy the scalar map values over - // to the vector map. - if (VF.isScalar()) { - set(Def, ScalarValue, Part); - return ScalarValue; - } - - bool IsUniform = vputils::isUniformAfterVectorization(Def); - - unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1; - // Check if there is a scalar value for the selected lane. - if (!hasScalarValue(Def, {Part, LastLane})) { - // At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and - // VPExpandSCEVRecipes can also be uniform. - assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) || - isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe()) || - isa<VPExpandSCEVRecipe>(Def->getDefiningRecipe())) && - "unexpected recipe found to be invariant"); - IsUniform = true; - LastLane = 0; - } - - auto *LastInst = cast<Instruction>(get(Def, {Part, LastLane})); - // Set the insert point after the last scalarized instruction or after the - // last PHI, if LastInst is a PHI. This ensures the insertelement sequence - // will directly follow the scalar definitions. - auto OldIP = Builder.saveIP(); - auto NewIP = - isa<PHINode>(LastInst) - ? BasicBlock::iterator(LastInst->getParent()->getFirstNonPHI()) - : std::next(BasicBlock::iterator(LastInst)); - Builder.SetInsertPoint(&*NewIP); - - // However, if we are vectorizing, we need to construct the vector values. - // If the value is known to be uniform after vectorization, we can just - // broadcast the scalar value corresponding to lane zero for each unroll - // iteration. Otherwise, we construct the vector values using - // insertelement instructions. Since the resulting vectors are stored in - // State, we will only generate the insertelements once. - Value *VectorValue = nullptr; - if (IsUniform) { - VectorValue = GetBroadcastInstrs(ScalarValue); - set(Def, VectorValue, Part); - } else { - // Initialize packing with insertelements to start from undef. - assert(!VF.isScalable() && "VF is assumed to be non scalable."); - Value *Undef = PoisonValue::get(VectorType::get(LastInst->getType(), VF)); - set(Def, Undef, Part); - for (unsigned Lane = 0; Lane < VF.getKnownMinValue(); ++Lane) - ILV->packScalarIntoVectorValue(Def, {Part, Lane}, *this); - VectorValue = get(Def, Part); - } - Builder.restoreIP(OldIP); - return VectorValue; -} - // Process the loop in the VPlan-native vectorization path. This path builds // VPlan upfront in the vectorization pipeline, which allows to apply // VPlan-to-VPlan transformations from the very beginning without modifying the @@ -9960,7 +9665,8 @@ static bool processLoopInVPlanNativePath( // Use the planner for outer loop vectorization. // TODO: CM is not used at this point inside the planner. Turn CM into an // optional argument if we don't need it in the future. - LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, LVL, CM, IAI, PSE, Hints, ORE); + LoopVectorizationPlanner LVP(L, LI, DT, TLI, *TTI, LVL, CM, IAI, PSE, Hints, + ORE); // Get user vectorization factor. ElementCount UserVF = Hints.getWidth(); @@ -9979,8 +9685,10 @@ static bool processLoopInVPlanNativePath( VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); { + bool AddBranchWeights = + hasBranchWeightMD(*L->getLoopLatch()->getTerminator()); GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI, - F->getParent()->getDataLayout()); + F->getParent()->getDataLayout(), AddBranchWeights); InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, VF.Width, 1, LVL, &CM, BFI, PSI, Checks); LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \"" @@ -9988,6 +9696,8 @@ static bool processLoopInVPlanNativePath( LVP.executePlan(VF.Width, 1, BestPlan, LB, DT, false); } + reportVectorization(ORE, L, VF, 1); + // Mark the loop as already vectorized to avoid vectorizing again. Hints.setAlreadyVectorized(); assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs())); @@ -10042,7 +9752,8 @@ static void checkMixedPrecision(Loop *L, OptimizationRemarkEmitter *ORE) { static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, VectorizationFactor &VF, std::optional<unsigned> VScale, Loop *L, - ScalarEvolution &SE) { + ScalarEvolution &SE, + ScalarEpilogueLowering SEL) { InstructionCost CheckCost = Checks.getCost(); if (!CheckCost.isValid()) return false; @@ -10112,11 +9823,13 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, // RtC < ScalarC * TC * (1 / X) ==> RtC * X / ScalarC < TC double MinTC2 = RtC * 10 / ScalarC; - // Now pick the larger minimum. If it is not a multiple of VF, choose the - // next closest multiple of VF. This should partly compensate for ignoring - // the epilogue cost. + // Now pick the larger minimum. If it is not a multiple of VF and a scalar + // epilogue is allowed, choose the next closest multiple of VF. This should + // partly compensate for ignoring the epilogue cost. uint64_t MinTC = std::ceil(std::max(MinTC1, MinTC2)); - VF.MinProfitableTripCount = ElementCount::getFixed(alignTo(MinTC, IntVF)); + if (SEL == CM_ScalarEpilogueAllowed) + MinTC = alignTo(MinTC, IntVF); + VF.MinProfitableTripCount = ElementCount::getFixed(MinTC); LLVM_DEBUG( dbgs() << "LV: Minimum required TC for runtime checks to be profitable:" @@ -10236,7 +9949,14 @@ bool LoopVectorizePass::processLoop(Loop *L) { else { if (*ExpectedTC > TTI->getMinTripCountTailFoldingThreshold()) { LLVM_DEBUG(dbgs() << "\n"); - SEL = CM_ScalarEpilogueNotAllowedLowTripLoop; + // Predicate tail-folded loops are efficient even when the loop + // iteration count is low. However, setting the epilogue policy to + // `CM_ScalarEpilogueNotAllowedLowTripLoop` prevents vectorizing loops + // with runtime checks. It's more effective to let + // `areRuntimeChecksProfitable` determine if vectorization is beneficial + // for the loop. + if (SEL != CM_ScalarEpilogueNotNeededUsePredicate) + SEL = CM_ScalarEpilogueNotAllowedLowTripLoop; } else { LLVM_DEBUG(dbgs() << " But the target considers the trip count too " "small to consider vectorizing.\n"); @@ -10300,7 +10020,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, &Hints, IAI); // Use the planner for vectorization. - LoopVectorizationPlanner LVP(L, LI, TLI, *TTI, &LVL, CM, IAI, PSE, Hints, + LoopVectorizationPlanner LVP(L, LI, DT, TLI, *TTI, &LVL, CM, IAI, PSE, Hints, ORE); // Get user vectorization factor and interleave count. @@ -10313,8 +10033,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { VectorizationFactor VF = VectorizationFactor::Disabled(); unsigned IC = 1; + bool AddBranchWeights = + hasBranchWeightMD(*L->getLoopLatch()->getTerminator()); GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI, - F->getParent()->getDataLayout()); + F->getParent()->getDataLayout(), AddBranchWeights); if (MaybeVF) { VF = *MaybeVF; // Select the interleave count. @@ -10331,7 +10053,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { Hints.getForce() == LoopVectorizeHints::FK_Enabled; if (!ForceVectorization && !areRuntimeChecksProfitable(Checks, VF, getVScaleForTuning(L, *TTI), L, - *PSE.getSE())) { + *PSE.getSE(), SEL)) { ORE->emit([&]() { return OptimizationRemarkAnalysisAliasing( DEBUG_TYPE, "CantReorderMemOps", L->getStartLoc(), @@ -10553,13 +10275,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { DisableRuntimeUnroll = true; } // Report the vectorization decision. - ORE->emit([&]() { - return OptimizationRemark(LV_NAME, "Vectorized", L->getStartLoc(), - L->getHeader()) - << "vectorized loop (vectorization width: " - << NV("VectorizationFactor", VF.Width) - << ", interleaved count: " << NV("InterleaveCount", IC) << ")"; - }); + reportVectorization(ORE, L, VF, IC); } if (ORE->allowExtraAnalysis(LV_NAME)) @@ -10642,8 +10358,14 @@ LoopVectorizeResult LoopVectorizePass::runImpl( Changed |= CFGChanged |= processLoop(L); - if (Changed) + if (Changed) { LAIs->clear(); + +#ifndef NDEBUG + if (VerifySCEV) + SE->verify(); +#endif + } } // Process each loop nest in the function. @@ -10691,10 +10413,6 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, PA.preserve<LoopAnalysis>(); PA.preserve<DominatorTreeAnalysis>(); PA.preserve<ScalarEvolutionAnalysis>(); - -#ifdef EXPENSIVE_CHECKS - SE.verify(); -#endif } if (Result.MadeCFGChange) { diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 821a3fa22a85..fe2aac78e5ab 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -19,7 +19,6 @@ #include "llvm/Transforms/Vectorize/SLPVectorizer.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/PriorityQueue.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" @@ -34,6 +33,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/IVDescriptors.h" @@ -97,7 +97,6 @@ #include <string> #include <tuple> #include <utility> -#include <vector> using namespace llvm; using namespace llvm::PatternMatch; @@ -108,8 +107,9 @@ using namespace slpvectorizer; STATISTIC(NumVectorInstructions, "Number of vector instructions generated"); -cl::opt<bool> RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden, - cl::desc("Run the SLP vectorization passes")); +static cl::opt<bool> + RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden, + cl::desc("Run the SLP vectorization passes")); static cl::opt<int> SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden, @@ -140,10 +140,6 @@ static cl::opt<unsigned> MaxVFOption("slp-max-vf", cl::init(0), cl::Hidden, cl::desc("Maximum SLP vectorization factor (0=unlimited)")); -static cl::opt<int> -MaxStoreLookup("slp-max-store-lookup", cl::init(32), cl::Hidden, - cl::desc("Maximum depth of the lookup for consecutive stores.")); - /// Limits the size of scheduling regions in a block. /// It avoid long compile times for _very_ large blocks where vector /// instructions are spread over a wide range. @@ -232,6 +228,17 @@ static bool isVectorLikeInstWithConstOps(Value *V) { return isConstant(I->getOperand(2)); } +#if !defined(NDEBUG) +/// Print a short descriptor of the instruction bundle suitable for debug output. +static std::string shortBundleName(ArrayRef<Value *> VL) { + std::string Result; + raw_string_ostream OS(Result); + OS << "n=" << VL.size() << " [" << *VL.front() << ", ..]"; + OS.flush(); + return Result; +} +#endif + /// \returns true if all of the instructions in \p VL are in the same block or /// false otherwise. static bool allSameBlock(ArrayRef<Value *> VL) { @@ -429,26 +436,6 @@ static SmallBitVector isUndefVector(const Value *V, /// i32 6> /// %2 = mul <4 x i8> %1, %1 /// ret <4 x i8> %2 -/// We convert this initially to something like: -/// %x0 = extractelement <4 x i8> %x, i32 0 -/// %x3 = extractelement <4 x i8> %x, i32 3 -/// %y1 = extractelement <4 x i8> %y, i32 1 -/// %y2 = extractelement <4 x i8> %y, i32 2 -/// %1 = insertelement <4 x i8> poison, i8 %x0, i32 0 -/// %2 = insertelement <4 x i8> %1, i8 %x3, i32 1 -/// %3 = insertelement <4 x i8> %2, i8 %y1, i32 2 -/// %4 = insertelement <4 x i8> %3, i8 %y2, i32 3 -/// %5 = mul <4 x i8> %4, %4 -/// %6 = extractelement <4 x i8> %5, i32 0 -/// %ins1 = insertelement <4 x i8> poison, i8 %6, i32 0 -/// %7 = extractelement <4 x i8> %5, i32 1 -/// %ins2 = insertelement <4 x i8> %ins1, i8 %7, i32 1 -/// %8 = extractelement <4 x i8> %5, i32 2 -/// %ins3 = insertelement <4 x i8> %ins2, i8 %8, i32 2 -/// %9 = extractelement <4 x i8> %5, i32 3 -/// %ins4 = insertelement <4 x i8> %ins3, i8 %9, i32 3 -/// ret <4 x i8> %ins4 -/// InstCombiner transforms this into a shuffle and vector mul /// Mask will return the Shuffle Mask equivalent to the extracted elements. /// TODO: Can we split off and reuse the shuffle mask detection from /// ShuffleVectorInst/getShuffleCost? @@ -539,117 +526,6 @@ static std::optional<unsigned> getExtractIndex(Instruction *E) { return *EI->idx_begin(); } -/// Tries to find extractelement instructions with constant indices from fixed -/// vector type and gather such instructions into a bunch, which highly likely -/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was -/// successful, the matched scalars are replaced by poison values in \p VL for -/// future analysis. -static std::optional<TTI::ShuffleKind> -tryToGatherExtractElements(SmallVectorImpl<Value *> &VL, - SmallVectorImpl<int> &Mask) { - // Scan list of gathered scalars for extractelements that can be represented - // as shuffles. - MapVector<Value *, SmallVector<int>> VectorOpToIdx; - SmallVector<int> UndefVectorExtracts; - for (int I = 0, E = VL.size(); I < E; ++I) { - auto *EI = dyn_cast<ExtractElementInst>(VL[I]); - if (!EI) { - if (isa<UndefValue>(VL[I])) - UndefVectorExtracts.push_back(I); - continue; - } - auto *VecTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType()); - if (!VecTy || !isa<ConstantInt, UndefValue>(EI->getIndexOperand())) - continue; - std::optional<unsigned> Idx = getExtractIndex(EI); - // Undefined index. - if (!Idx) { - UndefVectorExtracts.push_back(I); - continue; - } - SmallBitVector ExtractMask(VecTy->getNumElements(), true); - ExtractMask.reset(*Idx); - if (isUndefVector(EI->getVectorOperand(), ExtractMask).all()) { - UndefVectorExtracts.push_back(I); - continue; - } - VectorOpToIdx[EI->getVectorOperand()].push_back(I); - } - // Sort the vector operands by the maximum number of uses in extractelements. - MapVector<unsigned, SmallVector<Value *>> VFToVector; - for (const auto &Data : VectorOpToIdx) - VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()] - .push_back(Data.first); - for (auto &Data : VFToVector) { - stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) { - return VectorOpToIdx.find(V1)->second.size() > - VectorOpToIdx.find(V2)->second.size(); - }); - } - // Find the best pair of the vectors with the same number of elements or a - // single vector. - const int UndefSz = UndefVectorExtracts.size(); - unsigned SingleMax = 0; - Value *SingleVec = nullptr; - unsigned PairMax = 0; - std::pair<Value *, Value *> PairVec(nullptr, nullptr); - for (auto &Data : VFToVector) { - Value *V1 = Data.second.front(); - if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) { - SingleMax = VectorOpToIdx[V1].size() + UndefSz; - SingleVec = V1; - } - Value *V2 = nullptr; - if (Data.second.size() > 1) - V2 = *std::next(Data.second.begin()); - if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + - UndefSz) { - PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz; - PairVec = std::make_pair(V1, V2); - } - } - if (SingleMax == 0 && PairMax == 0 && UndefSz == 0) - return std::nullopt; - // Check if better to perform a shuffle of 2 vectors or just of a single - // vector. - SmallVector<Value *> SavedVL(VL.begin(), VL.end()); - SmallVector<Value *> GatheredExtracts( - VL.size(), PoisonValue::get(VL.front()->getType())); - if (SingleMax >= PairMax && SingleMax) { - for (int Idx : VectorOpToIdx[SingleVec]) - std::swap(GatheredExtracts[Idx], VL[Idx]); - } else { - for (Value *V : {PairVec.first, PairVec.second}) - for (int Idx : VectorOpToIdx[V]) - std::swap(GatheredExtracts[Idx], VL[Idx]); - } - // Add extracts from undefs too. - for (int Idx : UndefVectorExtracts) - std::swap(GatheredExtracts[Idx], VL[Idx]); - // Check that gather of extractelements can be represented as just a - // shuffle of a single/two vectors the scalars are extracted from. - std::optional<TTI::ShuffleKind> Res = - isFixedVectorShuffle(GatheredExtracts, Mask); - if (!Res) { - // TODO: try to check other subsets if possible. - // Restore the original VL if attempt was not successful. - VL.swap(SavedVL); - return std::nullopt; - } - // Restore unused scalars from mask, if some of the extractelements were not - // selected for shuffle. - for (int I = 0, E = GatheredExtracts.size(); I < E; ++I) { - auto *EI = dyn_cast<ExtractElementInst>(VL[I]); - if (!EI || !isa<FixedVectorType>(EI->getVectorOperandType()) || - !isa<ConstantInt, UndefValue>(EI->getIndexOperand()) || - is_contained(UndefVectorExtracts, I)) - continue; - if (Mask[I] == PoisonMaskElem && !isa<PoisonValue>(GatheredExtracts[I])) - std::swap(VL[I], GatheredExtracts[I]); - } - return Res; -} - namespace { /// Main data required for vectorization of instructions. @@ -695,7 +571,7 @@ static Value *isOneOf(const InstructionsState &S, Value *Op) { return S.OpValue; } -/// \returns true if \p Opcode is allowed as part of of the main/alternate +/// \returns true if \p Opcode is allowed as part of the main/alternate /// instruction for SLP vectorization. /// /// Example of unsupported opcode is SDIV that can potentially cause UB if the @@ -889,18 +765,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, /// \returns true if all of the values in \p VL have the same type or false /// otherwise. static bool allSameType(ArrayRef<Value *> VL) { - Type *Ty = VL[0]->getType(); - for (int i = 1, e = VL.size(); i < e; i++) - if (VL[i]->getType() != Ty) - return false; - - return true; + Type *Ty = VL.front()->getType(); + return all_of(VL.drop_front(), [&](Value *V) { return V->getType() == Ty; }); } /// \returns True if in-tree use also needs extract. This refers to /// possible scalar operand in vectorized instruction. -static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, - TargetLibraryInfo *TLI) { +static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, + TargetLibraryInfo *TLI) { unsigned Opcode = UserInst->getOpcode(); switch (Opcode) { case Instruction::Load: { @@ -914,11 +786,10 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, case Instruction::Call: { CallInst *CI = cast<CallInst>(UserInst); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) { - if (isVectorIntrinsicWithScalarOpAtArg(ID, i)) - return (CI->getArgOperand(i) == Scalar); - } - [[fallthrough]]; + return any_of(enumerate(CI->args()), [&](auto &&Arg) { + return isVectorIntrinsicWithScalarOpAtArg(ID, Arg.index()) && + Arg.value().get() == Scalar; + }); } default: return false; @@ -1181,6 +1052,7 @@ public: void deleteTree() { VectorizableTree.clear(); ScalarToTreeEntry.clear(); + MultiNodeScalars.clear(); MustGather.clear(); EntryToLastInstruction.clear(); ExternalUses.clear(); @@ -1273,7 +1145,7 @@ public: /// {{{i16, i16}, {i16, i16}}, {{i16, i16}, {i16, i16}}} and so on. /// /// \returns number of elements in vector if isomorphism exists, 0 otherwise. - unsigned canMapToVector(Type *T, const DataLayout &DL) const; + unsigned canMapToVector(Type *T) const; /// \returns True if the VectorizableTree is both tiny and not fully /// vectorizable. We do not vectorize such trees. @@ -1324,6 +1196,9 @@ public: } LLVM_DUMP_METHOD void dump() const { dump(dbgs()); } #endif + bool operator == (const EdgeInfo &Other) const { + return UserTE == Other.UserTE && EdgeIdx == Other.EdgeIdx; + } }; /// A helper class used for scoring candidates for two consecutive lanes. @@ -1764,7 +1639,7 @@ public: auto *IdxLaneI = dyn_cast<Instruction>(IdxLaneV); if (!IdxLaneI || !isa<Instruction>(OpIdxLaneV)) return 0; - return R.areAllUsersVectorized(IdxLaneI, std::nullopt) + return R.areAllUsersVectorized(IdxLaneI) ? LookAheadHeuristics::ScoreAllUserVectorized : 0; } @@ -1941,7 +1816,7 @@ public: HashMap[NumFreeOpsHash.Hash] = std::make_pair(1, Lane); } else if (NumFreeOpsHash.NumOfAPOs == Min && NumFreeOpsHash.NumOpsWithSameOpcodeParent == SameOpNumber) { - auto It = HashMap.find(NumFreeOpsHash.Hash); + auto *It = HashMap.find(NumFreeOpsHash.Hash); if (It == HashMap.end()) HashMap[NumFreeOpsHash.Hash] = std::make_pair(1, Lane); else @@ -2203,7 +2078,7 @@ public: for (int Pass = 0; Pass != 2; ++Pass) { // Check if no need to reorder operands since they're are perfect or // shuffled diamond match. - // Need to to do it to avoid extra external use cost counting for + // Need to do it to avoid extra external use cost counting for // shuffled matches, which may cause regressions. if (SkipReordering()) break; @@ -2388,6 +2263,18 @@ public: ~BoUpSLP(); private: + /// Determine if a vectorized value \p V in can be demoted to + /// a smaller type with a truncation. We collect the values that will be + /// demoted in ToDemote and additional roots that require investigating in + /// Roots. + /// \param DemotedConsts list of Instruction/OperandIndex pairs that are + /// constant and to be demoted. Required to correctly identify constant nodes + /// to be demoted. + bool collectValuesToDemote( + Value *V, SmallVectorImpl<Value *> &ToDemote, + DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts, + SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const; + /// Check if the operands on the edges \p Edges of the \p UserTE allows /// reordering (i.e. the operands can be reordered because they have only one /// user and reordarable). @@ -2410,12 +2297,25 @@ private: TreeEntry *getVectorizedOperand(TreeEntry *UserTE, unsigned OpIdx) { ArrayRef<Value *> VL = UserTE->getOperand(OpIdx); TreeEntry *TE = nullptr; - const auto *It = find_if(VL, [this, &TE](Value *V) { + const auto *It = find_if(VL, [&](Value *V) { TE = getTreeEntry(V); - return TE; + if (TE && is_contained(TE->UserTreeIndices, EdgeInfo(UserTE, OpIdx))) + return true; + auto It = MultiNodeScalars.find(V); + if (It != MultiNodeScalars.end()) { + for (TreeEntry *E : It->second) { + if (is_contained(E->UserTreeIndices, EdgeInfo(UserTE, OpIdx))) { + TE = E; + return true; + } + } + } + return false; }); - if (It != VL.end() && TE->isSame(VL)) + if (It != VL.end()) { + assert(TE->isSame(VL) && "Expected same scalars."); return TE; + } return nullptr; } @@ -2428,13 +2328,16 @@ private: } /// Checks if all users of \p I are the part of the vectorization tree. - bool areAllUsersVectorized(Instruction *I, - ArrayRef<Value *> VectorizedVals) const; + bool areAllUsersVectorized( + Instruction *I, + const SmallDenseSet<Value *> *VectorizedVals = nullptr) const; /// Return information about the vector formed for the specified index /// of a vector of (the same) instruction. - TargetTransformInfo::OperandValueInfo getOperandInfo(ArrayRef<Value *> VL, - unsigned OpIdx); + TargetTransformInfo::OperandValueInfo getOperandInfo(ArrayRef<Value *> Ops); + + /// \ returns the graph entry for the \p Idx operand of the \p E entry. + const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const; /// \returns the cost of the vectorizable entry. InstructionCost getEntryCost(const TreeEntry *E, @@ -2450,15 +2353,22 @@ private: /// vector) and sets \p CurrentOrder to the identity permutation; otherwise /// returns false, setting \p CurrentOrder to either an empty vector or a /// non-identity permutation that allows to reuse extract instructions. + /// \param ResizeAllowed indicates whether it is allowed to handle subvector + /// extract order. bool canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, - SmallVectorImpl<unsigned> &CurrentOrder) const; + SmallVectorImpl<unsigned> &CurrentOrder, + bool ResizeAllowed = false) const; /// Vectorize a single entry in the tree. - Value *vectorizeTree(TreeEntry *E); + /// \param PostponedPHIs true, if need to postpone emission of phi nodes to + /// avoid issues with def-use order. + Value *vectorizeTree(TreeEntry *E, bool PostponedPHIs); /// Vectorize a single entry in the tree, the \p Idx-th operand of the entry /// \p E. - Value *vectorizeOperand(TreeEntry *E, unsigned NodeIdx); + /// \param PostponedPHIs true, if need to postpone emission of phi nodes to + /// avoid issues with def-use order. + Value *vectorizeOperand(TreeEntry *E, unsigned NodeIdx, bool PostponedPHIs); /// Create a new vector from a list of scalar values. Produces a sequence /// which exploits values reused across lanes, and arranges the inserts @@ -2477,17 +2387,50 @@ private: /// instruction in the list). Instruction &getLastInstructionInBundle(const TreeEntry *E); - /// Checks if the gathered \p VL can be represented as shuffle(s) of previous - /// tree entries. + /// Tries to find extractelement instructions with constant indices from fixed + /// vector type and gather such instructions into a bunch, which highly likely + /// might be detected as a shuffle of 1 or 2 input vectors. If this attempt + /// was successful, the matched scalars are replaced by poison values in \p VL + /// for future analysis. + std::optional<TargetTransformInfo::ShuffleKind> + tryToGatherSingleRegisterExtractElements(MutableArrayRef<Value *> VL, + SmallVectorImpl<int> &Mask) const; + + /// Tries to find extractelement instructions with constant indices from fixed + /// vector type and gather such instructions into a bunch, which highly likely + /// might be detected as a shuffle of 1 or 2 input vectors. If this attempt + /// was successful, the matched scalars are replaced by poison values in \p VL + /// for future analysis. + SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> + tryToGatherExtractElements(SmallVectorImpl<Value *> &VL, + SmallVectorImpl<int> &Mask, + unsigned NumParts) const; + + /// Checks if the gathered \p VL can be represented as a single register + /// shuffle(s) of previous tree entries. /// \param TE Tree entry checked for permutation. /// \param VL List of scalars (a subset of the TE scalar), checked for - /// permutations. + /// permutations. Must form single-register vector. /// \returns ShuffleKind, if gathered values can be represented as shuffles of - /// previous tree entries. \p Mask is filled with the shuffle mask. + /// previous tree entries. \p Part of \p Mask is filled with the shuffle mask. std::optional<TargetTransformInfo::ShuffleKind> - isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, - SmallVectorImpl<int> &Mask, - SmallVectorImpl<const TreeEntry *> &Entries); + isGatherShuffledSingleRegisterEntry( + const TreeEntry *TE, ArrayRef<Value *> VL, MutableArrayRef<int> Mask, + SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part); + + /// Checks if the gathered \p VL can be represented as multi-register + /// shuffle(s) of previous tree entries. + /// \param TE Tree entry checked for permutation. + /// \param VL List of scalars (a subset of the TE scalar), checked for + /// permutations. + /// \returns per-register series of ShuffleKind, if gathered values can be + /// represented as shuffles of previous tree entries. \p Mask is filled with + /// the shuffle mask (also on per-register base). + SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> + isGatherShuffledEntry( + const TreeEntry *TE, ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask, + SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries, + unsigned NumParts); /// \returns the scalarization cost for this list of values. Assuming that /// this subtree gets vectorized, we may need to extract the values from the @@ -2517,14 +2460,14 @@ private: /// Helper for `findExternalStoreUsersReorderIndices()`. It iterates over the /// users of \p TE and collects the stores. It returns the map from the store /// pointers to the collected stores. - DenseMap<Value *, SmallVector<StoreInst *, 4>> + DenseMap<Value *, SmallVector<StoreInst *>> collectUserStores(const BoUpSLP::TreeEntry *TE) const; /// Helper for `findExternalStoreUsersReorderIndices()`. It checks if the - /// stores in \p StoresVec can form a vector instruction. If so it returns true - /// and populates \p ReorderIndices with the shuffle indices of the the stores - /// when compared to the sorted vector. - bool canFormVector(const SmallVector<StoreInst *, 4> &StoresVec, + /// stores in \p StoresVec can form a vector instruction. If so it returns + /// true and populates \p ReorderIndices with the shuffle indices of the + /// stores when compared to the sorted vector. + bool canFormVector(ArrayRef<StoreInst *> StoresVec, OrdersType &ReorderIndices) const; /// Iterates through the users of \p TE, looking for scalar stores that can be @@ -2621,10 +2564,18 @@ private: /// The Scalars are vectorized into this value. It is initialized to Null. WeakTrackingVH VectorizedValue = nullptr; + /// New vector phi instructions emitted for the vectorized phi nodes. + PHINode *PHI = nullptr; + /// Do we need to gather this sequence or vectorize it /// (either with vector instruction or with scatter/gather /// intrinsics for store/load)? - enum EntryState { Vectorize, ScatterVectorize, NeedToGather }; + enum EntryState { + Vectorize, + ScatterVectorize, + PossibleStridedVectorize, + NeedToGather + }; EntryState State; /// Does this sequence require some shuffling? @@ -2772,6 +2723,14 @@ private: return FoundLane; } + /// Build a shuffle mask for graph entry which represents a merge of main + /// and alternate operations. + void + buildAltOpShuffleMask(const function_ref<bool(Instruction *)> IsAltOp, + SmallVectorImpl<int> &Mask, + SmallVectorImpl<Value *> *OpScalars = nullptr, + SmallVectorImpl<Value *> *AltScalars = nullptr) const; + #ifndef NDEBUG /// Debug printer. LLVM_DUMP_METHOD void dump() const { @@ -2792,6 +2751,9 @@ private: case ScatterVectorize: dbgs() << "ScatterVectorize\n"; break; + case PossibleStridedVectorize: + dbgs() << "PossibleStridedVectorize\n"; + break; case NeedToGather: dbgs() << "NeedToGather\n"; break; @@ -2892,7 +2854,14 @@ private: } if (Last->State != TreeEntry::NeedToGather) { for (Value *V : VL) { - assert(!getTreeEntry(V) && "Scalar already in tree!"); + const TreeEntry *TE = getTreeEntry(V); + assert((!TE || TE == Last || doesNotNeedToBeScheduled(V)) && + "Scalar already in tree!"); + if (TE) { + if (TE != Last) + MultiNodeScalars.try_emplace(V).first->getSecond().push_back(Last); + continue; + } ScalarToTreeEntry[V] = Last; } // Update the scheduler bundle to point to this TreeEntry. @@ -2905,7 +2874,8 @@ private: for (Value *V : VL) { if (doesNotNeedToBeScheduled(V)) continue; - assert(BundleMember && "Unexpected end of bundle."); + if (!BundleMember) + continue; BundleMember->TE = Last; BundleMember = BundleMember->NextInBundle; } @@ -2913,6 +2883,10 @@ private: assert(!BundleMember && "Bundle and VL out of sync"); } else { MustGather.insert(VL.begin(), VL.end()); + // Build a map for gathered scalars to the nodes where they are used. + for (Value *V : VL) + if (!isConstant(V)) + ValueToGatherNodes.try_emplace(V).first->getSecond().insert(Last); } if (UserTreeIdx.UserTE) @@ -2950,6 +2924,10 @@ private: /// Maps a specific scalar to its tree entry. SmallDenseMap<Value *, TreeEntry *> ScalarToTreeEntry; + /// List of scalars, used in several vectorize nodes, and the list of the + /// nodes. + SmallDenseMap<Value *, SmallVector<TreeEntry *>> MultiNodeScalars; + /// Maps a value to the proposed vectorizable size. SmallDenseMap<Value *, unsigned> InstrElementSize; @@ -2995,25 +2973,25 @@ private: /// is invariant in the calling loop. bool isAliased(const MemoryLocation &Loc1, Instruction *Inst1, Instruction *Inst2) { + if (!Loc1.Ptr || !isSimple(Inst1) || !isSimple(Inst2)) + return true; // First check if the result is already in the cache. - AliasCacheKey key = std::make_pair(Inst1, Inst2); - std::optional<bool> &result = AliasCache[key]; - if (result) { - return *result; - } - bool aliased = true; - if (Loc1.Ptr && isSimple(Inst1)) - aliased = isModOrRefSet(BatchAA.getModRefInfo(Inst2, Loc1)); + AliasCacheKey Key = std::make_pair(Inst1, Inst2); + auto It = AliasCache.find(Key); + if (It != AliasCache.end()) + return It->second; + bool Aliased = isModOrRefSet(BatchAA.getModRefInfo(Inst2, Loc1)); // Store the result in the cache. - result = aliased; - return aliased; + AliasCache.try_emplace(Key, Aliased); + AliasCache.try_emplace(std::make_pair(Inst2, Inst1), Aliased); + return Aliased; } using AliasCacheKey = std::pair<Instruction *, Instruction *>; /// Cache for alias results. /// TODO: consider moving this to the AliasAnalysis itself. - DenseMap<AliasCacheKey, std::optional<bool>> AliasCache; + DenseMap<AliasCacheKey, bool> AliasCache; // Cache for pointerMayBeCaptured calls inside AA. This is preserved // globally through SLP because we don't perform any action which @@ -3047,7 +3025,7 @@ private: SetVector<Instruction *> GatherShuffleExtractSeq; /// A list of blocks that we are going to CSE. - SetVector<BasicBlock *> CSEBlocks; + DenseSet<BasicBlock *> CSEBlocks; /// Contains all scheduling relevant data for an instruction. /// A ScheduleData either represents a single instruction or a member of an @@ -3497,7 +3475,7 @@ private: BasicBlock *BB; /// Simple memory allocation for ScheduleData. - std::vector<std::unique_ptr<ScheduleData[]>> ScheduleDataChunks; + SmallVector<std::unique_ptr<ScheduleData[]>> ScheduleDataChunks; /// The size of a ScheduleData array in ScheduleDataChunks. int ChunkSize; @@ -3607,7 +3585,7 @@ private: /// where "width" indicates the minimum bit width and "signed" is True if the /// value must be signed-extended, rather than zero-extended, back to its /// original width. - MapVector<Value *, std::pair<uint64_t, bool>> MinBWs; + DenseMap<const TreeEntry *, std::pair<uint64_t, bool>> MinBWs; }; } // end namespace slpvectorizer @@ -3676,7 +3654,7 @@ template <> struct GraphTraits<BoUpSLP *> { template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { using TreeEntry = BoUpSLP::TreeEntry; - DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) { std::string Str; @@ -3699,7 +3677,8 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { const BoUpSLP *) { if (Entry->State == TreeEntry::NeedToGather) return "color=red"; - if (Entry->State == TreeEntry::ScatterVectorize) + if (Entry->State == TreeEntry::ScatterVectorize || + Entry->State == TreeEntry::PossibleStridedVectorize) return "color=blue"; return ""; } @@ -3761,7 +3740,7 @@ static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) { inversePermutation(Order, MaskOrder); } reorderReuses(MaskOrder, Mask); - if (ShuffleVectorInst::isIdentityMask(MaskOrder)) { + if (ShuffleVectorInst::isIdentityMask(MaskOrder, MaskOrder.size())) { Order.clear(); return; } @@ -3779,7 +3758,40 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { OrdersType CurrentOrder(NumScalars, NumScalars); SmallVector<int> Positions; SmallBitVector UsedPositions(NumScalars); - const TreeEntry *STE = nullptr; + DenseMap<const TreeEntry *, unsigned> UsedEntries; + DenseMap<Value *, std::pair<const TreeEntry *, unsigned>> ValueToEntryPos; + for (Value *V : TE.Scalars) { + if (!isa<LoadInst, ExtractElementInst, ExtractValueInst>(V)) + continue; + const auto *LocalSTE = getTreeEntry(V); + if (!LocalSTE) + continue; + unsigned Lane = + std::distance(LocalSTE->Scalars.begin(), find(LocalSTE->Scalars, V)); + if (Lane >= NumScalars) + continue; + ++UsedEntries.try_emplace(LocalSTE, 0).first->getSecond(); + ValueToEntryPos.try_emplace(V, LocalSTE, Lane); + } + if (UsedEntries.empty()) + return std::nullopt; + const TreeEntry &BestSTE = + *std::max_element(UsedEntries.begin(), UsedEntries.end(), + [](const std::pair<const TreeEntry *, unsigned> &P1, + const std::pair<const TreeEntry *, unsigned> &P2) { + return P1.second < P2.second; + }) + ->first; + UsedEntries.erase(&BestSTE); + const TreeEntry *SecondBestSTE = nullptr; + if (!UsedEntries.empty()) + SecondBestSTE = + std::max_element(UsedEntries.begin(), UsedEntries.end(), + [](const std::pair<const TreeEntry *, unsigned> &P1, + const std::pair<const TreeEntry *, unsigned> &P2) { + return P1.second < P2.second; + }) + ->first; // Try to find all gathered scalars that are gets vectorized in other // vectorize node. Here we can have only one single tree vector node to // correctly identify order of the gathered scalars. @@ -3787,58 +3799,56 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { Value *V = TE.Scalars[I]; if (!isa<LoadInst, ExtractElementInst, ExtractValueInst>(V)) continue; - if (const auto *LocalSTE = getTreeEntry(V)) { - if (!STE) - STE = LocalSTE; - else if (STE != LocalSTE) - // Take the order only from the single vector node. - return std::nullopt; - unsigned Lane = - std::distance(STE->Scalars.begin(), find(STE->Scalars, V)); - if (Lane >= NumScalars) - return std::nullopt; - if (CurrentOrder[Lane] != NumScalars) { - if (Lane != I) - continue; - UsedPositions.reset(CurrentOrder[Lane]); - } - // The partial identity (where only some elements of the gather node are - // in the identity order) is good. - CurrentOrder[Lane] = I; - UsedPositions.set(I); + const auto [LocalSTE, Lane] = ValueToEntryPos.lookup(V); + if (!LocalSTE || (LocalSTE != &BestSTE && LocalSTE != SecondBestSTE)) + continue; + if (CurrentOrder[Lane] != NumScalars) { + if ((CurrentOrder[Lane] >= BestSTE.Scalars.size() || + BestSTE.Scalars[CurrentOrder[Lane]] == V) && + (Lane != I || LocalSTE == SecondBestSTE)) + continue; + UsedPositions.reset(CurrentOrder[Lane]); } + // The partial identity (where only some elements of the gather node are + // in the identity order) is good. + CurrentOrder[Lane] = I; + UsedPositions.set(I); } // Need to keep the order if we have a vector entry and at least 2 scalars or // the vectorized entry has just 2 scalars. - if (STE && (UsedPositions.count() > 1 || STE->Scalars.size() == 2)) { - auto &&IsIdentityOrder = [NumScalars](ArrayRef<unsigned> CurrentOrder) { - for (unsigned I = 0; I < NumScalars; ++I) - if (CurrentOrder[I] != I && CurrentOrder[I] != NumScalars) - return false; - return true; - }; - if (IsIdentityOrder(CurrentOrder)) - return OrdersType(); - auto *It = CurrentOrder.begin(); - for (unsigned I = 0; I < NumScalars;) { - if (UsedPositions.test(I)) { - ++I; - continue; - } - if (*It == NumScalars) { - *It = I; - ++I; - } - ++It; + if (BestSTE.Scalars.size() != 2 && UsedPositions.count() <= 1) + return std::nullopt; + auto IsIdentityOrder = [&](ArrayRef<unsigned> CurrentOrder) { + for (unsigned I = 0; I < NumScalars; ++I) + if (CurrentOrder[I] != I && CurrentOrder[I] != NumScalars) + return false; + return true; + }; + if (IsIdentityOrder(CurrentOrder)) + return OrdersType(); + auto *It = CurrentOrder.begin(); + for (unsigned I = 0; I < NumScalars;) { + if (UsedPositions.test(I)) { + ++I; + continue; } - return std::move(CurrentOrder); + if (*It == NumScalars) { + *It = I; + ++I; + } + ++It; } - return std::nullopt; + return std::move(CurrentOrder); } namespace { /// Tracks the state we can represent the loads in the given sequence. -enum class LoadsState { Gather, Vectorize, ScatterVectorize }; +enum class LoadsState { + Gather, + Vectorize, + ScatterVectorize, + PossibleStridedVectorize +}; } // anonymous namespace static bool arePointersCompatible(Value *Ptr1, Value *Ptr2, @@ -3898,6 +3908,7 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, if (IsSorted || all_of(PointerOps, [&](Value *P) { return arePointersCompatible(P, PointerOps.front(), TLI); })) { + bool IsPossibleStrided = false; if (IsSorted) { Value *Ptr0; Value *PtrN; @@ -3913,6 +3924,8 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, // Check that the sorted loads are consecutive. if (static_cast<unsigned>(*Diff) == VL.size() - 1) return LoadsState::Vectorize; + // Simple check if not a strided access - clear order. + IsPossibleStrided = *Diff % (VL.size() - 1) == 0; } // TODO: need to improve analysis of the pointers, if not all of them are // GEPs or have > 2 operands, we end up with a gather node, which just @@ -3934,7 +3947,8 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) && !TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) - return LoadsState::ScatterVectorize; + return IsPossibleStrided ? LoadsState::PossibleStridedVectorize + : LoadsState::ScatterVectorize; } } @@ -4050,7 +4064,8 @@ static bool areTwoInsertFromSameBuildVector( // Go through the vector operand of insertelement instructions trying to find // either VU as the original vector for IE2 or V as the original vector for // IE1. - SmallSet<int, 8> ReusedIdx; + SmallBitVector ReusedIdx( + cast<VectorType>(VU->getType())->getElementCount().getKnownMinValue()); bool IsReusedIdx = false; do { if (IE2 == VU && !IE1) @@ -4058,16 +4073,18 @@ static bool areTwoInsertFromSameBuildVector( if (IE1 == V && !IE2) return V->hasOneUse(); if (IE1 && IE1 != V) { - IsReusedIdx |= - !ReusedIdx.insert(getInsertIndex(IE1).value_or(*Idx2)).second; + unsigned Idx1 = getInsertIndex(IE1).value_or(*Idx2); + IsReusedIdx |= ReusedIdx.test(Idx1); + ReusedIdx.set(Idx1); if ((IE1 != VU && !IE1->hasOneUse()) || IsReusedIdx) IE1 = nullptr; else IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1)); } if (IE2 && IE2 != VU) { - IsReusedIdx |= - !ReusedIdx.insert(getInsertIndex(IE2).value_or(*Idx1)).second; + unsigned Idx2 = getInsertIndex(IE2).value_or(*Idx1); + IsReusedIdx |= ReusedIdx.test(Idx2); + ReusedIdx.set(Idx2); if ((IE2 != V && !IE2->hasOneUse()) || IsReusedIdx) IE2 = nullptr; else @@ -4135,13 +4152,16 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { return std::nullopt; // No need to reorder. return std::move(ResOrder); } - if (TE.State == TreeEntry::Vectorize && + if ((TE.State == TreeEntry::Vectorize || + TE.State == TreeEntry::PossibleStridedVectorize) && (isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE.getMainOp()) || (TopToBottom && isa<StoreInst, InsertElementInst>(TE.getMainOp()))) && !TE.isAltShuffle()) return TE.ReorderIndices; if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::PHI) { - auto PHICompare = [](llvm::Value *V1, llvm::Value *V2) { + auto PHICompare = [&](unsigned I1, unsigned I2) { + Value *V1 = TE.Scalars[I1]; + Value *V2 = TE.Scalars[I2]; if (V1 == V2) return false; if (!V1->hasOneUse() || !V2->hasOneUse()) @@ -4180,14 +4200,13 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { }; if (!TE.ReorderIndices.empty()) return TE.ReorderIndices; - DenseMap<Value *, unsigned> PhiToId; - SmallVector<Value *, 4> Phis; + DenseMap<unsigned, unsigned> PhiToId; + SmallVector<unsigned> Phis(TE.Scalars.size()); + std::iota(Phis.begin(), Phis.end(), 0); OrdersType ResOrder(TE.Scalars.size()); - for (unsigned Id = 0, Sz = TE.Scalars.size(); Id < Sz; ++Id) { - PhiToId[TE.Scalars[Id]] = Id; - Phis.push_back(TE.Scalars[Id]); - } - llvm::stable_sort(Phis, PHICompare); + for (unsigned Id = 0, Sz = TE.Scalars.size(); Id < Sz; ++Id) + PhiToId[Id] = Id; + stable_sort(Phis, PHICompare); for (unsigned Id = 0, Sz = Phis.size(); Id < Sz; ++Id) ResOrder[Id] = PhiToId[Phis[Id]]; if (IsIdentityOrder(ResOrder)) @@ -4214,7 +4233,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // Check that gather of extractelements can be represented as // just a shuffle of a single vector. OrdersType CurrentOrder; - bool Reuse = canReuseExtract(TE.Scalars, TE.getMainOp(), CurrentOrder); + bool Reuse = canReuseExtract(TE.Scalars, TE.getMainOp(), CurrentOrder, + /*ResizeAllowed=*/true); if (Reuse || !CurrentOrder.empty()) { if (!CurrentOrder.empty()) fixupOrderingIndices(CurrentOrder); @@ -4270,7 +4290,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { static bool isRepeatedNonIdentityClusteredMask(ArrayRef<int> Mask, unsigned Sz) { ArrayRef<int> FirstCluster = Mask.slice(0, Sz); - if (ShuffleVectorInst::isIdentityMask(FirstCluster)) + if (ShuffleVectorInst::isIdentityMask(FirstCluster, Sz)) return false; for (unsigned I = Sz, E = Mask.size(); I < E; I += Sz) { ArrayRef<int> Cluster = Mask.slice(I, Sz); @@ -4386,7 +4406,9 @@ void BoUpSLP::reorderTopToBottom() { ++Cnt; } VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get()); - if (TE->State != TreeEntry::Vectorize || !TE->ReuseShuffleIndices.empty()) + if (!(TE->State == TreeEntry::Vectorize || + TE->State == TreeEntry::PossibleStridedVectorize) || + !TE->ReuseShuffleIndices.empty()) GathersToOrders.try_emplace(TE.get(), *CurrentOrder); if (TE->State == TreeEntry::Vectorize && TE->getOpcode() == Instruction::PHI) @@ -4409,6 +4431,9 @@ void BoUpSLP::reorderTopToBottom() { MapVector<OrdersType, unsigned, DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>> OrdersUses; + // Last chance orders - scatter vectorize. Try to use their orders if no + // other orders or the order is counted already. + SmallVector<OrdersType> StridedVectorizeOrders; SmallPtrSet<const TreeEntry *, 4> VisitedOps; for (const TreeEntry *OpTE : OrderedEntries) { // No need to reorder this nodes, still need to extend and to use shuffle, @@ -4455,6 +4480,11 @@ void BoUpSLP::reorderTopToBottom() { if (Order.empty()) continue; } + // Postpone scatter orders. + if (OpTE->State == TreeEntry::PossibleStridedVectorize) { + StridedVectorizeOrders.push_back(Order); + continue; + } // Stores actually store the mask, not the order, need to invert. if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && OpTE->getOpcode() == Instruction::Store && !Order.empty()) { @@ -4472,8 +4502,21 @@ void BoUpSLP::reorderTopToBottom() { } } // Set order of the user node. - if (OrdersUses.empty()) - continue; + if (OrdersUses.empty()) { + if (StridedVectorizeOrders.empty()) + continue; + // Add (potentially!) strided vectorize orders. + for (OrdersType &Order : StridedVectorizeOrders) + ++OrdersUses.insert(std::make_pair(Order, 0)).first->second; + } else { + // Account (potentially!) strided vectorize orders only if it was used + // already. + for (OrdersType &Order : StridedVectorizeOrders) { + auto *It = OrdersUses.find(Order); + if (It != OrdersUses.end()) + ++It->second; + } + } // Choose the most used order. ArrayRef<unsigned> BestOrder = OrdersUses.front().first; unsigned Cnt = OrdersUses.front().second; @@ -4514,7 +4557,8 @@ void BoUpSLP::reorderTopToBottom() { } continue; } - if (TE->State == TreeEntry::Vectorize && + if ((TE->State == TreeEntry::Vectorize || + TE->State == TreeEntry::PossibleStridedVectorize) && isa<ExtractElementInst, ExtractValueInst, LoadInst, StoreInst, InsertElementInst>(TE->getMainOp()) && !TE->isAltShuffle()) { @@ -4555,6 +4599,10 @@ bool BoUpSLP::canReorderOperands( })) continue; if (TreeEntry *TE = getVectorizedOperand(UserTE, I)) { + // FIXME: Do not reorder (possible!) strided vectorized nodes, they + // require reordering of the operands, which is not implemented yet. + if (TE->State == TreeEntry::PossibleStridedVectorize) + return false; // Do not reorder if operand node is used by many user nodes. if (any_of(TE->UserTreeIndices, [UserTE](const EdgeInfo &EI) { return EI.UserTE != UserTE; })) @@ -4567,7 +4615,8 @@ bool BoUpSLP::canReorderOperands( // simply add to the list of gathered ops. // If there are reused scalars, process this node as a regular vectorize // node, just reorder reuses mask. - if (TE->State != TreeEntry::Vectorize && TE->ReuseShuffleIndices.empty()) + if (TE->State != TreeEntry::Vectorize && + TE->ReuseShuffleIndices.empty() && TE->ReorderIndices.empty()) GatherOps.push_back(TE); continue; } @@ -4602,18 +4651,19 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { // Currently the are vectorized loads,extracts without alternate operands + // some gathering of extracts. SmallVector<TreeEntry *> NonVectorized; - for_each(VectorizableTree, [this, &OrderedEntries, &GathersToOrders, - &NonVectorized]( - const std::unique_ptr<TreeEntry> &TE) { - if (TE->State != TreeEntry::Vectorize) + for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree) { + if (TE->State != TreeEntry::Vectorize && + TE->State != TreeEntry::PossibleStridedVectorize) NonVectorized.push_back(TE.get()); if (std::optional<OrdersType> CurrentOrder = getReorderingData(*TE, /*TopToBottom=*/false)) { OrderedEntries.insert(TE.get()); - if (TE->State != TreeEntry::Vectorize || !TE->ReuseShuffleIndices.empty()) + if (!(TE->State == TreeEntry::Vectorize || + TE->State == TreeEntry::PossibleStridedVectorize) || + !TE->ReuseShuffleIndices.empty()) GathersToOrders.try_emplace(TE.get(), *CurrentOrder); } - }); + } // 1. Propagate order to the graph nodes, which use only reordered nodes. // I.e., if the node has operands, that are reordered, try to make at least @@ -4627,6 +4677,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { SmallVector<TreeEntry *> Filtered; for (TreeEntry *TE : OrderedEntries) { if (!(TE->State == TreeEntry::Vectorize || + TE->State == TreeEntry::PossibleStridedVectorize || (TE->State == TreeEntry::NeedToGather && GathersToOrders.count(TE))) || TE->UserTreeIndices.empty() || !TE->ReuseShuffleIndices.empty() || @@ -4649,8 +4700,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { } } // Erase filtered entries. - for_each(Filtered, - [&OrderedEntries](TreeEntry *TE) { OrderedEntries.remove(TE); }); + for (TreeEntry *TE : Filtered) + OrderedEntries.remove(TE); SmallVector< std::pair<TreeEntry *, SmallVector<std::pair<unsigned, TreeEntry *>>>> UsersVec(Users.begin(), Users.end()); @@ -4662,10 +4713,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { SmallVector<TreeEntry *> GatherOps; if (!canReorderOperands(Data.first, Data.second, NonVectorized, GatherOps)) { - for_each(Data.second, - [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { - OrderedEntries.remove(Op.second); - }); + for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) + OrderedEntries.remove(Op.second); continue; } // All operands are reordered and used only in this node - propagate the @@ -4673,6 +4722,9 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { MapVector<OrdersType, unsigned, DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>> OrdersUses; + // Last chance orders - scatter vectorize. Try to use their orders if no + // other orders or the order is counted already. + SmallVector<std::pair<OrdersType, unsigned>> StridedVectorizeOrders; // Do the analysis for each tree entry only once, otherwise the order of // the same node my be considered several times, though might be not // profitable. @@ -4694,6 +4746,11 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { Data.second, [OpTE](const std::pair<unsigned, TreeEntry *> &P) { return P.second == OpTE; }); + // Postpone scatter orders. + if (OpTE->State == TreeEntry::PossibleStridedVectorize) { + StridedVectorizeOrders.emplace_back(Order, NumOps); + continue; + } // Stores actually store the mask, not the order, need to invert. if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && OpTE->getOpcode() == Instruction::Store && !Order.empty()) { @@ -4754,11 +4811,27 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { } // If no orders - skip current nodes and jump to the next one, if any. if (OrdersUses.empty()) { - for_each(Data.second, - [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { - OrderedEntries.remove(Op.second); - }); - continue; + if (StridedVectorizeOrders.empty() || + (Data.first->ReorderIndices.empty() && + Data.first->ReuseShuffleIndices.empty() && + !(IgnoreReorder && + Data.first == VectorizableTree.front().get()))) { + for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) + OrderedEntries.remove(Op.second); + continue; + } + // Add (potentially!) strided vectorize orders. + for (std::pair<OrdersType, unsigned> &Pair : StridedVectorizeOrders) + OrdersUses.insert(std::make_pair(Pair.first, 0)).first->second += + Pair.second; + } else { + // Account (potentially!) strided vectorize orders only if it was used + // already. + for (std::pair<OrdersType, unsigned> &Pair : StridedVectorizeOrders) { + auto *It = OrdersUses.find(Pair.first); + if (It != OrdersUses.end()) + It->second += Pair.second; + } } // Choose the best order. ArrayRef<unsigned> BestOrder = OrdersUses.front().first; @@ -4771,10 +4844,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { } // Set order of the user node (reordering of operands and user nodes). if (BestOrder.empty()) { - for_each(Data.second, - [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { - OrderedEntries.remove(Op.second); - }); + for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) + OrderedEntries.remove(Op.second); continue; } // Erase operands from OrderedEntries list and adjust their orders. @@ -4796,7 +4867,10 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { continue; } // Gathers are processed separately. - if (TE->State != TreeEntry::Vectorize) + if (TE->State != TreeEntry::Vectorize && + TE->State != TreeEntry::PossibleStridedVectorize && + (TE->State != TreeEntry::ScatterVectorize || + TE->ReorderIndices.empty())) continue; assert((BestOrder.size() == TE->ReorderIndices.size() || TE->ReorderIndices.empty()) && @@ -4825,7 +4899,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { Data.first->isAltShuffle()) Data.first->reorderOperands(Mask); if (!isa<InsertElementInst, StoreInst>(Data.first->getMainOp()) || - Data.first->isAltShuffle()) { + Data.first->isAltShuffle() || + Data.first->State == TreeEntry::PossibleStridedVectorize) { reorderScalars(Data.first->Scalars, Mask); reorderOrder(Data.first->ReorderIndices, MaskOrder); if (Data.first->ReuseShuffleIndices.empty() && @@ -4859,10 +4934,12 @@ void BoUpSLP::buildExternalUses( // For each lane: for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { Value *Scalar = Entry->Scalars[Lane]; + if (!isa<Instruction>(Scalar)) + continue; int FoundLane = Entry->findLaneForValue(Scalar); // Check if the scalar is externally used as an extra arg. - auto ExtI = ExternallyUsedValues.find(Scalar); + const auto *ExtI = ExternallyUsedValues.find(Scalar); if (ExtI != ExternallyUsedValues.end()) { LLVM_DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " << Lane << " from " << *Scalar << ".\n"); @@ -4886,7 +4963,8 @@ void BoUpSLP::buildExternalUses( // be used. if (UseScalar != U || UseEntry->State == TreeEntry::ScatterVectorize || - !InTreeUserNeedToExtract(Scalar, UserInst, TLI)) { + UseEntry->State == TreeEntry::PossibleStridedVectorize || + !doesInTreeUserNeedToExtract(Scalar, UserInst, TLI)) { LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U << ".\n"); assert(UseEntry->State != TreeEntry::NeedToGather && "Bad state"); @@ -4906,9 +4984,9 @@ void BoUpSLP::buildExternalUses( } } -DenseMap<Value *, SmallVector<StoreInst *, 4>> +DenseMap<Value *, SmallVector<StoreInst *>> BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const { - DenseMap<Value *, SmallVector<StoreInst *, 4>> PtrToStoresMap; + DenseMap<Value *, SmallVector<StoreInst *>> PtrToStoresMap; for (unsigned Lane : seq<unsigned>(0, TE->Scalars.size())) { Value *V = TE->Scalars[Lane]; // To save compilation time we don't visit if we have too many users. @@ -4947,14 +5025,14 @@ BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const { return PtrToStoresMap; } -bool BoUpSLP::canFormVector(const SmallVector<StoreInst *, 4> &StoresVec, +bool BoUpSLP::canFormVector(ArrayRef<StoreInst *> StoresVec, OrdersType &ReorderIndices) const { // We check whether the stores in StoreVec can form a vector by sorting them // and checking whether they are consecutive. // To avoid calling getPointersDiff() while sorting we create a vector of // pairs {store, offset from first} and sort this instead. - SmallVector<std::pair<StoreInst *, int>, 4> StoreOffsetVec(StoresVec.size()); + SmallVector<std::pair<StoreInst *, int>> StoreOffsetVec(StoresVec.size()); StoreInst *S0 = StoresVec[0]; StoreOffsetVec[0] = {S0, 0}; Type *S0Ty = S0->getValueOperand()->getType(); @@ -5023,7 +5101,7 @@ SmallVector<BoUpSLP::OrdersType, 1> BoUpSLP::findExternalStoreUsersReorderIndices(TreeEntry *TE) const { unsigned NumLanes = TE->Scalars.size(); - DenseMap<Value *, SmallVector<StoreInst *, 4>> PtrToStoresMap = + DenseMap<Value *, SmallVector<StoreInst *>> PtrToStoresMap = collectUserStores(TE); // Holds the reorder indices for each candidate store vector that is a user of @@ -5244,6 +5322,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( return TreeEntry::Vectorize; case LoadsState::ScatterVectorize: return TreeEntry::ScatterVectorize; + case LoadsState::PossibleStridedVectorize: + return TreeEntry::PossibleStridedVectorize; case LoadsState::Gather: #ifndef NDEBUG Type *ScalarTy = VL0->getType(); @@ -5416,7 +5496,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); VFShape Shape = VFShape::get( - *CI, ElementCount::getFixed(static_cast<unsigned int>(VL.size())), + CI->getFunctionType(), + ElementCount::getFixed(static_cast<unsigned int>(VL.size())), false /*HasGlobalPred*/); Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); @@ -5488,9 +5569,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, SmallVector<int> ReuseShuffleIndicies; SmallVector<Value *> UniqueValues; - auto &&TryToFindDuplicates = [&VL, &ReuseShuffleIndicies, &UniqueValues, - &UserTreeIdx, - this](const InstructionsState &S) { + SmallVector<Value *> NonUniqueValueVL; + auto TryToFindDuplicates = [&](const InstructionsState &S, + bool DoNotFail = false) { // Check that every instruction appears once in this bundle. DenseMap<Value *, unsigned> UniquePositions(VL.size()); for (Value *V : VL) { @@ -5517,6 +5598,24 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, !isConstant(V); })) || !llvm::has_single_bit<uint32_t>(NumUniqueScalarValues)) { + if (DoNotFail && UniquePositions.size() > 1 && + NumUniqueScalarValues > 1 && S.MainOp->isSafeToRemove() && + all_of(UniqueValues, [=](Value *V) { + return isa<ExtractElementInst>(V) || + areAllUsersVectorized(cast<Instruction>(V), + UserIgnoreList); + })) { + unsigned PWSz = PowerOf2Ceil(UniqueValues.size()); + if (PWSz == VL.size()) { + ReuseShuffleIndicies.clear(); + } else { + NonUniqueValueVL.assign(UniqueValues.begin(), UniqueValues.end()); + NonUniqueValueVL.append(PWSz - UniqueValues.size(), + UniqueValues.back()); + VL = NonUniqueValueVL; + } + return true; + } LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return false; @@ -5528,6 +5627,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, InstructionsState S = getSameOpcode(VL, *TLI); + // Don't vectorize ephemeral values. + if (!EphValues.empty()) { + for (Value *V : VL) { + if (EphValues.count(V)) { + LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V + << ") is ephemeral.\n"); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); + return; + } + } + } + // Gather if we hit the RecursionMaxDepth, unless this is a load (or z/sext of // a load), in which case peek through to include it in the tree, without // ballooning over-budget. @@ -5633,7 +5744,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BasicBlock *BB = nullptr; bool IsScatterVectorizeUserTE = UserTreeIdx.UserTE && - UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; + (UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize || + UserTreeIdx.UserTE->State == TreeEntry::PossibleStridedVectorize); bool AreAllSameInsts = (S.getOpcode() && allSameBlock(VL)) || (S.OpValue->getType()->isPointerTy() && IsScatterVectorizeUserTE && @@ -5665,39 +5777,44 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // We now know that this is a vector of instructions of the same type from // the same block. - // Don't vectorize ephemeral values. - if (!EphValues.empty()) { - for (Value *V : VL) { - if (EphValues.count(V)) { - LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V - << ") is ephemeral.\n"); - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); - return; - } - } - } - // Check if this is a duplicate of another entry. if (TreeEntry *E = getTreeEntry(S.OpValue)) { LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.OpValue << ".\n"); if (!E->isSame(VL)) { - LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); - if (TryToFindDuplicates(S)) - newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + auto It = MultiNodeScalars.find(S.OpValue); + if (It != MultiNodeScalars.end()) { + auto *TEIt = find_if(It->getSecond(), + [&](TreeEntry *ME) { return ME->isSame(VL); }); + if (TEIt != It->getSecond().end()) + E = *TEIt; + else + E = nullptr; + } else { + E = nullptr; + } + } + if (!E) { + if (!doesNotNeedToBeScheduled(S.OpValue)) { + LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + return; + } + } else { + // Record the reuse of the tree node. FIXME, currently this is only used + // to properly draw the graph rather than for the actual vectorization. + E->UserTreeIndices.push_back(UserTreeIdx); + LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue + << ".\n"); return; } - // Record the reuse of the tree node. FIXME, currently this is only used to - // properly draw the graph rather than for the actual vectorization. - E->UserTreeIndices.push_back(UserTreeIdx); - LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue - << ".\n"); - return; } // Check that none of the instructions in the bundle are already in the tree. for (Value *V : VL) { - if (!IsScatterVectorizeUserTE && !isa<Instruction>(V)) + if ((!IsScatterVectorizeUserTE && !isa<Instruction>(V)) || + doesNotNeedToBeScheduled(V)) continue; if (getTreeEntry(V)) { LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V @@ -5725,7 +5842,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Special processing for sorted pointers for ScatterVectorize node with // constant indeces only. if (AreAllSameInsts && UserTreeIdx.UserTE && - UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize && + (UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize || + UserTreeIdx.UserTE->State == TreeEntry::PossibleStridedVectorize) && !(S.getOpcode() && allSameBlock(VL))) { assert(S.OpValue->getType()->isPointerTy() && count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >= @@ -5760,7 +5878,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } // Check that every instruction appears once in this bundle. - if (!TryToFindDuplicates(S)) + if (!TryToFindDuplicates(S, /*DoNotFail=*/true)) return; // Perform specific checks for each particular instruction kind. @@ -5780,7 +5898,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BlockScheduling &BS = *BSRef; - std::optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S); + std::optional<ScheduleData *> Bundle = + BS.tryScheduleBundle(UniqueValues, this, S); #ifdef EXPENSIVE_CHECKS // Make sure we didn't break any internal invariants BS.verify(); @@ -5905,6 +6024,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. TreeEntry *TE = nullptr; + fixupOrderingIndices(CurrentOrder); switch (State) { case TreeEntry::Vectorize: if (CurrentOrder.empty()) { @@ -5913,7 +6033,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); } else { - fixupOrderingIndices(CurrentOrder); // Need to reorder. TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies, CurrentOrder); @@ -5921,6 +6040,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } TE->setOperandsInOrder(); break; + case TreeEntry::PossibleStridedVectorize: + // Vectorizing non-consecutive loads with `llvm.masked.gather`. + if (CurrentOrder.empty()) { + TE = newTreeEntry(VL, TreeEntry::PossibleStridedVectorize, Bundle, S, + UserTreeIdx, ReuseShuffleIndicies); + } else { + TE = newTreeEntry(VL, TreeEntry::PossibleStridedVectorize, Bundle, S, + UserTreeIdx, ReuseShuffleIndicies, CurrentOrder); + } + TE->setOperandsInOrder(); + buildTree_rec(PointerOps, Depth + 1, {TE, 0}); + LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n"); + break; case TreeEntry::ScatterVectorize: // Vectorizing non-consecutive loads with `llvm.masked.gather`. TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, S, @@ -5951,13 +6083,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n"); TE->setOperandsInOrder(); - for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { + for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) { ValueList Operands; // Prepare the operand vector. for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(i)); + Operands.push_back(cast<Instruction>(V)->getOperand(I)); - buildTree_rec(Operands, Depth + 1, {TE, i}); + buildTree_rec(Operands, Depth + 1, {TE, I}); } return; } @@ -6031,13 +6163,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } TE->setOperandsInOrder(); - for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { + for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) { ValueList Operands; // Prepare the operand vector. for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(i)); + Operands.push_back(cast<Instruction>(V)->getOperand(I)); - buildTree_rec(Operands, Depth + 1, {TE, i}); + buildTree_rec(Operands, Depth + 1, {TE, I}); } return; } @@ -6087,8 +6219,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!CI) Operands.back().push_back(Op); else - Operands.back().push_back(ConstantExpr::getIntegerCast( - CI, Ty, CI->getValue().isSignBitSet())); + Operands.back().push_back(ConstantFoldIntegerCast( + CI, Ty, CI->getValue().isSignBitSet(), *DL)); } TE->setOperand(IndexIdx, Operands.back()); @@ -6132,18 +6264,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); TE->setOperandsInOrder(); - for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) { - // For scalar operands no need to to create an entry since no need to + for (unsigned I : seq<unsigned>(0, CI->arg_size())) { + // For scalar operands no need to create an entry since no need to // vectorize it. - if (isVectorIntrinsicWithScalarOpAtArg(ID, i)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, I)) continue; ValueList Operands; // Prepare the operand vector. for (Value *V : VL) { auto *CI2 = cast<CallInst>(V); - Operands.push_back(CI2->getArgOperand(i)); + Operands.push_back(CI2->getArgOperand(I)); } - buildTree_rec(Operands, Depth + 1, {TE, i}); + buildTree_rec(Operands, Depth + 1, {TE, I}); } return; } @@ -6194,13 +6326,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } TE->setOperandsInOrder(); - for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { + for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) { ValueList Operands; // Prepare the operand vector. for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(i)); + Operands.push_back(cast<Instruction>(V)->getOperand(I)); - buildTree_rec(Operands, Depth + 1, {TE, i}); + buildTree_rec(Operands, Depth + 1, {TE, I}); } return; } @@ -6210,7 +6342,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, llvm_unreachable("Unexpected vectorization of the instructions."); } -unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { +unsigned BoUpSLP::canMapToVector(Type *T) const { unsigned N = 1; Type *EltTy = T; @@ -6234,15 +6366,16 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { if (!isValidElementType(EltTy)) return 0; - uint64_t VTSize = DL.getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N)); + uint64_t VTSize = DL->getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N)); if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || - VTSize != DL.getTypeStoreSizeInBits(T)) + VTSize != DL->getTypeStoreSizeInBits(T)) return 0; return N; } bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, - SmallVectorImpl<unsigned> &CurrentOrder) const { + SmallVectorImpl<unsigned> &CurrentOrder, + bool ResizeAllowed) const { const auto *It = find_if(VL, [](Value *V) { return isa<ExtractElementInst, ExtractValueInst>(V); }); @@ -6263,8 +6396,7 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, // We have to extract from a vector/aggregate with the same number of elements. unsigned NElts; if (E0->getOpcode() == Instruction::ExtractValue) { - const DataLayout &DL = E0->getModule()->getDataLayout(); - NElts = canMapToVector(Vec->getType(), DL); + NElts = canMapToVector(Vec->getType()); if (!NElts) return false; // Check if load can be rewritten as load of vector. @@ -6275,46 +6407,55 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, NElts = cast<FixedVectorType>(Vec->getType())->getNumElements(); } - if (NElts != VL.size()) - return false; - - // Check that all of the indices extract from the correct offset. - bool ShouldKeepOrder = true; unsigned E = VL.size(); - // Assign to all items the initial value E + 1 so we can check if the extract - // instruction index was used already. - // Also, later we can check that all the indices are used and we have a - // consecutive access in the extract instructions, by checking that no - // element of CurrentOrder still has value E + 1. - CurrentOrder.assign(E, E); - unsigned I = 0; - for (; I < E; ++I) { - auto *Inst = dyn_cast<Instruction>(VL[I]); + if (!ResizeAllowed && NElts != E) + return false; + SmallVector<int> Indices(E, PoisonMaskElem); + unsigned MinIdx = NElts, MaxIdx = 0; + for (auto [I, V] : enumerate(VL)) { + auto *Inst = dyn_cast<Instruction>(V); if (!Inst) continue; if (Inst->getOperand(0) != Vec) - break; + return false; if (auto *EE = dyn_cast<ExtractElementInst>(Inst)) if (isa<UndefValue>(EE->getIndexOperand())) continue; std::optional<unsigned> Idx = getExtractIndex(Inst); if (!Idx) - break; + return false; const unsigned ExtIdx = *Idx; - if (ExtIdx != I) { - if (ExtIdx >= E || CurrentOrder[ExtIdx] != E) - break; - ShouldKeepOrder = false; - CurrentOrder[ExtIdx] = I; - } else { - if (CurrentOrder[I] != E) - break; - CurrentOrder[I] = I; - } + if (ExtIdx >= NElts) + continue; + Indices[I] = ExtIdx; + if (MinIdx > ExtIdx) + MinIdx = ExtIdx; + if (MaxIdx < ExtIdx) + MaxIdx = ExtIdx; } - if (I < E) { - CurrentOrder.clear(); + if (MaxIdx - MinIdx + 1 > E) return false; + if (MaxIdx + 1 <= E) + MinIdx = 0; + + // Check that all of the indices extract from the correct offset. + bool ShouldKeepOrder = true; + // Assign to all items the initial value E + 1 so we can check if the extract + // instruction index was used already. + // Also, later we can check that all the indices are used and we have a + // consecutive access in the extract instructions, by checking that no + // element of CurrentOrder still has value E + 1. + CurrentOrder.assign(E, E); + for (unsigned I = 0; I < E; ++I) { + if (Indices[I] == PoisonMaskElem) + continue; + const unsigned ExtIdx = Indices[I] - MinIdx; + if (CurrentOrder[ExtIdx] != E) { + CurrentOrder.clear(); + return false; + } + ShouldKeepOrder &= ExtIdx == I; + CurrentOrder[ExtIdx] = I; } if (ShouldKeepOrder) CurrentOrder.clear(); @@ -6322,9 +6463,9 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, return ShouldKeepOrder; } -bool BoUpSLP::areAllUsersVectorized(Instruction *I, - ArrayRef<Value *> VectorizedVals) const { - return (I->hasOneUse() && is_contained(VectorizedVals, I)) || +bool BoUpSLP::areAllUsersVectorized( + Instruction *I, const SmallDenseSet<Value *> *VectorizedVals) const { + return (I->hasOneUse() && (!VectorizedVals || VectorizedVals->contains(I))) || all_of(I->users(), [this](User *U) { return ScalarToTreeEntry.count(U) > 0 || isVectorLikeInstWithConstOps(U) || @@ -6351,8 +6492,8 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, auto IntrinsicCost = TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput); - auto Shape = VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>( - VecTy->getNumElements())), + auto Shape = VFShape::get(CI->getFunctionType(), + ElementCount::getFixed(VecTy->getNumElements()), false /*HasGlobalPred*/); Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape); auto LibCost = IntrinsicCost; @@ -6365,16 +6506,11 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, return {IntrinsicCost, LibCost}; } -/// Build shuffle mask for shuffle graph entries and lists of main and alternate -/// operations operands. -static void -buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, - ArrayRef<int> ReusesIndices, - const function_ref<bool(Instruction *)> IsAltOp, - SmallVectorImpl<int> &Mask, - SmallVectorImpl<Value *> *OpScalars = nullptr, - SmallVectorImpl<Value *> *AltScalars = nullptr) { - unsigned Sz = VL.size(); +void BoUpSLP::TreeEntry::buildAltOpShuffleMask( + const function_ref<bool(Instruction *)> IsAltOp, SmallVectorImpl<int> &Mask, + SmallVectorImpl<Value *> *OpScalars, + SmallVectorImpl<Value *> *AltScalars) const { + unsigned Sz = Scalars.size(); Mask.assign(Sz, PoisonMaskElem); SmallVector<int> OrderMask; if (!ReorderIndices.empty()) @@ -6383,7 +6519,7 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, unsigned Idx = I; if (!ReorderIndices.empty()) Idx = OrderMask[I]; - auto *OpInst = cast<Instruction>(VL[Idx]); + auto *OpInst = cast<Instruction>(Scalars[Idx]); if (IsAltOp(OpInst)) { Mask[I] = Sz + Idx; if (AltScalars) @@ -6394,9 +6530,9 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, OpScalars->push_back(OpInst); } } - if (!ReusesIndices.empty()) { - SmallVector<int> NewMask(ReusesIndices.size(), PoisonMaskElem); - transform(ReusesIndices, NewMask.begin(), [&Mask](int Idx) { + if (!ReuseShuffleIndices.empty()) { + SmallVector<int> NewMask(ReuseShuffleIndices.size(), PoisonMaskElem); + transform(ReuseShuffleIndices, NewMask.begin(), [&Mask](int Idx) { return Idx != PoisonMaskElem ? Mask[Idx] : PoisonMaskElem; }); Mask.swap(NewMask); @@ -6429,52 +6565,27 @@ static bool isAlternateInstruction(const Instruction *I, return I->getOpcode() == AltOp->getOpcode(); } -TTI::OperandValueInfo BoUpSLP::getOperandInfo(ArrayRef<Value *> VL, - unsigned OpIdx) { - assert(!VL.empty()); - const auto *I0 = cast<Instruction>(*find_if(VL, Instruction::classof)); - const auto *Op0 = I0->getOperand(OpIdx); +TTI::OperandValueInfo BoUpSLP::getOperandInfo(ArrayRef<Value *> Ops) { + assert(!Ops.empty()); + const auto *Op0 = Ops.front(); - const bool IsConstant = all_of(VL, [&](Value *V) { + const bool IsConstant = all_of(Ops, [](Value *V) { // TODO: We should allow undef elements here - const auto *I = dyn_cast<Instruction>(V); - if (!I) - return true; - auto *Op = I->getOperand(OpIdx); - return isConstant(Op) && !isa<UndefValue>(Op); + return isConstant(V) && !isa<UndefValue>(V); }); - const bool IsUniform = all_of(VL, [&](Value *V) { + const bool IsUniform = all_of(Ops, [=](Value *V) { // TODO: We should allow undef elements here - const auto *I = dyn_cast<Instruction>(V); - if (!I) - return false; - return I->getOperand(OpIdx) == Op0; + return V == Op0; }); - const bool IsPowerOfTwo = all_of(VL, [&](Value *V) { + const bool IsPowerOfTwo = all_of(Ops, [](Value *V) { // TODO: We should allow undef elements here - const auto *I = dyn_cast<Instruction>(V); - if (!I) { - assert((isa<UndefValue>(V) || - I0->getOpcode() == Instruction::GetElementPtr) && - "Expected undef or GEP."); - return true; - } - auto *Op = I->getOperand(OpIdx); - if (auto *CI = dyn_cast<ConstantInt>(Op)) + if (auto *CI = dyn_cast<ConstantInt>(V)) return CI->getValue().isPowerOf2(); return false; }); - const bool IsNegatedPowerOfTwo = all_of(VL, [&](Value *V) { + const bool IsNegatedPowerOfTwo = all_of(Ops, [](Value *V) { // TODO: We should allow undef elements here - const auto *I = dyn_cast<Instruction>(V); - if (!I) { - assert((isa<UndefValue>(V) || - I0->getOpcode() == Instruction::GetElementPtr) && - "Expected undef or GEP."); - return true; - } - const auto *Op = I->getOperand(OpIdx); - if (auto *CI = dyn_cast<ConstantInt>(Op)) + if (auto *CI = dyn_cast<ConstantInt>(V)) return CI->getValue().isNegatedPowerOf2(); return false; }); @@ -6505,9 +6616,24 @@ protected: bool IsStrict) { int Limit = Mask.size(); int VF = VecTy->getNumElements(); - return (VF == Limit || !IsStrict) && - all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) && - ShuffleVectorInst::isIdentityMask(Mask); + int Index = -1; + if (VF == Limit && ShuffleVectorInst::isIdentityMask(Mask, Limit)) + return true; + if (!IsStrict) { + // Consider extract subvector starting from index 0. + if (ShuffleVectorInst::isExtractSubvectorMask(Mask, VF, Index) && + Index == 0) + return true; + // All VF-size submasks are identity (e.g. + // <poison,poison,poison,poison,0,1,2,poison,poison,1,2,3> etc. for VF 4). + if (Limit % VF == 0 && all_of(seq<int>(0, Limit / VF), [=](int Idx) { + ArrayRef<int> Slice = Mask.slice(Idx * VF, VF); + return all_of(Slice, [](int I) { return I == PoisonMaskElem; }) || + ShuffleVectorInst::isIdentityMask(Slice, VF); + })) + return true; + } + return false; } /// Tries to combine 2 different masks into single one. @@ -6577,7 +6703,8 @@ protected: if (isIdentityMask(Mask, SVTy, /*IsStrict=*/false)) { if (!IdentityOp || !SinglePermute || (isIdentityMask(Mask, SVTy, /*IsStrict=*/true) && - !ShuffleVectorInst::isZeroEltSplatMask(IdentityMask))) { + !ShuffleVectorInst::isZeroEltSplatMask(IdentityMask, + IdentityMask.size()))) { IdentityOp = SV; // Store current mask in the IdentityMask so later we did not lost // this info if IdentityOp is selected as the best candidate for the @@ -6647,7 +6774,7 @@ protected: } if (auto *OpTy = dyn_cast<FixedVectorType>(Op->getType()); !OpTy || !isIdentityMask(Mask, OpTy, SinglePermute) || - ShuffleVectorInst::isZeroEltSplatMask(Mask)) { + ShuffleVectorInst::isZeroEltSplatMask(Mask, Mask.size())) { if (IdentityOp) { V = IdentityOp; assert(Mask.size() == IdentityMask.size() && @@ -6663,7 +6790,7 @@ protected: /*IsStrict=*/true) || (Shuffle && Mask.size() == Shuffle->getShuffleMask().size() && Shuffle->isZeroEltSplat() && - ShuffleVectorInst::isZeroEltSplatMask(Mask))); + ShuffleVectorInst::isZeroEltSplatMask(Mask, Mask.size()))); } V = Op; return false; @@ -6768,11 +6895,9 @@ protected: CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF); } } - const int Limit = CombinedMask1.size() * 2; - if (Op1 == Op2 && Limit == 2 * VF && - all_of(CombinedMask1, [=](int Idx) { return Idx < Limit; }) && - (ShuffleVectorInst::isIdentityMask(CombinedMask1) || - (ShuffleVectorInst::isZeroEltSplatMask(CombinedMask1) && + if (Op1 == Op2 && + (ShuffleVectorInst::isIdentityMask(CombinedMask1, VF) || + (ShuffleVectorInst::isZeroEltSplatMask(CombinedMask1, VF) && isa<ShuffleVectorInst>(Op1) && cast<ShuffleVectorInst>(Op1)->getShuffleMask() == ArrayRef(CombinedMask1)))) @@ -6807,10 +6932,29 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors; const TargetTransformInfo &TTI; InstructionCost Cost = 0; - ArrayRef<Value *> VectorizedVals; + SmallDenseSet<Value *> VectorizedVals; BoUpSLP &R; SmallPtrSetImpl<Value *> &CheckedExtracts; constexpr static TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + /// While set, still trying to estimate the cost for the same nodes and we + /// can delay actual cost estimation (virtual shuffle instruction emission). + /// May help better estimate the cost if same nodes must be permuted + allows + /// to move most of the long shuffles cost estimation to TTI. + bool SameNodesEstimated = true; + + static Constant *getAllOnesValue(const DataLayout &DL, Type *Ty) { + if (Ty->getScalarType()->isPointerTy()) { + Constant *Res = ConstantExpr::getIntToPtr( + ConstantInt::getAllOnesValue( + IntegerType::get(Ty->getContext(), + DL.getTypeStoreSizeInBits(Ty->getScalarType()))), + Ty->getScalarType()); + if (auto *VTy = dyn_cast<VectorType>(Ty)) + Res = ConstantVector::getSplat(VTy->getElementCount(), Res); + return Res; + } + return Constant::getAllOnesValue(Ty); + } InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) { if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof)) @@ -6821,20 +6965,35 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { // Improve gather cost for gather of loads, if we can group some of the // loads into vector loads. InstructionsState S = getSameOpcode(VL, *R.TLI); - if (VL.size() > 2 && S.getOpcode() == Instruction::Load && - !S.isAltShuffle() && + const unsigned Sz = R.DL->getTypeSizeInBits(VL.front()->getType()); + unsigned MinVF = R.getMinVF(2 * Sz); + if (VL.size() > 2 && + ((S.getOpcode() == Instruction::Load && !S.isAltShuffle()) || + (InVectors.empty() && + any_of(seq<unsigned>(0, VL.size() / MinVF), + [&](unsigned Idx) { + ArrayRef<Value *> SubVL = VL.slice(Idx * MinVF, MinVF); + InstructionsState S = getSameOpcode(SubVL, *R.TLI); + return S.getOpcode() == Instruction::Load && + !S.isAltShuffle(); + }))) && !all_of(Gathers, [&](Value *V) { return R.getTreeEntry(V); }) && !isSplat(Gathers)) { - BoUpSLP::ValueSet VectorizedLoads; + SetVector<Value *> VectorizedLoads; + SmallVector<LoadInst *> VectorizedStarts; + SmallVector<std::pair<unsigned, unsigned>> ScatterVectorized; unsigned StartIdx = 0; unsigned VF = VL.size() / 2; - unsigned VectorizedCnt = 0; - unsigned ScatterVectorizeCnt = 0; - const unsigned Sz = R.DL->getTypeSizeInBits(S.MainOp->getType()); - for (unsigned MinVF = R.getMinVF(2 * Sz); VF >= MinVF; VF /= 2) { + for (; VF >= MinVF; VF /= 2) { for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End; Cnt += VF) { ArrayRef<Value *> Slice = VL.slice(Cnt, VF); + if (S.getOpcode() != Instruction::Load || S.isAltShuffle()) { + InstructionsState SliceS = getSameOpcode(Slice, *R.TLI); + if (SliceS.getOpcode() != Instruction::Load || + SliceS.isAltShuffle()) + continue; + } if (!VectorizedLoads.count(Slice.front()) && !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) { SmallVector<Value *> PointerOps; @@ -6845,12 +7004,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { switch (LS) { case LoadsState::Vectorize: case LoadsState::ScatterVectorize: + case LoadsState::PossibleStridedVectorize: // Mark the vectorized loads so that we don't vectorize them // again. - if (LS == LoadsState::Vectorize) - ++VectorizedCnt; + // TODO: better handling of loads with reorders. + if (LS == LoadsState::Vectorize && CurrentOrder.empty()) + VectorizedStarts.push_back(cast<LoadInst>(Slice.front())); else - ++ScatterVectorizeCnt; + ScatterVectorized.emplace_back(Cnt, VF); VectorizedLoads.insert(Slice.begin(), Slice.end()); // If we vectorized initial block, no need to try to vectorize // it again. @@ -6881,8 +7042,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } // Exclude potentially vectorized loads from list of gathered // scalars. - auto *LI = cast<LoadInst>(S.MainOp); - Gathers.assign(Gathers.size(), PoisonValue::get(LI->getType())); + Gathers.assign(Gathers.size(), PoisonValue::get(VL.front()->getType())); // The cost for vectorized loads. InstructionCost ScalarsCost = 0; for (Value *V : VectorizedLoads) { @@ -6892,17 +7052,24 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { LI->getAlign(), LI->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo(), LI); } - auto *LoadTy = FixedVectorType::get(LI->getType(), VF); - Align Alignment = LI->getAlign(); - GatherCost += - VectorizedCnt * - TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, - LI->getPointerAddressSpace(), CostKind, - TTI::OperandValueInfo(), LI); - GatherCost += ScatterVectorizeCnt * - TTI.getGatherScatterOpCost( - Instruction::Load, LoadTy, LI->getPointerOperand(), - /*VariableMask=*/false, Alignment, CostKind, LI); + auto *LoadTy = FixedVectorType::get(VL.front()->getType(), VF); + for (LoadInst *LI : VectorizedStarts) { + Align Alignment = LI->getAlign(); + GatherCost += + TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, + LI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo(), LI); + } + for (std::pair<unsigned, unsigned> P : ScatterVectorized) { + auto *LI0 = cast<LoadInst>(VL[P.first]); + Align CommonAlignment = LI0->getAlign(); + for (Value *V : VL.slice(P.first + 1, VF - 1)) + CommonAlignment = + std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); + GatherCost += TTI.getGatherScatterOpCost( + Instruction::Load, LoadTy, LI0->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind, LI0); + } if (NeedInsertSubvectorAnalysis) { // Add the cost for the subvectors insert. for (int I = VF, E = VL.size(); I < E; I += VF) @@ -6938,77 +7105,137 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { : R.getGatherCost(Gathers, !Root && VL.equals(Gathers))); }; - /// Compute the cost of creating a vector of type \p VecTy containing the - /// extracted values from \p VL. - InstructionCost computeExtractCost(ArrayRef<Value *> VL, ArrayRef<int> Mask, - TTI::ShuffleKind ShuffleKind) { - auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size()); - unsigned NumOfParts = TTI.getNumberOfParts(VecTy); - - if (ShuffleKind != TargetTransformInfo::SK_PermuteSingleSrc || - !NumOfParts || VecTy->getNumElements() < NumOfParts) - return TTI.getShuffleCost(ShuffleKind, VecTy, Mask); - - bool AllConsecutive = true; - unsigned EltsPerVector = VecTy->getNumElements() / NumOfParts; - unsigned Idx = -1; + /// Compute the cost of creating a vector containing the extracted values from + /// \p VL. + InstructionCost + computeExtractCost(ArrayRef<Value *> VL, ArrayRef<int> Mask, + ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds, + unsigned NumParts) { + assert(VL.size() > NumParts && "Unexpected scalarized shuffle."); + unsigned NumElts = + std::accumulate(VL.begin(), VL.end(), 0, [](unsigned Sz, Value *V) { + auto *EE = dyn_cast<ExtractElementInst>(V); + if (!EE) + return Sz; + auto *VecTy = cast<FixedVectorType>(EE->getVectorOperandType()); + return std::max(Sz, VecTy->getNumElements()); + }); + unsigned NumSrcRegs = TTI.getNumberOfParts( + FixedVectorType::get(VL.front()->getType(), NumElts)); + if (NumSrcRegs == 0) + NumSrcRegs = 1; + // FIXME: this must be moved to TTI for better estimation. + unsigned EltsPerVector = PowerOf2Ceil(std::max( + divideCeil(VL.size(), NumParts), divideCeil(NumElts, NumSrcRegs))); + auto CheckPerRegistersShuffle = + [&](MutableArrayRef<int> Mask) -> std::optional<TTI::ShuffleKind> { + DenseSet<int> RegIndices; + // Check that if trying to permute same single/2 input vectors. + TTI::ShuffleKind ShuffleKind = TTI::SK_PermuteSingleSrc; + int FirstRegId = -1; + for (int &I : Mask) { + if (I == PoisonMaskElem) + continue; + int RegId = (I / NumElts) * NumParts + (I % NumElts) / EltsPerVector; + if (FirstRegId < 0) + FirstRegId = RegId; + RegIndices.insert(RegId); + if (RegIndices.size() > 2) + return std::nullopt; + if (RegIndices.size() == 2) + ShuffleKind = TTI::SK_PermuteTwoSrc; + I = (I % NumElts) % EltsPerVector + + (RegId == FirstRegId ? 0 : EltsPerVector); + } + return ShuffleKind; + }; InstructionCost Cost = 0; // Process extracts in blocks of EltsPerVector to check if the source vector // operand can be re-used directly. If not, add the cost of creating a // shuffle to extract the values into a vector register. - SmallVector<int> RegMask(EltsPerVector, PoisonMaskElem); - for (auto *V : VL) { - ++Idx; - - // Reached the start of a new vector registers. - if (Idx % EltsPerVector == 0) { - RegMask.assign(EltsPerVector, PoisonMaskElem); - AllConsecutive = true; + for (unsigned Part = 0; Part < NumParts; ++Part) { + if (!ShuffleKinds[Part]) continue; - } - - // Need to exclude undefs from analysis. - if (isa<UndefValue>(V) || Mask[Idx] == PoisonMaskElem) + ArrayRef<int> MaskSlice = + Mask.slice(Part * EltsPerVector, + (Part == NumParts - 1 && Mask.size() % EltsPerVector != 0) + ? Mask.size() % EltsPerVector + : EltsPerVector); + SmallVector<int> SubMask(EltsPerVector, PoisonMaskElem); + copy(MaskSlice, SubMask.begin()); + std::optional<TTI::ShuffleKind> RegShuffleKind = + CheckPerRegistersShuffle(SubMask); + if (!RegShuffleKind) { + Cost += TTI.getShuffleCost( + *ShuffleKinds[Part], + FixedVectorType::get(VL.front()->getType(), NumElts), MaskSlice); continue; - - // Check all extracts for a vector register on the target directly - // extract values in order. - unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V)); - if (!isa<UndefValue>(VL[Idx - 1]) && Mask[Idx - 1] != PoisonMaskElem) { - unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1])); - AllConsecutive &= PrevIdx + 1 == CurrentIdx && - CurrentIdx % EltsPerVector == Idx % EltsPerVector; - RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector; } - - if (AllConsecutive) - continue; - - // Skip all indices, except for the last index per vector block. - if ((Idx + 1) % EltsPerVector != 0 && Idx + 1 != VL.size()) - continue; - - // If we have a series of extracts which are not consecutive and hence - // cannot re-use the source vector register directly, compute the shuffle - // cost to extract the vector with EltsPerVector elements. - Cost += TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, - FixedVectorType::get(VecTy->getElementType(), EltsPerVector), - RegMask); + if (*RegShuffleKind != TTI::SK_PermuteSingleSrc || + !ShuffleVectorInst::isIdentityMask(SubMask, EltsPerVector)) { + Cost += TTI.getShuffleCost( + *RegShuffleKind, + FixedVectorType::get(VL.front()->getType(), EltsPerVector), + SubMask); + } } return Cost; } + /// Transforms mask \p CommonMask per given \p Mask to make proper set after + /// shuffle emission. + static void transformMaskAfterShuffle(MutableArrayRef<int> CommonMask, + ArrayRef<int> Mask) { + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != PoisonMaskElem) + CommonMask[Idx] = Idx; + } + /// Adds the cost of reshuffling \p E1 and \p E2 (if present), using given + /// mask \p Mask, register number \p Part, that includes \p SliceSize + /// elements. + void estimateNodesPermuteCost(const TreeEntry &E1, const TreeEntry *E2, + ArrayRef<int> Mask, unsigned Part, + unsigned SliceSize) { + if (SameNodesEstimated) { + // Delay the cost estimation if the same nodes are reshuffling. + // If we already requested the cost of reshuffling of E1 and E2 before, no + // need to estimate another cost with the sub-Mask, instead include this + // sub-Mask into the CommonMask to estimate it later and avoid double cost + // estimation. + if ((InVectors.size() == 2 && + InVectors.front().get<const TreeEntry *>() == &E1 && + InVectors.back().get<const TreeEntry *>() == E2) || + (!E2 && InVectors.front().get<const TreeEntry *>() == &E1)) { + assert(all_of(ArrayRef(CommonMask).slice(Part * SliceSize, SliceSize), + [](int Idx) { return Idx == PoisonMaskElem; }) && + "Expected all poisoned elements."); + ArrayRef<int> SubMask = + ArrayRef(Mask).slice(Part * SliceSize, SliceSize); + copy(SubMask, std::next(CommonMask.begin(), SliceSize * Part)); + return; + } + // Found non-matching nodes - need to estimate the cost for the matched + // and transform mask. + Cost += createShuffle(InVectors.front(), + InVectors.size() == 1 ? nullptr : InVectors.back(), + CommonMask); + transformMaskAfterShuffle(CommonMask, CommonMask); + } + SameNodesEstimated = false; + Cost += createShuffle(&E1, E2, Mask); + transformMaskAfterShuffle(CommonMask, Mask); + } class ShuffleCostBuilder { const TargetTransformInfo &TTI; static bool isEmptyOrIdentity(ArrayRef<int> Mask, unsigned VF) { - int Limit = 2 * VF; + int Index = -1; return Mask.empty() || (VF == Mask.size() && - all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) && - ShuffleVectorInst::isIdentityMask(Mask)); + ShuffleVectorInst::isIdentityMask(Mask, VF)) || + (ShuffleVectorInst::isExtractSubvectorMask(Mask, VF, Index) && + Index == 0); } public: @@ -7021,21 +7248,17 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); if (isEmptyOrIdentity(Mask, VF)) return TTI::TCC_Free; - return TTI.getShuffleCost( - TTI::SK_PermuteTwoSrc, - FixedVectorType::get( - cast<VectorType>(V1->getType())->getElementType(), Mask.size()), - Mask); + return TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, + cast<VectorType>(V1->getType()), Mask); } InstructionCost createShuffleVector(Value *V1, ArrayRef<int> Mask) const { // Empty mask or identity mask are free. - if (isEmptyOrIdentity(Mask, Mask.size())) + unsigned VF = + cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); + if (isEmptyOrIdentity(Mask, VF)) return TTI::TCC_Free; - return TTI.getShuffleCost( - TTI::SK_PermuteSingleSrc, - FixedVectorType::get( - cast<VectorType>(V1->getType())->getElementType(), Mask.size()), - Mask); + return TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, + cast<VectorType>(V1->getType()), Mask); } InstructionCost createIdentity(Value *) const { return TTI::TCC_Free; } InstructionCost createPoison(Type *Ty, unsigned VF) const { @@ -7052,139 +7275,226 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { const PointerUnion<Value *, const TreeEntry *> &P2, ArrayRef<int> Mask) { ShuffleCostBuilder Builder(TTI); + SmallVector<int> CommonMask(Mask.begin(), Mask.end()); Value *V1 = P1.dyn_cast<Value *>(), *V2 = P2.dyn_cast<Value *>(); - unsigned CommonVF = 0; - if (!V1) { + unsigned CommonVF = Mask.size(); + if (!V1 && !V2 && !P2.isNull()) { + // Shuffle 2 entry nodes. const TreeEntry *E = P1.get<const TreeEntry *>(); unsigned VF = E->getVectorFactor(); - if (V2) { - unsigned V2VF = cast<FixedVectorType>(V2->getType())->getNumElements(); - if (V2VF != VF && V2VF == E->Scalars.size()) - VF = E->Scalars.size(); - } else if (!P2.isNull()) { - const TreeEntry *E2 = P2.get<const TreeEntry *>(); - if (E->Scalars.size() == E2->Scalars.size()) - CommonVF = VF = E->Scalars.size(); - } else { - // P2 is empty, check that we have same node + reshuffle (if any). - if (E->Scalars.size() == Mask.size() && VF != Mask.size()) { - VF = E->Scalars.size(); - SmallVector<int> CommonMask(Mask.begin(), Mask.end()); - ::addMask(CommonMask, E->getCommonMask()); - V1 = Constant::getNullValue( - FixedVectorType::get(E->Scalars.front()->getType(), VF)); - return BaseShuffleAnalysis::createShuffle<InstructionCost>( - V1, nullptr, CommonMask, Builder); + const TreeEntry *E2 = P2.get<const TreeEntry *>(); + CommonVF = std::max(VF, E2->getVectorFactor()); + assert(all_of(Mask, + [=](int Idx) { + return Idx < 2 * static_cast<int>(CommonVF); + }) && + "All elements in mask must be less than 2 * CommonVF."); + if (E->Scalars.size() == E2->Scalars.size()) { + SmallVector<int> EMask = E->getCommonMask(); + SmallVector<int> E2Mask = E2->getCommonMask(); + if (!EMask.empty() || !E2Mask.empty()) { + for (int &Idx : CommonMask) { + if (Idx == PoisonMaskElem) + continue; + if (Idx < static_cast<int>(CommonVF) && !EMask.empty()) + Idx = EMask[Idx]; + else if (Idx >= static_cast<int>(CommonVF)) + Idx = (E2Mask.empty() ? Idx - CommonVF : E2Mask[Idx - CommonVF]) + + E->Scalars.size(); + } } + CommonVF = E->Scalars.size(); } V1 = Constant::getNullValue( - FixedVectorType::get(E->Scalars.front()->getType(), VF)); - } - if (!V2 && !P2.isNull()) { - const TreeEntry *E = P2.get<const TreeEntry *>(); + FixedVectorType::get(E->Scalars.front()->getType(), CommonVF)); + V2 = getAllOnesValue( + *R.DL, FixedVectorType::get(E->Scalars.front()->getType(), CommonVF)); + } else if (!V1 && P2.isNull()) { + // Shuffle single entry node. + const TreeEntry *E = P1.get<const TreeEntry *>(); unsigned VF = E->getVectorFactor(); - unsigned V1VF = cast<FixedVectorType>(V1->getType())->getNumElements(); - if (!CommonVF && V1VF == E->Scalars.size()) + CommonVF = VF; + assert( + all_of(Mask, + [=](int Idx) { return Idx < static_cast<int>(CommonVF); }) && + "All elements in mask must be less than CommonVF."); + if (E->Scalars.size() == Mask.size() && VF != Mask.size()) { + SmallVector<int> EMask = E->getCommonMask(); + assert(!EMask.empty() && "Expected non-empty common mask."); + for (int &Idx : CommonMask) { + if (Idx != PoisonMaskElem) + Idx = EMask[Idx]; + } CommonVF = E->Scalars.size(); - if (CommonVF) - VF = CommonVF; - V2 = Constant::getNullValue( - FixedVectorType::get(E->Scalars.front()->getType(), VF)); + } + V1 = Constant::getNullValue( + FixedVectorType::get(E->Scalars.front()->getType(), CommonVF)); + } else if (V1 && P2.isNull()) { + // Shuffle single vector. + CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements(); + assert( + all_of(Mask, + [=](int Idx) { return Idx < static_cast<int>(CommonVF); }) && + "All elements in mask must be less than CommonVF."); + } else if (V1 && !V2) { + // Shuffle vector and tree node. + unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements(); + const TreeEntry *E2 = P2.get<const TreeEntry *>(); + CommonVF = std::max(VF, E2->getVectorFactor()); + assert(all_of(Mask, + [=](int Idx) { + return Idx < 2 * static_cast<int>(CommonVF); + }) && + "All elements in mask must be less than 2 * CommonVF."); + if (E2->Scalars.size() == VF && VF != CommonVF) { + SmallVector<int> E2Mask = E2->getCommonMask(); + assert(!E2Mask.empty() && "Expected non-empty common mask."); + for (int &Idx : CommonMask) { + if (Idx == PoisonMaskElem) + continue; + if (Idx >= static_cast<int>(CommonVF)) + Idx = E2Mask[Idx - CommonVF] + VF; + } + CommonVF = VF; + } + V1 = Constant::getNullValue( + FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF)); + V2 = getAllOnesValue( + *R.DL, + FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF)); + } else if (!V1 && V2) { + // Shuffle vector and tree node. + unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements(); + const TreeEntry *E1 = P1.get<const TreeEntry *>(); + CommonVF = std::max(VF, E1->getVectorFactor()); + assert(all_of(Mask, + [=](int Idx) { + return Idx < 2 * static_cast<int>(CommonVF); + }) && + "All elements in mask must be less than 2 * CommonVF."); + if (E1->Scalars.size() == VF && VF != CommonVF) { + SmallVector<int> E1Mask = E1->getCommonMask(); + assert(!E1Mask.empty() && "Expected non-empty common mask."); + for (int &Idx : CommonMask) { + if (Idx == PoisonMaskElem) + continue; + if (Idx >= static_cast<int>(CommonVF)) + Idx = E1Mask[Idx - CommonVF] + VF; + } + CommonVF = VF; + } + V1 = Constant::getNullValue( + FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF)); + V2 = getAllOnesValue( + *R.DL, + FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF)); + } else { + assert(V1 && V2 && "Expected both vectors."); + unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements(); + CommonVF = + std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements()); + assert(all_of(Mask, + [=](int Idx) { + return Idx < 2 * static_cast<int>(CommonVF); + }) && + "All elements in mask must be less than 2 * CommonVF."); + if (V1->getType() != V2->getType()) { + V1 = Constant::getNullValue(FixedVectorType::get( + cast<FixedVectorType>(V1->getType())->getElementType(), CommonVF)); + V2 = getAllOnesValue( + *R.DL, FixedVectorType::get( + cast<FixedVectorType>(V1->getType())->getElementType(), + CommonVF)); + } } - return BaseShuffleAnalysis::createShuffle<InstructionCost>(V1, V2, Mask, - Builder); + InVectors.front() = Constant::getNullValue(FixedVectorType::get( + cast<FixedVectorType>(V1->getType())->getElementType(), + CommonMask.size())); + if (InVectors.size() == 2) + InVectors.pop_back(); + return BaseShuffleAnalysis::createShuffle<InstructionCost>( + V1, V2, CommonMask, Builder); } public: ShuffleCostEstimator(TargetTransformInfo &TTI, ArrayRef<Value *> VectorizedVals, BoUpSLP &R, SmallPtrSetImpl<Value *> &CheckedExtracts) - : TTI(TTI), VectorizedVals(VectorizedVals), R(R), - CheckedExtracts(CheckedExtracts) {} - Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask, - TTI::ShuffleKind ShuffleKind) { + : TTI(TTI), VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), + R(R), CheckedExtracts(CheckedExtracts) {} + Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask, + ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds, + unsigned NumParts, bool &UseVecBaseAsInput) { + UseVecBaseAsInput = false; if (Mask.empty()) return nullptr; Value *VecBase = nullptr; ArrayRef<Value *> VL = E->Scalars; - auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size()); // If the resulting type is scalarized, do not adjust the cost. - unsigned VecNumParts = TTI.getNumberOfParts(VecTy); - if (VecNumParts == VecTy->getNumElements()) + if (NumParts == VL.size()) return nullptr; - DenseMap<Value *, int> ExtractVectorsTys; - for (auto [I, V] : enumerate(VL)) { - // Ignore non-extractelement scalars. - if (isa<UndefValue>(V) || (!Mask.empty() && Mask[I] == PoisonMaskElem)) - continue; - // If all users of instruction are going to be vectorized and this - // instruction itself is not going to be vectorized, consider this - // instruction as dead and remove its cost from the final cost of the - // vectorized tree. - // Also, avoid adjusting the cost for extractelements with multiple uses - // in different graph entries. - const TreeEntry *VE = R.getTreeEntry(V); - if (!CheckedExtracts.insert(V).second || - !R.areAllUsersVectorized(cast<Instruction>(V), VectorizedVals) || - (VE && VE != E)) - continue; - auto *EE = cast<ExtractElementInst>(V); - VecBase = EE->getVectorOperand(); - std::optional<unsigned> EEIdx = getExtractIndex(EE); - if (!EEIdx) - continue; - unsigned Idx = *EEIdx; - if (VecNumParts != TTI.getNumberOfParts(EE->getVectorOperandType())) { - auto It = - ExtractVectorsTys.try_emplace(EE->getVectorOperand(), Idx).first; - It->getSecond() = std::min<int>(It->second, Idx); - } - // Take credit for instruction that will become dead. - if (EE->hasOneUse()) { - Instruction *Ext = EE->user_back(); - if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) { - return isa<GetElementPtrInst>(U); - })) { - // Use getExtractWithExtendCost() to calculate the cost of - // extractelement/ext pair. - Cost -= TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), - EE->getVectorOperandType(), Idx); - // Add back the cost of s|zext which is subtracted separately. - Cost += TTI.getCastInstrCost( - Ext->getOpcode(), Ext->getType(), EE->getType(), - TTI::getCastContextHint(Ext), CostKind, Ext); + // Check if it can be considered reused if same extractelements were + // vectorized already. + bool PrevNodeFound = any_of( + ArrayRef(R.VectorizableTree).take_front(E->Idx), + [&](const std::unique_ptr<TreeEntry> &TE) { + return ((!TE->isAltShuffle() && + TE->getOpcode() == Instruction::ExtractElement) || + TE->State == TreeEntry::NeedToGather) && + all_of(enumerate(TE->Scalars), [&](auto &&Data) { + return VL.size() > Data.index() && + (Mask[Data.index()] == PoisonMaskElem || + isa<UndefValue>(VL[Data.index()]) || + Data.value() == VL[Data.index()]); + }); + }); + SmallPtrSet<Value *, 4> UniqueBases; + unsigned SliceSize = VL.size() / NumParts; + for (unsigned Part = 0; Part < NumParts; ++Part) { + ArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize); + for (auto [I, V] : enumerate(VL.slice(Part * SliceSize, SliceSize))) { + // Ignore non-extractelement scalars. + if (isa<UndefValue>(V) || + (!SubMask.empty() && SubMask[I] == PoisonMaskElem)) continue; - } - } - Cost -= TTI.getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind, - Idx); - } - // Add a cost for subvector extracts/inserts if required. - for (const auto &Data : ExtractVectorsTys) { - auto *EEVTy = cast<FixedVectorType>(Data.first->getType()); - unsigned NumElts = VecTy->getNumElements(); - if (Data.second % NumElts == 0) - continue; - if (TTI.getNumberOfParts(EEVTy) > VecNumParts) { - unsigned Idx = (Data.second / NumElts) * NumElts; - unsigned EENumElts = EEVTy->getNumElements(); - if (Idx % NumElts == 0) + // If all users of instruction are going to be vectorized and this + // instruction itself is not going to be vectorized, consider this + // instruction as dead and remove its cost from the final cost of the + // vectorized tree. + // Also, avoid adjusting the cost for extractelements with multiple uses + // in different graph entries. + auto *EE = cast<ExtractElementInst>(V); + VecBase = EE->getVectorOperand(); + UniqueBases.insert(VecBase); + const TreeEntry *VE = R.getTreeEntry(V); + if (!CheckedExtracts.insert(V).second || + !R.areAllUsersVectorized(cast<Instruction>(V), &VectorizedVals) || + (VE && VE != E)) continue; - if (Idx + NumElts <= EENumElts) { - Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - EEVTy, std::nullopt, CostKind, Idx, VecTy); - } else { - // Need to round up the subvector type vectorization factor to avoid a - // crash in cost model functions. Make SubVT so that Idx + VF of SubVT - // <= EENumElts. - auto *SubVT = - FixedVectorType::get(VecTy->getElementType(), EENumElts - Idx); - Cost += TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - EEVTy, std::nullopt, CostKind, Idx, SubVT); + std::optional<unsigned> EEIdx = getExtractIndex(EE); + if (!EEIdx) + continue; + unsigned Idx = *EEIdx; + // Take credit for instruction that will become dead. + if (EE->hasOneUse() || !PrevNodeFound) { + Instruction *Ext = EE->user_back(); + if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) { + return isa<GetElementPtrInst>(U); + })) { + // Use getExtractWithExtendCost() to calculate the cost of + // extractelement/ext pair. + Cost -= + TTI.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), + EE->getVectorOperandType(), Idx); + // Add back the cost of s|zext which is subtracted separately. + Cost += TTI.getCastInstrCost( + Ext->getOpcode(), Ext->getType(), EE->getType(), + TTI::getCastContextHint(Ext), CostKind, Ext); + continue; + } } - } else { - Cost += TTI.getShuffleCost(TargetTransformInfo::SK_InsertSubvector, - VecTy, std::nullopt, CostKind, 0, EEVTy); + Cost -= TTI.getVectorInstrCost(*EE, EE->getVectorOperandType(), + CostKind, Idx); } } // Check that gather of extractelements can be represented as just a @@ -7192,31 +7502,152 @@ public: // Found the bunch of extractelement instructions that must be gathered // into a vector and can be represented as a permutation elements in a // single input vector or of 2 input vectors. - Cost += computeExtractCost(VL, Mask, ShuffleKind); + // Done for reused if same extractelements were vectorized already. + if (!PrevNodeFound) + Cost += computeExtractCost(VL, Mask, ShuffleKinds, NumParts); + InVectors.assign(1, E); + CommonMask.assign(Mask.begin(), Mask.end()); + transformMaskAfterShuffle(CommonMask, CommonMask); + SameNodesEstimated = false; + if (NumParts != 1 && UniqueBases.size() != 1) { + UseVecBaseAsInput = true; + VecBase = Constant::getNullValue( + FixedVectorType::get(VL.front()->getType(), CommonMask.size())); + } return VecBase; } - void add(const TreeEntry *E1, const TreeEntry *E2, ArrayRef<int> Mask) { - CommonMask.assign(Mask.begin(), Mask.end()); - InVectors.assign({E1, E2}); + /// Checks if the specified entry \p E needs to be delayed because of its + /// dependency nodes. + std::optional<InstructionCost> + needToDelay(const TreeEntry *, + ArrayRef<SmallVector<const TreeEntry *>>) const { + // No need to delay the cost estimation during analysis. + return std::nullopt; } - void add(const TreeEntry *E1, ArrayRef<int> Mask) { - CommonMask.assign(Mask.begin(), Mask.end()); - InVectors.assign(1, E1); + void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) { + if (&E1 == &E2) { + assert(all_of(Mask, + [&](int Idx) { + return Idx < static_cast<int>(E1.getVectorFactor()); + }) && + "Expected single vector shuffle mask."); + add(E1, Mask); + return; + } + if (InVectors.empty()) { + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign({&E1, &E2}); + return; + } + assert(!CommonMask.empty() && "Expected non-empty common mask."); + auto *MaskVecTy = + FixedVectorType::get(E1.Scalars.front()->getType(), Mask.size()); + unsigned NumParts = TTI.getNumberOfParts(MaskVecTy); + if (NumParts == 0 || NumParts >= Mask.size()) + NumParts = 1; + unsigned SliceSize = Mask.size() / NumParts; + const auto *It = + find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; }); + unsigned Part = std::distance(Mask.begin(), It) / SliceSize; + estimateNodesPermuteCost(E1, &E2, Mask, Part, SliceSize); + } + void add(const TreeEntry &E1, ArrayRef<int> Mask) { + if (InVectors.empty()) { + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign(1, &E1); + return; + } + assert(!CommonMask.empty() && "Expected non-empty common mask."); + auto *MaskVecTy = + FixedVectorType::get(E1.Scalars.front()->getType(), Mask.size()); + unsigned NumParts = TTI.getNumberOfParts(MaskVecTy); + if (NumParts == 0 || NumParts >= Mask.size()) + NumParts = 1; + unsigned SliceSize = Mask.size() / NumParts; + const auto *It = + find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; }); + unsigned Part = std::distance(Mask.begin(), It) / SliceSize; + estimateNodesPermuteCost(E1, nullptr, Mask, Part, SliceSize); + if (!SameNodesEstimated && InVectors.size() == 1) + InVectors.emplace_back(&E1); + } + /// Adds 2 input vectors and the mask for their shuffling. + void add(Value *V1, Value *V2, ArrayRef<int> Mask) { + // May come only for shuffling of 2 vectors with extractelements, already + // handled in adjustExtracts. + assert(InVectors.size() == 1 && + all_of(enumerate(CommonMask), + [&](auto P) { + if (P.value() == PoisonMaskElem) + return Mask[P.index()] == PoisonMaskElem; + auto *EI = + cast<ExtractElementInst>(InVectors.front() + .get<const TreeEntry *>() + ->Scalars[P.index()]); + return EI->getVectorOperand() == V1 || + EI->getVectorOperand() == V2; + }) && + "Expected extractelement vectors."); } /// Adds another one input vector and the mask for the shuffling. - void add(Value *V1, ArrayRef<int> Mask) { - assert(CommonMask.empty() && InVectors.empty() && - "Expected empty input mask/vectors."); - CommonMask.assign(Mask.begin(), Mask.end()); - InVectors.assign(1, V1); + void add(Value *V1, ArrayRef<int> Mask, bool ForExtracts = false) { + if (InVectors.empty()) { + assert(CommonMask.empty() && !ForExtracts && + "Expected empty input mask/vectors."); + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign(1, V1); + return; + } + if (ForExtracts) { + // No need to add vectors here, already handled them in adjustExtracts. + assert(InVectors.size() == 1 && + InVectors.front().is<const TreeEntry *>() && !CommonMask.empty() && + all_of(enumerate(CommonMask), + [&](auto P) { + Value *Scalar = InVectors.front() + .get<const TreeEntry *>() + ->Scalars[P.index()]; + if (P.value() == PoisonMaskElem) + return P.value() == Mask[P.index()] || + isa<UndefValue>(Scalar); + if (isa<Constant>(V1)) + return true; + auto *EI = cast<ExtractElementInst>(Scalar); + return EI->getVectorOperand() == V1; + }) && + "Expected only tree entry for extractelement vectors."); + return; + } + assert(!InVectors.empty() && !CommonMask.empty() && + "Expected only tree entries from extracts/reused buildvectors."); + unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements(); + if (InVectors.size() == 2) { + Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask); + transformMaskAfterShuffle(CommonMask, CommonMask); + VF = std::max<unsigned>(VF, CommonMask.size()); + } else if (const auto *InTE = + InVectors.front().dyn_cast<const TreeEntry *>()) { + VF = std::max(VF, InTE->getVectorFactor()); + } else { + VF = std::max( + VF, cast<FixedVectorType>(InVectors.front().get<Value *>()->getType()) + ->getNumElements()); + } + InVectors.push_back(V1); + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem) + CommonMask[Idx] = Mask[Idx] + VF; } - Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) { + Value *gather(ArrayRef<Value *> VL, unsigned MaskVF = 0, + Value *Root = nullptr) { Cost += getBuildVectorCost(VL, Root); if (!Root) { - assert(InVectors.empty() && "Unexpected input vectors for buildvector."); // FIXME: Need to find a way to avoid use of getNullValue here. SmallVector<Constant *> Vals; - for (Value *V : VL) { + unsigned VF = VL.size(); + if (MaskVF != 0) + VF = std::min(VF, MaskVF); + for (Value *V : VL.take_front(VF)) { if (isa<UndefValue>(V)) { Vals.push_back(cast<Constant>(V)); continue; @@ -7226,9 +7657,11 @@ public: return ConstantVector::get(Vals); } return ConstantVector::getSplat( - ElementCount::getFixed(VL.size()), - Constant::getNullValue(VL.front()->getType())); + ElementCount::getFixed( + cast<FixedVectorType>(Root->getType())->getNumElements()), + getAllOnesValue(*R.DL, VL.front()->getType())); } + InstructionCost createFreeze(InstructionCost Cost) { return Cost; } /// Finalize emission of the shuffles. InstructionCost finalize(ArrayRef<int> ExtMask, unsigned VF = 0, @@ -7236,31 +7669,24 @@ public: IsFinalized = true; if (Action) { const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front(); - if (InVectors.size() == 2) { + if (InVectors.size() == 2) Cost += createShuffle(Vec, InVectors.back(), CommonMask); - InVectors.pop_back(); - } else { + else Cost += createShuffle(Vec, nullptr, CommonMask); - } for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) if (CommonMask[Idx] != PoisonMaskElem) CommonMask[Idx] = Idx; assert(VF > 0 && "Expected vector length for the final value before action."); - Value *V = Vec.dyn_cast<Value *>(); - if (!Vec.isNull() && !V) - V = Constant::getNullValue(FixedVectorType::get( - Vec.get<const TreeEntry *>()->Scalars.front()->getType(), - CommonMask.size())); + Value *V = Vec.get<Value *>(); Action(V, CommonMask); + InVectors.front() = V; } ::addMask(CommonMask, ExtMask, /*ExtendingManyInputs=*/true); - if (CommonMask.empty()) - return Cost; - int Limit = CommonMask.size() * 2; - if (all_of(CommonMask, [=](int Idx) { return Idx < Limit; }) && - ShuffleVectorInst::isIdentityMask(CommonMask)) + if (CommonMask.empty()) { + assert(InVectors.size() == 1 && "Expected only one vector with no mask"); return Cost; + } return Cost + createShuffle(InVectors.front(), InVectors.size() == 2 ? InVectors.back() : nullptr, @@ -7273,28 +7699,63 @@ public: } }; +const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E, + unsigned Idx) const { + Value *Op = E->getOperand(Idx).front(); + if (const TreeEntry *TE = getTreeEntry(Op)) { + if (find_if(E->UserTreeIndices, [&](const EdgeInfo &EI) { + return EI.EdgeIdx == Idx && EI.UserTE == E; + }) != TE->UserTreeIndices.end()) + return TE; + auto MIt = MultiNodeScalars.find(Op); + if (MIt != MultiNodeScalars.end()) { + for (const TreeEntry *TE : MIt->second) { + if (find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) { + return EI.EdgeIdx == Idx && EI.UserTE == E; + }) != TE->UserTreeIndices.end()) + return TE; + } + } + } + const auto *It = + find_if(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { + return TE->State == TreeEntry::NeedToGather && + find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) { + return EI.EdgeIdx == Idx && EI.UserTE == E; + }) != TE->UserTreeIndices.end(); + }); + assert(It != VectorizableTree.end() && "Expected vectorizable entry."); + return It->get(); +} + InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, SmallPtrSetImpl<Value *> &CheckedExtracts) { ArrayRef<Value *> VL = E->Scalars; Type *ScalarTy = VL[0]->getType(); - if (auto *SI = dyn_cast<StoreInst>(VL[0])) - ScalarTy = SI->getValueOperand()->getType(); - else if (auto *CI = dyn_cast<CmpInst>(VL[0])) - ScalarTy = CI->getOperand(0)->getType(); - else if (auto *IE = dyn_cast<InsertElementInst>(VL[0])) - ScalarTy = IE->getOperand(1)->getType(); + if (E->State != TreeEntry::NeedToGather) { + if (auto *SI = dyn_cast<StoreInst>(VL[0])) + ScalarTy = SI->getValueOperand()->getType(); + else if (auto *CI = dyn_cast<CmpInst>(VL[0])) + ScalarTy = CI->getOperand(0)->getType(); + else if (auto *IE = dyn_cast<InsertElementInst>(VL[0])) + ScalarTy = IE->getOperand(1)->getType(); + } + if (!FixedVectorType::isValidElementType(ScalarTy)) + return InstructionCost::getInvalid(); auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; // If we have computed a smaller type for the expression, update VecTy so // that the costs will be accurate. - if (MinBWs.count(VL[0])) - VecTy = FixedVectorType::get( - IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); + auto It = MinBWs.find(E); + if (It != MinBWs.end()) { + ScalarTy = IntegerType::get(F->getContext(), It->second.first); + VecTy = FixedVectorType::get(ScalarTy, VL.size()); + } unsigned EntryVF = E->getVectorFactor(); - auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF); + auto *FinalVecTy = FixedVectorType::get(ScalarTy, EntryVF); bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); if (E->State == TreeEntry::NeedToGather) { @@ -7302,121 +7763,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, return 0; if (isa<InsertElementInst>(VL[0])) return InstructionCost::getInvalid(); - ShuffleCostEstimator Estimator(*TTI, VectorizedVals, *this, - CheckedExtracts); - unsigned VF = E->getVectorFactor(); - SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(), - E->ReuseShuffleIndices.end()); - SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); - // Build a mask out of the reorder indices and reorder scalars per this - // mask. - SmallVector<int> ReorderMask; - inversePermutation(E->ReorderIndices, ReorderMask); - if (!ReorderMask.empty()) - reorderScalars(GatheredScalars, ReorderMask); - SmallVector<int> Mask; - SmallVector<int> ExtractMask; - std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle; - std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle; - SmallVector<const TreeEntry *> Entries; - Type *ScalarTy = GatheredScalars.front()->getType(); - // Check for gathered extracts. - ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask); - SmallVector<Value *> IgnoredVals; - if (UserIgnoreList) - IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end()); - - bool Resized = false; - if (Value *VecBase = Estimator.adjustExtracts( - E, ExtractMask, ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc))) - if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType())) - if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) { - Resized = true; - GatheredScalars.append(VF - GatheredScalars.size(), - PoisonValue::get(ScalarTy)); - } - - // Do not try to look for reshuffled loads for gathered loads (they will be - // handled later), for vectorized scalars, and cases, which are definitely - // not profitable (splats and small gather nodes.) - if (ExtractShuffle || E->getOpcode() != Instruction::Load || - E->isAltShuffle() || - all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) || - isSplat(E->Scalars) || - (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) - GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); - if (GatherShuffle) { - assert((Entries.size() == 1 || Entries.size() == 2) && - "Expected shuffle of 1 or 2 entries."); - if (*GatherShuffle == TTI::SK_PermuteSingleSrc && - Entries.front()->isSame(E->Scalars)) { - // Perfect match in the graph, will reuse the previously vectorized - // node. Cost is 0. - LLVM_DEBUG( - dbgs() - << "SLP: perfect diamond match for gather bundle that starts with " - << *VL.front() << ".\n"); - // Restore the mask for previous partially matched values. - for (auto [I, V] : enumerate(E->Scalars)) { - if (isa<PoisonValue>(V)) { - Mask[I] = PoisonMaskElem; - continue; - } - if (Mask[I] == PoisonMaskElem) - Mask[I] = Entries.front()->findLaneForValue(V); - } - Estimator.add(Entries.front(), Mask); - return Estimator.finalize(E->ReuseShuffleIndices); - } - if (!Resized) { - unsigned VF1 = Entries.front()->getVectorFactor(); - unsigned VF2 = Entries.back()->getVectorFactor(); - if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF) - GatheredScalars.append(VF - GatheredScalars.size(), - PoisonValue::get(ScalarTy)); - } - // Remove shuffled elements from list of gathers. - for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { - if (Mask[I] != PoisonMaskElem) - GatheredScalars[I] = PoisonValue::get(ScalarTy); - } - LLVM_DEBUG(dbgs() << "SLP: shuffled " << Entries.size() - << " entries for bundle that starts with " - << *VL.front() << ".\n";); - if (Entries.size() == 1) - Estimator.add(Entries.front(), Mask); - else - Estimator.add(Entries.front(), Entries.back(), Mask); - if (all_of(GatheredScalars, PoisonValue ::classof)) - return Estimator.finalize(E->ReuseShuffleIndices); - return Estimator.finalize( - E->ReuseShuffleIndices, E->Scalars.size(), - [&](Value *&Vec, SmallVectorImpl<int> &Mask) { - Vec = Estimator.gather(GatheredScalars, - Constant::getNullValue(FixedVectorType::get( - GatheredScalars.front()->getType(), - GatheredScalars.size()))); - }); - } - if (!all_of(GatheredScalars, PoisonValue::classof)) { - auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size()); - bool SameGathers = VL.equals(Gathers); - Value *BV = Estimator.gather( - Gathers, SameGathers ? nullptr - : Constant::getNullValue(FixedVectorType::get( - GatheredScalars.front()->getType(), - GatheredScalars.size()))); - SmallVector<int> ReuseMask(Gathers.size(), PoisonMaskElem); - std::iota(ReuseMask.begin(), ReuseMask.end(), 0); - Estimator.add(BV, ReuseMask); - } - if (ExtractShuffle) - Estimator.add(E, std::nullopt); - return Estimator.finalize(E->ReuseShuffleIndices); + return processBuildVector<ShuffleCostEstimator, InstructionCost>( + E, *TTI, VectorizedVals, *this, CheckedExtracts); } InstructionCost CommonCost = 0; SmallVector<int> Mask; - if (!E->ReorderIndices.empty()) { + if (!E->ReorderIndices.empty() && + E->State != TreeEntry::PossibleStridedVectorize) { SmallVector<int> NewMask; if (E->getOpcode() == Instruction::Store) { // For stores the order is actually a mask. @@ -7429,11 +7782,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, } if (NeedToShuffleReuses) ::addMask(Mask, E->ReuseShuffleIndices); - if (!Mask.empty() && !ShuffleVectorInst::isIdentityMask(Mask)) + if (!Mask.empty() && !ShuffleVectorInst::isIdentityMask(Mask, Mask.size())) CommonCost = TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, FinalVecTy, Mask); assert((E->State == TreeEntry::Vectorize || - E->State == TreeEntry::ScatterVectorize) && + E->State == TreeEntry::ScatterVectorize || + E->State == TreeEntry::PossibleStridedVectorize) && "Unhandled state"); assert(E->getOpcode() && ((allSameType(VL) && allSameBlock(VL)) || @@ -7443,7 +7797,34 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, Instruction *VL0 = E->getMainOp(); unsigned ShuffleOrOp = E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); - const unsigned Sz = VL.size(); + SetVector<Value *> UniqueValues(VL.begin(), VL.end()); + const unsigned Sz = UniqueValues.size(); + SmallBitVector UsedScalars(Sz, false); + for (unsigned I = 0; I < Sz; ++I) { + if (getTreeEntry(UniqueValues[I]) == E) + continue; + UsedScalars.set(I); + } + auto GetCastContextHint = [&](Value *V) { + if (const TreeEntry *OpTE = getTreeEntry(V)) { + if (OpTE->State == TreeEntry::ScatterVectorize) + return TTI::CastContextHint::GatherScatter; + if (OpTE->State == TreeEntry::Vectorize && + OpTE->getOpcode() == Instruction::Load && !OpTE->isAltShuffle()) { + if (OpTE->ReorderIndices.empty()) + return TTI::CastContextHint::Normal; + SmallVector<int> Mask; + inversePermutation(OpTE->ReorderIndices, Mask); + if (ShuffleVectorInst::isReverseMask(Mask, Mask.size())) + return TTI::CastContextHint::Reversed; + } + } else { + InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI); + if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle()) + return TTI::CastContextHint::GatherScatter; + } + return TTI::CastContextHint::None; + }; auto GetCostDiff = [=](function_ref<InstructionCost(unsigned)> ScalarEltCost, function_ref<InstructionCost(InstructionCost)> VectorCost) { @@ -7453,13 +7834,49 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, // For some of the instructions no need to calculate cost for each // particular instruction, we can use the cost of the single // instruction x total number of scalar instructions. - ScalarCost = Sz * ScalarEltCost(0); + ScalarCost = (Sz - UsedScalars.count()) * ScalarEltCost(0); } else { - for (unsigned I = 0; I < Sz; ++I) + for (unsigned I = 0; I < Sz; ++I) { + if (UsedScalars.test(I)) + continue; ScalarCost += ScalarEltCost(I); + } } InstructionCost VecCost = VectorCost(CommonCost); + // Check if the current node must be resized, if the parent node is not + // resized. + if (!UnaryInstruction::isCast(E->getOpcode()) && E->Idx != 0) { + const EdgeInfo &EI = E->UserTreeIndices.front(); + if ((EI.UserTE->getOpcode() != Instruction::Select || + EI.EdgeIdx != 0) && + It != MinBWs.end()) { + auto UserBWIt = MinBWs.find(EI.UserTE); + Type *UserScalarTy = + EI.UserTE->getOperand(EI.EdgeIdx).front()->getType(); + if (UserBWIt != MinBWs.end()) + UserScalarTy = IntegerType::get(ScalarTy->getContext(), + UserBWIt->second.first); + if (ScalarTy != UserScalarTy) { + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy); + unsigned VecOpcode; + auto *SrcVecTy = + FixedVectorType::get(UserScalarTy, E->getVectorFactor()); + if (BWSz > SrcBWSz) + VecOpcode = Instruction::Trunc; + else + VecOpcode = + It->second.second ? Instruction::SExt : Instruction::ZExt; + TTI::CastContextHint CCH = GetCastContextHint(VL0); + VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, + CostKind); + ScalarCost += + Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy, + CCH, CostKind); + } + } + } LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost - CommonCost, ScalarCost, "Calculated costs for Tree")); return VecCost - ScalarCost; @@ -7550,7 +7967,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, // Count reused scalars. InstructionCost ScalarCost = 0; SmallPtrSet<const TreeEntry *, 4> CountedOps; - for (Value *V : VL) { + for (Value *V : UniqueValues) { auto *PHI = dyn_cast<PHINode>(V); if (!PHI) continue; @@ -7571,8 +7988,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, } case Instruction::ExtractValue: case Instruction::ExtractElement: { - auto GetScalarCost = [=](unsigned Idx) { - auto *I = cast<Instruction>(VL[Idx]); + auto GetScalarCost = [&](unsigned Idx) { + auto *I = cast<Instruction>(UniqueValues[Idx]); VectorType *SrcVecTy; if (ShuffleOrOp == Instruction::ExtractElement) { auto *EE = cast<ExtractElementInst>(I); @@ -7680,8 +8097,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, // need to shift the vector. // Do not calculate the cost if the actual size is the register size and // we can merge this shuffle with the following SK_Select. - auto *InsertVecTy = - FixedVectorType::get(SrcVecTy->getElementType(), InsertVecSz); + auto *InsertVecTy = FixedVectorType::get(ScalarTy, InsertVecSz); if (!IsIdentity) Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, InsertVecTy, Mask); @@ -7697,8 +8113,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask)); if (!InMask.all() && NumScalars != NumElts && !IsWholeSubvector) { if (InsertVecSz != VecSz) { - auto *ActualVecTy = - FixedVectorType::get(SrcVecTy->getElementType(), VecSz); + auto *ActualVecTy = FixedVectorType::get(ScalarTy, VecSz); Cost += TTI->getShuffleCost(TTI::SK_InsertSubvector, ActualVecTy, std::nullopt, CostKind, OffsetBeg - Offset, InsertVecTy); @@ -7729,22 +8144,52 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - auto GetScalarCost = [=](unsigned Idx) { - auto *VI = cast<Instruction>(VL[Idx]); - return TTI->getCastInstrCost(E->getOpcode(), ScalarTy, - VI->getOperand(0)->getType(), + auto SrcIt = MinBWs.find(getOperandEntry(E, 0)); + Type *SrcScalarTy = VL0->getOperand(0)->getType(); + auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size()); + unsigned Opcode = ShuffleOrOp; + unsigned VecOpcode = Opcode; + if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() && + (SrcIt != MinBWs.end() || It != MinBWs.end())) { + // Check if the values are candidates to demote. + unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy); + if (SrcIt != MinBWs.end()) { + SrcBWSz = SrcIt->second.first; + SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz); + SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size()); + } + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + if (BWSz == SrcBWSz) { + VecOpcode = Instruction::BitCast; + } else if (BWSz < SrcBWSz) { + VecOpcode = Instruction::Trunc; + } else if (It != MinBWs.end()) { + assert(BWSz > SrcBWSz && "Invalid cast!"); + VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + } + } + auto GetScalarCost = [&](unsigned Idx) -> InstructionCost { + // Do not count cost here if minimum bitwidth is in effect and it is just + // a bitcast (here it is just a noop). + if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast) + return TTI::TCC_Free; + auto *VI = VL0->getOpcode() == Opcode + ? cast<Instruction>(UniqueValues[Idx]) + : nullptr; + return TTI->getCastInstrCost(Opcode, VL0->getType(), + VL0->getOperand(0)->getType(), TTI::getCastContextHint(VI), CostKind, VI); }; auto GetVectorCost = [=](InstructionCost CommonCost) { - Type *SrcTy = VL0->getOperand(0)->getType(); - auto *SrcVecTy = FixedVectorType::get(SrcTy, VL.size()); - InstructionCost VecCost = CommonCost; - // Check if the values are candidates to demote. - if (!MinBWs.count(VL0) || VecTy != SrcVecTy) - VecCost += - TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, - TTI::getCastContextHint(VL0), CostKind, VL0); - return VecCost; + // Do not count cost here if minimum bitwidth is in effect and it is just + // a bitcast (here it is just a noop). + if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast) + return CommonCost; + auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr; + TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0)); + return CommonCost + + TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind, + VecOpcode == Opcode ? VI : nullptr); }; return GetCostDiff(GetScalarCost, GetVectorCost); } @@ -7761,7 +8206,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, ? CmpInst::BAD_FCMP_PREDICATE : CmpInst::BAD_ICMP_PREDICATE; auto GetScalarCost = [&](unsigned Idx) { - auto *VI = cast<Instruction>(VL[Idx]); + auto *VI = cast<Instruction>(UniqueValues[Idx]); CmpInst::Predicate CurrentPred = ScalarTy->isFloatingPointTy() ? CmpInst::BAD_FCMP_PREDICATE : CmpInst::BAD_ICMP_PREDICATE; @@ -7821,8 +8266,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, case Instruction::And: case Instruction::Or: case Instruction::Xor: { - auto GetScalarCost = [=](unsigned Idx) { - auto *VI = cast<Instruction>(VL[Idx]); + auto GetScalarCost = [&](unsigned Idx) { + auto *VI = cast<Instruction>(UniqueValues[Idx]); unsigned OpIdx = isa<UnaryOperator>(VI) ? 0 : 1; TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(VI->getOperand(0)); TTI::OperandValueInfo Op2Info = @@ -7833,8 +8278,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, }; auto GetVectorCost = [=](InstructionCost CommonCost) { unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1; - TTI::OperandValueInfo Op1Info = getOperandInfo(VL, 0); - TTI::OperandValueInfo Op2Info = getOperandInfo(VL, OpIdx); + TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0)); + TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx)); return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info, Op2Info) + CommonCost; @@ -7845,23 +8290,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, return CommonCost + GetGEPCostDiff(VL, VL0); } case Instruction::Load: { - auto GetScalarCost = [=](unsigned Idx) { - auto *VI = cast<LoadInst>(VL[Idx]); + auto GetScalarCost = [&](unsigned Idx) { + auto *VI = cast<LoadInst>(UniqueValues[Idx]); return TTI->getMemoryOpCost(Instruction::Load, ScalarTy, VI->getAlign(), VI->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo(), VI); }; auto *LI0 = cast<LoadInst>(VL0); - auto GetVectorCost = [=](InstructionCost CommonCost) { + auto GetVectorCost = [&](InstructionCost CommonCost) { InstructionCost VecLdCost; if (E->State == TreeEntry::Vectorize) { VecLdCost = TTI->getMemoryOpCost( Instruction::Load, VecTy, LI0->getAlign(), LI0->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo()); } else { - assert(E->State == TreeEntry::ScatterVectorize && "Unknown EntryState"); + assert((E->State == TreeEntry::ScatterVectorize || + E->State == TreeEntry::PossibleStridedVectorize) && + "Unknown EntryState"); Align CommonAlignment = LI0->getAlign(); - for (Value *V : VL) + for (Value *V : UniqueValues) CommonAlignment = std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); VecLdCost = TTI->getGatherScatterOpCost( @@ -7874,7 +8321,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, InstructionCost Cost = GetCostDiff(GetScalarCost, GetVectorCost); // If this node generates masked gather load then it is not a terminal node. // Hence address operand cost is estimated separately. - if (E->State == TreeEntry::ScatterVectorize) + if (E->State == TreeEntry::ScatterVectorize || + E->State == TreeEntry::PossibleStridedVectorize) return Cost; // Estimate cost of GEPs since this tree node is a terminator. @@ -7887,7 +8335,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, bool IsReorder = !E->ReorderIndices.empty(); auto GetScalarCost = [=](unsigned Idx) { auto *VI = cast<StoreInst>(VL[Idx]); - TTI::OperandValueInfo OpInfo = getOperandInfo(VI, 0); + TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(VI->getValueOperand()); return TTI->getMemoryOpCost(Instruction::Store, ScalarTy, VI->getAlign(), VI->getPointerAddressSpace(), CostKind, OpInfo, VI); @@ -7896,7 +8344,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0); auto GetVectorCost = [=](InstructionCost CommonCost) { // We know that we can merge the stores. Calculate the cost. - TTI::OperandValueInfo OpInfo = getOperandInfo(VL, 0); + TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0)); return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(), BaseSI->getPointerAddressSpace(), CostKind, OpInfo) + @@ -7912,8 +8360,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, GetGEPCostDiff(PointerOps, BaseSI->getPointerOperand()); } case Instruction::Call: { - auto GetScalarCost = [=](unsigned Idx) { - auto *CI = cast<CallInst>(VL[Idx]); + auto GetScalarCost = [&](unsigned Idx) { + auto *CI = cast<CallInst>(UniqueValues[Idx]); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (ID != Intrinsic::not_intrinsic) { IntrinsicCostAttributes CostAttrs(ID, *CI, 1); @@ -7954,8 +8402,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, } return false; }; - auto GetScalarCost = [=](unsigned Idx) { - auto *VI = cast<Instruction>(VL[Idx]); + auto GetScalarCost = [&](unsigned Idx) { + auto *VI = cast<Instruction>(UniqueValues[Idx]); assert(E->isOpcodeOrAlt(VI) && "Unexpected main/alternate opcode"); (void)E; return TTI->getInstructionCost(VI, CostKind); @@ -7995,21 +8443,15 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty, TTI::CastContextHint::None, CostKind); } - if (E->ReuseShuffleIndices.empty()) { - VecCost += - TTI->getShuffleCost(TargetTransformInfo::SK_Select, FinalVecTy); - } else { - SmallVector<int> Mask; - buildShuffleEntryMask( - E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, - [E](Instruction *I) { - assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - return I->getOpcode() == E->getAltOpcode(); - }, - Mask); - VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, - FinalVecTy, Mask); - } + SmallVector<int> Mask; + E->buildAltOpShuffleMask( + [E](Instruction *I) { + assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); + return I->getOpcode() == E->getAltOpcode(); + }, + Mask); + VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, + FinalVecTy, Mask); return VecCost; }; return GetCostDiff(GetScalarCost, GetVectorCost); @@ -8065,7 +8507,8 @@ bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const { // Gathering cost would be too much for tiny trees. if (VectorizableTree[0]->State == TreeEntry::NeedToGather || (VectorizableTree[1]->State == TreeEntry::NeedToGather && - VectorizableTree[0]->State != TreeEntry::ScatterVectorize)) + VectorizableTree[0]->State != TreeEntry::ScatterVectorize && + VectorizableTree[0]->State != TreeEntry::PossibleStridedVectorize)) return false; return true; @@ -8144,6 +8587,23 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const { allConstant(VectorizableTree[1]->Scalars)))) return true; + // If the graph includes only PHI nodes and gathers, it is defnitely not + // profitable for the vectorization, we can skip it, if the cost threshold is + // default. The cost of vectorized PHI nodes is almost always 0 + the cost of + // gathers/buildvectors. + constexpr int Limit = 4; + if (!ForReduction && !SLPCostThreshold.getNumOccurrences() && + !VectorizableTree.empty() && + all_of(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { + return (TE->State == TreeEntry::NeedToGather && + TE->getOpcode() != Instruction::ExtractElement && + count_if(TE->Scalars, + [](Value *V) { return isa<ExtractElementInst>(V); }) <= + Limit) || + TE->getOpcode() == Instruction::PHI; + })) + return true; + // We can vectorize the tree if its size is greater than or equal to the // minimum size specified by the MinTreeSize command line option. if (VectorizableTree.size() >= MinTreeSize) @@ -8435,16 +8895,6 @@ static T *performExtractsShuffleAction( } InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { - // Build a map for gathered scalars to the nodes where they are used. - ValueToGatherNodes.clear(); - for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) { - if (EntryPtr->State != TreeEntry::NeedToGather) - continue; - for (Value *V : EntryPtr->Scalars) - if (!isConstant(V)) - ValueToGatherNodes.try_emplace(V).first->getSecond().insert( - EntryPtr.get()); - } InstructionCost Cost = 0; LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << VectorizableTree.size() << ".\n"); @@ -8460,8 +8910,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { E->isSame(TE.Scalars)) { // Some gather nodes might be absolutely the same as some vectorizable // nodes after reordering, need to handle it. - LLVM_DEBUG(dbgs() << "SLP: Adding cost 0 for bundle that starts with " - << *TE.Scalars[0] << ".\n" + LLVM_DEBUG(dbgs() << "SLP: Adding cost 0 for bundle " + << shortBundleName(TE.Scalars) << ".\n" << "SLP: Current total cost = " << Cost << "\n"); continue; } @@ -8469,9 +8919,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { InstructionCost C = getEntryCost(&TE, VectorizedVals, CheckedExtracts); Cost += C; - LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C - << " for bundle that starts with " << *TE.Scalars[0] - << ".\n" + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle " + << shortBundleName(TE.Scalars) << ".\n" << "SLP: Current total cost = " << Cost << "\n"); } @@ -8480,6 +8929,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { SmallVector<MapVector<const TreeEntry *, SmallVector<int>>> ShuffleMasks; SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers; SmallVector<APInt> DemandedElts; + SmallDenseSet<Value *, 4> UsedInserts; + DenseSet<Value *> VectorCasts; for (ExternalUser &EU : ExternalUses) { // We only add extract cost once for the same scalar. if (!isa_and_nonnull<InsertElementInst>(EU.User) && @@ -8500,6 +8951,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { // to detect it as a final shuffled/identity match. if (auto *VU = dyn_cast_or_null<InsertElementInst>(EU.User)) { if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) { + if (!UsedInserts.insert(VU).second) + continue; std::optional<unsigned> InsertIdx = getInsertIndex(VU); if (InsertIdx) { const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar); @@ -8546,6 +8999,28 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { FirstUsers.emplace_back(VU, ScalarTE); DemandedElts.push_back(APInt::getZero(FTy->getNumElements())); VecId = FirstUsers.size() - 1; + auto It = MinBWs.find(ScalarTE); + if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) { + unsigned BWSz = It->second.second; + unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType()); + unsigned VecOpcode; + if (BWSz < SrcBWSz) + VecOpcode = Instruction::Trunc; + else + VecOpcode = + It->second.second ? Instruction::SExt : Instruction::ZExt; + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost C = TTI->getCastInstrCost( + VecOpcode, FTy, + FixedVectorType::get( + IntegerType::get(FTy->getContext(), It->second.first), + FTy->getNumElements()), + TTI::CastContextHint::None, CostKind); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C + << " for extending externally used vector with " + "non-equal minimum bitwidth.\n"); + Cost += C; + } } else { if (isFirstInsertElement(VU, cast<InsertElementInst>(It->first))) It->first = VU; @@ -8567,11 +9042,11 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { // for the extract and the added cost of the sign extend if needed. auto *VecTy = FixedVectorType::get(EU.Scalar->getType(), BundleWidth); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - auto *ScalarRoot = VectorizableTree[0]->Scalars[0]; - if (MinBWs.count(ScalarRoot)) { - auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); - auto Extend = - MinBWs[ScalarRoot].second ? Instruction::SExt : Instruction::ZExt; + auto It = MinBWs.find(getTreeEntry(EU.Scalar)); + if (It != MinBWs.end()) { + auto *MinTy = IntegerType::get(F->getContext(), It->second.first); + unsigned Extend = + It->second.second ? Instruction::SExt : Instruction::ZExt; VecTy = FixedVectorType::get(MinTy, BundleWidth); ExtractCost += TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(), VecTy, EU.Lane); @@ -8580,6 +9055,21 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { CostKind, EU.Lane); } } + // Add reduced value cost, if resized. + if (!VectorizedVals.empty()) { + auto BWIt = MinBWs.find(VectorizableTree.front().get()); + if (BWIt != MinBWs.end()) { + Type *DstTy = VectorizableTree.front()->Scalars.front()->getType(); + unsigned OriginalSz = DL->getTypeSizeInBits(DstTy); + unsigned Opcode = Instruction::Trunc; + if (OriginalSz < BWIt->second.first) + Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt; + Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first); + Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy, + TTI::CastContextHint::None, + TTI::TCK_RecipThroughput); + } + } InstructionCost SpillCost = getSpillCost(); Cost += SpillCost + ExtractCost; @@ -8590,9 +9080,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { unsigned VecVF = TE->getVectorFactor(); if (VF != VecVF && (any_of(Mask, [VF](int Idx) { return Idx >= static_cast<int>(VF); }) || - (all_of(Mask, - [VF](int Idx) { return Idx < 2 * static_cast<int>(VF); }) && - !ShuffleVectorInst::isIdentityMask(Mask)))) { + !ShuffleVectorInst::isIdentityMask(Mask, VF))) { SmallVector<int> OrigMask(VecVF, PoisonMaskElem); std::copy(Mask.begin(), std::next(Mask.begin(), std::min(VF, VecVF)), OrigMask.begin()); @@ -8611,19 +9099,23 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { // Calculate the cost of the reshuffled vectors, if any. for (int I = 0, E = FirstUsers.size(); I < E; ++I) { Value *Base = cast<Instruction>(FirstUsers[I].first)->getOperand(0); - unsigned VF = ShuffleMasks[I].begin()->second.size(); - auto *FTy = FixedVectorType::get( - cast<VectorType>(FirstUsers[I].first->getType())->getElementType(), VF); auto Vector = ShuffleMasks[I].takeVector(); - auto &&EstimateShufflesCost = [this, FTy, - &Cost](ArrayRef<int> Mask, - ArrayRef<const TreeEntry *> TEs) { + unsigned VF = 0; + auto EstimateShufflesCost = [&](ArrayRef<int> Mask, + ArrayRef<const TreeEntry *> TEs) { assert((TEs.size() == 1 || TEs.size() == 2) && "Expected exactly 1 or 2 tree entries."); if (TEs.size() == 1) { - int Limit = 2 * Mask.size(); - if (!all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) || - !ShuffleVectorInst::isIdentityMask(Mask)) { + if (VF == 0) + VF = TEs.front()->getVectorFactor(); + auto *FTy = + FixedVectorType::get(TEs.back()->Scalars.front()->getType(), VF); + if (!ShuffleVectorInst::isIdentityMask(Mask, VF) && + !all_of(enumerate(Mask), [=](const auto &Data) { + return Data.value() == PoisonMaskElem || + (Data.index() < VF && + static_cast<int>(Data.index()) == Data.value()); + })) { InstructionCost C = TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, FTy, Mask); LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C @@ -8634,6 +9126,15 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { Cost += C; } } else { + if (VF == 0) { + if (TEs.front() && + TEs.front()->getVectorFactor() == TEs.back()->getVectorFactor()) + VF = TEs.front()->getVectorFactor(); + else + VF = Mask.size(); + } + auto *FTy = + FixedVectorType::get(TEs.back()->Scalars.front()->getType(), VF); InstructionCost C = TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, FTy, Mask); LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C @@ -8643,6 +9144,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { dbgs() << "SLP: Current total cost = " << Cost << "\n"); Cost += C; } + VF = Mask.size(); return TEs.back(); }; (void)performExtractsShuffleAction<const TreeEntry>( @@ -8671,54 +9173,198 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { return Cost; } -std::optional<TargetTransformInfo::ShuffleKind> -BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, - SmallVectorImpl<int> &Mask, - SmallVectorImpl<const TreeEntry *> &Entries) { - Entries.clear(); - // No need to check for the topmost gather node. - if (TE == VectorizableTree.front().get()) +/// Tries to find extractelement instructions with constant indices from fixed +/// vector type and gather such instructions into a bunch, which highly likely +/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was +/// successful, the matched scalars are replaced by poison values in \p VL for +/// future analysis. +std::optional<TTI::ShuffleKind> +BoUpSLP::tryToGatherSingleRegisterExtractElements( + MutableArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) const { + // Scan list of gathered scalars for extractelements that can be represented + // as shuffles. + MapVector<Value *, SmallVector<int>> VectorOpToIdx; + SmallVector<int> UndefVectorExtracts; + for (int I = 0, E = VL.size(); I < E; ++I) { + auto *EI = dyn_cast<ExtractElementInst>(VL[I]); + if (!EI) { + if (isa<UndefValue>(VL[I])) + UndefVectorExtracts.push_back(I); + continue; + } + auto *VecTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType()); + if (!VecTy || !isa<ConstantInt, UndefValue>(EI->getIndexOperand())) + continue; + std::optional<unsigned> Idx = getExtractIndex(EI); + // Undefined index. + if (!Idx) { + UndefVectorExtracts.push_back(I); + continue; + } + SmallBitVector ExtractMask(VecTy->getNumElements(), true); + ExtractMask.reset(*Idx); + if (isUndefVector(EI->getVectorOperand(), ExtractMask).all()) { + UndefVectorExtracts.push_back(I); + continue; + } + VectorOpToIdx[EI->getVectorOperand()].push_back(I); + } + // Sort the vector operands by the maximum number of uses in extractelements. + MapVector<unsigned, SmallVector<Value *>> VFToVector; + for (const auto &Data : VectorOpToIdx) + VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()] + .push_back(Data.first); + for (auto &Data : VFToVector) { + stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) { + return VectorOpToIdx.find(V1)->second.size() > + VectorOpToIdx.find(V2)->second.size(); + }); + } + // Find the best pair of the vectors with the same number of elements or a + // single vector. + const int UndefSz = UndefVectorExtracts.size(); + unsigned SingleMax = 0; + Value *SingleVec = nullptr; + unsigned PairMax = 0; + std::pair<Value *, Value *> PairVec(nullptr, nullptr); + for (auto &Data : VFToVector) { + Value *V1 = Data.second.front(); + if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) { + SingleMax = VectorOpToIdx[V1].size() + UndefSz; + SingleVec = V1; + } + Value *V2 = nullptr; + if (Data.second.size() > 1) + V2 = *std::next(Data.second.begin()); + if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + + UndefSz) { + PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz; + PairVec = std::make_pair(V1, V2); + } + } + if (SingleMax == 0 && PairMax == 0 && UndefSz == 0) return std::nullopt; + // Check if better to perform a shuffle of 2 vectors or just of a single + // vector. + SmallVector<Value *> SavedVL(VL.begin(), VL.end()); + SmallVector<Value *> GatheredExtracts( + VL.size(), PoisonValue::get(VL.front()->getType())); + if (SingleMax >= PairMax && SingleMax) { + for (int Idx : VectorOpToIdx[SingleVec]) + std::swap(GatheredExtracts[Idx], VL[Idx]); + } else { + for (Value *V : {PairVec.first, PairVec.second}) + for (int Idx : VectorOpToIdx[V]) + std::swap(GatheredExtracts[Idx], VL[Idx]); + } + // Add extracts from undefs too. + for (int Idx : UndefVectorExtracts) + std::swap(GatheredExtracts[Idx], VL[Idx]); + // Check that gather of extractelements can be represented as just a + // shuffle of a single/two vectors the scalars are extracted from. + std::optional<TTI::ShuffleKind> Res = + isFixedVectorShuffle(GatheredExtracts, Mask); + if (!Res) { + // TODO: try to check other subsets if possible. + // Restore the original VL if attempt was not successful. + copy(SavedVL, VL.begin()); + return std::nullopt; + } + // Restore unused scalars from mask, if some of the extractelements were not + // selected for shuffle. + for (int I = 0, E = GatheredExtracts.size(); I < E; ++I) { + if (Mask[I] == PoisonMaskElem && !isa<PoisonValue>(GatheredExtracts[I]) && + isa<UndefValue>(GatheredExtracts[I])) { + std::swap(VL[I], GatheredExtracts[I]); + continue; + } + auto *EI = dyn_cast<ExtractElementInst>(VL[I]); + if (!EI || !isa<FixedVectorType>(EI->getVectorOperandType()) || + !isa<ConstantInt, UndefValue>(EI->getIndexOperand()) || + is_contained(UndefVectorExtracts, I)) + continue; + } + return Res; +} + +/// Tries to find extractelement instructions with constant indices from fixed +/// vector type and gather such instructions into a bunch, which highly likely +/// might be detected as a shuffle of 1 or 2 input vectors. If this attempt was +/// successful, the matched scalars are replaced by poison values in \p VL for +/// future analysis. +SmallVector<std::optional<TTI::ShuffleKind>> +BoUpSLP::tryToGatherExtractElements(SmallVectorImpl<Value *> &VL, + SmallVectorImpl<int> &Mask, + unsigned NumParts) const { + assert(NumParts > 0 && "NumParts expected be greater than or equal to 1."); + SmallVector<std::optional<TTI::ShuffleKind>> ShufflesRes(NumParts); Mask.assign(VL.size(), PoisonMaskElem); - assert(TE->UserTreeIndices.size() == 1 && - "Expected only single user of the gather node."); + unsigned SliceSize = VL.size() / NumParts; + for (unsigned Part = 0; Part < NumParts; ++Part) { + // Scan list of gathered scalars for extractelements that can be represented + // as shuffles. + MutableArrayRef<Value *> SubVL = + MutableArrayRef(VL).slice(Part * SliceSize, SliceSize); + SmallVector<int> SubMask; + std::optional<TTI::ShuffleKind> Res = + tryToGatherSingleRegisterExtractElements(SubVL, SubMask); + ShufflesRes[Part] = Res; + copy(SubMask, std::next(Mask.begin(), Part * SliceSize)); + } + if (none_of(ShufflesRes, [](const std::optional<TTI::ShuffleKind> &Res) { + return Res.has_value(); + })) + ShufflesRes.clear(); + return ShufflesRes; +} + +std::optional<TargetTransformInfo::ShuffleKind> +BoUpSLP::isGatherShuffledSingleRegisterEntry( + const TreeEntry *TE, ArrayRef<Value *> VL, MutableArrayRef<int> Mask, + SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part) { + Entries.clear(); // TODO: currently checking only for Scalars in the tree entry, need to count // reused elements too for better cost estimation. - Instruction &UserInst = - getLastInstructionInBundle(TE->UserTreeIndices.front().UserTE); - BasicBlock *ParentBB = nullptr; + const EdgeInfo &TEUseEI = TE->UserTreeIndices.front(); + const Instruction *TEInsertPt = &getLastInstructionInBundle(TEUseEI.UserTE); + const BasicBlock *TEInsertBlock = nullptr; // Main node of PHI entries keeps the correct order of operands/incoming // blocks. - if (auto *PHI = - dyn_cast<PHINode>(TE->UserTreeIndices.front().UserTE->getMainOp())) { - ParentBB = PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx); + if (auto *PHI = dyn_cast<PHINode>(TEUseEI.UserTE->getMainOp())) { + TEInsertBlock = PHI->getIncomingBlock(TEUseEI.EdgeIdx); + TEInsertPt = TEInsertBlock->getTerminator(); } else { - ParentBB = UserInst.getParent(); + TEInsertBlock = TEInsertPt->getParent(); } - auto *NodeUI = DT->getNode(ParentBB); + auto *NodeUI = DT->getNode(TEInsertBlock); assert(NodeUI && "Should only process reachable instructions"); SmallPtrSet<Value *, 4> GatheredScalars(VL.begin(), VL.end()); - auto CheckOrdering = [&](Instruction *LastEI) { - // Check if the user node of the TE comes after user node of EntryPtr, - // otherwise EntryPtr depends on TE. - // Gather nodes usually are not scheduled and inserted before their first - // user node. So, instead of checking dependency between the gather nodes - // themselves, we check the dependency between their user nodes. - // If one user node comes before the second one, we cannot use the second - // gather node as the source vector for the first gather node, because in - // the list of instructions it will be emitted later. - auto *EntryParent = LastEI->getParent(); - auto *NodeEUI = DT->getNode(EntryParent); + auto CheckOrdering = [&](const Instruction *InsertPt) { + // Argument InsertPt is an instruction where vector code for some other + // tree entry (one that shares one or more scalars with TE) is going to be + // generated. This lambda returns true if insertion point of vector code + // for the TE dominates that point (otherwise dependency is the other way + // around). The other node is not limited to be of a gather kind. Gather + // nodes are not scheduled and their vector code is inserted before their + // first user. If user is PHI, that is supposed to be at the end of a + // predecessor block. Otherwise it is the last instruction among scalars of + // the user node. So, instead of checking dependency between instructions + // themselves, we check dependency between their insertion points for vector + // code (since each scalar instruction ends up as a lane of a vector + // instruction). + const BasicBlock *InsertBlock = InsertPt->getParent(); + auto *NodeEUI = DT->getNode(InsertBlock); if (!NodeEUI) return false; assert((NodeUI == NodeEUI) == (NodeUI->getDFSNumIn() == NodeEUI->getDFSNumIn()) && "Different nodes should have different DFS numbers"); // Check the order of the gather nodes users. - if (UserInst.getParent() != EntryParent && + if (TEInsertPt->getParent() != InsertBlock && (DT->dominates(NodeUI, NodeEUI) || !DT->dominates(NodeEUI, NodeUI))) return false; - if (UserInst.getParent() == EntryParent && UserInst.comesBefore(LastEI)) + if (TEInsertPt->getParent() == InsertBlock && + TEInsertPt->comesBefore(InsertPt)) return false; return true; }; @@ -8743,43 +9389,42 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, [&](Value *V) { return GatheredScalars.contains(V); }) && "Must contain at least single gathered value."); assert(TEPtr->UserTreeIndices.size() == 1 && - "Expected only single user of the gather node."); - PHINode *EntryPHI = - dyn_cast<PHINode>(TEPtr->UserTreeIndices.front().UserTE->getMainOp()); - Instruction *EntryUserInst = - EntryPHI ? nullptr - : &getLastInstructionInBundle( - TEPtr->UserTreeIndices.front().UserTE); - if (&UserInst == EntryUserInst) { - assert(!EntryPHI && "Unexpected phi node entry."); - // If 2 gathers are operands of the same entry, compare operands - // indices, use the earlier one as the base. - if (TE->UserTreeIndices.front().UserTE == - TEPtr->UserTreeIndices.front().UserTE && - TE->UserTreeIndices.front().EdgeIdx < - TEPtr->UserTreeIndices.front().EdgeIdx) + "Expected only single user of a gather node."); + const EdgeInfo &UseEI = TEPtr->UserTreeIndices.front(); + + PHINode *UserPHI = dyn_cast<PHINode>(UseEI.UserTE->getMainOp()); + const Instruction *InsertPt = + UserPHI ? UserPHI->getIncomingBlock(UseEI.EdgeIdx)->getTerminator() + : &getLastInstructionInBundle(UseEI.UserTE); + if (TEInsertPt == InsertPt) { + // If 2 gathers are operands of the same entry (regardless of whether + // user is PHI or else), compare operands indices, use the earlier one + // as the base. + if (TEUseEI.UserTE == UseEI.UserTE && TEUseEI.EdgeIdx < UseEI.EdgeIdx) + continue; + // If the user instruction is used for some reason in different + // vectorized nodes - make it depend on index. + if (TEUseEI.UserTE != UseEI.UserTE && + TEUseEI.UserTE->Idx < UseEI.UserTE->Idx) continue; } - // Check if the user node of the TE comes after user node of EntryPtr, - // otherwise EntryPtr depends on TE. - auto *EntryI = - EntryPHI - ? EntryPHI - ->getIncomingBlock(TEPtr->UserTreeIndices.front().EdgeIdx) - ->getTerminator() - : EntryUserInst; - if ((ParentBB != EntryI->getParent() || - TE->UserTreeIndices.front().EdgeIdx < - TEPtr->UserTreeIndices.front().EdgeIdx || - TE->UserTreeIndices.front().UserTE != - TEPtr->UserTreeIndices.front().UserTE) && - !CheckOrdering(EntryI)) + + // Check if the user node of the TE comes after user node of TEPtr, + // otherwise TEPtr depends on TE. + if ((TEInsertBlock != InsertPt->getParent() || + TEUseEI.EdgeIdx < UseEI.EdgeIdx || TEUseEI.UserTE != UseEI.UserTE) && + !CheckOrdering(InsertPt)) continue; VToTEs.insert(TEPtr); } if (const TreeEntry *VTE = getTreeEntry(V)) { - Instruction &EntryUserInst = getLastInstructionInBundle(VTE); - if (&EntryUserInst == &UserInst || !CheckOrdering(&EntryUserInst)) + Instruction &LastBundleInst = getLastInstructionInBundle(VTE); + if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst)) + continue; + auto It = MinBWs.find(VTE); + // If vectorize node is demoted - do not match. + if (It != MinBWs.end() && + It->second.first != DL->getTypeSizeInBits(V->getType())) continue; VToTEs.insert(VTE); } @@ -8823,8 +9468,10 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, } } - if (UsedTEs.empty()) + if (UsedTEs.empty()) { + Entries.clear(); return std::nullopt; + } unsigned VF = 0; if (UsedTEs.size() == 1) { @@ -8838,9 +9485,19 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, auto *It = find_if(FirstEntries, [=](const TreeEntry *EntryPtr) { return EntryPtr->isSame(VL) || EntryPtr->isSame(TE->Scalars); }); - if (It != FirstEntries.end() && (*It)->getVectorFactor() == VL.size()) { + if (It != FirstEntries.end() && + ((*It)->getVectorFactor() == VL.size() || + ((*It)->getVectorFactor() == TE->Scalars.size() && + TE->ReuseShuffleIndices.size() == VL.size() && + (*It)->isSame(TE->Scalars)))) { Entries.push_back(*It); - std::iota(Mask.begin(), Mask.end(), 0); + if ((*It)->getVectorFactor() == VL.size()) { + std::iota(std::next(Mask.begin(), Part * VL.size()), + std::next(Mask.begin(), (Part + 1) * VL.size()), 0); + } else { + SmallVector<int> CommonMask = TE->getCommonMask(); + copy(CommonMask, Mask.begin()); + } // Clear undef scalars. for (int I = 0, Sz = VL.size(); I < Sz; ++I) if (isa<PoisonValue>(VL[I])) @@ -8923,12 +9580,9 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, // by extractelements processing) or may form vector node in future. auto MightBeIgnored = [=](Value *V) { auto *I = dyn_cast<Instruction>(V); - SmallVector<Value *> IgnoredVals; - if (UserIgnoreList) - IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end()); return I && !IsSplatOrUndefs && !ScalarToTreeEntry.count(I) && !isVectorLikeInstWithConstOps(I) && - !areAllUsersVectorized(I, IgnoredVals) && isSimple(I); + !areAllUsersVectorized(I, UserIgnoreList) && isSimple(I); }; // Check that the neighbor instruction may form a full vector node with the // current instruction V. It is possible, if they have same/alternate opcode @@ -8980,7 +9634,10 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, TempEntries.push_back(Entries[I]); } Entries.swap(TempEntries); - if (EntryLanes.size() == Entries.size() && !VL.equals(TE->Scalars)) { + if (EntryLanes.size() == Entries.size() && + !VL.equals(ArrayRef(TE->Scalars) + .slice(Part * VL.size(), + std::min<int>(VL.size(), TE->Scalars.size())))) { // We may have here 1 or 2 entries only. If the number of scalars is equal // to the number of entries, no need to do the analysis, it is not very // profitable. Since VL is not the same as TE->Scalars, it means we already @@ -8993,9 +9650,10 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, // Pair.first is the offset to the vector, while Pair.second is the index of // scalar in the list. for (const std::pair<unsigned, int> &Pair : EntryLanes) { - Mask[Pair.second] = Pair.first * VF + - Entries[Pair.first]->findLaneForValue(VL[Pair.second]); - IsIdentity &= Mask[Pair.second] == Pair.second; + unsigned Idx = Part * VL.size() + Pair.second; + Mask[Idx] = Pair.first * VF + + Entries[Pair.first]->findLaneForValue(VL[Pair.second]); + IsIdentity &= Mask[Idx] == Pair.second; } switch (Entries.size()) { case 1: @@ -9010,9 +9668,64 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, break; } Entries.clear(); + // Clear the corresponding mask elements. + std::fill(std::next(Mask.begin(), Part * VL.size()), + std::next(Mask.begin(), (Part + 1) * VL.size()), PoisonMaskElem); return std::nullopt; } +SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> +BoUpSLP::isGatherShuffledEntry( + const TreeEntry *TE, ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask, + SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries, + unsigned NumParts) { + assert(NumParts > 0 && NumParts < VL.size() && + "Expected positive number of registers."); + Entries.clear(); + // No need to check for the topmost gather node. + if (TE == VectorizableTree.front().get()) + return {}; + Mask.assign(VL.size(), PoisonMaskElem); + assert(TE->UserTreeIndices.size() == 1 && + "Expected only single user of the gather node."); + assert(VL.size() % NumParts == 0 && + "Number of scalars must be divisible by NumParts."); + unsigned SliceSize = VL.size() / NumParts; + SmallVector<std::optional<TTI::ShuffleKind>> Res; + for (unsigned Part = 0; Part < NumParts; ++Part) { + ArrayRef<Value *> SubVL = VL.slice(Part * SliceSize, SliceSize); + SmallVectorImpl<const TreeEntry *> &SubEntries = Entries.emplace_back(); + std::optional<TTI::ShuffleKind> SubRes = + isGatherShuffledSingleRegisterEntry(TE, SubVL, Mask, SubEntries, Part); + if (!SubRes) + SubEntries.clear(); + Res.push_back(SubRes); + if (SubEntries.size() == 1 && *SubRes == TTI::SK_PermuteSingleSrc && + SubEntries.front()->getVectorFactor() == VL.size() && + (SubEntries.front()->isSame(TE->Scalars) || + SubEntries.front()->isSame(VL))) { + SmallVector<const TreeEntry *> LocalSubEntries; + LocalSubEntries.swap(SubEntries); + Entries.clear(); + Res.clear(); + std::iota(Mask.begin(), Mask.end(), 0); + // Clear undef scalars. + for (int I = 0, Sz = VL.size(); I < Sz; ++I) + if (isa<PoisonValue>(VL[I])) + Mask[I] = PoisonMaskElem; + Entries.emplace_back(1, LocalSubEntries.front()); + Res.push_back(TargetTransformInfo::SK_PermuteSingleSrc); + return Res; + } + } + if (all_of(Res, + [](const std::optional<TTI::ShuffleKind> &SK) { return !SK; })) { + Entries.clear(); + return {}; + } + return Res; +} + InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc) const { // Find the type of the operands in VL. @@ -9224,18 +9937,20 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { auto *Front = E->getMainOp(); Instruction *LastInst = &getLastInstructionInBundle(E); assert(LastInst && "Failed to find last instruction in bundle"); + BasicBlock::iterator LastInstIt = LastInst->getIterator(); // If the instruction is PHI, set the insert point after all the PHIs. bool IsPHI = isa<PHINode>(LastInst); if (IsPHI) - LastInst = LastInst->getParent()->getFirstNonPHI(); + LastInstIt = LastInst->getParent()->getFirstNonPHIIt(); if (IsPHI || (E->State != TreeEntry::NeedToGather && doesNotNeedToSchedule(E->Scalars))) { - Builder.SetInsertPoint(LastInst); + Builder.SetInsertPoint(LastInst->getParent(), LastInstIt); } else { // Set the insertion point after the last instruction in the bundle. Set the // debug location to Front. - Builder.SetInsertPoint(LastInst->getParent(), - std::next(LastInst->getIterator())); + Builder.SetInsertPoint( + LastInst->getParent(), + LastInst->getNextNonDebugInstruction()->getIterator()); } Builder.SetCurrentDebugLocation(Front->getDebugLoc()); } @@ -9271,10 +9986,12 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) { GatherShuffleExtractSeq.insert(InsElt); CSEBlocks.insert(InsElt->getParent()); // Add to our 'need-to-extract' list. - if (TreeEntry *Entry = getTreeEntry(V)) { - // Find which lane we need to extract. - unsigned FoundLane = Entry->findLaneForValue(V); - ExternalUses.emplace_back(V, InsElt, FoundLane); + if (isa<Instruction>(V)) { + if (TreeEntry *Entry = getTreeEntry(V)) { + // Find which lane we need to extract. + unsigned FoundLane = Entry->findLaneForValue(V); + ExternalUses.emplace_back(V, InsElt, FoundLane); + } } return Vec; }; @@ -9367,12 +10084,12 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { /// Holds all of the instructions that we gathered. SetVector<Instruction *> &GatherShuffleExtractSeq; /// A list of blocks that we are going to CSE. - SetVector<BasicBlock *> &CSEBlocks; + DenseSet<BasicBlock *> &CSEBlocks; public: ShuffleIRBuilder(IRBuilderBase &Builder, SetVector<Instruction *> &GatherShuffleExtractSeq, - SetVector<BasicBlock *> &CSEBlocks) + DenseSet<BasicBlock *> &CSEBlocks) : Builder(Builder), GatherShuffleExtractSeq(GatherShuffleExtractSeq), CSEBlocks(CSEBlocks) {} ~ShuffleIRBuilder() = default; @@ -9392,7 +10109,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { return V1; unsigned VF = Mask.size(); unsigned LocalVF = cast<FixedVectorType>(V1->getType())->getNumElements(); - if (VF == LocalVF && ShuffleVectorInst::isIdentityMask(Mask)) + if (VF == LocalVF && ShuffleVectorInst::isIdentityMask(Mask, VF)) return V1; Value *Vec = Builder.CreateShuffleVector(V1, Mask); if (auto *I = dyn_cast<Instruction>(Vec)) { @@ -9455,7 +10172,11 @@ public: : Builder(Builder), R(R) {} /// Adjusts extractelements after reusing them. - Value *adjustExtracts(const TreeEntry *E, ArrayRef<int> Mask) { + Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask, + ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds, + unsigned NumParts, bool &UseVecBaseAsInput) { + UseVecBaseAsInput = false; + SmallPtrSet<Value *, 4> UniqueBases; Value *VecBase = nullptr; for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { int Idx = Mask[I]; @@ -9463,6 +10184,10 @@ public: continue; auto *EI = cast<ExtractElementInst>(E->Scalars[I]); VecBase = EI->getVectorOperand(); + if (const TreeEntry *TE = R.getTreeEntry(VecBase)) + VecBase = TE->VectorizedValue; + assert(VecBase && "Expected vectorized value."); + UniqueBases.insert(VecBase); // If the only one use is vectorized - can delete the extractelement // itself. if (!EI->hasOneUse() || any_of(EI->users(), [&](User *U) { @@ -9471,14 +10196,97 @@ public: continue; R.eraseInstruction(EI); } - return VecBase; + if (NumParts == 1 || UniqueBases.size() == 1) + return VecBase; + UseVecBaseAsInput = true; + auto TransformToIdentity = [](MutableArrayRef<int> Mask) { + for (auto [I, Idx] : enumerate(Mask)) + if (Idx != PoisonMaskElem) + Idx = I; + }; + // Perform multi-register vector shuffle, joining them into a single virtual + // long vector. + // Need to shuffle each part independently and then insert all this parts + // into a long virtual vector register, forming the original vector. + Value *Vec = nullptr; + SmallVector<int> VecMask(Mask.size(), PoisonMaskElem); + unsigned SliceSize = E->Scalars.size() / NumParts; + for (unsigned Part = 0; Part < NumParts; ++Part) { + ArrayRef<Value *> VL = + ArrayRef(E->Scalars).slice(Part * SliceSize, SliceSize); + MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize); + constexpr int MaxBases = 2; + SmallVector<Value *, MaxBases> Bases(MaxBases); +#ifndef NDEBUG + int PrevSize = 0; +#endif // NDEBUG + for (const auto [I, V]: enumerate(VL)) { + if (SubMask[I] == PoisonMaskElem) + continue; + Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand(); + if (const TreeEntry *TE = R.getTreeEntry(VecOp)) + VecOp = TE->VectorizedValue; + assert(VecOp && "Expected vectorized value."); + const int Size = + cast<FixedVectorType>(VecOp->getType())->getNumElements(); +#ifndef NDEBUG + assert((PrevSize == Size || PrevSize == 0) && + "Expected vectors of the same size."); + PrevSize = Size; +#endif // NDEBUG + Bases[SubMask[I] < Size ? 0 : 1] = VecOp; + } + if (!Bases.front()) + continue; + Value *SubVec; + if (Bases.back()) { + SubVec = createShuffle(Bases.front(), Bases.back(), SubMask); + TransformToIdentity(SubMask); + } else { + SubVec = Bases.front(); + } + if (!Vec) { + Vec = SubVec; + assert((Part == 0 || all_of(seq<unsigned>(0, Part), + [&](unsigned P) { + ArrayRef<int> SubMask = + Mask.slice(P * SliceSize, SliceSize); + return all_of(SubMask, [](int Idx) { + return Idx == PoisonMaskElem; + }); + })) && + "Expected first part or all previous parts masked."); + copy(SubMask, std::next(VecMask.begin(), Part * SliceSize)); + } else { + unsigned VF = cast<FixedVectorType>(Vec->getType())->getNumElements(); + if (Vec->getType() != SubVec->getType()) { + unsigned SubVecVF = + cast<FixedVectorType>(SubVec->getType())->getNumElements(); + VF = std::max(VF, SubVecVF); + } + // Adjust SubMask. + for (auto [I, Idx] : enumerate(SubMask)) + if (Idx != PoisonMaskElem) + Idx += VF; + copy(SubMask, std::next(VecMask.begin(), Part * SliceSize)); + Vec = createShuffle(Vec, SubVec, VecMask); + TransformToIdentity(VecMask); + } + } + copy(VecMask, Mask.begin()); + return Vec; } /// Checks if the specified entry \p E needs to be delayed because of its /// dependency nodes. - Value *needToDelay(const TreeEntry *E, ArrayRef<const TreeEntry *> Deps) { + std::optional<Value *> + needToDelay(const TreeEntry *E, + ArrayRef<SmallVector<const TreeEntry *>> Deps) const { // No need to delay emission if all deps are ready. - if (all_of(Deps, [](const TreeEntry *TE) { return TE->VectorizedValue; })) - return nullptr; + if (all_of(Deps, [](ArrayRef<const TreeEntry *> TEs) { + return all_of( + TEs, [](const TreeEntry *TE) { return TE->VectorizedValue; }); + })) + return std::nullopt; // Postpone gather emission, will be emitted after the end of the // process to keep correct order. auto *VecTy = FixedVectorType::get(E->Scalars.front()->getType(), @@ -9487,6 +10295,16 @@ public: VecTy, PoisonValue::get(PointerType::getUnqual(VecTy->getContext())), MaybeAlign()); } + /// Adds 2 input vectors (in form of tree entries) and the mask for their + /// shuffling. + void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) { + add(E1.VectorizedValue, E2.VectorizedValue, Mask); + } + /// Adds single input vector (in form of tree entry) and the mask for its + /// shuffling. + void add(const TreeEntry &E1, ArrayRef<int> Mask) { + add(E1.VectorizedValue, Mask); + } /// Adds 2 input vectors and the mask for their shuffling. void add(Value *V1, Value *V2, ArrayRef<int> Mask) { assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors."); @@ -9516,7 +10334,7 @@ public: InVectors.push_back(V1); } /// Adds another one input vector and the mask for the shuffling. - void add(Value *V1, ArrayRef<int> Mask) { + void add(Value *V1, ArrayRef<int> Mask, bool = false) { if (InVectors.empty()) { if (!isa<FixedVectorType>(V1->getType())) { V1 = createShuffle(V1, nullptr, CommonMask); @@ -9578,7 +10396,8 @@ public: inversePermutation(Order, NewMask); add(V1, NewMask); } - Value *gather(ArrayRef<Value *> VL, Value *Root = nullptr) { + Value *gather(ArrayRef<Value *> VL, unsigned MaskVF = 0, + Value *Root = nullptr) { return R.gather(VL, Root); } Value *createFreeze(Value *V) { return Builder.CreateFreeze(V); } @@ -9639,8 +10458,14 @@ public: } }; -Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { - ArrayRef<Value *> VL = E->getOperand(NodeIdx); +Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx, + bool PostponedPHIs) { + ValueList &VL = E->getOperand(NodeIdx); + if (E->State == TreeEntry::PossibleStridedVectorize && + !E->ReorderIndices.empty()) { + SmallVector<int> Mask(E->ReorderIndices.begin(), E->ReorderIndices.end()); + reorderScalars(VL, Mask); + } const unsigned VF = VL.size(); InstructionsState S = getSameOpcode(VL, *TLI); // Special processing for GEPs bundle, which may include non-gep values. @@ -9651,23 +10476,39 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { S = getSameOpcode(*It, *TLI); } if (S.getOpcode()) { - if (TreeEntry *VE = getTreeEntry(S.OpValue); - VE && VE->isSame(VL) && - (any_of(VE->UserTreeIndices, - [E, NodeIdx](const EdgeInfo &EI) { - return EI.UserTE == E && EI.EdgeIdx == NodeIdx; - }) || - any_of(VectorizableTree, - [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) { - return TE->isOperandGatherNode({E, NodeIdx}) && - VE->isSame(TE->Scalars); - }))) { + auto CheckSameVE = [&](const TreeEntry *VE) { + return VE->isSame(VL) && + (any_of(VE->UserTreeIndices, + [E, NodeIdx](const EdgeInfo &EI) { + return EI.UserTE == E && EI.EdgeIdx == NodeIdx; + }) || + any_of(VectorizableTree, + [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) { + return TE->isOperandGatherNode({E, NodeIdx}) && + VE->isSame(TE->Scalars); + })); + }; + TreeEntry *VE = getTreeEntry(S.OpValue); + bool IsSameVE = VE && CheckSameVE(VE); + if (!IsSameVE) { + auto It = MultiNodeScalars.find(S.OpValue); + if (It != MultiNodeScalars.end()) { + auto *I = find_if(It->getSecond(), [&](const TreeEntry *TE) { + return TE != VE && CheckSameVE(TE); + }); + if (I != It->getSecond().end()) { + VE = *I; + IsSameVE = true; + } + } + } + if (IsSameVE) { auto FinalShuffle = [&](Value *V, ArrayRef<int> Mask) { ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); ShuffleBuilder.add(V, Mask); return ShuffleBuilder.finalize(std::nullopt); }; - Value *V = vectorizeTree(VE); + Value *V = vectorizeTree(VE, PostponedPHIs); if (VF != cast<FixedVectorType>(V->getType())->getNumElements()) { if (!VE->ReuseShuffleIndices.empty()) { // Reshuffle to get only unique values. @@ -9740,14 +10581,7 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { assert(I->get()->UserTreeIndices.size() == 1 && "Expected only single user for the gather node."); assert(I->get()->isSame(VL) && "Expected same list of scalars."); - IRBuilder<>::InsertPointGuard Guard(Builder); - if (E->getOpcode() != Instruction::InsertElement && - E->getOpcode() != Instruction::PHI) { - Instruction *LastInst = &getLastInstructionInBundle(E); - assert(LastInst && "Failed to find last instruction in bundle"); - Builder.SetInsertPoint(LastInst); - } - return vectorizeTree(I->get()); + return vectorizeTree(I->get(), PostponedPHIs); } template <typename BVTy, typename ResTy, typename... Args> @@ -9765,7 +10599,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { inversePermutation(E->ReorderIndices, ReorderMask); if (!ReorderMask.empty()) reorderScalars(GatheredScalars, ReorderMask); - auto FindReusedSplat = [&](SmallVectorImpl<int> &Mask) { + auto FindReusedSplat = [&](MutableArrayRef<int> Mask, unsigned InputVF) { if (!isSplat(E->Scalars) || none_of(E->Scalars, [](Value *V) { return isa<UndefValue>(V) && !isa<PoisonValue>(V); })) @@ -9782,70 +10616,102 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { }); if (It == VectorizableTree.end()) return false; - unsigned I = - *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; }); - int Sz = Mask.size(); - if (all_of(Mask, [Sz](int Idx) { return Idx < 2 * Sz; }) && - ShuffleVectorInst::isIdentityMask(Mask)) + int Idx; + if ((Mask.size() < InputVF && + ShuffleVectorInst::isExtractSubvectorMask(Mask, InputVF, Idx) && + Idx == 0) || + (Mask.size() == InputVF && + ShuffleVectorInst::isIdentityMask(Mask, Mask.size()))) { std::iota(Mask.begin(), Mask.end(), 0); - else + } else { + unsigned I = + *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; }); std::fill(Mask.begin(), Mask.end(), I); + } return true; }; BVTy ShuffleBuilder(Params...); ResTy Res = ResTy(); SmallVector<int> Mask; - SmallVector<int> ExtractMask; - std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle; - std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle; - SmallVector<const TreeEntry *> Entries; + SmallVector<int> ExtractMask(GatheredScalars.size(), PoisonMaskElem); + SmallVector<std::optional<TTI::ShuffleKind>> ExtractShuffles; + Value *ExtractVecBase = nullptr; + bool UseVecBaseAsInput = false; + SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> GatherShuffles; + SmallVector<SmallVector<const TreeEntry *>> Entries; Type *ScalarTy = GatheredScalars.front()->getType(); + auto *VecTy = FixedVectorType::get(ScalarTy, GatheredScalars.size()); + unsigned NumParts = TTI->getNumberOfParts(VecTy); + if (NumParts == 0 || NumParts >= GatheredScalars.size()) + NumParts = 1; if (!all_of(GatheredScalars, UndefValue::classof)) { // Check for gathered extracts. - ExtractShuffle = tryToGatherExtractElements(GatheredScalars, ExtractMask); - SmallVector<Value *> IgnoredVals; - if (UserIgnoreList) - IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end()); bool Resized = false; - if (Value *VecBase = ShuffleBuilder.adjustExtracts(E, ExtractMask)) - if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType())) - if (VF == VecBaseTy->getNumElements() && GatheredScalars.size() != VF) { - Resized = true; - GatheredScalars.append(VF - GatheredScalars.size(), - PoisonValue::get(ScalarTy)); - } + ExtractShuffles = + tryToGatherExtractElements(GatheredScalars, ExtractMask, NumParts); + if (!ExtractShuffles.empty()) { + SmallVector<const TreeEntry *> ExtractEntries; + for (auto [Idx, I] : enumerate(ExtractMask)) { + if (I == PoisonMaskElem) + continue; + if (const auto *TE = getTreeEntry( + cast<ExtractElementInst>(E->Scalars[Idx])->getVectorOperand())) + ExtractEntries.push_back(TE); + } + if (std::optional<ResTy> Delayed = + ShuffleBuilder.needToDelay(E, ExtractEntries)) { + // Delay emission of gathers which are not ready yet. + PostponedGathers.insert(E); + // Postpone gather emission, will be emitted after the end of the + // process to keep correct order. + return *Delayed; + } + if (Value *VecBase = ShuffleBuilder.adjustExtracts( + E, ExtractMask, ExtractShuffles, NumParts, UseVecBaseAsInput)) { + ExtractVecBase = VecBase; + if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType())) + if (VF == VecBaseTy->getNumElements() && + GatheredScalars.size() != VF) { + Resized = true; + GatheredScalars.append(VF - GatheredScalars.size(), + PoisonValue::get(ScalarTy)); + } + } + } // Gather extracts after we check for full matched gathers only. - if (ExtractShuffle || E->getOpcode() != Instruction::Load || + if (!ExtractShuffles.empty() || E->getOpcode() != Instruction::Load || E->isAltShuffle() || all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) || isSplat(E->Scalars) || (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) { - GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); + GatherShuffles = + isGatherShuffledEntry(E, GatheredScalars, Mask, Entries, NumParts); } - if (GatherShuffle) { - if (Value *Delayed = ShuffleBuilder.needToDelay(E, Entries)) { + if (!GatherShuffles.empty()) { + if (std::optional<ResTy> Delayed = + ShuffleBuilder.needToDelay(E, Entries)) { // Delay emission of gathers which are not ready yet. PostponedGathers.insert(E); // Postpone gather emission, will be emitted after the end of the // process to keep correct order. - return Delayed; + return *Delayed; } - assert((Entries.size() == 1 || Entries.size() == 2) && - "Expected shuffle of 1 or 2 entries."); - if (*GatherShuffle == TTI::SK_PermuteSingleSrc && - Entries.front()->isSame(E->Scalars)) { + if (GatherShuffles.size() == 1 && + *GatherShuffles.front() == TTI::SK_PermuteSingleSrc && + Entries.front().front()->isSame(E->Scalars)) { // Perfect match in the graph, will reuse the previously vectorized // node. Cost is 0. LLVM_DEBUG( dbgs() - << "SLP: perfect diamond match for gather bundle that starts with " - << *E->Scalars.front() << ".\n"); + << "SLP: perfect diamond match for gather bundle " + << shortBundleName(E->Scalars) << ".\n"); // Restore the mask for previous partially matched values. - if (Entries.front()->ReorderIndices.empty() && - ((Entries.front()->ReuseShuffleIndices.empty() && - E->Scalars.size() == Entries.front()->Scalars.size()) || - (E->Scalars.size() == - Entries.front()->ReuseShuffleIndices.size()))) { + Mask.resize(E->Scalars.size()); + const TreeEntry *FrontTE = Entries.front().front(); + if (FrontTE->ReorderIndices.empty() && + ((FrontTE->ReuseShuffleIndices.empty() && + E->Scalars.size() == FrontTE->Scalars.size()) || + (E->Scalars.size() == FrontTE->ReuseShuffleIndices.size()))) { std::iota(Mask.begin(), Mask.end(), 0); } else { for (auto [I, V] : enumerate(E->Scalars)) { @@ -9853,17 +10719,20 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { Mask[I] = PoisonMaskElem; continue; } - Mask[I] = Entries.front()->findLaneForValue(V); + Mask[I] = FrontTE->findLaneForValue(V); } } - ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask); + ShuffleBuilder.add(*FrontTE, Mask); Res = ShuffleBuilder.finalize(E->getCommonMask()); return Res; } if (!Resized) { - unsigned VF1 = Entries.front()->getVectorFactor(); - unsigned VF2 = Entries.back()->getVectorFactor(); - if ((VF == VF1 || VF == VF2) && GatheredScalars.size() != VF) + if (GatheredScalars.size() != VF && + any_of(Entries, [&](ArrayRef<const TreeEntry *> TEs) { + return any_of(TEs, [&](const TreeEntry *TE) { + return TE->getVectorFactor() == VF; + }); + })) GatheredScalars.append(VF - GatheredScalars.size(), PoisonValue::get(ScalarTy)); } @@ -9943,78 +10812,108 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { if (It != Scalars.end()) { // Replace undefs by the non-poisoned scalars and emit broadcast. int Pos = std::distance(Scalars.begin(), It); - for_each(UndefPos, [&](int I) { + for (int I : UndefPos) { // Set the undef position to the non-poisoned scalar. ReuseMask[I] = Pos; // Replace the undef by the poison, in the mask it is replaced by // non-poisoned scalar already. if (I != Pos) Scalars[I] = PoisonValue::get(ScalarTy); - }); + } } else { // Replace undefs by the poisons, emit broadcast and then emit // freeze. - for_each(UndefPos, [&](int I) { + for (int I : UndefPos) { ReuseMask[I] = PoisonMaskElem; if (isa<UndefValue>(Scalars[I])) Scalars[I] = PoisonValue::get(ScalarTy); - }); + } NeedFreeze = true; } } }; - if (ExtractShuffle || GatherShuffle) { + if (!ExtractShuffles.empty() || !GatherShuffles.empty()) { bool IsNonPoisoned = true; - bool IsUsedInExpr = false; + bool IsUsedInExpr = true; Value *Vec1 = nullptr; - if (ExtractShuffle) { + if (!ExtractShuffles.empty()) { // Gather of extractelements can be represented as just a shuffle of // a single/two vectors the scalars are extracted from. // Find input vectors. Value *Vec2 = nullptr; for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) { - if (ExtractMask[I] == PoisonMaskElem || - (!Mask.empty() && Mask[I] != PoisonMaskElem)) { + if (!Mask.empty() && Mask[I] != PoisonMaskElem) ExtractMask[I] = PoisonMaskElem; - continue; - } - if (isa<UndefValue>(E->Scalars[I])) - continue; - auto *EI = cast<ExtractElementInst>(E->Scalars[I]); - if (!Vec1) { - Vec1 = EI->getVectorOperand(); - } else if (Vec1 != EI->getVectorOperand()) { - assert((!Vec2 || Vec2 == EI->getVectorOperand()) && - "Expected only 1 or 2 vectors shuffle."); - Vec2 = EI->getVectorOperand(); + } + if (UseVecBaseAsInput) { + Vec1 = ExtractVecBase; + } else { + for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) { + if (ExtractMask[I] == PoisonMaskElem) + continue; + if (isa<UndefValue>(E->Scalars[I])) + continue; + auto *EI = cast<ExtractElementInst>(E->Scalars[I]); + Value *VecOp = EI->getVectorOperand(); + if (const auto *TE = getTreeEntry(VecOp)) + if (TE->VectorizedValue) + VecOp = TE->VectorizedValue; + if (!Vec1) { + Vec1 = VecOp; + } else if (Vec1 != EI->getVectorOperand()) { + assert((!Vec2 || Vec2 == EI->getVectorOperand()) && + "Expected only 1 or 2 vectors shuffle."); + Vec2 = VecOp; + } } } if (Vec2) { + IsUsedInExpr = false; IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1) && isGuaranteedNotToBePoison(Vec2); ShuffleBuilder.add(Vec1, Vec2, ExtractMask); } else if (Vec1) { - IsUsedInExpr = FindReusedSplat(ExtractMask); - ShuffleBuilder.add(Vec1, ExtractMask); + IsUsedInExpr &= FindReusedSplat( + ExtractMask, + cast<FixedVectorType>(Vec1->getType())->getNumElements()); + ShuffleBuilder.add(Vec1, ExtractMask, /*ForExtracts=*/true); IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1); } else { + IsUsedInExpr = false; ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get( ScalarTy, GatheredScalars.size())), - ExtractMask); + ExtractMask, /*ForExtracts=*/true); } } - if (GatherShuffle) { - if (Entries.size() == 1) { - IsUsedInExpr = FindReusedSplat(Mask); - ShuffleBuilder.add(Entries.front()->VectorizedValue, Mask); - IsNonPoisoned &= - isGuaranteedNotToBePoison(Entries.front()->VectorizedValue); - } else { - ShuffleBuilder.add(Entries.front()->VectorizedValue, - Entries.back()->VectorizedValue, Mask); - IsNonPoisoned &= - isGuaranteedNotToBePoison(Entries.front()->VectorizedValue) && - isGuaranteedNotToBePoison(Entries.back()->VectorizedValue); + if (!GatherShuffles.empty()) { + unsigned SliceSize = E->Scalars.size() / NumParts; + SmallVector<int> VecMask(Mask.size(), PoisonMaskElem); + for (const auto [I, TEs] : enumerate(Entries)) { + if (TEs.empty()) { + assert(!GatherShuffles[I] && + "No shuffles with empty entries list expected."); + continue; + } + assert((TEs.size() == 1 || TEs.size() == 2) && + "Expected shuffle of 1 or 2 entries."); + auto SubMask = ArrayRef(Mask).slice(I * SliceSize, SliceSize); + VecMask.assign(VecMask.size(), PoisonMaskElem); + copy(SubMask, std::next(VecMask.begin(), I * SliceSize)); + if (TEs.size() == 1) { + IsUsedInExpr &= + FindReusedSplat(VecMask, TEs.front()->getVectorFactor()); + ShuffleBuilder.add(*TEs.front(), VecMask); + if (TEs.front()->VectorizedValue) + IsNonPoisoned &= + isGuaranteedNotToBePoison(TEs.front()->VectorizedValue); + } else { + IsUsedInExpr = false; + ShuffleBuilder.add(*TEs.front(), *TEs.back(), VecMask); + if (TEs.front()->VectorizedValue && TEs.back()->VectorizedValue) + IsNonPoisoned &= + isGuaranteedNotToBePoison(TEs.front()->VectorizedValue) && + isGuaranteedNotToBePoison(TEs.back()->VectorizedValue); + } } } // Try to figure out best way to combine values: build a shuffle and insert @@ -10025,16 +10924,24 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { int MSz = Mask.size(); // Try to build constant vector and shuffle with it only if currently we // have a single permutation and more than 1 scalar constants. - bool IsSingleShuffle = !ExtractShuffle || !GatherShuffle; + bool IsSingleShuffle = ExtractShuffles.empty() || GatherShuffles.empty(); bool IsIdentityShuffle = - (ExtractShuffle.value_or(TTI::SK_PermuteTwoSrc) == - TTI::SK_PermuteSingleSrc && + ((UseVecBaseAsInput || + all_of(ExtractShuffles, + [](const std::optional<TTI::ShuffleKind> &SK) { + return SK.value_or(TTI::SK_PermuteTwoSrc) == + TTI::SK_PermuteSingleSrc; + })) && none_of(ExtractMask, [&](int I) { return I >= EMSz; }) && - ShuffleVectorInst::isIdentityMask(ExtractMask)) || - (GatherShuffle.value_or(TTI::SK_PermuteTwoSrc) == - TTI::SK_PermuteSingleSrc && + ShuffleVectorInst::isIdentityMask(ExtractMask, EMSz)) || + (!GatherShuffles.empty() && + all_of(GatherShuffles, + [](const std::optional<TTI::ShuffleKind> &SK) { + return SK.value_or(TTI::SK_PermuteTwoSrc) == + TTI::SK_PermuteSingleSrc; + }) && none_of(Mask, [&](int I) { return I >= MSz; }) && - ShuffleVectorInst::isIdentityMask(Mask)); + ShuffleVectorInst::isIdentityMask(Mask, MSz)); bool EnoughConstsForShuffle = IsSingleShuffle && (none_of(GatheredScalars, @@ -10064,7 +10971,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { if (!all_of(GatheredScalars, PoisonValue::classof)) { SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem); TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true); - Value *BV = ShuffleBuilder.gather(GatheredScalars); + Value *BV = ShuffleBuilder.gather(GatheredScalars, BVMask.size()); ShuffleBuilder.add(BV, BVMask); } if (all_of(NonConstants, [=](Value *V) { @@ -10078,13 +10985,13 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { E->ReuseShuffleIndices, E->Scalars.size(), [&](Value *&Vec, SmallVectorImpl<int> &Mask) { TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false); - Vec = ShuffleBuilder.gather(NonConstants, Vec); + Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec); }); } else if (!allConstant(GatheredScalars)) { // Gather unique scalars and all constants. SmallVector<int> ReuseMask(GatheredScalars.size(), PoisonMaskElem); TryPackScalars(GatheredScalars, ReuseMask, /*IsRootPoison=*/true); - Value *BV = ShuffleBuilder.gather(GatheredScalars); + Value *BV = ShuffleBuilder.gather(GatheredScalars, ReuseMask.size()); ShuffleBuilder.add(BV, ReuseMask); Res = ShuffleBuilder.finalize(E->ReuseShuffleIndices); } else { @@ -10109,29 +11016,37 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) { *this); } -Value *BoUpSLP::vectorizeTree(TreeEntry *E) { +Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { IRBuilder<>::InsertPointGuard Guard(Builder); - if (E->VectorizedValue) { + if (E->VectorizedValue && + (E->State != TreeEntry::Vectorize || E->getOpcode() != Instruction::PHI || + E->isAltShuffle())) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *E->Scalars[0] << ".\n"); return E->VectorizedValue; } if (E->State == TreeEntry::NeedToGather) { - if (E->getMainOp() && E->Idx == 0) + // Set insert point for non-reduction initial nodes. + if (E->getMainOp() && E->Idx == 0 && !UserIgnoreList) setInsertPointAfterBundle(E); Value *Vec = createBuildVector(E); E->VectorizedValue = Vec; return Vec; } - auto FinalShuffle = [&](Value *V, const TreeEntry *E) { + auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy, + bool IsSigned) { + if (V->getType() != VecTy) + V = Builder.CreateIntCast(V, VecTy, IsSigned); ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); if (E->getOpcode() == Instruction::Store) { ArrayRef<int> Mask = ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()), E->ReorderIndices.size()); ShuffleBuilder.add(V, Mask); + } else if (E->State == TreeEntry::PossibleStridedVectorize) { + ShuffleBuilder.addOrdered(V, std::nullopt); } else { ShuffleBuilder.addOrdered(V, E->ReorderIndices); } @@ -10139,7 +11054,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { }; assert((E->State == TreeEntry::Vectorize || - E->State == TreeEntry::ScatterVectorize) && + E->State == TreeEntry::ScatterVectorize || + E->State == TreeEntry::PossibleStridedVectorize) && "Unhandled state"); unsigned ShuffleOrOp = E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); @@ -10149,6 +11065,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ScalarTy = Store->getValueOperand()->getType(); else if (auto *IE = dyn_cast<InsertElementInst>(VL0)) ScalarTy = IE->getOperand(1)->getType(); + bool IsSigned = false; + auto It = MinBWs.find(E); + if (It != MinBWs.end()) { + ScalarTy = IntegerType::get(F->getContext(), It->second.first); + IsSigned = It->second.second; + } auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size()); switch (ShuffleOrOp) { case Instruction::PHI: { @@ -10156,32 +11078,45 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { E != VectorizableTree.front().get() || !E->UserTreeIndices.empty()) && "PHI reordering is free."); + if (PostponedPHIs && E->VectorizedValue) + return E->VectorizedValue; auto *PH = cast<PHINode>(VL0); - Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI()); + Builder.SetInsertPoint(PH->getParent(), + PH->getParent()->getFirstNonPHIIt()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues()); - Value *V = NewPhi; + if (PostponedPHIs || !E->VectorizedValue) { + PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues()); + E->PHI = NewPhi; + Value *V = NewPhi; - // Adjust insertion point once all PHI's have been generated. - Builder.SetInsertPoint(&*PH->getParent()->getFirstInsertionPt()); - Builder.SetCurrentDebugLocation(PH->getDebugLoc()); + // Adjust insertion point once all PHI's have been generated. + Builder.SetInsertPoint(PH->getParent(), + PH->getParent()->getFirstInsertionPt()); + Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); - E->VectorizedValue = V; + E->VectorizedValue = V; + if (PostponedPHIs) + return V; + } + PHINode *NewPhi = cast<PHINode>(E->PHI); + // If phi node is fully emitted - exit. + if (NewPhi->getNumIncomingValues() != 0) + return NewPhi; // PHINodes may have multiple entries from the same block. We want to // visit every block once. SmallPtrSet<BasicBlock *, 4> VisitedBBs; - for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { + for (unsigned I : seq<unsigned>(0, PH->getNumIncomingValues())) { ValueList Operands; - BasicBlock *IBB = PH->getIncomingBlock(i); + BasicBlock *IBB = PH->getIncomingBlock(I); // Stop emission if all incoming values are generated. if (NewPhi->getNumIncomingValues() == PH->getNumIncomingValues()) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); - return V; + return NewPhi; } if (!VisitedBBs.insert(IBB).second) { @@ -10191,37 +11126,54 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Builder.SetInsertPoint(IBB->getTerminator()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - Value *Vec = vectorizeOperand(E, i); + Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true); + if (VecTy != Vec->getType()) { + assert(MinBWs.contains(getOperandEntry(E, I)) && + "Expected item in MinBWs."); + Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second); + } NewPhi->addIncoming(Vec, IBB); } assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() && "Invalid number of incoming values"); - return V; + return NewPhi; } case Instruction::ExtractElement: { Value *V = E->getSingleOperand(0); setInsertPointAfterBundle(E); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; return V; } case Instruction::ExtractValue: { auto *LI = cast<LoadInst>(E->getSingleOperand(0)); Builder.SetInsertPoint(LI); - auto *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace()); - Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy); + Value *Ptr = LI->getPointerOperand(); LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign()); Value *NewV = propagateMetadata(V, E->Scalars); - NewV = FinalShuffle(NewV, E); + NewV = FinalShuffle(NewV, E, VecTy, IsSigned); E->VectorizedValue = NewV; return NewV; } case Instruction::InsertElement: { assert(E->ReuseShuffleIndices.empty() && "All inserts should be unique"); Builder.SetInsertPoint(cast<Instruction>(E->Scalars.back())); - Value *V = vectorizeOperand(E, 1); + Value *V = vectorizeOperand(E, 1, PostponedPHIs); + ArrayRef<Value *> Op = E->getOperand(1); + Type *ScalarTy = Op.front()->getType(); + if (cast<VectorType>(V->getType())->getElementType() != ScalarTy) { + assert(ScalarTy->isIntegerTy() && "Expected item in MinBWs."); + std::pair<unsigned, bool> Res = MinBWs.lookup(getOperandEntry(E, 1)); + assert(Res.first > 0 && "Expected item in MinBWs."); + V = Builder.CreateIntCast( + V, + FixedVectorType::get( + ScalarTy, + cast<FixedVectorType>(V->getType())->getNumElements()), + Res.second); + } // Create InsertVector shuffle if necessary auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { @@ -10254,7 +11206,57 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Mask[InsertIdx - Offset] = I; } if (!IsIdentity || NumElts != NumScalars) { - V = Builder.CreateShuffleVector(V, Mask); + Value *V2 = nullptr; + bool IsVNonPoisonous = isGuaranteedNotToBePoison(V) && !isConstant(V); + SmallVector<int> InsertMask(Mask); + if (NumElts != NumScalars && Offset == 0) { + // Follow all insert element instructions from the current buildvector + // sequence. + InsertElementInst *Ins = cast<InsertElementInst>(VL0); + do { + std::optional<unsigned> InsertIdx = getInsertIndex(Ins); + if (!InsertIdx) + break; + if (InsertMask[*InsertIdx] == PoisonMaskElem) + InsertMask[*InsertIdx] = *InsertIdx; + if (!Ins->hasOneUse()) + break; + Ins = dyn_cast_or_null<InsertElementInst>( + Ins->getUniqueUndroppableUser()); + } while (Ins); + SmallBitVector UseMask = + buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask); + SmallBitVector IsFirstPoison = + isUndefVector<true>(FirstInsert->getOperand(0), UseMask); + SmallBitVector IsFirstUndef = + isUndefVector(FirstInsert->getOperand(0), UseMask); + if (!IsFirstPoison.all()) { + unsigned Idx = 0; + for (unsigned I = 0; I < NumElts; I++) { + if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I) && + IsFirstUndef.test(I)) { + if (IsVNonPoisonous) { + InsertMask[I] = I < NumScalars ? I : 0; + continue; + } + if (!V2) + V2 = UndefValue::get(V->getType()); + if (Idx >= NumScalars) + Idx = NumScalars - 1; + InsertMask[I] = NumScalars + Idx; + ++Idx; + } else if (InsertMask[I] != PoisonMaskElem && + Mask[I] == PoisonMaskElem) { + InsertMask[I] = PoisonMaskElem; + } + } + } else { + InsertMask = Mask; + } + } + if (!V2) + V2 = PoisonValue::get(V->getType()); + V = Builder.CreateShuffleVector(V, V2, InsertMask); if (auto *I = dyn_cast<Instruction>(V)) { GatherShuffleExtractSeq.insert(I); CSEBlocks.insert(I->getParent()); @@ -10273,15 +11275,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if ((!IsIdentity || Offset != 0 || !IsFirstUndef.all()) && NumElts != NumScalars) { if (IsFirstUndef.all()) { - if (!ShuffleVectorInst::isIdentityMask(InsertMask)) { - SmallBitVector IsFirstPoison = - isUndefVector<true>(FirstInsert->getOperand(0), UseMask); - if (!IsFirstPoison.all()) { - for (unsigned I = 0; I < NumElts; I++) { - if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I)) - InsertMask[I] = I + NumElts; + if (!ShuffleVectorInst::isIdentityMask(InsertMask, NumElts)) { + SmallBitVector IsFirstPoison = + isUndefVector<true>(FirstInsert->getOperand(0), UseMask); + if (!IsFirstPoison.all()) { + for (unsigned I = 0; I < NumElts; I++) { + if (InsertMask[I] == PoisonMaskElem && !IsFirstPoison.test(I)) + InsertMask[I] = I + NumElts; + } } - } V = Builder.CreateShuffleVector( V, IsFirstPoison.all() ? PoisonValue::get(V->getType()) @@ -10329,15 +11331,36 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::BitCast: { setInsertPointAfterBundle(E); - Value *InVec = vectorizeOperand(E, 0); + Value *InVec = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } auto *CI = cast<CastInst>(VL0); - Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy); - V = FinalShuffle(V, E); + Instruction::CastOps VecOpcode = CI->getOpcode(); + Type *SrcScalarTy = VL0->getOperand(0)->getType(); + auto SrcIt = MinBWs.find(getOperandEntry(E, 0)); + if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() && + (SrcIt != MinBWs.end() || It != MinBWs.end())) { + // Check if the values are candidates to demote. + unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy); + if (SrcIt != MinBWs.end()) + SrcBWSz = SrcIt->second.first; + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + if (BWSz == SrcBWSz) { + VecOpcode = Instruction::BitCast; + } else if (BWSz < SrcBWSz) { + VecOpcode = Instruction::Trunc; + } else if (It != MinBWs.end()) { + assert(BWSz > SrcBWSz && "Invalid cast!"); + VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + } + } + Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast) + ? InVec + : Builder.CreateCast(VecOpcode, InVec, VecTy); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10347,21 +11370,30 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::ICmp: { setInsertPointAfterBundle(E); - Value *L = vectorizeOperand(E, 0); + Value *L = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - Value *R = vectorizeOperand(E, 1); + Value *R = vectorizeOperand(E, 1, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } + if (L->getType() != R->getType()) { + assert((MinBWs.contains(getOperandEntry(E, 0)) || + MinBWs.contains(getOperandEntry(E, 1))) && + "Expected item in MinBWs."); + L = Builder.CreateIntCast(L, VecTy, IsSigned); + R = Builder.CreateIntCast(R, VecTy, IsSigned); + } CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); Value *V = Builder.CreateCmp(P0, L, R); propagateIRFlags(V, E->Scalars, VL0); - V = FinalShuffle(V, E); + // Do not cast for cmps. + VecTy = cast<FixedVectorType>(V->getType()); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10370,24 +11402,31 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Select: { setInsertPointAfterBundle(E); - Value *Cond = vectorizeOperand(E, 0); + Value *Cond = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - Value *True = vectorizeOperand(E, 1); + Value *True = vectorizeOperand(E, 1, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - Value *False = vectorizeOperand(E, 2); + Value *False = vectorizeOperand(E, 2, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } + if (True->getType() != False->getType()) { + assert((MinBWs.contains(getOperandEntry(E, 1)) || + MinBWs.contains(getOperandEntry(E, 2))) && + "Expected item in MinBWs."); + True = Builder.CreateIntCast(True, VecTy, IsSigned); + False = Builder.CreateIntCast(False, VecTy, IsSigned); + } Value *V = Builder.CreateSelect(Cond, True, False); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10396,7 +11435,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::FNeg: { setInsertPointAfterBundle(E); - Value *Op = vectorizeOperand(E, 0); + Value *Op = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); @@ -10409,7 +11448,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10436,16 +11475,23 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Xor: { setInsertPointAfterBundle(E); - Value *LHS = vectorizeOperand(E, 0); + Value *LHS = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - Value *RHS = vectorizeOperand(E, 1); + Value *RHS = vectorizeOperand(E, 1, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } + if (LHS->getType() != RHS->getType()) { + assert((MinBWs.contains(getOperandEntry(E, 0)) || + MinBWs.contains(getOperandEntry(E, 1))) && + "Expected item in MinBWs."); + LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned); + RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned); + } Value *V = Builder.CreateBinOp( static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, @@ -10454,7 +11500,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10475,14 +11521,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // The pointer operand uses an in-tree scalar so we add the new // LoadInst to ExternalUses list to make sure that an extract will // be generated in the future. - if (TreeEntry *Entry = getTreeEntry(PO)) { - // Find which lane we need to extract. - unsigned FoundLane = Entry->findLaneForValue(PO); - ExternalUses.emplace_back(PO, NewLI, FoundLane); + if (isa<Instruction>(PO)) { + if (TreeEntry *Entry = getTreeEntry(PO)) { + // Find which lane we need to extract. + unsigned FoundLane = Entry->findLaneForValue(PO); + ExternalUses.emplace_back(PO, NewLI, FoundLane); + } } } else { - assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state"); - Value *VecPtr = vectorizeOperand(E, 0); + assert((E->State == TreeEntry::ScatterVectorize || + E->State == TreeEntry::PossibleStridedVectorize) && + "Unhandled state"); + Value *VecPtr = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -10496,35 +11546,32 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V = propagateMetadata(NewLI, E->Scalars); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; return V; } case Instruction::Store: { auto *SI = cast<StoreInst>(VL0); - unsigned AS = SI->getPointerAddressSpace(); setInsertPointAfterBundle(E); - Value *VecValue = vectorizeOperand(E, 0); - VecValue = FinalShuffle(VecValue, E); + Value *VecValue = vectorizeOperand(E, 0, PostponedPHIs); + VecValue = FinalShuffle(VecValue, E, VecTy, IsSigned); - Value *ScalarPtr = SI->getPointerOperand(); - Value *VecPtr = Builder.CreateBitCast( - ScalarPtr, VecValue->getType()->getPointerTo(AS)); + Value *Ptr = SI->getPointerOperand(); StoreInst *ST = - Builder.CreateAlignedStore(VecValue, VecPtr, SI->getAlign()); + Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign()); - // The pointer operand uses an in-tree scalar, so add the new BitCast or - // StoreInst to ExternalUses to make sure that an extract will be - // generated in the future. - if (TreeEntry *Entry = getTreeEntry(ScalarPtr)) { - // Find which lane we need to extract. - unsigned FoundLane = Entry->findLaneForValue(ScalarPtr); - ExternalUses.push_back(ExternalUser( - ScalarPtr, ScalarPtr != VecPtr ? cast<User>(VecPtr) : ST, - FoundLane)); + // The pointer operand uses an in-tree scalar, so add the new StoreInst to + // ExternalUses to make sure that an extract will be generated in the + // future. + if (isa<Instruction>(Ptr)) { + if (TreeEntry *Entry = getTreeEntry(Ptr)) { + // Find which lane we need to extract. + unsigned FoundLane = Entry->findLaneForValue(Ptr); + ExternalUses.push_back(ExternalUser(Ptr, ST, FoundLane)); + } } Value *V = propagateMetadata(ST, E->Scalars); @@ -10537,7 +11584,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { auto *GEP0 = cast<GetElementPtrInst>(VL0); setInsertPointAfterBundle(E); - Value *Op0 = vectorizeOperand(E, 0); + Value *Op0 = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -10545,7 +11592,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { SmallVector<Value *> OpVecs; for (int J = 1, N = GEP0->getNumOperands(); J < N; ++J) { - Value *OpVec = vectorizeOperand(E, J); + Value *OpVec = vectorizeOperand(E, J, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -10563,7 +11610,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { V = propagateMetadata(I, GEPs); } - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10585,41 +11632,42 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { VecCallCosts.first <= VecCallCosts.second; Value *ScalarArg = nullptr; - std::vector<Value *> OpVecs; + SmallVector<Value *> OpVecs; SmallVector<Type *, 2> TysForDecl; // Add return type if intrinsic is overloaded on it. if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1)) TysForDecl.push_back( FixedVectorType::get(CI->getType(), E->Scalars.size())); - for (int j = 0, e = CI->arg_size(); j < e; ++j) { + for (unsigned I : seq<unsigned>(0, CI->arg_size())) { ValueList OpVL; // Some intrinsics have scalar arguments. This argument should not be // vectorized. - if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(IID, j)) { + if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(IID, I)) { CallInst *CEI = cast<CallInst>(VL0); - ScalarArg = CEI->getArgOperand(j); - OpVecs.push_back(CEI->getArgOperand(j)); - if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j)) + ScalarArg = CEI->getArgOperand(I); + OpVecs.push_back(CEI->getArgOperand(I)); + if (isVectorIntrinsicWithOverloadTypeAtArg(IID, I)) TysForDecl.push_back(ScalarArg->getType()); continue; } - Value *OpVec = vectorizeOperand(E, j); + Value *OpVec = vectorizeOperand(E, I, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - LLVM_DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); + LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); - if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j)) + if (isVectorIntrinsicWithOverloadTypeAtArg(IID, I)) TysForDecl.push_back(OpVec->getType()); } Function *CF; if (!UseIntrinsic) { VFShape Shape = - VFShape::get(*CI, ElementCount::getFixed(static_cast<unsigned>( - VecTy->getNumElements())), + VFShape::get(CI->getFunctionType(), + ElementCount::getFixed( + static_cast<unsigned>(VecTy->getNumElements())), false /*HasGlobalPred*/); CF = VFDatabase(*CI).getVectorizedFunction(Shape); } else { @@ -10633,7 +11681,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // The scalar argument uses an in-tree scalar so we add the new vectorized // call to ExternalUses list to make sure that an extract will be // generated in the future. - if (ScalarArg) { + if (isa_and_present<Instruction>(ScalarArg)) { if (TreeEntry *Entry = getTreeEntry(ScalarArg)) { // Find which lane we need to extract. unsigned FoundLane = Entry->findLaneForValue(ScalarArg); @@ -10643,7 +11691,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } propagateIRFlags(V, E->Scalars, VL0); - V = FinalShuffle(V, E); + V = FinalShuffle(V, E, VecTy, IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10661,20 +11709,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *LHS = nullptr, *RHS = nullptr; if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) { setInsertPointAfterBundle(E); - LHS = vectorizeOperand(E, 0); + LHS = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - RHS = vectorizeOperand(E, 1); + RHS = vectorizeOperand(E, 1, PostponedPHIs); } else { setInsertPointAfterBundle(E); - LHS = vectorizeOperand(E, 0); + LHS = vectorizeOperand(E, 0, PostponedPHIs); } if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } + if (LHS && RHS && LHS->getType() != RHS->getType()) { + assert((MinBWs.contains(getOperandEntry(E, 0)) || + MinBWs.contains(getOperandEntry(E, 1))) && + "Expected item in MinBWs."); + LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned); + RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned); + } Value *V0, *V1; if (Instruction::isBinaryOp(E->getOpcode())) { @@ -10707,8 +11762,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // each vector operation. ValueList OpScalars, AltScalars; SmallVector<int> Mask; - buildShuffleEntryMask( - E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, + E->buildAltOpShuffleMask( [E, this](Instruction *I) { assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); return isAlternateInstruction(I, E->getMainOp(), E->getAltOp(), @@ -10726,6 +11780,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { CSEBlocks.insert(I->getParent()); } + if (V->getType() != VecTy && !isa<CmpInst>(VL0)) + V = Builder.CreateIntCast( + V, FixedVectorType::get(ScalarTy, E->getVectorFactor()), IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -10766,9 +11823,19 @@ Value *BoUpSLP::vectorizeTree( // need to rebuild it. EntryToLastInstruction.clear(); - Builder.SetInsertPoint(ReductionRoot ? ReductionRoot - : &F->getEntryBlock().front()); - auto *VectorRoot = vectorizeTree(VectorizableTree[0].get()); + if (ReductionRoot) + Builder.SetInsertPoint(ReductionRoot->getParent(), + ReductionRoot->getIterator()); + else + Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin()); + + // Postpone emission of PHIs operands to avoid cyclic dependencies issues. + (void)vectorizeTree(VectorizableTree[0].get(), /*PostponedPHIs=*/true); + for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree) + if (TE->State == TreeEntry::Vectorize && + TE->getOpcode() == Instruction::PHI && !TE->isAltShuffle() && + TE->VectorizedValue) + (void)vectorizeTree(TE.get(), /*PostponedPHIs=*/false); // Run through the list of postponed gathers and emit them, replacing the temp // emitted allocas with actual vector instructions. ArrayRef<const TreeEntry *> PostponedNodes = PostponedGathers.getArrayRef(); @@ -10785,9 +11852,32 @@ Value *BoUpSLP::vectorizeTree( TE->VectorizedValue = nullptr; auto *UserI = cast<Instruction>(TE->UserTreeIndices.front().UserTE->VectorizedValue); - Builder.SetInsertPoint(PrevVec); + // If user is a PHI node, its vector code have to be inserted right before + // block terminator. Since the node was delayed, there were some unresolved + // dependencies at the moment when stab instruction was emitted. In a case + // when any of these dependencies turn out an operand of another PHI, coming + // from this same block, position of a stab instruction will become invalid. + // The is because source vector that supposed to feed this gather node was + // inserted at the end of the block [after stab instruction]. So we need + // to adjust insertion point again to the end of block. + if (isa<PHINode>(UserI)) { + // Insert before all users. + Instruction *InsertPt = PrevVec->getParent()->getTerminator(); + for (User *U : PrevVec->users()) { + if (U == UserI) + continue; + auto *UI = dyn_cast<Instruction>(U); + if (!UI || isa<PHINode>(UI) || UI->getParent() != InsertPt->getParent()) + continue; + if (UI->comesBefore(InsertPt)) + InsertPt = UI; + } + Builder.SetInsertPoint(InsertPt); + } else { + Builder.SetInsertPoint(PrevVec); + } Builder.SetCurrentDebugLocation(UserI->getDebugLoc()); - Value *Vec = vectorizeTree(TE); + Value *Vec = vectorizeTree(TE, /*PostponedPHIs=*/false); PrevVec->replaceAllUsesWith(Vec); PostponedValues.try_emplace(Vec).first->second.push_back(TE); // Replace the stub vector node, if it was used before for one of the @@ -10800,26 +11890,6 @@ Value *BoUpSLP::vectorizeTree( eraseInstruction(PrevVec); } - // If the vectorized tree can be rewritten in a smaller type, we truncate the - // vectorized root. InstCombine will then rewrite the entire expression. We - // sign extend the extracted values below. - auto *ScalarRoot = VectorizableTree[0]->Scalars[0]; - if (MinBWs.count(ScalarRoot)) { - if (auto *I = dyn_cast<Instruction>(VectorRoot)) { - // If current instr is a phi and not the last phi, insert it after the - // last phi node. - if (isa<PHINode>(I)) - Builder.SetInsertPoint(&*I->getParent()->getFirstInsertionPt()); - else - Builder.SetInsertPoint(&*++BasicBlock::iterator(I)); - } - auto BundleWidth = VectorizableTree[0]->Scalars.size(); - auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); - auto *VecTy = FixedVectorType::get(MinTy, BundleWidth); - auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy); - VectorizableTree[0]->VectorizedValue = Trunc; - } - LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); @@ -10829,6 +11899,8 @@ Value *BoUpSLP::vectorizeTree( // Maps extract Scalar to the corresponding extractelement instruction in the // basic block. Only one extractelement per block should be emitted. DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs; + SmallDenseSet<Value *, 4> UsedInserts; + DenseMap<Value *, Value *> VectorCasts; // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { Value *Scalar = ExternalUse.Scalar; @@ -10863,7 +11935,8 @@ Value *BoUpSLP::vectorizeTree( Instruction *I = EEIt->second; if (Builder.GetInsertPoint() != Builder.GetInsertBlock()->end() && Builder.GetInsertPoint()->comesBefore(I)) - I->moveBefore(&*Builder.GetInsertPoint()); + I->moveBefore(*Builder.GetInsertPoint()->getParent(), + Builder.GetInsertPoint()); Ex = I; } } @@ -10886,11 +11959,10 @@ Value *BoUpSLP::vectorizeTree( } // If necessary, sign-extend or zero-extend ScalarRoot // to the larger type. - if (!MinBWs.count(ScalarRoot)) - return Ex; - if (MinBWs[ScalarRoot].second) - return Builder.CreateSExt(Ex, Scalar->getType()); - return Builder.CreateZExt(Ex, Scalar->getType()); + if (Scalar->getType() != Ex->getType()) + return Builder.CreateIntCast(Ex, Scalar->getType(), + MinBWs.find(E)->second.second); + return Ex; } assert(isa<FixedVectorType>(Scalar->getType()) && isa<InsertElementInst>(Scalar) && @@ -10908,12 +11980,13 @@ Value *BoUpSLP::vectorizeTree( "ExternallyUsedValues map"); if (auto *VecI = dyn_cast<Instruction>(Vec)) { if (auto *PHI = dyn_cast<PHINode>(VecI)) - Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI()); + Builder.SetInsertPoint(PHI->getParent(), + PHI->getParent()->getFirstNonPHIIt()); else Builder.SetInsertPoint(VecI->getParent(), std::next(VecI->getIterator())); } else { - Builder.SetInsertPoint(&F->getEntryBlock().front()); + Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin()); } Value *NewInst = ExtractAndExtendIfNeeded(Vec); // Required to update internally referenced instructions. @@ -10926,12 +11999,26 @@ Value *BoUpSLP::vectorizeTree( // Skip if the scalar is another vector op or Vec is not an instruction. if (!Scalar->getType()->isVectorTy() && isa<Instruction>(Vec)) { if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) { + if (!UsedInserts.insert(VU).second) + continue; + // Need to use original vector, if the root is truncated. + auto BWIt = MinBWs.find(E); + if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) { + auto VecIt = VectorCasts.find(Scalar); + if (VecIt == VectorCasts.end()) { + IRBuilder<>::InsertPointGuard Guard(Builder); + if (auto *IVec = dyn_cast<Instruction>(Vec)) + Builder.SetInsertPoint(IVec->getNextNonDebugInstruction()); + Vec = Builder.CreateIntCast(Vec, VU->getType(), + BWIt->second.second); + VectorCasts.try_emplace(Scalar, Vec); + } else { + Vec = VecIt->second; + } + } + std::optional<unsigned> InsertIdx = getInsertIndex(VU); if (InsertIdx) { - // Need to use original vector, if the root is truncated. - if (MinBWs.count(Scalar) && - VectorizableTree[0]->VectorizedValue == Vec) - Vec = VectorRoot; auto *It = find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) { // Checks if 2 insertelements are from the same buildvector. @@ -10991,18 +12078,18 @@ Value *BoUpSLP::vectorizeTree( // Find the insertion point for the extractelement lane. if (auto *VecI = dyn_cast<Instruction>(Vec)) { if (PHINode *PH = dyn_cast<PHINode>(User)) { - for (int i = 0, e = PH->getNumIncomingValues(); i != e; ++i) { - if (PH->getIncomingValue(i) == Scalar) { + for (unsigned I : seq<unsigned>(0, PH->getNumIncomingValues())) { + if (PH->getIncomingValue(I) == Scalar) { Instruction *IncomingTerminator = - PH->getIncomingBlock(i)->getTerminator(); + PH->getIncomingBlock(I)->getTerminator(); if (isa<CatchSwitchInst>(IncomingTerminator)) { Builder.SetInsertPoint(VecI->getParent(), std::next(VecI->getIterator())); } else { - Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); + Builder.SetInsertPoint(PH->getIncomingBlock(I)->getTerminator()); } Value *NewInst = ExtractAndExtendIfNeeded(Vec); - PH->setOperand(i, NewInst); + PH->setOperand(I, NewInst); } } } else { @@ -11011,7 +12098,7 @@ Value *BoUpSLP::vectorizeTree( User->replaceUsesOfWith(Scalar, NewInst); } } else { - Builder.SetInsertPoint(&F->getEntryBlock().front()); + Builder.SetInsertPoint(&F->getEntryBlock(), F->getEntryBlock().begin()); Value *NewInst = ExtractAndExtendIfNeeded(Vec); User->replaceUsesOfWith(Scalar, NewInst); } @@ -11084,7 +12171,7 @@ Value *BoUpSLP::vectorizeTree( // non-resizing mask. if (Mask.size() != cast<FixedVectorType>(Vals.front()->getType()) ->getNumElements() || - !ShuffleVectorInst::isIdentityMask(Mask)) + !ShuffleVectorInst::isIdentityMask(Mask, Mask.size())) return CreateShuffle(Vals.front(), nullptr, Mask); return Vals.front(); } @@ -11675,7 +12762,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, } } - auto makeControlDependent = [&](Instruction *I) { + auto MakeControlDependent = [&](Instruction *I) { auto *DepDest = getScheduleData(I); assert(DepDest && "must be in schedule window"); DepDest->ControlDependencies.push_back(BundleMember); @@ -11697,7 +12784,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, continue; // Add the dependency - makeControlDependent(I); + MakeControlDependent(I); if (!isGuaranteedToTransferExecutionToSuccessor(I)) // Everything past here must be control dependent on I. @@ -11723,7 +12810,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, continue; // Add the dependency - makeControlDependent(I); + MakeControlDependent(I); } } @@ -11741,7 +12828,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, continue; // Add the dependency - makeControlDependent(I); + MakeControlDependent(I); break; } } @@ -11756,7 +12843,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, "NextLoadStore list for non memory effecting bundle?"); MemoryLocation SrcLoc = getLocation(SrcInst); bool SrcMayWrite = BundleMember->Inst->mayWriteToMemory(); - unsigned numAliased = 0; + unsigned NumAliased = 0; unsigned DistToSrc = 1; for (; DepDest; DepDest = DepDest->NextLoadStore) { @@ -11771,13 +12858,13 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, // check this limit even between two read-only instructions. if (DistToSrc >= MaxMemDepDistance || ((SrcMayWrite || DepDest->Inst->mayWriteToMemory()) && - (numAliased >= AliasedCheckLimit || + (NumAliased >= AliasedCheckLimit || SLP->isAliased(SrcLoc, SrcInst, DepDest->Inst)))) { // We increment the counter only if the locations are aliased // (instead of counting all alias checks). This gives a better // balance between reduced runtime and accurate dependencies. - numAliased++; + NumAliased++; DepDest->MemoryDependencies.push_back(BundleMember); BundleMember->Dependencies++; @@ -11879,20 +12966,20 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { // Do the "real" scheduling. while (!ReadyInsts.empty()) { - ScheduleData *picked = *ReadyInsts.begin(); + ScheduleData *Picked = *ReadyInsts.begin(); ReadyInsts.erase(ReadyInsts.begin()); // Move the scheduled instruction(s) to their dedicated places, if not // there yet. - for (ScheduleData *BundleMember = picked; BundleMember; + for (ScheduleData *BundleMember = Picked; BundleMember; BundleMember = BundleMember->NextInBundle) { - Instruction *pickedInst = BundleMember->Inst; - if (pickedInst->getNextNode() != LastScheduledInst) - pickedInst->moveBefore(LastScheduledInst); - LastScheduledInst = pickedInst; + Instruction *PickedInst = BundleMember->Inst; + if (PickedInst->getNextNode() != LastScheduledInst) + PickedInst->moveBefore(LastScheduledInst); + LastScheduledInst = PickedInst; } - BS->schedule(picked, ReadyInsts); + BS->schedule(Picked, ReadyInsts); } // Check that we didn't break any of our invariants. @@ -11993,21 +13080,22 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // Determine if a value V in a vectorizable expression Expr can be demoted to a // smaller type with a truncation. We collect the values that will be demoted // in ToDemote and additional roots that require investigating in Roots. -static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, - SmallVectorImpl<Value *> &ToDemote, - SmallVectorImpl<Value *> &Roots) { +bool BoUpSLP::collectValuesToDemote( + Value *V, SmallVectorImpl<Value *> &ToDemote, + DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts, + SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const { // We can always demote constants. - if (isa<Constant>(V)) { - ToDemote.push_back(V); + if (isa<Constant>(V)) return true; - } - // If the value is not an instruction in the expression with only one use, it - // cannot be demoted. + // If the value is not a vectorized instruction in the expression with only + // one use, it cannot be demoted. auto *I = dyn_cast<Instruction>(V); - if (!I || !I->hasOneUse() || !Expr.count(I)) + if (!I || !I->hasOneUse() || !getTreeEntry(I) || !Visited.insert(I).second) return false; + unsigned Start = 0; + unsigned End = I->getNumOperands(); switch (I->getOpcode()) { // We can always demote truncations and extensions. Since truncations can @@ -12029,16 +13117,21 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, case Instruction::And: case Instruction::Or: case Instruction::Xor: - if (!collectValuesToDemote(I->getOperand(0), Expr, ToDemote, Roots) || - !collectValuesToDemote(I->getOperand(1), Expr, ToDemote, Roots)) + if (!collectValuesToDemote(I->getOperand(0), ToDemote, DemotedConsts, Roots, + Visited) || + !collectValuesToDemote(I->getOperand(1), ToDemote, DemotedConsts, Roots, + Visited)) return false; break; // We can demote selects if we can demote their true and false values. case Instruction::Select: { + Start = 1; SelectInst *SI = cast<SelectInst>(I); - if (!collectValuesToDemote(SI->getTrueValue(), Expr, ToDemote, Roots) || - !collectValuesToDemote(SI->getFalseValue(), Expr, ToDemote, Roots)) + if (!collectValuesToDemote(SI->getTrueValue(), ToDemote, DemotedConsts, + Roots, Visited) || + !collectValuesToDemote(SI->getFalseValue(), ToDemote, DemotedConsts, + Roots, Visited)) return false; break; } @@ -12048,7 +13141,8 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, case Instruction::PHI: { PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!collectValuesToDemote(IncValue, Expr, ToDemote, Roots)) + if (!collectValuesToDemote(IncValue, ToDemote, DemotedConsts, Roots, + Visited)) return false; break; } @@ -12058,6 +13152,10 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, return false; } + // Gather demoted constant operands. + for (unsigned Idx : seq<unsigned>(Start, End)) + if (isa<Constant>(I->getOperand(Idx))) + DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx); // Record the value that we can demote. ToDemote.push_back(V); return true; @@ -12075,44 +13173,26 @@ void BoUpSLP::computeMinimumValueSizes() { if (!TreeRootIT) return; - // If the expression is not rooted by a store, these roots should have - // external uses. We will rely on InstCombine to rewrite the expression in - // the narrower type. However, InstCombine only rewrites single-use values. - // This means that if a tree entry other than a root is used externally, it - // must have multiple uses and InstCombine will not rewrite it. The code - // below ensures that only the roots are used externally. - SmallPtrSet<Value *, 32> Expr(TreeRoot.begin(), TreeRoot.end()); - for (auto &EU : ExternalUses) - if (!Expr.erase(EU.Scalar)) - return; - if (!Expr.empty()) + // Ensure the roots of the vectorizable tree don't form a cycle. + if (!VectorizableTree.front()->UserTreeIndices.empty()) return; - // Collect the scalar values of the vectorizable expression. We will use this - // context to determine which values can be demoted. If we see a truncation, - // we mark it as seeding another demotion. - for (auto &EntryPtr : VectorizableTree) - Expr.insert(EntryPtr->Scalars.begin(), EntryPtr->Scalars.end()); - - // Ensure the roots of the vectorizable tree don't form a cycle. They must - // have a single external user that is not in the vectorizable tree. - for (auto *Root : TreeRoot) - if (!Root->hasOneUse() || Expr.count(*Root->user_begin())) - return; - // Conservatively determine if we can actually truncate the roots of the // expression. Collect the values that can be demoted in ToDemote and // additional roots that require investigating in Roots. SmallVector<Value *, 32> ToDemote; + DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts; SmallVector<Value *, 4> Roots; - for (auto *Root : TreeRoot) - if (!collectValuesToDemote(Root, Expr, ToDemote, Roots)) + for (auto *Root : TreeRoot) { + DenseSet<Value *> Visited; + if (!collectValuesToDemote(Root, ToDemote, DemotedConsts, Roots, Visited)) return; + } // The maximum bit width required to represent all the values that can be // demoted without loss of precision. It would be safe to truncate the roots // of the expression to this width. - auto MaxBitWidth = 8u; + auto MaxBitWidth = 1u; // We first check if all the bits of the roots are demanded. If they're not, // we can truncate the roots to this narrower type. @@ -12137,9 +13217,9 @@ void BoUpSLP::computeMinimumValueSizes() { // maximum bit width required to store the scalar by using ValueTracking to // compute the number of high-order bits we can truncate. if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType()) && - llvm::all_of(TreeRoot, [](Value *R) { - assert(R->hasOneUse() && "Root should have only one use!"); - return isa<GetElementPtrInst>(R->user_back()); + all_of(TreeRoot, [](Value *V) { + return all_of(V->users(), + [](User *U) { return isa<GetElementPtrInst>(U); }); })) { MaxBitWidth = 8u; @@ -12188,12 +13268,39 @@ void BoUpSLP::computeMinimumValueSizes() { // If we can truncate the root, we must collect additional values that might // be demoted as a result. That is, those seeded by truncations we will // modify. - while (!Roots.empty()) - collectValuesToDemote(Roots.pop_back_val(), Expr, ToDemote, Roots); + while (!Roots.empty()) { + DenseSet<Value *> Visited; + collectValuesToDemote(Roots.pop_back_val(), ToDemote, DemotedConsts, Roots, + Visited); + } // Finally, map the values we can demote to the maximum bit with we computed. - for (auto *Scalar : ToDemote) - MinBWs[Scalar] = std::make_pair(MaxBitWidth, !IsKnownPositive); + for (auto *Scalar : ToDemote) { + auto *TE = getTreeEntry(Scalar); + assert(TE && "Expected vectorized scalar."); + if (MinBWs.contains(TE)) + continue; + bool IsSigned = any_of(TE->Scalars, [&](Value *R) { + KnownBits Known = computeKnownBits(R, *DL); + return !Known.isNonNegative(); + }); + MinBWs.try_emplace(TE, MaxBitWidth, IsSigned); + const auto *I = cast<Instruction>(Scalar); + auto DCIt = DemotedConsts.find(I); + if (DCIt != DemotedConsts.end()) { + for (unsigned Idx : DCIt->getSecond()) { + // Check that all instructions operands are demoted. + if (all_of(TE->Scalars, [&](Value *V) { + auto SIt = DemotedConsts.find(cast<Instruction>(V)); + return SIt != DemotedConsts.end() && + is_contained(SIt->getSecond(), Idx); + })) { + const TreeEntry *CTE = getOperandEntry(TE, Idx); + MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned); + } + } + } + } } PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { @@ -12347,139 +13454,206 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, BoUpSLP::ValueSet VectorizedStores; bool Changed = false; - int E = Stores.size(); - SmallBitVector Tails(E, false); - int MaxIter = MaxStoreLookup.getValue(); - SmallVector<std::pair<int, int>, 16> ConsecutiveChain( - E, std::make_pair(E, INT_MAX)); - SmallVector<SmallBitVector, 4> CheckedPairs(E, SmallBitVector(E, false)); - int IterCnt; - auto &&FindConsecutiveAccess = [this, &Stores, &Tails, &IterCnt, MaxIter, - &CheckedPairs, - &ConsecutiveChain](int K, int Idx) { - if (IterCnt >= MaxIter) - return true; - if (CheckedPairs[Idx].test(K)) - return ConsecutiveChain[K].second == 1 && - ConsecutiveChain[K].first == Idx; - ++IterCnt; - CheckedPairs[Idx].set(K); - CheckedPairs[K].set(Idx); - std::optional<int> Diff = getPointersDiff( - Stores[K]->getValueOperand()->getType(), Stores[K]->getPointerOperand(), - Stores[Idx]->getValueOperand()->getType(), - Stores[Idx]->getPointerOperand(), *DL, *SE, /*StrictCheck=*/true); - if (!Diff || *Diff == 0) - return false; - int Val = *Diff; - if (Val < 0) { - if (ConsecutiveChain[Idx].second > -Val) { - Tails.set(K); - ConsecutiveChain[Idx] = std::make_pair(K, -Val); - } - return false; + // Stores the pair of stores (first_store, last_store) in a range, that were + // already tried to be vectorized. Allows to skip the store ranges that were + // already tried to be vectorized but the attempts were unsuccessful. + DenseSet<std::pair<Value *, Value *>> TriedSequences; + struct StoreDistCompare { + bool operator()(const std::pair<unsigned, int> &Op1, + const std::pair<unsigned, int> &Op2) const { + return Op1.second < Op2.second; } - if (ConsecutiveChain[K].second <= Val) - return false; - - Tails.set(Idx); - ConsecutiveChain[K] = std::make_pair(Idx, Val); - return Val == 1; }; - // Do a quadratic search on all of the given stores in reverse order and find - // all of the pairs of stores that follow each other. - for (int Idx = E - 1; Idx >= 0; --Idx) { - // If a store has multiple consecutive store candidates, search according - // to the sequence: Idx-1, Idx+1, Idx-2, Idx+2, ... - // This is because usually pairing with immediate succeeding or preceding - // candidate create the best chance to find slp vectorization opportunity. - const int MaxLookDepth = std::max(E - Idx, Idx + 1); - IterCnt = 0; - for (int Offset = 1, F = MaxLookDepth; Offset < F; ++Offset) - if ((Idx >= Offset && FindConsecutiveAccess(Idx - Offset, Idx)) || - (Idx + Offset < E && FindConsecutiveAccess(Idx + Offset, Idx))) - break; - } - - // Tracks if we tried to vectorize stores starting from the given tail - // already. - SmallBitVector TriedTails(E, false); - // For stores that start but don't end a link in the chain: - for (int Cnt = E; Cnt > 0; --Cnt) { - int I = Cnt - 1; - if (ConsecutiveChain[I].first == E || Tails.test(I)) - continue; - // We found a store instr that starts a chain. Now follow the chain and try - // to vectorize it. + // A set of pairs (index of store in Stores array ref, Distance of the store + // address relative to base store address in units). + using StoreIndexToDistSet = + std::set<std::pair<unsigned, int>, StoreDistCompare>; + auto TryToVectorize = [&](const StoreIndexToDistSet &Set) { + int PrevDist = -1; BoUpSLP::ValueList Operands; // Collect the chain into a list. - while (I != E && !VectorizedStores.count(Stores[I])) { - Operands.push_back(Stores[I]); - Tails.set(I); - if (ConsecutiveChain[I].second != 1) { - // Mark the new end in the chain and go back, if required. It might be - // required if the original stores come in reversed order, for example. - if (ConsecutiveChain[I].first != E && - Tails.test(ConsecutiveChain[I].first) && !TriedTails.test(I) && - !VectorizedStores.count(Stores[ConsecutiveChain[I].first])) { - TriedTails.set(I); - Tails.reset(ConsecutiveChain[I].first); - if (Cnt < ConsecutiveChain[I].first + 2) - Cnt = ConsecutiveChain[I].first + 2; - } - break; + for (auto [Idx, Data] : enumerate(Set)) { + if (Operands.empty() || Data.second - PrevDist == 1) { + Operands.push_back(Stores[Data.first]); + PrevDist = Data.second; + if (Idx != Set.size() - 1) + continue; + } + if (Operands.size() <= 1) { + Operands.clear(); + Operands.push_back(Stores[Data.first]); + PrevDist = Data.second; + continue; } - // Move to the next value in the chain. - I = ConsecutiveChain[I].first; - } - assert(!Operands.empty() && "Expected non-empty list of stores."); - unsigned MaxVecRegSize = R.getMaxVecRegSize(); - unsigned EltSize = R.getVectorElementSize(Operands[0]); - unsigned MaxElts = llvm::bit_floor(MaxVecRegSize / EltSize); + unsigned MaxVecRegSize = R.getMaxVecRegSize(); + unsigned EltSize = R.getVectorElementSize(Operands[0]); + unsigned MaxElts = llvm::bit_floor(MaxVecRegSize / EltSize); - unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store), - MaxElts); - auto *Store = cast<StoreInst>(Operands[0]); - Type *StoreTy = Store->getValueOperand()->getType(); - Type *ValueTy = StoreTy; - if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand())) - ValueTy = Trunc->getSrcTy(); - unsigned MinVF = TTI->getStoreMinimumVF( - R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy); + unsigned MaxVF = + std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts); + auto *Store = cast<StoreInst>(Operands[0]); + Type *StoreTy = Store->getValueOperand()->getType(); + Type *ValueTy = StoreTy; + if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand())) + ValueTy = Trunc->getSrcTy(); + unsigned MinVF = TTI->getStoreMinimumVF( + R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy); - if (MaxVF <= MinVF) { - LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF << ") <= " - << "MinVF (" << MinVF << ")\n"); - } + if (MaxVF <= MinVF) { + LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF + << ") <= " + << "MinVF (" << MinVF << ")\n"); + } - // FIXME: Is division-by-2 the correct step? Should we assert that the - // register size is a power-of-2? - unsigned StartIdx = 0; - for (unsigned Size = MaxVF; Size >= MinVF; Size /= 2) { - for (unsigned Cnt = StartIdx, E = Operands.size(); Cnt + Size <= E;) { - ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size); - if (!VectorizedStores.count(Slice.front()) && - !VectorizedStores.count(Slice.back()) && - vectorizeStoreChain(Slice, R, Cnt, MinVF)) { - // Mark the vectorized stores so that we don't vectorize them again. - VectorizedStores.insert(Slice.begin(), Slice.end()); - Changed = true; - // If we vectorized initial block, no need to try to vectorize it - // again. - if (Cnt == StartIdx) - StartIdx += Size; - Cnt += Size; - continue; + // FIXME: Is division-by-2 the correct step? Should we assert that the + // register size is a power-of-2? + unsigned StartIdx = 0; + for (unsigned Size = MaxVF; Size >= MinVF; Size /= 2) { + for (unsigned Cnt = StartIdx, E = Operands.size(); Cnt + Size <= E;) { + ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size); + assert( + all_of( + Slice, + [&](Value *V) { + return cast<StoreInst>(V)->getValueOperand()->getType() == + cast<StoreInst>(Slice.front()) + ->getValueOperand() + ->getType(); + }) && + "Expected all operands of same type."); + if (!VectorizedStores.count(Slice.front()) && + !VectorizedStores.count(Slice.back()) && + TriedSequences.insert(std::make_pair(Slice.front(), Slice.back())) + .second && + vectorizeStoreChain(Slice, R, Cnt, MinVF)) { + // Mark the vectorized stores so that we don't vectorize them again. + VectorizedStores.insert(Slice.begin(), Slice.end()); + Changed = true; + // If we vectorized initial block, no need to try to vectorize it + // again. + if (Cnt == StartIdx) + StartIdx += Size; + Cnt += Size; + continue; + } + ++Cnt; } - ++Cnt; + // Check if the whole array was vectorized already - exit. + if (StartIdx >= Operands.size()) + break; } - // Check if the whole array was vectorized already - exit. - if (StartIdx >= Operands.size()) - break; + Operands.clear(); + Operands.push_back(Stores[Data.first]); + PrevDist = Data.second; } + }; + + // Stores pair (first: index of the store into Stores array ref, address of + // which taken as base, second: sorted set of pairs {index, dist}, which are + // indices of stores in the set and their store location distances relative to + // the base address). + + // Need to store the index of the very first store separately, since the set + // may be reordered after the insertion and the first store may be moved. This + // container allows to reduce number of calls of getPointersDiff() function. + SmallVector<std::pair<unsigned, StoreIndexToDistSet>> SortedStores; + // Inserts the specified store SI with the given index Idx to the set of the + // stores. If the store with the same distance is found already - stop + // insertion, try to vectorize already found stores. If some stores from this + // sequence were not vectorized - try to vectorize them with the new store + // later. But this logic is applied only to the stores, that come before the + // previous store with the same distance. + // Example: + // 1. store x, %p + // 2. store y, %p+1 + // 3. store z, %p+2 + // 4. store a, %p + // 5. store b, %p+3 + // - Scan this from the last to first store. The very first bunch of stores is + // {5, {{4, -3}, {2, -2}, {3, -1}, {5, 0}}} (the element in SortedStores + // vector). + // - The next store in the list - #1 - has the same distance from store #5 as + // the store #4. + // - Try to vectorize sequence of stores 4,2,3,5. + // - If all these stores are vectorized - just drop them. + // - If some of them are not vectorized (say, #3 and #5), do extra analysis. + // - Start new stores sequence. + // The new bunch of stores is {1, {1, 0}}. + // - Add the stores from previous sequence, that were not vectorized. + // Here we consider the stores in the reversed order, rather they are used in + // the IR (Stores are reversed already, see vectorizeStoreChains() function). + // Store #3 can be added -> comes after store #4 with the same distance as + // store #1. + // Store #5 cannot be added - comes before store #4. + // This logic allows to improve the compile time, we assume that the stores + // after previous store with the same distance most likely have memory + // dependencies and no need to waste compile time to try to vectorize them. + // - Try to vectorize the sequence {1, {1, 0}, {3, 2}}. + auto FillStoresSet = [&](unsigned Idx, StoreInst *SI) { + for (std::pair<unsigned, StoreIndexToDistSet> &Set : SortedStores) { + std::optional<int> Diff = getPointersDiff( + Stores[Set.first]->getValueOperand()->getType(), + Stores[Set.first]->getPointerOperand(), + SI->getValueOperand()->getType(), SI->getPointerOperand(), *DL, *SE, + /*StrictCheck=*/true); + if (!Diff) + continue; + auto It = Set.second.find(std::make_pair(Idx, *Diff)); + if (It == Set.second.end()) { + Set.second.emplace(Idx, *Diff); + return; + } + // Try to vectorize the first found set to avoid duplicate analysis. + TryToVectorize(Set.second); + StoreIndexToDistSet PrevSet; + PrevSet.swap(Set.second); + Set.first = Idx; + Set.second.emplace(Idx, 0); + // Insert stores that followed previous match to try to vectorize them + // with this store. + unsigned StartIdx = It->first + 1; + SmallBitVector UsedStores(Idx - StartIdx); + // Distances to previously found dup store (or this store, since they + // store to the same addresses). + SmallVector<int> Dists(Idx - StartIdx, 0); + for (const std::pair<unsigned, int> &Pair : reverse(PrevSet)) { + // Do not try to vectorize sequences, we already tried. + if (Pair.first <= It->first || + VectorizedStores.contains(Stores[Pair.first])) + break; + unsigned BI = Pair.first - StartIdx; + UsedStores.set(BI); + Dists[BI] = Pair.second - It->second; + } + for (unsigned I = StartIdx; I < Idx; ++I) { + unsigned BI = I - StartIdx; + if (UsedStores.test(BI)) + Set.second.emplace(I, Dists[BI]); + } + return; + } + auto &Res = SortedStores.emplace_back(); + Res.first = Idx; + Res.second.emplace(Idx, 0); + }; + StoreInst *PrevStore = Stores.front(); + for (auto [I, SI] : enumerate(Stores)) { + // Check that we do not try to vectorize stores of different types. + if (PrevStore->getValueOperand()->getType() != + SI->getValueOperand()->getType()) { + for (auto &Set : SortedStores) + TryToVectorize(Set.second); + SortedStores.clear(); + PrevStore = SI; + } + FillStoresSet(I, SI); } + // Final vectorization attempt. + for (auto &Set : SortedStores) + TryToVectorize(Set.second); + return Changed; } @@ -12506,7 +13680,7 @@ void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { // constant index, or a pointer operand that doesn't point to a scalar // type. else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { - auto Idx = GEP->idx_begin()->get(); + Value *Idx = GEP->idx_begin()->get(); if (GEP->getNumIndices() > 1 || isa<Constant>(Idx)) continue; if (!isValidElementType(Idx->getType())) @@ -12541,8 +13715,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, // NOTE: the following will give user internal llvm type name, which may // not be useful. R.getORE()->emit([&]() { - std::string type_str; - llvm::raw_string_ostream rso(type_str); + std::string TypeStr; + llvm::raw_string_ostream rso(TypeStr); Ty->print(rso); return OptimizationRemarkMissed(SV_NAME, "UnsupportedType", I0) << "Cannot SLP vectorize list: type " @@ -12877,10 +14051,12 @@ class HorizontalReduction { static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, Value *RHS, const Twine &Name, const ReductionOpsListType &ReductionOps) { - bool UseSelect = ReductionOps.size() == 2 || - // Logical or/and. - (ReductionOps.size() == 1 && - isa<SelectInst>(ReductionOps.front().front())); + bool UseSelect = + ReductionOps.size() == 2 || + // Logical or/and. + (ReductionOps.size() == 1 && any_of(ReductionOps.front(), [](Value *V) { + return isa<SelectInst>(V); + })); assert((!UseSelect || ReductionOps.size() != 2 || isa<SelectInst>(ReductionOps[1][0])) && "Expected cmp + select pairs for reduction"); @@ -13314,12 +14490,26 @@ public: // Update the final value in the reduction. Builder.SetCurrentDebugLocation( cast<Instruction>(ReductionOps.front().front())->getDebugLoc()); + if ((isa<PoisonValue>(VectorizedTree) && !isa<PoisonValue>(Res)) || + (isGuaranteedNotToBePoison(Res) && + !isGuaranteedNotToBePoison(VectorizedTree))) { + auto It = ReducedValsToOps.find(Res); + if (It != ReducedValsToOps.end() && + any_of(It->getSecond(), + [](Instruction *I) { return isBoolLogicOp(I); })) + std::swap(VectorizedTree, Res); + } + return createOp(Builder, RdxKind, VectorizedTree, Res, "op.rdx", ReductionOps); } // Initialize the final value in the reduction. return Res; }; + bool AnyBoolLogicOp = + any_of(ReductionOps.back(), [](Value *V) { + return isBoolLogicOp(cast<Instruction>(V)); + }); // The reduction root is used as the insertion point for new instructions, // so set it as externally used to prevent it from being deleted. ExternallyUsedValues[ReductionRoot]; @@ -13363,10 +14553,12 @@ public: // Check if the reduction value was not overriden by the extractelement // instruction because of the vectorization and exclude it, if it is not // compatible with other values. - if (auto *Inst = dyn_cast<Instruction>(RdxVal)) - if (isVectorLikeInstWithConstOps(Inst) && - (!S.getOpcode() || !S.isOpcodeOrAlt(Inst))) - continue; + // Also check if the instruction was folded to constant/other value. + auto *Inst = dyn_cast<Instruction>(RdxVal); + if ((Inst && isVectorLikeInstWithConstOps(Inst) && + (!S.getOpcode() || !S.isOpcodeOrAlt(Inst))) || + (S.getOpcode() && !Inst)) + continue; Candidates.push_back(RdxVal); TrackedToOrig.try_emplace(RdxVal, OrigReducedVals[Cnt]); } @@ -13542,11 +14734,9 @@ public: for (unsigned Cnt = 0, Sz = ReducedVals.size(); Cnt < Sz; ++Cnt) { if (Cnt == I || (ShuffledExtracts && Cnt == I - 1)) continue; - for_each(ReducedVals[Cnt], - [&LocalExternallyUsedValues, &TrackedVals](Value *V) { - if (isa<Instruction>(V)) - LocalExternallyUsedValues[TrackedVals[V]]; - }); + for (Value *V : ReducedVals[Cnt]) + if (isa<Instruction>(V)) + LocalExternallyUsedValues[TrackedVals[V]]; } if (!IsSupportedHorRdxIdentityOp) { // Number of uses of the candidates in the vector of values. @@ -13590,7 +14780,7 @@ public: // Update LocalExternallyUsedValues for the scalar, replaced by // extractelement instructions. for (const std::pair<Value *, Value *> &Pair : ReplacedExternals) { - auto It = ExternallyUsedValues.find(Pair.first); + auto *It = ExternallyUsedValues.find(Pair.first); if (It == ExternallyUsedValues.end()) continue; LocalExternallyUsedValues[Pair.second].append(It->second); @@ -13604,7 +14794,8 @@ public: InstructionCost ReductionCost = getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF); InstructionCost Cost = TreeCost + ReductionCost; - LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); + LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost + << " for reduction\n"); if (!Cost.isValid()) return nullptr; if (Cost >= -SLPCostThreshold) { @@ -13651,7 +14842,9 @@ public: // To prevent poison from leaking across what used to be sequential, // safe, scalar boolean logic operations, the reduction operand must be // frozen. - if (isBoolLogicOp(RdxRootInst)) + if ((isBoolLogicOp(RdxRootInst) || + (AnyBoolLogicOp && VL.size() != TrackedVals.size())) && + !isGuaranteedNotToBePoison(VectorizedRoot)) VectorizedRoot = Builder.CreateFreeze(VectorizedRoot); // Emit code to correctly handle reused reduced values, if required. @@ -13663,6 +14856,16 @@ public: Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); + if (ReducedSubTree->getType() != VL.front()->getType()) { + ReducedSubTree = Builder.CreateIntCast( + ReducedSubTree, VL.front()->getType(), any_of(VL, [&](Value *R) { + KnownBits Known = computeKnownBits( + R, cast<Instruction>(ReductionOps.front().front()) + ->getModule() + ->getDataLayout()); + return !Known.isNonNegative(); + })); + } // Improved analysis for add/fadd/xor reductions with same scale factor // for all operands of reductions. We can emit scalar ops for them @@ -13715,31 +14918,33 @@ public: // RedOp2 = select i1 ?, i1 RHS, i1 false // Then, we must freeze LHS in the new op. - auto &&FixBoolLogicalOps = - [&Builder, VectorizedTree](Value *&LHS, Value *&RHS, - Instruction *RedOp1, Instruction *RedOp2) { - if (!isBoolLogicOp(RedOp1)) - return; - if (LHS == VectorizedTree || getRdxOperand(RedOp1, 0) == LHS || - isGuaranteedNotToBePoison(LHS)) - return; - if (!isBoolLogicOp(RedOp2)) - return; - if (RHS == VectorizedTree || getRdxOperand(RedOp2, 0) == RHS || - isGuaranteedNotToBePoison(RHS)) { - std::swap(LHS, RHS); - return; - } - LHS = Builder.CreateFreeze(LHS); - }; + auto FixBoolLogicalOps = [&, VectorizedTree](Value *&LHS, Value *&RHS, + Instruction *RedOp1, + Instruction *RedOp2, + bool InitStep) { + if (!AnyBoolLogicOp) + return; + if (isBoolLogicOp(RedOp1) && + ((!InitStep && LHS == VectorizedTree) || + getRdxOperand(RedOp1, 0) == LHS || isGuaranteedNotToBePoison(LHS))) + return; + if (isBoolLogicOp(RedOp2) && ((!InitStep && RHS == VectorizedTree) || + getRdxOperand(RedOp2, 0) == RHS || + isGuaranteedNotToBePoison(RHS))) { + std::swap(LHS, RHS); + return; + } + if (LHS != VectorizedTree) + LHS = Builder.CreateFreeze(LHS); + }; // Finish the reduction. // Need to add extra arguments and not vectorized possible reduction // values. // Try to avoid dependencies between the scalar remainders after // reductions. - auto &&FinalGen = - [this, &Builder, &TrackedVals, &FixBoolLogicalOps]( - ArrayRef<std::pair<Instruction *, Value *>> InstVals) { + auto FinalGen = + [&](ArrayRef<std::pair<Instruction *, Value *>> InstVals, + bool InitStep) { unsigned Sz = InstVals.size(); SmallVector<std::pair<Instruction *, Value *>> ExtraReds(Sz / 2 + Sz % 2); @@ -13760,7 +14965,7 @@ public: // sequential, safe, scalar boolean logic operations, the // reduction operand must be frozen. FixBoolLogicalOps(StableRdxVal1, StableRdxVal2, InstVals[I].first, - RedOp); + RedOp, InitStep); Value *ExtraRed = createOp(Builder, RdxKind, StableRdxVal1, StableRdxVal2, "op.rdx", ReductionOps); ExtraReds[I / 2] = std::make_pair(InstVals[I].first, ExtraRed); @@ -13790,11 +14995,13 @@ public: ExtraReductions.emplace_back(I, Pair.first); } // Iterate through all not-vectorized reduction values/extra arguments. + bool InitStep = true; while (ExtraReductions.size() > 1) { VectorizedTree = ExtraReductions.front().second; SmallVector<std::pair<Instruction *, Value *>> NewReds = - FinalGen(ExtraReductions); + FinalGen(ExtraReductions, InitStep); ExtraReductions.swap(NewReds); + InitStep = false; } VectorizedTree = ExtraReductions.front().second; @@ -13841,8 +15048,7 @@ private: bool IsCmpSelMinMax, unsigned ReduxWidth, FastMathFlags FMF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - Value *FirstReducedVal = ReducedVals.front(); - Type *ScalarTy = FirstReducedVal->getType(); + Type *ScalarTy = ReducedVals.front()->getType(); FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); InstructionCost VectorCost = 0, ScalarCost; // If all of the reduced values are constant, the vector cost is 0, since @@ -13916,7 +15122,7 @@ private: } LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost - << " for reduction that starts with " << *FirstReducedVal + << " for reduction of " << shortBundleName(ReducedVals) << " (It is a splitting reduction)\n"); return VectorCost - ScalarCost; } @@ -13931,7 +15137,7 @@ private: "A call to the llvm.fmuladd intrinsic is not handled yet"); ++NumVectorInstructions; - return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind); + return createSimpleTargetReduction(Builder, VectorizedValue, RdxKind); } /// Emits optimized code for unique scalar value reused \p Cnt times. @@ -13978,8 +15184,8 @@ private: case RecurKind::Mul: case RecurKind::FMul: case RecurKind::FMulAdd: - case RecurKind::SelectICmp: - case RecurKind::SelectFCmp: + case RecurKind::IAnyOf: + case RecurKind::FAnyOf: case RecurKind::None: llvm_unreachable("Unexpected reduction kind for repeated scalar."); } @@ -14067,8 +15273,8 @@ private: case RecurKind::Mul: case RecurKind::FMul: case RecurKind::FMulAdd: - case RecurKind::SelectICmp: - case RecurKind::SelectFCmp: + case RecurKind::IAnyOf: + case RecurKind::FAnyOf: case RecurKind::None: llvm_unreachable("Unexpected reduction kind for reused scalars."); } @@ -14163,8 +15369,8 @@ static bool findBuildAggregate(Instruction *LastInsertInst, InsertElts.resize(*AggregateSize); findBuildAggregate_rec(LastInsertInst, TTI, BuildVectorOpds, InsertElts, 0); - llvm::erase_value(BuildVectorOpds, nullptr); - llvm::erase_value(InsertElts, nullptr); + llvm::erase(BuildVectorOpds, nullptr); + llvm::erase(InsertElts, nullptr); if (BuildVectorOpds.size() >= 2) return true; @@ -14400,8 +15606,7 @@ bool SLPVectorizerPass::tryToVectorize(ArrayRef<WeakTrackingVH> Insts, bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, BasicBlock *BB, BoUpSLP &R) { - const DataLayout &DL = BB->getModule()->getDataLayout(); - if (!R.canMapToVector(IVI->getType(), DL)) + if (!R.canMapToVector(IVI->getType())) return false; SmallVector<Value *, 16> BuildVectorOpds; @@ -14540,11 +15745,11 @@ static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI, if (BasePred1 > BasePred2) return false; // Compare operands. - bool LEPreds = Pred1 <= Pred2; - bool GEPreds = Pred1 >= Pred2; + bool CI1Preds = Pred1 == BasePred1; + bool CI2Preds = Pred2 == BasePred1; for (int I = 0, E = CI1->getNumOperands(); I < E; ++I) { - auto *Op1 = CI1->getOperand(LEPreds ? I : E - I - 1); - auto *Op2 = CI2->getOperand(GEPreds ? I : E - I - 1); + auto *Op1 = CI1->getOperand(CI1Preds ? I : E - I - 1); + auto *Op2 = CI2->getOperand(CI2Preds ? I : E - I - 1); if (Op1->getValueID() < Op2->getValueID()) return !IsCompatibility; if (Op1->getValueID() > Op2->getValueID()) @@ -14690,14 +15895,20 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { return true; if (Opcodes1.size() > Opcodes2.size()) return false; - std::optional<bool> ConstOrder; for (int I = 0, E = Opcodes1.size(); I < E; ++I) { // Undefs are compatible with any other value. if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) { - if (!ConstOrder) - ConstOrder = - !isa<UndefValue>(Opcodes1[I]) && isa<UndefValue>(Opcodes2[I]); - continue; + if (isa<Instruction>(Opcodes1[I])) + return true; + if (isa<Instruction>(Opcodes2[I])) + return false; + if (isa<Constant>(Opcodes1[I]) && !isa<UndefValue>(Opcodes1[I])) + return true; + if (isa<Constant>(Opcodes2[I]) && !isa<UndefValue>(Opcodes2[I])) + return false; + if (isa<UndefValue>(Opcodes1[I]) && isa<UndefValue>(Opcodes2[I])) + continue; + return isa<UndefValue>(Opcodes2[I]); } if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { @@ -14713,21 +15924,26 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (NodeI1 != NodeI2) return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); InstructionsState S = getSameOpcode({I1, I2}, *TLI); - if (S.getOpcode()) + if (S.getOpcode() && !S.isAltShuffle()) continue; return I1->getOpcode() < I2->getOpcode(); } - if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) { - if (!ConstOrder) - ConstOrder = Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID(); - continue; - } + if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) + return Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID(); + if (isa<Instruction>(Opcodes1[I])) + return true; + if (isa<Instruction>(Opcodes2[I])) + return false; + if (isa<Constant>(Opcodes1[I])) + return true; + if (isa<Constant>(Opcodes2[I])) + return false; if (Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID()) return true; if (Opcodes1[I]->getValueID() > Opcodes2[I]->getValueID()) return false; } - return ConstOrder && *ConstOrder; + return false; }; auto AreCompatiblePHIs = [&PHIToOpcodes, this](Value *V1, Value *V2) { if (V1 == V2) @@ -14775,6 +15991,9 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { Incoming.push_back(P); } + if (Incoming.size() <= 1) + break; + // Find the corresponding non-phi nodes for better matching when trying to // build the tree. for (Value *V : Incoming) { @@ -14837,41 +16056,41 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { return I->use_empty() && (I->getType()->isVoidTy() || isa<CallInst, InvokeInst>(I)); }; - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + for (BasicBlock::iterator It = BB->begin(), E = BB->end(); It != E; ++It) { // Skip instructions with scalable type. The num of elements is unknown at // compile-time for scalable type. - if (isa<ScalableVectorType>(it->getType())) + if (isa<ScalableVectorType>(It->getType())) continue; // Skip instructions marked for the deletion. - if (R.isDeleted(&*it)) + if (R.isDeleted(&*It)) continue; // We may go through BB multiple times so skip the one we have checked. - if (!VisitedInstrs.insert(&*it).second) { - if (HasNoUsers(&*it) && - VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator())) { + if (!VisitedInstrs.insert(&*It).second) { + if (HasNoUsers(&*It) && + VectorizeInsertsAndCmps(/*VectorizeCmps=*/It->isTerminator())) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. Changed = true; - it = BB->begin(); - e = BB->end(); + It = BB->begin(); + E = BB->end(); } continue; } - if (isa<DbgInfoIntrinsic>(it)) + if (isa<DbgInfoIntrinsic>(It)) continue; // Try to vectorize reductions that use PHINodes. - if (PHINode *P = dyn_cast<PHINode>(it)) { + if (PHINode *P = dyn_cast<PHINode>(It)) { // Check that the PHI is a reduction PHI. if (P->getNumIncomingValues() == 2) { // Try to match and vectorize a horizontal reduction. Instruction *Root = getReductionInstr(DT, P, BB, LI); if (Root && vectorizeRootInstruction(P, Root, BB, R, TTI)) { Changed = true; - it = BB->begin(); - e = BB->end(); + It = BB->begin(); + E = BB->end(); continue; } } @@ -14896,23 +16115,23 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; } - if (HasNoUsers(&*it)) { + if (HasNoUsers(&*It)) { bool OpsChanged = false; - auto *SI = dyn_cast<StoreInst>(it); + auto *SI = dyn_cast<StoreInst>(It); bool TryToVectorizeRoot = ShouldStartVectorizeHorAtStore || !SI; if (SI) { - auto I = Stores.find(getUnderlyingObject(SI->getPointerOperand())); + auto *I = Stores.find(getUnderlyingObject(SI->getPointerOperand())); // Try to vectorize chain in store, if this is the only store to the // address in the block. // TODO: This is just a temporarily solution to save compile time. Need // to investigate if we can safely turn on slp-vectorize-hor-store // instead to allow lookup for reduction chains in all non-vectorized // stores (need to check side effects and compile time). - TryToVectorizeRoot = (I == Stores.end() || I->second.size() == 1) && - SI->getValueOperand()->hasOneUse(); + TryToVectorizeRoot |= (I == Stores.end() || I->second.size() == 1) && + SI->getValueOperand()->hasOneUse(); } if (TryToVectorizeRoot) { - for (auto *V : it->operand_values()) { + for (auto *V : It->operand_values()) { // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *VI = dyn_cast<Instruction>(V); @@ -14925,21 +16144,21 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // top-tree instructions to try to vectorize as many instructions as // possible. OpsChanged |= - VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator()); + VectorizeInsertsAndCmps(/*VectorizeCmps=*/It->isTerminator()); if (OpsChanged) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. Changed = true; - it = BB->begin(); - e = BB->end(); + It = BB->begin(); + E = BB->end(); continue; } } - if (isa<InsertElementInst, InsertValueInst>(it)) - PostProcessInserts.insert(&*it); - else if (isa<CmpInst>(it)) - PostProcessCmps.insert(cast<CmpInst>(&*it)); + if (isa<InsertElementInst, InsertValueInst>(It)) + PostProcessInserts.insert(&*It); + else if (isa<CmpInst>(It)) + PostProcessCmps.insert(cast<CmpInst>(&*It)); } return Changed; @@ -15043,6 +16262,12 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { // compatible (have the same opcode, same parent), otherwise it is // definitely not profitable to try to vectorize them. auto &&StoreSorter = [this](StoreInst *V, StoreInst *V2) { + if (V->getValueOperand()->getType()->getTypeID() < + V2->getValueOperand()->getType()->getTypeID()) + return true; + if (V->getValueOperand()->getType()->getTypeID() > + V2->getValueOperand()->getType()->getTypeID()) + return false; if (V->getPointerOperandType()->getTypeID() < V2->getPointerOperandType()->getTypeID()) return true; @@ -15081,6 +16306,8 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { auto &&AreCompatibleStores = [this](StoreInst *V1, StoreInst *V2) { if (V1 == V2) return true; + if (V1->getValueOperand()->getType() != V2->getValueOperand()->getType()) + return false; if (V1->getPointerOperandType() != V2->getPointerOperandType()) return false; // Undefs are compatible with any other value. @@ -15112,8 +16339,13 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { if (!isValidElementType(Pair.second.front()->getValueOperand()->getType())) continue; + // Reverse stores to do bottom-to-top analysis. This is important if the + // values are stores to the same addresses several times, in this case need + // to follow the stores order (reversed to meet the memory dependecies). + SmallVector<StoreInst *> ReversedStores(Pair.second.rbegin(), + Pair.second.rend()); Changed |= tryToVectorizeSequence<StoreInst>( - Pair.second, StoreSorter, AreCompatibleStores, + ReversedStores, StoreSorter, AreCompatibleStores, [this, &R](ArrayRef<StoreInst *> Candidates, bool) { return vectorizeStores(Candidates, R); }, diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h index 1271d1424c03..7ff6749a0908 100644 --- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -133,9 +133,12 @@ public: Ingredient2Recipe[I] = R; } + /// Create the mask for the vector loop header block. + void createHeaderMask(VPlan &Plan); + /// A helper function that computes the predicate of the block BB, assuming - /// that the header block of the loop is set to True. It returns the *entry* - /// mask for the block BB. + /// that the header block of the loop is set to True or the loop mask when + /// tail folding. It returns the *entry* mask for the block BB. VPValue *createBlockInMask(BasicBlock *BB, VPlan &Plan); /// A helper function that computes the predicate of the edge between SRC diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index e81b88fd8099..263d9938d1f0 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -19,7 +19,6 @@ #include "VPlan.h" #include "VPlanCFG.h" #include "VPlanDominatorTree.h" -#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -234,6 +233,99 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { // set(Def, Extract, Instance); return Extract; } + +Value *VPTransformState::get(VPValue *Def, unsigned Part) { + // If Values have been set for this Def return the one relevant for \p Part. + if (hasVectorValue(Def, Part)) + return Data.PerPartOutput[Def][Part]; + + auto GetBroadcastInstrs = [this, Def](Value *V) { + bool SafeToHoist = Def->isDefinedOutsideVectorRegions(); + if (VF.isScalar()) + return V; + // Place the code for broadcasting invariant variables in the new preheader. + IRBuilder<>::InsertPointGuard Guard(Builder); + if (SafeToHoist) { + BasicBlock *LoopVectorPreHeader = CFG.VPBB2IRBB[cast<VPBasicBlock>( + Plan->getVectorLoopRegion()->getSinglePredecessor())]; + if (LoopVectorPreHeader) + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + } + + // Place the code for broadcasting invariant variables in the new preheader. + // Broadcast the scalar into all locations in the vector. + Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast"); + + return Shuf; + }; + + if (!hasScalarValue(Def, {Part, 0})) { + assert(Def->isLiveIn() && "expected a live-in"); + if (Part != 0) + return get(Def, 0); + Value *IRV = Def->getLiveInIRValue(); + Value *B = GetBroadcastInstrs(IRV); + set(Def, B, Part); + return B; + } + + Value *ScalarValue = get(Def, {Part, 0}); + // If we aren't vectorizing, we can just copy the scalar map values over + // to the vector map. + if (VF.isScalar()) { + set(Def, ScalarValue, Part); + return ScalarValue; + } + + bool IsUniform = vputils::isUniformAfterVectorization(Def); + + unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1; + // Check if there is a scalar value for the selected lane. + if (!hasScalarValue(Def, {Part, LastLane})) { + // At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and + // VPExpandSCEVRecipes can also be uniform. + assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) || + isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe()) || + isa<VPExpandSCEVRecipe>(Def->getDefiningRecipe())) && + "unexpected recipe found to be invariant"); + IsUniform = true; + LastLane = 0; + } + + auto *LastInst = cast<Instruction>(get(Def, {Part, LastLane})); + // Set the insert point after the last scalarized instruction or after the + // last PHI, if LastInst is a PHI. This ensures the insertelement sequence + // will directly follow the scalar definitions. + auto OldIP = Builder.saveIP(); + auto NewIP = + isa<PHINode>(LastInst) + ? BasicBlock::iterator(LastInst->getParent()->getFirstNonPHI()) + : std::next(BasicBlock::iterator(LastInst)); + Builder.SetInsertPoint(&*NewIP); + + // However, if we are vectorizing, we need to construct the vector values. + // If the value is known to be uniform after vectorization, we can just + // broadcast the scalar value corresponding to lane zero for each unroll + // iteration. Otherwise, we construct the vector values using + // insertelement instructions. Since the resulting vectors are stored in + // State, we will only generate the insertelements once. + Value *VectorValue = nullptr; + if (IsUniform) { + VectorValue = GetBroadcastInstrs(ScalarValue); + set(Def, VectorValue, Part); + } else { + // Initialize packing with insertelements to start from undef. + assert(!VF.isScalable() && "VF is assumed to be non scalable."); + Value *Undef = PoisonValue::get(VectorType::get(LastInst->getType(), VF)); + set(Def, Undef, Part); + for (unsigned Lane = 0; Lane < VF.getKnownMinValue(); ++Lane) + packScalarIntoVectorValue(Def, {Part, Lane}); + VectorValue = get(Def, Part); + } + Builder.restoreIP(OldIP); + return VectorValue; +} + BasicBlock *VPTransformState::CFGState::getPreheaderBBFor(VPRecipeBase *R) { VPRegionBlock *LoopRegion = R->getParent()->getEnclosingLoopRegion(); return VPBB2IRBB[LoopRegion->getPreheaderVPBB()]; @@ -267,18 +359,15 @@ void VPTransformState::addMetadata(ArrayRef<Value *> To, Instruction *From) { } } -void VPTransformState::setDebugLocFromInst(const Value *V) { - const Instruction *Inst = dyn_cast<Instruction>(V); - if (!Inst) { - Builder.SetCurrentDebugLocation(DebugLoc()); - return; - } - - const DILocation *DIL = Inst->getDebugLoc(); +void VPTransformState::setDebugLocFrom(DebugLoc DL) { + const DILocation *DIL = DL; // When a FSDiscriminator is enabled, we don't need to add the multiply // factors to the discriminators. - if (DIL && Inst->getFunction()->shouldEmitDebugInfoForProfiling() && - !Inst->isDebugOrPseudoInst() && !EnableFSDiscriminator) { + if (DIL && + Builder.GetInsertBlock() + ->getParent() + ->shouldEmitDebugInfoForProfiling() && + !EnableFSDiscriminator) { // FIXME: For scalable vectors, assume vscale=1. auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(UF * VF.getKnownMinValue()); @@ -291,6 +380,15 @@ void VPTransformState::setDebugLocFromInst(const Value *V) { Builder.SetCurrentDebugLocation(DIL); } +void VPTransformState::packScalarIntoVectorValue(VPValue *Def, + const VPIteration &Instance) { + Value *ScalarInst = get(Def, Instance); + Value *VectorValue = get(Def, Instance.Part); + VectorValue = Builder.CreateInsertElement( + VectorValue, ScalarInst, Instance.Lane.getAsRuntimeExpr(Builder, VF)); + set(Def, VectorValue, Instance.Part); +} + BasicBlock * VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { // BB stands for IR BasicBlocks. VPBB stands for VPlan VPBasicBlocks. @@ -616,22 +714,17 @@ VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE) { auto Plan = std::make_unique<VPlan>(Preheader, VecPreheader); Plan->TripCount = vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE); + // Create empty VPRegionBlock, to be filled during processing later. + auto *TopRegion = new VPRegionBlock("vector loop", false /*isReplicator*/); + VPBlockUtils::insertBlockAfter(TopRegion, VecPreheader); + VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block"); + VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion); return Plan; } -VPActiveLaneMaskPHIRecipe *VPlan::getActiveLaneMaskPhi() { - VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock(); - for (VPRecipeBase &R : Header->phis()) { - if (isa<VPActiveLaneMaskPHIRecipe>(&R)) - return cast<VPActiveLaneMaskPHIRecipe>(&R); - } - return nullptr; -} - void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, Value *CanonicalIVStartValue, - VPTransformState &State, - bool IsEpilogueVectorization) { + VPTransformState &State) { // Check if the backedge taken count is needed, and if so build it. if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) { IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); @@ -648,6 +741,12 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) State.set(&VectorTripCount, VectorTripCountV, Part); + IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); + // FIXME: Model VF * UF computation completely in VPlan. + State.set(&VFxUF, + createStepForVF(Builder, TripCountV->getType(), State.VF, State.UF), + 0); + // When vectorizing the epilogue loop, the canonical induction start value // needs to be changed from zero to the value after the main vector loop. // FIXME: Improve modeling for canonical IV start values in the epilogue loop. @@ -656,16 +755,12 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, auto *IV = getCanonicalIV(); assert(all_of(IV->users(), [](const VPUser *U) { - if (isa<VPScalarIVStepsRecipe>(U) || - isa<VPDerivedIVRecipe>(U)) - return true; - auto *VPI = cast<VPInstruction>(U); - return VPI->getOpcode() == - VPInstruction::CanonicalIVIncrement || - VPI->getOpcode() == - VPInstruction::CanonicalIVIncrementNUW; + return isa<VPScalarIVStepsRecipe>(U) || + isa<VPDerivedIVRecipe>(U) || + cast<VPInstruction>(U)->getOpcode() == + Instruction::Add; }) && - "the canonical IV should only be used by its increments or " + "the canonical IV should only be used by its increment or " "ScalarIVSteps when resetting the start value"); IV->setOperand(0, VPV); } @@ -754,11 +849,14 @@ void VPlan::execute(VPTransformState *State) { } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -LLVM_DUMP_METHOD -void VPlan::print(raw_ostream &O) const { +void VPlan::printLiveIns(raw_ostream &O) const { VPSlotTracker SlotTracker(this); - O << "VPlan '" << getName() << "' {"; + if (VFxUF.getNumUsers() > 0) { + O << "\nLive-in "; + VFxUF.printAsOperand(O, SlotTracker); + O << " = VF * UF"; + } if (VectorTripCount.getNumUsers() > 0) { O << "\nLive-in "; @@ -778,6 +876,15 @@ void VPlan::print(raw_ostream &O) const { TripCount->printAsOperand(O, SlotTracker); O << " = original trip-count"; O << "\n"; +} + +LLVM_DUMP_METHOD +void VPlan::print(raw_ostream &O) const { + VPSlotTracker SlotTracker(this); + + O << "VPlan '" << getName() << "' {"; + + printLiveIns(O); if (!getPreheader()->empty()) { O << "\n"; @@ -895,11 +1002,18 @@ void VPlanPrinter::dump() { OS << "graph [labelloc=t, fontsize=30; label=\"Vectorization Plan"; if (!Plan.getName().empty()) OS << "\\n" << DOT::EscapeString(Plan.getName()); - if (Plan.BackedgeTakenCount) { - OS << ", where:\\n"; - Plan.BackedgeTakenCount->print(OS, SlotTracker); - OS << " := BackedgeTakenCount"; + + { + // Print live-ins. + std::string Str; + raw_string_ostream SS(Str); + Plan.printLiveIns(SS); + SmallVector<StringRef, 0> Lines; + StringRef(Str).rtrim('\n').split(Lines, "\n"); + for (auto Line : Lines) + OS << DOT::EscapeString(Line.str()) << "\\n"; } + OS << "\"]\n"; OS << "node [shape=rect, fontname=Courier, fontsize=30]\n"; OS << "edge [fontname=Courier, fontsize=30]\n"; @@ -1035,6 +1149,26 @@ void VPValue::replaceAllUsesWith(VPValue *New) { } } +void VPValue::replaceUsesWithIf( + VPValue *New, + llvm::function_ref<bool(VPUser &U, unsigned Idx)> ShouldReplace) { + for (unsigned J = 0; J < getNumUsers();) { + VPUser *User = Users[J]; + unsigned NumUsers = getNumUsers(); + for (unsigned I = 0, E = User->getNumOperands(); I < E; ++I) { + if (User->getOperand(I) != this || !ShouldReplace(*User, I)) + continue; + + User->setOperand(I, New); + } + // If a user got removed after updating the current user, the next user to + // update will be moved to the current position, so we only need to + // increment the index if the number of users did not change. + if (NumUsers == getNumUsers()) + J++; + } +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPValue::printAsOperand(raw_ostream &OS, VPSlotTracker &Tracker) const { if (const Value *UV = getUnderlyingValue()) { @@ -1116,6 +1250,8 @@ void VPSlotTracker::assignSlot(const VPValue *V) { } void VPSlotTracker::assignSlots(const VPlan &Plan) { + if (Plan.VFxUF.getNumUsers() > 0) + assignSlot(&Plan.VFxUF); assignSlot(&Plan.VectorTripCount); if (Plan.BackedgeTakenCount) assignSlot(Plan.BackedgeTakenCount); @@ -1139,6 +1275,11 @@ bool vputils::onlyFirstLaneUsed(VPValue *Def) { [Def](VPUser *U) { return U->onlyFirstLaneUsed(Def); }); } +bool vputils::onlyFirstPartUsed(VPValue *Def) { + return all_of(Def->users(), + [Def](VPUser *U) { return U->onlyFirstPartUsed(Def); }); +} + VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, ScalarEvolution &SE) { if (auto *Expanded = Plan.getSCEVExpansion(Expr)) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 73313465adea..94cb76889813 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -23,6 +23,7 @@ #ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_H #define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H +#include "VPlanAnalysis.h" #include "VPlanValue.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -233,9 +234,9 @@ struct VPIteration { struct VPTransformState { VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder, - InnerLoopVectorizer *ILV, VPlan *Plan) + InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx) : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan), - LVer(nullptr) {} + LVer(nullptr), TypeAnalysis(Ctx) {} /// The chosen Vectorization and Unroll Factors of the loop being vectorized. ElementCount VF; @@ -274,10 +275,6 @@ struct VPTransformState { I->second[Part]; } - bool hasAnyVectorValue(VPValue *Def) const { - return Data.PerPartOutput.contains(Def); - } - bool hasScalarValue(VPValue *Def, VPIteration Instance) { auto I = Data.PerPartScalars.find(Def); if (I == Data.PerPartScalars.end()) @@ -349,8 +346,11 @@ struct VPTransformState { /// vector of instructions. void addMetadata(ArrayRef<Value *> To, Instruction *From); - /// Set the debug location in the builder using the debug location in \p V. - void setDebugLocFromInst(const Value *V); + /// Set the debug location in the builder using the debug location \p DL. + void setDebugLocFrom(DebugLoc DL); + + /// Construct the vector value of a scalarized value \p V one lane at a time. + void packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance); /// Hold state information used when constructing the CFG of the output IR, /// traversing the VPBasicBlocks and generating corresponding IR BasicBlocks. @@ -410,6 +410,9 @@ struct VPTransformState { /// Map SCEVs to their expanded values. Populated when executing /// VPExpandSCEVRecipes. DenseMap<const SCEV *, Value *> ExpandedSCEVs; + + /// VPlan-based type analysis. + VPTypeAnalysis TypeAnalysis; }; /// VPBlockBase is the building block of the Hierarchical Control-Flow Graph. @@ -582,6 +585,8 @@ public: /// This VPBlockBase must have no successors. void setOneSuccessor(VPBlockBase *Successor) { assert(Successors.empty() && "Setting one successor when others exist."); + assert(Successor->getParent() == getParent() && + "connected blocks must have the same parent"); appendSuccessor(Successor); } @@ -693,7 +698,7 @@ public: }; /// VPRecipeBase is a base class modeling a sequence of one or more output IR -/// instructions. VPRecipeBase owns the the VPValues it defines through VPDef +/// instructions. VPRecipeBase owns the VPValues it defines through VPDef /// and is responsible for deleting its defined values. Single-value /// VPRecipeBases that also inherit from VPValue must make sure to inherit from /// VPRecipeBase before VPValue. @@ -706,13 +711,18 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>, /// Each VPRecipe belongs to a single VPBasicBlock. VPBasicBlock *Parent = nullptr; + /// The debug location for the recipe. + DebugLoc DL; + public: - VPRecipeBase(const unsigned char SC, ArrayRef<VPValue *> Operands) - : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe) {} + VPRecipeBase(const unsigned char SC, ArrayRef<VPValue *> Operands, + DebugLoc DL = {}) + : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe), DL(DL) {} template <typename IterT> - VPRecipeBase(const unsigned char SC, iterator_range<IterT> Operands) - : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe) {} + VPRecipeBase(const unsigned char SC, iterator_range<IterT> Operands, + DebugLoc DL = {}) + : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe), DL(DL) {} virtual ~VPRecipeBase() = default; /// \return the VPBasicBlock which this VPRecipe belongs to. @@ -789,6 +799,9 @@ public: bool mayReadOrWriteMemory() const { return mayReadFromMemory() || mayWriteToMemory(); } + + /// Returns the debug location of the recipe. + DebugLoc getDebugLoc() const { return DL; } }; // Helper macro to define common classof implementations for recipes. @@ -808,153 +821,30 @@ public: return R->getVPDefID() == VPDefID; \ } -/// This is a concrete Recipe that models a single VPlan-level instruction. -/// While as any Recipe it may generate a sequence of IR instructions when -/// executed, these instructions would always form a single-def expression as -/// the VPInstruction is also a single def-use vertex. -class VPInstruction : public VPRecipeBase, public VPValue { - friend class VPlanSlp; - -public: - /// VPlan opcodes, extending LLVM IR with idiomatics instructions. - enum { - FirstOrderRecurrenceSplice = - Instruction::OtherOpsEnd + 1, // Combines the incoming and previous - // values of a first-order recurrence. - Not, - ICmpULE, - SLPLoad, - SLPStore, - ActiveLaneMask, - CalculateTripCountMinusVF, - CanonicalIVIncrement, - CanonicalIVIncrementNUW, - // The next two are similar to the above, but instead increment the - // canonical IV separately for each unrolled part. - CanonicalIVIncrementForPart, - CanonicalIVIncrementForPartNUW, - BranchOnCount, - BranchOnCond - }; - -private: - typedef unsigned char OpcodeTy; - OpcodeTy Opcode; - FastMathFlags FMF; - DebugLoc DL; - - /// An optional name that can be used for the generated IR instruction. - const std::string Name; - - /// Utility method serving execute(): generates a single instance of the - /// modeled instruction. \returns the generated value for \p Part. - /// In some cases an existing value is returned rather than a generated - /// one. - Value *generateInstruction(VPTransformState &State, unsigned Part); - -protected: - void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } - -public: - VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL, - const Twine &Name = "") - : VPRecipeBase(VPDef::VPInstructionSC, Operands), VPValue(this), - Opcode(Opcode), DL(DL), Name(Name.str()) {} - - VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, - DebugLoc DL = {}, const Twine &Name = "") - : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {} - - VP_CLASSOF_IMPL(VPDef::VPInstructionSC) - - VPInstruction *clone() const { - SmallVector<VPValue *, 2> Operands(operands()); - return new VPInstruction(Opcode, Operands, DL, Name); - } - - unsigned getOpcode() const { return Opcode; } - - /// Generate the instruction. - /// TODO: We currently execute only per-part unless a specific instance is - /// provided. - void execute(VPTransformState &State) override; - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - /// Print the VPInstruction to \p O. - void print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const override; - - /// Print the VPInstruction to dbgs() (for debugging). - LLVM_DUMP_METHOD void dump() const; -#endif - - /// Return true if this instruction may modify memory. - bool mayWriteToMemory() const { - // TODO: we can use attributes of the called function to rule out memory - // modifications. - return Opcode == Instruction::Store || Opcode == Instruction::Call || - Opcode == Instruction::Invoke || Opcode == SLPStore; - } - - bool hasResult() const { - // CallInst may or may not have a result, depending on the called function. - // Conservatively return calls have results for now. - switch (getOpcode()) { - case Instruction::Ret: - case Instruction::Br: - case Instruction::Store: - case Instruction::Switch: - case Instruction::IndirectBr: - case Instruction::Resume: - case Instruction::CatchRet: - case Instruction::Unreachable: - case Instruction::Fence: - case Instruction::AtomicRMW: - case VPInstruction::BranchOnCond: - case VPInstruction::BranchOnCount: - return false; - default: - return true; - } - } - - /// Set the fast-math flags. - void setFastMathFlags(FastMathFlags FMFNew); - - /// Returns true if the recipe only uses the first lane of operand \p Op. - bool onlyFirstLaneUsed(const VPValue *Op) const override { - assert(is_contained(operands(), Op) && - "Op must be an operand of the recipe"); - if (getOperand(0) != Op) - return false; - switch (getOpcode()) { - default: - return false; - case VPInstruction::ActiveLaneMask: - case VPInstruction::CalculateTripCountMinusVF: - case VPInstruction::CanonicalIVIncrement: - case VPInstruction::CanonicalIVIncrementNUW: - case VPInstruction::CanonicalIVIncrementForPart: - case VPInstruction::CanonicalIVIncrementForPartNUW: - case VPInstruction::BranchOnCount: - return true; - }; - llvm_unreachable("switch should return"); - } -}; - /// Class to record LLVM IR flag for a recipe along with it. class VPRecipeWithIRFlags : public VPRecipeBase { enum class OperationType : unsigned char { + Cmp, OverflowingBinOp, + DisjointOp, PossiblyExactOp, GEPOp, FPMathOp, + NonNegOp, Other }; + +public: struct WrapFlagsTy { char HasNUW : 1; char HasNSW : 1; + + WrapFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {} + }; + +private: + struct DisjointFlagsTy { + char IsDisjoint : 1; }; struct ExactFlagsTy { char IsExact : 1; @@ -962,6 +852,9 @@ class VPRecipeWithIRFlags : public VPRecipeBase { struct GEPFlagsTy { char IsInBounds : 1; }; + struct NonNegFlagsTy { + char NonNeg : 1; + }; struct FastMathFlagsTy { char AllowReassoc : 1; char NoNaNs : 1; @@ -970,56 +863,81 @@ class VPRecipeWithIRFlags : public VPRecipeBase { char AllowReciprocal : 1; char AllowContract : 1; char ApproxFunc : 1; + + FastMathFlagsTy(const FastMathFlags &FMF); }; OperationType OpType; union { + CmpInst::Predicate CmpPredicate; WrapFlagsTy WrapFlags; + DisjointFlagsTy DisjointFlags; ExactFlagsTy ExactFlags; GEPFlagsTy GEPFlags; + NonNegFlagsTy NonNegFlags; FastMathFlagsTy FMFs; - unsigned char AllFlags; + unsigned AllFlags; }; public: template <typename IterT> - VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands) - : VPRecipeBase(SC, Operands) { + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, DebugLoc DL = {}) + : VPRecipeBase(SC, Operands, DL) { OpType = OperationType::Other; AllFlags = 0; } template <typename IterT> - VPRecipeWithIRFlags(const unsigned char SC, iterator_range<IterT> Operands, - Instruction &I) - : VPRecipeWithIRFlags(SC, Operands) { - if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) { + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I) + : VPRecipeWithIRFlags(SC, Operands, I.getDebugLoc()) { + if (auto *Op = dyn_cast<CmpInst>(&I)) { + OpType = OperationType::Cmp; + CmpPredicate = Op->getPredicate(); + } else if (auto *Op = dyn_cast<PossiblyDisjointInst>(&I)) { + OpType = OperationType::DisjointOp; + DisjointFlags.IsDisjoint = Op->isDisjoint(); + } else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) { OpType = OperationType::OverflowingBinOp; - WrapFlags.HasNUW = Op->hasNoUnsignedWrap(); - WrapFlags.HasNSW = Op->hasNoSignedWrap(); + WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()}; } else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) { OpType = OperationType::PossiblyExactOp; ExactFlags.IsExact = Op->isExact(); } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { OpType = OperationType::GEPOp; GEPFlags.IsInBounds = GEP->isInBounds(); + } else if (auto *PNNI = dyn_cast<PossiblyNonNegInst>(&I)) { + OpType = OperationType::NonNegOp; + NonNegFlags.NonNeg = PNNI->hasNonNeg(); } else if (auto *Op = dyn_cast<FPMathOperator>(&I)) { OpType = OperationType::FPMathOp; - FastMathFlags FMF = Op->getFastMathFlags(); - FMFs.AllowReassoc = FMF.allowReassoc(); - FMFs.NoNaNs = FMF.noNaNs(); - FMFs.NoInfs = FMF.noInfs(); - FMFs.NoSignedZeros = FMF.noSignedZeros(); - FMFs.AllowReciprocal = FMF.allowReciprocal(); - FMFs.AllowContract = FMF.allowContract(); - FMFs.ApproxFunc = FMF.approxFunc(); + FMFs = Op->getFastMathFlags(); } } + template <typename IterT> + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + CmpInst::Predicate Pred, DebugLoc DL = {}) + : VPRecipeBase(SC, Operands, DL), OpType(OperationType::Cmp), + CmpPredicate(Pred) {} + + template <typename IterT> + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + WrapFlagsTy WrapFlags, DebugLoc DL = {}) + : VPRecipeBase(SC, Operands, DL), OpType(OperationType::OverflowingBinOp), + WrapFlags(WrapFlags) {} + + template <typename IterT> + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + FastMathFlags FMFs, DebugLoc DL = {}) + : VPRecipeBase(SC, Operands, DL), OpType(OperationType::FPMathOp), + FMFs(FMFs) {} + static inline bool classof(const VPRecipeBase *R) { - return R->getVPDefID() == VPRecipeBase::VPWidenSC || + return R->getVPDefID() == VPRecipeBase::VPInstructionSC || + R->getVPDefID() == VPRecipeBase::VPWidenSC || R->getVPDefID() == VPRecipeBase::VPWidenGEPSC || + R->getVPDefID() == VPRecipeBase::VPWidenCastSC || R->getVPDefID() == VPRecipeBase::VPReplicateSC; } @@ -1032,6 +950,9 @@ public: WrapFlags.HasNUW = false; WrapFlags.HasNSW = false; break; + case OperationType::DisjointOp: + DisjointFlags.IsDisjoint = false; + break; case OperationType::PossiblyExactOp: ExactFlags.IsExact = false; break; @@ -1042,6 +963,10 @@ public: FMFs.NoNaNs = false; FMFs.NoInfs = false; break; + case OperationType::NonNegOp: + NonNegFlags.NonNeg = false; + break; + case OperationType::Cmp: case OperationType::Other: break; } @@ -1054,6 +979,9 @@ public: I->setHasNoUnsignedWrap(WrapFlags.HasNUW); I->setHasNoSignedWrap(WrapFlags.HasNSW); break; + case OperationType::DisjointOp: + cast<PossiblyDisjointInst>(I)->setIsDisjoint(DisjointFlags.IsDisjoint); + break; case OperationType::PossiblyExactOp: I->setIsExact(ExactFlags.IsExact); break; @@ -1069,43 +997,209 @@ public: I->setHasAllowContract(FMFs.AllowContract); I->setHasApproxFunc(FMFs.ApproxFunc); break; + case OperationType::NonNegOp: + I->setNonNeg(NonNegFlags.NonNeg); + break; + case OperationType::Cmp: case OperationType::Other: break; } } + CmpInst::Predicate getPredicate() const { + assert(OpType == OperationType::Cmp && + "recipe doesn't have a compare predicate"); + return CmpPredicate; + } + bool isInBounds() const { assert(OpType == OperationType::GEPOp && "recipe doesn't have inbounds flag"); return GEPFlags.IsInBounds; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - FastMathFlags getFastMathFlags() const { - FastMathFlags Res; - Res.setAllowReassoc(FMFs.AllowReassoc); - Res.setNoNaNs(FMFs.NoNaNs); - Res.setNoInfs(FMFs.NoInfs); - Res.setNoSignedZeros(FMFs.NoSignedZeros); - Res.setAllowReciprocal(FMFs.AllowReciprocal); - Res.setAllowContract(FMFs.AllowContract); - Res.setApproxFunc(FMFs.ApproxFunc); - return Res; + /// Returns true if the recipe has fast-math flags. + bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; } + + FastMathFlags getFastMathFlags() const; + + bool hasNoUnsignedWrap() const { + assert(OpType == OperationType::OverflowingBinOp && + "recipe doesn't have a NUW flag"); + return WrapFlags.HasNUW; } + bool hasNoSignedWrap() const { + assert(OpType == OperationType::OverflowingBinOp && + "recipe doesn't have a NSW flag"); + return WrapFlags.HasNSW; + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printFlags(raw_ostream &O) const; #endif }; +/// This is a concrete Recipe that models a single VPlan-level instruction. +/// While as any Recipe it may generate a sequence of IR instructions when +/// executed, these instructions would always form a single-def expression as +/// the VPInstruction is also a single def-use vertex. +class VPInstruction : public VPRecipeWithIRFlags, public VPValue { + friend class VPlanSlp; + +public: + /// VPlan opcodes, extending LLVM IR with idiomatics instructions. + enum { + FirstOrderRecurrenceSplice = + Instruction::OtherOpsEnd + 1, // Combines the incoming and previous + // values of a first-order recurrence. + Not, + SLPLoad, + SLPStore, + ActiveLaneMask, + CalculateTripCountMinusVF, + // Increment the canonical IV separately for each unrolled part. + CanonicalIVIncrementForPart, + BranchOnCount, + BranchOnCond + }; + +private: + typedef unsigned char OpcodeTy; + OpcodeTy Opcode; + + /// An optional name that can be used for the generated IR instruction. + const std::string Name; + + /// Utility method serving execute(): generates a single instance of the + /// modeled instruction. \returns the generated value for \p Part. + /// In some cases an existing value is returned rather than a generated + /// one. + Value *generateInstruction(VPTransformState &State, unsigned Part); + +#if !defined(NDEBUG) + /// Return true if the VPInstruction is a floating point math operation, i.e. + /// has fast-math flags. + bool isFPMathOp() const; +#endif + +protected: + void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } + +public: + VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL, + const Twine &Name = "") + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DL), + VPValue(this), Opcode(Opcode), Name(Name.str()) {} + + VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, + DebugLoc DL = {}, const Twine &Name = "") + : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {} + + VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A, + VPValue *B, DebugLoc DL = {}, const Twine &Name = ""); + + VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, + WrapFlagsTy WrapFlags, DebugLoc DL = {}, const Twine &Name = "") + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags, DL), + VPValue(this), Opcode(Opcode), Name(Name.str()) {} + + VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, + FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = ""); + + VP_CLASSOF_IMPL(VPDef::VPInstructionSC) + + unsigned getOpcode() const { return Opcode; } + + /// Generate the instruction. + /// TODO: We currently execute only per-part unless a specific instance is + /// provided. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the VPInstruction to \p O. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; + + /// Print the VPInstruction to dbgs() (for debugging). + LLVM_DUMP_METHOD void dump() const; +#endif + + /// Return true if this instruction may modify memory. + bool mayWriteToMemory() const { + // TODO: we can use attributes of the called function to rule out memory + // modifications. + return Opcode == Instruction::Store || Opcode == Instruction::Call || + Opcode == Instruction::Invoke || Opcode == SLPStore; + } + + bool hasResult() const { + // CallInst may or may not have a result, depending on the called function. + // Conservatively return calls have results for now. + switch (getOpcode()) { + case Instruction::Ret: + case Instruction::Br: + case Instruction::Store: + case Instruction::Switch: + case Instruction::IndirectBr: + case Instruction::Resume: + case Instruction::CatchRet: + case Instruction::Unreachable: + case Instruction::Fence: + case Instruction::AtomicRMW: + case VPInstruction::BranchOnCond: + case VPInstruction::BranchOnCount: + return false; + default: + return true; + } + } + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + if (getOperand(0) != Op) + return false; + switch (getOpcode()) { + default: + return false; + case VPInstruction::ActiveLaneMask: + case VPInstruction::CalculateTripCountMinusVF: + case VPInstruction::CanonicalIVIncrementForPart: + case VPInstruction::BranchOnCount: + return true; + }; + llvm_unreachable("switch should return"); + } + + /// Returns true if the recipe only uses the first part of operand \p Op. + bool onlyFirstPartUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + if (getOperand(0) != Op) + return false; + switch (getOpcode()) { + default: + return false; + case VPInstruction::BranchOnCount: + return true; + }; + llvm_unreachable("switch should return"); + } +}; + /// VPWidenRecipe is a recipe for producing a copy of vector type its /// ingredient. This recipe covers most of the traditional vectorization cases /// where each ingredient transforms into a vectorized version of itself. class VPWidenRecipe : public VPRecipeWithIRFlags, public VPValue { + unsigned Opcode; public: template <typename IterT> VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands) - : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPValue(this, &I) {} + : VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPValue(this, &I), + Opcode(I.getOpcode()) {} ~VPWidenRecipe() override = default; @@ -1114,6 +1208,8 @@ public: /// Produce widened copies of all Ingredients. void execute(VPTransformState &State) override; + unsigned getOpcode() const { return Opcode; } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, @@ -1122,7 +1218,7 @@ public: }; /// VPWidenCastRecipe is a recipe to create vector cast instructions. -class VPWidenCastRecipe : public VPRecipeBase, public VPValue { +class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPValue { /// Cast instruction opcode. Instruction::CastOps Opcode; @@ -1131,15 +1227,19 @@ class VPWidenCastRecipe : public VPRecipeBase, public VPValue { public: VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy, - CastInst *UI = nullptr) - : VPRecipeBase(VPDef::VPWidenCastSC, Op), VPValue(this, UI), + CastInst &UI) + : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, UI), VPValue(this, &UI), Opcode(Opcode), ResultTy(ResultTy) { - assert((!UI || UI->getOpcode() == Opcode) && + assert(UI.getOpcode() == Opcode && "opcode of underlying cast doesn't match"); - assert((!UI || UI->getType() == ResultTy) && + assert(UI.getType() == ResultTy && "result type of underlying cast doesn't match"); } + VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy) + : VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPValue(this, nullptr), + Opcode(Opcode), ResultTy(ResultTy) {} + ~VPWidenCastRecipe() override = default; VP_CLASSOF_IMPL(VPDef::VPWidenCastSC) @@ -1196,7 +1296,8 @@ public: struct VPWidenSelectRecipe : public VPRecipeBase, public VPValue { template <typename IterT> VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands) - : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I) {} + : VPRecipeBase(VPDef::VPWidenSelectSC, Operands, I.getDebugLoc()), + VPValue(this, &I) {} ~VPWidenSelectRecipe() override = default; @@ -1282,8 +1383,8 @@ public: class VPHeaderPHIRecipe : public VPRecipeBase, public VPValue { protected: VPHeaderPHIRecipe(unsigned char VPDefID, Instruction *UnderlyingInstr, - VPValue *Start = nullptr) - : VPRecipeBase(VPDefID, {}), VPValue(this, UnderlyingInstr) { + VPValue *Start = nullptr, DebugLoc DL = {}) + : VPRecipeBase(VPDefID, {}, DL), VPValue(this, UnderlyingInstr) { if (Start) addOperand(Start); } @@ -1404,7 +1505,7 @@ public: bool isCanonical() const; /// Returns the scalar type of the induction. - const Type *getScalarType() const { + Type *getScalarType() const { return Trunc ? Trunc->getType() : IV->getType(); } }; @@ -1565,14 +1666,13 @@ public: /// A recipe for vectorizing a phi-node as a sequence of mask-based select /// instructions. class VPBlendRecipe : public VPRecipeBase, public VPValue { - PHINode *Phi; - public: /// The blend operation is a User of the incoming values and of their /// respective masks, ordered [I0, M0, I1, M1, ...]. Note that a single value /// might be incoming with a full mask for which there is no VPValue. VPBlendRecipe(PHINode *Phi, ArrayRef<VPValue *> Operands) - : VPRecipeBase(VPDef::VPBlendSC, Operands), VPValue(this, Phi), Phi(Phi) { + : VPRecipeBase(VPDef::VPBlendSC, Operands, Phi->getDebugLoc()), + VPValue(this, Phi) { assert(Operands.size() > 0 && ((Operands.size() == 1) || (Operands.size() % 2 == 0)) && "Expected either a single incoming value or a positive even number " @@ -1701,16 +1801,13 @@ public: /// The Operands are {ChainOp, VecOp, [Condition]}. class VPReductionRecipe : public VPRecipeBase, public VPValue { /// The recurrence decriptor for the reduction in question. - const RecurrenceDescriptor *RdxDesc; - /// Pointer to the TTI, needed to create the target reduction - const TargetTransformInfo *TTI; + const RecurrenceDescriptor &RdxDesc; public: - VPReductionRecipe(const RecurrenceDescriptor *R, Instruction *I, - VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, - const TargetTransformInfo *TTI) + VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I, + VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp) : VPRecipeBase(VPDef::VPReductionSC, {ChainOp, VecOp}), VPValue(this, I), - RdxDesc(R), TTI(TTI) { + RdxDesc(R) { if (CondOp) addOperand(CondOp); } @@ -2008,11 +2105,9 @@ public: /// loop). VPWidenCanonicalIVRecipe represents the vector version of the /// canonical induction variable. class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe { - DebugLoc DL; - public: VPCanonicalIVPHIRecipe(VPValue *StartV, DebugLoc DL) - : VPHeaderPHIRecipe(VPDef::VPCanonicalIVPHISC, nullptr, StartV), DL(DL) {} + : VPHeaderPHIRecipe(VPDef::VPCanonicalIVPHISC, nullptr, StartV, DL) {} ~VPCanonicalIVPHIRecipe() override = default; @@ -2032,8 +2127,8 @@ public: #endif /// Returns the scalar type of the induction. - const Type *getScalarType() const { - return getOperand(0)->getLiveInIRValue()->getType(); + Type *getScalarType() const { + return getStartValue()->getLiveInIRValue()->getType(); } /// Returns true if the recipe only uses the first lane of operand \p Op. @@ -2043,6 +2138,13 @@ public: return true; } + /// Returns true if the recipe only uses the first part of operand \p Op. + bool onlyFirstPartUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } + /// Check if the induction described by \p Kind, /p Start and \p Step is /// canonical, i.e. has the same start, step (of 1), and type as the /// canonical IV. @@ -2055,12 +2157,10 @@ public: /// TODO: It would be good to use the existing VPWidenPHIRecipe instead and /// remove VPActiveLaneMaskPHIRecipe. class VPActiveLaneMaskPHIRecipe : public VPHeaderPHIRecipe { - DebugLoc DL; - public: VPActiveLaneMaskPHIRecipe(VPValue *StartMask, DebugLoc DL) - : VPHeaderPHIRecipe(VPDef::VPActiveLaneMaskPHISC, nullptr, StartMask), - DL(DL) {} + : VPHeaderPHIRecipe(VPDef::VPActiveLaneMaskPHISC, nullptr, StartMask, + DL) {} ~VPActiveLaneMaskPHIRecipe() override = default; @@ -2113,19 +2213,24 @@ public: /// an IV with different start and step values, using Start + CanonicalIV * /// Step. class VPDerivedIVRecipe : public VPRecipeBase, public VPValue { - /// The type of the result value. It may be smaller than the type of the - /// induction and in this case it will get truncated to ResultTy. - Type *ResultTy; + /// If not nullptr, the result of the induction will get truncated to + /// TruncResultTy. + Type *TruncResultTy; - /// Induction descriptor for the induction the canonical IV is transformed to. - const InductionDescriptor &IndDesc; + /// Kind of the induction. + const InductionDescriptor::InductionKind Kind; + /// If not nullptr, the floating point induction binary operator. Must be set + /// for floating point inductions. + const FPMathOperator *FPBinOp; public: VPDerivedIVRecipe(const InductionDescriptor &IndDesc, VPValue *Start, VPCanonicalIVPHIRecipe *CanonicalIV, VPValue *Step, - Type *ResultTy) + Type *TruncResultTy) : VPRecipeBase(VPDef::VPDerivedIVSC, {Start, CanonicalIV, Step}), - VPValue(this), ResultTy(ResultTy), IndDesc(IndDesc) {} + VPValue(this), TruncResultTy(TruncResultTy), Kind(IndDesc.getKind()), + FPBinOp(dyn_cast_or_null<FPMathOperator>(IndDesc.getInductionBinOp())) { + } ~VPDerivedIVRecipe() override = default; @@ -2141,6 +2246,11 @@ public: VPSlotTracker &SlotTracker) const override; #endif + Type *getScalarType() const { + return TruncResultTy ? TruncResultTy + : getStartValue()->getLiveInIRValue()->getType(); + } + VPValue *getStartValue() const { return getOperand(0); } VPValue *getCanonicalIV() const { return getOperand(1); } VPValue *getStepValue() const { return getOperand(2); } @@ -2155,14 +2265,23 @@ public: /// A recipe for handling phi nodes of integer and floating-point inductions, /// producing their scalar values. -class VPScalarIVStepsRecipe : public VPRecipeBase, public VPValue { - const InductionDescriptor &IndDesc; +class VPScalarIVStepsRecipe : public VPRecipeWithIRFlags, public VPValue { + Instruction::BinaryOps InductionOpcode; public: + VPScalarIVStepsRecipe(VPValue *IV, VPValue *Step, + Instruction::BinaryOps Opcode, FastMathFlags FMFs) + : VPRecipeWithIRFlags(VPDef::VPScalarIVStepsSC, + ArrayRef<VPValue *>({IV, Step}), FMFs), + VPValue(this), InductionOpcode(Opcode) {} + VPScalarIVStepsRecipe(const InductionDescriptor &IndDesc, VPValue *IV, VPValue *Step) - : VPRecipeBase(VPDef::VPScalarIVStepsSC, {IV, Step}), VPValue(this), - IndDesc(IndDesc) {} + : VPScalarIVStepsRecipe( + IV, Step, IndDesc.getInductionOpcode(), + dyn_cast_or_null<FPMathOperator>(IndDesc.getInductionBinOp()) + ? IndDesc.getInductionBinOp()->getFastMathFlags() + : FastMathFlags()) {} ~VPScalarIVStepsRecipe() override = default; @@ -2445,6 +2564,9 @@ class VPlan { /// Represents the vector trip count. VPValue VectorTripCount; + /// Represents the loop-invariant VF * UF of the vector loop region. + VPValue VFxUF; + /// Holds a mapping between Values and their corresponding VPValue inside /// VPlan. Value2VPValueTy Value2VPValue; @@ -2490,15 +2612,17 @@ public: ~VPlan(); - /// Create an initial VPlan with preheader and entry blocks. Creates a - /// VPExpandSCEVRecipe for \p TripCount and uses it as plan's trip count. + /// Create initial VPlan skeleton, having an "entry" VPBasicBlock (wrapping + /// original scalar pre-header) which contains SCEV expansions that need to + /// happen before the CFG is modified; a VPBasicBlock for the vector + /// pre-header, followed by a region for the vector loop, followed by the + /// middle VPBasicBlock. static VPlanPtr createInitialVPlan(const SCEV *TripCount, ScalarEvolution &PSE); /// Prepare the plan for execution, setting up the required live-in values. void prepareToExecute(Value *TripCount, Value *VectorTripCount, - Value *CanonicalIVStartValue, VPTransformState &State, - bool IsEpilogueVectorization); + Value *CanonicalIVStartValue, VPTransformState &State); /// Generate the IR code for this VPlan. void execute(VPTransformState *State); @@ -2522,6 +2646,9 @@ public: /// The vector trip count. VPValue &getVectorTripCount() { return VectorTripCount; } + /// Returns VF * UF of the vector loop region. + VPValue &getVFxUF() { return VFxUF; } + /// Mark the plan to indicate that using Value2VPValue is not safe any /// longer, because it may be stale. void disableValue2VPValue() { Value2VPValueEnabled = false; } @@ -2583,13 +2710,10 @@ public: return getVPValue(V); } - void removeVPValueFor(Value *V) { - assert(Value2VPValueEnabled && - "IR value to VPValue mapping may be out of date!"); - Value2VPValue.erase(V); - } - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the live-ins of this VPlan to \p O. + void printLiveIns(raw_ostream &O) const; + /// Print this VPlan to \p O. void print(raw_ostream &O) const; @@ -2628,10 +2752,6 @@ public: return cast<VPCanonicalIVPHIRecipe>(&*EntryVPBB->begin()); } - /// Find and return the VPActiveLaneMaskPHIRecipe from the header - there - /// be only one at most. If there isn't one, then return nullptr. - VPActiveLaneMaskPHIRecipe *getActiveLaneMaskPhi(); - void addLiveOut(PHINode *PN, VPValue *V); void removeLiveOut(PHINode *PN) { @@ -2959,6 +3079,9 @@ namespace vputils { /// Returns true if only the first lane of \p Def is used. bool onlyFirstLaneUsed(VPValue *Def); +/// Returns true if only the first part of \p Def is used. +bool onlyFirstPartUsed(VPValue *Def); + /// Get or create a VPValue that corresponds to the expansion of \p Expr. If \p /// Expr is a SCEVConstant or SCEVUnknown, return a VPValue wrapping the live-in /// value. Otherwise return a VPExpandSCEVRecipe to expand \p Expr. If \p Plan's diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp new file mode 100644 index 000000000000..97a8a1803bbf --- /dev/null +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -0,0 +1,237 @@ +//===- VPlanAnalysis.cpp - Various Analyses working on VPlan ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "VPlanAnalysis.h" +#include "VPlan.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace llvm; + +#define DEBUG_TYPE "vplan" + +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) { + Type *ResTy = inferScalarType(R->getIncomingValue(0)); + for (unsigned I = 1, E = R->getNumIncomingValues(); I != E; ++I) { + VPValue *Inc = R->getIncomingValue(I); + assert(inferScalarType(Inc) == ResTy && + "different types inferred for different incoming values"); + CachedTypes[Inc] = ResTy; + } + return ResTy; +} + +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { + switch (R->getOpcode()) { + case Instruction::Select: { + Type *ResTy = inferScalarType(R->getOperand(1)); + VPValue *OtherV = R->getOperand(2); + assert(inferScalarType(OtherV) == ResTy && + "different types inferred for different operands"); + CachedTypes[OtherV] = ResTy; + return ResTy; + } + case VPInstruction::FirstOrderRecurrenceSplice: { + Type *ResTy = inferScalarType(R->getOperand(0)); + VPValue *OtherV = R->getOperand(1); + assert(inferScalarType(OtherV) == ResTy && + "different types inferred for different operands"); + CachedTypes[OtherV] = ResTy; + return ResTy; + } + default: + break; + } + // Type inference not implemented for opcode. + LLVM_DEBUG({ + dbgs() << "LV: Found unhandled opcode for: "; + R->getVPSingleValue()->dump(); + }); + llvm_unreachable("Unhandled opcode!"); +} + +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) { + unsigned Opcode = R->getOpcode(); + switch (Opcode) { + case Instruction::ICmp: + case Instruction::FCmp: + return IntegerType::get(Ctx, 1); + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::URem: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + Type *ResTy = inferScalarType(R->getOperand(0)); + assert(ResTy == inferScalarType(R->getOperand(1)) && + "types for both operands must match for binary op"); + CachedTypes[R->getOperand(1)] = ResTy; + return ResTy; + } + case Instruction::FNeg: + case Instruction::Freeze: + return inferScalarType(R->getOperand(0)); + default: + break; + } + + // Type inference not implemented for opcode. + LLVM_DEBUG({ + dbgs() << "LV: Found unhandled opcode for: "; + R->getVPSingleValue()->dump(); + }); + llvm_unreachable("Unhandled opcode!"); +} + +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) { + auto &CI = *cast<CallInst>(R->getUnderlyingInstr()); + return CI.getType(); +} + +Type *VPTypeAnalysis::inferScalarTypeForRecipe( + const VPWidenMemoryInstructionRecipe *R) { + assert(!R->isStore() && "Store recipes should not define any values"); + return cast<LoadInst>(&R->getIngredient())->getType(); +} + +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) { + Type *ResTy = inferScalarType(R->getOperand(1)); + VPValue *OtherV = R->getOperand(2); + assert(inferScalarType(OtherV) == ResTy && + "different types inferred for different operands"); + CachedTypes[OtherV] = ResTy; + return ResTy; +} + +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) { + switch (R->getUnderlyingInstr()->getOpcode()) { + case Instruction::Call: { + unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1); + return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue()) + ->getReturnType(); + } + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::URem: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + Type *ResTy = inferScalarType(R->getOperand(0)); + assert(ResTy == inferScalarType(R->getOperand(1)) && + "inferred types for operands of binary op don't match"); + CachedTypes[R->getOperand(1)] = ResTy; + return ResTy; + } + case Instruction::Select: { + Type *ResTy = inferScalarType(R->getOperand(1)); + assert(ResTy == inferScalarType(R->getOperand(2)) && + "inferred types for operands of select op don't match"); + CachedTypes[R->getOperand(2)] = ResTy; + return ResTy; + } + case Instruction::ICmp: + case Instruction::FCmp: + return IntegerType::get(Ctx, 1); + case Instruction::Alloca: + case Instruction::BitCast: + case Instruction::Trunc: + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::FPExt: + case Instruction::FPTrunc: + case Instruction::ExtractValue: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::FPToSI: + case Instruction::FPToUI: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + return R->getUnderlyingInstr()->getType(); + case Instruction::Freeze: + case Instruction::FNeg: + case Instruction::GetElementPtr: + return inferScalarType(R->getOperand(0)); + case Instruction::Load: + return cast<LoadInst>(R->getUnderlyingInstr())->getType(); + case Instruction::Store: + // FIXME: VPReplicateRecipes with store opcodes still define a result + // VPValue, so we need to handle them here. Remove the code here once this + // is modeled accurately in VPlan. + return Type::getVoidTy(Ctx); + default: + break; + } + // Type inference not implemented for opcode. + LLVM_DEBUG({ + dbgs() << "LV: Found unhandled opcode for: "; + R->getVPSingleValue()->dump(); + }); + llvm_unreachable("Unhandled opcode"); +} + +Type *VPTypeAnalysis::inferScalarType(const VPValue *V) { + if (Type *CachedTy = CachedTypes.lookup(V)) + return CachedTy; + + if (V->isLiveIn()) + return V->getLiveInIRValue()->getType(); + + Type *ResultTy = + TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe()) + .Case<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe, + VPReductionPHIRecipe, VPWidenPointerInductionRecipe>( + [this](const auto *R) { + // Handle header phi recipes, except VPWienIntOrFpInduction + // which needs special handling due it being possibly truncated. + // TODO: consider inferring/caching type of siblings, e.g., + // backedge value, here and in cases below. + return inferScalarType(R->getStartValue()); + }) + .Case<VPWidenIntOrFpInductionRecipe, VPDerivedIVRecipe>( + [](const auto *R) { return R->getScalarType(); }) + .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe, + VPWidenGEPRecipe>([this](const VPRecipeBase *R) { + return inferScalarType(R->getOperand(0)); + }) + .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe, + VPWidenCallRecipe, VPWidenMemoryInstructionRecipe, + VPWidenSelectRecipe>( + [this](const auto *R) { return inferScalarTypeForRecipe(R); }) + .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) { + // TODO: Use info from interleave group. + return V->getUnderlyingValue()->getType(); + }) + .Case<VPWidenCastRecipe>( + [](const VPWidenCastRecipe *R) { return R->getResultType(); }); + assert(ResultTy && "could not infer type for the given VPValue"); + CachedTypes[V] = ResultTy; + return ResultTy; +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h new file mode 100644 index 000000000000..473a7c28e48a --- /dev/null +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h @@ -0,0 +1,64 @@ +//===- VPlanAnalysis.h - Various Analyses working on VPlan ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H + +#include "llvm/ADT/DenseMap.h" + +namespace llvm { + +class LLVMContext; +class VPValue; +class VPBlendRecipe; +class VPInterleaveRecipe; +class VPInstruction; +class VPReductionPHIRecipe; +class VPWidenRecipe; +class VPWidenCallRecipe; +class VPWidenCastRecipe; +class VPWidenIntOrFpInductionRecipe; +class VPWidenMemoryInstructionRecipe; +struct VPWidenSelectRecipe; +class VPReplicateRecipe; +class Type; + +/// An analysis for type-inference for VPValues. +/// It infers the scalar type for a given VPValue by bottom-up traversing +/// through defining recipes until root nodes with known types are reached (e.g. +/// live-ins or load recipes). The types are then propagated top down through +/// operations. +/// Note that the analysis caches the inferred types. A new analysis object must +/// be constructed once a VPlan has been modified in a way that invalidates any +/// of the previously inferred types. +class VPTypeAnalysis { + DenseMap<const VPValue *, Type *> CachedTypes; + LLVMContext &Ctx; + + Type *inferScalarTypeForRecipe(const VPBlendRecipe *R); + Type *inferScalarTypeForRecipe(const VPInstruction *R); + Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R); + Type *inferScalarTypeForRecipe(const VPWidenRecipe *R); + Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R); + Type *inferScalarTypeForRecipe(const VPWidenMemoryInstructionRecipe *R); + Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R); + Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R); + +public: + VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {} + + /// Infer the type of \p V. Returns the scalar type of \p V. + Type *inferScalarType(const VPValue *V); + + /// Return the LLVMContext used by the analysis. + LLVMContext &getContext() { return Ctx; } +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index f6e3a2a16db8..f950d4740e41 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -61,6 +61,7 @@ private: // Utility functions. void setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB); + void setRegionPredsFromBB(VPRegionBlock *VPBB, BasicBlock *BB); void fixPhiNodes(); VPBasicBlock *getOrCreateVPBB(BasicBlock *BB); #ifndef NDEBUG @@ -81,14 +82,43 @@ public: // Set predecessors of \p VPBB in the same order as they are in \p BB. \p VPBB // must have no predecessors. void PlainCFGBuilder::setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB) { - SmallVector<VPBlockBase *, 8> VPBBPreds; + auto GetLatchOfExit = [this](BasicBlock *BB) -> BasicBlock * { + auto *SinglePred = BB->getSinglePredecessor(); + Loop *LoopForBB = LI->getLoopFor(BB); + if (!SinglePred || LI->getLoopFor(SinglePred) == LoopForBB) + return nullptr; + // The input IR must be in loop-simplify form, ensuring a single predecessor + // for exit blocks. + assert(SinglePred == LI->getLoopFor(SinglePred)->getLoopLatch() && + "SinglePred must be the only loop latch"); + return SinglePred; + }; + if (auto *LatchBB = GetLatchOfExit(BB)) { + auto *PredRegion = getOrCreateVPBB(LatchBB)->getParent(); + assert(VPBB == cast<VPBasicBlock>(PredRegion->getSingleSuccessor()) && + "successor must already be set for PredRegion; it must have VPBB " + "as single successor"); + VPBB->setPredecessors({PredRegion}); + return; + } // Collect VPBB predecessors. + SmallVector<VPBlockBase *, 2> VPBBPreds; for (BasicBlock *Pred : predecessors(BB)) VPBBPreds.push_back(getOrCreateVPBB(Pred)); - VPBB->setPredecessors(VPBBPreds); } +static bool isHeaderBB(BasicBlock *BB, Loop *L) { + return L && BB == L->getHeader(); +} + +void PlainCFGBuilder::setRegionPredsFromBB(VPRegionBlock *Region, + BasicBlock *BB) { + // BB is a loop header block. Connect the region to the loop preheader. + Loop *LoopOfBB = LI->getLoopFor(BB); + Region->setPredecessors({getOrCreateVPBB(LoopOfBB->getLoopPredecessor())}); +} + // Add operands to VPInstructions representing phi nodes from the input IR. void PlainCFGBuilder::fixPhiNodes() { for (auto *Phi : PhisToFix) { @@ -100,38 +130,85 @@ void PlainCFGBuilder::fixPhiNodes() { assert(VPPhi->getNumOperands() == 0 && "Expected VPInstruction with no operands."); + Loop *L = LI->getLoopFor(Phi->getParent()); + if (isHeaderBB(Phi->getParent(), L)) { + // For header phis, make sure the incoming value from the loop + // predecessor is the first operand of the recipe. + assert(Phi->getNumOperands() == 2); + BasicBlock *LoopPred = L->getLoopPredecessor(); + VPPhi->addIncoming( + getOrCreateVPOperand(Phi->getIncomingValueForBlock(LoopPred)), + BB2VPBB[LoopPred]); + BasicBlock *LoopLatch = L->getLoopLatch(); + VPPhi->addIncoming( + getOrCreateVPOperand(Phi->getIncomingValueForBlock(LoopLatch)), + BB2VPBB[LoopLatch]); + continue; + } + for (unsigned I = 0; I != Phi->getNumOperands(); ++I) VPPhi->addIncoming(getOrCreateVPOperand(Phi->getIncomingValue(I)), BB2VPBB[Phi->getIncomingBlock(I)]); } } +static bool isHeaderVPBB(VPBasicBlock *VPBB) { + return VPBB->getParent() && VPBB->getParent()->getEntry() == VPBB; +} + +/// Return true of \p L loop is contained within \p OuterLoop. +static bool doesContainLoop(const Loop *L, const Loop *OuterLoop) { + if (L->getLoopDepth() < OuterLoop->getLoopDepth()) + return false; + const Loop *P = L; + while (P) { + if (P == OuterLoop) + return true; + P = P->getParentLoop(); + } + return false; +} + // Create a new empty VPBasicBlock for an incoming BasicBlock in the region // corresponding to the containing loop or retrieve an existing one if it was // already created. If no region exists yet for the loop containing \p BB, a new // one is created. VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) { - auto BlockIt = BB2VPBB.find(BB); - if (BlockIt != BB2VPBB.end()) + if (auto *VPBB = BB2VPBB.lookup(BB)) { // Retrieve existing VPBB. - return BlockIt->second; - - // Get or create a region for the loop containing BB. - Loop *CurrentLoop = LI->getLoopFor(BB); - VPRegionBlock *ParentR = nullptr; - if (CurrentLoop) { - auto Iter = Loop2Region.insert({CurrentLoop, nullptr}); - if (Iter.second) - Iter.first->second = new VPRegionBlock( - CurrentLoop->getHeader()->getName().str(), false /*isReplicator*/); - ParentR = Iter.first->second; + return VPBB; } // Create new VPBB. - LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << BB->getName() << "\n"); - VPBasicBlock *VPBB = new VPBasicBlock(BB->getName()); + StringRef Name = isHeaderBB(BB, TheLoop) ? "vector.body" : BB->getName(); + LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << Name << "\n"); + VPBasicBlock *VPBB = new VPBasicBlock(Name); BB2VPBB[BB] = VPBB; - VPBB->setParent(ParentR); + + // Get or create a region for the loop containing BB. + Loop *LoopOfBB = LI->getLoopFor(BB); + if (!LoopOfBB || !doesContainLoop(LoopOfBB, TheLoop)) + return VPBB; + + auto *RegionOfVPBB = Loop2Region.lookup(LoopOfBB); + if (!isHeaderBB(BB, LoopOfBB)) { + assert(RegionOfVPBB && + "Region should have been created by visiting header earlier"); + VPBB->setParent(RegionOfVPBB); + return VPBB; + } + + assert(!RegionOfVPBB && + "First visit of a header basic block expects to register its region."); + // Handle a header - take care of its Region. + if (LoopOfBB == TheLoop) { + RegionOfVPBB = Plan.getVectorLoopRegion(); + } else { + RegionOfVPBB = new VPRegionBlock(Name.str(), false /*isReplicator*/); + RegionOfVPBB->setParent(Loop2Region[LoopOfBB->getParentLoop()]); + } + RegionOfVPBB->setEntry(VPBB); + Loop2Region[LoopOfBB] = RegionOfVPBB; return VPBB; } @@ -254,6 +331,25 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, // Main interface to build the plain CFG. void PlainCFGBuilder::buildPlainCFG() { + // 0. Reuse the top-level region, vector-preheader and exit VPBBs from the + // skeleton. These were created directly rather than via getOrCreateVPBB(), + // revisit them now to update BB2VPBB. Note that header/entry and + // latch/exiting VPBB's of top-level region have yet to be created. + VPRegionBlock *TheRegion = Plan.getVectorLoopRegion(); + BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader(); + assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) && + "Unexpected loop preheader"); + auto *VectorPreheaderVPBB = + cast<VPBasicBlock>(TheRegion->getSinglePredecessor()); + // ThePreheaderBB conceptually corresponds to both Plan.getPreheader() (which + // wraps the original preheader BB) and Plan.getEntry() (which represents the + // new vector preheader); here we're interested in setting BB2VPBB to the + // latter. + BB2VPBB[ThePreheaderBB] = VectorPreheaderVPBB; + BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock(); + assert(LoopExitBB && "Loops with multiple exits are not supported."); + BB2VPBB[LoopExitBB] = cast<VPBasicBlock>(TheRegion->getSingleSuccessor()); + // 1. Scan the body of the loop in a topological order to visit each basic // block after having visited its predecessor basic blocks. Create a VPBB for // each BB and link it to its successor and predecessor VPBBs. Note that @@ -263,21 +359,11 @@ void PlainCFGBuilder::buildPlainCFG() { // Loop PH needs to be explicitly visited since it's not taken into account by // LoopBlocksDFS. - BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader(); - assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) && - "Unexpected loop preheader"); - VPBasicBlock *ThePreheaderVPBB = Plan.getEntry(); - BB2VPBB[ThePreheaderBB] = ThePreheaderVPBB; - ThePreheaderVPBB->setName("vector.ph"); for (auto &I : *ThePreheaderBB) { if (I.getType()->isVoidTy()) continue; IRDef2VPValue[&I] = Plan.getVPValueOrAddLiveIn(&I); } - // Create empty VPBB for Loop H so that we can link PH->H. - VPBlockBase *HeaderVPBB = getOrCreateVPBB(TheLoop->getHeader()); - HeaderVPBB->setName("vector.body"); - ThePreheaderVPBB->setOneSuccessor(HeaderVPBB); LoopBlocksRPO RPO(TheLoop); RPO.perform(LI); @@ -286,88 +372,55 @@ void PlainCFGBuilder::buildPlainCFG() { // Create or retrieve the VPBasicBlock for this BB and create its // VPInstructions. VPBasicBlock *VPBB = getOrCreateVPBB(BB); + VPRegionBlock *Region = VPBB->getParent(); createVPInstructionsForVPBB(VPBB, BB); + Loop *LoopForBB = LI->getLoopFor(BB); + // Set VPBB predecessors in the same order as they are in the incoming BB. + if (!isHeaderBB(BB, LoopForBB)) { + setVPBBPredsFromBB(VPBB, BB); + } else { + // BB is a loop header, set the predecessor for the region, except for the + // top region, whose predecessor was set when creating VPlan's skeleton. + assert(isHeaderVPBB(VPBB) && "isHeaderBB and isHeaderVPBB disagree"); + if (TheRegion != Region) + setRegionPredsFromBB(Region, BB); + } // Set VPBB successors. We create empty VPBBs for successors if they don't // exist already. Recipes will be created when the successor is visited // during the RPO traversal. - Instruction *TI = BB->getTerminator(); - assert(TI && "Terminator expected."); - unsigned NumSuccs = TI->getNumSuccessors(); - + auto *BI = cast<BranchInst>(BB->getTerminator()); + unsigned NumSuccs = succ_size(BB); if (NumSuccs == 1) { - VPBasicBlock *SuccVPBB = getOrCreateVPBB(TI->getSuccessor(0)); - assert(SuccVPBB && "VPBB Successor not found."); - VPBB->setOneSuccessor(SuccVPBB); - } else if (NumSuccs == 2) { - VPBasicBlock *SuccVPBB0 = getOrCreateVPBB(TI->getSuccessor(0)); - assert(SuccVPBB0 && "Successor 0 not found."); - VPBasicBlock *SuccVPBB1 = getOrCreateVPBB(TI->getSuccessor(1)); - assert(SuccVPBB1 && "Successor 1 not found."); - - // Get VPBB's condition bit. - assert(isa<BranchInst>(TI) && "Unsupported terminator!"); - // Look up the branch condition to get the corresponding VPValue - // representing the condition bit in VPlan (which may be in another VPBB). - assert(IRDef2VPValue.count(cast<BranchInst>(TI)->getCondition()) && - "Missing condition bit in IRDef2VPValue!"); - - // Link successors. - VPBB->setTwoSuccessors(SuccVPBB0, SuccVPBB1); - } else - llvm_unreachable("Number of successors not supported."); - - // Set VPBB predecessors in the same order as they are in the incoming BB. - setVPBBPredsFromBB(VPBB, BB); + auto *Successor = getOrCreateVPBB(BB->getSingleSuccessor()); + VPBB->setOneSuccessor(isHeaderVPBB(Successor) + ? Successor->getParent() + : static_cast<VPBlockBase *>(Successor)); + continue; + } + assert(BI->isConditional() && NumSuccs == 2 && BI->isConditional() && + "block must have conditional branch with 2 successors"); + // Look up the branch condition to get the corresponding VPValue + // representing the condition bit in VPlan (which may be in another VPBB). + assert(IRDef2VPValue.contains(BI->getCondition()) && + "Missing condition bit in IRDef2VPValue!"); + VPBasicBlock *Successor0 = getOrCreateVPBB(BI->getSuccessor(0)); + VPBasicBlock *Successor1 = getOrCreateVPBB(BI->getSuccessor(1)); + if (!LoopForBB || BB != LoopForBB->getLoopLatch()) { + VPBB->setTwoSuccessors(Successor0, Successor1); + continue; + } + // For a latch we need to set the successor of the region rather than that + // of VPBB and it should be set to the exit, i.e., non-header successor, + // except for the top region, whose successor was set when creating VPlan's + // skeleton. + if (TheRegion != Region) + Region->setOneSuccessor(isHeaderVPBB(Successor0) ? Successor1 + : Successor0); + Region->setExiting(VPBB); } - // 2. Process outermost loop exit. We created an empty VPBB for the loop - // single exit BB during the RPO traversal of the loop body but Instructions - // weren't visited because it's not part of the the loop. - BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock(); - assert(LoopExitBB && "Loops with multiple exits are not supported."); - VPBasicBlock *LoopExitVPBB = BB2VPBB[LoopExitBB]; - // Loop exit was already set as successor of the loop exiting BB. - // We only set its predecessor VPBB now. - setVPBBPredsFromBB(LoopExitVPBB, LoopExitBB); - - // 3. Fix up region blocks for loops. For each loop, - // * use the header block as entry to the corresponding region, - // * use the latch block as exit of the corresponding region, - // * set the region as successor of the loop pre-header, and - // * set the exit block as successor to the region. - SmallVector<Loop *> LoopWorkList; - LoopWorkList.push_back(TheLoop); - while (!LoopWorkList.empty()) { - Loop *L = LoopWorkList.pop_back_val(); - BasicBlock *Header = L->getHeader(); - BasicBlock *Exiting = L->getLoopLatch(); - assert(Exiting == L->getExitingBlock() && - "Latch must be the only exiting block"); - VPRegionBlock *Region = Loop2Region[L]; - VPBasicBlock *HeaderVPBB = getOrCreateVPBB(Header); - VPBasicBlock *ExitingVPBB = getOrCreateVPBB(Exiting); - - // Disconnect backedge and pre-header from header. - VPBasicBlock *PreheaderVPBB = getOrCreateVPBB(L->getLoopPreheader()); - VPBlockUtils::disconnectBlocks(PreheaderVPBB, HeaderVPBB); - VPBlockUtils::disconnectBlocks(ExitingVPBB, HeaderVPBB); - - Region->setParent(PreheaderVPBB->getParent()); - Region->setEntry(HeaderVPBB); - VPBlockUtils::connectBlocks(PreheaderVPBB, Region); - - // Disconnect exit block from exiting (=latch) block, set exiting block and - // connect region to exit block. - VPBasicBlock *ExitVPBB = getOrCreateVPBB(L->getExitBlock()); - VPBlockUtils::disconnectBlocks(ExitingVPBB, ExitVPBB); - Region->setExiting(ExitingVPBB); - VPBlockUtils::connectBlocks(Region, ExitVPBB); - - // Queue sub-loops for processing. - LoopWorkList.append(L->begin(), L->end()); - } - // 4. The whole CFG has been built at this point so all the input Values must + // 2. The whole CFG has been built at this point so all the input Values must // have a VPlan couterpart. Fix VPlan phi nodes by adding their corresponding // VPlan operands. fixPhiNodes(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 26c309eed800..c23428e2ba34 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "VPlan.h" +#include "VPlanAnalysis.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" @@ -114,6 +115,16 @@ bool VPRecipeBase::mayHaveSideEffects() const { case VPDerivedIVSC: case VPPredInstPHISC: return false; + case VPInstructionSC: + switch (cast<VPInstruction>(this)->getOpcode()) { + case Instruction::ICmp: + case VPInstruction::Not: + case VPInstruction::CalculateTripCountMinusVF: + case VPInstruction::CanonicalIVIncrementForPart: + return false; + default: + return true; + } case VPWidenCallSC: return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) ->mayHaveSideEffects(); @@ -156,8 +167,13 @@ void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) { VPValue *ExitValue = getOperand(0); if (vputils::isUniformAfterVectorization(ExitValue)) Lane = VPLane::getFirstLane(); + VPBasicBlock *MiddleVPBB = + cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor()); + assert(MiddleVPBB->getNumSuccessors() == 0 && + "the middle block must not have any successors"); + BasicBlock *MiddleBB = State.CFG.VPBB2IRBB[MiddleVPBB]; Phi->addIncoming(State.get(ExitValue, VPIteration(State.UF - 1, Lane)), - State.Builder.GetInsertBlock()); + MiddleBB); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -216,15 +232,55 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB, insertBefore(BB, I); } +FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const { + assert(OpType == OperationType::FPMathOp && + "recipe doesn't have fast math flags"); + FastMathFlags Res; + Res.setAllowReassoc(FMFs.AllowReassoc); + Res.setNoNaNs(FMFs.NoNaNs); + Res.setNoInfs(FMFs.NoInfs); + Res.setNoSignedZeros(FMFs.NoSignedZeros); + Res.setAllowReciprocal(FMFs.AllowReciprocal); + Res.setAllowContract(FMFs.AllowContract); + Res.setApproxFunc(FMFs.ApproxFunc); + return Res; +} + +VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, + VPValue *A, VPValue *B, DebugLoc DL, + const Twine &Name) + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef<VPValue *>({A, B}), + Pred, DL), + VPValue(this), Opcode(Opcode), Name(Name.str()) { + assert(Opcode == Instruction::ICmp && + "only ICmp predicates supported at the moment"); +} + +VPInstruction::VPInstruction(unsigned Opcode, + std::initializer_list<VPValue *> Operands, + FastMathFlags FMFs, DebugLoc DL, const Twine &Name) + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs, DL), + VPValue(this), Opcode(Opcode), Name(Name.str()) { + // Make sure the VPInstruction is a floating-point operation. + assert(isFPMathOp() && "this op can't take fast-math flags"); +} + Value *VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilderBase &Builder = State.Builder; - Builder.SetCurrentDebugLocation(DL); + Builder.SetCurrentDebugLocation(getDebugLoc()); if (Instruction::isBinaryOp(getOpcode())) { + if (Part != 0 && vputils::onlyFirstPartUsed(this)) + return State.get(this, 0); + Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); - return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); + auto *Res = + Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); + if (auto *I = dyn_cast<Instruction>(Res)) + setFlags(I); + return Res; } switch (getOpcode()) { @@ -232,10 +288,10 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, Value *A = State.get(getOperand(0), Part); return Builder.CreateNot(A, Name); } - case VPInstruction::ICmpULE: { - Value *IV = State.get(getOperand(0), Part); - Value *TC = State.get(getOperand(1), Part); - return Builder.CreateICmpULE(IV, TC, Name); + case Instruction::ICmp: { + Value *A = State.get(getOperand(0), Part); + Value *B = State.get(getOperand(1), Part); + return Builder.CreateCmp(getPredicate(), A, B, Name); } case Instruction::Select: { Value *Cond = State.get(getOperand(0), Part); @@ -285,23 +341,7 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, Value *Zero = ConstantInt::get(ScalarTC->getType(), 0); return Builder.CreateSelect(Cmp, Sub, Zero); } - case VPInstruction::CanonicalIVIncrement: - case VPInstruction::CanonicalIVIncrementNUW: { - if (Part == 0) { - bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW; - auto *Phi = State.get(getOperand(0), 0); - // The loop step is equal to the vectorization factor (num of SIMD - // elements) times the unroll factor (num of SIMD instructions). - Value *Step = - createStepForVF(Builder, Phi->getType(), State.VF, State.UF); - return Builder.CreateAdd(Phi, Step, Name, IsNUW, false); - } - return State.get(this, 0); - } - - case VPInstruction::CanonicalIVIncrementForPart: - case VPInstruction::CanonicalIVIncrementForPartNUW: { - bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW; + case VPInstruction::CanonicalIVIncrementForPart: { auto *IV = State.get(getOperand(0), VPIteration(0, 0)); if (Part == 0) return IV; @@ -309,7 +349,8 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, // The canonical IV is incremented by the vectorization factor (num of SIMD // elements) times the unroll part. Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part); - return Builder.CreateAdd(IV, Step, Name, IsNUW, false); + return Builder.CreateAdd(IV, Step, Name, hasNoUnsignedWrap(), + hasNoSignedWrap()); } case VPInstruction::BranchOnCond: { if (Part != 0) @@ -361,10 +402,25 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, } } +#if !defined(NDEBUG) +bool VPInstruction::isFPMathOp() const { + // Inspired by FPMathOperator::classof. Notable differences are that we don't + // support Call, PHI and Select opcodes here yet. + return Opcode == Instruction::FAdd || Opcode == Instruction::FMul || + Opcode == Instruction::FNeg || Opcode == Instruction::FSub || + Opcode == Instruction::FDiv || Opcode == Instruction::FRem || + Opcode == Instruction::FCmp || Opcode == Instruction::Select; +} +#endif + void VPInstruction::execute(VPTransformState &State) { assert(!State.Instance && "VPInstruction executing an Instance"); IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(FMF); + assert((hasFastMathFlags() == isFPMathOp() || + getOpcode() == Instruction::Select) && + "Recipe not a FPMathOp but has fast-math flags?"); + if (hasFastMathFlags()) + State.Builder.setFastMathFlags(getFastMathFlags()); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *GeneratedValue = generateInstruction(State, Part); if (!hasResult()) @@ -393,9 +449,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, case VPInstruction::Not: O << "not"; break; - case VPInstruction::ICmpULE: - O << "icmp ule"; - break; case VPInstruction::SLPLoad: O << "combined load"; break; @@ -408,12 +461,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, case VPInstruction::FirstOrderRecurrenceSplice: O << "first-order splice"; break; - case VPInstruction::CanonicalIVIncrement: - O << "VF * UF + "; - break; - case VPInstruction::CanonicalIVIncrementNUW: - O << "VF * UF +(nuw) "; - break; case VPInstruction::BranchOnCond: O << "branch-on-cond"; break; @@ -421,49 +468,35 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, O << "TC > VF ? TC - VF : 0"; break; case VPInstruction::CanonicalIVIncrementForPart: - O << "VF * Part + "; - break; - case VPInstruction::CanonicalIVIncrementForPartNUW: - O << "VF * Part +(nuw) "; + O << "VF * Part +"; break; case VPInstruction::BranchOnCount: - O << "branch-on-count "; + O << "branch-on-count"; break; default: O << Instruction::getOpcodeName(getOpcode()); } - O << FMF; - - for (const VPValue *Operand : operands()) { - O << " "; - Operand->printAsOperand(O, SlotTracker); - } + printFlags(O); + printOperands(O, SlotTracker); - if (DL) { + if (auto DL = getDebugLoc()) { O << ", !dbg "; DL.print(O); } } #endif -void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { - // Make sure the VPInstruction is a floating-point operation. - assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || - Opcode == Instruction::FNeg || Opcode == Instruction::FSub || - Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp) && - "this op can't take fast-math flags"); - FMF = FMFNew; -} - void VPWidenCallRecipe::execute(VPTransformState &State) { assert(State.VF.isVector() && "not widening"); auto &CI = *cast<CallInst>(getUnderlyingInstr()); assert(!isa<DbgInfoIntrinsic>(CI) && "DbgInfoIntrinsic should have been dropped during VPlan construction"); - State.setDebugLocFromInst(&CI); + State.setDebugLocFrom(CI.getDebugLoc()); + FunctionType *VFTy = nullptr; + if (Variant) + VFTy = Variant->getFunctionType(); for (unsigned Part = 0; Part < State.UF; ++Part) { SmallVector<Type *, 2> TysForDecl; // Add return type if intrinsic is overloaded on it. @@ -475,12 +508,15 @@ void VPWidenCallRecipe::execute(VPTransformState &State) { for (const auto &I : enumerate(operands())) { // Some intrinsics have a scalar argument - don't replace it with a // vector. + // Some vectorized function variants may also take a scalar argument, + // e.g. linear parameters for pointers. Value *Arg; - if (VectorIntrinsicID == Intrinsic::not_intrinsic || - !isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index())) - Arg = State.get(I.value(), Part); - else + if ((VFTy && !VFTy->getParamType(I.index())->isVectorTy()) || + (VectorIntrinsicID != Intrinsic::not_intrinsic && + isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index()))) Arg = State.get(I.value(), VPIteration(0, 0)); + else + Arg = State.get(I.value(), Part); if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index())) TysForDecl.push_back(Arg->getType()); Args.push_back(Arg); @@ -553,8 +589,7 @@ void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent, #endif void VPWidenSelectRecipe::execute(VPTransformState &State) { - auto &I = *cast<SelectInst>(getUnderlyingInstr()); - State.setDebugLocFromInst(&I); + State.setDebugLocFrom(getDebugLoc()); // The condition can be loop invariant but still defined inside the // loop. This means that we can't just use the original 'cond' value. @@ -569,13 +604,31 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) { Value *Op1 = State.get(getOperand(2), Part); Value *Sel = State.Builder.CreateSelect(Cond, Op0, Op1); State.set(this, Sel, Part); - State.addMetadata(Sel, &I); + State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue())); } } +VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy( + const FastMathFlags &FMF) { + AllowReassoc = FMF.allowReassoc(); + NoNaNs = FMF.noNaNs(); + NoInfs = FMF.noInfs(); + NoSignedZeros = FMF.noSignedZeros(); + AllowReciprocal = FMF.allowReciprocal(); + AllowContract = FMF.allowContract(); + ApproxFunc = FMF.approxFunc(); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { switch (OpType) { + case OperationType::Cmp: + O << " " << CmpInst::getPredicateName(getPredicate()); + break; + case OperationType::DisjointOp: + if (DisjointFlags.IsDisjoint) + O << " disjoint"; + break; case OperationType::PossiblyExactOp: if (ExactFlags.IsExact) O << " exact"; @@ -593,17 +646,22 @@ void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { if (GEPFlags.IsInBounds) O << " inbounds"; break; + case OperationType::NonNegOp: + if (NonNegFlags.NonNeg) + O << " nneg"; + break; case OperationType::Other: break; } - O << " "; + if (getNumOperands() > 0) + O << " "; } #endif void VPWidenRecipe::execute(VPTransformState &State) { - auto &I = *cast<Instruction>(getUnderlyingValue()); + State.setDebugLocFrom(getDebugLoc()); auto &Builder = State.Builder; - switch (I.getOpcode()) { + switch (Opcode) { case Instruction::Call: case Instruction::Br: case Instruction::PHI: @@ -630,28 +688,24 @@ void VPWidenRecipe::execute(VPTransformState &State) { case Instruction::Or: case Instruction::Xor: { // Just widen unops and binops. - State.setDebugLocFromInst(&I); - for (unsigned Part = 0; Part < State.UF; ++Part) { SmallVector<Value *, 2> Ops; for (VPValue *VPOp : operands()) Ops.push_back(State.get(VPOp, Part)); - Value *V = Builder.CreateNAryOp(I.getOpcode(), Ops); + Value *V = Builder.CreateNAryOp(Opcode, Ops); if (auto *VecOp = dyn_cast<Instruction>(V)) setFlags(VecOp); // Use this vector value for all users of the original instruction. State.set(this, V, Part); - State.addMetadata(V, &I); + State.addMetadata(V, dyn_cast_or_null<Instruction>(getUnderlyingValue())); } break; } case Instruction::Freeze: { - State.setDebugLocFromInst(&I); - for (unsigned Part = 0; Part < State.UF; ++Part) { Value *Op = State.get(getOperand(0), Part); @@ -663,9 +717,7 @@ void VPWidenRecipe::execute(VPTransformState &State) { case Instruction::ICmp: case Instruction::FCmp: { // Widen compares. Generate vector compares. - bool FCmp = (I.getOpcode() == Instruction::FCmp); - auto *Cmp = cast<CmpInst>(&I); - State.setDebugLocFromInst(Cmp); + bool FCmp = Opcode == Instruction::FCmp; for (unsigned Part = 0; Part < State.UF; ++Part) { Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); @@ -673,51 +725,64 @@ void VPWidenRecipe::execute(VPTransformState &State) { if (FCmp) { // Propagate fast math flags. IRBuilder<>::FastMathFlagGuard FMFG(Builder); - Builder.setFastMathFlags(Cmp->getFastMathFlags()); - C = Builder.CreateFCmp(Cmp->getPredicate(), A, B); + if (auto *I = dyn_cast_or_null<Instruction>(getUnderlyingValue())) + Builder.setFastMathFlags(I->getFastMathFlags()); + C = Builder.CreateFCmp(getPredicate(), A, B); } else { - C = Builder.CreateICmp(Cmp->getPredicate(), A, B); + C = Builder.CreateICmp(getPredicate(), A, B); } State.set(this, C, Part); - State.addMetadata(C, &I); + State.addMetadata(C, dyn_cast_or_null<Instruction>(getUnderlyingValue())); } break; } default: // This instruction is not vectorized by simple widening. - LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I); + LLVM_DEBUG(dbgs() << "LV: Found an unhandled opcode : " + << Instruction::getOpcodeName(Opcode)); llvm_unreachable("Unhandled instruction!"); } // end of switch. + +#if !defined(NDEBUG) + // Verify that VPlan type inference results agree with the type of the + // generated values. + for (unsigned Part = 0; Part < State.UF; ++Part) { + assert(VectorType::get(State.TypeAnalysis.inferScalarType(this), + State.VF) == State.get(this, Part)->getType() && + "inferred type and type from generated instructions do not match"); + } +#endif } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "WIDEN "; printAsOperand(O, SlotTracker); - const Instruction *UI = getUnderlyingInstr(); - O << " = " << UI->getOpcodeName(); + O << " = " << Instruction::getOpcodeName(Opcode); printFlags(O); - if (auto *Cmp = dyn_cast<CmpInst>(UI)) - O << Cmp->getPredicate() << " "; printOperands(O, SlotTracker); } #endif void VPWidenCastRecipe::execute(VPTransformState &State) { - auto *I = cast_or_null<Instruction>(getUnderlyingValue()); - if (I) - State.setDebugLocFromInst(I); + State.setDebugLocFrom(getDebugLoc()); auto &Builder = State.Builder; /// Vectorize casts. assert(State.VF.isVector() && "Not vectorizing?"); Type *DestTy = VectorType::get(getResultType(), State.VF); - + VPValue *Op = getOperand(0); for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *A = State.get(getOperand(0), Part); + if (Part > 0 && Op->isLiveIn()) { + // FIXME: Remove once explicit unrolling is implemented using VPlan. + State.set(this, State.get(this, 0), Part); + continue; + } + Value *A = State.get(Op, Part); Value *Cast = Builder.CreateCast(Instruction::CastOps(Opcode), A, DestTy); State.set(this, Cast, Part); - State.addMetadata(Cast, I); + State.addMetadata(Cast, cast_or_null<Instruction>(getUnderlyingValue())); } } @@ -727,10 +792,182 @@ void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent, O << Indent << "WIDEN-CAST "; printAsOperand(O, SlotTracker); O << " = " << Instruction::getOpcodeName(Opcode) << " "; + printFlags(O); printOperands(O, SlotTracker); O << " to " << *getResultType(); } +#endif + +/// This function adds +/// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...) +/// to each vector element of Val. The sequence starts at StartIndex. +/// \p Opcode is relevant for FP induction variable. +static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step, + Instruction::BinaryOps BinOp, ElementCount VF, + IRBuilderBase &Builder) { + assert(VF.isVector() && "only vector VFs are supported"); + + // Create and check the types. + auto *ValVTy = cast<VectorType>(Val->getType()); + ElementCount VLen = ValVTy->getElementCount(); + Type *STy = Val->getType()->getScalarType(); + assert((STy->isIntegerTy() || STy->isFloatingPointTy()) && + "Induction Step must be an integer or FP"); + assert(Step->getType() == STy && "Step has wrong type"); + + SmallVector<Constant *, 8> Indices; + + // Create a vector of consecutive numbers from zero to VF. + VectorType *InitVecValVTy = ValVTy; + if (STy->isFloatingPointTy()) { + Type *InitVecValSTy = + IntegerType::get(STy->getContext(), STy->getScalarSizeInBits()); + InitVecValVTy = VectorType::get(InitVecValSTy, VLen); + } + Value *InitVec = Builder.CreateStepVector(InitVecValVTy); + + // Splat the StartIdx + Value *StartIdxSplat = Builder.CreateVectorSplat(VLen, StartIdx); + + if (STy->isIntegerTy()) { + InitVec = Builder.CreateAdd(InitVec, StartIdxSplat); + Step = Builder.CreateVectorSplat(VLen, Step); + assert(Step->getType() == Val->getType() && "Invalid step vec"); + // FIXME: The newly created binary instructions should contain nsw/nuw + // flags, which can be found from the original scalar operations. + Step = Builder.CreateMul(InitVec, Step); + return Builder.CreateAdd(Val, Step, "induction"); + } + + // Floating point induction. + assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && + "Binary Opcode should be specified for FP induction"); + InitVec = Builder.CreateUIToFP(InitVec, ValVTy); + InitVec = Builder.CreateFAdd(InitVec, StartIdxSplat); + + Step = Builder.CreateVectorSplat(VLen, Step); + Value *MulOp = Builder.CreateFMul(InitVec, Step); + return Builder.CreateBinOp(BinOp, Val, MulOp, "induction"); +} + +/// A helper function that returns an integer or floating-point constant with +/// value C. +static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { + return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C) + : ConstantFP::get(Ty, C); +} + +static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy, + ElementCount VF) { + assert(FTy->isFloatingPointTy() && "Expected floating point type!"); + Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits()); + Value *RuntimeVF = getRuntimeVF(B, IntTy, VF); + return B.CreateUIToFP(RuntimeVF, FTy); +} + +void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "Int or FP induction being replicated."); + + Value *Start = getStartValue()->getLiveInIRValue(); + const InductionDescriptor &ID = getInductionDescriptor(); + TruncInst *Trunc = getTruncInst(); + IRBuilderBase &Builder = State.Builder; + assert(IV->getType() == ID.getStartValue()->getType() && "Types must match"); + assert(State.VF.isVector() && "must have vector VF"); + + // The value from the original loop to which we are mapping the new induction + // variable. + Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; + + // Fast-math-flags propagate from the original induction instruction. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + if (ID.getInductionBinOp() && isa<FPMathOperator>(ID.getInductionBinOp())) + Builder.setFastMathFlags(ID.getInductionBinOp()->getFastMathFlags()); + + // Now do the actual transformations, and start with fetching the step value. + Value *Step = State.get(getStepValue(), VPIteration(0, 0)); + + assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) && + "Expected either an induction phi-node or a truncate of it!"); + + // Construct the initial value of the vector IV in the vector loop preheader + auto CurrIP = Builder.saveIP(); + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + Builder.SetInsertPoint(VectorPH->getTerminator()); + if (isa<TruncInst>(EntryVal)) { + assert(Start->getType()->isIntegerTy() && + "Truncation requires an integer type"); + auto *TruncType = cast<IntegerType>(EntryVal->getType()); + Step = Builder.CreateTrunc(Step, TruncType); + Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); + } + + Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0); + Value *SplatStart = Builder.CreateVectorSplat(State.VF, Start); + Value *SteppedStart = getStepVector( + SplatStart, Zero, Step, ID.getInductionOpcode(), State.VF, State.Builder); + + // We create vector phi nodes for both integer and floating-point induction + // variables. Here, we determine the kind of arithmetic we will perform. + Instruction::BinaryOps AddOp; + Instruction::BinaryOps MulOp; + if (Step->getType()->isIntegerTy()) { + AddOp = Instruction::Add; + MulOp = Instruction::Mul; + } else { + AddOp = ID.getInductionOpcode(); + MulOp = Instruction::FMul; + } + + // Multiply the vectorization factor by the step using integer or + // floating-point arithmetic as appropriate. + Type *StepType = Step->getType(); + Value *RuntimeVF; + if (Step->getType()->isFloatingPointTy()) + RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, State.VF); + else + RuntimeVF = getRuntimeVF(Builder, StepType, State.VF); + Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF); + + // Create a vector splat to use in the induction update. + // + // FIXME: If the step is non-constant, we create the vector splat with + // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't + // handle a constant vector splat. + Value *SplatVF = isa<Constant>(Mul) + ? ConstantVector::getSplat(State.VF, cast<Constant>(Mul)) + : Builder.CreateVectorSplat(State.VF, Mul); + Builder.restoreIP(CurrIP); + + // We may need to add the step a number of times, depending on the unroll + // factor. The last of those goes into the PHI. + PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind"); + VecInd->insertBefore(State.CFG.PrevBB->getFirstInsertionPt()); + VecInd->setDebugLoc(EntryVal->getDebugLoc()); + Instruction *LastInduction = VecInd; + for (unsigned Part = 0; Part < State.UF; ++Part) { + State.set(this, LastInduction, Part); + + if (isa<TruncInst>(EntryVal)) + State.addMetadata(LastInduction, EntryVal); + + LastInduction = cast<Instruction>( + Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add")); + LastInduction->setDebugLoc(EntryVal->getDebugLoc()); + } + + LastInduction->setName("vec.ind.next"); + VecInd->addIncoming(SteppedStart, VectorPH); + // Add induction update using an incorrect block temporarily. The phi node + // will be fixed after VPlan execution. Note that at this point the latch + // block cannot be used, as it does not exist yet. + // TODO: Model increment value in VPlan, by turning the recipe into a + // multi-def and a subclass of VPHeaderPHIRecipe. + VecInd->addIncoming(LastInduction, VectorPH); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "WIDEN-INDUCTION"; @@ -770,17 +1007,112 @@ void VPDerivedIVRecipe::print(raw_ostream &O, const Twine &Indent, O << " * "; getStepValue()->printAsOperand(O, SlotTracker); - if (IndDesc.getStep()->getType() != ResultTy) - O << " (truncated to " << *ResultTy << ")"; + if (TruncResultTy) + O << " (truncated to " << *TruncResultTy << ")"; } #endif +void VPScalarIVStepsRecipe::execute(VPTransformState &State) { + // Fast-math-flags propagate from the original induction instruction. + IRBuilder<>::FastMathFlagGuard FMFG(State.Builder); + if (hasFastMathFlags()) + State.Builder.setFastMathFlags(getFastMathFlags()); + + /// Compute scalar induction steps. \p ScalarIV is the scalar induction + /// variable on which to base the steps, \p Step is the size of the step. + + Value *BaseIV = State.get(getOperand(0), VPIteration(0, 0)); + Value *Step = State.get(getStepValue(), VPIteration(0, 0)); + IRBuilderBase &Builder = State.Builder; + + // Ensure step has the same type as that of scalar IV. + Type *BaseIVTy = BaseIV->getType()->getScalarType(); + if (BaseIVTy != Step->getType()) { + // TODO: Also use VPDerivedIVRecipe when only the step needs truncating, to + // avoid separate truncate here. + assert(Step->getType()->isIntegerTy() && + "Truncation requires an integer step"); + Step = State.Builder.CreateTrunc(Step, BaseIVTy); + } + + // We build scalar steps for both integer and floating-point induction + // variables. Here, we determine the kind of arithmetic we will perform. + Instruction::BinaryOps AddOp; + Instruction::BinaryOps MulOp; + if (BaseIVTy->isIntegerTy()) { + AddOp = Instruction::Add; + MulOp = Instruction::Mul; + } else { + AddOp = InductionOpcode; + MulOp = Instruction::FMul; + } + + // Determine the number of scalars we need to generate for each unroll + // iteration. + bool FirstLaneOnly = vputils::onlyFirstLaneUsed(this); + // Compute the scalar steps and save the results in State. + Type *IntStepTy = + IntegerType::get(BaseIVTy->getContext(), BaseIVTy->getScalarSizeInBits()); + Type *VecIVTy = nullptr; + Value *UnitStepVec = nullptr, *SplatStep = nullptr, *SplatIV = nullptr; + if (!FirstLaneOnly && State.VF.isScalable()) { + VecIVTy = VectorType::get(BaseIVTy, State.VF); + UnitStepVec = + Builder.CreateStepVector(VectorType::get(IntStepTy, State.VF)); + SplatStep = Builder.CreateVectorSplat(State.VF, Step); + SplatIV = Builder.CreateVectorSplat(State.VF, BaseIV); + } + + unsigned StartPart = 0; + unsigned EndPart = State.UF; + unsigned StartLane = 0; + unsigned EndLane = FirstLaneOnly ? 1 : State.VF.getKnownMinValue(); + if (State.Instance) { + StartPart = State.Instance->Part; + EndPart = StartPart + 1; + StartLane = State.Instance->Lane.getKnownLane(); + EndLane = StartLane + 1; + } + for (unsigned Part = StartPart; Part < EndPart; ++Part) { + Value *StartIdx0 = createStepForVF(Builder, IntStepTy, State.VF, Part); + + if (!FirstLaneOnly && State.VF.isScalable()) { + auto *SplatStartIdx = Builder.CreateVectorSplat(State.VF, StartIdx0); + auto *InitVec = Builder.CreateAdd(SplatStartIdx, UnitStepVec); + if (BaseIVTy->isFloatingPointTy()) + InitVec = Builder.CreateSIToFP(InitVec, VecIVTy); + auto *Mul = Builder.CreateBinOp(MulOp, InitVec, SplatStep); + auto *Add = Builder.CreateBinOp(AddOp, SplatIV, Mul); + State.set(this, Add, Part); + // It's useful to record the lane values too for the known minimum number + // of elements so we do those below. This improves the code quality when + // trying to extract the first element, for example. + } + + if (BaseIVTy->isFloatingPointTy()) + StartIdx0 = Builder.CreateSIToFP(StartIdx0, BaseIVTy); + + for (unsigned Lane = StartLane; Lane < EndLane; ++Lane) { + Value *StartIdx = Builder.CreateBinOp( + AddOp, StartIdx0, getSignedIntOrFpConstant(BaseIVTy, Lane)); + // The step returned by `createStepForVF` is a runtime-evaluated value + // when VF is scalable. Otherwise, it should be folded into a Constant. + assert((State.VF.isScalable() || isa<Constant>(StartIdx)) && + "Expected StartIdx to be folded to a constant when VF is not " + "scalable"); + auto *Mul = Builder.CreateBinOp(MulOp, StartIdx, Step); + auto *Add = Builder.CreateBinOp(AddOp, BaseIV, Mul); + State.set(this, Add, VPIteration(Part, Lane)); + } + } +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPScalarIVStepsRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent; printAsOperand(O, SlotTracker); - O << Indent << "= SCALAR-STEPS "; + O << " = SCALAR-STEPS "; printOperands(O, SlotTracker); } #endif @@ -874,7 +1206,7 @@ void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent, #endif void VPBlendRecipe::execute(VPTransformState &State) { - State.setDebugLocFromInst(Phi); + State.setDebugLocFrom(getDebugLoc()); // We know that all PHIs in non-header blocks are converted into // selects, so we don't have to worry about the insertion order and we // can just use the builder. @@ -916,7 +1248,7 @@ void VPBlendRecipe::execute(VPTransformState &State) { void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "BLEND "; - Phi->printAsOperand(O, false); + printAsOperand(O, SlotTracker); O << " ="; if (getNumIncomingValues() == 1) { // Not a User of any mask: not really blending, this is a @@ -942,14 +1274,14 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, O << " +"; if (isa<FPMathOperator>(getUnderlyingInstr())) O << getUnderlyingInstr()->getFastMathFlags(); - O << " reduce." << Instruction::getOpcodeName(RdxDesc->getOpcode()) << " ("; + O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; getVecOp()->printAsOperand(O, SlotTracker); if (getCondOp()) { O << ", "; getCondOp()->printAsOperand(O, SlotTracker); } O << ")"; - if (RdxDesc->IntermediateStore) + if (RdxDesc.IntermediateStore) O << " (with final reduction value stored in invariant address sank " "outside of loop)"; } @@ -1093,12 +1425,12 @@ void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, const Twine &Indent, void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) { Value *Start = getStartValue()->getLiveInIRValue(); - PHINode *EntryPart = PHINode::Create( - Start->getType(), 2, "index", &*State.CFG.PrevBB->getFirstInsertionPt()); + PHINode *EntryPart = PHINode::Create(Start->getType(), 2, "index"); + EntryPart->insertBefore(State.CFG.PrevBB->getFirstInsertionPt()); BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); EntryPart->addIncoming(Start, VectorPH); - EntryPart->setDebugLoc(DL); + EntryPart->setDebugLoc(getDebugLoc()); for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) State.set(this, EntryPart, Part); } @@ -1108,7 +1440,8 @@ void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "EMIT "; printAsOperand(O, SlotTracker); - O << " = CANONICAL-INDUCTION"; + O << " = CANONICAL-INDUCTION "; + printOperands(O, SlotTracker); } #endif @@ -1221,8 +1554,8 @@ void VPFirstOrderRecurrencePHIRecipe::execute(VPTransformState &State) { } // Create a phi node for the new recurrence. - PHINode *EntryPart = PHINode::Create( - VecTy, 2, "vector.recur", &*State.CFG.PrevBB->getFirstInsertionPt()); + PHINode *EntryPart = PHINode::Create(VecTy, 2, "vector.recur"); + EntryPart->insertBefore(State.CFG.PrevBB->getFirstInsertionPt()); EntryPart->addIncoming(VectorInit, VectorPH); State.set(this, EntryPart, 0); } @@ -1254,8 +1587,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { "recipe must be in the vector loop header"); unsigned LastPartForNewPhi = isOrdered() ? 1 : State.UF; for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *EntryPart = - PHINode::Create(VecTy, 2, "vec.phi", &*HeaderBB->getFirstInsertionPt()); + Instruction *EntryPart = PHINode::Create(VecTy, 2, "vec.phi"); + EntryPart->insertBefore(HeaderBB->getFirstInsertionPt()); State.set(this, EntryPart, Part); } @@ -1269,8 +1602,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { Value *Iden = nullptr; RecurKind RK = RdxDesc.getRecurrenceKind(); if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) || - RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) { - // MinMax reduction have the start value as their identify. + RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) { + // MinMax and AnyOf reductions have the start value as their identity. if (ScalarPHI) { Iden = StartV; } else { @@ -1316,23 +1649,7 @@ void VPWidenPHIRecipe::execute(VPTransformState &State) { assert(EnableVPlanNativePath && "Non-native vplans are not expected to have VPWidenPHIRecipes."); - // Currently we enter here in the VPlan-native path for non-induction - // PHIs where all control flow is uniform. We simply widen these PHIs. - // Create a vector phi with no operands - the vector phi operands will be - // set at the end of vector code generation. - VPBasicBlock *Parent = getParent(); - VPRegionBlock *LoopRegion = Parent->getEnclosingLoopRegion(); - unsigned StartIdx = 0; - // For phis in header blocks of loop regions, use the index of the value - // coming from the preheader. - if (LoopRegion->getEntryBasicBlock() == Parent) { - for (unsigned I = 0; I < getNumOperands(); ++I) { - if (getIncomingBlock(I) == - LoopRegion->getSinglePredecessor()->getExitingBasicBlock()) - StartIdx = I; - } - } - Value *Op0 = State.get(getOperand(StartIdx), 0); + Value *Op0 = State.get(getOperand(0), 0); Type *VecTy = Op0->getType(); Value *VecPhi = State.Builder.CreatePHI(VecTy, 2, "vec.phi"); State.set(this, VecPhi, 0); @@ -1368,7 +1685,7 @@ void VPActiveLaneMaskPHIRecipe::execute(VPTransformState &State) { PHINode *EntryPart = State.Builder.CreatePHI(StartMask->getType(), 2, "active.lane.mask"); EntryPart->addIncoming(StartMask, VectorPH); - EntryPart->setDebugLoc(DL); + EntryPart->setDebugLoc(getDebugLoc()); State.set(this, EntryPart, Part); } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 83bfdfd09d19..ea90ed4a21b1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -12,17 +12,22 @@ //===----------------------------------------------------------------------===// #include "VPlanTransforms.h" -#include "VPlanDominatorTree.h" #include "VPRecipeBuilder.h" +#include "VPlanAnalysis.h" #include "VPlanCFG.h" +#include "VPlanDominatorTree.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PatternMatch.h" using namespace llvm; +using namespace llvm::PatternMatch; + void VPlanTransforms::VPInstructionsToVPRecipes( VPlanPtr &Plan, function_ref<const InductionDescriptor *(PHINode *)> @@ -76,7 +81,7 @@ void VPlanTransforms::VPInstructionsToVPRecipes( NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands()); } else if (auto *CI = dyn_cast<CastInst>(Inst)) { NewRecipe = new VPWidenCastRecipe( - CI->getOpcode(), Ingredient.getOperand(0), CI->getType(), CI); + CI->getOpcode(), Ingredient.getOperand(0), CI->getType(), *CI); } else { NewRecipe = new VPWidenRecipe(*Inst, Ingredient.operands()); } @@ -158,17 +163,10 @@ static bool sinkScalarOperands(VPlan &Plan) { // TODO: add ".cloned" suffix to name of Clone's VPValue. Clone->insertBefore(SinkCandidate); - for (auto *U : to_vector(SinkCandidate->getVPSingleValue()->users())) { - auto *UI = cast<VPRecipeBase>(U); - if (UI->getParent() == SinkTo) - continue; - - for (unsigned Idx = 0; Idx != UI->getNumOperands(); Idx++) { - if (UI->getOperand(Idx) != SinkCandidate->getVPSingleValue()) - continue; - UI->setOperand(Idx, Clone); - } - } + SinkCandidate->getVPSingleValue()->replaceUsesWithIf( + Clone, [SinkTo](VPUser &U, unsigned) { + return cast<VPRecipeBase>(&U)->getParent() != SinkTo; + }); } SinkCandidate->moveBefore(*SinkTo, SinkTo->getFirstNonPhi()); for (VPValue *Op : SinkCandidate->operands()) @@ -273,16 +271,10 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { VPValue *PredInst1 = cast<VPPredInstPHIRecipe>(&Phi1ToMove)->getOperand(0); VPValue *Phi1ToMoveV = Phi1ToMove.getVPSingleValue(); - for (VPUser *U : to_vector(Phi1ToMoveV->users())) { - auto *UI = dyn_cast<VPRecipeBase>(U); - if (!UI || UI->getParent() != Then2) - continue; - for (unsigned I = 0, E = U->getNumOperands(); I != E; ++I) { - if (Phi1ToMoveV != U->getOperand(I)) - continue; - U->setOperand(I, PredInst1); - } - } + Phi1ToMoveV->replaceUsesWithIf(PredInst1, [Then2](VPUser &U, unsigned) { + auto *UI = dyn_cast<VPRecipeBase>(&U); + return UI && UI->getParent() == Then2; + }); Phi1ToMove.moveBefore(*Merge2, Merge2->begin()); } @@ -479,15 +471,45 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) { // The recipes in the block are processed in reverse order, to catch chains // of dead recipes. for (VPRecipeBase &R : make_early_inc_range(reverse(*VPBB))) { - if (R.mayHaveSideEffects() || any_of(R.definedValues(), [](VPValue *V) { - return V->getNumUsers() > 0; - })) + // A user keeps R alive: + if (any_of(R.definedValues(), + [](VPValue *V) { return V->getNumUsers(); })) continue; + + // Having side effects keeps R alive, but do remove conditional assume + // instructions as their conditions may be flattened. + auto *RepR = dyn_cast<VPReplicateRecipe>(&R); + bool IsConditionalAssume = + RepR && RepR->isPredicated() && + match(RepR->getUnderlyingInstr(), m_Intrinsic<Intrinsic::assume>()); + if (R.mayHaveSideEffects() && !IsConditionalAssume) + continue; + R.eraseFromParent(); } } } +static VPValue *createScalarIVSteps(VPlan &Plan, const InductionDescriptor &ID, + ScalarEvolution &SE, Instruction *TruncI, + Type *IVTy, VPValue *StartV, + VPValue *Step) { + VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); + auto IP = HeaderVPBB->getFirstNonPhi(); + VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV(); + Type *TruncTy = TruncI ? TruncI->getType() : IVTy; + VPValue *BaseIV = CanonicalIV; + if (!CanonicalIV->isCanonical(ID.getKind(), StartV, Step, TruncTy)) { + BaseIV = new VPDerivedIVRecipe(ID, StartV, CanonicalIV, Step, + TruncI ? TruncI->getType() : nullptr); + HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP); + } + + VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe(ID, BaseIV, Step); + HeaderVPBB->insert(Steps, IP); + return Steps; +} + void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { SmallVector<VPRecipeBase *> ToRemove; VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); @@ -501,36 +523,17 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { })) continue; - auto IP = HeaderVPBB->getFirstNonPhi(); - VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV(); - Type *ResultTy = WideIV->getPHINode()->getType(); - if (Instruction *TruncI = WideIV->getTruncInst()) - ResultTy = TruncI->getType(); const InductionDescriptor &ID = WideIV->getInductionDescriptor(); - VPValue *Step = WideIV->getStepValue(); - VPValue *BaseIV = CanonicalIV; - if (!CanonicalIV->isCanonical(ID.getKind(), WideIV->getStartValue(), Step, - ResultTy)) { - BaseIV = new VPDerivedIVRecipe(ID, WideIV->getStartValue(), CanonicalIV, - Step, ResultTy); - HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP); - } - - VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe(ID, BaseIV, Step); - HeaderVPBB->insert(Steps, IP); + VPValue *Steps = createScalarIVSteps( + Plan, ID, SE, WideIV->getTruncInst(), WideIV->getPHINode()->getType(), + WideIV->getStartValue(), WideIV->getStepValue()); // Update scalar users of IV to use Step instead. Use SetVector to ensure // the list of users doesn't contain duplicates. - SetVector<VPUser *> Users(WideIV->user_begin(), WideIV->user_end()); - for (VPUser *U : Users) { - if (HasOnlyVectorVFs && !U->usesScalars(WideIV)) - continue; - for (unsigned I = 0, E = U->getNumOperands(); I != E; I++) { - if (U->getOperand(I) != WideIV) - continue; - U->setOperand(I, Steps); - } - } + WideIV->replaceUsesWithIf( + Steps, [HasOnlyVectorVFs, WideIV](VPUser &U, unsigned) { + return !HasOnlyVectorVFs || U.usesScalars(WideIV); + }); } } @@ -778,3 +781,375 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { } } } + +/// Returns true is \p V is constant one. +static bool isConstantOne(VPValue *V) { + if (!V->isLiveIn()) + return false; + auto *C = dyn_cast<ConstantInt>(V->getLiveInIRValue()); + return C && C->isOne(); +} + +/// Returns the llvm::Instruction opcode for \p R. +static unsigned getOpcodeForRecipe(VPRecipeBase &R) { + if (auto *WidenR = dyn_cast<VPWidenRecipe>(&R)) + return WidenR->getUnderlyingInstr()->getOpcode(); + if (auto *WidenC = dyn_cast<VPWidenCastRecipe>(&R)) + return WidenC->getOpcode(); + if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R)) + return RepR->getUnderlyingInstr()->getOpcode(); + if (auto *VPI = dyn_cast<VPInstruction>(&R)) + return VPI->getOpcode(); + return 0; +} + +/// Try to simplify recipe \p R. +static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { + switch (getOpcodeForRecipe(R)) { + case Instruction::Mul: { + VPValue *A = R.getOperand(0); + VPValue *B = R.getOperand(1); + if (isConstantOne(A)) + return R.getVPSingleValue()->replaceAllUsesWith(B); + if (isConstantOne(B)) + return R.getVPSingleValue()->replaceAllUsesWith(A); + break; + } + case Instruction::Trunc: { + VPRecipeBase *Ext = R.getOperand(0)->getDefiningRecipe(); + if (!Ext) + break; + unsigned ExtOpcode = getOpcodeForRecipe(*Ext); + if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) + break; + VPValue *A = Ext->getOperand(0); + VPValue *Trunc = R.getVPSingleValue(); + Type *TruncTy = TypeInfo.inferScalarType(Trunc); + Type *ATy = TypeInfo.inferScalarType(A); + if (TruncTy == ATy) { + Trunc->replaceAllUsesWith(A); + } else if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) { + auto *VPC = + new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy); + VPC->insertBefore(&R); + Trunc->replaceAllUsesWith(VPC); + } else if (ATy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits()) { + auto *VPC = new VPWidenCastRecipe(Instruction::Trunc, A, TruncTy); + VPC->insertBefore(&R); + Trunc->replaceAllUsesWith(VPC); + } +#ifndef NDEBUG + // Verify that the cached type info is for both A and its users is still + // accurate by comparing it to freshly computed types. + VPTypeAnalysis TypeInfo2(TypeInfo.getContext()); + assert(TypeInfo.inferScalarType(A) == TypeInfo2.inferScalarType(A)); + for (VPUser *U : A->users()) { + auto *R = dyn_cast<VPRecipeBase>(U); + if (!R) + continue; + for (VPValue *VPV : R->definedValues()) + assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV)); + } +#endif + break; + } + default: + break; + } +} + +/// Try to simplify the recipes in \p Plan. +static void simplifyRecipes(VPlan &Plan, LLVMContext &Ctx) { + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( + Plan.getEntry()); + VPTypeAnalysis TypeInfo(Ctx); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + simplifyRecipe(R, TypeInfo); + } + } +} + +void VPlanTransforms::truncateToMinimalBitwidths( + VPlan &Plan, const MapVector<Instruction *, uint64_t> &MinBWs, + LLVMContext &Ctx) { +#ifndef NDEBUG + // Count the processed recipes and cross check the count later with MinBWs + // size, to make sure all entries in MinBWs have been handled. + unsigned NumProcessedRecipes = 0; +#endif + // Keep track of created truncates, so they can be re-used. Note that we + // cannot use RAUW after creating a new truncate, as this would could make + // other uses have different types for their operands, making them invalidly + // typed. + DenseMap<VPValue *, VPWidenCastRecipe *> ProcessedTruncs; + VPTypeAnalysis TypeInfo(Ctx); + VPBasicBlock *PH = Plan.getEntry(); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( + vp_depth_first_deep(Plan.getVectorLoopRegion()))) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + if (!isa<VPWidenRecipe, VPWidenCastRecipe, VPReplicateRecipe, + VPWidenSelectRecipe>(&R)) + continue; + + VPValue *ResultVPV = R.getVPSingleValue(); + auto *UI = cast_or_null<Instruction>(ResultVPV->getUnderlyingValue()); + unsigned NewResSizeInBits = MinBWs.lookup(UI); + if (!NewResSizeInBits) + continue; + +#ifndef NDEBUG + NumProcessedRecipes++; +#endif + // If the value wasn't vectorized, we must maintain the original scalar + // type. Skip those here, after incrementing NumProcessedRecipes. Also + // skip casts which do not need to be handled explicitly here, as + // redundant casts will be removed during recipe simplification. + if (isa<VPReplicateRecipe, VPWidenCastRecipe>(&R)) { +#ifndef NDEBUG + // If any of the operands is a live-in and not used by VPWidenRecipe or + // VPWidenSelectRecipe, but in MinBWs, make sure it is counted as + // processed as well. When MinBWs is currently constructed, there is no + // information about whether recipes are widened or replicated and in + // case they are reciplicated the operands are not truncated. Counting + // them them here ensures we do not miss any recipes in MinBWs. + // TODO: Remove once the analysis is done on VPlan. + for (VPValue *Op : R.operands()) { + if (!Op->isLiveIn()) + continue; + auto *UV = dyn_cast_or_null<Instruction>(Op->getUnderlyingValue()); + if (UV && MinBWs.contains(UV) && !ProcessedTruncs.contains(Op) && + all_of(Op->users(), [](VPUser *U) { + return !isa<VPWidenRecipe, VPWidenSelectRecipe>(U); + })) { + // Add an entry to ProcessedTruncs to avoid counting the same + // operand multiple times. + ProcessedTruncs[Op] = nullptr; + NumProcessedRecipes += 1; + } + } +#endif + continue; + } + + Type *OldResTy = TypeInfo.inferScalarType(ResultVPV); + unsigned OldResSizeInBits = OldResTy->getScalarSizeInBits(); + assert(OldResTy->isIntegerTy() && "only integer types supported"); + if (OldResSizeInBits == NewResSizeInBits) + continue; + assert(OldResSizeInBits > NewResSizeInBits && "Nothing to shrink?"); + (void)OldResSizeInBits; + + auto *NewResTy = IntegerType::get(Ctx, NewResSizeInBits); + + // Shrink operands by introducing truncates as needed. + unsigned StartIdx = isa<VPWidenSelectRecipe>(&R) ? 1 : 0; + for (unsigned Idx = StartIdx; Idx != R.getNumOperands(); ++Idx) { + auto *Op = R.getOperand(Idx); + unsigned OpSizeInBits = + TypeInfo.inferScalarType(Op)->getScalarSizeInBits(); + if (OpSizeInBits == NewResSizeInBits) + continue; + assert(OpSizeInBits > NewResSizeInBits && "nothing to truncate"); + auto [ProcessedIter, IterIsEmpty] = + ProcessedTruncs.insert({Op, nullptr}); + VPWidenCastRecipe *NewOp = + IterIsEmpty + ? new VPWidenCastRecipe(Instruction::Trunc, Op, NewResTy) + : ProcessedIter->second; + R.setOperand(Idx, NewOp); + if (!IterIsEmpty) + continue; + ProcessedIter->second = NewOp; + if (!Op->isLiveIn()) { + NewOp->insertBefore(&R); + } else { + PH->appendRecipe(NewOp); +#ifndef NDEBUG + auto *OpInst = dyn_cast<Instruction>(Op->getLiveInIRValue()); + bool IsContained = MinBWs.contains(OpInst); + NumProcessedRecipes += IsContained; +#endif + } + } + + // Any wrapping introduced by shrinking this operation shouldn't be + // considered undefined behavior. So, we can't unconditionally copy + // arithmetic wrapping flags to VPW. + if (auto *VPW = dyn_cast<VPRecipeWithIRFlags>(&R)) + VPW->dropPoisonGeneratingFlags(); + + // Extend result to original width. + auto *Ext = new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, OldResTy); + Ext->insertAfter(&R); + ResultVPV->replaceAllUsesWith(Ext); + Ext->setOperand(0, ResultVPV); + } + } + + assert(MinBWs.size() == NumProcessedRecipes && + "some entries in MinBWs haven't been processed"); +} + +void VPlanTransforms::optimize(VPlan &Plan, ScalarEvolution &SE) { + removeRedundantCanonicalIVs(Plan); + removeRedundantInductionCasts(Plan); + + optimizeInductions(Plan, SE); + simplifyRecipes(Plan, SE.getContext()); + removeDeadRecipes(Plan); + + createAndOptimizeReplicateRegions(Plan); + + removeRedundantExpandSCEVRecipes(Plan); + mergeBlocksIntoPredecessors(Plan); +} + +// Add a VPActiveLaneMaskPHIRecipe and related recipes to \p Plan and replace +// the loop terminator with a branch-on-cond recipe with the negated +// active-lane-mask as operand. Note that this turns the loop into an +// uncountable one. Only the existing terminator is replaced, all other existing +// recipes/users remain unchanged, except for poison-generating flags being +// dropped from the canonical IV increment. Return the created +// VPActiveLaneMaskPHIRecipe. +// +// The function uses the following definitions: +// +// %TripCount = DataWithControlFlowWithoutRuntimeCheck ? +// calculate-trip-count-minus-VF (original TC) : original TC +// %IncrementValue = DataWithControlFlowWithoutRuntimeCheck ? +// CanonicalIVPhi : CanonicalIVIncrement +// %StartV is the canonical induction start value. +// +// The function adds the following recipes: +// +// vector.ph: +// %TripCount = calculate-trip-count-minus-VF (original TC) +// [if DataWithControlFlowWithoutRuntimeCheck] +// %EntryInc = canonical-iv-increment-for-part %StartV +// %EntryALM = active-lane-mask %EntryInc, %TripCount +// +// vector.body: +// ... +// %P = active-lane-mask-phi [ %EntryALM, %vector.ph ], [ %ALM, %vector.body ] +// ... +// %InLoopInc = canonical-iv-increment-for-part %IncrementValue +// %ALM = active-lane-mask %InLoopInc, TripCount +// %Negated = Not %ALM +// branch-on-cond %Negated +// +static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch( + VPlan &Plan, bool DataAndControlFlowWithoutRuntimeCheck) { + VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); + VPBasicBlock *EB = TopRegion->getExitingBasicBlock(); + auto *CanonicalIVPHI = Plan.getCanonicalIV(); + VPValue *StartV = CanonicalIVPHI->getStartValue(); + + auto *CanonicalIVIncrement = + cast<VPInstruction>(CanonicalIVPHI->getBackedgeValue()); + // TODO: Check if dropping the flags is needed if + // !DataAndControlFlowWithoutRuntimeCheck. + CanonicalIVIncrement->dropPoisonGeneratingFlags(); + DebugLoc DL = CanonicalIVIncrement->getDebugLoc(); + // We can't use StartV directly in the ActiveLaneMask VPInstruction, since + // we have to take unrolling into account. Each part needs to start at + // Part * VF + auto *VecPreheader = cast<VPBasicBlock>(TopRegion->getSinglePredecessor()); + VPBuilder Builder(VecPreheader); + + // Create the ActiveLaneMask instruction using the correct start values. + VPValue *TC = Plan.getTripCount(); + + VPValue *TripCount, *IncrementValue; + if (!DataAndControlFlowWithoutRuntimeCheck) { + // When the loop is guarded by a runtime overflow check for the loop + // induction variable increment by VF, we can increment the value before + // the get.active.lane mask and use the unmodified tripcount. + IncrementValue = CanonicalIVIncrement; + TripCount = TC; + } else { + // When avoiding a runtime check, the active.lane.mask inside the loop + // uses a modified trip count and the induction variable increment is + // done after the active.lane.mask intrinsic is called. + IncrementValue = CanonicalIVPHI; + TripCount = Builder.createNaryOp(VPInstruction::CalculateTripCountMinusVF, + {TC}, DL); + } + auto *EntryIncrement = Builder.createOverflowingOp( + VPInstruction::CanonicalIVIncrementForPart, {StartV}, {false, false}, DL, + "index.part.next"); + + // Create the active lane mask instruction in the VPlan preheader. + auto *EntryALM = + Builder.createNaryOp(VPInstruction::ActiveLaneMask, {EntryIncrement, TC}, + DL, "active.lane.mask.entry"); + + // Now create the ActiveLaneMaskPhi recipe in the main loop using the + // preheader ActiveLaneMask instruction. + auto LaneMaskPhi = new VPActiveLaneMaskPHIRecipe(EntryALM, DebugLoc()); + LaneMaskPhi->insertAfter(CanonicalIVPHI); + + // Create the active lane mask for the next iteration of the loop before the + // original terminator. + VPRecipeBase *OriginalTerminator = EB->getTerminator(); + Builder.setInsertPoint(OriginalTerminator); + auto *InLoopIncrement = + Builder.createOverflowingOp(VPInstruction::CanonicalIVIncrementForPart, + {IncrementValue}, {false, false}, DL); + auto *ALM = Builder.createNaryOp(VPInstruction::ActiveLaneMask, + {InLoopIncrement, TripCount}, DL, + "active.lane.mask.next"); + LaneMaskPhi->addOperand(ALM); + + // Replace the original terminator with BranchOnCond. We have to invert the + // mask here because a true condition means jumping to the exit block. + auto *NotMask = Builder.createNot(ALM, DL); + Builder.createNaryOp(VPInstruction::BranchOnCond, {NotMask}, DL); + OriginalTerminator->eraseFromParent(); + return LaneMaskPhi; +} + +void VPlanTransforms::addActiveLaneMask( + VPlan &Plan, bool UseActiveLaneMaskForControlFlow, + bool DataAndControlFlowWithoutRuntimeCheck) { + assert((!DataAndControlFlowWithoutRuntimeCheck || + UseActiveLaneMaskForControlFlow) && + "DataAndControlFlowWithoutRuntimeCheck implies " + "UseActiveLaneMaskForControlFlow"); + + auto FoundWidenCanonicalIVUser = + find_if(Plan.getCanonicalIV()->users(), + [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); }); + assert(FoundWidenCanonicalIVUser && + "Must have widened canonical IV when tail folding!"); + auto *WideCanonicalIV = + cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser); + VPRecipeBase *LaneMask; + if (UseActiveLaneMaskForControlFlow) { + LaneMask = addVPLaneMaskPhiAndUpdateExitBranch( + Plan, DataAndControlFlowWithoutRuntimeCheck); + } else { + LaneMask = new VPInstruction(VPInstruction::ActiveLaneMask, + {WideCanonicalIV, Plan.getTripCount()}, + nullptr, "active.lane.mask"); + LaneMask->insertAfter(WideCanonicalIV); + } + + // Walk users of WideCanonicalIV and replace all compares of the form + // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an + // active-lane-mask. + VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); + for (VPUser *U : SmallVector<VPUser *>(WideCanonicalIV->users())) { + auto *CompareToReplace = dyn_cast<VPInstruction>(U); + if (!CompareToReplace || + CompareToReplace->getOpcode() != Instruction::ICmp || + CompareToReplace->getPredicate() != CmpInst::ICMP_ULE || + CompareToReplace->getOperand(1) != BTC) + continue; + + assert(CompareToReplace->getOperand(0) == WideCanonicalIV && + "WidenCanonicalIV must be the first operand of the compare"); + CompareToReplace->replaceAllUsesWith(LaneMask->getVPSingleValue()); + CompareToReplace->eraseFromParent(); + } +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 3eccf6e9600d..e8a6da8c3205 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -37,12 +37,56 @@ struct VPlanTransforms { GetIntOrFpInductionDescriptor, ScalarEvolution &SE, const TargetLibraryInfo &TLI); + /// Sink users of fixed-order recurrences after the recipe defining their + /// previous value. Then introduce FirstOrderRecurrenceSplice VPInstructions + /// to combine the value from the recurrence phis and previous values. The + /// current implementation assumes all users can be sunk after the previous + /// value, which is enforced by earlier legality checks. + /// \returns true if all users of fixed-order recurrences could be re-arranged + /// as needed or false if it is not possible. In the latter case, \p Plan is + /// not valid. + static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); + + /// Clear NSW/NUW flags from reduction instructions if necessary. + static void clearReductionWrapFlags(VPlan &Plan); + + /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the + /// resulting plan to \p BestVF and \p BestUF. + static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, + unsigned BestUF, + PredicatedScalarEvolution &PSE); + + /// Apply VPlan-to-VPlan optimizations to \p Plan, including induction recipe + /// optimizations, dead recipe removal, replicate region optimizations and + /// block merging. + static void optimize(VPlan &Plan, ScalarEvolution &SE); + /// Wrap predicated VPReplicateRecipes with a mask operand in an if-then /// region block and remove the mask operand. Optimize the created regions by /// iteratively sinking scalar operands into the region, followed by merging /// regions until no improvements are remaining. static void createAndOptimizeReplicateRegions(VPlan &Plan); + /// Replace (ICMP_ULE, wide canonical IV, backedge-taken-count) checks with an + /// (active-lane-mask recipe, wide canonical IV, trip-count). If \p + /// UseActiveLaneMaskForControlFlow is true, introduce an + /// VPActiveLaneMaskPHIRecipe. If \p DataAndControlFlowWithoutRuntimeCheck is + /// true, no minimum-iteration runtime check will be created (during skeleton + /// creation) and instead it is handled using active-lane-mask. \p + /// DataAndControlFlowWithoutRuntimeCheck implies \p + /// UseActiveLaneMaskForControlFlow. + static void addActiveLaneMask(VPlan &Plan, + bool UseActiveLaneMaskForControlFlow, + bool DataAndControlFlowWithoutRuntimeCheck); + + /// Insert truncates and extends for any truncated recipe. Redundant casts + /// will be folded later. + static void + truncateToMinimalBitwidths(VPlan &Plan, + const MapVector<Instruction *, uint64_t> &MinBWs, + LLVMContext &Ctx); + +private: /// Remove redundant VPBasicBlocks by merging them into their predecessor if /// the predecessor has a single successor. static bool mergeBlocksIntoPredecessors(VPlan &Plan); @@ -71,24 +115,6 @@ struct VPlanTransforms { /// them with already existing recipes expanding the same SCEV expression. static void removeRedundantExpandSCEVRecipes(VPlan &Plan); - /// Sink users of fixed-order recurrences after the recipe defining their - /// previous value. Then introduce FirstOrderRecurrenceSplice VPInstructions - /// to combine the value from the recurrence phis and previous values. The - /// current implementation assumes all users can be sunk after the previous - /// value, which is enforced by earlier legality checks. - /// \returns true if all users of fixed-order recurrences could be re-arranged - /// as needed or false if it is not possible. In the latter case, \p Plan is - /// not valid. - static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); - - /// Clear NSW/NUW flags from reduction instructions if necessary. - static void clearReductionWrapFlags(VPlan &Plan); - - /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the - /// resulting plan to \p BestVF and \p BestUF. - static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, - unsigned BestUF, - PredicatedScalarEvolution &PSE); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index ac110bb3b0ef..e5ca52755dd2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -163,6 +163,13 @@ public: void replaceAllUsesWith(VPValue *New); + /// Go through the uses list for this VPValue and make each use point to \p + /// New if the callback ShouldReplace returns true for the given use specified + /// by a pair of (VPUser, the use index). + void replaceUsesWithIf( + VPValue *New, + llvm::function_ref<bool(VPUser &U, unsigned Idx)> ShouldReplace); + /// Returns the recipe defining this VPValue or nullptr if it is not defined /// by a recipe, i.e. is a live-in. VPRecipeBase *getDefiningRecipe(); @@ -296,6 +303,14 @@ public: "Op must be an operand of the recipe"); return false; } + + /// Returns true if the VPUser only uses the first part of operand \p Op. + /// Conservatively returns false. + virtual bool onlyFirstPartUsed(const VPValue *Op) const { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return false; + } }; /// This class augments a recipe with a set of VPValues defined by the recipe. @@ -325,7 +340,7 @@ class VPDef { assert(V->Def == this && "can only remove VPValue linked with this VPDef"); assert(is_contained(DefinedValues, V) && "VPValue to remove must be in DefinedValues"); - erase_value(DefinedValues, V); + llvm::erase(DefinedValues, V); V->Def = nullptr; } diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 13464c9d3496..f18711ba30b7 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -13,6 +13,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/VectorCombine.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -28,6 +30,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/Local.h" #include <numeric> +#include <queue> #define DEBUG_TYPE "vector-combine" #include "llvm/Transforms/Utils/InstructionWorklist.h" @@ -100,8 +103,9 @@ private: Instruction &I); bool foldExtractExtract(Instruction &I); bool foldInsExtFNeg(Instruction &I); - bool foldBitcastShuf(Instruction &I); + bool foldBitcastShuffle(Instruction &I); bool scalarizeBinopOrCmp(Instruction &I); + bool scalarizeVPIntrinsic(Instruction &I); bool foldExtractedCmps(Instruction &I); bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); @@ -258,8 +262,8 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // It is safe and potentially profitable to load a vector directly: // inselt undef, load Scalar, 0 --> load VecPtr IRBuilder<> Builder(Load); - Value *CastedPtr = Builder.CreatePointerBitCastOrAddrSpaceCast( - SrcPtr, MinVecTy->getPointerTo(AS)); + Value *CastedPtr = + Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment); VecLd = Builder.CreateShuffleVector(VecLd, Mask); @@ -321,7 +325,7 @@ bool VectorCombine::widenSubvectorLoad(Instruction &I) { IRBuilder<> Builder(Load); Value *CastedPtr = - Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Ty->getPointerTo(AS)); + Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment); replaceValue(I, *VecLd); ++NumVecLoad; @@ -677,7 +681,7 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) { /// If this is a bitcast of a shuffle, try to bitcast the source vector to the /// destination type followed by shuffle. This can enable further transforms by /// moving bitcasts or shuffles together. -bool VectorCombine::foldBitcastShuf(Instruction &I) { +bool VectorCombine::foldBitcastShuffle(Instruction &I) { Value *V; ArrayRef<int> Mask; if (!match(&I, m_BitCast( @@ -687,35 +691,43 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) { // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for // scalable type is unknown; Second, we cannot reason if the narrowed shuffle // mask for scalable type is a splat or not. - // 2) Disallow non-vector casts and length-changing shuffles. + // 2) Disallow non-vector casts. // TODO: We could allow any shuffle. + auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); auto *SrcTy = dyn_cast<FixedVectorType>(V->getType()); - if (!SrcTy || I.getOperand(0)->getType() != SrcTy) + if (!DestTy || !SrcTy) + return false; + + unsigned DestEltSize = DestTy->getScalarSizeInBits(); + unsigned SrcEltSize = SrcTy->getScalarSizeInBits(); + if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0) return false; - auto *DestTy = cast<FixedVectorType>(I.getType()); - unsigned DestNumElts = DestTy->getNumElements(); - unsigned SrcNumElts = SrcTy->getNumElements(); SmallVector<int, 16> NewMask; - if (SrcNumElts <= DestNumElts) { + if (DestEltSize <= SrcEltSize) { // The bitcast is from wide to narrow/equal elements. The shuffle mask can // always be expanded to the equivalent form choosing narrower elements. - assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); - unsigned ScaleFactor = DestNumElts / SrcNumElts; + assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask"); + unsigned ScaleFactor = SrcEltSize / DestEltSize; narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); } else { // The bitcast is from narrow elements to wide elements. The shuffle mask // must choose consecutive elements to allow casting first. - assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask"); - unsigned ScaleFactor = SrcNumElts / DestNumElts; + assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask"); + unsigned ScaleFactor = DestEltSize / SrcEltSize; if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) return false; } + // Bitcast the shuffle src - keep its original width but using the destination + // scalar type. + unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize; + auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); + // The new shuffle must not cost more than the old shuffle. The bitcast is // moved ahead of the shuffle, so assume that it has the same cost as before. InstructionCost DestCost = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, DestTy, NewMask); + TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask); InstructionCost SrcCost = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask); if (DestCost > SrcCost || !DestCost.isValid()) @@ -723,12 +735,131 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) { // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' ++NumShufOfBitcast; - Value *CastV = Builder.CreateBitCast(V, DestTy); + Value *CastV = Builder.CreateBitCast(V, ShuffleTy); Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask); replaceValue(I, *Shuf); return true; } +/// VP Intrinsics whose vector operands are both splat values may be simplified +/// into the scalar version of the operation and the result splatted. This +/// can lead to scalarization down the line. +bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { + if (!isa<VPIntrinsic>(I)) + return false; + VPIntrinsic &VPI = cast<VPIntrinsic>(I); + Value *Op0 = VPI.getArgOperand(0); + Value *Op1 = VPI.getArgOperand(1); + + if (!isSplatValue(Op0) || !isSplatValue(Op1)) + return false; + + // Check getSplatValue early in this function, to avoid doing unnecessary + // work. + Value *ScalarOp0 = getSplatValue(Op0); + Value *ScalarOp1 = getSplatValue(Op1); + if (!ScalarOp0 || !ScalarOp1) + return false; + + // For the binary VP intrinsics supported here, the result on disabled lanes + // is a poison value. For now, only do this simplification if all lanes + // are active. + // TODO: Relax the condition that all lanes are active by using insertelement + // on inactive lanes. + auto IsAllTrueMask = [](Value *MaskVal) { + if (Value *SplattedVal = getSplatValue(MaskVal)) + if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) + return ConstValue->isAllOnesValue(); + return false; + }; + if (!IsAllTrueMask(VPI.getArgOperand(2))) + return false; + + // Check to make sure we support scalarization of the intrinsic + Intrinsic::ID IntrID = VPI.getIntrinsicID(); + if (!VPBinOpIntrinsic::isVPBinOp(IntrID)) + return false; + + // Calculate cost of splatting both operands into vectors and the vector + // intrinsic + VectorType *VecTy = cast<VectorType>(VPI.getType()); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost SplatCost = + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) + + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy); + + // Calculate the cost of the VP Intrinsic + SmallVector<Type *, 4> Args; + for (Value *V : VPI.args()) + Args.push_back(V->getType()); + IntrinsicCostAttributes Attrs(IntrID, VecTy, Args); + InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); + InstructionCost OldCost = 2 * SplatCost + VectorOpCost; + + // Determine scalar opcode + std::optional<unsigned> FunctionalOpcode = + VPI.getFunctionalOpcode(); + std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt; + if (!FunctionalOpcode) { + ScalarIntrID = VPI.getFunctionalIntrinsicID(); + if (!ScalarIntrID) + return false; + } + + // Calculate cost of scalarizing + InstructionCost ScalarOpCost = 0; + if (ScalarIntrID) { + IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args); + ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); + } else { + ScalarOpCost = + TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType()); + } + + // The existing splats may be kept around if other instructions use them. + InstructionCost CostToKeepSplats = + (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse()); + InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats; + + LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI + << "\n"); + LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost + << ", Cost of scalarizing:" << NewCost << "\n"); + + // We want to scalarize unless the vector variant actually has lower cost. + if (OldCost < NewCost || !NewCost.isValid()) + return false; + + // Scalarize the intrinsic + ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount(); + Value *EVL = VPI.getArgOperand(3); + const DataLayout &DL = VPI.getModule()->getDataLayout(); + + // If the VP op might introduce UB or poison, we can scalarize it provided + // that we know the EVL > 0: If the EVL is zero, then the original VP op + // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by + // scalarizing it. + bool SafeToSpeculate; + if (ScalarIntrID) + SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID) + .hasFnAttr(Attribute::AttrKind::Speculatable); + else + SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode( + *FunctionalOpcode, &VPI, nullptr, &AC, &DT); + if (!SafeToSpeculate && !isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT)) + return false; + + Value *ScalarVal = + ScalarIntrID + ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID, + {ScalarOp0, ScalarOp1}) + : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode), + ScalarOp0, ScalarOp1); + + replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal)); + return true; +} + /// Match a vector binop or compare instruction with at least one inserted /// scalar operand and convert to scalar binop/cmp followed by insertelement. bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { @@ -1013,19 +1144,24 @@ public: /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy, - Value *Idx, Instruction *CtxI, +static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, + Instruction *CtxI, AssumptionCache &AC, const DominatorTree &DT) { + // We do checks for both fixed vector types and scalable vector types. + // This is the number of elements of fixed vector types, + // or the minimum number of elements of scalable vector types. + uint64_t NumElements = VecTy->getElementCount().getKnownMinValue(); + if (auto *C = dyn_cast<ConstantInt>(Idx)) { - if (C->getValue().ult(VecTy->getNumElements())) + if (C->getValue().ult(NumElements)) return ScalarizationResult::safe(); return ScalarizationResult::unsafe(); } unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); APInt Zero(IntWidth, 0); - APInt MaxElts(IntWidth, VecTy->getNumElements()); + APInt MaxElts(IntWidth, NumElements); ConstantRange ValidIndices(Zero, MaxElts); ConstantRange IdxRange(IntWidth, true); @@ -1074,8 +1210,7 @@ static Align computeAlignmentAfterScalarization(Align VectorAlignment, // store i32 %b, i32* %1 bool VectorCombine::foldSingleElementStore(Instruction &I) { auto *SI = cast<StoreInst>(&I); - if (!SI->isSimple() || - !isa<FixedVectorType>(SI->getValueOperand()->getType())) + if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType())) return false; // TODO: Combine more complicated patterns (multiple insert) by referencing @@ -1089,13 +1224,13 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { return false; if (auto *Load = dyn_cast<LoadInst>(Source)) { - auto VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType()); + auto VecTy = cast<VectorType>(SI->getValueOperand()->getType()); const DataLayout &DL = I.getModule()->getDataLayout(); Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); // Don't optimize for atomic/volatile load or store. Ensure memory is not // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || - !DL.typeSizeEqualsStoreSize(Load->getType()) || + !DL.typeSizeEqualsStoreSize(Load->getType()->getScalarType()) || SrcAddr != SI->getPointerOperand()->stripPointerCasts()) return false; @@ -1130,19 +1265,26 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { if (!match(&I, m_Load(m_Value(Ptr)))) return false; - auto *FixedVT = cast<FixedVectorType>(I.getType()); + auto *VecTy = cast<VectorType>(I.getType()); auto *LI = cast<LoadInst>(&I); const DataLayout &DL = I.getModule()->getDataLayout(); - if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT)) + if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy->getScalarType())) return false; InstructionCost OriginalCost = - TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(), + TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), LI->getPointerAddressSpace()); InstructionCost ScalarizedCost = 0; Instruction *LastCheckedInst = LI; unsigned NumInstChecked = 0; + DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze; + auto FailureGuard = make_scope_exit([&]() { + // If the transform is aborted, discard the ScalarizationResults. + for (auto &Pair : NeedFreeze) + Pair.second.discard(); + }); + // Check if all users of the load are extracts with no memory modifications // between the load and the extract. Compute the cost of both the original // code and the scalarized version. @@ -1151,9 +1293,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { if (!UI || UI->getParent() != LI->getParent()) return false; - if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT)) - return false; - // Check if any instruction between the load and the extract may modify // memory. if (LastCheckedInst->comesBefore(UI)) { @@ -1168,22 +1307,23 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { LastCheckedInst = UI; } - auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); - if (!ScalarIdx.isSafe()) { - // TODO: Freeze index if it is safe to do so. - ScalarIdx.discard(); + auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT); + if (ScalarIdx.isUnsafe()) return false; + if (ScalarIdx.isSafeWithFreeze()) { + NeedFreeze.try_emplace(UI, ScalarIdx); + ScalarIdx.discard(); } auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; OriginalCost += - TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind, + TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index ? Index->getZExtValue() : -1); ScalarizedCost += - TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(), + TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(), Align(1), LI->getPointerAddressSpace()); - ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType()); + ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType()); } if (ScalarizedCost >= OriginalCost) @@ -1192,21 +1332,27 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { // Replace extracts with narrow scalar loads. for (User *U : LI->users()) { auto *EI = cast<ExtractElementInst>(U); - Builder.SetInsertPoint(EI); - Value *Idx = EI->getOperand(1); + + // Insert 'freeze' for poison indexes. + auto It = NeedFreeze.find(EI); + if (It != NeedFreeze.end()) + It->second.freeze(Builder, *cast<Instruction>(Idx)); + + Builder.SetInsertPoint(EI); Value *GEP = - Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx}); + Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx}); auto *NewLoad = cast<LoadInst>(Builder.CreateLoad( - FixedVT->getElementType(), GEP, EI->getName() + ".scalar")); + VecTy->getElementType(), GEP, EI->getName() + ".scalar")); Align ScalarOpAlignment = computeAlignmentAfterScalarization( - LI->getAlign(), FixedVT->getElementType(), Idx, DL); + LI->getAlign(), VecTy->getElementType(), Idx, DL); NewLoad->setAlignment(ScalarOpAlignment); replaceValue(*EI, *NewLoad); } + FailureGuard.release(); return true; } @@ -1340,21 +1486,28 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) { dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType()); if (!ShuffleInputType) return false; - int NumInputElts = ShuffleInputType->getNumElements(); + unsigned NumInputElts = ShuffleInputType->getNumElements(); // Find the mask from sorting the lanes into order. This is most likely to // become a identity or concat mask. Undef elements are pushed to the end. SmallVector<int> ConcatMask; Shuffle->getShuffleMask(ConcatMask); sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; }); + // In the case of a truncating shuffle it's possible for the mask + // to have an index greater than the size of the resulting vector. + // This requires special handling. + bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts; bool UsesSecondVec = - any_of(ConcatMask, [&](int M) { return M >= NumInputElts; }); + any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; }); + + FixedVectorType *VecTyForCost = + (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType; InstructionCost OldCost = TTI.getShuffleCost( - UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType, - Shuffle->getShuffleMask()); + UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, + VecTyForCost, Shuffle->getShuffleMask()); InstructionCost NewCost = TTI.getShuffleCost( - UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType, - ConcatMask); + UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, + VecTyForCost, ConcatMask); LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle << "\n"); @@ -1657,16 +1810,16 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { return SSV->getOperand(Op); return SV->getOperand(Op); }; - Builder.SetInsertPoint(SVI0A->getInsertionPointAfterDef()); + Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef()); Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0), GetShuffleOperand(SVI0A, 1), V1A); - Builder.SetInsertPoint(SVI0B->getInsertionPointAfterDef()); + Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef()); Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0), GetShuffleOperand(SVI0B, 1), V1B); - Builder.SetInsertPoint(SVI1A->getInsertionPointAfterDef()); + Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef()); Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0), GetShuffleOperand(SVI1A, 1), V2A); - Builder.SetInsertPoint(SVI1B->getInsertionPointAfterDef()); + Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef()); Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0), GetShuffleOperand(SVI1B, 1), V2B); Builder.SetInsertPoint(Op0); @@ -1723,9 +1876,6 @@ bool VectorCombine::run() { case Instruction::ShuffleVector: MadeChange |= widenSubvectorLoad(I); break; - case Instruction::Load: - MadeChange |= scalarizeLoadExtract(I); - break; default: break; } @@ -1733,13 +1883,15 @@ bool VectorCombine::run() { // This transform works with scalable and fixed vectors // TODO: Identify and allow other scalable transforms - if (isa<VectorType>(I.getType())) + if (isa<VectorType>(I.getType())) { MadeChange |= scalarizeBinopOrCmp(I); + MadeChange |= scalarizeLoadExtract(I); + MadeChange |= scalarizeVPIntrinsic(I); + } if (Opcode == Instruction::Store) MadeChange |= foldSingleElementStore(I); - // If this is an early pipeline invocation of this pass, we are done. if (TryEarlyFoldsOnly) return; @@ -1758,7 +1910,7 @@ bool VectorCombine::run() { MadeChange |= foldSelectShuffle(I); break; case Instruction::BitCast: - MadeChange |= foldBitcastShuf(I); + MadeChange |= foldBitcastShuffle(I); break; } } else { |
