diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-03-20 11:40:34 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2022-05-14 11:43:05 +0000 |
commit | 349cc55c9796c4596a5b9904cd3281af295f878f (patch) | |
tree | 410c5a785075730a35f1272ca6a7adf72222ad03 /contrib/llvm-project/llvm/lib/Transforms | |
parent | cb2ae6163174b90e999326ecec3699ee093a5d43 (diff) | |
parent | c0981da47d5696fe36474fcf86b4ce03ae3ff818 (diff) | |
download | src-349cc55c9796c4596a5b9904cd3281af295f878f.tar.gz src-349cc55c9796c4596a5b9904cd3281af295f878f.zip |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms')
188 files changed, 14343 insertions, 8757 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 85abbf6d86e0..7243e39c9029 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -18,6 +18,7 @@ #include "llvm-c/Transforms/AggressiveInstCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -205,8 +206,8 @@ struct MaskOps { bool FoundAnd1; MaskOps(unsigned BitWidth, bool MatchAnds) - : Root(nullptr), Mask(APInt::getNullValue(BitWidth)), - MatchAndChain(MatchAnds), FoundAnd1(false) {} + : Root(nullptr), Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds), + FoundAnd1(false) {} }; /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a @@ -377,10 +378,10 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { // Also, we want to avoid matching partial patterns. // TODO: It would be more efficient if we removed dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { + for (Instruction &I : llvm::reverse(BB)) { MadeChange |= foldAnyOrAllBitsSet(I); MadeChange |= foldGuardedFunnelShift(I, DT); - MadeChange |= tryToRecognizePopCount(I); + MadeChange |= tryToRecognizePopCount(I); } } @@ -394,10 +395,11 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. -static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { +static bool runImpl(Function &F, AssumptionCache &AC, TargetLibraryInfo &TLI, + DominatorTree &DT) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); - TruncInstCombine TIC(TLI, DL, DT); + TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); MadeChange |= foldUnusualPatterns(F, DT); return MadeChange; @@ -406,6 +408,7 @@ static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { void AggressiveInstCombinerLegacyPass::getAnalysisUsage( AnalysisUsage &AU) const { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); @@ -415,16 +418,18 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage( } bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return runImpl(F, TLI, DT); + return runImpl(F, AC, TLI, DT); } PreservedAnalyses AggressiveInstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - if (!runImpl(F, TLI, DT)) { + if (!runImpl(F, AC, TLI, DT)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } @@ -438,6 +443,7 @@ char AggressiveInstCombinerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", "Combine pattern based expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h index 42bcadfc7dcd..5d69e26d6ecc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -17,6 +17,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -39,16 +41,18 @@ using namespace llvm; //===----------------------------------------------------------------------===// namespace llvm { - class DataLayout; - class DominatorTree; - class Function; - class Instruction; - class TargetLibraryInfo; - class TruncInst; - class Type; - class Value; +class AssumptionCache; +class DataLayout; +class DominatorTree; +class Function; +class Instruction; +class TargetLibraryInfo; +class TruncInst; +class Type; +class Value; class TruncInstCombine { + AssumptionCache &AC; TargetLibraryInfo &TLI; const DataLayout &DL; const DominatorTree &DT; @@ -75,9 +79,9 @@ class TruncInstCombine { MapVector<Instruction *, Info> InstInfoMap; public: - TruncInstCombine(TargetLibraryInfo &TLI, const DataLayout &DL, - const DominatorTree &DT) - : TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} + TruncInstCombine(AssumptionCache &AC, TargetLibraryInfo &TLI, + const DataLayout &DL, const DominatorTree &DT) + : AC(AC), TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} /// Perform TruncInst pattern optimization on given function. bool run(Function &F); @@ -104,6 +108,18 @@ private: /// to be reduced. Type *getBestTruncatedType(); + KnownBits computeKnownBits(const Value *V) const { + return llvm::computeKnownBits(V, DL, /*Depth=*/0, &AC, + /*CtxI=*/cast<Instruction>(CurrentTruncInst), + &DT); + } + + unsigned ComputeNumSignBits(const Value *V) const { + return llvm::ComputeNumSignBits( + V, DL, /*Depth=*/0, &AC, /*CtxI=*/cast<Instruction>(CurrentTruncInst), + &DT); + } + /// Given a \p V value and a \p SclTy scalar type return the generated reduced /// value of \p V based on the type \p SclTy. /// diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 16b82219e8ca..abac3f801a22 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -61,9 +62,18 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::InsertElement: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; + case Instruction::ExtractElement: + Ops.push_back(I->getOperand(0)); + break; case Instruction::Select: Ops.push_back(I->getOperand(1)); Ops.push_back(I->getOperand(2)); @@ -127,6 +137,13 @@ bool TruncInstCombine::buildTruncExpressionDag() { case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: + case Instruction::InsertElement: + case Instruction::ExtractElement: case Instruction::Select: { SmallVector<Value *, 2> Operands; getRelevantOperands(I, Operands); @@ -135,10 +152,9 @@ bool TruncInstCombine::buildTruncExpressionDag() { } default: // TODO: Can handle more cases here: - // 1. shufflevector, extractelement, insertelement - // 2. udiv, urem - // 3. shl, lshr, ashr - // 4. phi node(and loop handling) + // 1. shufflevector + // 2. sdiv, srem + // 3. phi node(and loop handling) // ... return false; } @@ -270,6 +286,50 @@ Type *TruncInstCombine::getBestTruncatedType() { unsigned OrigBitWidth = CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + // Initialize MinBitWidth for shift instructions with the minimum number + // that is greater than shift amount (i.e. shift amount + 1). + // For `lshr` adjust MinBitWidth so that all potentially truncated + // bits of the value-to-be-shifted are zeros. + // For `ashr` adjust MinBitWidth so that all potentially truncated + // bits of the value-to-be-shifted are sign bits (all zeros or ones) + // and even one (first) untruncated bit is sign bit. + // Exit early if MinBitWidth is not less than original bitwidth. + for (auto &Itr : InstInfoMap) { + Instruction *I = Itr.first; + if (I->isShift()) { + KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); + unsigned MinBitWidth = KnownRHS.getMaxValue() + .uadd_sat(APInt(OrigBitWidth, 1)) + .getLimitedValue(OrigBitWidth); + if (MinBitWidth == OrigBitWidth) + return nullptr; + if (I->getOpcode() == Instruction::LShr) { + KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); + MinBitWidth = + std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); + } + if (I->getOpcode() == Instruction::AShr) { + unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); + MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); + } + if (MinBitWidth >= OrigBitWidth) + return nullptr; + Itr.second.MinBitWidth = MinBitWidth; + } + if (I->getOpcode() == Instruction::UDiv || + I->getOpcode() == Instruction::URem) { + unsigned MinBitWidth = 0; + for (const auto &Op : I->operands()) { + KnownBits Known = computeKnownBits(Op); + MinBitWidth = + std::max(Known.getMaxValue().getActiveBits(), MinBitWidth); + if (MinBitWidth >= OrigBitWidth) + return nullptr; + } + Itr.second.MinBitWidth = MinBitWidth; + } + } + // Calculate minimum allowed bit-width allowed for shrinking the currently // visited truncate's operand. unsigned MinBitWidth = getMinBitWidth(); @@ -356,10 +416,32 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { case Instruction::Mul: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::UDiv: + case Instruction::URem: { Value *LHS = getReducedOperand(I->getOperand(0), SclTy); Value *RHS = getReducedOperand(I->getOperand(1), SclTy); Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); + // Preserve `exact` flag since truncation doesn't change exactness + if (auto *PEO = dyn_cast<PossiblyExactOperator>(I)) + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->setIsExact(PEO->isExact()); + break; + } + case Instruction::ExtractElement: { + Value *Vec = getReducedOperand(I->getOperand(0), SclTy); + Value *Idx = I->getOperand(1); + Res = Builder.CreateExtractElement(Vec, Idx); + break; + } + case Instruction::InsertElement: { + Value *Vec = getReducedOperand(I->getOperand(0), SclTy); + Value *NewElt = getReducedOperand(I->getOperand(1), SclTy); + Value *Idx = I->getOperand(2); + Res = Builder.CreateInsertElement(Vec, NewElt, Idx); break; } case Instruction::Select: { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index 5b09cdb35791..67f8828e4c75 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -56,8 +56,10 @@ static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) { bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { bool Changed = false; - for (auto IB = inst_begin(F), E = inst_end(F); IB != E;) { - Instruction &I = *IB++; + bool IsPrivateAndUnprocessed = + F.hasFnAttribute(CORO_PRESPLIT_ATTR) && F.hasLocalLinkage(); + + for (Instruction &I : llvm::make_early_inc_range(instructions(F))) { if (auto *II = dyn_cast<IntrinsicInst>(&I)) { switch (II->getIntrinsicID()) { default: @@ -71,6 +73,10 @@ bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { case Intrinsic::coro_alloc: II->replaceAllUsesWith(ConstantInt::getTrue(Context)); break; + case Intrinsic::coro_async_resume: + II->replaceAllUsesWith( + ConstantPointerNull::get(cast<PointerType>(I.getType()))); + break; case Intrinsic::coro_id: case Intrinsic::coro_id_retcon: case Intrinsic::coro_id_retcon_once: @@ -80,6 +86,13 @@ bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { case Intrinsic::coro_subfn_addr: lowerSubFn(Builder, cast<CoroSubFnInst>(II)); break; + case Intrinsic::coro_end: + case Intrinsic::coro_suspend_retcon: + if (IsPrivateAndUnprocessed) { + II->replaceAllUsesWith(UndefValue::get(II->getType())); + } else + continue; + break; case Intrinsic::coro_async_size_replace: auto *Target = cast<ConstantStruct>( cast<GlobalVariable>(II->getArgOperand(0)->stripPointerCasts()) @@ -115,7 +128,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) { return coro::declaresIntrinsics( M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr", "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon", - "llvm.coro.id.retcon.once", "llvm.coro.async.size.replace"}); + "llvm.coro.id.retcon.once", "llvm.coro.async.size.replace", + "llvm.coro.async.resume"}); } PreservedAnalyses CoroCleanupPass::run(Function &F, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index 5e5e513cdfda..68a34bdcb1cd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -150,8 +150,7 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { CoroIdInst *CoroId = nullptr; SmallVector<CoroFreeInst *, 4> CoroFrees; bool HasCoroSuspend = false; - for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) { - Instruction &I = *IB++; + for (Instruction &I : llvm::make_early_inc_range(instructions(F))) { if (auto *CB = dyn_cast<CallBase>(&I)) { switch (CB->getIntrinsicID()) { default: diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index beae5fdac8ab..ac3d078714ce 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -16,6 +16,7 @@ #include "CoroInternal.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" #include "llvm/Analysis/PtrUseVisitor.h" #include "llvm/Analysis/StackLifetime.h" @@ -435,7 +436,7 @@ private: DenseMap<Value*, unsigned> FieldIndexByKey; public: - FrameTypeBuilder(LLVMContext &Context, DataLayout const &DL, + FrameTypeBuilder(LLVMContext &Context, const DataLayout &DL, Optional<Align> MaxFrameAlignment) : DL(DL), Context(Context), MaxFrameAlignment(MaxFrameAlignment) {} @@ -576,13 +577,8 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F, using AllocaSetType = SmallVector<AllocaInst *, 4>; SmallVector<AllocaSetType, 4> NonOverlapedAllocas; - // We need to add field for allocas at the end of this function. However, this - // function has multiple exits, so we use this helper to avoid redundant code. - struct RTTIHelper { - std::function<void()> func; - RTTIHelper(std::function<void()> &&func) : func(func) {} - ~RTTIHelper() { func(); } - } Helper([&]() { + // We need to add field for allocas at the end of this function. + auto AddFieldForAllocasAtExit = make_scope_exit([&]() { for (auto AllocaList : NonOverlapedAllocas) { auto *LargestAI = *AllocaList.begin(); FieldIDType Id = addFieldForAlloca(LargestAI); @@ -840,8 +836,9 @@ static StringRef solveTypeName(Type *Ty) { return "UnknownType"; } -static DIType *solveDIType(DIBuilder &Builder, Type *Ty, DataLayout &Layout, - DIScope *Scope, unsigned LineNum, +static DIType *solveDIType(DIBuilder &Builder, Type *Ty, + const DataLayout &Layout, DIScope *Scope, + unsigned LineNum, DenseMap<Type *, DIType *> &DITypeCache) { if (DIType *DT = DITypeCache.lookup(Ty)) return DT; @@ -1348,13 +1345,17 @@ struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { } void visitIntrinsicInst(IntrinsicInst &II) { - if (II.getIntrinsicID() != Intrinsic::lifetime_start) + // When we found the lifetime markers refers to a + // subrange of the original alloca, ignore the lifetime + // markers to avoid misleading the analysis. + if (II.getIntrinsicID() != Intrinsic::lifetime_start || !IsOffsetKnown || + !Offset.isZero()) return Base::visitIntrinsicInst(II); LifetimeStarts.insert(&II); } void visitCallBase(CallBase &CB) { - for (unsigned Op = 0, OpCount = CB.getNumArgOperands(); Op < OpCount; ++Op) + for (unsigned Op = 0, OpCount = CB.arg_size(); Op < OpCount; ++Op) if (U->get() == CB.getArgOperand(Op) && !CB.doesNotCapture(Op)) PI.setEscaped(&CB); handleMayWrite(CB); @@ -1868,8 +1869,7 @@ static void cleanupSinglePredPHIs(Function &F) { } } while (!Worklist.empty()) { - auto *Phi = Worklist.back(); - Worklist.pop_back(); + auto *Phi = Worklist.pop_back_val(); auto *OriginalValue = Phi->getIncomingValue(0); Phi->replaceAllUsesWith(OriginalValue); } @@ -1984,14 +1984,15 @@ static void rewriteMaterializableInstructions(IRBuilder<> &IRB, if (CurrentBlock != U->getParent()) { bool IsInCoroSuspendBlock = isa<AnyCoroSuspendInst>(U); - CurrentBlock = IsInCoroSuspendBlock - ? U->getParent()->getSinglePredecessor() - : U->getParent(); + CurrentBlock = U->getParent(); + auto *InsertBlock = IsInCoroSuspendBlock + ? CurrentBlock->getSinglePredecessor() + : CurrentBlock; CurrentMaterialization = cast<Instruction>(Def)->clone(); CurrentMaterialization->setName(Def->getName()); CurrentMaterialization->insertBefore( - IsInCoroSuspendBlock ? CurrentBlock->getTerminator() - : &*CurrentBlock->getFirstInsertionPt()); + IsInCoroSuspendBlock ? InsertBlock->getTerminator() + : &*InsertBlock->getFirstInsertionPt()); } if (auto *PN = dyn_cast<PHINode>(U)) { assert(PN->getNumIncomingValues() == 1 && @@ -2244,12 +2245,7 @@ static Value *emitSetAndGetSwiftErrorValueAround(Instruction *Call, /// intrinsics and attempting to MemToReg the alloca away. static void eliminateSwiftErrorAlloca(Function &F, AllocaInst *Alloca, coro::Shape &Shape) { - for (auto UI = Alloca->use_begin(), UE = Alloca->use_end(); UI != UE; ) { - // We're likely changing the use list, so use a mutation-safe - // iteration pattern. - auto &Use = *UI; - ++UI; - + for (Use &Use : llvm::make_early_inc_range(Alloca->uses())) { // swifterror values can only be used in very specific ways. // We take advantage of that here. auto User = Use.getUser(); @@ -2510,11 +2506,11 @@ void coro::salvageDebugInfo( DIExpression *Expr = DVI->getExpression(); // Follow the pointer arithmetic all the way to the incoming // function argument and convert into a DIExpression. - bool OutermostLoad = true; + bool SkipOutermostLoad = !isa<DbgValueInst>(DVI); Value *Storage = DVI->getVariableLocationOp(0); Value *OriginalStorage = Storage; - while (Storage) { - if (auto *LdInst = dyn_cast<LoadInst>(Storage)) { + while (auto *Inst = dyn_cast_or_null<Instruction>(Storage)) { + if (auto *LdInst = dyn_cast<LoadInst>(Inst)) { Storage = LdInst->getOperand(0); // FIXME: This is a heuristic that works around the fact that // LLVM IR debug intrinsics cannot yet distinguish between @@ -2522,26 +2518,25 @@ void coro::salvageDebugInfo( // implicitly a memory location no DW_OP_deref operation for the // last direct load from an alloca is necessary. This condition // effectively drops the *last* DW_OP_deref in the expression. - if (!OutermostLoad) + if (!SkipOutermostLoad) Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore); - OutermostLoad = false; - } else if (auto *StInst = dyn_cast<StoreInst>(Storage)) { + } else if (auto *StInst = dyn_cast<StoreInst>(Inst)) { Storage = StInst->getOperand(0); - } else if (auto *GEPInst = dyn_cast<GetElementPtrInst>(Storage)) { - SmallVector<Value *> AdditionalValues; - DIExpression *SalvagedExpr = llvm::salvageDebugInfoImpl( - *GEPInst, Expr, - /*WithStackValue=*/false, 0, AdditionalValues); - // Debug declares cannot currently handle additional location - // operands. - if (!SalvagedExpr || !AdditionalValues.empty()) + } else { + SmallVector<uint64_t, 16> Ops; + SmallVector<Value *, 0> AdditionalValues; + Value *Op = llvm::salvageDebugInfoImpl( + *Inst, Expr ? Expr->getNumLocationOperands() : 0, Ops, + AdditionalValues); + if (!Op || !AdditionalValues.empty()) { + // If salvaging failed or salvaging produced more than one location + // operand, give up. break; - Expr = SalvagedExpr; - Storage = GEPInst->getOperand(0); - } else if (auto *BCInst = dyn_cast<llvm::BitCastInst>(Storage)) - Storage = BCInst->getOperand(0); - else - break; + } + Storage = Op; + Expr = DIExpression::appendOpsToArg(Expr, Ops, 0, /*StackValue*/ false); + } + SkipOutermostLoad = false; } if (!Storage) return; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h index 5ed800d67fe9..bf3d781ba43e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -638,7 +638,7 @@ public: void checkWellFormed() const; Function *getMustTailCallFunction() const { - if (getNumArgOperands() < 3) + if (arg_size() < 3) return nullptr; return cast<Function>( diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index b6932dbbfc3f..fa1d92f439b8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -520,8 +520,8 @@ void CoroCloner::replaceRetconOrAsyncSuspendUses() { } // Try to peephole extracts of an aggregate return. - for (auto UI = NewS->use_begin(), UE = NewS->use_end(); UI != UE; ) { - auto EVI = dyn_cast<ExtractValueInst>((UI++)->getUser()); + for (Use &U : llvm::make_early_inc_range(NewS->uses())) { + auto *EVI = dyn_cast<ExtractValueInst>(U.getUser()); if (!EVI || EVI->getNumIndices() != 1) continue; @@ -622,12 +622,12 @@ static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape, // If there are no arguments, this is a 'get' operation. Value *MappedResult; - if (Op->getNumArgOperands() == 0) { + if (Op->arg_empty()) { auto ValueTy = Op->getType(); auto Slot = getSwiftErrorSlot(ValueTy); MappedResult = Builder.CreateLoad(ValueTy, Slot); } else { - assert(Op->getNumArgOperands() == 1); + assert(Op->arg_size() == 1); auto Value = MappedOp->getArgOperand(0); auto ValueTy = Value->getType(); auto Slot = getSwiftErrorSlot(ValueTy); @@ -669,7 +669,7 @@ void CoroCloner::salvageDebugInfo() { for (DbgVariableIntrinsic *DVI : Worklist) { if (IsUnreachableBlock(DVI->getParent())) DVI->eraseFromParent(); - else if (dyn_cast_or_null<AllocaInst>(DVI->getVariableLocationOp(0))) { + else if (isa_and_nonnull<AllocaInst>(DVI->getVariableLocationOp(0))) { // Count all non-debuginfo uses in reachable blocks. unsigned Uses = 0; for (auto *User : DVI->getVariableLocationOp(0)->users()) @@ -738,8 +738,7 @@ void CoroCloner::replaceEntryBlock() { // entry needs to be moved to the new entry. Function *F = OldEntry->getParent(); DominatorTree DT{*F}; - for (auto IT = inst_begin(F), End = inst_end(F); IT != End;) { - Instruction &I = *IT++; + for (Instruction &I : llvm::make_early_inc_range(instructions(F))) { auto *Alloca = dyn_cast<AllocaInst>(&I); if (!Alloca || I.use_empty()) continue; @@ -773,9 +772,8 @@ Value *CoroCloner::deriveNewFramePointer() { auto DbgLoc = cast<CoroSuspendAsyncInst>(VMap[ActiveSuspend])->getDebugLoc(); // Calling i8* (i8*) - auto *CallerContext = Builder.CreateCall( - cast<FunctionType>(ProjectionFunc->getType()->getPointerElementType()), - ProjectionFunc, CalleeContext); + auto *CallerContext = Builder.CreateCall(ProjectionFunc->getFunctionType(), + ProjectionFunc, CalleeContext); CallerContext->setCallingConv(ProjectionFunc->getCallingConv()); CallerContext->setDebugLoc(DbgLoc); // The frame is located after the async_context header. @@ -906,8 +904,7 @@ void CoroCloner::create() { case coro::ABI::Switch: // Bootstrap attributes by copying function attributes from the // original function. This should include optimization settings and so on. - NewAttrs = NewAttrs.addAttributes(Context, AttributeList::FunctionIndex, - OrigAttrs.getFnAttributes()); + NewAttrs = NewAttrs.addFnAttributes(Context, OrigAttrs.getFnAttrs()); addFramePointerAttrs(NewAttrs, Context, 0, Shape.FrameSize, Shape.FrameAlign); @@ -929,9 +926,8 @@ void CoroCloner::create() { } // Transfer the original function's attributes. - auto FnAttrs = OrigF.getAttributes().getFnAttributes(); - NewAttrs = - NewAttrs.addAttributes(Context, AttributeList::FunctionIndex, FnAttrs); + auto FnAttrs = OrigF.getAttributes().getFnAttrs(); + NewAttrs = NewAttrs.addFnAttributes(Context, FnAttrs); break; } case coro::ABI::Retcon: @@ -1144,11 +1140,13 @@ static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, static void postSplitCleanup(Function &F) { removeUnreachableBlocks(F); +#ifndef NDEBUG // For now, we do a mandatory verification step because we don't // entirely trust this pass. Note that we don't want to add a verifier // pass to FPM below because it will also verify all the global data. if (verifyFunction(F, &errs())) report_fatal_error("Broken function"); +#endif } // Assuming we arrived at the block NewBlock from Prev instruction, store @@ -1262,7 +1260,7 @@ static bool shouldBeMustTail(const CallInst &CI, const Function &F) { Attribute::SwiftSelf, Attribute::SwiftError}; AttributeList Attrs = CI.getAttributes(); for (auto AK : ABIAttrs) - if (Attrs.hasParamAttribute(0, AK)) + if (Attrs.hasParamAttr(0, AK)) return false; return true; @@ -1357,7 +1355,7 @@ static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) { auto *BB = Worklist.pop_back_val(); Set.insert(BB); for (auto *Pred : predecessors(BB)) - if (Set.count(Pred) == 0) + if (!Set.contains(Pred)) Worklist.push_back(Pred); } @@ -1547,8 +1545,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy, CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, ArrayRef<Value *> Arguments, IRBuilder<> &Builder) { - auto *FnTy = - cast<FunctionType>(MustTailCallFn->getType()->getPointerElementType()); + auto *FnTy = MustTailCallFn->getFunctionType(); // Coerce the arguments, llvm optimizations seem to ignore the types in // vaarg functions and throws away casts in optimized mode. SmallVector<Value *, 8> CallArgs; @@ -1568,8 +1565,8 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, // Reset various things that the optimizer might have decided it // "knows" about the coroutine function due to not seeing a return. F.removeFnAttr(Attribute::NoReturn); - F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); - F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + F.removeRetAttr(Attribute::NoAlias); + F.removeRetAttr(Attribute::NonNull); auto &Context = F.getContext(); auto *Int8PtrTy = Type::getInt8PtrTy(Context); @@ -1667,8 +1664,8 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, // Reset various things that the optimizer might have decided it // "knows" about the coroutine function due to not seeing a return. F.removeFnAttr(Attribute::NoReturn); - F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); - F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + F.removeRetAttr(Attribute::NoAlias); + F.removeRetAttr(Attribute::NonNull); // Allocate the frame. auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId()); @@ -1977,9 +1974,9 @@ static void replacePrepare(CallInst *Prepare, LazyCallGraph &CG, // %2 = bitcast %1 to [[TYPE]] // ==> // %2 = @some_function - for (auto UI = Prepare->use_begin(), UE = Prepare->use_end(); UI != UE;) { + for (Use &U : llvm::make_early_inc_range(Prepare->uses())) { // Look for bitcasts back to the original function type. - auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser()); + auto *Cast = dyn_cast<BitCastInst>(U.getUser()); if (!Cast || Cast->getType() != Fn->getType()) continue; @@ -2019,10 +2016,9 @@ static void replacePrepare(CallInst *Prepare, CallGraph &CG) { // %2 = bitcast %1 to [[TYPE]] // ==> // %2 = @some_function - for (auto UI = Prepare->use_begin(), UE = Prepare->use_end(); - UI != UE; ) { + for (Use &U : llvm::make_early_inc_range(Prepare->uses())) { // Look for bitcasts back to the original function type. - auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser()); + auto *Cast = dyn_cast<BitCastInst>(U.getUser()); if (!Cast || Cast->getType() != Fn->getType()) continue; // Check whether the replacement will introduce new direct calls. @@ -2059,9 +2055,9 @@ static void replacePrepare(CallInst *Prepare, CallGraph &CG) { static bool replaceAllPrepares(Function *PrepareFn, LazyCallGraph &CG, LazyCallGraph::SCC &C) { bool Changed = false; - for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end(); PI != PE;) { + for (Use &P : llvm::make_early_inc_range(PrepareFn->uses())) { // Intrinsics can only be used in calls. - auto *Prepare = cast<CallInst>((PI++)->getUser()); + auto *Prepare = cast<CallInst>(P.getUser()); replacePrepare(Prepare, CG, C); Changed = true; } @@ -2077,10 +2073,9 @@ static bool replaceAllPrepares(Function *PrepareFn, LazyCallGraph &CG, /// switch coroutines, which are lowered in multiple stages). static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) { bool Changed = false; - for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end(); - PI != PE; ) { + for (Use &P : llvm::make_early_inc_range(PrepareFn->uses())) { // Intrinsics can only be used in calls. - auto *Prepare = cast<CallInst>((PI++)->getUser()); + auto *Prepare = cast<CallInst>(P.getUser()); replacePrepare(Prepare, CG); Changed = true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp index ae2d9e192c87..e4883ef89db7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -126,6 +126,7 @@ static bool isCoroutineIntrinsicName(StringRef Name) { "llvm.coro.alloc", "llvm.coro.async.context.alloc", "llvm.coro.async.context.dealloc", + "llvm.coro.async.resume", "llvm.coro.async.size.replace", "llvm.coro.async.store_resume", "llvm.coro.begin", @@ -311,10 +312,9 @@ void coro::Shape::buildFrom(Function &F) { if (CoroBegin) report_fatal_error( "coroutine should have exactly one defining @llvm.coro.begin"); - CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); - CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); - CB->removeAttribute(AttributeList::FunctionIndex, - Attribute::NoDuplicate); + CB->addRetAttr(Attribute::NonNull); + CB->addRetAttr(Attribute::NoAlias); + CB->removeFnAttr(Attribute::NoDuplicate); CoroBegin = CB; break; } @@ -571,8 +571,8 @@ void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr, llvm_unreachable("Unknown coro::ABI enum"); } -LLVM_ATTRIBUTE_NORETURN -static void fail(const Instruction *I, const char *Reason, Value *V) { +[[noreturn]] static void fail(const Instruction *I, const char *Reason, + Value *V) { #ifndef NDEBUG I->dump(); if (V) { @@ -722,7 +722,7 @@ void CoroAsyncEndInst::checkWellFormed() const { return; auto *FnTy = cast<FunctionType>(MustTailCallFunc->getType()->getPointerElementType()); - if (FnTy->getNumParams() != (getNumArgOperands() - 3)) + if (FnTy->getNumParams() != (arg_size() - 3)) fail(this, "llvm.coro.end.async must tail call function argument type must " "match the tail arguments", diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp index 532599b42e0d..01e724e22dcf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -73,8 +73,8 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, }, ORE); assert(OIC); - emitInlinedInto(ORE, CB->getDebugLoc(), CB->getParent(), F, *Caller, - *OIC, false, DEBUG_TYPE); + emitInlinedIntoBasedOnCost(ORE, CB->getDebugLoc(), CB->getParent(), F, + *Caller, *OIC, false, DEBUG_TYPE); InlineFunctionInfo IFI( /*cg=*/nullptr, GetAssumptionCache, &PSI, @@ -108,8 +108,10 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, // Delete the non-comdat ones from the module and also from our vector. auto NonComdatBegin = partition( InlinedFunctions, [&](Function *F) { return F->hasComdat(); }); - for (Function *F : make_range(NonComdatBegin, InlinedFunctions.end())) + for (Function *F : make_range(NonComdatBegin, InlinedFunctions.end())) { M.getFunctionList().erase(F); + Changed = true; + } InlinedFunctions.erase(NonComdatBegin, InlinedFunctions.end()); if (!InlinedFunctions.empty()) { @@ -117,8 +119,10 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, // are not actually dead. filterDeadComdatFunctions(M, InlinedFunctions); // The remaining functions are actually dead. - for (Function *F : InlinedFunctions) + for (Function *F : InlinedFunctions) { M.getFunctionList().erase(F); + Changed = true; + } } return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index f670a101767e..93bb11433775 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -148,7 +148,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, } else if (!ArgsToPromote.count(&*I)) { // Unchanged argument Params.push_back(I->getType()); - ArgAttrVec.push_back(PAL.getParamAttributes(ArgNo)); + ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo)); } else if (I->use_empty()) { // Dead argument (which are always marked as promotable) ++NumArgumentsDead; @@ -177,9 +177,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Since loads will only have a single operand, and GEPs only a single // non-index operand, this will record direct loads without any indices, // and gep+loads with the GEP indices. - for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end(); - II != IE; ++II) - Indices.push_back(cast<ConstantInt>(*II)->getSExtValue()); + for (const Use &I : llvm::drop_begin(UI->operands())) + Indices.push_back(cast<ConstantInt>(I)->getSExtValue()); // GEPs with a single 0 index can be merged with direct loads if (Indices.size() == 1 && Indices.front() == 0) Indices.clear(); @@ -231,8 +230,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Recompute the parameter attributes list based on the new arguments for // the function. - NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttributes(), - PAL.getRetAttributes(), ArgAttrVec)); + NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(), + PAL.getRetAttrs(), ArgAttrVec)); ArgAttrVec.clear(); F->getParent()->getFunctionList().insert(F->getIterator(), NF); @@ -257,7 +256,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, ++I, ++AI, ++ArgNo) if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { Args.push_back(*AI); // Unmodified argument - ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); + ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo)); } else if (ByValArgsToTransform.count(&*I)) { // Emit a GEP and load for each element of the struct. Type *AgTy = I->getParamByValType(); @@ -313,9 +312,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val"); newLoad->setAlignment(OrigLoad->getAlign()); // Transfer the AA info too. - AAMDNodes AAInfo; - OrigLoad->getAAMetadata(AAInfo); - newLoad->setAAMetadata(AAInfo); + newLoad->setAAMetadata(OrigLoad->getAAMetadata()); Args.push_back(newLoad); ArgAttrVec.push_back(AttributeSet()); @@ -325,7 +322,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Push any varargs arguments on the list. for (; AI != CB.arg_end(); ++AI, ++ArgNo) { Args.push_back(*AI); - ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); + ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo)); } SmallVector<OperandBundleDef, 1> OpBundles; @@ -341,9 +338,9 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, NewCS = NewCall; } NewCS->setCallingConv(CB.getCallingConv()); - NewCS->setAttributes( - AttributeList::get(F->getContext(), CallPAL.getFnAttributes(), - CallPAL.getRetAttributes(), ArgAttrVec)); + NewCS->setAttributes(AttributeList::get(F->getContext(), + CallPAL.getFnAttrs(), + CallPAL.getRetAttrs(), ArgAttrVec)); NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg}); Args.clear(); ArgAttrVec.clear(); @@ -1018,11 +1015,12 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, do { LocalChange = false; + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); + for (LazyCallGraph::Node &N : C) { Function &OldF = N.getFunction(); - FunctionAnalysisManager &FAM = - AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); // FIXME: This lambda must only be used with this function. We should // skip the lambda and just get the AA results directly. auto AARGetter = [&](Function &F) -> AAResults & { @@ -1045,6 +1043,13 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, C.getOuterRefSCC().replaceNodeFunction(N, *NewF); FAM.clear(OldF, OldF.getName()); OldF.eraseFromParent(); + + PreservedAnalyses FuncPA; + FuncPA.preserveSet<CFGAnalyses>(); + for (auto *U : NewF->users()) { + auto *UserF = cast<CallBase>(U)->getFunction(); + FAM.invalidate(*UserF, FuncPA); + } } Changed |= LocalChange; @@ -1053,7 +1058,12 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, if (!Changed) return PreservedAnalyses::all(); - return PreservedAnalyses::none(); + PreservedAnalyses PA; + // We've cleared out analyses for deleted functions. + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We've manually invalidated analyses for functions we've modified. + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; } namespace { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp index 91b16ec66ee3..edadc79e3a9f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp @@ -382,30 +382,30 @@ static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr, if (Attr.isEnumAttribute()) { Attribute::AttrKind Kind = Attr.getKindAsEnum(); - if (Attrs.hasAttribute(AttrIdx, Kind)) + if (Attrs.hasAttributeAtIndex(AttrIdx, Kind)) if (!ForceReplace && - isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind))) return false; - Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr); return true; } if (Attr.isStringAttribute()) { StringRef Kind = Attr.getKindAsString(); - if (Attrs.hasAttribute(AttrIdx, Kind)) + if (Attrs.hasAttributeAtIndex(AttrIdx, Kind)) if (!ForceReplace && - isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind))) return false; - Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr); return true; } if (Attr.isIntAttribute()) { Attribute::AttrKind Kind = Attr.getKindAsEnum(); - if (Attrs.hasAttribute(AttrIdx, Kind)) + if (Attrs.hasAttributeAtIndex(AttrIdx, Kind)) if (!ForceReplace && - isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + isEqualOrWorse(Attr, Attrs.getAttributeAtIndex(AttrIdx, Kind))) return false; - Attrs = Attrs.removeAttribute(Ctx, AttrIdx, Kind); - Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + Attrs = Attrs.removeAttributeAtIndex(Ctx, AttrIdx, Kind); + Attrs = Attrs.addAttributeAtIndex(Ctx, AttrIdx, Attr); return true; } @@ -658,9 +658,9 @@ bool IRPosition::getAttrsFromIRAttr(Attribute::AttrKind AK, else AttrList = getAssociatedFunction()->getAttributes(); - bool HasAttr = AttrList.hasAttribute(getAttrIdx(), AK); + bool HasAttr = AttrList.hasAttributeAtIndex(getAttrIdx(), AK); if (HasAttr) - Attrs.push_back(AttrList.getAttribute(getAttrIdx(), AK)); + Attrs.push_back(AttrList.getAttributeAtIndex(getAttrIdx(), AK)); return HasAttr; } @@ -1043,6 +1043,8 @@ bool Attributor::checkForAllUses(function_ref<bool(const Use &, bool &)> Pred, if (auto *SI = dyn_cast<StoreInst>(U->getUser())) { if (&SI->getOperandUse(0) == U) { + if (!Visited.insert(U).second) + continue; SmallSetVector<Value *, 4> PotentialCopies; if (AA::getPotentialCopiesOfStoredValue(*this, *SI, PotentialCopies, QueryingAA, @@ -1121,6 +1123,10 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred, if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U.getUser())) { if (CE->isCast() && CE->getType()->isPointerTy() && CE->getType()->getPointerElementType()->isFunctionTy()) { + LLVM_DEBUG( + dbgs() << "[Attributor] Use, is constant cast expression, add " + << CE->getNumUses() + << " uses of that expression instead!\n"); for (const Use &CEU : CE->uses()) Uses.push_back(&CEU); continue; @@ -1141,9 +1147,13 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred, const Use *EffectiveUse = ACS.isCallbackCall() ? &ACS.getCalleeUseForCallback() : &U; if (!ACS.isCallee(EffectiveUse)) { - if (!RequireAllCallSites) + if (!RequireAllCallSites) { + LLVM_DEBUG(dbgs() << "[Attributor] User " << *EffectiveUse->getUser() + << " is not a call of " << Fn.getName() + << ", skip use\n"); continue; - LLVM_DEBUG(dbgs() << "[Attributor] User " << EffectiveUse->getUser() + } + LLVM_DEBUG(dbgs() << "[Attributor] User " << *EffectiveUse->getUser() << " is an invalid use of " << Fn.getName() << "\n"); return false; } @@ -1413,6 +1423,16 @@ void Attributor::runTillFixpoint() { } while (!Worklist.empty() && (IterationCounter++ < MaxFixedPointIterations || VerifyMaxFixpointIterations)); + if (IterationCounter > MaxFixedPointIterations && !Worklist.empty()) { + auto Remark = [&](OptimizationRemarkMissed ORM) { + return ORM << "Attributor did not reach a fixpoint after " + << ore::NV("Iterations", MaxFixedPointIterations) + << " iterations."; + }; + Function *F = Worklist.front()->getIRPosition().getAssociatedFunction(); + emitRemark<OptimizationRemarkMissed>(F, "FixedPoint", Remark); + } + LLVM_DEBUG(dbgs() << "\n[Attributor] Fixpoint iteration done after: " << IterationCounter << "/" << MaxFixpointIterations << " iterations\n"); @@ -1922,7 +1942,7 @@ void Attributor::createShallowWrapper(Function &F) { CallInst *CI = CallInst::Create(&F, Args, "", EntryBB); CI->setTailCall(true); - CI->addAttribute(AttributeList::FunctionIndex, Attribute::NoInline); + CI->addFnAttr(Attribute::NoInline); ReturnInst::Create(Ctx, CI->getType()->isVoidTy() ? nullptr : CI, EntryBB); NumFnShallowWrappersCreated++; @@ -2015,7 +2035,8 @@ bool Attributor::isValidFunctionSignatureRewrite( if (!RewriteSignatures) return false; - auto CallSiteCanBeChanged = [](AbstractCallSite ACS) { + Function *Fn = Arg.getParent(); + auto CallSiteCanBeChanged = [Fn](AbstractCallSite ACS) { // Forbid the call site to cast the function return type. If we need to // rewrite these functions we need to re-create a cast for the new call site // (if the old had uses). @@ -2023,11 +2044,12 @@ bool Attributor::isValidFunctionSignatureRewrite( ACS.getInstruction()->getType() != ACS.getCalledFunction()->getReturnType()) return false; + if (ACS.getCalledOperand()->getType() != Fn->getType()) + return false; // Forbid must-tail calls for now. return !ACS.isCallbackCall() && !ACS.getInstruction()->isMustTailCall(); }; - Function *Fn = Arg.getParent(); // Avoid var-arg functions for now. if (Fn->isVarArg()) { LLVM_DEBUG(dbgs() << "[Attributor] Cannot rewrite var-args functions\n"); @@ -2157,7 +2179,7 @@ ChangeStatus Attributor::rewriteFunctionSignatures( } else { NewArgumentTypes.push_back(Arg.getType()); NewArgumentAttributes.push_back( - OldFnAttributeList.getParamAttributes(Arg.getArgNo())); + OldFnAttributeList.getParamAttrs(Arg.getArgNo())); } } @@ -2188,8 +2210,8 @@ ChangeStatus Attributor::rewriteFunctionSignatures( // the function. LLVMContext &Ctx = OldFn->getContext(); NewFn->setAttributes(AttributeList::get( - Ctx, OldFnAttributeList.getFnAttributes(), - OldFnAttributeList.getRetAttributes(), NewArgumentAttributes)); + Ctx, OldFnAttributeList.getFnAttrs(), OldFnAttributeList.getRetAttrs(), + NewArgumentAttributes)); // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the @@ -2234,7 +2256,7 @@ ChangeStatus Attributor::rewriteFunctionSignatures( } else { NewArgOperands.push_back(ACS.getCallArgOperand(OldArgNum)); NewArgOperandAttributes.push_back( - OldCallAttributeList.getParamAttributes(OldArgNum)); + OldCallAttributeList.getParamAttrs(OldArgNum)); } } @@ -2264,8 +2286,8 @@ ChangeStatus Attributor::rewriteFunctionSignatures( NewCB->setCallingConv(OldCB->getCallingConv()); NewCB->takeName(OldCB); NewCB->setAttributes(AttributeList::get( - Ctx, OldCallAttributeList.getFnAttributes(), - OldCallAttributeList.getRetAttributes(), NewArgOperandAttributes)); + Ctx, OldCallAttributeList.getFnAttrs(), + OldCallAttributeList.getRetAttrs(), NewArgOperandAttributes)); CallSitePairs.push_back({OldCB, NewCB}); return true; @@ -2480,6 +2502,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function can be "readnone/argmemonly/inaccessiblememonly/...". getOrCreateAAFor<AAMemoryLocation>(FPos); + // Every function can track active assumptions. + getOrCreateAAFor<AAAssumptionInfo>(FPos); + // Every function might be applicable for Heap-To-Stack conversion. if (EnableHeapToStack) getOrCreateAAFor<AAHeapToStack>(FPos); @@ -2565,6 +2590,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { auto CallSitePred = [&](Instruction &I) -> bool { auto &CB = cast<CallBase>(I); IRPosition CBRetPos = IRPosition::callsite_returned(CB); + IRPosition CBFnPos = IRPosition::callsite_function(CB); // Call sites might be dead if they do not have side effects and no live // users. The return value might be dead if there are no live users. @@ -2576,6 +2602,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { if (!Callee) return true; + // Every call site can track active assumptions. + getOrCreateAAFor<AAAssumptionInfo>(CBFnPos); + // Skip declarations except if annotations on their call sites were // explicitly requested. if (!AnnotateDeclarationCallSites && Callee->isDeclaration() && @@ -2588,7 +2617,7 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { getOrCreateAAFor<AAValueSimplify>(CBRetPos); } - for (int I = 0, E = CB.getNumArgOperands(); I < E; ++I) { + for (int I = 0, E = CB.arg_size(); I < E; ++I) { IRPosition CBArgPos = IRPosition::callsite_argument(CB, I); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 3529923a9082..ec08287393de 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -28,6 +29,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Assumptions.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" @@ -146,6 +148,7 @@ PIPE_OPERATOR(AANoUndef) PIPE_OPERATOR(AACallEdges) PIPE_OPERATOR(AAFunctionReachability) PIPE_OPERATOR(AAPointerInfo) +PIPE_OPERATOR(AAAssumptionInfo) #undef PIPE_OPERATOR @@ -203,46 +206,25 @@ static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr, << "-bytes as " << *ResTy << "\n"); if (Offset) { - SmallVector<Value *, 4> Indices; - std::string GEPName = Ptr->getName().str() + ".0"; - - // Add 0 index to look through the pointer. - assert((uint64_t)Offset < DL.getTypeAllocSize(PtrElemTy) && - "Offset out of bounds"); - Indices.push_back(Constant::getNullValue(IRB.getInt32Ty())); - Type *Ty = PtrElemTy; - do { - auto *STy = dyn_cast<StructType>(Ty); - if (!STy) - // Non-aggregate type, we cast and make byte-wise progress now. - break; - - const StructLayout *SL = DL.getStructLayout(STy); - if (int64_t(SL->getSizeInBytes()) < Offset) - break; - - uint64_t Idx = SL->getElementContainingOffset(Offset); - assert(Idx < STy->getNumElements() && "Offset calculation error!"); - uint64_t Rem = Offset - SL->getElementOffset(Idx); - Ty = STy->getElementType(Idx); - - LLVM_DEBUG(errs() << "Ty: " << *Ty << " Offset: " << Offset - << " Idx: " << Idx << " Rem: " << Rem << "\n"); + APInt IntOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset); + SmallVector<APInt> IntIndices = DL.getGEPIndicesForOffset(Ty, IntOffset); - GEPName += "." + std::to_string(Idx); - Indices.push_back(ConstantInt::get(IRB.getInt32Ty(), Idx)); - Offset = Rem; - } while (Offset); + SmallVector<Value *, 4> ValIndices; + std::string GEPName = Ptr->getName().str(); + for (const APInt &Index : IntIndices) { + ValIndices.push_back(IRB.getInt(Index)); + GEPName += "." + std::to_string(Index.getZExtValue()); + } // Create a GEP for the indices collected above. - Ptr = IRB.CreateGEP(PtrElemTy, Ptr, Indices, GEPName); + Ptr = IRB.CreateGEP(PtrElemTy, Ptr, ValIndices, GEPName); // If an offset is left we use byte-wise adjustment. - if (Offset) { + if (IntOffset != 0) { Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy()); - Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt32(Offset), - GEPName + ".b" + Twine(Offset)); + Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(IntOffset), + GEPName + ".b" + Twine(IntOffset.getZExtValue())); } } @@ -431,6 +413,7 @@ const Value *stripAndAccumulateMinimalOffsets( }; return Val->stripAndAccumulateConstantOffsets(DL, Offset, AllowNonInbounds, + /* AllowInvariant */ false, AttributorAnalysis); } @@ -503,6 +486,7 @@ static void clampReturnedValueStates( S ^= *T; } +namespace { /// Helper class for generic deduction: return value -> returned position. template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, @@ -661,6 +645,7 @@ struct AACallSiteReturnedFromReturned : public BaseType { return clampStateAndIndicateChange(S, AA.getState()); } }; +} // namespace /// Helper function to accumulate uses. template <class AAType, typename StateType = typename AAType::StateType> @@ -1051,6 +1036,7 @@ private: BooleanState BS; }; +namespace { struct AAPointerInfoImpl : public StateWrapper<AA::PointerInfo::State, AAPointerInfo> { using BaseTy = StateWrapper<AA::PointerInfo::State, AAPointerInfo>; @@ -1207,7 +1193,7 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { } SmallVector<Value *, 8> Indices; - for (Use &Idx : llvm::make_range(GEP->idx_begin(), GEP->idx_end())) { + for (Use &Idx : GEP->indices()) { if (auto *CIdx = dyn_cast<ConstantInt>(Idx)) { Indices.push_back(CIdx); continue; @@ -1244,7 +1230,11 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { } // Check if the PHI operand is not dependent on the PHI itself. - APInt Offset(DL.getIndexTypeSizeInBits(AssociatedValue.getType()), 0); + // TODO: This is not great as we look at the pointer type. However, it + // is unclear where the Offset size comes from with typeless pointers. + APInt Offset( + DL.getIndexSizeInBits(CurPtr->getType()->getPointerAddressSpace()), + 0); if (&AssociatedValue == CurPtr->stripAndAccumulateConstantOffsets( DL, Offset, /* AllowNonInbounds */ true)) { if (Offset != PtrOI.Offset) { @@ -2432,6 +2422,10 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { const size_t NoUBPrevSize = AssumedNoUBInsts.size(); auto InspectMemAccessInstForUB = [&](Instruction &I) { + // Lang ref now states volatile store is not UB, let's skip them. + if (I.isVolatile() && I.mayWriteToMemory()) + return true; + // Skip instructions that are already saved. if (AssumedNoUBInsts.count(&I) || KnownUBInsts.count(&I)) return true; @@ -2511,7 +2505,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { Function *Callee = CB.getCalledFunction(); if (!Callee) return true; - for (unsigned idx = 0; idx < CB.getNumArgOperands(); idx++) { + for (unsigned idx = 0; idx < CB.arg_size(); idx++) { // If current argument is known to be simplified to null pointer and the // corresponding argument position is known to have nonnull attribute, // the argument is poison. Furthermore, if the argument is poison and @@ -3179,8 +3173,7 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { // value passed at this call site. // TODO: AbstractCallSite const auto &CB = cast<CallBase>(getAnchorValue()); - for (unsigned OtherArgNo = 0; OtherArgNo < CB.getNumArgOperands(); - OtherArgNo++) + for (unsigned OtherArgNo = 0; OtherArgNo < CB.arg_size(); OtherArgNo++) if (mayAliasWithArgument(A, AAR, MemBehaviorAA, CB, OtherArgNo)) return false; @@ -3398,6 +3391,10 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { } bool isDeadStore(Attributor &A, StoreInst &SI) { + // Lang ref now states volatile store is not UB/dead, let's skip them. + if (SI.isVolatile()) + return false; + bool UsedAssumedInformation = false; SmallSetVector<Value *, 4> PotentialCopies; if (!AA::getPotentialCopiesOfStoredValue(A, SI, PotentialCopies, *this, @@ -5083,6 +5080,7 @@ struct AANoCaptureCallSiteReturned final : AANoCaptureImpl { STATS_DECLTRACK_CSRET_ATTR(nocapture) } }; +} // namespace /// ------------------ Value Simplify Attribute ---------------------------- @@ -5103,6 +5101,7 @@ bool ValueSimplifyStateType::unionAssumed(Optional<Value *> Other) { return true; } +namespace { struct AAValueSimplifyImpl : AAValueSimplify { AAValueSimplifyImpl(const IRPosition &IRP, Attributor &A) : AAValueSimplify(IRP, A) {} @@ -6508,7 +6507,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { auto IsCompatiblePrivArgOfDirectCS = [&](AbstractCallSite ACS) { CallBase *DC = cast<CallBase>(ACS.getInstruction()); int DCArgNo = ACS.getCallArgOperandNo(ArgNo); - assert(DCArgNo >= 0 && unsigned(DCArgNo) < DC->getNumArgOperands() && + assert(DCArgNo >= 0 && unsigned(DCArgNo) < DC->arg_size() && "Expected a direct call operand for callback call operand"); LLVM_DEBUG({ @@ -7331,10 +7330,12 @@ void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use &U, case Instruction::Store: // Stores cause the NO_WRITES property to disappear if the use is the - // pointer operand. Note that we do assume that capturing was taken care of - // somewhere else. + // pointer operand. Note that while capturing was taken care of somewhere + // else we need to deal with stores of the value that is not looked through. if (cast<StoreInst>(UserI)->getPointerOperand() == U.get()) removeAssumedBits(NO_WRITES); + else + indicatePessimisticFixpoint(); return; case Instruction::Call: @@ -7380,6 +7381,7 @@ void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use &U, if (UserI->mayWriteToMemory()) removeAssumedBits(NO_WRITES); } +} // namespace /// -------------------- Memory Locations Attributes --------------------------- /// Includes read-none, argmemonly, inaccessiblememonly, @@ -7672,11 +7674,14 @@ void AAMemoryLocationImpl::categorizePtrValue( assert(!isa<GEPOperator>(Obj) && "GEPs should have been stripped."); if (isa<UndefValue>(Obj)) continue; - if (auto *Arg = dyn_cast<Argument>(Obj)) { - if (Arg->hasByValAttr()) - MLK = NO_LOCAL_MEM; - else - MLK = NO_ARGUMENT_MEM; + if (isa<Argument>(Obj)) { + // TODO: For now we do not treat byval arguments as local copies performed + // on the call edge, though, we should. To make that happen we need to + // teach various passes, e.g., DSE, about the copy effect of a byval. That + // would also allow us to mark functions only accessing byval arguments as + // readnone again, atguably their acceses have no effect outside of the + // function, like accesses to allocas. + MLK = NO_ARGUMENT_MEM; } else if (auto *GV = dyn_cast<GlobalValue>(Obj)) { // Reading constant memory is not treated as a read "effect" by the // function attr pass so we won't neither. Constants defined by TBAA are @@ -7722,7 +7727,7 @@ void AAMemoryLocationImpl::categorizePtrValue( void AAMemoryLocationImpl::categorizeArgumentPointerLocations( Attributor &A, CallBase &CB, AAMemoryLocation::StateType &AccessedLocs, bool &Changed) { - for (unsigned ArgNo = 0, E = CB.getNumArgOperands(); ArgNo < E; ++ArgNo) { + for (unsigned ArgNo = 0, E = CB.arg_size(); ArgNo < E; ++ArgNo) { // Skip non-pointer arguments. const Value *ArgOp = CB.getArgOperand(ArgNo); @@ -8655,31 +8660,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { static bool calculateICmpInst(const ICmpInst *ICI, const APInt &LHS, const APInt &RHS) { - ICmpInst::Predicate Pred = ICI->getPredicate(); - switch (Pred) { - case ICmpInst::ICMP_UGT: - return LHS.ugt(RHS); - case ICmpInst::ICMP_SGT: - return LHS.sgt(RHS); - case ICmpInst::ICMP_EQ: - return LHS.eq(RHS); - case ICmpInst::ICMP_UGE: - return LHS.uge(RHS); - case ICmpInst::ICMP_SGE: - return LHS.sge(RHS); - case ICmpInst::ICMP_ULT: - return LHS.ult(RHS); - case ICmpInst::ICMP_SLT: - return LHS.slt(RHS); - case ICmpInst::ICMP_NE: - return LHS.ne(RHS); - case ICmpInst::ICMP_ULE: - return LHS.ule(RHS); - case ICmpInst::ICMP_SLE: - return LHS.sle(RHS); - default: - llvm_unreachable("Invalid ICmp predicate!"); - } + return ICmpInst::compare(LHS, RHS, ICI->getPredicate()); } static APInt calculateCastInst(const CastInst *CI, const APInt &Src, @@ -8719,25 +8700,25 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { case Instruction::Mul: return LHS * RHS; case Instruction::UDiv: - if (RHS.isNullValue()) { + if (RHS.isZero()) { SkipOperation = true; return LHS; } return LHS.udiv(RHS); case Instruction::SDiv: - if (RHS.isNullValue()) { + if (RHS.isZero()) { SkipOperation = true; return LHS; } return LHS.sdiv(RHS); case Instruction::URem: - if (RHS.isNullValue()) { + if (RHS.isZero()) { SkipOperation = true; return LHS; } return LHS.urem(RHS); case Instruction::SRem: - if (RHS.isNullValue()) { + if (RHS.isZero()) { SkipOperation = true; return LHS; } @@ -9336,32 +9317,69 @@ struct AANoUndefCallSiteReturned final void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) } }; -struct AACallEdgesFunction : public AACallEdges { - AACallEdgesFunction(const IRPosition &IRP, Attributor &A) - : AACallEdges(IRP, A) {} +struct AACallEdgesImpl : public AACallEdges { + AACallEdgesImpl(const IRPosition &IRP, Attributor &A) : AACallEdges(IRP, A) {} + virtual const SetVector<Function *> &getOptimisticEdges() const override { + return CalledFunctions; + } + + virtual bool hasUnknownCallee() const override { return HasUnknownCallee; } + + virtual bool hasNonAsmUnknownCallee() const override { + return HasUnknownCalleeNonAsm; + } + + const std::string getAsStr() const override { + return "CallEdges[" + std::to_string(HasUnknownCallee) + "," + + std::to_string(CalledFunctions.size()) + "]"; + } + + void trackStatistics() const override {} + +protected: + void addCalledFunction(Function *Fn, ChangeStatus &Change) { + if (CalledFunctions.insert(Fn)) { + Change = ChangeStatus::CHANGED; + LLVM_DEBUG(dbgs() << "[AACallEdges] New call edge: " << Fn->getName() + << "\n"); + } + } + + void setHasUnknownCallee(bool NonAsm, ChangeStatus &Change) { + if (!HasUnknownCallee) + Change = ChangeStatus::CHANGED; + if (NonAsm && !HasUnknownCalleeNonAsm) + Change = ChangeStatus::CHANGED; + HasUnknownCalleeNonAsm |= NonAsm; + HasUnknownCallee = true; + } + +private: + /// Optimistic set of functions that might be called by this position. + SetVector<Function *> CalledFunctions; + + /// Is there any call with a unknown callee. + bool HasUnknownCallee = false; + + /// Is there any call with a unknown callee, excluding any inline asm. + bool HasUnknownCalleeNonAsm = false; +}; + +struct AACallEdgesCallSite : public AACallEdgesImpl { + AACallEdgesCallSite(const IRPosition &IRP, Attributor &A) + : AACallEdgesImpl(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { ChangeStatus Change = ChangeStatus::UNCHANGED; - bool OldHasUnknownCallee = HasUnknownCallee; - bool OldHasUnknownCalleeNonAsm = HasUnknownCalleeNonAsm; - - auto AddCalledFunction = [&](Function *Fn) { - if (CalledFunctions.insert(Fn)) { - Change = ChangeStatus::CHANGED; - LLVM_DEBUG(dbgs() << "[AACallEdges] New call edge: " << Fn->getName() - << "\n"); - } - }; auto VisitValue = [&](Value &V, const Instruction *CtxI, bool &HasUnknown, bool Stripped) -> bool { if (Function *Fn = dyn_cast<Function>(&V)) { - AddCalledFunction(Fn); + addCalledFunction(Fn, Change); } else { LLVM_DEBUG(dbgs() << "[AACallEdges] Unrecognized value: " << V << "\n"); - HasUnknown = true; - HasUnknownCalleeNonAsm = true; + setHasUnknownCallee(true, Change); } // Explore all values. @@ -9369,44 +9387,67 @@ struct AACallEdgesFunction : public AACallEdges { }; // Process any value that we might call. - auto ProcessCalledOperand = [&](Value *V, Instruction *Ctx) { + auto ProcessCalledOperand = [&](Value *V) { + bool DummyValue = false; if (!genericValueTraversal<bool>(A, IRPosition::value(*V), *this, - HasUnknownCallee, VisitValue, nullptr, + DummyValue, VisitValue, nullptr, false)) { // If we haven't gone through all values, assume that there are unknown // callees. - HasUnknownCallee = true; - HasUnknownCalleeNonAsm = true; + setHasUnknownCallee(true, Change); } }; - auto ProcessCallInst = [&](Instruction &Inst) { - CallBase &CB = static_cast<CallBase &>(Inst); - if (CB.isInlineAsm()) { - HasUnknownCallee = true; - return true; - } + CallBase *CB = static_cast<CallBase *>(getCtxI()); - // Process callee metadata if available. - if (auto *MD = Inst.getMetadata(LLVMContext::MD_callees)) { - for (auto &Op : MD->operands()) { - Function *Callee = mdconst::extract_or_null<Function>(Op); - if (Callee) - AddCalledFunction(Callee); - } - // Callees metadata grantees that the called function is one of its - // operands, So we are done. - return true; + if (CB->isInlineAsm()) { + setHasUnknownCallee(false, Change); + return Change; + } + + // Process callee metadata if available. + if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) { + for (auto &Op : MD->operands()) { + Function *Callee = mdconst::dyn_extract_or_null<Function>(Op); + if (Callee) + addCalledFunction(Callee, Change); } + return Change; + } - // The most simple case. - ProcessCalledOperand(CB.getCalledOperand(), &Inst); + // The most simple case. + ProcessCalledOperand(CB->getCalledOperand()); - // Process callback functions. - SmallVector<const Use *, 4u> CallbackUses; - AbstractCallSite::getCallbackUses(CB, CallbackUses); - for (const Use *U : CallbackUses) - ProcessCalledOperand(U->get(), &Inst); + // Process callback functions. + SmallVector<const Use *, 4u> CallbackUses; + AbstractCallSite::getCallbackUses(*CB, CallbackUses); + for (const Use *U : CallbackUses) + ProcessCalledOperand(U->get()); + + return Change; + } +}; + +struct AACallEdgesFunction : public AACallEdgesImpl { + AACallEdgesFunction(const IRPosition &IRP, Attributor &A) + : AACallEdgesImpl(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + auto ProcessCallInst = [&](Instruction &Inst) { + CallBase &CB = static_cast<CallBase &>(Inst); + + auto &CBEdges = A.getAAFor<AACallEdges>( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); + if (CBEdges.hasNonAsmUnknownCallee()) + setHasUnknownCallee(true, Change); + if (CBEdges.hasUnknownCallee()) + setHasUnknownCallee(false, Change); + + for (Function *F : CBEdges.getOptimisticEdges()) + addCalledFunction(F, Change); return true; }; @@ -9417,155 +9458,323 @@ struct AACallEdgesFunction : public AACallEdges { UsedAssumedInformation)) { // If we haven't looked at all call like instructions, assume that there // are unknown callees. - HasUnknownCallee = true; - HasUnknownCalleeNonAsm = true; + setHasUnknownCallee(true, Change); } - // Track changes. - if (OldHasUnknownCallee != HasUnknownCallee || - OldHasUnknownCalleeNonAsm != HasUnknownCalleeNonAsm) - Change = ChangeStatus::CHANGED; - return Change; } +}; - virtual const SetVector<Function *> &getOptimisticEdges() const override { - return CalledFunctions; - }; +struct AAFunctionReachabilityFunction : public AAFunctionReachability { +private: + struct QuerySet { + void markReachable(Function *Fn) { + Reachable.insert(Fn); + Unreachable.erase(Fn); + } + + ChangeStatus update(Attributor &A, const AAFunctionReachability &AA, + ArrayRef<const AACallEdges *> AAEdgesList) { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + for (auto *AAEdges : AAEdgesList) { + if (AAEdges->hasUnknownCallee()) { + if (!CanReachUnknownCallee) + Change = ChangeStatus::CHANGED; + CanReachUnknownCallee = true; + return Change; + } + } - virtual bool hasUnknownCallee() const override { return HasUnknownCallee; } + for (Function *Fn : make_early_inc_range(Unreachable)) { + if (checkIfReachable(A, AA, AAEdgesList, Fn)) { + Change = ChangeStatus::CHANGED; + markReachable(Fn); + } + } + return Change; + } - virtual bool hasNonAsmUnknownCallee() const override { - return HasUnknownCalleeNonAsm; - } + bool isReachable(Attributor &A, const AAFunctionReachability &AA, + ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) { + // Assume that we can reach the function. + // TODO: Be more specific with the unknown callee. + if (CanReachUnknownCallee) + return true; - const std::string getAsStr() const override { - return "CallEdges[" + std::to_string(HasUnknownCallee) + "," + - std::to_string(CalledFunctions.size()) + "]"; - } + if (Reachable.count(Fn)) + return true; - void trackStatistics() const override {} + if (Unreachable.count(Fn)) + return false; - /// Optimistic set of functions that might be called by this function. - SetVector<Function *> CalledFunctions; + // We need to assume that this function can't reach Fn to prevent + // an infinite loop if this function is recursive. + Unreachable.insert(Fn); - /// Is there any call with a unknown callee. - bool HasUnknownCallee = false; + bool Result = checkIfReachable(A, AA, AAEdgesList, Fn); + if (Result) + markReachable(Fn); + return Result; + } - /// Is there any call with a unknown callee, excluding any inline asm. - bool HasUnknownCalleeNonAsm = false; -}; + bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA, + ArrayRef<const AACallEdges *> AAEdgesList, + Function *Fn) const { -struct AAFunctionReachabilityFunction : public AAFunctionReachability { - AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) - : AAFunctionReachability(IRP, A) {} + // Handle the most trivial case first. + for (auto *AAEdges : AAEdgesList) { + const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges(); - bool canReach(Attributor &A, Function *Fn) const override { - // Assume that we can reach any function if we can reach a call with - // unknown callee. - if (CanReachUnknownCallee) - return true; + if (Edges.count(Fn)) + return true; + } - if (ReachableQueries.count(Fn)) - return true; + SmallVector<const AAFunctionReachability *, 8> Deps; + for (auto &AAEdges : AAEdgesList) { + const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges(); + + for (Function *Edge : Edges) { + // We don't need a dependency if the result is reachable. + const AAFunctionReachability &EdgeReachability = + A.getAAFor<AAFunctionReachability>( + AA, IRPosition::function(*Edge), DepClassTy::NONE); + Deps.push_back(&EdgeReachability); + + if (EdgeReachability.canReach(A, Fn)) + return true; + } + } + + // The result is false for now, set dependencies and leave. + for (auto Dep : Deps) + A.recordDependence(AA, *Dep, DepClassTy::REQUIRED); - if (UnreachableQueries.count(Fn)) return false; + } + + /// Set of functions that we know for sure is reachable. + DenseSet<Function *> Reachable; + + /// Set of functions that are unreachable, but might become reachable. + DenseSet<Function *> Unreachable; + + /// If we can reach a function with a call to a unknown function we assume + /// that we can reach any function. + bool CanReachUnknownCallee = false; + }; +public: + AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) + : AAFunctionReachability(IRP, A) {} + + bool canReach(Attributor &A, Function *Fn) const override { const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED); - const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges(); - bool Result = checkIfReachable(A, Edges, Fn); + // Attributor returns attributes as const, so this function has to be + // const for users of this attribute to use it without having to do + // a const_cast. + // This is a hack for us to be able to cache queries. + auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this); + bool Result = + NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, Fn); + + return Result; + } + + /// Can \p CB reach \p Fn + bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override { + const AACallEdges &AAEdges = A.getAAFor<AACallEdges>( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); // Attributor returns attributes as const, so this function has to be // const for users of this attribute to use it without having to do // a const_cast. // This is a hack for us to be able to cache queries. auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this); + QuerySet &CBQuery = NonConstThis->CBQueries[&CB]; - if (Result) - NonConstThis->ReachableQueries.insert(Fn); - else - NonConstThis->UnreachableQueries.insert(Fn); + bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn); return Result; } /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - if (CanReachUnknownCallee) - return ChangeStatus::UNCHANGED; - const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED); - const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges(); ChangeStatus Change = ChangeStatus::UNCHANGED; - if (AAEdges.hasUnknownCallee()) { - bool OldCanReachUnknown = CanReachUnknownCallee; - CanReachUnknownCallee = true; - return OldCanReachUnknown ? ChangeStatus::UNCHANGED - : ChangeStatus::CHANGED; - } + Change |= WholeFunction.update(A, *this, {&AAEdges}); - // Check if any of the unreachable functions become reachable. - for (auto Current = UnreachableQueries.begin(); - Current != UnreachableQueries.end();) { - if (!checkIfReachable(A, Edges, *Current)) { - Current++; - continue; - } - ReachableQueries.insert(*Current); - UnreachableQueries.erase(*Current++); - Change = ChangeStatus::CHANGED; + for (auto CBPair : CBQueries) { + const AACallEdges &AAEdges = A.getAAFor<AACallEdges>( + *this, IRPosition::callsite_function(*CBPair.first), + DepClassTy::REQUIRED); + + Change |= CBPair.second.update(A, *this, {&AAEdges}); } return Change; } const std::string getAsStr() const override { - size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size(); + size_t QueryCount = + WholeFunction.Reachable.size() + WholeFunction.Unreachable.size(); - return "FunctionReachability [" + std::to_string(ReachableQueries.size()) + - "," + std::to_string(QueryCount) + "]"; + return "FunctionReachability [" + + std::to_string(WholeFunction.Reachable.size()) + "," + + std::to_string(QueryCount) + "]"; } void trackStatistics() const override {} private: - bool canReachUnknownCallee() const override { return CanReachUnknownCallee; } + bool canReachUnknownCallee() const override { + return WholeFunction.CanReachUnknownCallee; + } - bool checkIfReachable(Attributor &A, const SetVector<Function *> &Edges, - Function *Fn) const { - if (Edges.count(Fn)) - return true; + /// Used to answer if a the whole function can reacha a specific function. + QuerySet WholeFunction; - for (Function *Edge : Edges) { - // We don't need a dependency if the result is reachable. - const AAFunctionReachability &EdgeReachability = - A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Edge), - DepClassTy::NONE); + /// Used to answer if a call base inside this function can reach a specific + /// function. + DenseMap<CallBase *, QuerySet> CBQueries; +}; - if (EdgeReachability.canReach(A, Fn)) - return true; - } - for (Function *Fn : Edges) - A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Fn), - DepClassTy::REQUIRED); +/// ---------------------- Assumption Propagation ------------------------------ +struct AAAssumptionInfoImpl : public AAAssumptionInfo { + AAAssumptionInfoImpl(const IRPosition &IRP, Attributor &A, + const DenseSet<StringRef> &Known) + : AAAssumptionInfo(IRP, A, Known) {} - return false; + bool hasAssumption(const StringRef Assumption) const override { + return isValidState() && setContains(Assumption); } - /// Set of functions that we know for sure is reachable. - SmallPtrSet<Function *, 8> ReachableQueries; + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + const SetContents &Known = getKnown(); + const SetContents &Assumed = getAssumed(); + + const std::string KnownStr = + llvm::join(Known.getSet().begin(), Known.getSet().end(), ","); + const std::string AssumedStr = + (Assumed.isUniversal()) + ? "Universal" + : llvm::join(Assumed.getSet().begin(), Assumed.getSet().end(), ","); + + return "Known [" + KnownStr + "]," + " Assumed [" + AssumedStr + "]"; + } +}; + +/// Propagates assumption information from parent functions to all of their +/// successors. An assumption can be propagated if the containing function +/// dominates the called function. +/// +/// We start with a "known" set of assumptions already valid for the associated +/// function and an "assumed" set that initially contains all possible +/// assumptions. The assumed set is inter-procedurally updated by narrowing its +/// contents as concrete values are known. The concrete values are seeded by the +/// first nodes that are either entries into the call graph, or contains no +/// assumptions. Each node is updated as the intersection of the assumed state +/// with all of its predecessors. +struct AAAssumptionInfoFunction final : AAAssumptionInfoImpl { + AAAssumptionInfoFunction(const IRPosition &IRP, Attributor &A) + : AAAssumptionInfoImpl(IRP, A, + getAssumptions(*IRP.getAssociatedFunction())) {} + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + const auto &Assumptions = getKnown(); + + // Don't manifest a universal set if it somehow made it here. + if (Assumptions.isUniversal()) + return ChangeStatus::UNCHANGED; + + Function *AssociatedFunction = getAssociatedFunction(); + + bool Changed = addAssumptions(*AssociatedFunction, Assumptions.getSet()); + + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + bool Changed = false; + + auto CallSitePred = [&](AbstractCallSite ACS) { + const auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>( + *this, IRPosition::callsite_function(*ACS.getInstruction()), + DepClassTy::REQUIRED); + // Get the set of assumptions shared by all of this function's callers. + Changed |= getIntersection(AssumptionAA.getAssumed()); + return !getAssumed().empty() || !getKnown().empty(); + }; + + bool AllCallSitesKnown; + // Get the intersection of all assumptions held by this node's predecessors. + // If we don't know all the call sites then this is either an entry into the + // call graph or an empty node. This node is known to only contain its own + // assumptions and can be propagated to its successors. + if (!A.checkForAllCallSites(CallSitePred, *this, true, AllCallSitesKnown)) + return indicatePessimisticFixpoint(); - /// Set of functions that are unreachable, but might become reachable. - SmallPtrSet<Function *, 8> UnreachableQueries; + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + void trackStatistics() const override {} +}; + +/// Assumption Info defined for call sites. +struct AAAssumptionInfoCallSite final : AAAssumptionInfoImpl { + + AAAssumptionInfoCallSite(const IRPosition &IRP, Attributor &A) + : AAAssumptionInfoImpl(IRP, A, getInitialAssumptions(IRP)) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + const IRPosition &FnPos = IRPosition::function(*getAnchorScope()); + A.getAAFor<AAAssumptionInfo>(*this, FnPos, DepClassTy::REQUIRED); + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + // Don't manifest a universal set if it somehow made it here. + if (getKnown().isUniversal()) + return ChangeStatus::UNCHANGED; - /// If we can reach a function with a call to a unknown function we assume - /// that we can reach any function. - bool CanReachUnknownCallee = false; + CallBase &AssociatedCall = cast<CallBase>(getAssociatedValue()); + bool Changed = addAssumptions(AssociatedCall, getAssumed().getSet()); + + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + const IRPosition &FnPos = IRPosition::function(*getAnchorScope()); + auto &AssumptionAA = + A.getAAFor<AAAssumptionInfo>(*this, FnPos, DepClassTy::REQUIRED); + bool Changed = getIntersection(AssumptionAA.getAssumed()); + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} + +private: + /// Helper to initialized the known set as all the assumptions this call and + /// the callee contain. + DenseSet<StringRef> getInitialAssumptions(const IRPosition &IRP) { + const CallBase &CB = cast<CallBase>(IRP.getAssociatedValue()); + auto Assumptions = getAssumptions(CB); + if (Function *F = IRP.getAssociatedFunction()) + set_union(Assumptions, getAssumptions(*F)); + if (Function *F = IRP.getAssociatedFunction()) + set_union(Assumptions, getAssumptions(*F)); + return Assumptions; + } }; } // namespace @@ -9603,6 +9812,7 @@ const char AANoUndef::ID = 0; const char AACallEdges::ID = 0; const char AAFunctionReachability::ID = 0; const char AAPointerInfo::ID = 0; +const char AAAssumptionInfo::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -9704,6 +9914,8 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAWillReturn) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoReturn) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) @@ -9723,7 +9935,6 @@ CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReachability) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) -CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAFunctionReachability) CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp index 8e81f4bad4af..178d3f41963e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp @@ -153,33 +153,30 @@ static bool mergeConstants(Module &M) { // were just merged. while (true) { // Find the canonical constants others will be merged with. - for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); - GVI != E; ) { - GlobalVariable *GV = &*GVI++; - + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { // If this GV is dead, remove it. - GV->removeDeadConstantUsers(); - if (GV->use_empty() && GV->hasLocalLinkage()) { - GV->eraseFromParent(); + GV.removeDeadConstantUsers(); + if (GV.use_empty() && GV.hasLocalLinkage()) { + GV.eraseFromParent(); ++ChangesMade; continue; } - if (isUnmergeableGlobal(GV, UsedGlobals)) + if (isUnmergeableGlobal(&GV, UsedGlobals)) continue; // This transformation is legal for weak ODR globals in the sense it // doesn't change semantics, but we really don't want to perform it // anyway; it's likely to pessimize code generation, and some tools // (like the Darwin linker in cases involving CFString) don't expect it. - if (GV->isWeakForLinker()) + if (GV.isWeakForLinker()) continue; // Don't touch globals with metadata other then !dbg. - if (hasMetadataOtherThanDebugLoc(GV)) + if (hasMetadataOtherThanDebugLoc(&GV)) continue; - Constant *Init = GV->getInitializer(); + Constant *Init = GV.getInitializer(); // Check to see if the initializer is already known. GlobalVariable *&Slot = CMap[Init]; @@ -188,9 +185,9 @@ static bool mergeConstants(Module &M) { // replace with the current one. If the current is externally visible // it cannot be replace, but can be the canonical constant we merge with. bool FirstConstantFound = !Slot; - if (FirstConstantFound || IsBetterCanonical(*GV, *Slot)) { - Slot = GV; - LLVM_DEBUG(dbgs() << "Cmap[" << *Init << "] = " << GV->getName() + if (FirstConstantFound || IsBetterCanonical(GV, *Slot)) { + Slot = &GV; + LLVM_DEBUG(dbgs() << "Cmap[" << *Init << "] = " << GV.getName() << (FirstConstantFound ? "\n" : " (updated)\n")); } } @@ -199,18 +196,15 @@ static bool mergeConstants(Module &M) { // SameContentReplacements vector. We cannot do the replacement in this pass // because doing so may cause initializers of other globals to be rewritten, // invalidating the Constant* pointers in CMap. - for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); - GVI != E; ) { - GlobalVariable *GV = &*GVI++; - - if (isUnmergeableGlobal(GV, UsedGlobals)) + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { + if (isUnmergeableGlobal(&GV, UsedGlobals)) continue; // We can only replace constant with local linkage. - if (!GV->hasLocalLinkage()) + if (!GV.hasLocalLinkage()) continue; - Constant *Init = GV->getInitializer(); + Constant *Init = GV.getInitializer(); // Check to see if the initializer is already known. auto Found = CMap.find(Init); @@ -218,16 +212,16 @@ static bool mergeConstants(Module &M) { continue; GlobalVariable *Slot = Found->second; - if (Slot == GV) + if (Slot == &GV) continue; - if (makeMergeable(GV, Slot) == CanMerge::No) + if (makeMergeable(&GV, Slot) == CanMerge::No) continue; // Make all uses of the duplicate constant use the canonical version. - LLVM_DEBUG(dbgs() << "Will replace: @" << GV->getName() << " -> @" + LLVM_DEBUG(dbgs() << "Will replace: @" << GV.getName() << " -> @" << Slot->getName() << "\n"); - SameContentReplacements.push_back(std::make_pair(GV, Slot)); + SameContentReplacements.push_back(std::make_pair(&GV, Slot)); } // Now that we have figured out which replacements must be made, do them all diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index d95fd55870f8..fb9ab7954e36 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -175,8 +175,8 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { // to pass in a smaller number of arguments into the new function. // std::vector<Value *> Args; - for (Value::user_iterator I = Fn.user_begin(), E = Fn.user_end(); I != E; ) { - CallBase *CB = dyn_cast<CallBase>(*I++); + for (User *U : llvm::make_early_inc_range(Fn.users())) { + CallBase *CB = dyn_cast<CallBase>(U); if (!CB) continue; @@ -188,9 +188,9 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { if (!PAL.isEmpty()) { SmallVector<AttributeSet, 8> ArgAttrs; for (unsigned ArgNo = 0; ArgNo < NumArgs; ++ArgNo) - ArgAttrs.push_back(PAL.getParamAttributes(ArgNo)); - PAL = AttributeList::get(Fn.getContext(), PAL.getFnAttributes(), - PAL.getRetAttributes(), ArgAttrs); + ArgAttrs.push_back(PAL.getParamAttrs(ArgNo)); + PAL = AttributeList::get(Fn.getContext(), PAL.getFnAttrs(), + PAL.getRetAttrs(), ArgAttrs); } SmallVector<OperandBundleDef, 1> OpBundles; @@ -762,8 +762,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { if (LiveValues.erase(Arg)) { Params.push_back(I->getType()); ArgAlive[ArgI] = true; - ArgAttrVec.push_back(PAL.getParamAttributes(ArgI)); - HasLiveReturnedArg |= PAL.hasParamAttribute(ArgI, Attribute::Returned); + ArgAttrVec.push_back(PAL.getParamAttrs(ArgI)); + HasLiveReturnedArg |= PAL.hasParamAttr(ArgI, Attribute::Returned); } else { ++NumArgumentsEliminated; LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " @@ -838,7 +838,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { assert(NRetTy && "No new return type found?"); // The existing function return attributes. - AttrBuilder RAttrs(PAL.getRetAttributes()); + AttrBuilder RAttrs(PAL.getRetAttrs()); // Remove any incompatible attributes, but only if we removed all return // values. Otherwise, ensure that we don't have any conflicting attributes @@ -853,8 +853,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs); // Strip allocsize attributes. They might refer to the deleted arguments. - AttributeSet FnAttrs = PAL.getFnAttributes().removeAttribute( - F->getContext(), Attribute::AllocSize); + AttributeSet FnAttrs = + PAL.getFnAttrs().removeAttribute(F->getContext(), Attribute::AllocSize); // Reconstruct the AttributesList based on the vector we constructed. assert(ArgAttrVec.size() == Params.size()); @@ -889,7 +889,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Adjust the call return attributes in case the function was changed to // return void. - AttrBuilder RAttrs(CallPAL.getRetAttributes()); + AttrBuilder RAttrs(CallPAL.getRetAttrs()); RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy)); AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs); @@ -903,7 +903,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { if (ArgAlive[Pi]) { Args.push_back(*I); // Get original parameter attributes, but skip return attributes. - AttributeSet Attrs = CallPAL.getParamAttributes(Pi); + AttributeSet Attrs = CallPAL.getParamAttrs(Pi); if (NRetTy != RetTy && Attrs.hasAttribute(Attribute::Returned)) { // If the return type has changed, then get rid of 'returned' on the // call site. The alternative is to make all 'returned' attributes on @@ -922,7 +922,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Push any varargs arguments on the list. Don't forget their attributes. for (auto E = CB.arg_end(); I != E; ++I, ++Pi) { Args.push_back(*I); - ArgAttrVec.push_back(CallPAL.getParamAttributes(Pi)); + ArgAttrVec.push_back(CallPAL.getParamAttrs(Pi)); } // Reconstruct the AttributesList based on the vector we constructed. @@ -930,7 +930,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Again, be sure to remove any allocsize attributes, since their indices // may now be incorrect. - AttributeSet FnAttrs = CallPAL.getFnAttributes().removeAttribute( + AttributeSet FnAttrs = CallPAL.getFnAttrs().removeAttribute( F->getContext(), Attribute::AllocSize); AttributeList NewCallPAL = AttributeList::get( @@ -1094,11 +1094,9 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, // fused with the next loop, because deleting a function invalidates // information computed while surveying other functions. LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Deleting dead varargs\n"); - for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { - Function &F = *I++; + for (Function &F : llvm::make_early_inc_range(M)) if (F.getFunctionType()->isVarArg()) Changed |= DeleteDeadVarargs(F); - } // Second phase:loop through the module, determining which arguments are live. // We assume all arguments are dead unless proven otherwise (allowing us to @@ -1109,13 +1107,10 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, SurveyFunction(F); // Now, remove all dead arguments and return values from each function in - // turn. - for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { - // Increment now, because the function will probably get removed (ie. - // replaced by a new one). - Function *F = &*I++; - Changed |= RemoveDeadStuffFromFunction(F); - } + // turn. We use make_early_inc_range here because functions will probably get + // removed (i.e. replaced by new ones). + for (Function &F : llvm::make_early_inc_range(M)) + Changed |= RemoveDeadStuffFromFunction(&F); // Finally, look for any unused parameters in functions with non-local // linkage and replace the passed in parameters with undef. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp index ba0efd46af16..387f114f6ffa 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp @@ -121,32 +121,27 @@ namespace { } // Visit the Aliases. - for (Module::alias_iterator I = M.alias_begin(), E = M.alias_end(); - I != E;) { - Module::alias_iterator CurI = I; - ++I; - - bool Delete = deleteStuff == (bool)Named.count(&*CurI); - makeVisible(*CurI, Delete); + for (GlobalAlias &GA : llvm::make_early_inc_range(M.aliases())) { + bool Delete = deleteStuff == (bool)Named.count(&GA); + makeVisible(GA, Delete); if (Delete) { - Type *Ty = CurI->getValueType(); + Type *Ty = GA.getValueType(); - CurI->removeFromParent(); + GA.removeFromParent(); llvm::Value *Declaration; if (FunctionType *FTy = dyn_cast<FunctionType>(Ty)) { - Declaration = Function::Create(FTy, GlobalValue::ExternalLinkage, - CurI->getAddressSpace(), - CurI->getName(), &M); + Declaration = + Function::Create(FTy, GlobalValue::ExternalLinkage, + GA.getAddressSpace(), GA.getName(), &M); } else { Declaration = - new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, - nullptr, CurI->getName()); - + new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, + nullptr, GA.getName()); } - CurI->replaceAllUsesWith(Declaration); - delete &*CurI; + GA.replaceAllUsesWith(Declaration); + delete &GA; } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 47fdf042f9d4..16d00a0c89e1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -50,14 +50,14 @@ static void forceAttributes(Function &F) { return Kind; }; - for (auto &S : ForceAttributes) { + for (const auto &S : ForceAttributes) { auto Kind = ParseFunctionAndAttr(S); if (Kind == Attribute::None || F.hasFnAttribute(Kind)) continue; F.addFnAttr(Kind); } - for (auto &S : ForceRemoveAttributes) { + for (const auto &S : ForceRemoveAttributes) { auto Kind = ParseFunctionAndAttr(S); if (Kind == Attribute::None || !F.hasFnAttribute(Kind)) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index ca8660a98ded..cde78713b554 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -14,10 +14,12 @@ #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -82,6 +84,11 @@ STATISTIC(NumNoFree, "Number of functions marked as nofree"); STATISTIC(NumWillReturn, "Number of functions marked as willreturn"); STATISTIC(NumNoSync, "Number of functions marked as nosync"); +STATISTIC(NumThinLinkNoRecurse, + "Number of functions marked as norecurse during thinlink"); +STATISTIC(NumThinLinkNoUnwind, + "Number of functions marked as nounwind during thinlink"); + static cl::opt<bool> EnableNonnullArgPropagation( "enable-nonnull-arg-prop", cl::init(true), cl::Hidden, cl::desc("Try to propagate nonnull argument attributes from callsites to " @@ -95,6 +102,10 @@ static cl::opt<bool> DisableNoFreeInference( "disable-nofree-inference", cl::Hidden, cl::desc("Stop inferring nofree attribute during function-attrs pass")); +static cl::opt<bool> DisableThinLTOPropagation( + "disable-thinlto-funcattrs", cl::init(true), cl::Hidden, + cl::desc("Don't propagate function-attrs in thinLTO")); + namespace { using SCCNodeSet = SmallSetVector<Function *, 8>; @@ -131,12 +142,10 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, // Scan the function body for instructions that may read or write memory. bool ReadsMemory = false; bool WritesMemory = false; - for (inst_iterator II = inst_begin(F), E = inst_end(F); II != E; ++II) { - Instruction *I = &*II; - + for (Instruction &I : instructions(F)) { // Some instructions can be ignored even if they read or write memory. // Detect these now, skipping to the next instruction if one is found. - if (auto *Call = dyn_cast<CallBase>(I)) { + if (auto *Call = dyn_cast<CallBase>(&I)) { // Ignore calls to functions in the same SCC, as long as the call sites // don't have operand bundles. Calls with operand bundles are allowed to // have memory effects not described by the memory effects of the call @@ -170,14 +179,13 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, // Check whether all pointer arguments point to local memory, and // ignore calls that only access local memory. - for (auto CI = Call->arg_begin(), CE = Call->arg_end(); CI != CE; ++CI) { - Value *Arg = *CI; + for (const Use &U : Call->args()) { + const Value *Arg = U; if (!Arg->getType()->isPtrOrPtrVectorTy()) continue; - AAMDNodes AAInfo; - I->getAAMetadata(AAInfo); - MemoryLocation Loc = MemoryLocation::getBeforeOrAfter(Arg, AAInfo); + MemoryLocation Loc = + MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()); // Skip accesses to local or constant memory as they don't impact the // externally visible mod/ref behavior. @@ -192,21 +200,21 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, ReadsMemory = true; } continue; - } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + } else if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { // Ignore non-volatile loads from local memory. (Atomic is okay here.) if (!LI->isVolatile()) { MemoryLocation Loc = MemoryLocation::get(LI); if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) continue; } - } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { + } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { // Ignore non-volatile stores to local memory. (Atomic is okay here.) if (!SI->isVolatile()) { MemoryLocation Loc = MemoryLocation::get(SI); if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) continue; } - } else if (VAArgInst *VI = dyn_cast<VAArgInst>(I)) { + } else if (VAArgInst *VI = dyn_cast<VAArgInst>(&I)) { // Ignore vaargs on local memory. MemoryLocation Loc = MemoryLocation::get(VI); if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) @@ -217,10 +225,10 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, // read or write memory. // // Writes memory, remember that. - WritesMemory |= I->mayWriteToMemory(); + WritesMemory |= I.mayWriteToMemory(); // If this instruction may read memory, remember that. - ReadsMemory |= I->mayReadFromMemory(); + ReadsMemory |= I.mayReadFromMemory(); } if (WritesMemory) { @@ -240,7 +248,8 @@ MemoryAccessKind llvm::computeFunctionBodyMemoryAccess(Function &F, /// Deduce readonly/readnone attributes for the SCC. template <typename AARGetterT> -static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { +static void addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, + SmallSet<Function *, 8> &Changed) { // Check if any of the functions in the SCC read or write memory. If they // write memory then they can't be marked readnone or readonly. bool ReadsMemory = false; @@ -255,7 +264,7 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { switch (checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes)) { case MAK_MayWrite: - return false; + return; case MAK_ReadOnly: ReadsMemory = true; break; @@ -271,11 +280,10 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { // If the SCC contains both functions that read and functions that write, then // we cannot add readonly attributes. if (ReadsMemory && WritesMemory) - return false; + return; // Success! Functions in this SCC do not access memory, or only read memory. // Give them the appropriate attribute. - bool MadeChange = false; for (Function *F : SCCNodes) { if (F->doesNotAccessMemory()) @@ -289,7 +297,7 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { if (F->doesNotReadMemory() && WritesMemory) continue; - MadeChange = true; + Changed.insert(F); // Clear out any existing attributes. AttrBuilder AttrsToRemove; @@ -303,7 +311,7 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { AttrsToRemove.addAttribute(Attribute::InaccessibleMemOnly); AttrsToRemove.addAttribute(Attribute::InaccessibleMemOrArgMemOnly); } - F->removeAttributes(AttributeList::FunctionIndex, AttrsToRemove); + F->removeFnAttrs(AttrsToRemove); // Add in the new attribute. if (WritesMemory && !ReadsMemory) @@ -318,8 +326,195 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { else ++NumReadNone; } +} + +// Compute definitive function attributes for a function taking into account +// prevailing definitions and linkage types +static FunctionSummary *calculatePrevailingSummary( + ValueInfo VI, + DenseMap<ValueInfo, FunctionSummary *> &CachedPrevailingSummary, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing) { + + if (CachedPrevailingSummary.count(VI)) + return CachedPrevailingSummary[VI]; + + /// At this point, prevailing symbols have been resolved. The following leads + /// to returning a conservative result: + /// - Multiple instances with local linkage. Normally local linkage would be + /// unique per module + /// as the GUID includes the module path. We could have a guid alias if + /// there wasn't any distinguishing path when each file was compiled, but + /// that should be rare so we'll punt on those. + + /// These next 2 cases should not happen and will assert: + /// - Multiple instances with external linkage. This should be caught in + /// symbol resolution + /// - Non-existent FunctionSummary for Aliasee. This presents a hole in our + /// knowledge meaning we have to go conservative. + + /// Otherwise, we calculate attributes for a function as: + /// 1. If we have a local linkage, take its attributes. If there's somehow + /// multiple, bail and go conservative. + /// 2. If we have an external/WeakODR/LinkOnceODR linkage check that it is + /// prevailing, take its attributes. + /// 3. If we have a Weak/LinkOnce linkage the copies can have semantic + /// differences. However, if the prevailing copy is known it will be used + /// so take its attributes. If the prevailing copy is in a native file + /// all IR copies will be dead and propagation will go conservative. + /// 4. AvailableExternally summaries without a prevailing copy are known to + /// occur in a couple of circumstances: + /// a. An internal function gets imported due to its caller getting + /// imported, it becomes AvailableExternally but no prevailing + /// definition exists. Because it has to get imported along with its + /// caller the attributes will be captured by propagating on its + /// caller. + /// b. C++11 [temp.explicit]p10 can generate AvailableExternally + /// definitions of explicitly instanced template declarations + /// for inlining which are ultimately dropped from the TU. Since this + /// is localized to the TU the attributes will have already made it to + /// the callers. + /// These are edge cases and already captured by their callers so we + /// ignore these for now. If they become relevant to optimize in the + /// future this can be revisited. + /// 5. Otherwise, go conservative. + + CachedPrevailingSummary[VI] = nullptr; + FunctionSummary *Local = nullptr; + FunctionSummary *Prevailing = nullptr; + + for (const auto &GVS : VI.getSummaryList()) { + if (!GVS->isLive()) + continue; + + FunctionSummary *FS = dyn_cast<FunctionSummary>(GVS->getBaseObject()); + // Virtual and Unknown (e.g. indirect) calls require going conservative + if (!FS || FS->fflags().HasUnknownCall) + return nullptr; + + const auto &Linkage = GVS->linkage(); + if (GlobalValue::isLocalLinkage(Linkage)) { + if (Local) { + LLVM_DEBUG( + dbgs() + << "ThinLTO FunctionAttrs: Multiple Local Linkage, bailing on " + "function " + << VI.name() << " from " << FS->modulePath() << ". Previous module " + << Local->modulePath() << "\n"); + return nullptr; + } + Local = FS; + } else if (GlobalValue::isExternalLinkage(Linkage)) { + assert(IsPrevailing(VI.getGUID(), GVS.get())); + Prevailing = FS; + break; + } else if (GlobalValue::isWeakODRLinkage(Linkage) || + GlobalValue::isLinkOnceODRLinkage(Linkage) || + GlobalValue::isWeakAnyLinkage(Linkage) || + GlobalValue::isLinkOnceAnyLinkage(Linkage)) { + if (IsPrevailing(VI.getGUID(), GVS.get())) { + Prevailing = FS; + break; + } + } else if (GlobalValue::isAvailableExternallyLinkage(Linkage)) { + // TODO: Handle these cases if they become meaningful + continue; + } + } + + if (Local) { + assert(!Prevailing); + CachedPrevailingSummary[VI] = Local; + } else if (Prevailing) { + assert(!Local); + CachedPrevailingSummary[VI] = Prevailing; + } - return MadeChange; + return CachedPrevailingSummary[VI]; +} + +bool llvm::thinLTOPropagateFunctionAttrs( + ModuleSummaryIndex &Index, + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing) { + // TODO: implement addNoAliasAttrs once + // there's more information about the return type in the summary + if (DisableThinLTOPropagation) + return false; + + DenseMap<ValueInfo, FunctionSummary *> CachedPrevailingSummary; + bool Changed = false; + + auto PropagateAttributes = [&](std::vector<ValueInfo> &SCCNodes) { + // Assume we can propagate unless we discover otherwise + FunctionSummary::FFlags InferredFlags; + InferredFlags.NoRecurse = (SCCNodes.size() == 1); + InferredFlags.NoUnwind = true; + + for (auto &V : SCCNodes) { + FunctionSummary *CallerSummary = + calculatePrevailingSummary(V, CachedPrevailingSummary, IsPrevailing); + + // Function summaries can fail to contain information such as declarations + if (!CallerSummary) + return; + + if (CallerSummary->fflags().MayThrow) + InferredFlags.NoUnwind = false; + + for (const auto &Callee : CallerSummary->calls()) { + FunctionSummary *CalleeSummary = calculatePrevailingSummary( + Callee.first, CachedPrevailingSummary, IsPrevailing); + + if (!CalleeSummary) + return; + + if (!CalleeSummary->fflags().NoRecurse) + InferredFlags.NoRecurse = false; + + if (!CalleeSummary->fflags().NoUnwind) + InferredFlags.NoUnwind = false; + + if (!InferredFlags.NoUnwind && !InferredFlags.NoRecurse) + break; + } + } + + if (InferredFlags.NoUnwind || InferredFlags.NoRecurse) { + Changed = true; + for (auto &V : SCCNodes) { + if (InferredFlags.NoRecurse) { + LLVM_DEBUG(dbgs() << "ThinLTO FunctionAttrs: Propagated NoRecurse to " + << V.name() << "\n"); + ++NumThinLinkNoRecurse; + } + + if (InferredFlags.NoUnwind) { + LLVM_DEBUG(dbgs() << "ThinLTO FunctionAttrs: Propagated NoUnwind to " + << V.name() << "\n"); + ++NumThinLinkNoUnwind; + } + + for (auto &S : V.getSummaryList()) { + if (auto *FS = dyn_cast<FunctionSummary>(S.get())) { + if (InferredFlags.NoRecurse) + FS->setNoRecurse(); + + if (InferredFlags.NoUnwind) + FS->setNoUnwind(); + } + } + } + } + }; + + // Call propagation functions on each SCC in the Index + for (scc_iterator<ModuleSummaryIndex *> I = scc_begin(&Index); !I.isAtEnd(); + ++I) { + std::vector<ValueInfo> Nodes(*I); + PropagateAttributes(Nodes); + } + return Changed; } namespace { @@ -395,7 +590,7 @@ struct ArgumentUsesTracker : public CaptureTracker { assert(UseIndex < CB->data_operands_size() && "Indirect function calls should have been filtered above!"); - if (UseIndex >= CB->getNumArgOperands()) { + if (UseIndex >= CB->arg_size()) { // Data operand, but not a argument operand -- must be a bundle operand assert(CB->hasOperandBundles() && "Must be!"); @@ -530,7 +725,7 @@ determinePointerReadAttrs(Argument *A, assert(UseIndex < CB.data_operands_size() && "Data operand use expected!"); - bool IsOperandBundleUse = UseIndex >= CB.getNumArgOperands(); + bool IsOperandBundleUse = UseIndex >= CB.arg_size(); if (UseIndex >= F->arg_size() && !IsOperandBundleUse) { assert(F->isVarArg() && "More params than args in non-varargs call"); @@ -581,9 +776,8 @@ determinePointerReadAttrs(Argument *A, } /// Deduce returned attributes for the SCC. -static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { - bool Changed = false; - +static void addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { // Check each function in turn, determining if an argument is always returned. for (Function *F : SCCNodes) { // We can infer and propagate function attributes only when we know that the @@ -623,11 +817,9 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { auto *A = cast<Argument>(RetArg); A->addAttr(Attribute::Returned); ++NumReturned; - Changed = true; + Changed.insert(F); } } - - return Changed; } /// If a callsite has arguments that are also arguments to the parent function, @@ -693,9 +885,8 @@ static bool addReadAttr(Argument *A, Attribute::AttrKind R) { } /// Deduce nocapture attributes for the SCC. -static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { - bool Changed = false; - +static void addArgumentAttrs(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { ArgumentGraph AG; // Check each function in turn, determining which pointer arguments are not @@ -707,7 +898,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (!F->hasExactDefinition()) continue; - Changed |= addArgumentAttrsFromCallsites(*F); + if (addArgumentAttrsFromCallsites(*F)) + Changed.insert(F); // Functions that are readonly (or readnone) and nounwind and don't return // a value can't capture arguments. Don't analyze them. @@ -718,7 +910,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (A->getType()->isPointerTy() && !A->hasNoCaptureAttr()) { A->addAttr(Attribute::NoCapture); ++NumNoCapture; - Changed = true; + Changed.insert(F); } } continue; @@ -737,7 +929,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { // If it's trivially not captured, mark it nocapture now. A->addAttr(Attribute::NoCapture); ++NumNoCapture; - Changed = true; + Changed.insert(F); } else { // If it's not trivially captured and not trivially not captured, // then it must be calling into another function in our SCC. Save @@ -761,7 +953,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { Self.insert(&*A); Attribute::AttrKind R = determinePointerReadAttrs(&*A, Self); if (R != Attribute::None) - Changed = addReadAttr(A, R); + if (addReadAttr(A, R)) + Changed.insert(F); } } } @@ -785,7 +978,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { Argument *A = ArgumentSCC[0]->Definition; A->addAttr(Attribute::NoCapture); ++NumNoCapture; - Changed = true; + Changed.insert(A->getParent()); } continue; } @@ -827,7 +1020,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { Argument *A = ArgumentSCC[i]->Definition; A->addAttr(Attribute::NoCapture); ++NumNoCapture; - Changed = true; + Changed.insert(A->getParent()); } // We also want to compute readonly/readnone. With a small number of false @@ -858,12 +1051,11 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (ReadAttr != Attribute::None) { for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) { Argument *A = ArgumentSCC[i]->Definition; - Changed = addReadAttr(A, ReadAttr); + if (addReadAttr(A, ReadAttr)) + Changed.insert(A->getParent()); } } } - - return Changed; } /// Tests whether a function is "malloc-like". @@ -934,7 +1126,8 @@ static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) { } /// Deduce noalias attributes for the SCC. -static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { +static void addNoAliasAttrs(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { // Check each function in turn, determining which functions return noalias // pointers. for (Function *F : SCCNodes) { @@ -946,7 +1139,7 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { // definition we'll get at link time is *exactly* the definition we see now. // For more details, see GlobalValue::mayBeDerefined. if (!F->hasExactDefinition()) - return false; + return; // We annotate noalias return values, which are only applicable to // pointer types. @@ -954,10 +1147,9 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { continue; if (!isFunctionMallocLike(F, SCCNodes)) - return false; + return; } - bool MadeChange = false; for (Function *F : SCCNodes) { if (F->returnDoesNotAlias() || !F->getReturnType()->isPointerTy()) @@ -965,10 +1157,8 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { F->setReturnDoesNotAlias(); ++NumNoAlias; - MadeChange = true; + Changed.insert(F); } - - return MadeChange; } /// Tests whether this function is known to not return null. @@ -1044,26 +1234,24 @@ static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, } /// Deduce nonnull attributes for the SCC. -static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { +static void addNonNullAttrs(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { // Speculative that all functions in the SCC return only nonnull // pointers. We may refute this as we analyze functions. bool SCCReturnsNonNull = true; - bool MadeChange = false; - // Check each function in turn, determining which functions return nonnull // pointers. for (Function *F : SCCNodes) { // Already nonnull. - if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, - Attribute::NonNull)) + if (F->getAttributes().hasRetAttr(Attribute::NonNull)) continue; // We can infer and propagate function attributes only when we know that the // definition we'll get at link time is *exactly* the definition we see now. // For more details, see GlobalValue::mayBeDerefined. if (!F->hasExactDefinition()) - return false; + return; // We annotate nonnull return values, which are only applicable to // pointer types. @@ -1077,9 +1265,9 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { // which prevents us from speculating about the entire SCC LLVM_DEBUG(dbgs() << "Eagerly marking " << F->getName() << " as nonnull\n"); - F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + F->addRetAttr(Attribute::NonNull); ++NumNonNullReturn; - MadeChange = true; + Changed.insert(F); } continue; } @@ -1090,19 +1278,16 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { if (SCCReturnsNonNull) { for (Function *F : SCCNodes) { - if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, - Attribute::NonNull) || + if (F->getAttributes().hasRetAttr(Attribute::NonNull) || !F->getReturnType()->isPointerTy()) continue; LLVM_DEBUG(dbgs() << "SCC marking " << F->getName() << " as nonnull\n"); - F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + F->addRetAttr(Attribute::NonNull); ++NumNonNullReturn; - MadeChange = true; + Changed.insert(F); } } - - return MadeChange; } namespace { @@ -1155,12 +1340,13 @@ public: InferenceDescriptors.push_back(AttrInference); } - bool run(const SCCNodeSet &SCCNodes); + void run(const SCCNodeSet &SCCNodes, SmallSet<Function *, 8> &Changed); }; /// Perform all the requested attribute inference actions according to the /// attribute predicates stored before. -bool AttributeInferer::run(const SCCNodeSet &SCCNodes) { +void AttributeInferer::run(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { SmallVector<InferenceDescriptor, 4> InferInSCC = InferenceDescriptors; // Go through all the functions in SCC and check corresponding attribute // assumptions for each of them. Attributes that are invalid for this SCC @@ -1169,7 +1355,7 @@ bool AttributeInferer::run(const SCCNodeSet &SCCNodes) { // No attributes whose assumptions are still valid - done. if (InferInSCC.empty()) - return false; + return; // Check if our attributes ever need scanning/can be scanned. llvm::erase_if(InferInSCC, [F](const InferenceDescriptor &ID) { @@ -1212,9 +1398,8 @@ bool AttributeInferer::run(const SCCNodeSet &SCCNodes) { } if (InferInSCC.empty()) - return false; + return; - bool Changed = false; for (Function *F : SCCNodes) // At this point InferInSCC contains only functions that were either: // - explicitly skipped from scan/inference, or @@ -1223,10 +1408,9 @@ bool AttributeInferer::run(const SCCNodeSet &SCCNodes) { for (auto &ID : InferInSCC) { if (ID.SkipFunction(*F)) continue; - Changed = true; + Changed.insert(F); ID.SetAttribute(*F); } - return Changed; } struct SCCNodesResult { @@ -1243,7 +1427,7 @@ static bool InstrBreaksNonConvergent(Instruction &I, // Breaks non-convergent assumption if CS is a convergent call to a function // not in the SCC. return CB && CB->isConvergent() && - SCCNodes.count(CB->getCalledFunction()) == 0; + !SCCNodes.contains(CB->getCalledFunction()); } /// Helper for NoUnwind inference predicate InstrBreaksAttribute. @@ -1282,7 +1466,8 @@ static bool InstrBreaksNoFree(Instruction &I, const SCCNodeSet &SCCNodes) { /// Attempt to remove convergent function attribute when possible. /// /// Returns true if any changes to function attributes were made. -static bool inferConvergent(const SCCNodeSet &SCCNodes) { +static void inferConvergent(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { AttributeInferer AI; // Request to remove the convergent attribute from all functions in the SCC @@ -1305,7 +1490,7 @@ static bool inferConvergent(const SCCNodeSet &SCCNodes) { }, /* RequiresExactDefinition= */ false}); // Perform all the requested attribute inference actions. - return AI.run(SCCNodes); + AI.run(SCCNodes, Changed); } /// Infer attributes from all functions in the SCC by scanning every @@ -1314,7 +1499,8 @@ static bool inferConvergent(const SCCNodeSet &SCCNodes) { /// - addition of NoUnwind attribute /// /// Returns true if any changes to function attributes were made. -static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { +static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { AttributeInferer AI; if (!DisableNoUnwindInference) @@ -1363,19 +1549,20 @@ static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { /* RequiresExactDefinition= */ true}); // Perform all the requested attribute inference actions. - return AI.run(SCCNodes); + AI.run(SCCNodes, Changed); } -static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { +static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { // Try and identify functions that do not recurse. // If the SCC contains multiple nodes we know for sure there is recursion. if (SCCNodes.size() != 1) - return false; + return; Function *F = *SCCNodes.begin(); if (!F || !F->hasExactDefinition() || F->doesNotRecurse()) - return false; + return; // If all of the calls in F are identifiable and are to norecurse functions, F // is norecurse. This check also detects self-recursion as F is not currently @@ -1386,7 +1573,7 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { Function *Callee = CB->getCalledFunction(); if (!Callee || Callee == F || !Callee->doesNotRecurse()) // Function calls a potentially recursive function. - return false; + return; } // Every call was to a non-recursive function other than this function, and @@ -1394,7 +1581,7 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { // recurse. F->setDoesNotRecurse(); ++NumNoRecurse; - return true; + Changed.insert(F); } static bool instructionDoesNotReturn(Instruction &I) { @@ -1412,9 +1599,8 @@ static bool basicBlockCanReturn(BasicBlock &BB) { } // Set the noreturn function attribute if possible. -static bool addNoReturnAttrs(const SCCNodeSet &SCCNodes) { - bool Changed = false; - +static void addNoReturnAttrs(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { for (Function *F : SCCNodes) { if (!F || !F->hasExactDefinition() || F->hasFnAttribute(Attribute::Naked) || F->doesNotReturn()) @@ -1424,11 +1610,9 @@ static bool addNoReturnAttrs(const SCCNodeSet &SCCNodes) { // FIXME: this doesn't handle recursion or unreachable blocks. if (none_of(*F, basicBlockCanReturn)) { F->setDoesNotReturn(); - Changed = true; + Changed.insert(F); } } - - return Changed; } static bool functionWillReturn(const Function &F) { @@ -1461,19 +1645,16 @@ static bool functionWillReturn(const Function &F) { } // Set the willreturn function attribute if possible. -static bool addWillReturn(const SCCNodeSet &SCCNodes) { - bool Changed = false; - +static void addWillReturn(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { for (Function *F : SCCNodes) { if (!F || F->willReturn() || !functionWillReturn(*F)) continue; F->setWillReturn(); NumWillReturn++; - Changed = true; + Changed.insert(F); } - - return Changed; } // Return true if this is an atomic which has an ordering stronger than @@ -1532,7 +1713,8 @@ static bool InstrBreaksNoSync(Instruction &I, const SCCNodeSet &SCCNodes) { } // Infer the nosync attribute. -static bool addNoSyncAttr(const SCCNodeSet &SCCNodes) { +static void addNoSyncAttr(const SCCNodeSet &SCCNodes, + SmallSet<Function *, 8> &Changed) { AttributeInferer AI; AI.registerAttrInference(AttributeInferer::InferenceDescriptor{ Attribute::NoSync, @@ -1549,14 +1731,15 @@ static bool addNoSyncAttr(const SCCNodeSet &SCCNodes) { ++NumNoSync; }, /* RequiresExactDefinition= */ true}); - return AI.run(SCCNodes); + AI.run(SCCNodes, Changed); } static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) { SCCNodesResult Res; Res.HasUnknownCall = false; for (Function *F : Functions) { - if (!F || F->hasOptNone() || F->hasFnAttribute(Attribute::Naked)) { + if (!F || F->hasOptNone() || F->hasFnAttribute(Attribute::Naked) || + F->isPresplitCoroutine()) { // Treat any function we're trying not to optimize as if it were an // indirect call and omit it from the node set used below. Res.HasUnknownCall = true; @@ -1582,32 +1765,33 @@ static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) { } template <typename AARGetterT> -static bool deriveAttrsInPostOrder(ArrayRef<Function *> Functions, - AARGetterT &&AARGetter) { +static SmallSet<Function *, 8> +deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { SCCNodesResult Nodes = createSCCNodeSet(Functions); - bool Changed = false; // Bail if the SCC only contains optnone functions. if (Nodes.SCCNodes.empty()) - return Changed; + return {}; + + SmallSet<Function *, 8> Changed; - Changed |= addArgumentReturnedAttrs(Nodes.SCCNodes); - Changed |= addReadAttrs(Nodes.SCCNodes, AARGetter); - Changed |= addArgumentAttrs(Nodes.SCCNodes); - Changed |= inferConvergent(Nodes.SCCNodes); - Changed |= addNoReturnAttrs(Nodes.SCCNodes); - Changed |= addWillReturn(Nodes.SCCNodes); + addArgumentReturnedAttrs(Nodes.SCCNodes, Changed); + addReadAttrs(Nodes.SCCNodes, AARGetter, Changed); + addArgumentAttrs(Nodes.SCCNodes, Changed); + inferConvergent(Nodes.SCCNodes, Changed); + addNoReturnAttrs(Nodes.SCCNodes, Changed); + addWillReturn(Nodes.SCCNodes, Changed); // If we have no external nodes participating in the SCC, we can deduce some // more precise attributes as well. if (!Nodes.HasUnknownCall) { - Changed |= addNoAliasAttrs(Nodes.SCCNodes); - Changed |= addNonNullAttrs(Nodes.SCCNodes); - Changed |= inferAttrsFromFunctionBodies(Nodes.SCCNodes); - Changed |= addNoRecurseAttrs(Nodes.SCCNodes); + addNoAliasAttrs(Nodes.SCCNodes, Changed); + addNonNullAttrs(Nodes.SCCNodes, Changed); + inferAttrsFromFunctionBodies(Nodes.SCCNodes, Changed); + addNoRecurseAttrs(Nodes.SCCNodes, Changed); } - Changed |= addNoSyncAttr(Nodes.SCCNodes); + addNoSyncAttr(Nodes.SCCNodes, Changed); // Finally, infer the maximal set of attributes from the ones we've inferred // above. This is handling the cases where one attribute on a signature @@ -1615,7 +1799,8 @@ static bool deriveAttrsInPostOrder(ArrayRef<Function *> Functions, // the later is missing (or simply less sophisticated). for (Function *F : Nodes.SCCNodes) if (F) - Changed |= inferAttributesFromOthers(*F); + if (inferAttributesFromOthers(*F)) + Changed.insert(F); return Changed; } @@ -1638,14 +1823,35 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, Functions.push_back(&N.getFunction()); } - if (deriveAttrsInPostOrder(Functions, AARGetter)) { - // We have not changed the call graph or removed/added functions. - PreservedAnalyses PA; - PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); - return PA; + auto ChangedFunctions = deriveAttrsInPostOrder(Functions, AARGetter); + if (ChangedFunctions.empty()) + return PreservedAnalyses::all(); + + // Invalidate analyses for modified functions so that we don't have to + // invalidate all analyses for all functions in this SCC. + PreservedAnalyses FuncPA; + // We haven't changed the CFG for modified functions. + FuncPA.preserveSet<CFGAnalyses>(); + for (Function *Changed : ChangedFunctions) { + FAM.invalidate(*Changed, FuncPA); + // Also invalidate any direct callers of changed functions since analyses + // may care about attributes of direct callees. For example, MemorySSA cares + // about whether or not a call's callee modifies memory and queries that + // through function attributes. + for (auto *U : Changed->users()) { + if (auto *Call = dyn_cast<CallBase>(U)) { + if (Call->getCalledFunction() == Changed) + FAM.invalidate(*Call->getFunction(), FuncPA); + } + } } - return PreservedAnalyses::all(); + PreservedAnalyses PA; + // We have not added or removed functions. + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We already invalidated all relevant function analyses above. + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; } namespace { @@ -1690,7 +1896,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { Functions.push_back(I->getFunction()); } - return deriveAttrsInPostOrder(Functions, AARGetter); + return !deriveAttrsInPostOrder(Functions, AARGetter).empty(); } bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp index 2f6cf0ca7087..d9b43109f629 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Errc.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" @@ -187,23 +188,6 @@ selectCallee(const ModuleSummaryIndex &Index, return false; } - // For SamplePGO, in computeImportForFunction the OriginalId - // may have been used to locate the callee summary list (See - // comment there). - // The mapping from OriginalId to GUID may return a GUID - // that corresponds to a static variable. Filter it out here. - // This can happen when - // 1) There is a call to a library function which is not defined - // in the index. - // 2) There is a static variable with the OriginalGUID identical - // to the GUID of the library function in 1); - // When this happens, the logic for SamplePGO kicks in and - // the static variable in 2) will be found, which needs to be - // filtered out. - if (GVSummary->getSummaryKind() == GlobalValueSummary::GlobalVarKind) { - Reason = FunctionImporter::ImportFailureReason::GlobalVar; - return false; - } if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) { Reason = FunctionImporter::ImportFailureReason::InterposableLinkage; // There is no point in importing these, we can't inline them @@ -264,21 +248,6 @@ using EdgeInfo = } // anonymous namespace -static ValueInfo -updateValueInfoForIndirectCalls(const ModuleSummaryIndex &Index, ValueInfo VI) { - if (!VI.getSummaryList().empty()) - return VI; - // For SamplePGO, the indirect call targets for local functions will - // have its original name annotated in profile. We try to find the - // corresponding PGOFuncName as the GUID. - // FIXME: Consider updating the edges in the graph after building - // it, rather than needing to perform this mapping on each walk. - auto GUID = Index.getGUIDFromOriginalID(VI.getGUID()); - if (GUID == 0) - return ValueInfo(); - return Index.getValueInfo(GUID); -} - static bool shouldImportGlobal(const ValueInfo &VI, const GVSummaryMapTy &DefinedGVSummaries) { const auto &GVS = DefinedGVSummaries.find(VI.getGUID()); @@ -400,10 +369,6 @@ static void computeImportForFunction( continue; } - VI = updateValueInfoForIndirectCalls(Index, VI); - if (!VI) - continue; - if (DefinedGVSummaries.count(VI.getGUID())) { // FIXME: Consider not skipping import if the module contains // a non-prevailing def with interposable linkage. The prevailing copy @@ -496,7 +461,7 @@ static void computeImportForFunction( VI.name().str() + " due to " + getFailureName(Reason); auto Error = make_error<StringError>( - Msg, std::make_error_code(std::errc::operation_not_supported)); + Msg, make_error_code(errc::not_supported)); logAllUnhandledErrors(std::move(Error), errs(), "Error importing module: "); break; @@ -839,16 +804,61 @@ void llvm::ComputeCrossModuleImportForModuleFromIndex( #endif } -void llvm::computeDeadSymbols( +// For SamplePGO, the indirect call targets for local functions will +// have its original name annotated in profile. We try to find the +// corresponding PGOFuncName as the GUID, and fix up the edges +// accordingly. +void updateValueInfoForIndirectCalls(ModuleSummaryIndex &Index, + FunctionSummary *FS) { + for (auto &EI : FS->mutableCalls()) { + if (!EI.first.getSummaryList().empty()) + continue; + auto GUID = Index.getGUIDFromOriginalID(EI.first.getGUID()); + if (GUID == 0) + continue; + // Update the edge to point directly to the correct GUID. + auto VI = Index.getValueInfo(GUID); + if (llvm::any_of( + VI.getSummaryList(), + [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { + // The mapping from OriginalId to GUID may return a GUID + // that corresponds to a static variable. Filter it out here. + // This can happen when + // 1) There is a call to a library function which is not defined + // in the index. + // 2) There is a static variable with the OriginalGUID identical + // to the GUID of the library function in 1); + // When this happens the static variable in 2) will be found, + // which needs to be filtered out. + return SummaryPtr->getSummaryKind() == + GlobalValueSummary::GlobalVarKind; + })) + continue; + EI.first = VI; + } +} + +void llvm::updateIndirectCalls(ModuleSummaryIndex &Index) { + for (const auto &Entry : Index) { + for (auto &S : Entry.second.SummaryList) { + if (auto *FS = dyn_cast<FunctionSummary>(S.get())) + updateValueInfoForIndirectCalls(Index, FS); + } + } +} + +void llvm::computeDeadSymbolsAndUpdateIndirectCalls( ModuleSummaryIndex &Index, const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols, function_ref<PrevailingType(GlobalValue::GUID)> isPrevailing) { assert(!Index.withGlobalValueDeadStripping()); - if (!ComputeDead) - return; - if (GUIDPreservedSymbols.empty()) - // Don't do anything when nothing is live, this is friendly with tests. + if (!ComputeDead || + // Don't do anything when nothing is live, this is friendly with tests. + GUIDPreservedSymbols.empty()) { + // Still need to update indirect calls. + updateIndirectCalls(Index); return; + } unsigned LiveSymbols = 0; SmallVector<ValueInfo, 128> Worklist; Worklist.reserve(GUIDPreservedSymbols.size() * 2); @@ -863,13 +873,16 @@ void llvm::computeDeadSymbols( // Add values flagged in the index as live roots to the worklist. for (const auto &Entry : Index) { auto VI = Index.getValueInfo(Entry); - for (auto &S : Entry.second.SummaryList) + for (auto &S : Entry.second.SummaryList) { + if (auto *FS = dyn_cast<FunctionSummary>(S.get())) + updateValueInfoForIndirectCalls(Index, FS); if (S->isLive()) { LLVM_DEBUG(dbgs() << "Live root: " << VI << "\n"); Worklist.push_back(VI); ++LiveSymbols; break; } + } } // Make value live and add it to the worklist if it was not live before. @@ -882,9 +895,6 @@ void llvm::computeDeadSymbols( // binary, which increases the binary size unnecessarily. Note that // if this code changes, the importer needs to change so that edges // to functions marked dead are skipped. - VI = updateValueInfoForIndirectCalls(Index, VI); - if (!VI) - return; if (llvm::any_of(VI.getSummaryList(), [](const std::unique_ptr<llvm::GlobalValueSummary> &S) { @@ -958,7 +968,8 @@ void llvm::computeDeadSymbolsWithConstProp( const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols, function_ref<PrevailingType(GlobalValue::GUID)> isPrevailing, bool ImportEnabled) { - computeDeadSymbols(Index, GUIDPreservedSymbols, isPrevailing); + computeDeadSymbolsAndUpdateIndirectCalls(Index, GUIDPreservedSymbols, + isPrevailing); if (ImportEnabled) Index.propagateAttributes(GUIDPreservedSymbols); } @@ -1040,13 +1051,33 @@ bool llvm::convertToDeclaration(GlobalValue &GV) { return true; } -void llvm::thinLTOResolvePrevailingInModule( - Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { - auto updateLinkage = [&](GlobalValue &GV) { +void llvm::thinLTOFinalizeInModule(Module &TheModule, + const GVSummaryMapTy &DefinedGlobals, + bool PropagateAttrs) { + auto FinalizeInModule = [&](GlobalValue &GV, bool Propagate = false) { // See if the global summary analysis computed a new resolved linkage. const auto &GS = DefinedGlobals.find(GV.getGUID()); if (GS == DefinedGlobals.end()) return; + + if (Propagate) + if (FunctionSummary *FS = dyn_cast<FunctionSummary>(GS->second)) { + if (Function *F = dyn_cast<Function>(&GV)) { + // TODO: propagate ReadNone and ReadOnly. + if (FS->fflags().ReadNone && !F->doesNotAccessMemory()) + F->setDoesNotAccessMemory(); + + if (FS->fflags().ReadOnly && !F->onlyReadsMemory()) + F->setOnlyReadsMemory(); + + if (FS->fflags().NoRecurse && !F->doesNotRecurse()) + F->setDoesNotRecurse(); + + if (FS->fflags().NoUnwind && !F->doesNotThrow()) + F->setDoesNotThrow(); + } + } + auto NewLinkage = GS->second->linkage(); if (GlobalValue::isLocalLinkage(GV.getLinkage()) || // Don't internalize anything here, because the code below @@ -1105,11 +1136,11 @@ void llvm::thinLTOResolvePrevailingInModule( // Process functions and global now for (auto &GV : TheModule) - updateLinkage(GV); + FinalizeInModule(GV, PropagateAttrs); for (auto &GV : TheModule.globals()) - updateLinkage(GV); + FinalizeInModule(GV); for (auto &GV : TheModule.aliases()) - updateLinkage(GV); + FinalizeInModule(GV); } /// Run internalization on \p TheModule based on symmary analysis. @@ -1153,7 +1184,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, /// Make alias a clone of its aliasee. static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { - Function *Fn = cast<Function>(GA->getBaseObject()); + Function *Fn = cast<Function>(GA->getAliaseeObject()); ValueToValueMapTy VMap; Function *NewFn = CloneFunction(Fn, VMap); @@ -1259,12 +1290,12 @@ Expected<bool> FunctionImporter::importFunctions( if (Error Err = GA.materialize()) return std::move(Err); // Import alias as a copy of its aliasee. - GlobalObject *Base = GA.getBaseObject(); - if (Error Err = Base->materialize()) + GlobalObject *GO = GA.getAliaseeObject(); + if (Error Err = GO->materialize()) return std::move(Err); auto *Fn = replaceAliasWithAliasee(SrcModule.get(), &GA); - LLVM_DEBUG(dbgs() << "Is importing aliasee fn " << Base->getGUID() - << " " << Base->getName() << " from " + LLVM_DEBUG(dbgs() << "Is importing aliasee fn " << GO->getGUID() << " " + << GO->getName() << " from " << SrcModule->getSourceFileName() << "\n"); if (EnableImportMetadata) { // Add 'thinlto_src_module' metadata for statistics and debugging. @@ -1303,7 +1334,7 @@ Expected<bool> FunctionImporter::importFunctions( std::move(SrcModule), GlobalsToImport.getArrayRef(), [](GlobalValue &, IRMover::ValueAdder) {}, /*IsPerformingImport=*/true)) - report_fatal_error("Function Import: link error: " + + report_fatal_error(Twine("Function Import: link error: ") + toString(std::move(Err))); ImportedCount += GlobalsToImport.size(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index f61f4312b777..fbd083bb9bbf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -11,7 +11,6 @@ // are propagated to the callee by specializing the function. // // Current limitations: -// - It does not handle specialization of recursive functions, // - It does not yet handle integer ranges. // - Only 1 argument per function is specialised, // - The cost-model could be further looked into, @@ -22,6 +21,18 @@ // a direct way to steer function specialization, avoiding the cost-model, // and thus control compile-times / code-size. // +// Todos: +// - Specializing recursive functions relies on running the transformation a +// number of times, which is controlled by option +// `func-specialization-max-iters`. Thus, increasing this value and the +// number of iterations, will linearly increase the number of times recursive +// functions get specialized, see also the discussion in +// https://reviews.llvm.org/D106426 for details. Perhaps there is a +// compile-time friendlier way to control/limit the number of specialisations +// for recursive functions. +// - Don't transform the function if there is no function specialization +// happens. +// //===----------------------------------------------------------------------===// #include "llvm/ADT/Statistic.h" @@ -59,20 +70,166 @@ static cl::opt<unsigned> MaxConstantsThreshold( "specialization"), cl::init(3)); +static cl::opt<unsigned> SmallFunctionThreshold( + "func-specialization-size-threshold", cl::Hidden, + cl::desc("Don't specialize functions that have less than this theshold " + "number of instructions"), + cl::init(100)); + static cl::opt<unsigned> AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden, cl::desc("Average loop iteration count cost"), cl::init(10)); +static cl::opt<bool> SpecializeOnAddresses( + "func-specialization-on-address", cl::init(false), cl::Hidden, + cl::desc("Enable function specialization on the address of global values")); + +// TODO: This needs checking to see the impact on compile-times, which is why +// this is off by default for now. static cl::opt<bool> EnableSpecializationForLiteralConstant( "function-specialization-for-literal-constant", cl::init(false), cl::Hidden, - cl::desc("Make function specialization available for literal constant.")); + cl::desc("Enable specialization of functions that take a literal constant " + "as an argument.")); + +// Helper to check if \p LV is either a constant or a constant +// range with a single element. This should cover exactly the same cases as the +// old ValueLatticeElement::isConstant() and is intended to be used in the +// transition to ValueLatticeElement. +static bool isConstant(const ValueLatticeElement &LV) { + return LV.isConstant() || + (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); +} // Helper to check if \p LV is either overdefined or a constant int. static bool isOverdefined(const ValueLatticeElement &LV) { - return !LV.isUnknownOrUndef() && !LV.isConstant(); + return !LV.isUnknownOrUndef() && !isConstant(LV); +} + +static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { + Value *StoreValue = nullptr; + for (auto *User : Alloca->users()) { + // We can't use llvm::isAllocaPromotable() as that would fail because of + // the usage in the CallInst, which is what we check here. + if (User == Call) + continue; + if (auto *Bitcast = dyn_cast<BitCastInst>(User)) { + if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call) + return nullptr; + continue; + } + + if (auto *Store = dyn_cast<StoreInst>(User)) { + // This is a duplicate store, bail out. + if (StoreValue || Store->isVolatile()) + return nullptr; + StoreValue = Store->getValueOperand(); + continue; + } + // Bail if there is any other unknown usage. + return nullptr; + } + return dyn_cast_or_null<Constant>(StoreValue); } +// A constant stack value is an AllocaInst that has a single constant +// value stored to it. Return this constant if such an alloca stack value +// is a function argument. +static Constant *getConstantStackValue(CallInst *Call, Value *Val, + SCCPSolver &Solver) { + if (!Val) + return nullptr; + Val = Val->stripPointerCasts(); + if (auto *ConstVal = dyn_cast<ConstantInt>(Val)) + return ConstVal; + auto *Alloca = dyn_cast<AllocaInst>(Val); + if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy()) + return nullptr; + return getPromotableAlloca(Alloca, Call); +} + +// To support specializing recursive functions, it is important to propagate +// constant arguments because after a first iteration of specialisation, a +// reduced example may look like this: +// +// define internal void @RecursiveFn(i32* arg1) { +// %temp = alloca i32, align 4 +// store i32 2 i32* %temp, align 4 +// call void @RecursiveFn.1(i32* nonnull %temp) +// ret void +// } +// +// Before a next iteration, we need to propagate the constant like so +// which allows further specialization in next iterations. +// +// @funcspec.arg = internal constant i32 2 +// +// define internal void @someFunc(i32* arg1) { +// call void @otherFunc(i32* nonnull @funcspec.arg) +// ret void +// } +// +static void constantArgPropagation(SmallVectorImpl<Function *> &WorkList, + Module &M, SCCPSolver &Solver) { + // Iterate over the argument tracked functions see if there + // are any new constant values for the call instruction via + // stack variables. + for (auto *F : WorkList) { + // TODO: Generalize for any read only arguments. + if (F->arg_size() != 1) + continue; + + auto &Arg = *F->arg_begin(); + if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy()) + continue; + + for (auto *User : F->users()) { + auto *Call = dyn_cast<CallInst>(User); + if (!Call) + break; + auto *ArgOp = Call->getArgOperand(0); + auto *ArgOpType = ArgOp->getType(); + auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); + if (!ConstVal) + break; + + Value *GV = new GlobalVariable(M, ConstVal->getType(), true, + GlobalValue::InternalLinkage, ConstVal, + "funcspec.arg"); + + if (ArgOpType != ConstVal->getType()) + GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType()); + + Call->setArgOperand(0, GV); + + // Add the changed CallInst to Solver Worklist + Solver.visitCall(*Call); + } + } +} + +// ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics +// interfere with the constantArgPropagation optimization. +static void removeSSACopy(Function &F) { + for (BasicBlock &BB : F) { + for (Instruction &Inst : llvm::make_early_inc_range(BB)) { + auto *II = dyn_cast<IntrinsicInst>(&Inst); + if (!II) + continue; + if (II->getIntrinsicID() != Intrinsic::ssa_copy) + continue; + Inst.replaceAllUsesWith(II->getOperand(0)); + Inst.eraseFromParent(); + } + } +} + +static void removeSSACopy(Module &M) { + for (Function &F : M) + removeSSACopy(F); +} + +namespace { class FunctionSpecializer { /// The IPSCCP Solver. @@ -115,9 +272,14 @@ public: for (auto *SpecializedFunc : CurrentSpecializations) { SpecializedFuncs.insert(SpecializedFunc); - // TODO: If we want to support specializing specialized functions, - // initialize here the state of the newly created functions, marking - // them argument-tracked and executable. + // Initialize the state of the newly created functions, marking them + // argument-tracked and executable. + if (SpecializedFunc->hasExactDefinition() && + !SpecializedFunc->hasFnAttribute(Attribute::Naked)) + Solver.addTrackedFunction(SpecializedFunc); + Solver.addArgumentTrackedFunction(SpecializedFunc); + FuncDecls.push_back(SpecializedFunc); + Solver.markBlockExecutable(&SpecializedFunc->front()); // Replace the function arguments for the specialized functions. for (Argument &Arg : SpecializedFunc->args()) @@ -138,12 +300,22 @@ public: const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); if (isOverdefined(IV)) return false; - auto *Const = IV.isConstant() ? Solver.getConstant(IV) - : UndefValue::get(V->getType()); + auto *Const = + isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); V->replaceAllUsesWith(Const); - // TODO: Update the solver here if we want to specialize specialized - // functions. + for (auto *U : Const->users()) + if (auto *I = dyn_cast<Instruction>(U)) + if (Solver.isBlockExecutable(I->getParent())) + Solver.visit(I); + + // Remove the instruction from Block and Solver. + if (auto *I = dyn_cast<Instruction>(V)) { + if (I->isSafeToRemove()) { + I->eraseFromParent(); + Solver.removeLatticeValueFor(I); + } + } return true; } @@ -152,6 +324,15 @@ private: // also in the cost model. unsigned NbFunctionsSpecialized = 0; + /// Clone the function \p F and remove the ssa_copy intrinsics added by + /// the SCCPSolver in the cloned version. + Function *cloneCandidateFunction(Function *F) { + ValueToValueMapTy EmptyMap; + Function *Clone = CloneFunction(F, EmptyMap); + removeSSACopy(*Clone); + return Clone; + } + /// This function decides whether to specialize function \p F based on the /// known constant values its arguments can take on. Specialization is /// performed on the first interesting argument. Specializations based on @@ -162,9 +343,8 @@ private: SmallVectorImpl<Function *> &Specializations) { // Do not specialize the cloned function again. - if (SpecializedFuncs.contains(F)) { + if (SpecializedFuncs.contains(F)) return false; - } // If we're optimizing the function for size, we shouldn't specialize it. if (F->hasOptSize() || @@ -176,8 +356,25 @@ private: if (!Solver.isBlockExecutable(&F->getEntryBlock())) return false; + // It wastes time to specialize a function which would get inlined finally. + if (F->hasFnAttribute(Attribute::AlwaysInline)) + return false; + LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName() << "\n"); + + // Determine if it would be profitable to create a specialization of the + // function where the argument takes on the given constant value. If so, + // add the constant to Constants. + auto FnSpecCost = getSpecializationCost(F); + if (!FnSpecCost.isValid()) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialisation cost.\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "FnSpecialization: func specialisation cost: "; + FnSpecCost.print(dbgs()); dbgs() << "\n"); + // Determine if we should specialize the function based on the values the // argument can take on. If specialization is not profitable, we continue // on to the next argument. @@ -195,7 +392,7 @@ private: // be set to false by isArgumentInteresting (that function only adds // values to the Constants list that are deemed profitable). SmallVector<Constant *, 4> Constants; - if (!isArgumentInteresting(&A, Constants, IsPartial)) { + if (!isArgumentInteresting(&A, Constants, FnSpecCost, IsPartial)) { LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is not interesting\n"); continue; } @@ -214,8 +411,7 @@ private: for (auto *C : Constants) { // Clone the function. We leave the ValueToValueMap empty to allow // IPSCCP to propagate the constant arguments. - ValueToValueMapTy EmptyMap; - Function *Clone = CloneFunction(F, EmptyMap); + Function *Clone = cloneCandidateFunction(F); Argument *ClonedArg = Clone->arg_begin() + A.getArgNo(); // Rewrite calls to the function so that they call the clone instead. @@ -231,9 +427,10 @@ private: NbFunctionsSpecialized++; } - // TODO: if we want to support specialize specialized functions, and if - // the function has been completely specialized, the original function is - // no longer needed, so we would need to mark it unreachable here. + // If the function has been completely specialized, the original function + // is no longer needed. Mark it unreachable. + if (!IsPartial) + Solver.markFunctionUnreachable(F); // FIXME: Only one argument per function. return true; @@ -253,7 +450,11 @@ private: // If the code metrics reveal that we shouldn't duplicate the function, we // shouldn't specialize it. Set the specialization cost to Invalid. - if (Metrics.notDuplicatable) { + // Or if the lines of codes implies that this function is easy to get + // inlined so that we shouldn't specialize it. + if (Metrics.notDuplicatable || + (!ForceFunctionSpecialization && + Metrics.NumInsts < SmallFunctionThreshold)) { InstructionCost C{}; C.setInvalid(); return C; @@ -379,9 +580,8 @@ private: /// argument. bool isArgumentInteresting(Argument *A, SmallVectorImpl<Constant *> &Constants, + const InstructionCost &FnSpecCost, bool &IsPartial) { - Function *F = A->getParent(); - // For now, don't attempt to specialize functions based on the values of // composite types. if (!A->getType()->isSingleValueType() || A->user_empty()) @@ -420,18 +620,6 @@ private: return false; } - // Determine if it would be profitable to create a specialization of the - // function where the argument takes on the given constant value. If so, - // add the constant to Constants. - auto FnSpecCost = getSpecializationCost(F); - if (!FnSpecCost.isValid()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialisation cost.\n"); - return false; - } - - LLVM_DEBUG(dbgs() << "FnSpecialization: func specialisation cost: "; - FnSpecCost.print(dbgs()); dbgs() << "\n"); - for (auto *C : PossibleConstants) { LLVM_DEBUG(dbgs() << "FnSpecialization: Constant: " << *C << "\n"); if (ForceFunctionSpecialization) { @@ -475,6 +663,12 @@ private: if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) continue; auto &CS = *cast<CallBase>(U); + // If the call site has attribute minsize set, that callsite won't be + // specialized. + if (CS.hasFnAttr(Attribute::MinSize)) { + AllConstant = false; + continue; + } // If the parent of the call site will never be executed, we don't need // to worry about the passed value. @@ -482,11 +676,25 @@ private: continue; auto *V = CS.getArgOperand(A->getArgNo()); + if (isa<PoisonValue>(V)) + return false; + + // For now, constant expressions are fine but only if they are function + // calls. + if (auto *CE = dyn_cast<ConstantExpr>(V)) + if (!isa<Function>(CE->getOperand(0))) + return false; + // TrackValueOfGlobalVariable only tracks scalar global variables. if (auto *GV = dyn_cast<GlobalVariable>(V)) { - if (!GV->getValueType()->isSingleValueType()) { + // Check if we want to specialize on the address of non-constant + // global values. + if (!GV->isConstant()) + if (!SpecializeOnAddresses) + return false; + + if (!GV->getValueType()->isSingleValueType()) return false; - } } if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() || @@ -506,6 +714,9 @@ private: /// This function modifies calls to function \p F whose argument at index \p /// ArgNo is equal to constant \p C. The calls are rewritten to call function /// \p Clone instead. + /// + /// Callsites that have been marked with the MinSize function attribute won't + /// be specialized and rewritten. void rewriteCallSites(Function *F, Function *Clone, Argument &Arg, Constant *C) { unsigned ArgNo = Arg.getArgNo(); @@ -527,24 +738,7 @@ private: } } }; - -/// Function to clean up the left over intrinsics from SCCP util. -static void cleanup(Module &M) { - for (Function &F : M) { - for (BasicBlock &BB : F) { - for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { - Instruction *Inst = &*BI++; - if (auto *II = dyn_cast<IntrinsicInst>(Inst)) { - if (II->getIntrinsicID() == Intrinsic::ssa_copy) { - Value *Op = II->getOperand(0); - Inst->replaceAllUsesWith(Op); - Inst->eraseFromParent(); - } - } - } - } - } -} +} // namespace bool llvm::runFunctionSpecialization( Module &M, const DataLayout &DL, @@ -597,12 +791,27 @@ bool llvm::runFunctionSpecialization( Solver.trackValueOfGlobalVariable(&G); } + auto &TrackedFuncs = Solver.getArgumentTrackedFunctions(); + SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(), + TrackedFuncs.end()); + + // No tracked functions, so nothing to do: don't run the solver and remove + // the ssa_copy intrinsics that may have been introduced. + if (TrackedFuncs.empty()) { + removeSSACopy(M); + return false; + } + // Solve for constants. auto RunSCCPSolver = [&](auto &WorkList) { bool ResolvedUndefs = true; while (ResolvedUndefs) { + // Not running the solver unnecessary is checked in regression test + // nothing-to-do.ll, so if this debug message is changed, this regression + // test needs updating too. LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n"); + Solver.solve(); LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n"); ResolvedUndefs = false; @@ -615,15 +824,14 @@ bool llvm::runFunctionSpecialization( for (BasicBlock &BB : *F) { if (!Solver.isBlockExecutable(&BB)) continue; + // FIXME: The solver may make changes to the function here, so set + // Changed, even if later function specialization does not trigger. for (auto &I : make_early_inc_range(BB)) - FS.tryToReplaceWithConstant(&I); + Changed |= FS.tryToReplaceWithConstant(&I); } } }; - auto &TrackedFuncs = Solver.getArgumentTrackedFunctions(); - SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(), - TrackedFuncs.end()); #ifndef NDEBUG LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n"); for (auto *F : FuncDecls) @@ -637,14 +845,18 @@ bool llvm::runFunctionSpecialization( unsigned I = 0; while (FuncSpecializationMaxIters != I++ && FS.specializeFunctions(FuncDecls, CurrentSpecializations)) { - // TODO: run the solver here for the specialized functions only if we want - // to specialize recursively. + + // Run the solver for the specialized functions. + RunSCCPSolver(CurrentSpecializations); + + // Replace some unresolved constant arguments. + constantArgPropagation(FuncDecls, M, Solver); CurrentSpecializations.clear(); Changed = true; } // Clean up the IR by removing ssa_copy intrinsics. - cleanup(M); + removeSSACopy(M); return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp index fb4cb23b837e..5e5d2086adc2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp @@ -88,7 +88,7 @@ ModulePass *llvm::createGlobalDCEPass() { static bool isEmptyFunction(Function *F) { BasicBlock &Entry = F->getEntryBlock(); for (auto &I : Entry) { - if (isa<DbgInfoIntrinsic>(I)) + if (I.isDebugOrPseudoInst()) continue; if (auto *RI = dyn_cast<ReturnInst>(&I)) return !RI->getReturnValue(); @@ -210,7 +210,7 @@ void GlobalDCEPass::ScanVTableLoad(Function *Caller, Metadata *TypeId, Constant *Ptr = getPointerAtOffset(VTable->getInitializer(), VTableOffset + CallOffset, - *Caller->getParent()); + *Caller->getParent(), VTable); if (!Ptr) { LLVM_DEBUG(dbgs() << "can't find pointer in vtable!\n"); VFESafeVTables.erase(VTable); @@ -416,6 +416,16 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // virtual function pointers with null, allowing us to remove the // function itself. ++NumVFuncs; + + // Detect vfuncs that are referenced as "relative pointers" which are used + // in Swift vtables, i.e. entries in the form of: + // + // i32 trunc (i64 sub (i64 ptrtoint @f, i64 ptrtoint ...)) to i32) + // + // In this case, replace the whole "sub" expression with constant 0 to + // avoid leaving a weird sub(0, symbol) expression behind. + replaceRelativePointerUsersWithZero(F); + F->replaceNonMetadataUsesWith(ConstantPointerNull::get(F->getType())); } EraseUnusedGlobalValue(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 8750eb9ecc4e..b2c2efed7db8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -208,9 +208,7 @@ CleanupPointerRootUsers(GlobalVariable *GV, SmallVector<std::pair<Instruction *, Instruction *>, 32> Dead; // Constants can't be pointers to dynamically allocated memory. - for (Value::user_iterator UI = GV->user_begin(), E = GV->user_end(); - UI != E;) { - User *U = *UI++; + for (User *U : llvm::make_early_inc_range(GV->users())) { if (StoreInst *SI = dyn_cast<StoreInst>(U)) { Value *V = SI->getValueOperand(); if (isa<Constant>(V)) { @@ -703,8 +701,9 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, !ICmpInst::isSigned(cast<ICmpInst>(U)->getPredicate()) && isa<LoadInst>(U->getOperand(0)) && isa<ConstantPointerNull>(U->getOperand(1))) { - assert(isa<GlobalValue>( - cast<LoadInst>(U->getOperand(0))->getPointerOperand()) && + assert(isa<GlobalValue>(cast<LoadInst>(U->getOperand(0)) + ->getPointerOperand() + ->stripPointerCasts()) && "Should be GlobalVariable"); // This and only this kind of non-signed ICmpInst is to be replaced with // the comparing of the value of the created global init bool later in @@ -720,22 +719,55 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, /// Return true if all uses of any loads from GV will trap if the loaded value /// is null. Note that this also permits comparisons of the loaded value /// against null, as a special case. -static bool AllUsesOfLoadedValueWillTrapIfNull(const GlobalVariable *GV) { - for (const User *U : GV->users()) - if (const LoadInst *LI = dyn_cast<LoadInst>(U)) { - SmallPtrSet<const PHINode*, 8> PHIs; - if (!AllUsesOfValueWillTrapIfNull(LI, PHIs)) +static bool allUsesOfLoadedValueWillTrapIfNull(const GlobalVariable *GV) { + SmallVector<const Value *, 4> Worklist; + Worklist.push_back(GV); + while (!Worklist.empty()) { + const Value *P = Worklist.pop_back_val(); + for (auto *U : P->users()) { + if (auto *LI = dyn_cast<LoadInst>(U)) { + SmallPtrSet<const PHINode *, 8> PHIs; + if (!AllUsesOfValueWillTrapIfNull(LI, PHIs)) + return false; + } else if (auto *SI = dyn_cast<StoreInst>(U)) { + // Ignore stores to the global. + if (SI->getPointerOperand() != P) + return false; + } else if (auto *CE = dyn_cast<ConstantExpr>(U)) { + if (CE->stripPointerCasts() != GV) + return false; + // Check further the ConstantExpr. + Worklist.push_back(CE); + } else { + // We don't know or understand this user, bail out. return false; - } else if (isa<StoreInst>(U)) { - // Ignore stores to the global. - } else { - // We don't know or understand this user, bail out. - //cerr << "UNKNOWN USER OF GLOBAL!: " << *U; - return false; + } } + } + return true; } +/// Get all the loads/store uses for global variable \p GV. +static void allUsesOfLoadAndStores(GlobalVariable *GV, + SmallVector<Value *, 4> &Uses) { + SmallVector<Value *, 4> Worklist; + Worklist.push_back(GV); + while (!Worklist.empty()) { + auto *P = Worklist.pop_back_val(); + for (auto *U : P->users()) { + if (auto *CE = dyn_cast<ConstantExpr>(U)) { + Worklist.push_back(CE); + continue; + } + + assert((isa<LoadInst>(U) || isa<StoreInst>(U)) && + "Expect only load or store instructions"); + Uses.push_back(U); + } + } +} + static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { bool Changed = false; for (auto UI = V->user_begin(), E = V->user_end(); UI != E; ) { @@ -817,8 +849,7 @@ static bool OptimizeAwayTrappingUsesOfLoads( bool AllNonStoreUsesGone = true; // Replace all uses of loads with uses of uses of the stored value. - for (Value::user_iterator GUI = GV->user_begin(), E = GV->user_end(); GUI != E;){ - User *GlobalUser = *GUI++; + for (User *GlobalUser : llvm::make_early_inc_range(GV->users())) { if (LoadInst *LI = dyn_cast<LoadInst>(GlobalUser)) { Changed |= OptimizeAwayTrappingUsesOfValue(LI, LV); // If we were able to delete all uses of the loads @@ -934,9 +965,8 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, } } - Constant *RepValue = NewGV; - if (NewGV->getType() != GV->getValueType()) - RepValue = ConstantExpr::getBitCast(RepValue, GV->getValueType()); + SmallPtrSet<Constant *, 1> RepValues; + RepValues.insert(NewGV); // If there is a comparison against null, we will insert a global bool to // keep track of whether the global was initialized yet or not. @@ -947,9 +977,11 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, GV->getName()+".init", GV->getThreadLocalMode()); bool InitBoolUsed = false; - // Loop over all uses of GV, processing them in turn. - while (!GV->use_empty()) { - if (StoreInst *SI = dyn_cast<StoreInst>(GV->user_back())) { + // Loop over all instruction uses of GV, processing them in turn. + SmallVector<Value *, 4> Guses; + allUsesOfLoadAndStores(GV, Guses); + for (auto *U : Guses) { + if (StoreInst *SI = dyn_cast<StoreInst>(U)) { // The global is initialized when the store to it occurs. If the stored // value is null value, the global bool is set to false, otherwise true. new StoreInst(ConstantInt::getBool( @@ -961,12 +993,14 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, continue; } - LoadInst *LI = cast<LoadInst>(GV->user_back()); + LoadInst *LI = cast<LoadInst>(U); while (!LI->use_empty()) { Use &LoadUse = *LI->use_begin(); ICmpInst *ICI = dyn_cast<ICmpInst>(LoadUse.getUser()); if (!ICI) { - LoadUse = RepValue; + auto *CE = ConstantExpr::getBitCast(NewGV, LI->getType()); + RepValues.insert(CE); + LoadUse.set(CE); continue; } @@ -1012,40 +1046,53 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, // To further other optimizations, loop over all users of NewGV and try to // constant prop them. This will promote GEP instructions with constant // indices into GEP constant-exprs, which will allow global-opt to hack on it. - ConstantPropUsersOf(NewGV, DL, TLI); - if (RepValue != NewGV) - ConstantPropUsersOf(RepValue, DL, TLI); + for (auto *CE : RepValues) + ConstantPropUsersOf(CE, DL, TLI); return NewGV; } -/// Scan the use-list of V checking to make sure that there are no complex uses -/// of V. We permit simple things like dereferencing the pointer, but not +/// Scan the use-list of GV checking to make sure that there are no complex uses +/// of GV. We permit simple things like dereferencing the pointer, but not /// storing through the address, unless it is to the specified global. static bool -valueIsOnlyUsedLocallyOrStoredToOneGlobal(const Instruction *V, +valueIsOnlyUsedLocallyOrStoredToOneGlobal(const CallInst *CI, const GlobalVariable *GV) { - for (const User *U : V->users()) { - const Instruction *Inst = cast<Instruction>(U); + SmallPtrSet<const Value *, 4> Visited; + SmallVector<const Value *, 4> Worklist; + Worklist.push_back(CI); - if (isa<LoadInst>(Inst) || isa<CmpInst>(Inst)) { - continue; // Fine, ignore. - } + while (!Worklist.empty()) { + const Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; - if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - if (SI->getOperand(0) == V && SI->getOperand(1) != GV) - return false; // Storing the pointer itself... bad. - continue; // Otherwise, storing through it, or storing into GV... fine. - } + for (const Use &VUse : V->uses()) { + const User *U = VUse.getUser(); + if (isa<LoadInst>(U) || isa<CmpInst>(U)) + continue; // Fine, ignore. - if (const BitCastInst *BCI = dyn_cast<BitCastInst>(Inst)) { - if (!valueIsOnlyUsedLocallyOrStoredToOneGlobal(BCI, GV)) - return false; - continue; - } + if (auto *SI = dyn_cast<StoreInst>(U)) { + if (SI->getValueOperand() == V && + SI->getPointerOperand()->stripPointerCasts() != GV) + return false; // Storing the pointer not into GV... bad. + continue; // Otherwise, storing through it, or storing into GV... fine. + } - return false; + if (auto *BCI = dyn_cast<BitCastInst>(U)) { + Worklist.push_back(BCI); + continue; + } + + if (auto *GEPI = dyn_cast<GetElementPtrInst>(U)) { + Worklist.push_back(GEPI); + continue; + } + + return false; + } } + return true; } @@ -1066,12 +1113,12 @@ static bool tryToOptimizeStoreOfMallocToGlobal(GlobalVariable *GV, CallInst *CI, // been reached). To do this, we check to see if all uses of the global // would trap if the global were null: this proves that they must all // happen after the malloc. - if (!AllUsesOfLoadedValueWillTrapIfNull(GV)) + if (!allUsesOfLoadedValueWillTrapIfNull(GV)) return false; // We can't optimize this if the malloc itself is used in a complex way, // for example, being stored into multiple globals. This allows the - // malloc to be stored into the specified global, loaded icmp'd. + // malloc to be stored into the specified global, loaded, gep, icmp'd. // These are all things we could transform to using the global for. if (!valueIsOnlyUsedLocallyOrStoredToOneGlobal(CI, GV)) return false; @@ -1112,6 +1159,7 @@ optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, // value was null. if (GV->getInitializer()->getType()->isPointerTy() && GV->getInitializer()->isNullValue() && + StoredOnceVal->getType()->isPointerTy() && !NullPointerIsDefined( nullptr /* F */, GV->getInitializer()->getType()->getPointerAddressSpace())) { @@ -1442,8 +1490,7 @@ static void makeAllConstantUsesInstructions(Constant *C) { append_range(UUsers, U->users()); for (auto *UU : UUsers) { Instruction *UI = cast<Instruction>(UU); - Instruction *NewU = U->getAsInstruction(); - NewU->insertBefore(UI); + Instruction *NewU = U->getAsInstruction(UI); UI->replaceUsesOfWith(U, NewU); } // We've replaced all the uses, so destroy the constant. (destroyConstant @@ -1456,6 +1503,7 @@ static void makeAllConstantUsesInstructions(Constant *C) { /// it if possible. If we make a change, return true. static bool processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, + function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree) { auto &DL = GV->getParent()->getDataLayout(); @@ -1554,43 +1602,57 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, if (SRAGlobal(GV, DL)) return true; } - if (GS.StoredType == GlobalStatus::StoredOnce && GS.StoredOnceValue) { + Value *StoredOnceValue = GS.getStoredOnceValue(); + if (GS.StoredType == GlobalStatus::StoredOnce && StoredOnceValue) { + // Avoid speculating constant expressions that might trap (div/rem). + auto *SOVConstant = dyn_cast<Constant>(StoredOnceValue); + if (SOVConstant && SOVConstant->canTrap()) + return Changed; + + Function &StoreFn = + const_cast<Function &>(*GS.StoredOnceStore->getFunction()); + bool CanHaveNonUndefGlobalInitializer = + GetTTI(StoreFn).canHaveNonUndefGlobalInitializerInAddressSpace( + GV->getType()->getAddressSpace()); // If the initial value for the global was an undef value, and if only // one other value was stored into it, we can just change the // initializer to be the stored value, then delete all stores to the // global. This allows us to mark it constant. - if (Constant *SOVConstant = dyn_cast<Constant>(GS.StoredOnceValue)) - if (isa<UndefValue>(GV->getInitializer())) { - // Change the initial value here. - GV->setInitializer(SOVConstant); - - // Clean up any obviously simplifiable users now. - CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); - - if (GV->use_empty()) { - LLVM_DEBUG(dbgs() << " *** Substituting initializer allowed us to " - << "simplify all users and delete global!\n"); - GV->eraseFromParent(); - ++NumDeleted; - } - ++NumSubstitute; - return true; + // This is restricted to address spaces that allow globals to have + // initializers. NVPTX, for example, does not support initializers for + // shared memory (AS 3). + if (SOVConstant && SOVConstant->getType() == GV->getValueType() && + isa<UndefValue>(GV->getInitializer()) && + CanHaveNonUndefGlobalInitializer) { + // Change the initial value here. + GV->setInitializer(SOVConstant); + + // Clean up any obviously simplifiable users now. + CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); + + if (GV->use_empty()) { + LLVM_DEBUG(dbgs() << " *** Substituting initializer allowed us to " + << "simplify all users and delete global!\n"); + GV->eraseFromParent(); + ++NumDeleted; } + ++NumSubstitute; + return true; + } // Try to optimize globals based on the knowledge that only one value // (besides its initializer) is ever stored to the global. - if (optimizeOnceStoredGlobal(GV, GS.StoredOnceValue, GS.Ordering, DL, - GetTLI)) + if (optimizeOnceStoredGlobal(GV, StoredOnceValue, GS.Ordering, DL, GetTLI)) return true; // Otherwise, if the global was not a boolean, we can shrink it to be a - // boolean. - if (Constant *SOVConstant = dyn_cast<Constant>(GS.StoredOnceValue)) { - if (GS.Ordering == AtomicOrdering::NotAtomic) { - if (TryToShrinkGlobalToBoolean(GV, SOVConstant)) { - ++NumShrunkToBool; - return true; - } + // boolean. Skip this optimization for AS that doesn't allow an initializer. + if (SOVConstant && GS.Ordering == AtomicOrdering::NotAtomic && + (!isa<UndefValue>(GV->getInitializer()) || + CanHaveNonUndefGlobalInitializer)) { + if (TryToShrinkGlobalToBoolean(GV, SOVConstant)) { + ++NumShrunkToBool; + return true; } } } @@ -1602,6 +1664,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, /// make a change, return true. static bool processGlobal(GlobalValue &GV, + function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree) { if (GV.getName().startswith("llvm.")) @@ -1634,7 +1697,8 @@ processGlobal(GlobalValue &GV, if (GVar->isConstant() || !GVar->hasInitializer()) return Changed; - return processInternalGlobal(GVar, GS, GetTLI, LookupDomTree) || Changed; + return processInternalGlobal(GVar, GS, GetTTI, GetTLI, LookupDomTree) || + Changed; } /// Walk all of the direct calls of the specified function, changing them to @@ -1651,7 +1715,7 @@ static AttributeList StripAttr(LLVMContext &C, AttributeList Attrs, Attribute::AttrKind A) { unsigned AttrIndex; if (Attrs.hasAttrSomewhere(A, &AttrIndex)) - return Attrs.removeAttribute(C, AttrIndex, A); + return Attrs.removeAttributeAtIndex(C, AttrIndex, A); return Attrs; } @@ -1864,10 +1928,8 @@ static void RemovePreallocated(Function *F) { Value *AllocaReplacement = ArgAllocas[AllocArgIndex]; if (!AllocaReplacement) { auto AddressSpace = UseCall->getType()->getPointerAddressSpace(); - auto *ArgType = UseCall - ->getAttribute(AttributeList::FunctionIndex, - Attribute::Preallocated) - .getValueAsType(); + auto *ArgType = + UseCall->getFnAttr(Attribute::Preallocated).getValueAsType(); auto *InsertBefore = PreallocatedSetup->getNextNonDebugInstruction(); Builder.SetInsertPoint(InsertBefore); auto *Alloca = @@ -1897,26 +1959,22 @@ OptimizeFunctions(Module &M, bool Changed = false; std::vector<Function *> AllCallsCold; - for (Module::iterator FI = M.begin(), E = M.end(); FI != E;) { - Function *F = &*FI++; - if (hasOnlyColdCalls(*F, GetBFI)) - AllCallsCold.push_back(F); - } + for (Function &F : llvm::make_early_inc_range(M)) + if (hasOnlyColdCalls(F, GetBFI)) + AllCallsCold.push_back(&F); // Optimize functions. - for (Module::iterator FI = M.begin(), E = M.end(); FI != E; ) { - Function *F = &*FI++; - + for (Function &F : llvm::make_early_inc_range(M)) { // Don't perform global opt pass on naked functions; we don't want fast // calling conventions for naked functions. - if (F->hasFnAttribute(Attribute::Naked)) + if (F.hasFnAttribute(Attribute::Naked)) continue; // Functions without names cannot be referenced outside this module. - if (!F->hasName() && !F->isDeclaration() && !F->hasLocalLinkage()) - F->setLinkage(GlobalValue::InternalLinkage); + if (!F.hasName() && !F.isDeclaration() && !F.hasLocalLinkage()) + F.setLinkage(GlobalValue::InternalLinkage); - if (deleteIfDead(*F, NotDiscardableComdats)) { + if (deleteIfDead(F, NotDiscardableComdats)) { Changed = true; continue; } @@ -1931,17 +1989,17 @@ OptimizeFunctions(Module &M, // some more complicated logic to break these cycles. // Removing unreachable blocks might invalidate the dominator so we // recalculate it. - if (!F->isDeclaration()) { - if (removeUnreachableBlocks(*F)) { - auto &DT = LookupDomTree(*F); - DT.recalculate(*F); + if (!F.isDeclaration()) { + if (removeUnreachableBlocks(F)) { + auto &DT = LookupDomTree(F); + DT.recalculate(F); Changed = true; } } - Changed |= processGlobal(*F, GetTLI, LookupDomTree); + Changed |= processGlobal(F, GetTTI, GetTLI, LookupDomTree); - if (!F->hasLocalLinkage()) + if (!F.hasLocalLinkage()) continue; // If we have an inalloca parameter that we can safely remove the @@ -1949,56 +2007,55 @@ OptimizeFunctions(Module &M, // wouldn't be safe in the presence of inalloca. // FIXME: We should also hoist alloca affected by this to the entry // block if possible. - if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca) && - !F->hasAddressTaken() && !hasMustTailCallers(F)) { - RemoveAttribute(F, Attribute::InAlloca); + if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) && + !F.hasAddressTaken() && !hasMustTailCallers(&F)) { + RemoveAttribute(&F, Attribute::InAlloca); Changed = true; } // FIXME: handle invokes // FIXME: handle musttail - if (F->getAttributes().hasAttrSomewhere(Attribute::Preallocated)) { - if (!F->hasAddressTaken() && !hasMustTailCallers(F) && - !hasInvokeCallers(F)) { - RemovePreallocated(F); + if (F.getAttributes().hasAttrSomewhere(Attribute::Preallocated)) { + if (!F.hasAddressTaken() && !hasMustTailCallers(&F) && + !hasInvokeCallers(&F)) { + RemovePreallocated(&F); Changed = true; } continue; } - if (hasChangeableCC(F) && !F->isVarArg() && !F->hasAddressTaken()) { + if (hasChangeableCC(&F) && !F.isVarArg() && !F.hasAddressTaken()) { NumInternalFunc++; - TargetTransformInfo &TTI = GetTTI(*F); + TargetTransformInfo &TTI = GetTTI(F); // Change the calling convention to coldcc if either stress testing is // enabled or the target would like to use coldcc on functions which are // cold at all call sites and the callers contain no other non coldcc // calls. if (EnableColdCCStressTest || - (TTI.useColdCCForColdCall(*F) && - isValidCandidateForColdCC(*F, GetBFI, AllCallsCold))) { - F->setCallingConv(CallingConv::Cold); - changeCallSitesToColdCC(F); + (TTI.useColdCCForColdCall(F) && + isValidCandidateForColdCC(F, GetBFI, AllCallsCold))) { + F.setCallingConv(CallingConv::Cold); + changeCallSitesToColdCC(&F); Changed = true; NumColdCC++; } } - if (hasChangeableCC(F) && !F->isVarArg() && - !F->hasAddressTaken()) { + if (hasChangeableCC(&F) && !F.isVarArg() && !F.hasAddressTaken()) { // If this function has a calling convention worth changing, is not a // varargs function, and is only called directly, promote it to use the // Fast calling convention. - F->setCallingConv(CallingConv::Fast); - ChangeCalleesToFastCall(F); + F.setCallingConv(CallingConv::Fast); + ChangeCalleesToFastCall(&F); ++NumFastCallFns; Changed = true; } - if (F->getAttributes().hasAttrSomewhere(Attribute::Nest) && - !F->hasAddressTaken()) { + if (F.getAttributes().hasAttrSomewhere(Attribute::Nest) && + !F.hasAddressTaken()) { // The function is not used by a trampoline intrinsic, so it is safe // to remove the 'nest' attribute. - RemoveAttribute(F, Attribute::Nest); + RemoveAttribute(&F, Attribute::Nest); ++NumNestRemoved; Changed = true; } @@ -2008,35 +2065,34 @@ OptimizeFunctions(Module &M, static bool OptimizeGlobalVars(Module &M, + function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree, SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { bool Changed = false; - for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); - GVI != E; ) { - GlobalVariable *GV = &*GVI++; + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { // Global variables without names cannot be referenced outside this module. - if (!GV->hasName() && !GV->isDeclaration() && !GV->hasLocalLinkage()) - GV->setLinkage(GlobalValue::InternalLinkage); + if (!GV.hasName() && !GV.isDeclaration() && !GV.hasLocalLinkage()) + GV.setLinkage(GlobalValue::InternalLinkage); // Simplify the initializer. - if (GV->hasInitializer()) - if (auto *C = dyn_cast<Constant>(GV->getInitializer())) { + if (GV.hasInitializer()) + if (auto *C = dyn_cast<Constant>(GV.getInitializer())) { auto &DL = M.getDataLayout(); // TLI is not used in the case of a Constant, so use default nullptr // for that optional parameter, since we don't have a Function to // provide GetTLI anyway. Constant *New = ConstantFoldConstant(C, DL, /*TLI*/ nullptr); if (New != C) - GV->setInitializer(New); + GV.setInitializer(New); } - if (deleteIfDead(*GV, NotDiscardableComdats)) { + if (deleteIfDead(GV, NotDiscardableComdats)) { Changed = true; continue; } - Changed |= processGlobal(*GV, GetTLI, LookupDomTree); + Changed |= processGlobal(GV, GetTTI, GetTLI, LookupDomTree); } return Changed; } @@ -2425,24 +2481,21 @@ OptimizeGlobalAliases(Module &M, for (GlobalValue *GV : Used.used()) Used.compilerUsedErase(GV); - for (Module::alias_iterator I = M.alias_begin(), E = M.alias_end(); - I != E;) { - GlobalAlias *J = &*I++; - + for (GlobalAlias &J : llvm::make_early_inc_range(M.aliases())) { // Aliases without names cannot be referenced outside this module. - if (!J->hasName() && !J->isDeclaration() && !J->hasLocalLinkage()) - J->setLinkage(GlobalValue::InternalLinkage); + if (!J.hasName() && !J.isDeclaration() && !J.hasLocalLinkage()) + J.setLinkage(GlobalValue::InternalLinkage); - if (deleteIfDead(*J, NotDiscardableComdats)) { + if (deleteIfDead(J, NotDiscardableComdats)) { Changed = true; continue; } // If the alias can change at link time, nothing can be done - bail out. - if (J->isInterposable()) + if (J.isInterposable()) continue; - Constant *Aliasee = J->getAliasee(); + Constant *Aliasee = J.getAliasee(); GlobalValue *Target = dyn_cast<GlobalValue>(Aliasee->stripPointerCasts()); // We can't trivially replace the alias with the aliasee if the aliasee is // non-trivial in some way. We also can't replace the alias with the aliasee @@ -2455,31 +2508,31 @@ OptimizeGlobalAliases(Module &M, // Make all users of the alias use the aliasee instead. bool RenameTarget; - if (!hasUsesToReplace(*J, Used, RenameTarget)) + if (!hasUsesToReplace(J, Used, RenameTarget)) continue; - J->replaceAllUsesWith(ConstantExpr::getBitCast(Aliasee, J->getType())); + J.replaceAllUsesWith(ConstantExpr::getBitCast(Aliasee, J.getType())); ++NumAliasesResolved; Changed = true; if (RenameTarget) { // Give the aliasee the name, linkage and other attributes of the alias. - Target->takeName(&*J); - Target->setLinkage(J->getLinkage()); - Target->setDSOLocal(J->isDSOLocal()); - Target->setVisibility(J->getVisibility()); - Target->setDLLStorageClass(J->getDLLStorageClass()); + Target->takeName(&J); + Target->setLinkage(J.getLinkage()); + Target->setDSOLocal(J.isDSOLocal()); + Target->setVisibility(J.getVisibility()); + Target->setDLLStorageClass(J.getDLLStorageClass()); - if (Used.usedErase(&*J)) + if (Used.usedErase(&J)) Used.usedInsert(Target); - if (Used.compilerUsedErase(&*J)) + if (Used.compilerUsedErase(&J)) Used.compilerUsedInsert(Target); - } else if (mayHaveOtherReferences(*J, Used)) + } else if (mayHaveOtherReferences(J, Used)) continue; // Delete the alias. - M.getAliasList().erase(J); + M.getAliasList().erase(&J); ++NumAliasesRemoved; Changed = true; } @@ -2526,7 +2579,7 @@ static bool cxxDtorIsEmpty(const Function &Fn) { return false; for (auto &I : Fn.getEntryBlock()) { - if (isa<DbgInfoIntrinsic>(I)) + if (I.isDebugOrPseudoInst()) continue; if (isa<ReturnInst>(I)) return true; @@ -2552,12 +2605,11 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { // and remove them. bool Changed = false; - for (auto I = CXAAtExitFn->user_begin(), E = CXAAtExitFn->user_end(); - I != E;) { + for (User *U : llvm::make_early_inc_range(CXAAtExitFn->users())) { // We're only interested in calls. Theoretically, we could handle invoke // instructions as well, but neither llvm-gcc nor clang generate invokes // to __cxa_atexit. - CallInst *CI = dyn_cast<CallInst>(*I++); + CallInst *CI = dyn_cast<CallInst>(U); if (!CI) continue; @@ -2614,8 +2666,8 @@ static bool optimizeGlobalsInModule( }); // Optimize non-address-taken globals. - LocalChange |= - OptimizeGlobalVars(M, GetTLI, LookupDomTree, NotDiscardableComdats); + LocalChange |= OptimizeGlobalVars(M, GetTTI, GetTLI, LookupDomTree, + NotDiscardableComdats); // Resolve aliases, when possible. LocalChange |= OptimizeGlobalAliases(M, NotDiscardableComdats); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp index 365b269dc3bf..e7d698c42fcf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp @@ -154,11 +154,8 @@ static bool splitGlobals(Module &M) { return false; bool Changed = false; - for (auto I = M.global_begin(); I != M.global_end();) { - GlobalVariable &GV = *I; - ++I; + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) Changed |= splitGlobal(GV); - } return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp index adf9ffba5780..b8a314c54f18 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DIBuilder.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Mangler.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" @@ -33,6 +34,10 @@ using namespace llvm; using namespace IRSimilarity; +// A command flag to be used for debugging to exclude branches from similarity +// matching and outlining. +extern cl::opt<bool> DisableBranches; + // Set to true if the user wants the ir outliner to run on linkonceodr linkage // functions. This is false by default because the linker can dedupe linkonceodr // functions. Since the outliner is confined to a single module (modulo LTO), @@ -71,8 +76,12 @@ struct OutlinableGroup { /// for extraction. bool IgnoreGroup = false; - /// The return block for the overall function. - BasicBlock *EndBB = nullptr; + /// The return blocks for the overall function. + DenseMap<Value *, BasicBlock *> EndBBs; + + /// The PHIBlocks with their corresponding return block based on the return + /// value as the key. + DenseMap<Value *, BasicBlock *> PHIBlocks; /// A set containing the different GVN store sets needed. Each array contains /// a sorted list of the different values that need to be stored into output @@ -87,6 +96,14 @@ struct OutlinableGroup { /// index in ArgumentTypes is an output argument. unsigned NumAggregateInputs = 0; + /// The mapping of the canonical numbering of the values in outlined sections + /// to specific arguments. + DenseMap<unsigned, unsigned> CanonicalNumberToAggArg; + + /// The number of branches in the region target a basic block that is outside + /// of the region. + unsigned BranchesToOutside = 0; + /// The number of instructions that will be outlined by extracting \ref /// Regions. InstructionCost Benefit = 0; @@ -118,20 +135,67 @@ struct OutlinableGroup { /// \param SourceBB - the BasicBlock to pull Instructions from. /// \param TargetBB - the BasicBlock to put Instruction into. static void moveBBContents(BasicBlock &SourceBB, BasicBlock &TargetBB) { - BasicBlock::iterator BBCurr, BBEnd, BBNext; - for (BBCurr = SourceBB.begin(), BBEnd = SourceBB.end(); BBCurr != BBEnd; - BBCurr = BBNext) { - BBNext = std::next(BBCurr); - BBCurr->moveBefore(TargetBB, TargetBB.end()); - } + for (Instruction &I : llvm::make_early_inc_range(SourceBB)) + I.moveBefore(TargetBB, TargetBB.end()); +} + +/// A function to sort the keys of \p Map, which must be a mapping of constant +/// values to basic blocks and return it in \p SortedKeys +/// +/// \param SortedKeys - The vector the keys will be return in and sorted. +/// \param Map - The DenseMap containing keys to sort. +static void getSortedConstantKeys(std::vector<Value *> &SortedKeys, + DenseMap<Value *, BasicBlock *> &Map) { + for (auto &VtoBB : Map) + SortedKeys.push_back(VtoBB.first); + + stable_sort(SortedKeys, [](const Value *LHS, const Value *RHS) { + const ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS); + const ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS); + assert(RHSC && "Not a constant integer in return value?"); + assert(LHSC && "Not a constant integer in return value?"); + + return LHSC->getLimitedValue() < RHSC->getLimitedValue(); + }); +} + +Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other, + Value *V) { + Optional<unsigned> GVN = Candidate->getGVN(V); + assert(GVN.hasValue() && "No GVN for incoming value"); + Optional<unsigned> CanonNum = Candidate->getCanonicalNum(*GVN); + Optional<unsigned> FirstGVN = Other.Candidate->fromCanonicalNum(*CanonNum); + Optional<Value *> FoundValueOpt = Other.Candidate->fromGVN(*FirstGVN); + return FoundValueOpt.getValueOr(nullptr); } void OutlinableRegion::splitCandidate() { assert(!CandidateSplit && "Candidate already split!"); + Instruction *BackInst = Candidate->backInstruction(); + + Instruction *EndInst = nullptr; + // Check whether the last instruction is a terminator, if it is, we do + // not split on the following instruction. We leave the block as it is. We + // also check that this is not the last instruction in the Module, otherwise + // the check for whether the current following instruction matches the + // previously recorded instruction will be incorrect. + if (!BackInst->isTerminator() || + BackInst->getParent() != &BackInst->getFunction()->back()) { + EndInst = Candidate->end()->Inst; + assert(EndInst && "Expected an end instruction?"); + } + + // We check if the current instruction following the last instruction in the + // region is the same as the recorded instruction following the last + // instruction. If they do not match, there could be problems in rewriting + // the program after outlining, so we ignore it. + if (!BackInst->isTerminator() && + EndInst != BackInst->getNextNonDebugInstruction()) + return; + Instruction *StartInst = (*Candidate->begin()).Inst; - Instruction *EndInst = (*Candidate->end()).Inst; - assert(StartInst && EndInst && "Expected a start and end instruction?"); + assert(StartInst && "Expected a start instruction?"); StartBB = StartInst->getParent(); PrevBB = StartBB; @@ -153,13 +217,20 @@ void OutlinableRegion::splitCandidate() { std::string OriginalName = PrevBB->getName().str(); StartBB = PrevBB->splitBasicBlock(StartInst, OriginalName + "_to_outline"); - - // This is the case for the inner block since we do not have to include - // multiple blocks. - EndBB = StartBB; - FollowBB = EndBB->splitBasicBlock(EndInst, OriginalName + "_after_outline"); + PrevBB->replaceSuccessorsPhiUsesWith(PrevBB, StartBB); CandidateSplit = true; + if (!BackInst->isTerminator()) { + EndBB = EndInst->getParent(); + FollowBB = EndBB->splitBasicBlock(EndInst, OriginalName + "_after_outline"); + EndBB->replaceSuccessorsPhiUsesWith(EndBB, FollowBB); + FollowBB->replaceSuccessorsPhiUsesWith(PrevBB, FollowBB); + return; + } + + EndBB = BackInst->getParent(); + EndsInBranch = true; + FollowBB = nullptr; } void OutlinableRegion::reattachCandidate() { @@ -180,7 +251,6 @@ void OutlinableRegion::reattachCandidate() { // inst3 // inst4 assert(StartBB != nullptr && "StartBB for Candidate is not defined!"); - assert(FollowBB != nullptr && "StartBB for Candidate is not defined!"); // StartBB should only have one predecessor since we put an unconditional // branch at the end of PrevBB when we split the BasicBlock. @@ -189,21 +259,24 @@ void OutlinableRegion::reattachCandidate() { "No Predecessor for the region start basic block!"); assert(PrevBB->getTerminator() && "Terminator removed from PrevBB!"); - assert(EndBB->getTerminator() && "Terminator removed from EndBB!"); PrevBB->getTerminator()->eraseFromParent(); - EndBB->getTerminator()->eraseFromParent(); moveBBContents(*StartBB, *PrevBB); BasicBlock *PlacementBB = PrevBB; if (StartBB != EndBB) PlacementBB = EndBB; - moveBBContents(*FollowBB, *PlacementBB); + if (!EndsInBranch && PlacementBB->getUniqueSuccessor() != nullptr) { + assert(FollowBB != nullptr && "FollowBB for Candidate is not defined!"); + assert(PlacementBB->getTerminator() && "Terminator removed from EndBB!"); + PlacementBB->getTerminator()->eraseFromParent(); + moveBBContents(*FollowBB, *PlacementBB); + PlacementBB->replaceSuccessorsPhiUsesWith(FollowBB, PlacementBB); + FollowBB->eraseFromParent(); + } PrevBB->replaceSuccessorsPhiUsesWith(StartBB, PrevBB); - PrevBB->replaceSuccessorsPhiUsesWith(FollowBB, PlacementBB); StartBB->eraseFromParent(); - FollowBB->eraseFromParent(); // Make sure to save changes back to the StartBB. StartBB = PrevBB; @@ -261,8 +334,9 @@ InstructionCost OutlinableRegion::getBenefit(TargetTransformInfo &TTI) { // division instruction for targets that have a native division instruction. // To be overly conservative, we only add 1 to the number of instructions for // each division instruction. - for (Instruction &I : *StartBB) { - switch (I.getOpcode()) { + for (IRInstructionData &ID : *Candidate) { + Instruction *I = ID.Inst; + switch (I->getOpcode()) { case Instruction::FDiv: case Instruction::FRem: case Instruction::SDiv: @@ -272,7 +346,7 @@ InstructionCost OutlinableRegion::getBenefit(TargetTransformInfo &TTI) { Benefit += 1; break; default: - Benefit += TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); + Benefit += TTI.getInstructionCost(I, TargetTransformInfo::TCK_CodeSize); break; } } @@ -373,8 +447,24 @@ Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, unsigned FunctionNameSuffix) { assert(!Group.OutlinedFunction && "Function is already defined!"); + Type *RetTy = Type::getVoidTy(M.getContext()); + // All extracted functions _should_ have the same return type at this point + // since the similarity identifier ensures that all branches outside of the + // region occur in the same place. + + // NOTE: Should we ever move to the model that uses a switch at every point + // needed, meaning that we could branch within the region or out, it is + // possible that we will need to switch to using the most general case all of + // the time. + for (OutlinableRegion *R : Group.Regions) { + Type *ExtractedFuncType = R->ExtractedFunction->getReturnType(); + if ((RetTy->isVoidTy() && !ExtractedFuncType->isVoidTy()) || + (RetTy->isIntegerTy(1) && ExtractedFuncType->isIntegerTy(16))) + RetTy = ExtractedFuncType; + } + Group.OutlinedFunctionType = FunctionType::get( - Type::getVoidTy(M.getContext()), Group.ArgumentTypes, false); + RetTy, Group.ArgumentTypes, false); // These functions will only be called from within the same module, so // we can set an internal linkage. @@ -430,21 +520,23 @@ Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, /// /// \param [in] Old - The function to move the basic blocks from. /// \param [in] New - The function to move the basic blocks to. -/// \returns the first return block for the function in New. -static BasicBlock *moveFunctionData(Function &Old, Function &New) { - Function::iterator CurrBB, NextBB, FinalBB; - BasicBlock *NewEnd = nullptr; - std::vector<Instruction *> DebugInsts; - for (CurrBB = Old.begin(), FinalBB = Old.end(); CurrBB != FinalBB; - CurrBB = NextBB) { - NextBB = std::next(CurrBB); - CurrBB->removeFromParent(); - CurrBB->insertInto(&New); - Instruction *I = CurrBB->getTerminator(); - if (isa<ReturnInst>(I)) - NewEnd = &(*CurrBB); - - for (Instruction &Val : *CurrBB) { +/// \param [out] NewEnds - The return blocks of the new overall function. +static void moveFunctionData(Function &Old, Function &New, + DenseMap<Value *, BasicBlock *> &NewEnds) { + for (BasicBlock &CurrBB : llvm::make_early_inc_range(Old)) { + CurrBB.removeFromParent(); + CurrBB.insertInto(&New); + Instruction *I = CurrBB.getTerminator(); + + // For each block we find a return instruction is, it is a potential exit + // path for the function. We keep track of each block based on the return + // value here. + if (ReturnInst *RI = dyn_cast<ReturnInst>(I)) + NewEnds.insert(std::make_pair(RI->getReturnValue(), &CurrBB)); + + std::vector<Instruction *> DebugInsts; + + for (Instruction &Val : CurrBB) { // We must handle the scoping of called functions differently than // other outlined instructions. if (!isa<CallInst>(&Val)) { @@ -476,8 +568,7 @@ static BasicBlock *moveFunctionData(Function &Old, Function &New) { I->eraseFromParent(); } - assert(NewEnd && "No return instruction for new function?"); - return NewEnd; + assert(NewEnds.size() > 0 && "No return instruction for new function?"); } /// Find the the constants that will need to be lifted into arguments @@ -664,11 +755,22 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region, // function to account for the extracted constants, we have two different // counters as we find extracted arguments, and as we come across overall // arguments. + + // Additionally, in our first pass, for the first extracted function, + // we find argument locations for the canonical value numbering. This + // numbering overrides any discovered location for the extracted code. for (unsigned InputVal : InputGVNs) { + Optional<unsigned> CanonicalNumberOpt = C.getCanonicalNum(InputVal); + assert(CanonicalNumberOpt.hasValue() && "Canonical number not found?"); + unsigned CanonicalNumber = CanonicalNumberOpt.getValue(); + Optional<Value *> InputOpt = C.fromGVN(InputVal); assert(InputOpt.hasValue() && "Global value number not found?"); Value *Input = InputOpt.getValue(); + DenseMap<unsigned, unsigned>::iterator AggArgIt = + Group.CanonicalNumberToAggArg.find(CanonicalNumber); + if (!Group.InputTypesSet) { Group.ArgumentTypes.push_back(Input->getType()); // If the input value has a swifterr attribute, make sure to mark the @@ -684,17 +786,34 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region, // Check if we have a constant. If we do add it to the overall argument // number to Constant map for the region, and continue to the next input. if (Constant *CST = dyn_cast<Constant>(Input)) { - Region.AggArgToConstant.insert(std::make_pair(TypeIndex, CST)); + if (AggArgIt != Group.CanonicalNumberToAggArg.end()) + Region.AggArgToConstant.insert(std::make_pair(AggArgIt->second, CST)); + else { + Group.CanonicalNumberToAggArg.insert( + std::make_pair(CanonicalNumber, TypeIndex)); + Region.AggArgToConstant.insert(std::make_pair(TypeIndex, CST)); + } TypeIndex++; continue; } // It is not a constant, we create the mapping from extracted argument list - // to the overall argument list. + // to the overall argument list, using the canonical location, if it exists. assert(ArgInputs.count(Input) && "Input cannot be found!"); - Region.ExtractedArgToAgg.insert(std::make_pair(OriginalIndex, TypeIndex)); - Region.AggArgToExtracted.insert(std::make_pair(TypeIndex, OriginalIndex)); + if (AggArgIt != Group.CanonicalNumberToAggArg.end()) { + if (OriginalIndex != AggArgIt->second) + Region.ChangedArgOrder = true; + Region.ExtractedArgToAgg.insert( + std::make_pair(OriginalIndex, AggArgIt->second)); + Region.AggArgToExtracted.insert( + std::make_pair(AggArgIt->second, OriginalIndex)); + } else { + Group.CanonicalNumberToAggArg.insert( + std::make_pair(CanonicalNumber, TypeIndex)); + Region.ExtractedArgToAgg.insert(std::make_pair(OriginalIndex, TypeIndex)); + Region.AggArgToExtracted.insert(std::make_pair(TypeIndex, OriginalIndex)); + } OriginalIndex++; TypeIndex++; } @@ -718,10 +837,41 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region, /// \param [in] Outputs - The values found by the code extractor. static void findExtractedOutputToOverallOutputMapping(OutlinableRegion &Region, - ArrayRef<Value *> Outputs) { + SetVector<Value *> &Outputs) { OutlinableGroup &Group = *Region.Parent; IRSimilarityCandidate &C = *Region.Candidate; + SmallVector<BasicBlock *> BE; + DenseSet<BasicBlock *> BBSet; + C.getBasicBlocks(BBSet, BE); + + // Find the exits to the region. + SmallPtrSet<BasicBlock *, 1> Exits; + for (BasicBlock *Block : BE) + for (BasicBlock *Succ : successors(Block)) + if (!BBSet.contains(Succ)) + Exits.insert(Succ); + + // After determining which blocks exit to PHINodes, we add these PHINodes to + // the set of outputs to be processed. We also check the incoming values of + // the PHINodes for whether they should no longer be considered outputs. + for (BasicBlock *ExitBB : Exits) { + for (PHINode &PN : ExitBB->phis()) { + // Find all incoming values from the outlining region. + SmallVector<unsigned, 2> IncomingVals; + for (unsigned Idx = 0; Idx < PN.getNumIncomingValues(); ++Idx) + if (BBSet.contains(PN.getIncomingBlock(Idx))) + IncomingVals.push_back(Idx); + + // Do not process PHI if there is one (or fewer) predecessor from region. + if (IncomingVals.size() <= 1) + continue; + + Region.IgnoreRegion = true; + return; + } + } + // This counts the argument number in the extracted function. unsigned OriginalIndex = Region.NumExtractedInputs; @@ -797,7 +947,7 @@ void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region, // Map the outputs found by the CodeExtractor to the arguments found for // the overall function. - findExtractedOutputToOverallOutputMapping(Region, Outputs.getArrayRef()); + findExtractedOutputToOverallOutputMapping(Region, Outputs); } /// Replace the extracted function in the Region with a call to the overall @@ -820,9 +970,10 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { assert(AggFunc && "Function to replace with is nullptr?"); // If the arguments are the same size, there are not values that need to be - // made argument, or different output registers to handle. We can simply - // replace the called function in this case. - if (AggFunc->arg_size() == Call->arg_size()) { + // made into an argument, the argument ordering has not been change, or + // different output registers to handle. We can simply replace the called + // function in this case. + if (!Region.ChangedArgOrder && AggFunc->arg_size() == Call->arg_size()) { LLVM_DEBUG(dbgs() << "Replace call to " << *Call << " with call to " << *AggFunc << " with same number of arguments\n"); Call->setCalledFunction(AggFunc); @@ -895,6 +1046,9 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { // Transfer any debug information. Call->setDebugLoc(Region.Call->getDebugLoc()); + // Since our output may determine which branch we go to, we make sure to + // propogate this new call value through the module. + OldCall->replaceAllUsesWith(Call); // Remove the old instruction. OldCall->eraseFromParent(); @@ -913,13 +1067,23 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { // region with the arguments of the function for an OutlinableGroup. // /// \param [in] Region - The region of extracted code to be changed. -/// \param [in,out] OutputBB - The BasicBlock for the output stores for this +/// \param [in,out] OutputBBs - The BasicBlock for the output stores for this /// region. -static void replaceArgumentUses(OutlinableRegion &Region, - BasicBlock *OutputBB) { +/// \param [in] FirstFunction - A flag to indicate whether we are using this +/// function to define the overall outlined function for all the regions, or +/// if we are operating on one of the following regions. +static void +replaceArgumentUses(OutlinableRegion &Region, + DenseMap<Value *, BasicBlock *> &OutputBBs, + bool FirstFunction = false) { OutlinableGroup &Group = *Region.Parent; assert(Region.ExtractedFunction && "Region has no extracted function?"); + Function *DominatingFunction = Region.ExtractedFunction; + if (FirstFunction) + DominatingFunction = Group.OutlinedFunction; + DominatorTree DT(*DominatingFunction); + for (unsigned ArgIdx = 0; ArgIdx < Region.ExtractedFunction->arg_size(); ArgIdx++) { assert(Region.ExtractedArgToAgg.find(ArgIdx) != @@ -946,11 +1110,53 @@ static void replaceArgumentUses(OutlinableRegion &Region, assert(InstAsUser && "User is nullptr!"); Instruction *I = cast<Instruction>(InstAsUser); - I->setDebugLoc(DebugLoc()); - LLVM_DEBUG(dbgs() << "Move store for instruction " << *I << " to " - << *OutputBB << "\n"); + BasicBlock *BB = I->getParent(); + SmallVector<BasicBlock *, 4> Descendants; + DT.getDescendants(BB, Descendants); + bool EdgeAdded = false; + if (Descendants.size() == 0) { + EdgeAdded = true; + DT.insertEdge(&DominatingFunction->getEntryBlock(), BB); + DT.getDescendants(BB, Descendants); + } + + // Iterate over the following blocks, looking for return instructions, + // if we find one, find the corresponding output block for the return value + // and move our store instruction there. + for (BasicBlock *DescendBB : Descendants) { + ReturnInst *RI = dyn_cast<ReturnInst>(DescendBB->getTerminator()); + if (!RI) + continue; + Value *RetVal = RI->getReturnValue(); + auto VBBIt = OutputBBs.find(RetVal); + assert(VBBIt != OutputBBs.end() && "Could not find output value!"); + + // If this is storing a PHINode, we must make sure it is included in the + // overall function. + StoreInst *SI = cast<StoreInst>(I); + + Value *ValueOperand = SI->getValueOperand(); + + StoreInst *NewI = cast<StoreInst>(I->clone()); + NewI->setDebugLoc(DebugLoc()); + BasicBlock *OutputBB = VBBIt->second; + OutputBB->getInstList().push_back(NewI); + LLVM_DEBUG(dbgs() << "Move store for instruction " << *I << " to " + << *OutputBB << "\n"); - I->moveBefore(*OutputBB, OutputBB->end()); + if (FirstFunction) + continue; + Value *CorrVal = + Region.findCorrespondingValueIn(*Group.Regions[0], ValueOperand); + assert(CorrVal && "Value is nullptr?"); + NewI->setOperand(0, CorrVal); + } + + // If we added an edge for basic blocks without a predecessor, we remove it + // here. + if (EdgeAdded) + DT.deleteEdge(&DominatingFunction->getEntryBlock(), BB); + I->eraseFromParent(); LLVM_DEBUG(dbgs() << "Replacing uses of output " << *Arg << " in function " << *Region.ExtractedFunction << " with " << *AggArg @@ -990,69 +1196,53 @@ void replaceConstants(OutlinableRegion &Region) { } } -/// For the given function, find all the nondebug or lifetime instructions, -/// and return them as a vector. Exclude any blocks in \p ExludeBlocks. -/// -/// \param [in] F - The function we collect the instructions from. -/// \param [in] ExcludeBlocks - BasicBlocks to ignore. -/// \returns the list of instructions extracted. -static std::vector<Instruction *> -collectRelevantInstructions(Function &F, - DenseSet<BasicBlock *> &ExcludeBlocks) { - std::vector<Instruction *> RelevantInstructions; - - for (BasicBlock &BB : F) { - if (ExcludeBlocks.contains(&BB)) - continue; - - for (Instruction &Inst : BB) { - if (Inst.isLifetimeStartOrEnd()) - continue; - if (isa<DbgInfoIntrinsic>(Inst)) - continue; - - RelevantInstructions.push_back(&Inst); - } - } - - return RelevantInstructions; -} - /// It is possible that there is a basic block that already performs the same /// stores. This returns a duplicate block, if it exists /// -/// \param OutputBB [in] the block we are looking for a duplicate of. +/// \param OutputBBs [in] the blocks we are looking for a duplicate of. /// \param OutputStoreBBs [in] The existing output blocks. /// \returns an optional value with the number output block if there is a match. -Optional<unsigned> -findDuplicateOutputBlock(BasicBlock *OutputBB, - ArrayRef<BasicBlock *> OutputStoreBBs) { +Optional<unsigned> findDuplicateOutputBlock( + DenseMap<Value *, BasicBlock *> &OutputBBs, + std::vector<DenseMap<Value *, BasicBlock *>> &OutputStoreBBs) { - bool WrongInst = false; - bool WrongSize = false; + bool Mismatch = false; unsigned MatchingNum = 0; - for (BasicBlock *CompBB : OutputStoreBBs) { - WrongInst = false; - if (CompBB->size() - 1 != OutputBB->size()) { - WrongSize = true; - MatchingNum++; - continue; - } - - WrongSize = false; - BasicBlock::iterator NIt = OutputBB->begin(); - for (Instruction &I : *CompBB) { - if (isa<BranchInst>(&I)) - continue; + // We compare the new set output blocks to the other sets of output blocks. + // If they are the same number, and have identical instructions, they are + // considered to be the same. + for (DenseMap<Value *, BasicBlock *> &CompBBs : OutputStoreBBs) { + Mismatch = false; + for (std::pair<Value *, BasicBlock *> &VToB : CompBBs) { + DenseMap<Value *, BasicBlock *>::iterator OutputBBIt = + OutputBBs.find(VToB.first); + if (OutputBBIt == OutputBBs.end()) { + Mismatch = true; + break; + } - if (!I.isIdenticalTo(&(*NIt))) { - WrongInst = true; + BasicBlock *CompBB = VToB.second; + BasicBlock *OutputBB = OutputBBIt->second; + if (CompBB->size() - 1 != OutputBB->size()) { + Mismatch = true; break; } - NIt++; + BasicBlock::iterator NIt = OutputBB->begin(); + for (Instruction &I : *CompBB) { + if (isa<BranchInst>(&I)) + continue; + + if (!I.isIdenticalTo(&(*NIt))) { + Mismatch = true; + break; + } + + NIt++; + } } - if (!WrongInst && !WrongSize) + + if (!Mismatch) return MatchingNum; MatchingNum++; @@ -1061,95 +1251,130 @@ findDuplicateOutputBlock(BasicBlock *OutputBB, return None; } +/// Remove empty output blocks from the outlined region. +/// +/// \param BlocksToPrune - Mapping of return values output blocks for the \p +/// Region. +/// \param Region - The OutlinableRegion we are analyzing. +static bool +analyzeAndPruneOutputBlocks(DenseMap<Value *, BasicBlock *> &BlocksToPrune, + OutlinableRegion &Region) { + bool AllRemoved = true; + Value *RetValueForBB; + BasicBlock *NewBB; + SmallVector<Value *, 4> ToRemove; + // Iterate over the output blocks created in the outlined section. + for (std::pair<Value *, BasicBlock *> &VtoBB : BlocksToPrune) { + RetValueForBB = VtoBB.first; + NewBB = VtoBB.second; + + // If there are no instructions, we remove it from the module, and also + // mark the value for removal from the return value to output block mapping. + if (NewBB->size() == 0) { + NewBB->eraseFromParent(); + ToRemove.push_back(RetValueForBB); + continue; + } + + // Mark that we could not remove all the blocks since they were not all + // empty. + AllRemoved = false; + } + + // Remove the return value from the mapping. + for (Value *V : ToRemove) + BlocksToPrune.erase(V); + + // Mark the region as having the no output scheme. + if (AllRemoved) + Region.OutputBlockNum = -1; + + return AllRemoved; +} + /// For the outlined section, move needed the StoreInsts for the output /// registers into their own block. Then, determine if there is a duplicate /// output block already created. /// /// \param [in] OG - The OutlinableGroup of regions to be outlined. /// \param [in] Region - The OutlinableRegion that is being analyzed. -/// \param [in,out] OutputBB - the block that stores for this region will be +/// \param [in,out] OutputBBs - the blocks that stores for this region will be /// placed in. -/// \param [in] EndBB - the final block of the extracted function. +/// \param [in] EndBBs - the final blocks of the extracted function. /// \param [in] OutputMappings - OutputMappings the mapping of values that have /// been replaced by a new output value. /// \param [in,out] OutputStoreBBs - The existing output blocks. -static void -alignOutputBlockWithAggFunc(OutlinableGroup &OG, OutlinableRegion &Region, - BasicBlock *OutputBB, BasicBlock *EndBB, - const DenseMap<Value *, Value *> &OutputMappings, - std::vector<BasicBlock *> &OutputStoreBBs) { - DenseSet<unsigned> ValuesToFind(Region.GVNStores.begin(), - Region.GVNStores.end()); - - // We iterate over the instructions in the extracted function, and find the - // global value number of the instructions. If we find a value that should - // be contained in a store, we replace the uses of the value with the value - // from the overall function, so that the store is storing the correct - // value from the overall function. - DenseSet<BasicBlock *> ExcludeBBs(OutputStoreBBs.begin(), - OutputStoreBBs.end()); - ExcludeBBs.insert(OutputBB); - std::vector<Instruction *> ExtractedFunctionInsts = - collectRelevantInstructions(*(Region.ExtractedFunction), ExcludeBBs); - std::vector<Instruction *> OverallFunctionInsts = - collectRelevantInstructions(*OG.OutlinedFunction, ExcludeBBs); - - assert(ExtractedFunctionInsts.size() == OverallFunctionInsts.size() && - "Number of relevant instructions not equal!"); - - unsigned NumInstructions = ExtractedFunctionInsts.size(); - for (unsigned Idx = 0; Idx < NumInstructions; Idx++) { - Value *V = ExtractedFunctionInsts[Idx]; - - if (OutputMappings.find(V) != OutputMappings.end()) - V = OutputMappings.find(V)->second; - Optional<unsigned> GVN = Region.Candidate->getGVN(V); - - // If we have found one of the stored values for output, replace the value - // with the corresponding one from the overall function. - if (GVN.hasValue() && ValuesToFind.erase(GVN.getValue())) { - V->replaceAllUsesWith(OverallFunctionInsts[Idx]); - if (ValuesToFind.size() == 0) - break; - } - - if (ValuesToFind.size() == 0) - break; - } - - assert(ValuesToFind.size() == 0 && "Not all store values were handled!"); - - // If the size of the block is 0, then there are no stores, and we do not - // need to save this block. - if (OutputBB->size() == 0) { - Region.OutputBlockNum = -1; - OutputBB->eraseFromParent(); +static void alignOutputBlockWithAggFunc( + OutlinableGroup &OG, OutlinableRegion &Region, + DenseMap<Value *, BasicBlock *> &OutputBBs, + DenseMap<Value *, BasicBlock *> &EndBBs, + const DenseMap<Value *, Value *> &OutputMappings, + std::vector<DenseMap<Value *, BasicBlock *>> &OutputStoreBBs) { + // If none of the output blocks have any instructions, this means that we do + // not have to determine if it matches any of the other output schemes, and we + // don't have to do anything else. + if (analyzeAndPruneOutputBlocks(OutputBBs, Region)) return; - } - // Determine is there is a duplicate block. + // Determine is there is a duplicate set of blocks. Optional<unsigned> MatchingBB = - findDuplicateOutputBlock(OutputBB, OutputStoreBBs); + findDuplicateOutputBlock(OutputBBs, OutputStoreBBs); - // If there is, we remove the new output block. If it does not, - // we add it to our list of output blocks. + // If there is, we remove the new output blocks. If it does not, + // we add it to our list of sets of output blocks. if (MatchingBB.hasValue()) { LLVM_DEBUG(dbgs() << "Set output block for region in function" << Region.ExtractedFunction << " to " << MatchingBB.getValue()); Region.OutputBlockNum = MatchingBB.getValue(); - OutputBB->eraseFromParent(); + for (std::pair<Value *, BasicBlock *> &VtoBB : OutputBBs) + VtoBB.second->eraseFromParent(); return; } Region.OutputBlockNum = OutputStoreBBs.size(); - LLVM_DEBUG(dbgs() << "Create output block for region in" - << Region.ExtractedFunction << " to " - << *OutputBB); - OutputStoreBBs.push_back(OutputBB); - BranchInst::Create(EndBB, OutputBB); + Value *RetValueForBB; + BasicBlock *NewBB; + OutputStoreBBs.push_back(DenseMap<Value *, BasicBlock *>()); + for (std::pair<Value *, BasicBlock *> &VtoBB : OutputBBs) { + RetValueForBB = VtoBB.first; + NewBB = VtoBB.second; + DenseMap<Value *, BasicBlock *>::iterator VBBIt = + EndBBs.find(RetValueForBB); + LLVM_DEBUG(dbgs() << "Create output block for region in" + << Region.ExtractedFunction << " to " + << *NewBB); + BranchInst::Create(VBBIt->second, NewBB); + OutputStoreBBs.back().insert(std::make_pair(RetValueForBB, NewBB)); + } +} + +/// Takes in a mapping, \p OldMap of ConstantValues to BasicBlocks, sorts keys, +/// before creating a basic block for each \p NewMap, and inserting into the new +/// block. Each BasicBlock is named with the scheme "<basename>_<key_idx>". +/// +/// \param OldMap [in] - The mapping to base the new mapping off of. +/// \param NewMap [out] - The output mapping using the keys of \p OldMap. +/// \param ParentFunc [in] - The function to put the new basic block in. +/// \param BaseName [in] - The start of the BasicBlock names to be appended to +/// by an index value. +static void createAndInsertBasicBlocks(DenseMap<Value *, BasicBlock *> &OldMap, + DenseMap<Value *, BasicBlock *> &NewMap, + Function *ParentFunc, Twine BaseName) { + unsigned Idx = 0; + std::vector<Value *> SortedKeys; + + getSortedConstantKeys(SortedKeys, OldMap); + + for (Value *RetVal : SortedKeys) { + BasicBlock *NewBB = BasicBlock::Create( + ParentFunc->getContext(), + Twine(BaseName) + Twine("_") + Twine(static_cast<unsigned>(Idx++)), + ParentFunc); + NewMap.insert(std::make_pair(RetVal, NewBB)); + } } /// Create the switch statement for outlined function to differentiate between @@ -1159,50 +1384,74 @@ alignOutputBlockWithAggFunc(OutlinableGroup &OG, OutlinableRegion &Region, /// matches the needed stores for the extracted section. /// \param [in] M - The module we are outlining from. /// \param [in] OG - The group of regions to be outlined. -/// \param [in] EndBB - The final block of the extracted function. +/// \param [in] EndBBs - The final blocks of the extracted function. /// \param [in,out] OutputStoreBBs - The existing output blocks. -void createSwitchStatement(Module &M, OutlinableGroup &OG, BasicBlock *EndBB, - ArrayRef<BasicBlock *> OutputStoreBBs) { +void createSwitchStatement( + Module &M, OutlinableGroup &OG, DenseMap<Value *, BasicBlock *> &EndBBs, + std::vector<DenseMap<Value *, BasicBlock *>> &OutputStoreBBs) { // We only need the switch statement if there is more than one store // combination. if (OG.OutputGVNCombinations.size() > 1) { Function *AggFunc = OG.OutlinedFunction; - // Create a final block - BasicBlock *ReturnBlock = - BasicBlock::Create(M.getContext(), "final_block", AggFunc); - Instruction *Term = EndBB->getTerminator(); - Term->moveBefore(*ReturnBlock, ReturnBlock->end()); - // Put the switch statement in the old end basic block for the function with - // a fall through to the new return block - LLVM_DEBUG(dbgs() << "Create switch statement in " << *AggFunc << " for " - << OutputStoreBBs.size() << "\n"); - SwitchInst *SwitchI = - SwitchInst::Create(AggFunc->getArg(AggFunc->arg_size() - 1), - ReturnBlock, OutputStoreBBs.size(), EndBB); - - unsigned Idx = 0; - for (BasicBlock *BB : OutputStoreBBs) { - SwitchI->addCase(ConstantInt::get(Type::getInt32Ty(M.getContext()), Idx), - BB); - Term = BB->getTerminator(); - Term->setSuccessor(0, ReturnBlock); - Idx++; + // Create a final block for each different return block. + DenseMap<Value *, BasicBlock *> ReturnBBs; + createAndInsertBasicBlocks(OG.EndBBs, ReturnBBs, AggFunc, "final_block"); + + for (std::pair<Value *, BasicBlock *> &RetBlockPair : ReturnBBs) { + std::pair<Value *, BasicBlock *> &OutputBlock = + *OG.EndBBs.find(RetBlockPair.first); + BasicBlock *ReturnBlock = RetBlockPair.second; + BasicBlock *EndBB = OutputBlock.second; + Instruction *Term = EndBB->getTerminator(); + // Move the return value to the final block instead of the original exit + // stub. + Term->moveBefore(*ReturnBlock, ReturnBlock->end()); + // Put the switch statement in the old end basic block for the function + // with a fall through to the new return block. + LLVM_DEBUG(dbgs() << "Create switch statement in " << *AggFunc << " for " + << OutputStoreBBs.size() << "\n"); + SwitchInst *SwitchI = + SwitchInst::Create(AggFunc->getArg(AggFunc->arg_size() - 1), + ReturnBlock, OutputStoreBBs.size(), EndBB); + + unsigned Idx = 0; + for (DenseMap<Value *, BasicBlock *> &OutputStoreBB : OutputStoreBBs) { + DenseMap<Value *, BasicBlock *>::iterator OSBBIt = + OutputStoreBB.find(OutputBlock.first); + + if (OSBBIt == OutputStoreBB.end()) + continue; + + BasicBlock *BB = OSBBIt->second; + SwitchI->addCase( + ConstantInt::get(Type::getInt32Ty(M.getContext()), Idx), BB); + Term = BB->getTerminator(); + Term->setSuccessor(0, ReturnBlock); + Idx++; + } } return; } - // If there needs to be stores, move them from the output block to the end - // block to save on branching instructions. + // If there needs to be stores, move them from the output blocks to their + // corresponding ending block. if (OutputStoreBBs.size() == 1) { LLVM_DEBUG(dbgs() << "Move store instructions to the end block in " << *OG.OutlinedFunction << "\n"); - BasicBlock *OutputBlock = OutputStoreBBs[0]; - Instruction *Term = OutputBlock->getTerminator(); - Term->eraseFromParent(); - Term = EndBB->getTerminator(); - moveBBContents(*OutputBlock, *EndBB); - Term->moveBefore(*EndBB, EndBB->end()); - OutputBlock->eraseFromParent(); + DenseMap<Value *, BasicBlock *> OutputBlocks = OutputStoreBBs[0]; + for (std::pair<Value *, BasicBlock *> &VBPair : OutputBlocks) { + DenseMap<Value *, BasicBlock *>::iterator EndBBIt = + EndBBs.find(VBPair.first); + assert(EndBBIt != EndBBs.end() && "Could not find end block"); + BasicBlock *EndBB = EndBBIt->second; + BasicBlock *OutputBB = VBPair.second; + Instruction *Term = OutputBB->getTerminator(); + Term->eraseFromParent(); + Term = EndBB->getTerminator(); + moveBBContents(*OutputBB, *EndBB); + Term->moveBefore(*EndBB, EndBB->end()); + OutputBB->eraseFromParent(); + } } } @@ -1217,42 +1466,44 @@ void createSwitchStatement(Module &M, OutlinableGroup &OG, BasicBlock *EndBB, /// set of stores needed for the different functions. /// \param [in,out] FuncsToRemove - Extracted functions to erase from module /// once outlining is complete. -static void fillOverallFunction(Module &M, OutlinableGroup &CurrentGroup, - std::vector<BasicBlock *> &OutputStoreBBs, - std::vector<Function *> &FuncsToRemove) { +static void fillOverallFunction( + Module &M, OutlinableGroup &CurrentGroup, + std::vector<DenseMap<Value *, BasicBlock *>> &OutputStoreBBs, + std::vector<Function *> &FuncsToRemove) { OutlinableRegion *CurrentOS = CurrentGroup.Regions[0]; // Move first extracted function's instructions into new function. LLVM_DEBUG(dbgs() << "Move instructions from " << *CurrentOS->ExtractedFunction << " to instruction " << *CurrentGroup.OutlinedFunction << "\n"); - - CurrentGroup.EndBB = moveFunctionData(*CurrentOS->ExtractedFunction, - *CurrentGroup.OutlinedFunction); + moveFunctionData(*CurrentOS->ExtractedFunction, + *CurrentGroup.OutlinedFunction, CurrentGroup.EndBBs); // Transfer the attributes from the function to the new function. - for (Attribute A : - CurrentOS->ExtractedFunction->getAttributes().getFnAttributes()) + for (Attribute A : CurrentOS->ExtractedFunction->getAttributes().getFnAttrs()) CurrentGroup.OutlinedFunction->addFnAttr(A); - // Create an output block for the first extracted function. - BasicBlock *NewBB = BasicBlock::Create( - M.getContext(), Twine("output_block_") + Twine(static_cast<unsigned>(0)), - CurrentGroup.OutlinedFunction); + // Create a new set of output blocks for the first extracted function. + DenseMap<Value *, BasicBlock *> NewBBs; + createAndInsertBasicBlocks(CurrentGroup.EndBBs, NewBBs, + CurrentGroup.OutlinedFunction, "output_block_0"); CurrentOS->OutputBlockNum = 0; - replaceArgumentUses(*CurrentOS, NewBB); + replaceArgumentUses(*CurrentOS, NewBBs, true); replaceConstants(*CurrentOS); - // If the new basic block has no new stores, we can erase it from the module. - // It it does, we create a branch instruction to the last basic block from the - // new one. - if (NewBB->size() == 0) { - CurrentOS->OutputBlockNum = -1; - NewBB->eraseFromParent(); - } else { - BranchInst::Create(CurrentGroup.EndBB, NewBB); - OutputStoreBBs.push_back(NewBB); + // We first identify if any output blocks are empty, if they are we remove + // them. We then create a branch instruction to the basic block to the return + // block for the function for each non empty output block. + if (!analyzeAndPruneOutputBlocks(NewBBs, *CurrentOS)) { + OutputStoreBBs.push_back(DenseMap<Value *, BasicBlock *>()); + for (std::pair<Value *, BasicBlock *> &VToBB : NewBBs) { + DenseMap<Value *, BasicBlock *>::iterator VBBIt = + CurrentGroup.EndBBs.find(VToBB.first); + BasicBlock *EndBB = VBBIt->second; + BranchInst::Create(EndBB, VToBB.second); + OutputStoreBBs.back().insert(VToBB); + } } // Replace the call to the extracted function with the outlined function. @@ -1268,25 +1519,28 @@ void IROutliner::deduplicateExtractedSections( std::vector<Function *> &FuncsToRemove, unsigned &OutlinedFunctionNum) { createFunction(M, CurrentGroup, OutlinedFunctionNum); - std::vector<BasicBlock *> OutputStoreBBs; + std::vector<DenseMap<Value *, BasicBlock *>> OutputStoreBBs; OutlinableRegion *CurrentOS; fillOverallFunction(M, CurrentGroup, OutputStoreBBs, FuncsToRemove); + std::vector<Value *> SortedKeys; for (unsigned Idx = 1; Idx < CurrentGroup.Regions.size(); Idx++) { CurrentOS = CurrentGroup.Regions[Idx]; AttributeFuncs::mergeAttributesForOutlining(*CurrentGroup.OutlinedFunction, *CurrentOS->ExtractedFunction); - // Create a new BasicBlock to hold the needed store instructions. - BasicBlock *NewBB = BasicBlock::Create( - M.getContext(), "output_block_" + std::to_string(Idx), - CurrentGroup.OutlinedFunction); - replaceArgumentUses(*CurrentOS, NewBB); + // Create a set of BasicBlocks, one for each return block, to hold the + // needed store instructions. + DenseMap<Value *, BasicBlock *> NewBBs; + createAndInsertBasicBlocks( + CurrentGroup.EndBBs, NewBBs, CurrentGroup.OutlinedFunction, + "output_block_" + Twine(static_cast<unsigned>(Idx))); - alignOutputBlockWithAggFunc(CurrentGroup, *CurrentOS, NewBB, - CurrentGroup.EndBB, OutputMappings, + replaceArgumentUses(*CurrentOS, NewBBs); + alignOutputBlockWithAggFunc(CurrentGroup, *CurrentOS, NewBBs, + CurrentGroup.EndBBs, OutputMappings, OutputStoreBBs); CurrentOS->Call = replaceCalledFunction(M, *CurrentOS); @@ -1294,11 +1548,78 @@ void IROutliner::deduplicateExtractedSections( } // Create a switch statement to handle the different output schemes. - createSwitchStatement(M, CurrentGroup, CurrentGroup.EndBB, OutputStoreBBs); + createSwitchStatement(M, CurrentGroup, CurrentGroup.EndBBs, OutputStoreBBs); OutlinedFunctionNum++; } +/// Checks that the next instruction in the InstructionDataList matches the +/// next instruction in the module. If they do not, there could be the +/// possibility that extra code has been inserted, and we must ignore it. +/// +/// \param ID - The IRInstructionData to check the next instruction of. +/// \returns true if the InstructionDataList and actual instruction match. +static bool nextIRInstructionDataMatchesNextInst(IRInstructionData &ID) { + // We check if there is a discrepancy between the InstructionDataList + // and the actual next instruction in the module. If there is, it means + // that an extra instruction was added, likely by the CodeExtractor. + + // Since we do not have any similarity data about this particular + // instruction, we cannot confidently outline it, and must discard this + // candidate. + IRInstructionDataList::iterator NextIDIt = std::next(ID.getIterator()); + Instruction *NextIDLInst = NextIDIt->Inst; + Instruction *NextModuleInst = nullptr; + if (!ID.Inst->isTerminator()) + NextModuleInst = ID.Inst->getNextNonDebugInstruction(); + else if (NextIDLInst != nullptr) + NextModuleInst = + &*NextIDIt->Inst->getParent()->instructionsWithoutDebug().begin(); + + if (NextIDLInst && NextIDLInst != NextModuleInst) + return false; + + return true; +} + +bool IROutliner::isCompatibleWithAlreadyOutlinedCode( + const OutlinableRegion &Region) { + IRSimilarityCandidate *IRSC = Region.Candidate; + unsigned StartIdx = IRSC->getStartIdx(); + unsigned EndIdx = IRSC->getEndIdx(); + + // A check to make sure that we are not about to attempt to outline something + // that has already been outlined. + for (unsigned Idx = StartIdx; Idx <= EndIdx; Idx++) + if (Outlined.contains(Idx)) + return false; + + // We check if the recorded instruction matches the actual next instruction, + // if it does not, we fix it in the InstructionDataList. + if (!Region.Candidate->backInstruction()->isTerminator()) { + Instruction *NewEndInst = + Region.Candidate->backInstruction()->getNextNonDebugInstruction(); + assert(NewEndInst && "Next instruction is a nullptr?"); + if (Region.Candidate->end()->Inst != NewEndInst) { + IRInstructionDataList *IDL = Region.Candidate->front()->IDL; + IRInstructionData *NewEndIRID = new (InstDataAllocator.Allocate()) + IRInstructionData(*NewEndInst, + InstructionClassifier.visit(*NewEndInst), *IDL); + + // Insert the first IRInstructionData of the new region after the + // last IRInstructionData of the IRSimilarityCandidate. + IDL->insert(Region.Candidate->end(), *NewEndIRID); + } + } + + return none_of(*IRSC, [this](IRInstructionData &ID) { + if (!nextIRInstructionDataMatchesNextInst(ID)) + return true; + + return !this->InstructionClassifier.visit(ID.Inst); + }); +} + void IROutliner::pruneIncompatibleRegions( std::vector<IRSimilarityCandidate> &CandidateVec, OutlinableGroup &CurrentGroup) { @@ -1310,6 +1631,15 @@ void IROutliner::pruneIncompatibleRegions( return LHS.getStartIdx() < RHS.getStartIdx(); }); + IRSimilarityCandidate &FirstCandidate = CandidateVec[0]; + // Since outlining a call and a branch instruction will be the same as only + // outlinining a call instruction, we ignore it as a space saving. + if (FirstCandidate.getLength() == 2) { + if (isa<CallInst>(FirstCandidate.front()->Inst) && + isa<BranchInst>(FirstCandidate.back()->Inst)) + return; + } + unsigned CurrentEndIdx = 0; for (IRSimilarityCandidate &IRSC : CandidateVec) { PreviouslyOutlined = false; @@ -1325,9 +1655,13 @@ void IROutliner::pruneIncompatibleRegions( if (PreviouslyOutlined) continue; - // TODO: If in the future we can outline across BasicBlocks, we will need to - // check all BasicBlocks contained in the region. - if (IRSC.getStartBB()->hasAddressTaken()) + // Check over the instructions, and if the basic block has its address + // taken for use somewhere else, we do not outline that block. + bool BBHasAddressTaken = any_of(IRSC, [](IRInstructionData &ID){ + return ID.Inst->getParent()->hasAddressTaken(); + }); + + if (BBHasAddressTaken) continue; if (IRSC.front()->Inst->getFunction()->hasLinkOnceODRLinkage() && @@ -1340,16 +1674,9 @@ void IROutliner::pruneIncompatibleRegions( continue; bool BadInst = any_of(IRSC, [this](IRInstructionData &ID) { - // We check if there is a discrepancy between the InstructionDataList - // and the actual next instruction in the module. If there is, it means - // that an extra instruction was added, likely by the CodeExtractor. - - // Since we do not have any similarity data about this particular - // instruction, we cannot confidently outline it, and must discard this - // candidate. - if (std::next(ID.getIterator())->Inst != - ID.Inst->getNextNonDebugInstruction()) + if (!nextIRInstructionDataMatchesNextInst(ID)) return true; + return !this->InstructionClassifier.visit(ID.Inst); }); @@ -1416,10 +1743,33 @@ static InstructionCost findCostForOutputBlocks(Module &M, OutlinableGroup &CurrentGroup, TargetTransformInfo &TTI) { InstructionCost OutputCost = 0; + unsigned NumOutputBranches = 0; + + IRSimilarityCandidate &Candidate = *CurrentGroup.Regions[0]->Candidate; + DenseSet<BasicBlock *> CandidateBlocks; + Candidate.getBasicBlocks(CandidateBlocks); + + // Count the number of different output branches that point to blocks outside + // of the region. + DenseSet<BasicBlock *> FoundBlocks; + for (IRInstructionData &ID : Candidate) { + if (!isa<BranchInst>(ID.Inst)) + continue; + + for (Value *V : ID.OperVals) { + BasicBlock *BB = static_cast<BasicBlock *>(V); + DenseSet<BasicBlock *>::iterator CBIt = CandidateBlocks.find(BB); + if (CBIt != CandidateBlocks.end() || FoundBlocks.contains(BB)) + continue; + FoundBlocks.insert(BB); + NumOutputBranches++; + } + } + + CurrentGroup.BranchesToOutside = NumOutputBranches; for (const ArrayRef<unsigned> &OutputUse : CurrentGroup.OutputGVNCombinations) { - IRSimilarityCandidate &Candidate = *CurrentGroup.Regions[0]->Candidate; for (unsigned GVN : OutputUse) { Optional<Value *> OV = Candidate.fromGVN(GVN); assert(OV.hasValue() && "Could not find value for GVN?"); @@ -1434,14 +1784,14 @@ static InstructionCost findCostForOutputBlocks(Module &M, LLVM_DEBUG(dbgs() << "Adding: " << StoreCost << " instructions to cost for output of type " << *V->getType() << "\n"); - OutputCost += StoreCost; + OutputCost += StoreCost * NumOutputBranches; } InstructionCost BranchCost = TTI.getCFInstrCost(Instruction::Br, TargetTransformInfo::TCK_CodeSize); LLVM_DEBUG(dbgs() << "Adding " << BranchCost << " to the current cost for" << " a branch instruction\n"); - OutputCost += BranchCost; + OutputCost += BranchCost * NumOutputBranches; } // If there is more than one output scheme, we must have a comparison and @@ -1460,7 +1810,7 @@ static InstructionCost findCostForOutputBlocks(Module &M, LLVM_DEBUG(dbgs() << "Adding: " << TotalCost << " instructions for each switch case for each different" << " output path in a function\n"); - OutputCost += TotalCost; + OutputCost += TotalCost * NumOutputBranches; } return OutputCost; @@ -1548,13 +1898,12 @@ void IROutliner::updateOutputMapping(OutlinableRegion &Region, bool IROutliner::extractSection(OutlinableRegion &Region) { SetVector<Value *> ArgInputs, Outputs, SinkCands; - Region.CE->findInputsOutputs(ArgInputs, Outputs, SinkCands); - assert(Region.StartBB && "StartBB for the OutlinableRegion is nullptr!"); - assert(Region.FollowBB && "FollowBB for the OutlinableRegion is nullptr!"); + BasicBlock *InitialStart = Region.StartBB; Function *OrigF = Region.StartBB->getParent(); CodeExtractorAnalysisCache CEAC(*OrigF); - Region.ExtractedFunction = Region.CE->extractCodeRegion(CEAC); + Region.ExtractedFunction = + Region.CE->extractCodeRegion(CEAC, ArgInputs, Outputs); // If the extraction was successful, find the BasicBlock, and reassign the // OutlinableRegion blocks @@ -1565,7 +1914,23 @@ bool IROutliner::extractSection(OutlinableRegion &Region) { return false; } - BasicBlock *RewrittenBB = Region.FollowBB->getSinglePredecessor(); + // Get the block containing the called branch, and reassign the blocks as + // necessary. If the original block still exists, it is because we ended on + // a branch instruction, and so we move the contents into the block before + // and assign the previous block correctly. + User *InstAsUser = Region.ExtractedFunction->user_back(); + BasicBlock *RewrittenBB = cast<Instruction>(InstAsUser)->getParent(); + Region.PrevBB = RewrittenBB->getSinglePredecessor(); + assert(Region.PrevBB && "PrevBB is nullptr?"); + if (Region.PrevBB == InitialStart) { + BasicBlock *NewPrev = InitialStart->getSinglePredecessor(); + Instruction *BI = NewPrev->getTerminator(); + BI->eraseFromParent(); + moveBBContents(*InitialStart, *NewPrev); + Region.PrevBB = NewPrev; + InitialStart->eraseFromParent(); + } + Region.StartBB = RewrittenBB; Region.EndBB = RewrittenBB; @@ -1608,6 +1973,7 @@ bool IROutliner::extractSection(OutlinableRegion &Region) { unsigned IROutliner::doOutline(Module &M) { // Find the possible similarity sections. + InstructionClassifier.EnableBranches = !DisableBranches; IRSimilarityIdentifier &Identifier = getIRSI(M); SimilarityGroupList &SimilarityCandidates = *Identifier.getSimilarity(); @@ -1622,12 +1988,17 @@ unsigned IROutliner::doOutline(Module &M) { return LHS[0].getLength() * LHS.size() > RHS[0].getLength() * RHS.size(); }); + // Creating OutlinableGroups for each SimilarityCandidate to be used in + // each of the following for loops to avoid making an allocator. + std::vector<OutlinableGroup> PotentialGroups(SimilarityCandidates.size()); DenseSet<unsigned> NotSame; - std::vector<Function *> FuncsToRemove; + std::vector<OutlinableGroup *> NegativeCostGroups; + std::vector<OutlinableRegion *> OutlinedRegions; // Iterate over the possible sets of similarity. + unsigned PotentialGroupIdx = 0; for (SimilarityGroup &CandidateVec : SimilarityCandidates) { - OutlinableGroup CurrentGroup; + OutlinableGroup &CurrentGroup = PotentialGroups[PotentialGroupIdx++]; // Remove entries that were previously outlined pruneIncompatibleRegions(CandidateVec, CurrentGroup); @@ -1649,20 +2020,31 @@ unsigned IROutliner::doOutline(Module &M) { // Create a CodeExtractor for each outlinable region. Identify inputs and // outputs for each section using the code extractor and create the argument // types for the Aggregate Outlining Function. - std::vector<OutlinableRegion *> OutlinedRegions; + OutlinedRegions.clear(); for (OutlinableRegion *OS : CurrentGroup.Regions) { // Break the outlinable region out of its parent BasicBlock into its own // BasicBlocks (see function implementation). OS->splitCandidate(); - std::vector<BasicBlock *> BE = {OS->StartBB}; + + // There's a chance that when the region is split, extra instructions are + // added to the region. This makes the region no longer viable + // to be split, so we ignore it for outlining. + if (!OS->CandidateSplit) + continue; + + SmallVector<BasicBlock *> BE; + DenseSet<BasicBlock *> BBSet; + OS->Candidate->getBasicBlocks(BBSet, BE); OS->CE = new (ExtractorAllocator.Allocate()) CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, false, "outlined"); findAddInputsOutputs(M, *OS, NotSame); if (!OS->IgnoreRegion) OutlinedRegions.push_back(OS); - else - OS->reattachCandidate(); + + // We recombine the blocks together now that we have gathered all the + // needed information. + OS->reattachCandidate(); } CurrentGroup.Regions = std::move(OutlinedRegions); @@ -1675,12 +2057,11 @@ unsigned IROutliner::doOutline(Module &M) { if (CostModel) findCostBenefit(M, CurrentGroup); - // If we are adhering to the cost model, reattach all the candidates + // If we are adhering to the cost model, skip those groups where the cost + // outweighs the benefits. if (CurrentGroup.Cost >= CurrentGroup.Benefit && CostModel) { - for (OutlinableRegion *OS : CurrentGroup.Regions) - OS->reattachCandidate(); - OptimizationRemarkEmitter &ORE = getORE( - *CurrentGroup.Regions[0]->Candidate->getFunction()); + OptimizationRemarkEmitter &ORE = + getORE(*CurrentGroup.Regions[0]->Candidate->getFunction()); ORE.emit([&]() { IRSimilarityCandidate *C = CurrentGroup.Regions[0]->Candidate; OptimizationRemarkMissed R(DEBUG_TYPE, "WouldNotDecreaseSize", @@ -1704,12 +2085,70 @@ unsigned IROutliner::doOutline(Module &M) { continue; } + NegativeCostGroups.push_back(&CurrentGroup); + } + + ExtractorAllocator.DestroyAll(); + + if (NegativeCostGroups.size() > 1) + stable_sort(NegativeCostGroups, + [](const OutlinableGroup *LHS, const OutlinableGroup *RHS) { + return LHS->Benefit - LHS->Cost > RHS->Benefit - RHS->Cost; + }); + + std::vector<Function *> FuncsToRemove; + for (OutlinableGroup *CG : NegativeCostGroups) { + OutlinableGroup &CurrentGroup = *CG; + + OutlinedRegions.clear(); + for (OutlinableRegion *Region : CurrentGroup.Regions) { + // We check whether our region is compatible with what has already been + // outlined, and whether we need to ignore this item. + if (!isCompatibleWithAlreadyOutlinedCode(*Region)) + continue; + OutlinedRegions.push_back(Region); + } + + if (OutlinedRegions.size() < 2) + continue; + + // Reestimate the cost and benefit of the OutlinableGroup. Continue only if + // we are still outlining enough regions to make up for the added cost. + CurrentGroup.Regions = std::move(OutlinedRegions); + if (CostModel) { + CurrentGroup.Benefit = 0; + CurrentGroup.Cost = 0; + findCostBenefit(M, CurrentGroup); + if (CurrentGroup.Cost >= CurrentGroup.Benefit) + continue; + } + OutlinedRegions.clear(); + for (OutlinableRegion *Region : CurrentGroup.Regions) { + Region->splitCandidate(); + if (!Region->CandidateSplit) + continue; + OutlinedRegions.push_back(Region); + } + + CurrentGroup.Regions = std::move(OutlinedRegions); + if (CurrentGroup.Regions.size() < 2) { + for (OutlinableRegion *R : CurrentGroup.Regions) + R->reattachCandidate(); + continue; + } + LLVM_DEBUG(dbgs() << "Outlining regions with cost " << CurrentGroup.Cost << " and benefit " << CurrentGroup.Benefit << "\n"); // Create functions out of all the sections, and mark them as outlined. OutlinedRegions.clear(); for (OutlinableRegion *OS : CurrentGroup.Regions) { + SmallVector<BasicBlock *> BE; + DenseSet<BasicBlock *> BBSet; + OS->Candidate->getBasicBlocks(BBSet, BE); + OS->CE = new (ExtractorAllocator.Allocate()) + CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, + false, "outlined"); bool FunctionOutlined = extractSection(*OS); if (FunctionOutlined) { unsigned StartIdx = OS->Candidate->getStartIdx(); @@ -1767,6 +2206,7 @@ bool IROutliner::run(Module &M) { } // Pass Manager Boilerplate +namespace { class IROutlinerLegacyPass : public ModulePass { public: static char ID; @@ -1782,6 +2222,7 @@ public: bool runOnModule(Module &M) override; }; +} // namespace bool IROutlinerLegacyPass::runOnModule(Module &M) { if (skipModule(M)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp index 59260af88832..992c2b292e1e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp @@ -31,9 +31,11 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/InlineOrder.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/ReplayInlineAdvisor.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" @@ -96,9 +98,53 @@ static cl::opt<std::string> CGSCCInlineReplayFile( "cgscc-inline-replay", cl::init(""), cl::value_desc("filename"), cl::desc( "Optimization remarks file containing inline remarks to be replayed " - "by inlining from cgscc inline remarks."), + "by cgscc inlining."), cl::Hidden); +static cl::opt<ReplayInlinerSettings::Scope> CGSCCInlineReplayScope( + "cgscc-inline-replay-scope", + cl::init(ReplayInlinerSettings::Scope::Function), + cl::values(clEnumValN(ReplayInlinerSettings::Scope::Function, "Function", + "Replay on functions that have remarks associated " + "with them (default)"), + clEnumValN(ReplayInlinerSettings::Scope::Module, "Module", + "Replay on the entire module")), + cl::desc("Whether inline replay should be applied to the entire " + "Module or just the Functions (default) that are present as " + "callers in remarks during cgscc inlining."), + cl::Hidden); + +static cl::opt<ReplayInlinerSettings::Fallback> CGSCCInlineReplayFallback( + "cgscc-inline-replay-fallback", + cl::init(ReplayInlinerSettings::Fallback::Original), + cl::values( + clEnumValN( + ReplayInlinerSettings::Fallback::Original, "Original", + "All decisions not in replay send to original advisor (default)"), + clEnumValN(ReplayInlinerSettings::Fallback::AlwaysInline, + "AlwaysInline", "All decisions not in replay are inlined"), + clEnumValN(ReplayInlinerSettings::Fallback::NeverInline, "NeverInline", + "All decisions not in replay are not inlined")), + cl::desc( + "How cgscc inline replay treats sites that don't come from the replay. " + "Original: defers to original advisor, AlwaysInline: inline all sites " + "not in replay, NeverInline: inline no sites not in replay"), + cl::Hidden); + +static cl::opt<CallSiteFormat::Format> CGSCCInlineReplayFormat( + "cgscc-inline-replay-format", + cl::init(CallSiteFormat::Format::LineColumnDiscriminator), + cl::values( + clEnumValN(CallSiteFormat::Format::Line, "Line", "<Line Number>"), + clEnumValN(CallSiteFormat::Format::LineColumn, "LineColumn", + "<Line Number>:<Column Number>"), + clEnumValN(CallSiteFormat::Format::LineDiscriminator, + "LineDiscriminator", "<Line Number>.<Discriminator>"), + clEnumValN(CallSiteFormat::Format::LineColumnDiscriminator, + "LineColumnDiscriminator", + "<Line Number>:<Column Number>.<Discriminator> (default)")), + cl::desc("How cgscc inline replay file is formatted"), cl::Hidden); + static cl::opt<bool> InlineEnablePriorityOrder( "inline-enable-priority-order", cl::Hidden, cl::init(false), cl::desc("Enable the priority inline order for the inliner")); @@ -463,7 +509,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, } ++NumInlined; - emitInlinedInto(ORE, DLoc, Block, *Callee, *Caller, *OIC); + emitInlinedIntoBasedOnCost(ORE, DLoc, Block, *Callee, *Caller, *OIC); // If inlining this function gave us any new call sites, throw them // onto our worklist to process. They are useful inline candidates. @@ -661,9 +707,12 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, std::make_unique<DefaultInlineAdvisor>(M, FAM, getInlineParams()); if (!CGSCCInlineReplayFile.empty()) - OwnedAdvisor = std::make_unique<ReplayInlineAdvisor>( + OwnedAdvisor = getReplayInlineAdvisor( M, FAM, M.getContext(), std::move(OwnedAdvisor), - CGSCCInlineReplayFile, + ReplayInlinerSettings{CGSCCInlineReplayFile, + CGSCCInlineReplayScope, + CGSCCInlineReplayFallback, + {CGSCCInlineReplayFormat}}, /*EmitRemarks=*/true); return *OwnedAdvisor; @@ -674,153 +723,6 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, return *IAA->getAdvisor(); } -template <typename T> class InlineOrder { -public: - using reference = T &; - using const_reference = const T &; - - virtual ~InlineOrder() {} - - virtual size_t size() = 0; - - virtual void push(const T &Elt) = 0; - - virtual T pop() = 0; - - virtual const_reference front() = 0; - - virtual void erase_if(function_ref<bool(T)> Pred) = 0; - - bool empty() { return !size(); } -}; - -template <typename T, typename Container = SmallVector<T, 16>> -class DefaultInlineOrder : public InlineOrder<T> { - using reference = T &; - using const_reference = const T &; - -public: - size_t size() override { return Calls.size() - FirstIndex; } - - void push(const T &Elt) override { Calls.push_back(Elt); } - - T pop() override { - assert(size() > 0); - return Calls[FirstIndex++]; - } - - const_reference front() override { - assert(size() > 0); - return Calls[FirstIndex]; - } - - void erase_if(function_ref<bool(T)> Pred) override { - Calls.erase(std::remove_if(Calls.begin() + FirstIndex, Calls.end(), Pred), - Calls.end()); - } - -private: - Container Calls; - size_t FirstIndex = 0; -}; - -class Priority { -public: - Priority(int Size) : Size(Size) {} - - static bool isMoreDesirable(const Priority &S1, const Priority &S2) { - return S1.Size < S2.Size; - } - - static Priority evaluate(CallBase *CB) { - Function *Callee = CB->getCalledFunction(); - return Priority(Callee->getInstructionCount()); - } - - int Size; -}; - -template <typename PriorityT> -class PriorityInlineOrder : public InlineOrder<std::pair<CallBase *, int>> { - using T = std::pair<CallBase *, int>; - using HeapT = std::pair<CallBase *, PriorityT>; - using reference = T &; - using const_reference = const T &; - - static bool cmp(const HeapT &P1, const HeapT &P2) { - return PriorityT::isMoreDesirable(P2.second, P1.second); - } - - // A call site could become less desirable for inlining because of the size - // growth from prior inlining into the callee. This method is used to lazily - // update the desirability of a call site if it's decreasing. It is only - // called on pop() or front(), not every time the desirability changes. When - // the desirability of the front call site decreases, an updated one would be - // pushed right back into the heap. For simplicity, those cases where - // the desirability of a call site increases are ignored here. - void adjust() { - bool Changed = false; - do { - CallBase *CB = Heap.front().first; - const PriorityT PreviousGoodness = Heap.front().second; - const PriorityT CurrentGoodness = PriorityT::evaluate(CB); - Changed = PriorityT::isMoreDesirable(PreviousGoodness, CurrentGoodness); - if (Changed) { - std::pop_heap(Heap.begin(), Heap.end(), cmp); - Heap.pop_back(); - Heap.push_back({CB, CurrentGoodness}); - std::push_heap(Heap.begin(), Heap.end(), cmp); - } - } while (Changed); - } - -public: - size_t size() override { return Heap.size(); } - - void push(const T &Elt) override { - CallBase *CB = Elt.first; - const int InlineHistoryID = Elt.second; - const PriorityT Goodness = PriorityT::evaluate(CB); - - Heap.push_back({CB, Goodness}); - std::push_heap(Heap.begin(), Heap.end(), cmp); - InlineHistoryMap[CB] = InlineHistoryID; - } - - T pop() override { - assert(size() > 0); - adjust(); - - CallBase *CB = Heap.front().first; - T Result = std::make_pair(CB, InlineHistoryMap[CB]); - InlineHistoryMap.erase(CB); - std::pop_heap(Heap.begin(), Heap.end(), cmp); - Heap.pop_back(); - return Result; - } - - const_reference front() override { - assert(size() > 0); - adjust(); - - CallBase *CB = Heap.front().first; - return *InlineHistoryMap.find(CB); - } - - void erase_if(function_ref<bool(T)> Pred) override { - auto PredWrapper = [=](HeapT P) -> bool { - return Pred(std::make_pair(P.first, 0)); - }; - Heap.erase(std::remove_if(Heap.begin(), Heap.end(), PredWrapper), - Heap.end()); - std::make_heap(Heap.begin(), Heap.end(), cmp); - } - -private: - SmallVector<HeapT, 16> Heap; - DenseMap<CallBase *, int> InlineHistoryMap; -}; - PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -868,7 +770,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // incrementally maknig a single function grow in a super linear fashion. std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> Calls; if (InlineEnablePriorityOrder) - Calls = std::make_unique<PriorityInlineOrder<Priority>>(); + Calls = std::make_unique<PriorityInlineOrder<InlineSizePriority>>(); else Calls = std::make_unique<DefaultInlineOrder<std::pair<CallBase *, int>>>(); assert(Calls != nullptr && "Expected an initialized InlineOrder"); @@ -972,8 +874,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, continue; } - auto Advice = Advisor.getAdvice(*CB, OnlyMandatory); + std::unique_ptr<InlineAdvice> Advice = + Advisor.getAdvice(*CB, OnlyMandatory); + // Check whether we want to inline this callsite. + if (!Advice) + continue; + if (!Advice->isInliningRecommended()) { Advice->recordUnattemptedInlining(); continue; @@ -1104,6 +1011,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, UR.InlinedInternalEdges.insert({&N, OldC}); } InlinedCallees.clear(); + + // Invalidate analyses for this function now so that we don't have to + // invalidate analyses for all functions in this SCC later. + FAM.invalidate(F, PreservedAnalyses::none()); } // Now that we've finished inlining all of the calls across this SCC, delete @@ -1147,10 +1058,12 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (!Changed) return PreservedAnalyses::all(); + PreservedAnalyses PA; // Even if we change the IR, we update the core CGSCC data structures and so // can preserve the proxy to the function analysis manager. - PreservedAnalyses PA; PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We have already invalidated all analyses on modified functions. + PA.preserveSet<AllAnalysesOn<Function>>(); return PA; } @@ -1173,7 +1086,11 @@ ModuleInlinerWrapperPass::ModuleInlinerWrapperPass(InlineParams Params, PreservedAnalyses ModuleInlinerWrapperPass::run(Module &M, ModuleAnalysisManager &MAM) { auto &IAA = MAM.getResult<InlineAdvisorAnalysis>(M); - if (!IAA.tryCreate(Params, Mode, CGSCCInlineReplayFile)) { + if (!IAA.tryCreate(Params, Mode, + {CGSCCInlineReplayFile, + CGSCCInlineReplayScope, + CGSCCInlineReplayFallback, + {CGSCCInlineReplayFormat}})) { M.getContext().emitError( "Could not setup Inlining Advisor for the requested " "mode and/or options"); @@ -1192,10 +1109,39 @@ PreservedAnalyses ModuleInlinerWrapperPass::run(Module &M, else MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor( createDevirtSCCRepeatedPass(std::move(PM), MaxDevirtIterations))); + + MPM.addPass(std::move(AfterCGMPM)); MPM.run(M, MAM); - IAA.clear(); + // Discard the InlineAdvisor, a subsequent inlining session should construct + // its own. + auto PA = PreservedAnalyses::all(); + PA.abandon<InlineAdvisorAnalysis>(); + return PA; +} - // The ModulePassManager has already taken care of invalidating analyses. - return PreservedAnalyses::all(); +void InlinerPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<InlinerPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + if (OnlyMandatory) + OS << "<only-mandatory>"; +} + +void ModuleInlinerWrapperPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + // Print some info about passes added to the wrapper. This is however + // incomplete as InlineAdvisorAnalysis part isn't included (which also depends + // on Params and Mode). + if (!MPM.isEmpty()) { + MPM.printPipeline(OS, MapClassName2PassName); + OS << ","; + } + OS << "cgscc("; + if (MaxDevirtIterations != 0) + OS << "devirt<" << MaxDevirtIterations << ">("; + PM.printPipeline(OS, MapClassName2PassName); + if (MaxDevirtIterations != 0) + OS << ")"; + OS << ")"; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp index db3b4384ce67..692e445cb7cb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp @@ -201,21 +201,6 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { AlwaysPreserved.insert(V->getName()); } - // Mark all functions not in the api as internal. - IsWasm = Triple(M.getTargetTriple()).isOSBinFormatWasm(); - for (Function &I : M) { - if (!maybeInternalize(I, ComdatMap)) - continue; - Changed = true; - - if (ExternalNode) - // Remove a callgraph edge from the external node to this function. - ExternalNode->removeOneAbstractEdgeTo((*CG)[&I]); - - ++NumFunctions; - LLVM_DEBUG(dbgs() << "Internalizing func " << I.getName() << "\n"); - } - // Never internalize the llvm.used symbol. It is used to implement // attribute((used)). // FIXME: Shouldn't this just filter on llvm.metadata section?? @@ -237,6 +222,21 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { else AlwaysPreserved.insert("__stack_chk_guard"); + // Mark all functions not in the api as internal. + IsWasm = Triple(M.getTargetTriple()).isOSBinFormatWasm(); + for (Function &I : M) { + if (!maybeInternalize(I, ComdatMap)) + continue; + Changed = true; + + if (ExternalNode) + // Remove a callgraph edge from the external node to this function. + ExternalNode->removeOneAbstractEdgeTo((*CG)[&I]); + + ++NumFunctions; + LLVM_DEBUG(dbgs() << "Internalizing func " << I.getName() << "\n"); + } + // Mark all global variables with initializers that are not in the api as // internal as well. for (auto &GV : M.globals()) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp index a497c0390bce..d9a59dd35fde 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp @@ -283,3 +283,13 @@ PreservedAnalyses LoopExtractorPass::run(Module &M, ModuleAnalysisManager &AM) { PA.preserve<LoopAnalysis>(); return PA; } + +void LoopExtractorPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LoopExtractorPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (NumLoops == 1) + OS << "single"; + OS << ">"; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index b492b200c6d5..f78971f0e586 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -342,7 +342,8 @@ private: struct ScopedSaveAliaseesAndUsed { Module &M; SmallVector<GlobalValue *, 4> Used, CompilerUsed; - std::vector<std::pair<GlobalIndirectSymbol *, Function *>> FunctionAliases; + std::vector<std::pair<GlobalAlias *, Function *>> FunctionAliases; + std::vector<std::pair<GlobalIFunc *, Function *>> ResolverIFuncs; ScopedSaveAliaseesAndUsed(Module &M) : M(M) { // The users of this class want to replace all function references except @@ -362,13 +363,16 @@ struct ScopedSaveAliaseesAndUsed { if (GlobalVariable *GV = collectUsedGlobalVariables(M, CompilerUsed, true)) GV->eraseFromParent(); - for (auto &GIS : concat<GlobalIndirectSymbol>(M.aliases(), M.ifuncs())) { + for (auto &GA : M.aliases()) { // FIXME: This should look past all aliases not just interposable ones, // see discussion on D65118. - if (auto *F = - dyn_cast<Function>(GIS.getIndirectSymbol()->stripPointerCasts())) - FunctionAliases.push_back({&GIS, F}); + if (auto *F = dyn_cast<Function>(GA.getAliasee()->stripPointerCasts())) + FunctionAliases.push_back({&GA, F}); } + + for (auto &GI : M.ifuncs()) + if (auto *F = dyn_cast<Function>(GI.getResolver()->stripPointerCasts())) + ResolverIFuncs.push_back({&GI, F}); } ~ScopedSaveAliaseesAndUsed() { @@ -376,8 +380,15 @@ struct ScopedSaveAliaseesAndUsed { appendToCompilerUsed(M, CompilerUsed); for (auto P : FunctionAliases) - P.first->setIndirectSymbol( + P.first->setAliasee( ConstantExpr::getBitCast(P.second, P.first->getType())); + + for (auto P : ResolverIFuncs) { + // This does not preserve pointer casts that may have been stripped by the + // constructor, but the resolver's type is different from that of the + // ifunc anyway. + P.first->setResolver(P.second); + } } }; @@ -1550,17 +1561,28 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), ConstantInt::get(IntPtrTy, I)}), F->getType()); - if (Functions[I]->isExported()) { - if (IsJumpTableCanonical) { - ExportSummary->cfiFunctionDefs().insert(std::string(F->getName())); - } else { - GlobalAlias *JtAlias = GlobalAlias::create( - F->getValueType(), 0, GlobalValue::ExternalLinkage, - F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M); + + const bool IsExported = Functions[I]->isExported(); + if (!IsJumpTableCanonical) { + GlobalValue::LinkageTypes LT = IsExported + ? GlobalValue::ExternalLinkage + : GlobalValue::InternalLinkage; + GlobalAlias *JtAlias = GlobalAlias::create(F->getValueType(), 0, LT, + F->getName() + ".cfi_jt", + CombinedGlobalElemPtr, &M); + if (IsExported) JtAlias->setVisibility(GlobalValue::HiddenVisibility); + else + appendToUsed(M, {JtAlias}); + } + + if (IsExported) { + if (IsJumpTableCanonical) + ExportSummary->cfiFunctionDefs().insert(std::string(F->getName())); + else ExportSummary->cfiFunctionDecls().insert(std::string(F->getName())); - } } + if (!IsJumpTableCanonical) { if (F->hasExternalWeakLinkage()) replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr, @@ -1751,11 +1773,7 @@ static bool isDirectCall(Use& U) { void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsJumpTableCanonical) { SmallSetVector<Constant *, 4> Constants; - auto UI = Old->use_begin(), E = Old->use_end(); - for (; UI != E;) { - Use &U = *UI; - ++UI; - + for (Use &U : llvm::make_early_inc_range(Old->uses())) { // Skip block addresses if (isa<BlockAddress>(U.getUser())) continue; @@ -1792,12 +1810,11 @@ bool LowerTypeTestsModule::lower() { M.getFunction(Intrinsic::getName(Intrinsic::type_test)); if (DropTypeTests && TypeTestFunc) { - for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end(); - UI != UE;) { - auto *CI = cast<CallInst>((*UI++).getUser()); + for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { + auto *CI = cast<CallInst>(U.getUser()); // Find and erase llvm.assume intrinsics for this llvm.type.test call. - for (auto CIU = CI->use_begin(), CIUE = CI->use_end(); CIU != CIUE;) - if (auto *Assume = dyn_cast<AssumeInst>((*CIU++).getUser())) + for (Use &CIU : llvm::make_early_inc_range(CI->uses())) + if (auto *Assume = dyn_cast<AssumeInst>(CIU.getUser())) Assume->eraseFromParent(); // If the assume was merged with another assume, we might have a use on a // phi (which will feed the assume). Simply replace the use on the phi @@ -1835,13 +1852,9 @@ bool LowerTypeTestsModule::lower() { return false; if (ImportSummary) { - if (TypeTestFunc) { - for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end(); - UI != UE;) { - auto *CI = cast<CallInst>((*UI++).getUser()); - importTypeTest(CI); - } - } + if (TypeTestFunc) + for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) + importTypeTest(cast<CallInst>(U.getUser())); if (ICallBranchFunnelFunc && !ICallBranchFunnelFunc->use_empty()) report_fatal_error( @@ -2100,11 +2113,11 @@ bool LowerTypeTestsModule::lower() { auto CI = cast<CallInst>(U.getUser()); std::vector<GlobalTypeMember *> Targets; - if (CI->getNumArgOperands() % 2 != 1) + if (CI->arg_size() % 2 != 1) report_fatal_error("number of arguments should be odd"); GlobalClassesTy::member_iterator CurSet; - for (unsigned I = 1; I != CI->getNumArgOperands(); I += 2) { + for (unsigned I = 1; I != CI->arg_size(); I += 2) { int64_t Offset; auto *Base = dyn_cast<GlobalObject>(GetPointerBaseWithConstantOffset( CI->getOperand(I), Offset, M.getDataLayout())); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp index 9e6dd879ac01..97ef872c5499 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -463,17 +463,15 @@ bool MergeFunctions::runOnModule(Module &M) { // Replace direct callers of Old with New. void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType()); - for (auto UI = Old->use_begin(), UE = Old->use_end(); UI != UE;) { - Use *U = &*UI; - ++UI; - CallBase *CB = dyn_cast<CallBase>(U->getUser()); - if (CB && CB->isCallee(U)) { + for (Use &U : llvm::make_early_inc_range(Old->uses())) { + CallBase *CB = dyn_cast<CallBase>(U.getUser()); + if (CB && CB->isCallee(&U)) { // Do not copy attributes from the called function to the call-site. // Function comparison ensures that the attributes are the same up to // type congruences in byval(), in which case we need to keep the byval // type of the call-site, not the callee function. remove(CB->getFunction()); - U->set(BitcastNew); + U.set(BitcastNew); } } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp new file mode 100644 index 000000000000..ebf080e87c3b --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp @@ -0,0 +1,354 @@ +//===- ModuleInliner.cpp - Code related to module inliner -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the mechanics required to implement inlining without +// missing any calls in the module level. It doesn't need any infromation about +// SCC or call graph, which is different from the SCC inliner. The decisions of +// which calls are profitable to inline are implemented elsewhere. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/ModuleInliner.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InlineAdvisor.h" +#include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/InlineOrder.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include <cassert> +#include <functional> + +using namespace llvm; + +#define DEBUG_TYPE "module-inline" + +STATISTIC(NumInlined, "Number of functions inlined"); +STATISTIC(NumDeleted, "Number of functions deleted because all callers found"); + +static cl::opt<bool> InlineEnablePriorityOrder( + "module-inline-enable-priority-order", cl::Hidden, cl::init(true), + cl::desc("Enable the priority inline order for the module inliner")); + +/// Return true if the specified inline history ID +/// indicates an inline history that includes the specified function. +static bool inlineHistoryIncludes( + Function *F, int InlineHistoryID, + const SmallVectorImpl<std::pair<Function *, int>> &InlineHistory) { + while (InlineHistoryID != -1) { + assert(unsigned(InlineHistoryID) < InlineHistory.size() && + "Invalid inline history ID"); + if (InlineHistory[InlineHistoryID].first == F) + return true; + InlineHistoryID = InlineHistory[InlineHistoryID].second; + } + return false; +} + +InlineAdvisor &ModuleInlinerPass::getAdvisor(const ModuleAnalysisManager &MAM, + FunctionAnalysisManager &FAM, + Module &M) { + if (OwnedAdvisor) + return *OwnedAdvisor; + + auto *IAA = MAM.getCachedResult<InlineAdvisorAnalysis>(M); + if (!IAA) { + // It should still be possible to run the inliner as a stand-alone module + // pass, for test scenarios. In that case, we default to the + // DefaultInlineAdvisor, which doesn't need to keep state between module + // pass runs. It also uses just the default InlineParams. In this case, we + // need to use the provided FAM, which is valid for the duration of the + // inliner pass, and thus the lifetime of the owned advisor. The one we + // would get from the MAM can be invalidated as a result of the inliner's + // activity. + OwnedAdvisor = std::make_unique<DefaultInlineAdvisor>(M, FAM, Params); + + return *OwnedAdvisor; + } + assert(IAA->getAdvisor() && + "Expected a present InlineAdvisorAnalysis also have an " + "InlineAdvisor initialized"); + return *IAA->getAdvisor(); +} + +static bool isKnownLibFunction(Function &F, TargetLibraryInfo &TLI) { + LibFunc LF; + + // Either this is a normal library function or a "vectorizable" + // function. Not using the VFDatabase here because this query + // is related only to libraries handled via the TLI. + return TLI.getLibFunc(F, LF) || + TLI.isKnownVectorFunctionInLibrary(F.getName()); +} + +PreservedAnalyses ModuleInlinerPass::run(Module &M, + ModuleAnalysisManager &MAM) { + LLVM_DEBUG(dbgs() << "---- Module Inliner is Running ---- \n"); + + auto &IAA = MAM.getResult<InlineAdvisorAnalysis>(M); + if (!IAA.tryCreate(Params, Mode, {})) { + M.getContext().emitError( + "Could not setup Inlining Advisor for the requested " + "mode and/or options"); + return PreservedAnalyses::all(); + } + + bool Changed = false; + + ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M); + + FunctionAnalysisManager &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + }; + + InlineAdvisor &Advisor = getAdvisor(MAM, FAM, M); + Advisor.onPassEntry(); + + auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(); }); + + // In the module inliner, a priority-based worklist is used for calls across + // the entire Module. With this module inliner, the inline order is not + // limited to bottom-up order. More globally scope inline order is enabled. + // Also, the inline deferral logic become unnecessary in this module inliner. + // It is possible to use other priority heuristics, e.g. profile-based + // heuristic. + // + // TODO: Here is a huge amount duplicate code between the module inliner and + // the SCC inliner, which need some refactoring. + std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> Calls; + if (InlineEnablePriorityOrder) + Calls = std::make_unique<PriorityInlineOrder<InlineSizePriority>>(); + else + Calls = std::make_unique<DefaultInlineOrder<std::pair<CallBase *, int>>>(); + assert(Calls != nullptr && "Expected an initialized InlineOrder"); + + // Populate the initial list of calls in this module. + for (Function &F : M) { + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + // We want to generally process call sites top-down in order for + // simplifications stemming from replacing the call with the returned value + // after inlining to be visible to subsequent inlining decisions. + // FIXME: Using instructions sequence is a really bad way to do this. + // Instead we should do an actual RPO walk of the function body. + for (Instruction &I : instructions(F)) + if (auto *CB = dyn_cast<CallBase>(&I)) + if (Function *Callee = CB->getCalledFunction()) { + if (!Callee->isDeclaration()) + Calls->push({CB, -1}); + else if (!isa<IntrinsicInst>(I)) { + using namespace ore; + setInlineRemark(*CB, "unavailable definition"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) + << NV("Callee", Callee) << " will not be inlined into " + << NV("Caller", CB->getCaller()) + << " because its definition is unavailable" + << setIsVerbose(); + }); + } + } + } + if (Calls->empty()) + return PreservedAnalyses::all(); + + // When inlining a callee produces new call sites, we want to keep track of + // the fact that they were inlined from the callee. This allows us to avoid + // infinite inlining in some obscure cases. To represent this, we use an + // index into the InlineHistory vector. + SmallVector<std::pair<Function *, int>, 16> InlineHistory; + + // Track a set vector of inlined callees so that we can augment the caller + // with all of their edges in the call graph before pruning out the ones that + // got simplified away. + SmallSetVector<Function *, 4> InlinedCallees; + + // Track the dead functions to delete once finished with inlining calls. We + // defer deleting these to make it easier to handle the call graph updates. + SmallVector<Function *, 4> DeadFunctions; + + // Loop forward over all of the calls. + while (!Calls->empty()) { + // We expect the calls to typically be batched with sequences of calls that + // have the same caller, so we first set up some shared infrastructure for + // this caller. We also do any pruning we can at this layer on the caller + // alone. + Function &F = *Calls->front().first->getCaller(); + + LLVM_DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n" + << " Function size: " << F.getInstructionCount() + << "\n"); + + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + + // Now process as many calls as we have within this caller in the sequence. + // We bail out as soon as the caller has to change so we can + // prepare the context of that new caller. + bool DidInline = false; + while (!Calls->empty() && Calls->front().first->getCaller() == &F) { + auto P = Calls->pop(); + CallBase *CB = P.first; + const int InlineHistoryID = P.second; + Function &Callee = *CB->getCalledFunction(); + + if (InlineHistoryID != -1 && + inlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) { + setInlineRemark(*CB, "recursive"); + continue; + } + + auto Advice = Advisor.getAdvice(*CB, /*OnlyMandatory*/ false); + // Check whether we want to inline this callsite. + if (!Advice->isInliningRecommended()) { + Advice->recordUnattemptedInlining(); + continue; + } + + // Setup the data structure used to plumb customization into the + // `InlineFunction` routine. + InlineFunctionInfo IFI( + /*cg=*/nullptr, GetAssumptionCache, PSI, + &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), + &FAM.getResult<BlockFrequencyAnalysis>(Callee)); + + InlineResult IR = + InlineFunction(*CB, IFI, &FAM.getResult<AAManager>(*CB->getCaller())); + if (!IR.isSuccess()) { + Advice->recordUnsuccessfulInlining(IR); + continue; + } + + DidInline = true; + InlinedCallees.insert(&Callee); + ++NumInlined; + + LLVM_DEBUG(dbgs() << " Size after inlining: " + << F.getInstructionCount() << "\n"); + + // Add any new callsites to defined functions to the worklist. + if (!IFI.InlinedCallSites.empty()) { + int NewHistoryID = InlineHistory.size(); + InlineHistory.push_back({&Callee, InlineHistoryID}); + + for (CallBase *ICB : reverse(IFI.InlinedCallSites)) { + Function *NewCallee = ICB->getCalledFunction(); + if (!NewCallee) { + // Try to promote an indirect (virtual) call without waiting for + // the post-inline cleanup and the next DevirtSCCRepeatedPass + // iteration because the next iteration may not happen and we may + // miss inlining it. + if (tryPromoteCall(*ICB)) + NewCallee = ICB->getCalledFunction(); + } + if (NewCallee) + if (!NewCallee->isDeclaration()) + Calls->push({ICB, NewHistoryID}); + } + } + + // Merge the attributes based on the inlining. + AttributeFuncs::mergeAttributesForInlining(F, Callee); + + // For local functions, check whether this makes the callee trivially + // dead. In that case, we can drop the body of the function eagerly + // which may reduce the number of callers of other functions to one, + // changing inline cost thresholds. + bool CalleeWasDeleted = false; + if (Callee.hasLocalLinkage()) { + // To check this we also need to nuke any dead constant uses (perhaps + // made dead by this operation on other functions). + Callee.removeDeadConstantUsers(); + // if (Callee.use_empty() && !CG.isLibFunction(Callee)) { + if (Callee.use_empty() && !isKnownLibFunction(Callee, GetTLI(Callee))) { + Calls->erase_if([&](const std::pair<CallBase *, int> &Call) { + return Call.first->getCaller() == &Callee; + }); + // Clear the body and queue the function itself for deletion when we + // finish inlining. + // Note that after this point, it is an error to do anything other + // than use the callee's address or delete it. + Callee.dropAllReferences(); + assert(!is_contained(DeadFunctions, &Callee) && + "Cannot put cause a function to become dead twice!"); + DeadFunctions.push_back(&Callee); + CalleeWasDeleted = true; + } + } + if (CalleeWasDeleted) + Advice->recordInliningWithCalleeDeleted(); + else + Advice->recordInlining(); + } + + if (!DidInline) + continue; + Changed = true; + + InlinedCallees.clear(); + } + + // Now that we've finished inlining all of the calls across this module, + // delete all of the trivially dead functions. + // + // Note that this walks a pointer set which has non-deterministic order but + // that is OK as all we do is delete things and add pointers to unordered + // sets. + for (Function *DeadF : DeadFunctions) { + // Clear out any cached analyses. + FAM.clear(*DeadF, DeadF->getName()); + + // And delete the actual function from the module. + // The Advisor may use Function pointers to efficiently index various + // internal maps, e.g. for memoization. Function cleanup passes like + // argument promotion create new functions. It is possible for a new + // function to be allocated at the address of a deleted function. We could + // index using names, but that's inefficient. Alternatively, we let the + // Advisor free the functions when it sees fit. + DeadF->getBasicBlockList().clear(); + M.getFunctionList().remove(DeadF); + + ++NumDeleted; + } + + if (!Changed) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 75eec25f5807..f342c35fa283 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/EnumeratedArray.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -33,6 +34,8 @@ #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" @@ -41,6 +44,8 @@ #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Transforms/Utils/CodeExtractor.h" +#include <algorithm> + using namespace llvm; using namespace omp; @@ -72,6 +77,46 @@ static cl::opt<bool> HideMemoryTransferLatency( " transfers"), cl::Hidden, cl::init(false)); +static cl::opt<bool> DisableOpenMPOptDeglobalization( + "openmp-opt-disable-deglobalization", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations involving deglobalization."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> DisableOpenMPOptSPMDization( + "openmp-opt-disable-spmdization", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations involving SPMD-ization."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> DisableOpenMPOptFolding( + "openmp-opt-disable-folding", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, + cl::init(false)); + +static cl::opt<bool> DisableOpenMPOptStateMachineRewrite( + "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore, + cl::desc("Disable OpenMP optimizations that replace the state machine."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> PrintModuleAfterOptimizations( + "openmp-opt-print-module", cl::ZeroOrMore, + cl::desc("Print the current module after OpenMP optimizations."), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> AlwaysInlineDeviceFunctions( + "openmp-opt-inline-device", cl::ZeroOrMore, + cl::desc("Inline all applicible functions on the device."), cl::Hidden, + cl::init(false)); + +static cl::opt<bool> + EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore, + cl::desc("Enables more verbose remarks."), cl::Hidden, + cl::init(false)); + +static cl::opt<unsigned> + SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, + cl::desc("Maximal number of attributor iterations."), + cl::init(256)); + STATISTIC(NumOpenMPRuntimeCallsDeduplicated, "Number of OpenMP runtime calls deduplicated"); STATISTIC(NumOpenMPParallelRegionsDeleted, @@ -328,7 +373,7 @@ struct OMPInformationCache : public InformationCache { if (F->arg_size() != RTFArgTypes.size()) return false; - auto RTFTyIt = RTFArgTypes.begin(); + auto *RTFTyIt = RTFArgTypes.begin(); for (Argument &Arg : F->args()) { if (Arg.getType() != *RTFTyIt) return false; @@ -503,7 +548,7 @@ struct KernelInfoState : AbstractState { /// State to track if we are in SPMD-mode, assumed or know, and why we decided /// we cannot be. If it is assumed, then RequiresFullRuntime should also be /// false. - BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker; + BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker; /// The __kmpc_target_init call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. @@ -542,7 +587,9 @@ struct KernelInfoState : AbstractState { /// See AbstractState::indicatePessimisticFixpoint(...) ChangeStatus indicatePessimisticFixpoint() override { IsAtFixpoint = true; + ReachingKernelEntries.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + ReachedKnownParallelRegions.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); return ChangeStatus::CHANGED; } @@ -550,6 +597,10 @@ struct KernelInfoState : AbstractState { /// See AbstractState::indicateOptimisticFixpoint(...) ChangeStatus indicateOptimisticFixpoint() override { IsAtFixpoint = true; + ReachingKernelEntries.indicateOptimisticFixpoint(); + SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + ReachedKnownParallelRegions.indicateOptimisticFixpoint(); + ReachedUnknownParallelRegions.indicateOptimisticFixpoint(); return ChangeStatus::UNCHANGED; } @@ -569,6 +620,12 @@ struct KernelInfoState : AbstractState { return true; } + /// Returns true if this kernel contains any OpenMP parallel regions. + bool mayContainParallelRegion() { + return !ReachedKnownParallelRegions.empty() || + !ReachedUnknownParallelRegions.empty(); + } + /// Return empty set as the best state of potential values. static KernelInfoState getBestState() { return KernelInfoState(true); } @@ -584,12 +641,14 @@ struct KernelInfoState : AbstractState { // Do not merge two different _init and _deinit call sites. if (KIS.KernelInitCB) { if (KernelInitCB && KernelInitCB != KIS.KernelInitCB) - indicatePessimisticFixpoint(); + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); KernelInitCB = KIS.KernelInitCB; } if (KIS.KernelDeinitCB) { if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB) - indicatePessimisticFixpoint(); + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); KernelDeinitCB = KIS.KernelDeinitCB; } SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; @@ -1032,8 +1091,8 @@ private: Args.clear(); Args.push_back(OutlinedFn->getArg(0)); Args.push_back(OutlinedFn->getArg(1)); - for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); - U < E; ++U) + for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E; + ++U) Args.push_back(CI->getArgOperand(U)); CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); @@ -1041,9 +1100,9 @@ private: NewCI->setDebugLoc(CI->getDebugLoc()); // Forward parameter attributes from the callback to the callee. - for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); - U < E; ++U) - for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) + for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E; + ++U) + for (const Attribute &A : CI->getAttributes().getParamAttrs(U)) NewCI->addParamAttr( U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); @@ -1563,13 +1622,13 @@ private: // TODO: Use dominance to find a good position instead. auto CanBeMoved = [this](CallBase &CB) { - unsigned NumArgs = CB.getNumArgOperands(); + unsigned NumArgs = CB.arg_size(); if (NumArgs == 0) return true; if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) return false; - for (unsigned u = 1; u < NumArgs; ++u) - if (isa<Instruction>(CB.getArgOperand(u))) + for (unsigned U = 1; U < NumArgs; ++U) + if (isa<Instruction>(CB.getArgOperand(U))) return false; return true; }; @@ -1612,7 +1671,7 @@ private: // valid at the new location. For now we just pick a global one, either // existing and used by one of the calls, or created from scratch. if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { - if (CI->getNumArgOperands() > 0 && + if (!CI->arg_empty() && CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, /* GlobalOnly */ true); @@ -1695,8 +1754,8 @@ private: // Transitively search for more arguments by looking at the users of the // ones we know already. During the search the GTIdArgs vector is extended // so we cannot cache the size nor can we use a range based for. - for (unsigned u = 0; u < GTIdArgs.size(); ++u) - AddUserArgs(*GTIdArgs[u]); + for (unsigned U = 0; U < GTIdArgs.size(); ++U) + AddUserArgs(*GTIdArgs[U]); } /// Kernel (=GPU) optimizations and utility functions @@ -1822,6 +1881,10 @@ private: OMPRTL___kmpc_kernel_end_parallel); ExternalizationRAII BarrierSPMD(OMPInfoCache, OMPRTL___kmpc_barrier_simple_spmd); + ExternalizationRAII BarrierGeneric(OMPInfoCache, + OMPRTL___kmpc_barrier_simple_generic); + ExternalizationRAII ThreadId(OMPInfoCache, + OMPRTL___kmpc_get_hardware_thread_id_in_block); registerAAs(IsModulePass); @@ -1918,6 +1981,10 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() { if (!KernelParallelRFI) return Changed; + // If we have disabled state machine changes, exit + if (DisableOpenMPOptStateMachineRewrite) + return Changed; + for (Function *F : SCC) { // Check if the function is a use in a __kmpc_parallel_51 call at @@ -2509,9 +2576,8 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; - // Check if the edge into the successor block compares the __kmpc_target_init - // result with -1. If we are in non-SPMD-mode that signals only the main - // thread will execute the edge. + // Check if the edge into the successor block contains a condition that only + // lets the main thread execute it. auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { if (!Edge || !Edge->isConditional()) return false; @@ -2526,16 +2592,27 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (!C) return false; - // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) + // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) if (C->isAllOnesValue()) { auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0)); CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; if (!CB) return false; - const int InitIsSPMDArgNo = 1; - auto *IsSPMDModeCI = - dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo)); - return IsSPMDModeCI && IsSPMDModeCI->isZero(); + const int InitModeArgNo = 1; + auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo)); + return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC); + } + + if (C->isZero()) { + // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x() + if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) + if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) + return true; + + // Match: 0 == llvm.amdgcn.workitem.id.x() + if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) + if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) + return true; } return false; @@ -2544,15 +2621,14 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // Merge all the predecessor states into the current basic block. A basic // block is executed by a single thread if all of its predecessors are. auto MergePredecessorStates = [&](BasicBlock *BB) { - if (pred_begin(BB) == pred_end(BB)) + if (pred_empty(BB)) return SingleThreadedBBs.contains(BB); bool IsInitialThread = true; - for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); - PredBB != PredEndBB; ++PredBB) { - if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), + for (BasicBlock *PredBB : predecessors(BB)) { + if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()), BB)) - IsInitialThread &= SingleThreadedBBs.contains(*PredBB); + IsInitialThread &= SingleThreadedBBs.contains(PredBB); } return IsInitialThread; @@ -2684,9 +2760,8 @@ struct AAHeapToSharedFunction : public AAHeapToShared { ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0)); - LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in " - << CB->getCaller()->getName() << " with " - << AllocSize->getZExtValue() + LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB + << " with " << AllocSize->getZExtValue() << " bytes of shared memory\n"); // Create a new shared memory buffer of the same size as the allocation @@ -2735,7 +2810,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { const auto &ED = A.getAAFor<AAExecutionDomain>( *this, IRPosition::function(*F), DepClassTy::REQUIRED); if (CallBase *CB = dyn_cast<CallBase>(U)) - if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) || + if (!isa<ConstantInt>(CB->getArgOperand(0)) || !ED.isExecutedByInitialThreadOnly(*CB)) MallocCalls.erase(CB); } @@ -2770,9 +2845,17 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]" : "") + std::string(" #PRs: ") + - std::to_string(ReachedKnownParallelRegions.size()) + + (ReachedKnownParallelRegions.isValidState() + ? std::to_string(ReachedKnownParallelRegions.size()) + : "<invalid>") + ", #Unknown PRs: " + - std::to_string(ReachedUnknownParallelRegions.size()); + (ReachedUnknownParallelRegions.isValidState() + ? std::to_string(ReachedUnknownParallelRegions.size()) + : "<invalid>") + + ", #Reaching Kernels: " + + (ReachingKernelEntries.isValidState() + ? std::to_string(ReachingKernelEntries.size()) + : "<invalid>"); } /// Create an abstract attribute biew for the position \p IRP. @@ -2798,6 +2881,12 @@ struct AAKernelInfoFunction : AAKernelInfo { AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) : AAKernelInfo(IRP, A) {} + SmallPtrSet<Instruction *, 4> GuardedInstructions; + + SmallPtrSetImpl<Instruction *> &getGuardedInstructions() { + return GuardedInstructions; + } + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -2844,8 +2933,11 @@ struct AAKernelInfoFunction : AAKernelInfo { }, Fn); - assert((KernelInitCB && KernelDeinitCB) && - "Kernel without __kmpc_target_init or __kmpc_target_deinit!"); + // Ignore kernels without initializers such as global constructors. + if (!KernelInitCB || !KernelDeinitCB) { + indicateOptimisticFixpoint(); + return; + } // For kernels we might need to initialize/finalize the IsSPMD state and // we need to register a simplification callback so that the Attributor @@ -2860,7 +2952,10 @@ struct AAKernelInfoFunction : AAKernelInfo { // state. As long as we are not in an invalid state, we will create a // custom state machine so the value should be a `i1 false`. If we are // in an invalid state, we won't change the value that is in the IR. - if (!isValidState()) + if (!ReachedKnownParallelRegions.isValidState()) + return nullptr; + // If we have disabled state machine rewrites, don't make a custom one. + if (DisableOpenMPOptStateMachineRewrite) return nullptr; if (AA) A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); @@ -2870,7 +2965,7 @@ struct AAKernelInfoFunction : AAKernelInfo { return FalseVal; }; - Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB = + Attributor::SimplifictionCallbackTy ModeSimplifyCB = [&](const IRPosition &IRP, const AbstractAttribute *AA, bool &UsedAssumedInformation) -> Optional<Value *> { // IRP represents the "SPMDCompatibilityTracker" argument of an @@ -2886,8 +2981,10 @@ struct AAKernelInfoFunction : AAKernelInfo { } else { UsedAssumedInformation = false; } - auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), - SPMDCompatibilityTracker.isAssumed()); + auto *Val = ConstantInt::getSigned( + IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()), + SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD + : OMP_TGT_EXEC_MODE_GENERIC); return Val; }; @@ -2912,8 +3009,8 @@ struct AAKernelInfoFunction : AAKernelInfo { return Val; }; - constexpr const int InitIsSPMDArgNo = 1; - constexpr const int DeinitIsSPMDArgNo = 1; + constexpr const int InitModeArgNo = 1; + constexpr const int DeinitModeArgNo = 1; constexpr const int InitUseStateMachineArgNo = 2; constexpr const int InitRequiresFullRuntimeArgNo = 3; constexpr const int DeinitRequiresFullRuntimeArgNo = 2; @@ -2921,11 +3018,11 @@ struct AAKernelInfoFunction : AAKernelInfo { IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), StateMachineSimplifyCB); A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo), - IsSPMDModeSimplifyCB); + IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo), + ModeSimplifyCB); A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), - IsSPMDModeSimplifyCB); + IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo), + ModeSimplifyCB); A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitRequiresFullRuntimeArgNo), @@ -2936,10 +3033,25 @@ struct AAKernelInfoFunction : AAKernelInfo { IsGenericModeSimplifyCB); // Check if we know we are in SPMD-mode already. - ConstantInt *IsSPMDArg = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); - if (IsSPMDArg && !IsSPMDArg->isZero()) + ConstantInt *ModeArg = + dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); + if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + // This is a generic region but SPMDization is disabled so stop tracking. + else if (DisableOpenMPOptSPMDization) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + } + + /// Sanitize the string \p S such that it is a suitable global symbol name. + static std::string sanitizeForGlobalName(std::string S) { + std::replace_if( + S.begin(), S.end(), + [](const char C) { + return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || + (C >= '0' && C <= '9') || C == '_'); + }, + '.'); + return S; } /// Modify the IR based on the KernelInfoState as the fixpoint iteration is @@ -2950,19 +3062,16 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!KernelInitCB || !KernelDeinitCB) return ChangeStatus::UNCHANGED; - // Known SPMD-mode kernels need no manifest changes. - if (SPMDCompatibilityTracker.isKnown()) - return ChangeStatus::UNCHANGED; - // If we can we change the execution mode to SPMD-mode otherwise we build a // custom state machine. - if (!changeToSPMDMode(A)) - buildCustomStateMachine(A); + ChangeStatus Changed = ChangeStatus::UNCHANGED; + if (!changeToSPMDMode(A, Changed)) + return buildCustomStateMachine(A); - return ChangeStatus::CHANGED; + return Changed; } - bool changeToSPMDMode(Attributor &A) { + bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); if (!SPMDCompatibilityTracker.isAssumed()) { @@ -2994,38 +3103,259 @@ struct AAKernelInfoFunction : AAKernelInfo { return false; } - // Adjust the global exec mode flag that tells the runtime what mode this - // kernel is executed in. + // Check if the kernel is already in SPMD mode, if so, return success. Function *Kernel = getAnchorScope(); GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( (Kernel->getName() + "_exec_mode").str()); assert(ExecMode && "Kernel without exec mode?"); - assert(ExecMode->getInitializer() && - ExecMode->getInitializer()->isOneValue() && - "Initially non-SPMD kernel has SPMD exec mode!"); + assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); // Set the global exec mode flag to indicate SPMD-Generic mode. - constexpr int SPMDGeneric = 2; - if (!ExecMode->getInitializer()->isZeroValue()) - ExecMode->setInitializer( - ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric)); + assert(isa<ConstantInt>(ExecMode->getInitializer()) && + "ExecMode is not an integer!"); + const int8_t ExecModeVal = + cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); + if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) + return true; + + // We will now unconditionally modify the IR, indicate a change. + Changed = ChangeStatus::CHANGED; + + auto CreateGuardedRegion = [&](Instruction *RegionStartI, + Instruction *RegionEndI) { + LoopInfo *LI = nullptr; + DominatorTree *DT = nullptr; + MemorySSAUpdater *MSU = nullptr; + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + BasicBlock *ParentBB = RegionStartI->getParent(); + Function *Fn = ParentBB->getParent(); + Module &M = *Fn->getParent(); + + // Create all the blocks and logic. + // ParentBB: + // goto RegionCheckTidBB + // RegionCheckTidBB: + // Tid = __kmpc_hardware_thread_id() + // if (Tid != 0) + // goto RegionBarrierBB + // RegionStartBB: + // <execute instructions guarded> + // goto RegionEndBB + // RegionEndBB: + // <store escaping values to shared mem> + // goto RegionBarrierBB + // RegionBarrierBB: + // __kmpc_simple_barrier_spmd() + // // second barrier is omitted if lacking escaping values. + // <load escaping values from shared mem> + // __kmpc_simple_barrier_spmd() + // goto RegionExitBB + // RegionExitBB: + // <execute rest of instructions> + + BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(), + DT, LI, MSU, "region.guarded.end"); + BasicBlock *RegionBarrierBB = + SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI, + MSU, "region.barrier"); + BasicBlock *RegionExitBB = + SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(), + DT, LI, MSU, "region.exit"); + BasicBlock *RegionStartBB = + SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded"); + + assert(ParentBB->getUniqueSuccessor() == RegionStartBB && + "Expected a different CFG"); + + BasicBlock *RegionCheckTidBB = SplitBlock( + ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid"); + + // Register basic blocks with the Attributor. + A.registerManifestAddedBasicBlock(*RegionEndBB); + A.registerManifestAddedBasicBlock(*RegionBarrierBB); + A.registerManifestAddedBasicBlock(*RegionExitBB); + A.registerManifestAddedBasicBlock(*RegionStartBB); + A.registerManifestAddedBasicBlock(*RegionCheckTidBB); + + bool HasBroadcastValues = false; + // Find escaping outputs from the guarded region to outside users and + // broadcast their values to them. + for (Instruction &I : *RegionStartBB) { + SmallPtrSet<Instruction *, 4> OutsideUsers; + for (User *Usr : I.users()) { + Instruction &UsrI = *cast<Instruction>(Usr); + if (UsrI.getParent() != RegionStartBB) + OutsideUsers.insert(&UsrI); + } + + if (OutsideUsers.empty()) + continue; + + HasBroadcastValues = true; + + // Emit a global variable in shared memory to store the broadcasted + // value. + auto *SharedMem = new GlobalVariable( + M, I.getType(), /* IsConstant */ false, + GlobalValue::InternalLinkage, UndefValue::get(I.getType()), + sanitizeForGlobalName( + (I.getName() + ".guarded.output.alloc").str()), + nullptr, GlobalValue::NotThreadLocal, + static_cast<unsigned>(AddressSpace::Shared)); + + // Emit a store instruction to update the value. + new StoreInst(&I, SharedMem, RegionEndBB->getTerminator()); + + LoadInst *LoadI = new LoadInst(I.getType(), SharedMem, + I.getName() + ".guarded.output.load", + RegionBarrierBB->getTerminator()); + + // Emit a load instruction and replace uses of the output value. + for (Instruction *UsrI : OutsideUsers) + UsrI->replaceUsesOfWith(&I, LoadI); + } + + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + + // Go to tid check BB in ParentBB. + const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); + ParentBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(ParentBB, ParentBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(Loc); + auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); + Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); + BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL); + + // Add check for Tid in RegionCheckTidBB + RegionCheckTidBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription LocRegionCheckTid( + InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid); + FunctionCallee HardwareTidFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_thread_id_in_block); + Value *Tid = + OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); + Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); + OMPInfoCache.OMPBuilder.Builder + .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB) + ->setDebugLoc(DL); + + // First barrier for synchronization, ensures main thread has updated + // values. + FunctionCallee BarrierFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_barrier_simple_spmd); + OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy( + RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt())); + OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}) + ->setDebugLoc(DL); + + // Second barrier ensures workers have read broadcast values. + if (HasBroadcastValues) + CallInst::Create(BarrierFn, {Ident, Tid}, "", + RegionBarrierBB->getTerminator()) + ->setDebugLoc(DL); + }; + + auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + SmallPtrSet<BasicBlock *, 8> Visited; + for (Instruction *GuardedI : SPMDCompatibilityTracker) { + BasicBlock *BB = GuardedI->getParent(); + if (!Visited.insert(BB).second) + continue; + + SmallVector<std::pair<Instruction *, Instruction *>> Reorders; + Instruction *LastEffect = nullptr; + BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend(); + while (++IP != IPEnd) { + if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory()) + continue; + Instruction *I = &*IP; + if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI)) + continue; + if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) { + LastEffect = nullptr; + continue; + } + if (LastEffect) + Reorders.push_back({I, LastEffect}); + LastEffect = &*IP; + } + for (auto &Reorder : Reorders) + Reorder.first->moveBefore(Reorder.second); + } + + SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions; + + for (Instruction *GuardedI : SPMDCompatibilityTracker) { + BasicBlock *BB = GuardedI->getParent(); + auto *CalleeAA = A.lookupAAFor<AAKernelInfo>( + IRPosition::function(*GuardedI->getFunction()), nullptr, + DepClassTy::NONE); + assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo"); + auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA); + // Continue if instruction is already guarded. + if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI)) + continue; + + Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr; + for (Instruction &I : *BB) { + // If instruction I needs to be guarded update the guarded region + // bounds. + if (SPMDCompatibilityTracker.contains(&I)) { + CalleeAAFunction.getGuardedInstructions().insert(&I); + if (GuardedRegionStart) + GuardedRegionEnd = &I; + else + GuardedRegionStart = GuardedRegionEnd = &I; + + continue; + } + + // Instruction I does not need guarding, store + // any region found and reset bounds. + if (GuardedRegionStart) { + GuardedRegions.push_back( + std::make_pair(GuardedRegionStart, GuardedRegionEnd)); + GuardedRegionStart = nullptr; + GuardedRegionEnd = nullptr; + } + } + } + + for (auto &GR : GuardedRegions) + CreateGuardedRegion(GR.first, GR.second); + + // Adjust the global exec mode flag that tells the runtime what mode this + // kernel is executed in. + assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && + "Initially non-SPMD kernel has SPMD exec mode!"); + ExecMode->setInitializer( + ConstantInt::get(ExecMode->getInitializer()->getType(), + ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. - const int InitIsSPMDArgNo = 1; - const int DeinitIsSPMDArgNo = 1; + const int InitModeArgNo = 1; + const int DeinitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; const int InitRequiresFullRuntimeArgNo = 3; const int DeinitRequiresFullRuntimeArgNo = 2; auto &Ctx = getAnchorValue().getContext(); - A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), - *ConstantInt::getBool(Ctx, 1)); + A.changeUseAfterManifest( + KernelInitCB->getArgOperandUse(InitModeArgNo), + *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), + OMP_TGT_EXEC_MODE_SPMD)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *ConstantInt::getBool(Ctx, 0)); A.changeUseAfterManifest( - KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo), - *ConstantInt::getBool(Ctx, 1)); + KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), + *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), + OMP_TGT_EXEC_MODE_SPMD)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), *ConstantInt::getBool(Ctx, 0)); @@ -3043,10 +3373,15 @@ struct AAKernelInfoFunction : AAKernelInfo { }; ChangeStatus buildCustomStateMachine(Attributor &A) { - assert(ReachedKnownParallelRegions.isValidState() && - "Custom state machine with invalid parallel region states?"); + // If we have disabled state machine rewrites, don't make a custom one + if (DisableOpenMPOptStateMachineRewrite) + return ChangeStatus::UNCHANGED; + + // Don't rewrite the state machine if we are not in a valid state. + if (!ReachedKnownParallelRegions.isValidState()) + return ChangeStatus::UNCHANGED; - const int InitIsSPMDArgNo = 1; + const int InitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; // Check if the current configuration is non-SPMD and generic state machine. @@ -3055,14 +3390,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // we give up. ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); - ConstantInt *IsSPMD = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); + ConstantInt *Mode = + dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); // If we are stuck with generic mode, try to create a custom device (=GPU) // state machine which is specialized for the parallel regions that are // reachable by the kernel. - if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD || - !IsSPMD->isZero()) + if (!UseStateMachine || UseStateMachine->isZero() || !Mode || + (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) return ChangeStatus::UNCHANGED; // If not SPMD mode, indicate we use a custom state machine now. @@ -3075,8 +3410,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // happen if there simply are no parallel regions. In the resulting kernel // all worker threads will simply exit right away, leaving the main thread // to do the work alone. - if (ReachedKnownParallelRegions.empty() && - ReachedUnknownParallelRegions.empty()) { + if (!mayContainParallelRegion()) { ++NumOpenMPTargetRegionKernelsWithoutStateMachine; auto Remark = [&](OptimizationRemark OR) { @@ -3122,9 +3456,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // Create all the blocks: // // InitCB = __kmpc_target_init(...) - // bool IsWorker = InitCB >= 0; + // BlockHwSize = + // __kmpc_get_hardware_num_threads_in_block(); + // WarpSize = __kmpc_get_warp_size(); + // BlockSize = BlockHwSize - WarpSize; + // if (InitCB >= BlockSize) return; + // IsWorkerCheckBB: bool IsWorker = InitCB >= 0; // if (IsWorker) { - // SMBeginBB: __kmpc_barrier_simple_spmd(...); + // SMBeginBB: __kmpc_barrier_simple_generic(...); // void *WorkFn; // bool Active = __kmpc_kernel_parallel(&WorkFn); // if (!WorkFn) return; @@ -3138,7 +3477,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // ((WorkFnTy*)WorkFn)(...); // SMEndParallelBB: __kmpc_kernel_end_parallel(...); // } - // SMDoneBB: __kmpc_barrier_simple_spmd(...); + // SMDoneBB: __kmpc_barrier_simple_generic(...); // goto SMBeginBB; // } // UserCodeEntryBB: // user code @@ -3150,6 +3489,8 @@ struct AAKernelInfoFunction : AAKernelInfo { BasicBlock *InitBB = KernelInitCB->getParent(); BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock( KernelInitCB->getNextNode(), "thread.user_code.check"); + BasicBlock *IsWorkerCheckBB = + BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB); BasicBlock *StateMachineBeginBB = BasicBlock::Create( Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB); BasicBlock *StateMachineFinishedBB = BasicBlock::Create( @@ -3166,6 +3507,7 @@ struct AAKernelInfoFunction : AAKernelInfo { Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB); A.registerManifestAddedBasicBlock(*InitBB); A.registerManifestAddedBasicBlock(*UserCodeEntryBB); + A.registerManifestAddedBasicBlock(*IsWorkerCheckBB); A.registerManifestAddedBasicBlock(*StateMachineBeginBB); A.registerManifestAddedBasicBlock(*StateMachineFinishedBB); A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB); @@ -3175,16 +3517,38 @@ struct AAKernelInfoFunction : AAKernelInfo { const DebugLoc &DLoc = KernelInitCB->getDebugLoc(); ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc); - InitBB->getTerminator()->eraseFromParent(); + + Module &M = *Kernel->getParent(); + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + FunctionCallee BlockHwSizeFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_num_threads_in_block); + FunctionCallee WarpSizeFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_warp_size); + Instruction *BlockHwSize = + CallInst::Create(BlockHwSizeFn, "block.hw_size", InitBB); + BlockHwSize->setDebugLoc(DLoc); + Instruction *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB); + WarpSize->setDebugLoc(DLoc); + Instruction *BlockSize = + BinaryOperator::CreateSub(BlockHwSize, WarpSize, "block.size", InitBB); + BlockSize->setDebugLoc(DLoc); + Instruction *IsMainOrWorker = + ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, + BlockSize, "thread.is_main_or_worker", InitBB); + IsMainOrWorker->setDebugLoc(DLoc); + BranchInst::Create(IsWorkerCheckBB, StateMachineFinishedBB, IsMainOrWorker, + InitBB); + Instruction *IsWorker = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB, ConstantInt::get(KernelInitCB->getType(), -1), - "thread.is_worker", InitBB); + "thread.is_worker", IsWorkerCheckBB); IsWorker->setDebugLoc(DLoc); - BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB); - - Module &M = *Kernel->getParent(); + BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, + IsWorkerCheckBB); // Create local storage for the work function pointer. const DataLayout &DL = M.getDataLayout(); @@ -3194,7 +3558,6 @@ struct AAKernelInfoFunction : AAKernelInfo { "worker.work_fn.addr", &Kernel->getEntryBlock().front()); WorkFnAI->setDebugLoc(DLoc); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); OMPInfoCache.OMPBuilder.updateToLocation( OpenMPIRBuilder::LocationDescription( IRBuilder<>::InsertPoint(StateMachineBeginBB, @@ -3206,7 +3569,7 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionCallee BarrierFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( - M, OMPRTL___kmpc_barrier_simple_spmd); + M, OMPRTL___kmpc_barrier_simple_generic); CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB) ->setDebugLoc(DLoc); @@ -3258,8 +3621,8 @@ struct AAKernelInfoFunction : AAKernelInfo { // Now that we have most of the CFG skeleton it is time for the if-cascade // that checks the function pointer we got from the runtime against the // parallel regions we expect, if there are any. - for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) { - auto *ParallelRegion = ReachedKnownParallelRegions[i]; + for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) { + auto *ParallelRegion = ReachedKnownParallelRegions[I]; BasicBlock *PRExecuteBB = BasicBlock::Create( Ctx, "worker_state_machine.parallel_region.execute", Kernel, StateMachineEndParallelBB); @@ -3275,7 +3638,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // Check if we need to compare the pointer at all or if we can just // call the parallel region function. Value *IsPR; - if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) { + if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) { Instruction *CmpI = ICmpInst::Create( ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); @@ -3339,8 +3702,21 @@ struct AAKernelInfoFunction : AAKernelInfo { if (llvm::all_of(Objects, [](const Value *Obj) { return isa<AllocaInst>(Obj); })) return true; + // Check for AAHeapToStack moved objects which must not be guarded. + auto &HS = A.getAAFor<AAHeapToStack>( + *this, IRPosition::function(*I.getFunction()), + DepClassTy::OPTIONAL); + if (llvm::all_of(Objects, [&HS](const Value *Obj) { + auto *CB = dyn_cast<CallBase>(Obj); + if (!CB) + return false; + return HS.isAssumedHeapToStack(*CB); + })) { + return true; + } } - // For now we give up on everything but stores. + + // Insert instruction that needs guarding. SPMDCompatibilityTracker.insert(&I); return true; }; @@ -3354,9 +3730,13 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!IsKernelEntry) { updateReachingKernelEntries(A); updateParallelLevels(A); + + if (!ParallelLevels.isValidState()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); } // Callback to check a call instruction. + bool AllParallelRegionStatesWereFixed = true; bool AllSPMDStatesWereFixed = true; auto CheckCallInst = [&](Instruction &I) { auto &CB = cast<CallBase>(I); @@ -3364,13 +3744,37 @@ struct AAKernelInfoFunction : AAKernelInfo { *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); getState() ^= CBAA.getState(); AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint(); + AllParallelRegionStatesWereFixed &= + CBAA.ReachedKnownParallelRegions.isAtFixpoint(); + AllParallelRegionStatesWereFixed &= + CBAA.ReachedUnknownParallelRegions.isAtFixpoint(); return true; }; bool UsedAssumedInformationInCheckCallInst = false; if (!A.checkForAllCallLikeInstructions( - CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) + CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) { + LLVM_DEBUG(dbgs() << TAG + << "Failed to visit all call-like instructions!\n";); return indicatePessimisticFixpoint(); + } + + // If we haven't used any assumed information for the reached parallel + // region states we can fix it. + if (!UsedAssumedInformationInCheckCallInst && + AllParallelRegionStatesWereFixed) { + ReachedKnownParallelRegions.indicateOptimisticFixpoint(); + ReachedUnknownParallelRegions.indicateOptimisticFixpoint(); + } + + // If we are sure there are no parallel regions in the kernel we do not + // want SPMD mode. + if (IsKernelEntry && ReachedUnknownParallelRegions.isAtFixpoint() && + ReachedKnownParallelRegions.isAtFixpoint() && + ReachedUnknownParallelRegions.isValidState() && + ReachedKnownParallelRegions.isValidState() && + !mayContainParallelRegion()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); // If we haven't used any assumed information for the SPMD state we can fix // it. @@ -3469,14 +3873,14 @@ struct AAKernelInfoCallSite : AAKernelInfo { CallBase &CB = cast<CallBase>(getAssociatedValue()); Function *Callee = getAssociatedFunction(); - // Helper to lookup an assumption string. - auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) { - return Fn && hasAssumption(*Fn, AssumptionStr); - }; + auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>( + *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); // Check for SPMD-mode assumptions. - if (HasAssumption(Callee, "ompx_spmd_amenable")) + if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) { SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + indicateOptimisticFixpoint(); + } // First weed out calls we do not care about, that is readonly/readnone // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a @@ -3498,14 +3902,16 @@ struct AAKernelInfoCallSite : AAKernelInfo { // Unknown callees might contain parallel regions, except if they have // an appropriate assumption attached. - if (!(HasAssumption(Callee, "omp_no_openmp") || - HasAssumption(Callee, "omp_no_parallelism"))) + if (!(AssumptionAA.hasAssumption("omp_no_openmp") || + AssumptionAA.hasAssumption("omp_no_parallelism"))) ReachedUnknownParallelRegions.insert(&CB); // If SPMDCompatibilityTracker is not fixed, we need to give up on the // idea we can run something unknown in SPMD-mode. - if (!SPMDCompatibilityTracker.isAtFixpoint()) + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); + } // We have updated the state for this unknown call properly, there won't // be any change so we indicate a fixpoint. @@ -3521,6 +3927,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { switch (RF) { // All the functions we know are compatible with SPMD mode. case OMPRTL___kmpc_is_spmd_exec_mode: + case OMPRTL___kmpc_distribute_static_fini: case OMPRTL___kmpc_for_static_fini: case OMPRTL___kmpc_global_thread_num: case OMPRTL___kmpc_get_hardware_num_threads_in_block: @@ -3531,6 +3938,10 @@ struct AAKernelInfoCallSite : AAKernelInfo { case OMPRTL___kmpc_end_master: case OMPRTL___kmpc_barrier: break; + case OMPRTL___kmpc_distribute_static_init_4: + case OMPRTL___kmpc_distribute_static_init_4u: + case OMPRTL___kmpc_distribute_static_init_8: + case OMPRTL___kmpc_distribute_static_init_8u: case OMPRTL___kmpc_for_static_init_4: case OMPRTL___kmpc_for_static_init_4u: case OMPRTL___kmpc_for_static_init_8: @@ -3548,6 +3959,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { case OMPScheduleType::DistributeChunked: break; default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); break; }; @@ -3580,7 +3992,7 @@ struct AAKernelInfoCallSite : AAKernelInfo { return; default: // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, - // generally. + // generally. However, they do not hide parallel regions. SPMDCompatibilityTracker.insert(&CB); break; } @@ -3700,6 +4112,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { } void initialize(Attributor &A) override { + if (DisableOpenMPOptFolding) + indicatePessimisticFixpoint(); + Function *Callee = getAssociatedFunction(); auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); @@ -3756,11 +4171,24 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { ChangeStatus Changed = ChangeStatus::UNCHANGED; if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) { - Instruction &CB = *getCtxI(); - A.changeValueAfterManifest(CB, **SimplifiedValue); - A.deleteAfterManifest(CB); + Instruction &I = *getCtxI(); + A.changeValueAfterManifest(I, **SimplifiedValue); + A.deleteAfterManifest(I); - LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with " + CallBase *CB = dyn_cast<CallBase>(&I); + auto Remark = [&](OptimizationRemark OR) { + if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue)) + return OR << "Replacing OpenMP runtime call " + << CB->getCalledFunction()->getName() << " with " + << ore::NV("FoldedValue", C->getZExtValue()) << "."; + return OR << "Replacing OpenMP runtime call " + << CB->getCalledFunction()->getName() << "."; + }; + + if (CB && EnableVerboseRemarks) + A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark); + + LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with " << **SimplifiedValue << "\n"); Changed = ChangeStatus::CHANGED; @@ -3994,7 +4422,6 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); - registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level); @@ -4027,7 +4454,8 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); return false; }; - GlobalizationRFI.foreachUse(SCC, CreateAA); + if (!DisableOpenMPOptDeglobalization) + GlobalizationRFI.foreachUse(SCC, CreateAA); // Create an ExecutionDomain AA for every function and a HeapToStack AA for // every function if there is a device kernel. @@ -4039,7 +4467,8 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { continue; A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F)); - A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); + if (!DisableOpenMPOptDeglobalization) + A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); for (auto &I : instructions(*F)) { if (auto *LI = dyn_cast<LoadInst>(&I)) { @@ -4234,12 +4663,24 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { SetVector<Function *> Functions(SCC.begin(), SCC.end()); OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); - unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; + unsigned MaxFixpointIterations = + (isOpenMPDevice(M)) ? SetFixpointIterations : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, MaxFixpointIterations, OREGetter, DEBUG_TYPE); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(true); + + // Optionally inline device functions for potentially better performance. + if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M)) + for (Function &F : M) + if (!F.isDeclaration() && !Kernels.contains(&F) && + !F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr(Attribute::AlwaysInline); + + if (PrintModuleAfterOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M); + if (Changed) return PreservedAnalyses::none(); @@ -4286,12 +4727,17 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, /*CGSCC*/ Functions, Kernels); - unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; + unsigned MaxFixpointIterations = + (isOpenMPDevice(M)) ? SetFixpointIterations : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, MaxFixpointIterations, OREGetter, DEBUG_TYPE); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(false); + + if (PrintModuleAfterOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M); + if (Changed) return PreservedAnalyses::none(); @@ -4352,12 +4798,18 @@ struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { Allocator, /*CGSCC*/ Functions, Kernels); - unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; + unsigned MaxFixpointIterations = + (isOpenMPDevice(M)) ? SetFixpointIterations : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, MaxFixpointIterations, OREGetter, DEBUG_TYPE); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); - return OMPOpt.run(false); + bool Result = OMPOpt.run(false); + + if (PrintModuleAfterOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M); + + return Result; } bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp index d517de38ace3..7402e399a88a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -441,9 +441,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo( }; auto BBProfileCount = [BFI](BasicBlock *BB) { - return BFI->getBlockProfileCount(BB) - ? BFI->getBlockProfileCount(BB).getValue() - : 0; + return BFI->getBlockProfileCount(BB).getValueOr(0); }; // Use the same computeBBInlineCost function to compute the cost savings of @@ -1413,7 +1411,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { computeCallsiteToProfCountMap(Cloner.ClonedFunc, CallSiteToProfCountMap); uint64_t CalleeEntryCountV = - (CalleeEntryCount ? CalleeEntryCount.getCount() : 0); + (CalleeEntryCount ? CalleeEntryCount->getCount() : 0); bool AnyInline = false; for (User *User : Users) { @@ -1461,8 +1459,8 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { if (AnyInline) { Cloner.IsFunctionInlined = true; if (CalleeEntryCount) - Cloner.OrigFunc->setEntryCount( - CalleeEntryCount.setCount(CalleeEntryCountV)); + Cloner.OrigFunc->setEntryCount(Function::ProfileCount( + CalleeEntryCountV, CalleeEntryCount->getType())); OptimizationRemarkEmitter OrigFuncORE(Cloner.OrigFunc); OrigFuncORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", Cloner.OrigFunc) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index aa916345954d..74f68531b89a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -437,6 +437,11 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createCFGSimplificationPass()); // Merge & remove BBs MPM.add(createReassociatePass()); // Reassociate expressions + // The matrix extension can introduce large vector operations early, which can + // benefit from running vector-combine early on. + if (EnableMatrix) + MPM.add(createVectorCombinePass()); + // Begin the loop pass pipeline. if (EnableSimpleLoopUnswitch) { // The simple loop unswitch pass relies on separate cleanup passes. Schedule @@ -1012,7 +1017,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { createPGOIndirectCallPromotionLegacyPass(true, !PGOSampleUse.empty())); // Propage constant function arguments by specializing the functions. - if (EnableFunctionSpecialization) + if (EnableFunctionSpecialization && OptLevel > 2) PM.add(createFunctionSpecializationPass()); // Propagate constants at call sites into the functions they call. This diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp index 081398a390fa..5779553ee732 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp @@ -135,6 +135,7 @@ PreservedAnalyses FunctionSpecializationPass::run(Module &M, return PA; } +namespace { struct FunctionSpecializationLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid FunctionSpecializationLegacyPass() : ModulePass(ID) {} @@ -175,6 +176,7 @@ struct FunctionSpecializationLegacyPass : public ModulePass { return runFunctionSpecialization(M, DL, GetTLI, GetTTI, GetAC, GetAnalysis); } }; +} // namespace char FunctionSpecializationLegacyPass::ID = 0; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp index 55b88ac14da5..bae9a1e27e75 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -32,7 +32,7 @@ ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, if (CalleeName.empty()) return getHottestChildContext(CallSite); - uint32_t Hash = nodeHash(CalleeName, CallSite); + uint64_t Hash = nodeHash(CalleeName, CallSite); auto It = AllChildContext.find(Hash); if (It != AllChildContext.end()) return &It->second; @@ -64,8 +64,8 @@ ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) { ContextTrieNode &ContextTrieNode::moveToChildContext( const LineLocation &CallSite, ContextTrieNode &&NodeToMove, - StringRef ContextStrToRemove, bool DeleteNode) { - uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite); + uint32_t ContextFramesToRemove, bool DeleteNode) { + uint64_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite); assert(!AllChildContext.count(Hash) && "Node to remove must exist"); LineLocation OldCallSite = NodeToMove.CallSiteLoc; ContextTrieNode &OldParentContext = *NodeToMove.getParentContext(); @@ -86,10 +86,10 @@ ContextTrieNode &ContextTrieNode::moveToChildContext( FunctionSamples *FSamples = Node->getFunctionSamples(); if (FSamples) { - FSamples->getContext().promoteOnPath(ContextStrToRemove); + FSamples->getContext().promoteOnPath(ContextFramesToRemove); FSamples->getContext().setState(SyntheticContext); - LLVM_DEBUG(dbgs() << " Context promoted to: " << FSamples->getContext() - << "\n"); + LLVM_DEBUG(dbgs() << " Context promoted to: " + << FSamples->getContext().toString() << "\n"); } for (auto &It : Node->getAllChildContext()) { @@ -108,12 +108,12 @@ ContextTrieNode &ContextTrieNode::moveToChildContext( void ContextTrieNode::removeChildContext(const LineLocation &CallSite, StringRef CalleeName) { - uint32_t Hash = nodeHash(CalleeName, CallSite); + uint64_t Hash = nodeHash(CalleeName, CallSite); // Note this essentially calls dtor and destroys that child context AllChildContext.erase(Hash); } -std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { +std::map<uint64_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { return AllChildContext; } @@ -127,6 +127,15 @@ void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { FuncSamples = FSamples; } +Optional<uint32_t> ContextTrieNode::getFunctionSize() const { return FuncSize; } + +void ContextTrieNode::addFunctionSize(uint32_t FSize) { + if (!FuncSize.hasValue()) + FuncSize = 0; + + FuncSize = FuncSize.getValue() + FSize; +} + LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; } ContextTrieNode *ContextTrieNode::getParentContext() const { @@ -137,9 +146,10 @@ void ContextTrieNode::setParentContext(ContextTrieNode *Parent) { ParentContext = Parent; } -void ContextTrieNode::dump() { +void ContextTrieNode::dumpNode() { dbgs() << "Node: " << FuncName << "\n" << " Callsite: " << CallSiteLoc << "\n" + << " Size: " << FuncSize << "\n" << " Children:\n"; for (auto &It : AllChildContext) { @@ -147,20 +157,38 @@ void ContextTrieNode::dump() { } } -uint32_t ContextTrieNode::nodeHash(StringRef ChildName, +void ContextTrieNode::dumpTree() { + dbgs() << "Context Profile Tree:\n"; + std::queue<ContextTrieNode *> NodeQueue; + NodeQueue.push(this); + + while (!NodeQueue.empty()) { + ContextTrieNode *Node = NodeQueue.front(); + NodeQueue.pop(); + Node->dumpNode(); + + for (auto &It : Node->getAllChildContext()) { + ContextTrieNode *ChildNode = &It.second; + NodeQueue.push(ChildNode); + } + } +} + +uint64_t ContextTrieNode::nodeHash(StringRef ChildName, const LineLocation &Callsite) { // We still use child's name for child hash, this is // because for children of root node, we don't have // different line/discriminator, and we'll rely on name // to differentiate children. - uint32_t NameHash = std::hash<std::string>{}(ChildName.str()); - uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator; + uint64_t NameHash = std::hash<std::string>{}(ChildName.str()); + uint64_t LocId = + (((uint64_t)Callsite.LineOffset) << 32) | Callsite.Discriminator; return NameHash + (LocId << 5) + LocId; } ContextTrieNode *ContextTrieNode::getOrCreateChildContext( const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { - uint32_t Hash = nodeHash(CalleeName, CallSite); + uint64_t Hash = nodeHash(CalleeName, CallSite); auto It = AllChildContext.find(Hash); if (It != AllChildContext.end()) { assert(It->second.getFuncName() == CalleeName && @@ -177,13 +205,16 @@ ContextTrieNode *ContextTrieNode::getOrCreateChildContext( // Profiler tracker than manages profiles and its associated context SampleContextTracker::SampleContextTracker( - StringMap<FunctionSamples> &Profiles) { + SampleProfileMap &Profiles, + const DenseMap<uint64_t, StringRef> *GUIDToFuncNameMap) + : GUIDToFuncNameMap(GUIDToFuncNameMap) { for (auto &FuncSample : Profiles) { FunctionSamples *FSamples = &FuncSample.second; - SampleContext Context(FuncSample.first(), RawContext); - LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n"); + SampleContext Context = FuncSample.first; + LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString() + << "\n"); if (!Context.isBaseContext()) - FuncToCtxtProfiles[Context.getNameWithoutContext()].push_back(FSamples); + FuncToCtxtProfiles[Context.getName()].insert(FSamples); ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); assert(!NewNode->getFunctionSamples() && "New node can't have sample profile"); @@ -200,6 +231,10 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, return nullptr; CalleeName = FunctionSamples::getCanonicalFnName(CalleeName); + // Convert real function names to MD5 names, if the input profile is + // MD5-based. + std::string FGUID; + CalleeName = getRepInFormat(CalleeName, FunctionSamples::UseMD5, FGUID); // For indirect call, CalleeName will be empty, in which case the context // profile for callee with largest total samples will be returned. @@ -207,7 +242,8 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, if (CalleeContext) { FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); LLVM_DEBUG(if (FSamples) { - dbgs() << " Callee context found: " << FSamples->getContext() << "\n"; + dbgs() << " Callee context found: " << FSamples->getContext().toString() + << "\n"; }); return FSamples; } @@ -285,6 +321,11 @@ FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, bool MergeContext) { LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); + // Convert real function names to MD5 names, if the input profile is + // MD5-based. + std::string FGUID; + Name = getRepInFormat(Name, FunctionSamples::UseMD5, FGUID); + // Base profile is top-level node (child of root node), so try to retrieve // existing top-level node for given function first. If it exists, it could be // that we've merged base profile before, or there's actually context-less @@ -299,14 +340,14 @@ FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, // into base profile. for (auto *CSamples : FuncToCtxtProfiles[Name]) { SampleContext &Context = CSamples->getContext(); - ContextTrieNode *FromNode = getContextFor(Context); - if (FromNode == Node) - continue; - // Skip inlined context profile and also don't re-merge any context if (Context.hasState(InlinedContext) || Context.hasState(MergedContext)) continue; + ContextTrieNode *FromNode = getContextFor(Context); + if (FromNode == Node) + continue; + ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode); assert((!Node || Node == &ToNode) && "Expect only one base profile"); Node = &ToNode; @@ -324,7 +365,7 @@ void SampleContextTracker::markContextSamplesInlined( const FunctionSamples *InlinedSamples) { assert(InlinedSamples && "Expect non-null inlined samples"); LLVM_DEBUG(dbgs() << "Marking context profile as inlined: " - << InlinedSamples->getContext() << "\n"); + << InlinedSamples->getContext().toString() << "\n"); InlinedSamples->getContext().setState(InlinedContext); } @@ -376,30 +417,23 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples(); assert(FromSamples && "Shouldn't promote a context without profile"); LLVM_DEBUG(dbgs() << " Found context tree root to promote: " - << FromSamples->getContext() << "\n"); + << FromSamples->getContext().toString() << "\n"); assert(!FromSamples->getContext().hasState(InlinedContext) && "Shouldn't promote inlined context profile"); - StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext(); + uint32_t ContextFramesToRemove = + FromSamples->getContext().getContextFrames().size() - 1; return promoteMergeContextSamplesTree(NodeToPromo, RootContext, - ContextStrToRemove); + ContextFramesToRemove); } -void SampleContextTracker::dump() { - dbgs() << "Context Profile Tree:\n"; - std::queue<ContextTrieNode *> NodeQueue; - NodeQueue.push(&RootContext); - - while (!NodeQueue.empty()) { - ContextTrieNode *Node = NodeQueue.front(); - NodeQueue.pop(); - Node->dump(); +void SampleContextTracker::dump() { RootContext.dumpTree(); } - for (auto &It : Node->getAllChildContext()) { - ContextTrieNode *ChildNode = &It.second; - NodeQueue.push(ChildNode); - } - } +StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const { + if (!FunctionSamples::UseMD5) + return Node->getFuncName(); + assert(GUIDToFuncNameMap && "GUIDToFuncNameMap needs to be populated first"); + return GUIDToFuncNameMap->lookup(std::stoull(Node->getFuncName().data())); } ContextTrieNode * @@ -444,11 +478,22 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { RootName = PrevDIL->getScope()->getSubprogram()->getName(); S.push_back(std::make_pair(LineLocation(0, 0), RootName)); + // Convert real function names to MD5 names, if the input profile is + // MD5-based. + std::list<std::string> MD5Names; + if (FunctionSamples::UseMD5) { + for (auto &Location : S) { + MD5Names.emplace_back(); + getRepInFormat(Location.second, FunctionSamples::UseMD5, MD5Names.back()); + Location.second = MD5Names.back(); + } + } + ContextTrieNode *ContextNode = &RootContext; int I = S.size(); while (--I >= 0 && ContextNode) { LineLocation &CallSite = S[I].first; - StringRef &CalleeName = S[I].second; + StringRef CalleeName = S[I].second; ContextNode = ContextNode->getChildContext(CallSite, CalleeName); } @@ -462,27 +507,18 @@ ContextTrieNode * SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, bool AllowCreate) { ContextTrieNode *ContextNode = &RootContext; - StringRef ContextRemain = Context; - StringRef ChildContext; - StringRef CalleeName; LineLocation CallSiteLoc(0, 0); - while (ContextNode && !ContextRemain.empty()) { - auto ContextSplit = SampleContext::splitContextString(ContextRemain); - ChildContext = ContextSplit.first; - ContextRemain = ContextSplit.second; - LineLocation NextCallSiteLoc(0, 0); - SampleContext::decodeContextString(ChildContext, CalleeName, - NextCallSiteLoc); - + for (auto &Callsite : Context.getContextFrames()) { // Create child node at parent line/disc location if (AllowCreate) { ContextNode = - ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName); + ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.FuncName); } else { - ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName); + ContextNode = + ContextNode->getChildContext(CallSiteLoc, Callsite.FuncName); } - CallSiteLoc = NextCallSiteLoc; + CallSiteLoc = Callsite.Location; } assert((!AllowCreate || ContextNode) && @@ -502,7 +538,7 @@ ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, ContextTrieNode &ToNode, - StringRef ContextStrToRemove) { + uint32_t ContextFramesToRemove) { FunctionSamples *FromSamples = FromNode.getFunctionSamples(); FunctionSamples *ToSamples = ToNode.getFunctionSamples(); if (FromSamples && ToSamples) { @@ -510,19 +546,21 @@ void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, ToSamples->merge(*FromSamples); ToSamples->getContext().setState(SyntheticContext); FromSamples->getContext().setState(MergedContext); + if (FromSamples->getContext().hasAttribute(ContextShouldBeInlined)) + ToSamples->getContext().setAttribute(ContextShouldBeInlined); } else if (FromSamples) { // Transfer FromSamples from FromNode to ToNode ToNode.setFunctionSamples(FromSamples); FromSamples->getContext().setState(SyntheticContext); - FromSamples->getContext().promoteOnPath(ContextStrToRemove); + FromSamples->getContext().promoteOnPath(ContextFramesToRemove); FromNode.setFunctionSamples(nullptr); } } ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent, - StringRef ContextStrToRemove) { - assert(!ContextStrToRemove.empty() && "Context to remove can't be empty"); + uint32_t ContextFramesToRemove) { + assert(ContextFramesToRemove && "Context to remove can't be empty"); // Ignore call site location if destination is top level under root LineLocation NewCallSiteLoc = LineLocation(0, 0); @@ -540,21 +578,21 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( // Do not delete node to move from its parent here because // caller is iterating over children of that parent node. ToNode = &ToNodeParent.moveToChildContext( - NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false); + NewCallSiteLoc, std::move(FromNode), ContextFramesToRemove, false); } else { // Destination node exists, merge samples for the context tree - mergeContextNode(FromNode, *ToNode, ContextStrToRemove); + mergeContextNode(FromNode, *ToNode, ContextFramesToRemove); LLVM_DEBUG({ if (ToNode->getFunctionSamples()) dbgs() << " Context promoted and merged to: " - << ToNode->getFunctionSamples()->getContext() << "\n"; + << ToNode->getFunctionSamples()->getContext().toString() << "\n"; }); // Recursively promote and merge children for (auto &It : FromNode.getAllChildContext()) { ContextTrieNode &FromChildNode = It.second; promoteMergeContextSamplesTree(FromChildNode, *ToNode, - ContextStrToRemove); + ContextFramesToRemove); } // Remove children once they're all merged diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp index 8e9c79fc7bbb..a961c47a7501 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -143,6 +143,12 @@ static cl::opt<bool> ProfileSampleAccurate( "callsite and function as having 0 samples. Otherwise, treat " "un-sampled callsites and functions conservatively as unknown. ")); +static cl::opt<bool> ProfileSampleBlockAccurate( + "profile-sample-block-accurate", cl::Hidden, cl::init(false), + cl::desc("If the sample profile is accurate, we will mark all un-sampled " + "branches and calls as having 0 samples. Otherwise, treat " + "them conservatively as unknown. ")); + static cl::opt<bool> ProfileAccurateForSymsInList( "profile-accurate-for-symsinlist", cl::Hidden, cl::ZeroOrMore, cl::init(true), @@ -214,6 +220,16 @@ static cl::opt<bool> CallsitePrioritizedInline( cl::desc("Use call site prioritized inlining for sample profile loader." "Currently only CSSPGO is supported.")); +static cl::opt<bool> UsePreInlinerDecision( + "sample-profile-use-preinliner", cl::Hidden, cl::ZeroOrMore, + cl::init(false), + cl::desc("Use the preinliner decisions stored in profile context.")); + +static cl::opt<bool> AllowRecursiveInline( + "sample-profile-recursive-inline", cl::Hidden, cl::ZeroOrMore, + cl::init(false), + cl::desc("Allow sample loader inliner to inline recursive calls.")); + static cl::opt<std::string> ProfileInlineReplayFile( "sample-profile-inline-replay", cl::init(""), cl::value_desc("filename"), cl::desc( @@ -221,6 +237,50 @@ static cl::opt<std::string> ProfileInlineReplayFile( "by inlining from sample profile loader."), cl::Hidden); +static cl::opt<ReplayInlinerSettings::Scope> ProfileInlineReplayScope( + "sample-profile-inline-replay-scope", + cl::init(ReplayInlinerSettings::Scope::Function), + cl::values(clEnumValN(ReplayInlinerSettings::Scope::Function, "Function", + "Replay on functions that have remarks associated " + "with them (default)"), + clEnumValN(ReplayInlinerSettings::Scope::Module, "Module", + "Replay on the entire module")), + cl::desc("Whether inline replay should be applied to the entire " + "Module or just the Functions (default) that are present as " + "callers in remarks during sample profile inlining."), + cl::Hidden); + +static cl::opt<ReplayInlinerSettings::Fallback> ProfileInlineReplayFallback( + "sample-profile-inline-replay-fallback", + cl::init(ReplayInlinerSettings::Fallback::Original), + cl::values( + clEnumValN( + ReplayInlinerSettings::Fallback::Original, "Original", + "All decisions not in replay send to original advisor (default)"), + clEnumValN(ReplayInlinerSettings::Fallback::AlwaysInline, + "AlwaysInline", "All decisions not in replay are inlined"), + clEnumValN(ReplayInlinerSettings::Fallback::NeverInline, "NeverInline", + "All decisions not in replay are not inlined")), + cl::desc("How sample profile inline replay treats sites that don't come " + "from the replay. Original: defers to original advisor, " + "AlwaysInline: inline all sites not in replay, NeverInline: " + "inline no sites not in replay"), + cl::Hidden); + +static cl::opt<CallSiteFormat::Format> ProfileInlineReplayFormat( + "sample-profile-inline-replay-format", + cl::init(CallSiteFormat::Format::LineColumnDiscriminator), + cl::values( + clEnumValN(CallSiteFormat::Format::Line, "Line", "<Line Number>"), + clEnumValN(CallSiteFormat::Format::LineColumn, "LineColumn", + "<Line Number>:<Column Number>"), + clEnumValN(CallSiteFormat::Format::LineDiscriminator, + "LineDiscriminator", "<Line Number>.<Discriminator>"), + clEnumValN(CallSiteFormat::Format::LineColumnDiscriminator, + "LineColumnDiscriminator", + "<Line Number>:<Column Number>.<Discriminator> (default)")), + cl::desc("How sample profile inline replay file is formatted"), cl::Hidden); + static cl::opt<unsigned> MaxNumPromotions("sample-profile-icp-max-prom", cl::init(3), cl::Hidden, cl::ZeroOrMore, @@ -358,10 +418,10 @@ public: std::function<AssumptionCache &(Function &)> GetAssumptionCache, std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo, std::function<const TargetLibraryInfo &(Function &)> GetTLI) - : SampleProfileLoaderBaseImpl(std::string(Name)), + : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)), GetAC(std::move(GetAssumptionCache)), GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)), - RemappingFilename(std::string(RemapName)), LTOPhase(LTOPhase) {} + LTOPhase(LTOPhase) {} bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr); bool runOnModule(Module &M, ModuleAnalysisManager *AM, @@ -377,7 +437,7 @@ protected: findFunctionSamples(const Instruction &I) const override; std::vector<const FunctionSamples *> findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const; - void findExternalInlineCandidate(const FunctionSamples *Samples, + void findExternalInlineCandidate(CallBase *CB, const FunctionSamples *Samples, DenseSet<GlobalValue::GUID> &InlinedGUIDs, const StringMap<Function *> &SymbolMap, uint64_t Threshold); @@ -385,8 +445,11 @@ protected: bool tryPromoteAndInlineCandidate( Function &F, InlineCandidate &Candidate, uint64_t SumOrigin, uint64_t &Sum, SmallVector<CallBase *, 8> *InlinedCallSites = nullptr); + bool inlineHotFunctions(Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs); + Optional<InlineCost> getExternalInlineAdvisorCost(CallBase &CB); + bool getExternalInlineAdvisorShouldInline(CallBase &CB); InlineCost shouldInlineCandidate(InlineCandidate &Candidate); bool getInlineCandidate(InlineCandidate *NewCandidate, CallBase *CB); bool @@ -417,9 +480,6 @@ protected: /// Profile tracker for different context. std::unique_ptr<SampleContextTracker> ContextTracker; - /// Name of the profile remapping file to load. - std::string RemappingFilename; - /// Flag indicating whether input profile is context-sensitive bool ProfileIsCS = false; @@ -464,7 +524,7 @@ protected: bool ProfAccForSymsInList; // External inline advisor used to replay inline decision from remarks. - std::unique_ptr<ReplayInlineAdvisor> ExternalInlineAdvisor; + std::unique_ptr<InlineAdvisor> ExternalInlineAdvisor; // A pseudo probe helper to correlate the imported sample counts. std::unique_ptr<PseudoProbeManager> ProbeManager; @@ -953,8 +1013,24 @@ void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates( } void SampleProfileLoader::findExternalInlineCandidate( - const FunctionSamples *Samples, DenseSet<GlobalValue::GUID> &InlinedGUIDs, + CallBase *CB, const FunctionSamples *Samples, + DenseSet<GlobalValue::GUID> &InlinedGUIDs, const StringMap<Function *> &SymbolMap, uint64_t Threshold) { + + // If ExternalInlineAdvisor wants to inline an external function + // make sure it's imported + if (CB && getExternalInlineAdvisorShouldInline(*CB)) { + // Samples may not exist for replayed function, if so + // just add the direct GUID and move on + if (!Samples) { + InlinedGUIDs.insert( + FunctionSamples::getGUID(CB->getCalledFunction()->getName())); + return; + } + // Otherwise, drop the threshold to import everything that we can + Threshold = 0; + } + assert(Samples && "expect non-null caller profile"); // For AutoFDO profile, retrieve candidate profiles by walking over @@ -975,14 +1051,21 @@ void SampleProfileLoader::findExternalInlineCandidate( // For CSSPGO profile, retrieve candidate profile by walking over the // trie built for context profile. Note that also take call targets // even if callee doesn't have a corresponding context profile. - if (!CalleeSample || CalleeSample->getEntrySamples() < Threshold) + if (!CalleeSample) + continue; + + // If pre-inliner decision is used, honor that for importing as well. + bool PreInline = + UsePreInlinerDecision && + CalleeSample->getContext().hasAttribute(ContextShouldBeInlined); + if (!PreInline && CalleeSample->getEntrySamples() < Threshold) continue; StringRef Name = CalleeSample->getFuncName(); Function *Func = SymbolMap.lookup(Name); // Add to the import list only when it's defined out of module. if (!Func || Func->isDeclaration()) - InlinedGUIDs.insert(FunctionSamples::getGUID(Name)); + InlinedGUIDs.insert(FunctionSamples::getGUID(CalleeSample->getName())); // Import hot CallTargets, which may not be available in IR because full // profile annotation cannot be done until backend compilation in ThinLTO. @@ -992,7 +1075,7 @@ void SampleProfileLoader::findExternalInlineCandidate( StringRef CalleeName = CalleeSample->getFuncName(TS.getKey()); const Function *Callee = SymbolMap.lookup(CalleeName); if (!Callee || Callee->isDeclaration()) - InlinedGUIDs.insert(FunctionSamples::getGUID(CalleeName)); + InlinedGUIDs.insert(FunctionSamples::getGUID(TS.getKey())); } // Import hot child context profile associted with callees. Note that this @@ -1042,16 +1125,20 @@ bool SampleProfileLoader::inlineHotFunctions( for (auto &I : BB.getInstList()) { const FunctionSamples *FS = nullptr; if (auto *CB = dyn_cast<CallBase>(&I)) { - if (!isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(*CB))) { - assert((!FunctionSamples::UseMD5 || FS->GUIDToFuncNameMap) && - "GUIDToFuncNameMap has to be populated"); - AllCandidates.push_back(CB); - if (FS->getEntrySamples() > 0 || ProfileIsCS) - LocalNotInlinedCallSites.try_emplace(CB, FS); - if (callsiteIsHot(FS, PSI, ProfAccForSymsInList)) - Hot = true; - else if (shouldInlineColdCallee(*CB)) - ColdCandidates.push_back(CB); + if (!isa<IntrinsicInst>(I)) { + if ((FS = findCalleeFunctionSamples(*CB))) { + assert((!FunctionSamples::UseMD5 || FS->GUIDToFuncNameMap) && + "GUIDToFuncNameMap has to be populated"); + AllCandidates.push_back(CB); + if (FS->getEntrySamples() > 0 || ProfileIsCS) + LocalNotInlinedCallSites.try_emplace(CB, FS); + if (callsiteIsHot(FS, PSI, ProfAccForSymsInList)) + Hot = true; + else if (shouldInlineColdCallee(*CB)) + ColdCandidates.push_back(CB); + } else if (getExternalInlineAdvisorShouldInline(*CB)) { + AllCandidates.push_back(CB); + } } } } @@ -1078,7 +1165,7 @@ bool SampleProfileLoader::inlineHotFunctions( for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { uint64_t SumOrigin = Sum; if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(FS, InlinedGUIDs, SymbolMap, + findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap, PSI->getOrCompHotCountThreshold()); continue; } @@ -1098,8 +1185,8 @@ bool SampleProfileLoader::inlineHotFunctions( LocalChanged = true; } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(findCalleeFunctionSamples(*I), InlinedGUIDs, - SymbolMap, + findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), + InlinedGUIDs, SymbolMap, PSI->getOrCompHotCountThreshold()); } } @@ -1184,8 +1271,8 @@ bool SampleProfileLoader::tryInlineCandidate( *CalledFunction); // The call to InlineFunction erases I, so we can't pass it here. - emitInlinedInto(*ORE, DLoc, BB, *CalledFunction, *BB->getParent(), Cost, - true, CSINLINE_DEBUG); + emitInlinedIntoBasedOnCost(*ORE, DLoc, BB, *CalledFunction, + *BB->getParent(), Cost, true, CSINLINE_DEBUG); // Now populate the list of newly exposed call sites. if (InlinedCallSites) { @@ -1228,7 +1315,9 @@ bool SampleProfileLoader::getInlineCandidate(InlineCandidate *NewCandidate, // Find the callee's profile. For indirect call, find hottest target profile. const FunctionSamples *CalleeSamples = findCalleeFunctionSamples(*CB); - if (!CalleeSamples) + // If ExternalInlineAdvisor wants to inline this site, do so even + // if Samples are not present. + if (!CalleeSamples && !getExternalInlineAdvisorShouldInline(*CB)) return false; float Factor = 1.0; @@ -1247,19 +1336,34 @@ bool SampleProfileLoader::getInlineCandidate(InlineCandidate *NewCandidate, return true; } -InlineCost -SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { +Optional<InlineCost> +SampleProfileLoader::getExternalInlineAdvisorCost(CallBase &CB) { std::unique_ptr<InlineAdvice> Advice = nullptr; if (ExternalInlineAdvisor) { - Advice = ExternalInlineAdvisor->getAdvice(*Candidate.CallInstr); - if (!Advice->isInliningRecommended()) { - Advice->recordUnattemptedInlining(); - return InlineCost::getNever("not previously inlined"); + Advice = ExternalInlineAdvisor->getAdvice(CB); + if (Advice) { + if (!Advice->isInliningRecommended()) { + Advice->recordUnattemptedInlining(); + return InlineCost::getNever("not previously inlined"); + } + Advice->recordInlining(); + return InlineCost::getAlways("previously inlined"); } - Advice->recordInlining(); - return InlineCost::getAlways("previously inlined"); } + return {}; +} + +bool SampleProfileLoader::getExternalInlineAdvisorShouldInline(CallBase &CB) { + Optional<InlineCost> Cost = getExternalInlineAdvisorCost(CB); + return Cost ? !!Cost.getValue() : false; +} + +InlineCost +SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { + if (Optional<InlineCost> ReplayCost = + getExternalInlineAdvisorCost(*Candidate.CallInstr)) + return ReplayCost.getValue(); // Adjust threshold based on call site hotness, only do this for callsite // prioritized inliner because otherwise cost-benefit check is done earlier. int SampleThreshold = SampleColdCallSiteThreshold; @@ -1274,7 +1378,9 @@ SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { assert(Callee && "Expect a definition for inline candidate of direct call"); InlineParams Params = getInlineParams(); + // We will ignore the threshold from inline cost, so always get full cost. Params.ComputeFullInlineCost = true; + Params.AllowRecursiveCall = AllowRecursiveInline; // Checks if there is anything in the reachable portion of the callee at // this callsite that makes this inlining potentially illegal. Need to // set ComputeFullInlineCost, otherwise getInlineCost may return early @@ -1288,6 +1394,25 @@ SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { if (Cost.isNever() || Cost.isAlways()) return Cost; + // With CSSPGO, the preinliner in llvm-profgen can estimate global inline + // decisions based on hotness as well as accurate function byte sizes for + // given context using function/inlinee sizes from previous build. It + // stores the decision in profile, and also adjust/merge context profile + // aiming at better context-sensitive post-inline profile quality, assuming + // all inline decision estimates are going to be honored by compiler. Here + // we replay that inline decision under `sample-profile-use-preinliner`. + // Note that we don't need to handle negative decision from preinliner as + // context profile for not inlined calls are merged by preinliner already. + if (UsePreInlinerDecision && Candidate.CalleeSamples) { + // Once two node are merged due to promotion, we're losing some context + // so the original context-sensitive preinliner decision should be ignored + // for SyntheticContext. + SampleContext &Context = Candidate.CalleeSamples->getContext(); + if (!Context.hasState(SyntheticContext) && + Context.hasAttribute(ContextShouldBeInlined)) + return InlineCost::getAlways("preinliner"); + } + // For old FDO inliner, we inline the call site as long as cost is not // "Never". The cost-benefit check is done earlier. if (!CallsitePrioritizedInline) { @@ -1357,7 +1482,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( for (const auto *FS : CalleeSamples) { // TODO: Consider disable pre-lTO ICP for MonoLTO as well if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(FS, InlinedGUIDs, SymbolMap, + findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap, PSI->getOrCompHotCountThreshold()); continue; } @@ -1405,8 +1530,9 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( Changed = true; } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(Candidate.CalleeSamples, InlinedGUIDs, - SymbolMap, PSI->getOrCompHotCountThreshold()); + findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), + InlinedGUIDs, SymbolMap, + PSI->getOrCompHotCountThreshold()); } } @@ -1494,7 +1620,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { {static_cast<uint32_t>(BlockWeights[BB])})); } } - } else if (OverwriteExistingWeights) { + } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { // Set profile metadata (possibly annotated by LTO prelink) to zero or // clear it for cold code. for (auto &I : BB->getInstList()) { @@ -1792,11 +1918,13 @@ bool SampleProfileLoader::doInitialization(Module &M, } if (FAM && !ProfileInlineReplayFile.empty()) { - ExternalInlineAdvisor = std::make_unique<ReplayInlineAdvisor>( - M, *FAM, Ctx, /*OriginalAdvisor=*/nullptr, ProfileInlineReplayFile, + ExternalInlineAdvisor = getReplayInlineAdvisor( + M, *FAM, Ctx, /*OriginalAdvisor=*/nullptr, + ReplayInlinerSettings{ProfileInlineReplayFile, + ProfileInlineReplayScope, + ProfileInlineReplayFallback, + {ProfileInlineReplayFormat}}, /*EmitRemarks=*/false); - if (!ExternalInlineAdvisor->areReplayRemarksLoaded()) - ExternalInlineAdvisor.reset(); } // Apply tweaks if context-sensitive profile is available. @@ -1810,13 +1938,21 @@ bool SampleProfileLoader::doInitialization(Module &M, if (!CallsitePrioritizedInline.getNumOccurrences()) CallsitePrioritizedInline = true; + // For CSSPGO, use preinliner decision by default when available. + if (!UsePreInlinerDecision.getNumOccurrences()) + UsePreInlinerDecision = true; + + // For CSSPGO, we also allow recursive inline to best use context profile. + if (!AllowRecursiveInline.getNumOccurrences()) + AllowRecursiveInline = true; + // Enable iterative-BFI by default for CSSPGO. if (!UseIterativeBFIInference.getNumOccurrences()) UseIterativeBFIInference = true; // Tracker for profiles under different context - ContextTracker = - std::make_unique<SampleContextTracker>(Reader->getProfiles()); + ContextTracker = std::make_unique<SampleContextTracker>( + Reader->getProfiles(), &GUIDToFuncNameMap); } // Load pseudo probe descriptors for probe-based function samples. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index 08d316337ef5..21395460bccb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -415,9 +415,7 @@ void PseudoProbeUpdatePass::runOnFunction(Function &F, FunctionAnalysisManager &FAM) { BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); auto BBProfileCount = [&BFI](BasicBlock *BB) { - return BFI.getBlockProfileCount(BB) - ? BFI.getBlockProfileCount(BB).getValue() - : 0; + return BFI.getBlockProfileCount(BB).getValueOr(0); }; // Collect the sum of execution weight for each probe. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp index 655a7a404951..0f2412dce1c9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripDeadPrototypes.cpp @@ -30,23 +30,20 @@ static bool stripDeadPrototypes(Module &M) { bool MadeChange = false; // Erase dead function prototypes. - for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { - Function *F = &*I++; + for (Function &F : llvm::make_early_inc_range(M)) { // Function must be a prototype and unused. - if (F->isDeclaration() && F->use_empty()) { - F->eraseFromParent(); + if (F.isDeclaration() && F.use_empty()) { + F.eraseFromParent(); ++NumDeadPrototypes; MadeChange = true; } } // Erase dead global var prototypes. - for (Module::global_iterator I = M.global_begin(), E = M.global_end(); - I != E; ) { - GlobalVariable *GV = &*I++; + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { // Global must be a prototype and unused. - if (GV->isDeclaration() && GV->use_empty()) - GV->eraseFromParent(); + if (GV.isDeclaration() && GV.use_empty()) + GV.eraseFromParent(); } // Return an indication of whether we changed anything or not. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp index 168740a1158e..9d4e9464f361 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -214,13 +214,13 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { findUsedValues(M.getGlobalVariable("llvm.compiler.used"), llvmUsedValues); for (GlobalVariable &GV : M.globals()) { - if (GV.hasLocalLinkage() && llvmUsedValues.count(&GV) == 0) + if (GV.hasLocalLinkage() && !llvmUsedValues.contains(&GV)) if (!PreserveDbgInfo || !GV.getName().startswith("llvm.dbg")) GV.setName(""); // Internal symbols can't participate in linkage } for (Function &I : M) { - if (I.hasLocalLinkage() && llvmUsedValues.count(&I) == 0) + if (I.hasLocalLinkage() && !llvmUsedValues.contains(&I)) if (!PreserveDbgInfo || !I.getName().startswith("llvm.dbg")) I.setName(""); // Internal symbols can't participate in linkage if (auto *Symtab = I.getValueSymbolTable()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index eea848d3eb2f..0cc1b37844f6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -164,8 +164,7 @@ void simplifyExternals(Module &M) { FunctionType *EmptyFT = FunctionType::get(Type::getVoidTy(M.getContext()), false); - for (auto I = M.begin(), E = M.end(); I != E;) { - Function &F = *I++; + for (Function &F : llvm::make_early_inc_range(M)) { if (F.isDeclaration() && F.use_empty()) { F.eraseFromParent(); continue; @@ -181,16 +180,15 @@ void simplifyExternals(Module &M) { F.getAddressSpace(), "", &M); NewF->copyAttributesFrom(&F); // Only copy function attribtues. - NewF->setAttributes( - AttributeList::get(M.getContext(), AttributeList::FunctionIndex, - F.getAttributes().getFnAttributes())); + NewF->setAttributes(AttributeList::get(M.getContext(), + AttributeList::FunctionIndex, + F.getAttributes().getFnAttrs())); NewF->takeName(&F); F.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, F.getType())); F.eraseFromParent(); } - for (auto I = M.global_begin(), E = M.global_end(); I != E;) { - GlobalVariable &GV = *I++; + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { if (GV.isDeclaration() && GV.use_empty()) { GV.eraseFromParent(); continue; @@ -325,7 +323,8 @@ void splitAndWriteThinLTOBitcode( return true; if (auto *F = dyn_cast<Function>(GV)) return EligibleVirtualFns.count(F); - if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject())) + if (auto *GVar = + dyn_cast_or_null<GlobalVariable>(GV->getAliaseeObject())) return HasTypeMetadata(GVar); return false; })); @@ -354,7 +353,7 @@ void splitAndWriteThinLTOBitcode( // Remove all globals with type metadata, globals with comdats that live in // MergedM, and aliases pointing to such globals from the thin LTO module. filterModule(&M, [&](const GlobalValue *GV) { - if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject())) + if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getAliaseeObject())) if (HasTypeMetadata(GVar)) return false; if (const auto *C = GV->getComdat()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 7a8946110785..61054e7ae46f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -1288,7 +1288,7 @@ void DevirtModule::tryICallBranchFunnel( M.getDataLayout().getProgramAddressSpace(), "branch_funnel", &M); } - JT->addAttribute(1, Attribute::Nest); + JT->addParamAttr(0, Attribute::Nest); std::vector<Value *> JTArgs; JTArgs.push_back(JT->arg_begin()); @@ -1361,10 +1361,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, M.getContext(), ArrayRef<Attribute>{Attribute::get( M.getContext(), Attribute::Nest)})); for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) - NewArgAttrs.push_back(Attrs.getParamAttributes(I)); + NewArgAttrs.push_back(Attrs.getParamAttrs(I)); NewCS->setAttributes( - AttributeList::get(M.getContext(), Attrs.getFnAttributes(), - Attrs.getRetAttributes(), NewArgAttrs)); + AttributeList::get(M.getContext(), Attrs.getFnAttrs(), + Attrs.getRetAttrs(), NewArgAttrs)); CB.replaceAllUsesWith(NewCS); CB.eraseFromParent(); @@ -1786,10 +1786,8 @@ void DevirtModule::scanTypeTestUsers( // points to a member of the type identifier %md. Group calls by (type ID, // offset) pair (effectively the identity of the virtual function) and store // to CallSlots. - for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); - I != E;) { - auto CI = dyn_cast<CallInst>(I->getUser()); - ++I; + for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { + auto *CI = dyn_cast<CallInst>(U.getUser()); if (!CI) continue; @@ -1858,11 +1856,8 @@ void DevirtModule::scanTypeTestUsers( void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); - for (auto I = TypeCheckedLoadFunc->use_begin(), - E = TypeCheckedLoadFunc->use_end(); - I != E;) { - auto CI = dyn_cast<CallInst>(I->getUser()); - ++I; + for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) { + auto *CI = dyn_cast<CallInst>(U.getUser()); if (!CI) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index d01a021bf3f4..eb1b8a29cfc5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -939,7 +939,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // add (xor X, LowMaskC), C --> sub (LowMaskC + C), X if (C2->isMask()) { KnownBits LHSKnown = computeKnownBits(X, 0, &Add); - if ((*C2 | LHSKnown.Zero).isAllOnesValue()) + if ((*C2 | LHSKnown.Zero).isAllOnes()) return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X); } @@ -963,7 +963,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } - if (C->isOneValue() && Op0->hasOneUse()) { + if (C->isOne() && Op0->hasOneUse()) { // add (sext i1 X), 1 --> zext (not X) // TODO: The smallest IR representation is (select X, 0, 1), and that would // not require the one-use check. But we need to remove a transform in @@ -1355,6 +1355,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (match(RHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(LHS))))) return BinaryOperator::CreateAdd(A, Builder.CreateShl(LHS, 1, "reass.add")); + { + // (A + C1) + (C2 - B) --> (A - B) + (C1 + C2) + Constant *C1, *C2; + if (match(&I, m_c_Add(m_Add(m_Value(A), m_ImmConstant(C1)), + m_Sub(m_ImmConstant(C2), m_Value(B)))) && + (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *Sub = Builder.CreateSub(A, B); + return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2)); + } + } + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); @@ -1817,12 +1828,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { if (match(Op0, m_AllOnes())) return BinaryOperator::CreateNot(Op1); - // (~X) - (~Y) --> Y - X - Value *X, *Y; - if (match(Op0, m_Not(m_Value(X))) && match(Op1, m_Not(m_Value(Y)))) - return BinaryOperator::CreateSub(Y, X); - // (X + -1) - Y --> ~Y + X + Value *X, *Y; if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes())))) return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X); @@ -1843,6 +1850,17 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return BinaryOperator::CreateSub(X, Add); } + // (~X) - (~Y) --> Y - X + // This is placed after the other reassociations and explicitly excludes a + // sub-of-sub pattern to avoid infinite looping. + if (isFreeToInvert(Op0, Op0->hasOneUse()) && + isFreeToInvert(Op1, Op1->hasOneUse()) && + !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) { + Value *NotOp0 = Builder.CreateNot(Op0); + Value *NotOp1 = Builder.CreateNot(Op1); + return BinaryOperator::CreateSub(NotOp1, NotOp0); + } + auto m_AddRdx = [](Value *&Vec) { return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec))); }; @@ -1892,7 +1910,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) + if ((*Op0C | RHSKnown.Zero).isAllOnes()) return BinaryOperator::CreateXor(Op1, Op0); } @@ -2039,12 +2057,31 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAnd( Op0, Builder.CreateNot(Y, Y->getName() + ".not")); + // ~X - Min/Max(~X, Y) -> ~Min/Max(X, ~Y) - X + // ~X - Min/Max(Y, ~X) -> ~Min/Max(X, ~Y) - X + // Min/Max(~X, Y) - ~X -> X - ~Min/Max(X, ~Y) + // Min/Max(Y, ~X) - ~X -> X - ~Min/Max(X, ~Y) + // As long as Y is freely invertible, this will be neutral or a win. + // Note: We don't generate the inverse max/min, just create the 'not' of + // it and let other folds do the rest. + if (match(Op0, m_Not(m_Value(X))) && + match(Op1, m_c_MaxOrMin(m_Specific(Op0), m_Value(Y))) && + !Op0->hasNUsesOrMore(3) && isFreeToInvert(Y, Y->hasOneUse())) { + Value *Not = Builder.CreateNot(Op1); + return BinaryOperator::CreateSub(Not, X); + } + if (match(Op1, m_Not(m_Value(X))) && + match(Op0, m_c_MaxOrMin(m_Specific(Op1), m_Value(Y))) && + !Op1->hasNUsesOrMore(3) && isFreeToInvert(Y, Y->hasOneUse())) { + Value *Not = Builder.CreateNot(Op0); + return BinaryOperator::CreateSub(X, Not); + } + + // TODO: This is the same logic as above but handles the cmp-select idioms + // for min/max, so the use checks are increased to account for the + // extra instructions. If we canonicalize to intrinsics, this block + // can likely be removed. { - // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A - // ~A - Min/Max(O, ~A) -> Max/Min(A, ~O) - A - // Min/Max(~A, O) - ~A -> A - Max/Min(A, ~O) - // Min/Max(O, ~A) - ~A -> A - Max/Min(A, ~O) - // So long as O here is freely invertible, this will be neutral or a win. Value *LHS, *RHS, *A; Value *NotA = Op0, *MinMax = Op1; SelectPatternFlavor SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor; @@ -2057,12 +2094,10 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { match(NotA, m_Not(m_Value(A))) && (NotA == LHS || NotA == RHS)) { if (NotA == LHS) std::swap(LHS, RHS); - // LHS is now O above and expected to have at least 2 uses (the min/max) - // NotA is epected to have 2 uses from the min/max and 1 from the sub. + // LHS is now Y above and expected to have at least 2 uses (the min/max) + // NotA is expected to have 2 uses from the min/max and 1 from the sub. if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && !NotA->hasNUsesOrMore(4)) { - // Note: We don't generate the inverse max/min, just create the not of - // it and let other folds do the rest. Value *Not = Builder.CreateNot(MinMax); if (NotA == Op0) return BinaryOperator::CreateSub(Not, A); @@ -2119,7 +2154,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { unsigned BitWidth = Ty->getScalarSizeInBits(); unsigned Cttz = AddC->countTrailingZeros(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); - if ((HighMask & *AndC).isNullValue()) + if ((HighMask & *AndC).isZero()) return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC))); } @@ -2133,6 +2168,19 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::umin, {I.getType()}, {Op0, Y})); + // umax(X, Op1) - Op1 --> usub.sat(X, Op1) + // TODO: The one-use restriction is not strictly necessary, but it may + // require improving other pattern matching and/or codegen. + if (match(Op0, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op1))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op1})); + + // Op0 - umax(X, Op0) --> 0 - usub.sat(X, Op0) + if (match(Op1, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op0))))) { + Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op0}); + return BinaryOperator::CreateNeg(USub); + } + // C - ctpop(X) => ctpop(~X) if C is bitwidth if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) @@ -2173,8 +2221,8 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { // TODO: We could propagate nsz/ninf from fdiv alone? FastMathFlags FMF = I.getFastMathFlags(); FastMathFlags OpFMF = FNegOp->getFastMathFlags(); - FDiv->setHasNoSignedZeros(FMF.noSignedZeros() & OpFMF.noSignedZeros()); - FDiv->setHasNoInfs(FMF.noInfs() & OpFMF.noInfs()); + FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); + FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); return FDiv; } // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 120852c44474..06c9bf650f37 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -185,14 +185,15 @@ enum MaskedICmpType { /// satisfies. static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, ICmpInst::Predicate Pred) { - ConstantInt *ACst = dyn_cast<ConstantInt>(A); - ConstantInt *BCst = dyn_cast<ConstantInt>(B); - ConstantInt *CCst = dyn_cast<ConstantInt>(C); + const APInt *ConstA = nullptr, *ConstB = nullptr, *ConstC = nullptr; + match(A, m_APInt(ConstA)); + match(B, m_APInt(ConstB)); + match(C, m_APInt(ConstC)); bool IsEq = (Pred == ICmpInst::ICMP_EQ); - bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2()); - bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2()); + bool IsAPow2 = ConstA && ConstA->isPowerOf2(); + bool IsBPow2 = ConstB && ConstB->isPowerOf2(); unsigned MaskVal = 0; - if (CCst && CCst->isZero()) { + if (ConstC && ConstC->isZero()) { // if C is zero, then both A and B qualify as mask MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); @@ -211,7 +212,7 @@ static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, if (IsAPow2) MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) : (Mask_AllZeros | AMask_Mixed)); - } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) { + } else if (ConstA && ConstC && ConstC->isSubsetOf(*ConstA)) { MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed); } @@ -221,7 +222,7 @@ static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, if (IsBPow2) MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) : (Mask_AllZeros | BMask_Mixed)); - } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) { + } else if (ConstB && ConstC && ConstC->isSubsetOf(*ConstB)) { MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); } @@ -269,9 +270,9 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { - // vectors are not (yet?) supported. Don't support pointers either. - if (!LHS->getOperand(0)->getType()->isIntegerTy() || - !RHS->getOperand(0)->getType()->isIntegerTy()) + // Don't allow pointers. Splat vectors are fine. + if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || + !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) return None; // Here comes the tricky part: @@ -367,9 +368,9 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, } else { return None; } + + assert(Ok && "Failed to find AND on the right side of the RHS icmp."); } - if (!Ok) - return None; if (L11 == A) { B = L12; @@ -619,8 +620,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // Remaining cases assume at least that B and D are constant, and depend on // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. - ConstantInt *BCst, *DCst; - if (!match(B, m_ConstantInt(BCst)) || !match(D, m_ConstantInt(DCst))) + const APInt *ConstB, *ConstD; + if (!match(B, m_APInt(ConstB)) || !match(D, m_APInt(ConstD))) return nullptr; if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { @@ -629,11 +630,10 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) // Only valid if one of the masks is a superset of the other (check "B&D" is // the same as either B or D). - APInt NewMask = BCst->getValue() & DCst->getValue(); - - if (NewMask == BCst->getValue()) + APInt NewMask = *ConstB & *ConstD; + if (NewMask == *ConstB) return LHS; - else if (NewMask == DCst->getValue()) + else if (NewMask == *ConstD) return RHS; } @@ -642,11 +642,10 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is // the same as either B or D). - APInt NewMask = BCst->getValue() | DCst->getValue(); - - if (NewMask == BCst->getValue()) + APInt NewMask = *ConstB | *ConstD; + if (NewMask == *ConstB) return LHS; - else if (NewMask == DCst->getValue()) + else if (NewMask == *ConstD) return RHS; } @@ -661,23 +660,21 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. - ConstantInt *CCst, *ECst; - if (!match(C, m_ConstantInt(CCst)) || !match(E, m_ConstantInt(ECst))) + const APInt *OldConstC, *OldConstE; + if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; - if (PredL != NewCC) - CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (PredR != NewCC) - ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + + const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; // If there is a conflict, we should actually return a false for the // whole construct. - if (((BCst->getValue() & DCst->getValue()) & - (CCst->getValue() ^ ECst->getValue())).getBoolValue()) + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) return ConstantInt::get(LHS->getType(), !IsAnd); Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); Value *NewAnd = Builder.CreateAnd(A, NewOr1); + Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); return Builder.CreateICmp(NewCC, NewAnd, NewOr2); } @@ -777,20 +774,6 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, return Builder.CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); } - // Special case: get the ordering right when the values wrap around zero. - // Ie, we assumed the constants were unsigned when swapping earlier. - if (C1->isNullValue() && C2->isAllOnesValue()) - std::swap(C1, C2); - - if (*C1 == *C2 - 1) { - // (X == 13 || X == 14) --> X - 13 <=u 1 - // (X != 13 && X != 14) --> X - 13 >u 1 - // An 'add' is the canonical IR form, so favor that over a 'sub'. - Value *Add = Builder.CreateAdd(X, ConstantInt::get(X->getType(), -(*C1))); - auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; - return Builder.CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1)); - } - return nullptr; } @@ -923,7 +906,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, if (!tryToDecompose(OtherICmp, X0, UnsetBitsMask)) return nullptr; - assert(!UnsetBitsMask.isNullValue() && "empty mask makes no sense."); + assert(!UnsetBitsMask.isZero() && "empty mask makes no sense."); // Are they working on the same value? Value *X; @@ -1113,8 +1096,8 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) { /// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01 /// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01 /// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer. -static Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, - InstCombiner::BuilderTy &Builder) { +Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse()) return nullptr; @@ -1202,6 +1185,51 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp); } +/// Fold (icmp Pred1 V1, C1) & (icmp Pred2 V2, C2) +/// or (icmp Pred1 V1, C1) | (icmp Pred2 V2, C2) +/// into a single comparison using range-based reasoning. +static Value *foldAndOrOfICmpsUsingRanges( + ICmpInst::Predicate Pred1, Value *V1, const APInt &C1, + ICmpInst::Predicate Pred2, Value *V2, const APInt &C2, + IRBuilderBase &Builder, bool IsAnd) { + // Look through add of a constant offset on V1, V2, or both operands. This + // allows us to interpret the V + C' < C'' range idiom into a proper range. + const APInt *Offset1 = nullptr, *Offset2 = nullptr; + if (V1 != V2) { + Value *X; + if (match(V1, m_Add(m_Value(X), m_APInt(Offset1)))) + V1 = X; + if (match(V2, m_Add(m_Value(X), m_APInt(Offset2)))) + V2 = X; + } + + if (V1 != V2) + return nullptr; + + ConstantRange CR1 = ConstantRange::makeExactICmpRegion(Pred1, C1); + if (Offset1) + CR1 = CR1.subtract(*Offset1); + + ConstantRange CR2 = ConstantRange::makeExactICmpRegion(Pred2, C2); + if (Offset2) + CR2 = CR2.subtract(*Offset2); + + Optional<ConstantRange> CR = + IsAnd ? CR1.exactIntersectWith(CR2) : CR1.exactUnionWith(CR2); + if (!CR) + return nullptr; + + CmpInst::Predicate NewPred; + APInt NewC, Offset; + CR->getEquivalentICmp(NewPred, NewC, Offset); + + Type *Ty = V1->getType(); + Value *NewV = V1; + if (Offset != 0) + NewV = Builder.CreateAdd(NewV, ConstantInt::get(Ty, Offset)); + return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And) { @@ -1262,170 +1290,64 @@ Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder)) return X; - if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true, Builder)) + if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true)) return X; // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); - ConstantInt *LHSC, *RHSC; - if (!match(LHS->getOperand(1), m_ConstantInt(LHSC)) || - !match(RHS->getOperand(1), m_ConstantInt(RHSC))) - return nullptr; - - if (LHSC == RHSC && PredL == PredR) { - // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) - // where C is a power of 2 or - // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) || - (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) { - Value *NewOr = Builder.CreateOr(LHS0, RHS0); - return Builder.CreateICmp(PredL, NewOr, LHSC); - } + // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) + // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + if (PredL == ICmpInst::ICMP_EQ && match(LHS->getOperand(1), m_ZeroInt()) && + PredR == ICmpInst::ICMP_EQ && match(RHS->getOperand(1), m_ZeroInt()) && + LHS0->getType() == RHS0->getType()) { + Value *NewOr = Builder.CreateOr(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewOr, + Constant::getNullValue(NewOr->getType())); } + const APInt *LHSC, *RHSC; + if (!match(LHS->getOperand(1), m_APInt(LHSC)) || + !match(RHS->getOperand(1), m_APInt(RHSC))) + return nullptr; + // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && RHS->hasOneUse()) { Value *V; - ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr; + const APInt *AndC, *SmallC = nullptr, *BigC = nullptr; // (trunc x) == C1 & (and x, CA) == C2 // (and x, CA) == C2 & (trunc x) == C1 if (match(RHS0, m_Trunc(m_Value(V))) && - match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + match(LHS0, m_And(m_Specific(V), m_APInt(AndC)))) { SmallC = RHSC; BigC = LHSC; } else if (match(LHS0, m_Trunc(m_Value(V))) && - match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + match(RHS0, m_And(m_Specific(V), m_APInt(AndC)))) { SmallC = LHSC; BigC = RHSC; } if (SmallC && BigC) { - unsigned BigBitSize = BigC->getType()->getBitWidth(); - unsigned SmallBitSize = SmallC->getType()->getBitWidth(); + unsigned BigBitSize = BigC->getBitWidth(); + unsigned SmallBitSize = SmallC->getBitWidth(); // Check that the low bits are zero. APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); - if ((Low & AndC->getValue()).isNullValue() && - (Low & BigC->getValue()).isNullValue()) { - Value *NewAnd = Builder.CreateAnd(V, Low | AndC->getValue()); - APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue(); - Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N); + if ((Low & *AndC).isZero() && (Low & *BigC).isZero()) { + Value *NewAnd = Builder.CreateAnd(V, Low | *AndC); + APInt N = SmallC->zext(BigBitSize) | *BigC; + Value *NewVal = ConstantInt::get(NewAnd->getType(), N); return Builder.CreateICmp(PredL, NewAnd, NewVal); } } } - // From here on, we only handle: - // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. - if (LHS0 != RHS0) - return nullptr; - - // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. - if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || - PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || - PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || - PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) - return nullptr; - - // We can't fold (ugt x, C) & (sgt x, C2). - if (!predicatesFoldable(PredL, PredR)) - return nullptr; - - // Ensure that the larger constant is on the RHS. - bool ShouldSwap; - if (CmpInst::isSigned(PredL) || - (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) - ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); - else - ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); - - if (ShouldSwap) { - std::swap(LHS, RHS); - std::swap(LHSC, RHSC); - std::swap(PredL, PredR); - } - - // At this point, we know we have two icmp instructions - // comparing a value against two constants and and'ing the result - // together. Because of the above check, we know that we only have - // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know - // (from the icmp folding check above), that the two constants - // are not equal and that the larger constant is on the RHS - assert(LHSC != RHSC && "Compares not folded above?"); - - switch (PredL) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_ULT: - // (X != 13 & X u< 14) -> X < 13 - if (LHSC->getValue() == (RHSC->getValue() - 1)) - return Builder.CreateICmpULT(LHS0, LHSC); - if (LHSC->isZero()) // (X != 0 & X u< C) -> X-1 u< C-1 - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - false, true); - break; // (X != 13 & X u< 15) -> no change - case ICmpInst::ICMP_SLT: - // (X != 13 & X s< 14) -> X < 13 - if (LHSC->getValue() == (RHSC->getValue() - 1)) - return Builder.CreateICmpSLT(LHS0, LHSC); - // (X != INT_MIN & X s< C) -> X-(INT_MIN+1) u< (C-(INT_MIN+1)) - if (LHSC->isMinValue(true)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - true, true); - break; // (X != 13 & X s< 15) -> no change - case ICmpInst::ICMP_NE: - // Potential folds for this case should already be handled. - break; - } - break; - case ICmpInst::ICMP_UGT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: - // (X u> 13 & X != 14) -> X u> 14 - if (RHSC->getValue() == (LHSC->getValue() + 1)) - return Builder.CreateICmp(PredL, LHS0, RHSC); - // X u> C & X != UINT_MAX -> (X-(C+1)) u< UINT_MAX-(C+1) - if (RHSC->isMaxValue(false)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - false, true); - break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) u< 1 - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - false, true); - } - break; - case ICmpInst::ICMP_SGT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: - // (X s> 13 & X != 14) -> X s> 14 - if (RHSC->getValue() == (LHSC->getValue() + 1)) - return Builder.CreateICmp(PredL, LHS0, RHSC); - // X s> C & X != INT_MAX -> (X-(C+1)) u< INT_MAX-(C+1) - if (RHSC->isMaxValue(true)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - true, true); - break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) u< 1 - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, - true); - } - break; - } - - return nullptr; + return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC, + Builder, /* IsAnd */ true); } Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, @@ -1496,15 +1418,15 @@ static Instruction *reassociateFCmps(BinaryOperator &BO, std::swap(Op0, Op1); // Match inner binop and the predicate for combining 2 NAN checks into 1. - BinaryOperator *BO1; + Value *BO10, *BO11; FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; if (!match(Op0, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())) || Pred != NanPred || - !match(Op1, m_BinOp(BO1)) || BO1->getOpcode() != Opcode) + !match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11)))) return nullptr; // The inner logic op must have a matching fcmp operand. - Value *BO10 = BO1->getOperand(0), *BO11 = BO1->getOperand(1), *Y; + Value *Y; if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || Pred != NanPred || X->getType() != Y->getType()) std::swap(BO10, BO11); @@ -1524,27 +1446,42 @@ static Instruction *reassociateFCmps(BinaryOperator &BO, return BinaryOperator::Create(Opcode, NewFCmp, BO11); } -/// Match De Morgan's Laws: +/// Match variations of De Morgan's Laws: /// (~A & ~B) == (~(A | B)) /// (~A | ~B) == (~(A & B)) static Instruction *matchDeMorgansLaws(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { - auto Opcode = I.getOpcode(); + const Instruction::BinaryOps Opcode = I.getOpcode(); assert((Opcode == Instruction::And || Opcode == Instruction::Or) && "Trying to match De Morgan's Laws with something other than and/or"); // Flip the logic operation. - Opcode = (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + const Instruction::BinaryOps FlippedOpcode = + (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *A, *B; - if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) && - match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) && + if (match(Op0, m_OneUse(m_Not(m_Value(A)))) && + match(Op1, m_OneUse(m_Not(m_Value(B)))) && !InstCombiner::isFreeToInvert(A, A->hasOneUse()) && !InstCombiner::isFreeToInvert(B, B->hasOneUse())) { - Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan"); + Value *AndOr = + Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } + // The 'not' ops may require reassociation. + // (A & ~B) & ~C --> A & ~(B | C) + // (~B & A) & ~C --> A & ~(B | C) + // (A | ~B) | ~C --> A | ~(B & C) + // (~B | A) | ~C --> A | ~(B & C) + Value *C; + if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) && + match(Op1, m_Not(m_Value(C)))) { + Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); + return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); + } + return nullptr; } @@ -1778,6 +1715,72 @@ Instruction *InstCombinerImpl::narrowMaskedBinOp(BinaryOperator &And) { return new ZExtInst(Builder.CreateAnd(NewBO, X), Ty); } +/// Try folding relatively complex patterns for both And and Or operations +/// with all And and Or swapped. +static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + const Instruction::BinaryOps Opcode = I.getOpcode(); + assert(Opcode == Instruction::And || Opcode == Instruction::Or); + + // Flip the logic operation. + const Instruction::BinaryOps FlippedOpcode = + (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C; + + // (~(A | B) & C) | ... --> ... + // (~(A & B) | C) & ... --> ... + // TODO: One use checks are conservative. We just need to check that a total + // number of multiple used values does not exceed reduction + // in operations. + if (match(Op0, m_c_BinOp(FlippedOpcode, + m_Not(m_BinOp(Opcode, m_Value(A), m_Value(B))), + m_Value(C)))) { + // (~(A | B) & C) | (~(A | C) & B) --> (B ^ C) & ~A + // (~(A & B) | C) & (~(A & C) | B) --> ~((B ^ C) & A) + if (match(Op1, + m_OneUse(m_c_BinOp(FlippedOpcode, + m_OneUse(m_Not(m_c_BinOp(Opcode, m_Specific(A), + m_Specific(C)))), + m_Specific(B))))) { + Value *Xor = Builder.CreateXor(B, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateAnd(Xor, Builder.CreateNot(A)) + : BinaryOperator::CreateNot(Builder.CreateAnd(Xor, A)); + } + + // (~(A | B) & C) | (~(B | C) & A) --> (A ^ C) & ~B + // (~(A & B) | C) & (~(B & C) | A) --> ~((A ^ C) & B) + if (match(Op1, + m_OneUse(m_c_BinOp(FlippedOpcode, + m_OneUse(m_Not(m_c_BinOp(Opcode, m_Specific(B), + m_Specific(C)))), + m_Specific(A))))) { + Value *Xor = Builder.CreateXor(A, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateAnd(Xor, Builder.CreateNot(B)) + : BinaryOperator::CreateNot(Builder.CreateAnd(Xor, B)); + } + + // (~(A | B) & C) | ~(A | C) --> ~((B & C) | A) + // (~(A & B) | C) & ~(A & C) --> ~((B | C) & A) + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(A), m_Specific(C))))))) + return BinaryOperator::CreateNot(Builder.CreateBinOp( + Opcode, Builder.CreateBinOp(FlippedOpcode, B, C), A)); + + // (~(A | B) & C) | ~(B | C) --> ~((A & C) | B) + // (~(A & B) | C) & ~(B & C) --> ~((A | C) & B) + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(B), m_Specific(C))))))) + return BinaryOperator::CreateNot(Builder.CreateBinOp( + Opcode, Builder.CreateBinOp(FlippedOpcode, A, C), B)); + } + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -1803,6 +1806,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Instruction *Xor = foldAndToXor(I, Builder)) return Xor; + if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) + return X; + // (A|B)&(A|C) -> A|(B&C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -1883,7 +1889,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // (X + AddC) & LowMaskC --> X & LowMaskC unsigned Ctlz = C->countLeadingZeros(); APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); - if ((*AddC & LowMask).isNullValue()) + if ((*AddC & LowMask).isZero()) return BinaryOperator::CreateAnd(X, Op1); // If we are masking the result of the add down to exactly one bit and @@ -1896,44 +1902,37 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateXor(NewAnd, Op1); } } - } - ConstantInt *AndRHS; - if (match(Op1, m_ConstantInt(AndRHS))) { - const APInt &AndRHSMask = AndRHS->getValue(); - - // Optimize a variety of ((val OP C1) & C2) combinations... - if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth - // of X and OP behaves well when given trunc(C1) and X. - // TODO: Do this for vectors by using m_APInt instead of m_ConstantInt. - switch (Op0I->getOpcode()) { - default: - break; + // ((C1 OP zext(X)) & C2) -> zext((C1 OP X) & C2) if C2 fits in the + // bitwidth of X and OP behaves well when given trunc(C1) and X. + auto isSuitableBinOpcode = [](BinaryOperator *B) { + switch (B->getOpcode()) { case Instruction::Xor: case Instruction::Or: case Instruction::Mul: case Instruction::Add: case Instruction::Sub: - Value *X; - ConstantInt *C1; - // TODO: The one use restrictions could be relaxed a little if the AND - // is going to be removed. - if (match(Op0I, m_OneUse(m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), - m_ConstantInt(C1))))) { - if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { - auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); - Value *BinOp; - Value *Op0LHS = Op0I->getOperand(0); - if (isa<ZExtInst>(Op0LHS)) - BinOp = Builder.CreateBinOp(Op0I->getOpcode(), X, TruncC1); - else - BinOp = Builder.CreateBinOp(Op0I->getOpcode(), TruncC1, X); - auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); - auto *And = Builder.CreateAnd(BinOp, TruncC2); - return new ZExtInst(And, Ty); - } - } + return true; + default: + return false; + } + }; + BinaryOperator *BO; + if (match(Op0, m_OneUse(m_BinOp(BO))) && isSuitableBinOpcode(BO)) { + Value *X; + const APInt *C1; + // TODO: The one-use restrictions could be relaxed a little if the AND + // is going to be removed. + if (match(BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), m_APInt(C1))) && + C->isIntN(X->getType()->getScalarSizeInBits())) { + unsigned XWidth = X->getType()->getScalarSizeInBits(); + Constant *TruncC1 = ConstantInt::get(X->getType(), C1->trunc(XWidth)); + Value *BinOp = isa<ZExtInst>(BO->getOperand(0)) + ? Builder.CreateBinOp(BO->getOpcode(), X, TruncC1) + : Builder.CreateBinOp(BO->getOpcode(), TruncC1, X); + Constant *TruncC = ConstantInt::get(X->getType(), C->trunc(XWidth)); + Value *And = Builder.CreateAnd(BinOp, TruncC); + return new ZExtInst(And, Ty); } } } @@ -2071,13 +2070,13 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); - // and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0. - if (match(&I, m_c_And(m_OneUse(m_AShr( - m_NSWSub(m_Value(Y), m_Value(X)), - m_SpecificInt(Ty->getScalarSizeInBits() - 1))), - m_Deferred(X)))) { - Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); - return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty)); + // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 + unsigned FullShift = Ty->getScalarSizeInBits() - 1; + if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))), + m_Value(Y)))) { + Constant *Zero = ConstantInt::getNullValue(Ty); + Value *Cmp = Builder.CreateICmpSLT(X, Zero, "isneg"); + return SelectInst::Create(Cmp, Y, Zero); } // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions @@ -2284,28 +2283,38 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { - // Step 1: We may have peeked through bitcasts in the caller. + // We may have peeked through bitcasts in the caller. // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) return nullptr; - // Step 2: We need 0 or all-1's bitmasks. - if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits()) - return nullptr; - - // Step 3: If B is the 'not' value of A, we have our answer. - if (match(A, m_Not(m_Specific(B)))) { + // If A is the 'not' operand of B and has enough signbits, we have our answer. + if (match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; - return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // If we look through a vector bitcast, the caller will bitcast the operands + // to match the condition's number of bits (N x i1). + // To make this poison-safe, disallow bitcast from wide element to narrow + // element. That could allow poison in lanes where it was not present in the + // original code. + A = peekThroughBitcast(A); + if (A->getType()->isIntOrIntVectorTy()) { + unsigned NumSignBits = ComputeNumSignBits(A); + if (NumSignBits == A->getType()->getScalarSizeInBits() && + NumSignBits <= Ty->getScalarSizeInBits()) + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType())); + } + return nullptr; } // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) - if (AConst == ConstantExpr::getNot(BConst)) + if (AConst == ConstantExpr::getNot(BConst) && + ComputeNumSignBits(A) == Ty->getScalarSizeInBits()) return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); // Look for more complex patterns. The 'not' op may be hidden behind various @@ -2349,10 +2358,17 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, B = peekThroughBitcast(B, true); if (Value *Cond = getSelectCondition(A, B)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. - Value *BitcastC = Builder.CreateBitCast(C, A->getType()); - Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Type *SelTy = A->getType(); + if (auto *VecTy = dyn_cast<VectorType>(Cond->getType())) { + unsigned Elts = VecTy->getElementCount().getKnownMinValue(); + Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts); + SelTy = VectorType::get(EltTy, VecTy->getElementCount()); + } + Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); } @@ -2374,8 +2390,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); - auto *LHSC = dyn_cast<ConstantInt>(LHS1); - auto *RHSC = dyn_cast<ConstantInt>(RHS1); + const APInt *LHSC = nullptr, *RHSC = nullptr; + match(LHS1, m_APInt(LHSC)); + match(RHS1, m_APInt(RHSC)); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) @@ -2389,40 +2406,41 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // This implies all values in the two ranges differ by exactly one bit. if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && - LHSC->getType() == RHSC->getType() && - LHSC->getValue() == (RHSC->getValue())) { + LHSC->getBitWidth() == RHSC->getBitWidth() && *LHSC == *RHSC) { Value *AddOpnd; - ConstantInt *LAddC, *RAddC; - if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) && - match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) && - LAddC->getValue().ugt(LHSC->getValue()) && - RAddC->getValue().ugt(LHSC->getValue())) { + const APInt *LAddC, *RAddC; + if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddC))) && + match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddC))) && + LAddC->ugt(*LHSC) && RAddC->ugt(*LHSC)) { - APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + APInt DiffC = *LAddC ^ *RAddC; if (DiffC.isPowerOf2()) { - ConstantInt *MaxAddC = nullptr; - if (LAddC->getValue().ult(RAddC->getValue())) + const APInt *MaxAddC = nullptr; + if (LAddC->ult(*RAddC)) MaxAddC = RAddC; else MaxAddC = LAddC; - APInt RRangeLow = -RAddC->getValue(); - APInt RRangeHigh = RRangeLow + LHSC->getValue(); - APInt LRangeLow = -LAddC->getValue(); - APInt LRangeHigh = LRangeLow + LHSC->getValue(); + APInt RRangeLow = -*RAddC; + APInt RRangeHigh = RRangeLow + *LHSC; + APInt LRangeLow = -*LAddC; + APInt LRangeHigh = LRangeLow + *LHSC; APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(LHSC->getValue())) { - Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); + RangeDiff.ugt(*LHSC)) { + Type *Ty = AddOpnd->getType(); + Value *MaskC = ConstantInt::get(Ty, ~DiffC); Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC); - Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC); - return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC); + Value *NewAdd = Builder.CreateAdd(NewAnd, + ConstantInt::get(Ty, *MaxAddC)); + return Builder.CreateICmp(LHS->getPredicate(), NewAdd, + ConstantInt::get(Ty, *LHSC)); } } } @@ -2496,14 +2514,13 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder)) return X; - if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false, Builder)) + if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false)) return X; // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) - // TODO: Remove this when foldLogOpOfMaskedICmps can handle vectors. - if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_Zero()) && - PredR == ICmpInst::ICMP_NE && match(RHS1, m_Zero()) && - LHS0->getType()->isIntOrIntVectorTy() && + // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_ZeroInt()) && + PredR == ICmpInst::ICMP_NE && match(RHS1, m_ZeroInt()) && LHS0->getType() == RHS0->getType()) { Value *NewOr = Builder.CreateOr(LHS0, RHS0); return Builder.CreateICmp(PredL, NewOr, @@ -2514,114 +2531,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (!LHSC || !RHSC) return nullptr; - // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) - // iff C2 + CA == C1. - if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { - ConstantInt *AddC; - if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC)))) - if (RHSC->getValue() + AddC->getValue() == LHSC->getValue()) - return Builder.CreateICmpULE(LHS0, LHSC); - } - - // From here on, we only handle: - // (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. - if (LHS0 != RHS0) - return nullptr; - - // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. - if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || - PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || - PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || - PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) - return nullptr; - - // We can't fold (ugt x, C) | (sgt x, C2). - if (!predicatesFoldable(PredL, PredR)) - return nullptr; - - // Ensure that the larger constant is on the RHS. - bool ShouldSwap; - if (CmpInst::isSigned(PredL) || - (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) - ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); - else - ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); - - if (ShouldSwap) { - std::swap(LHS, RHS); - std::swap(LHSC, RHSC); - std::swap(PredL, PredR); - } - - // At this point, we know we have two icmp instructions - // comparing a value against two constants and or'ing the result - // together. Because of the above check, we know that we only have - // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the - // icmp folding check above), that the two constants are not - // equal. - assert(LHSC != RHSC && "Compares not folded above?"); - - switch (PredL) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - // Potential folds for this case should already be handled. - break; - case ICmpInst::ICMP_UGT: - // (X == 0 || X u> C) -> (X-1) u>= C - if (LHSC->isMinValue(false)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, - false, false); - // (X == 13 | X u> 14) -> no change - break; - case ICmpInst::ICMP_SGT: - // (X == INT_MIN || X s> C) -> (X-(INT_MIN+1)) u>= C-INT_MIN - if (LHSC->isMinValue(true)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, - true, false); - // (X == 13 | X s> 14) -> no change - break; - } - break; - case ICmpInst::ICMP_ULT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change - // (X u< C || X == UINT_MAX) => (X-C) u>= UINT_MAX-C - if (RHSC->isMaxValue(false)) - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), - false, false); - break; - case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 - assert(!RHSC->isMaxValue(false) && "Missed icmp simplification"); - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, - false, false); - } - break; - case ICmpInst::ICMP_SLT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - // (X s< C || X == INT_MAX) => (X-C) u>= INT_MAX-C - if (RHSC->isMaxValue(true)) - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), - true, false); - // (X s< 13 | X == 14) -> no change - break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) u> 2 - assert(!RHSC->isMaxValue(true) && "Missed icmp simplification"); - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, - false); - } - break; - } - return nullptr; + return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC, + Builder, /* IsAnd */ false); } // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches @@ -2647,6 +2558,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Xor = foldOrToXor(I, Builder)) return Xor; + if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) + return X; + // (A&B)|(A&C) -> A&(B|C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -2684,69 +2598,63 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && - !CV->isAllOnesValue() && MaskedValueIsZero(Y, *CV, 0, &I)) { + !CV->isAllOnes() && MaskedValueIsZero(Y, *CV, 0, &I)) { // (X ^ C) | Y -> (X | Y) ^ C iff Y & C == 0 // The check for a 'not' op is for efficiency (if Y is known zero --> ~X). Value *Or = Builder.CreateOr(X, Y); return BinaryOperator::CreateXor(Or, ConstantInt::get(I.getType(), *CV)); } - // (A & C)|(B & D) + // (A & C) | (B & D) Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { - // (A & C1)|(B & C2) - ConstantInt *C1, *C2; - if (match(C, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2))) { - Value *V1 = nullptr, *V2 = nullptr; - if ((C1->getValue() & C2->getValue()).isNullValue()) { - // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) - // iff (C1&C2) == 0 and (N&~C1) == 0 - if (match(A, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == B && - MaskedValueIsZero(V2, ~C1->getValue(), 0, &I)) || // (V|N) - (V2 == B && - MaskedValueIsZero(V1, ~C1->getValue(), 0, &I)))) // (N|V) - return BinaryOperator::CreateAnd(A, - Builder.getInt(C1->getValue()|C2->getValue())); - // Or commutes, try both ways. - if (match(B, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == A && - MaskedValueIsZero(V2, ~C2->getValue(), 0, &I)) || // (V|N) - (V2 == A && - MaskedValueIsZero(V1, ~C2->getValue(), 0, &I)))) // (N|V) - return BinaryOperator::CreateAnd(B, - Builder.getInt(C1->getValue()|C2->getValue())); - - // ((V|C3)&C1) | ((V|C4)&C2) --> (V|C3|C4)&(C1|C2) - // iff (C1&C2) == 0 and (C3&~C1) == 0 and (C4&~C2) == 0. - ConstantInt *C3 = nullptr, *C4 = nullptr; - if (match(A, m_Or(m_Value(V1), m_ConstantInt(C3))) && - (C3->getValue() & ~C1->getValue()).isNullValue() && - match(B, m_Or(m_Specific(V1), m_ConstantInt(C4))) && - (C4->getValue() & ~C2->getValue()).isNullValue()) { - V2 = Builder.CreateOr(V1, ConstantExpr::getOr(C3, C4), "bitfield"); - return BinaryOperator::CreateAnd(V2, - Builder.getInt(C1->getValue()|C2->getValue())); - } - } - if (C1->getValue() == ~C2->getValue()) { - Value *X; - - // ((X|B)&C1)|(B&C2) -> (X&C1) | B iff C1 == ~C2 + // (A & C0) | (B & C1) + const APInt *C0, *C1; + if (match(C, m_APInt(C0)) && match(D, m_APInt(C1))) { + Value *X; + if (*C0 == ~*C1) { + // ((X | B) & MaskC) | (B & ~MaskC) -> (X & MaskC) | B if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) - return BinaryOperator::CreateOr(Builder.CreateAnd(X, C1), B); - // (A&C2)|((X|A)&C1) -> (X&C2) | A iff C1 == ~C2 + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C0), B); + // (A & MaskC) | ((X | A) & ~MaskC) -> (X & ~MaskC) | A if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) - return BinaryOperator::CreateOr(Builder.CreateAnd(X, C2), A); + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C1), A); - // ((X^B)&C1)|(B&C2) -> (X&C1) ^ B iff C1 == ~C2 + // ((X ^ B) & MaskC) | (B & ~MaskC) -> (X & MaskC) ^ B if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) - return BinaryOperator::CreateXor(Builder.CreateAnd(X, C1), B); - // (A&C2)|((X^A)&C1) -> (X&C2) ^ A iff C1 == ~C2 + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *C0), B); + // (A & MaskC) | ((X ^ A) & ~MaskC) -> (X & ~MaskC) ^ A if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) - return BinaryOperator::CreateXor(Builder.CreateAnd(X, C2), A); + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *C1), A); + } + + if ((*C0 & *C1).isZero()) { + // ((X | B) & C0) | (B & C1) --> (X | B) & (C0 | C1) + // iff (C0 & C1) == 0 and (X & ~C0) == 0 + if (match(A, m_c_Or(m_Value(X), m_Specific(B))) && + MaskedValueIsZero(X, ~*C0, 0, &I)) { + Constant *C01 = ConstantInt::get(I.getType(), *C0 | *C1); + return BinaryOperator::CreateAnd(A, C01); + } + // (A & C0) | ((X | A) & C1) --> (X | A) & (C0 | C1) + // iff (C0 & C1) == 0 and (X & ~C1) == 0 + if (match(B, m_c_Or(m_Value(X), m_Specific(A))) && + MaskedValueIsZero(X, ~*C1, 0, &I)) { + Constant *C01 = ConstantInt::get(I.getType(), *C0 | *C1); + return BinaryOperator::CreateAnd(B, C01); + } + // ((X | C2) & C0) | ((X | C3) & C1) --> (X | C2 | C3) & (C0 | C1) + // iff (C0 & C1) == 0 and (C2 & ~C0) == 0 and (C3 & ~C1) == 0. + const APInt *C2, *C3; + if (match(A, m_Or(m_Value(X), m_APInt(C2))) && + match(B, m_Or(m_Specific(X), m_APInt(C3))) && + (*C2 & ~*C0).isZero() && (*C3 & ~*C1).isZero()) { + Value *Or = Builder.CreateOr(X, *C2 | *C3, "bitfield"); + Constant *C01 = ConstantInt::get(I.getType(), *C0 | *C1); + return BinaryOperator::CreateAnd(Or, C01); + } } } @@ -2801,6 +2709,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // A | ( A ^ B) -> A | B // A | (~A ^ B) -> A | ~B // (A & B) | (A ^ B) + // ~A | (A ^ B) -> ~(A & B) + // The swap above should always make Op0 the 'not' for the last case. if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) { if (Op0 == A || Op0 == B) return BinaryOperator::CreateOr(A, B); @@ -2809,6 +2719,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { match(Op0, m_And(m_Specific(B), m_Specific(A)))) return BinaryOperator::CreateOr(A, B); + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + (match(Op0, m_Not(m_Specific(A))) || match(Op0, m_Not(m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { Value *Not = Builder.CreateNot(B, B->getName() + ".not"); return BinaryOperator::CreateOr(Not, Op0); @@ -3275,71 +3189,45 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) { return true; } -// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches -// here. We should standardize that construct where it is needed or choose some -// other way to ensure that commutated variants of patterns are not missed. -Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { - if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); - - if (SimplifyAssociativeOrCommutative(I)) - return &I; - - if (Instruction *X = foldVectorBinop(I)) - return X; - - if (Instruction *NewXor = foldXorToXor(I, Builder)) - return NewXor; - - // (A&B)^(A&C) -> A&(B^C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) - return replaceInstUsesWith(I, V); - - // See if we can simplify any instructions used by the instruction whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(I)) - return &I; - - if (Value *V = SimplifyBSwap(I, Builder)) - return replaceInstUsesWith(I, V); - - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - Type *Ty = I.getType(); - - // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) - // This it a special case in haveNoCommonBitsSet, but the computeKnownBits - // calls in there are unnecessary as SimplifyDemandedInstructionBits should - // have already taken care of those cases. - Value *M; - if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), - m_c_And(m_Deferred(M), m_Value())))) - return BinaryOperator::CreateOr(Op0, Op1); +Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { + Value *NotOp; + if (!match(&I, m_Not(m_Value(NotOp)))) + return nullptr; // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. - Value *X, *Y; - // We must eliminate the and/or (one-use) for these transforms to not increase // the instruction count. + // // ~(~X & Y) --> (X | ~Y) // ~(Y & ~X) --> (X | ~Y) - if (match(&I, m_Not(m_OneUse(m_c_And(m_Not(m_Value(X)), m_Value(Y)))))) { + // + // Note: The logical matches do not check for the commuted patterns because + // those are handled via SimplifySelectsFeedingBinaryOp(). + Type *Ty = I.getType(); + Value *X, *Y; + if (match(NotOp, m_OneUse(m_c_And(m_Not(m_Value(X)), m_Value(Y))))) { Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); return BinaryOperator::CreateOr(X, NotY); } + if (match(NotOp, m_OneUse(m_LogicalAnd(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return SelectInst::Create(X, ConstantInt::getTrue(Ty), NotY); + } + // ~(~X | Y) --> (X & ~Y) // ~(Y | ~X) --> (X & ~Y) - if (match(&I, m_Not(m_OneUse(m_c_Or(m_Not(m_Value(X)), m_Value(Y)))))) { + if (match(NotOp, m_OneUse(m_c_Or(m_Not(m_Value(X)), m_Value(Y))))) { Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); return BinaryOperator::CreateAnd(X, NotY); } - - if (Instruction *Xor = visitMaskedMerge(I, Builder)) - return Xor; + if (match(NotOp, m_OneUse(m_LogicalOr(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return SelectInst::Create(X, NotY, ConstantInt::getFalse(Ty)); + } // Is this a 'not' (~) fed by a binary operator? BinaryOperator *NotVal; - if (match(&I, m_Not(m_BinOp(NotVal)))) { + if (match(NotOp, m_BinOp(NotVal))) { if (NotVal->getOpcode() == Instruction::And || NotVal->getOpcode() == Instruction::Or) { // Apply DeMorgan's Law when inverts are free: @@ -3411,9 +3299,164 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { NotVal); } - // Use DeMorgan and reassociation to eliminate a 'not' op. + // not (cmp A, B) = !cmp A, B + CmpInst::Predicate Pred; + if (match(NotOp, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) { + cast<CmpInst>(NotOp)->setPredicate(CmpInst::getInversePredicate(Pred)); + return replaceInstUsesWith(I, NotOp); + } + + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // ~min(~X, ~Y) --> max(X, Y) + // ~max(~X, Y) --> min(X, ~Y) + auto *II = dyn_cast<IntrinsicInst>(NotOp); + if (II && II->hasOneUse()) { + if (match(NotOp, m_MaxOrMin(m_Value(X), m_Value(Y))) && + isFreeToInvert(X, X->hasOneUse()) && + isFreeToInvert(Y, Y->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *NotX = Builder.CreateNot(X); + Value *NotY = Builder.CreateNot(Y); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, NotX, NotY); + return replaceInstUsesWith(I, InvMaxMin); + } + if (match(NotOp, m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y)))) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *NotY = Builder.CreateNot(Y); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); + return replaceInstUsesWith(I, InvMaxMin); + } + } + + // TODO: Remove folds if we canonicalize to intrinsics (see above). + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // + // %notx = xor i32 %x, -1 + // %cmp1 = icmp sgt i32 %notx, %y + // %smax = select i1 %cmp1, i32 %notx, i32 %y + // %res = xor i32 %smax, -1 + // => + // %noty = xor i32 %y, -1 + // %cmp2 = icmp slt %x, %noty + // %res = select i1 %cmp2, i32 %x, i32 %noty + // + // Same is applicable for smin/umax/umin. + if (NotOp->hasOneUse()) { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(NotOp, LHS, RHS).Flavor; + if (SelectPatternResult::isMinOrMax(SPF)) { + // It's possible we get here before the not has been simplified, so make + // sure the input to the not isn't freely invertible. + if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { + Value *NotY = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); + } + + // It's possible we get here before the not has been simplified, so make + // sure the input to the not isn't freely invertible. + if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { + Value *NotX = Builder.CreateNot(LHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); + } + + // If both sides are freely invertible, then we can get rid of the xor + // completely. + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { + Value *NotLHS = Builder.CreateNot(LHS); + Value *NotRHS = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), NotLHS, NotRHS), + NotLHS, NotRHS); + } + } + + // Pull 'not' into operands of select if both operands are one-use compares + // or one is one-use compare and the other one is a constant. + // Inverting the predicates eliminates the 'not' operation. + // Example: + // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> + // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) + // not (select ?, (cmp TPred, ?, ?), true --> + // select ?, (cmp InvTPred, ?, ?), false + if (auto *Sel = dyn_cast<SelectInst>(NotOp)) { + Value *TV = Sel->getTrueValue(); + Value *FV = Sel->getFalseValue(); + auto *CmpT = dyn_cast<CmpInst>(TV); + auto *CmpF = dyn_cast<CmpInst>(FV); + bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV); + bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV); + if (InvertibleT && InvertibleF) { + if (CmpT) + CmpT->setPredicate(CmpT->getInversePredicate()); + else + Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV))); + if (CmpF) + CmpF->setPredicate(CmpF->getInversePredicate()); + else + Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV))); + return replaceInstUsesWith(I, Sel); + } + } + } + + if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) + return NewXor; + + return nullptr; +} + +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. +Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { + if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *NewXor = foldXorToXor(I, Builder)) + return NewXor; + + // (A&B)^(A&C) -> A&(B^C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Instruction *R = foldNot(I)) + return R; + + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) + // This it a special case in haveNoCommonBitsSet, but the computeKnownBits + // calls in there are unnecessary as SimplifyDemandedInstructionBits should + // have already taken care of those cases. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *M; + if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), + m_c_And(m_Deferred(M), m_Value())))) + return BinaryOperator::CreateOr(Op0, Op1); + + if (Instruction *Xor = visitMaskedMerge(I, Builder)) + return Xor; + + Value *X, *Y; Constant *C1; if (match(Op1, m_Constant(C1))) { + // Use DeMorgan and reassociation to eliminate a 'not' op. Constant *C2; if (match(Op0, m_OneUse(m_Or(m_Not(m_Value(X)), m_Constant(C2))))) { // (~X | C2) ^ C1 --> ((X & ~C2) ^ -1) ^ C1 --> (X & ~C2) ^ ~C1 @@ -3425,15 +3468,24 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { Value *Or = Builder.CreateOr(X, ConstantExpr::getNot(C2)); return BinaryOperator::CreateXor(Or, ConstantExpr::getNot(C1)); } - } - // not (cmp A, B) = !cmp A, B - CmpInst::Predicate Pred; - if (match(&I, m_Not(m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))))) { - cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred)); - return replaceInstUsesWith(I, Op0); + // Convert xor ([trunc] (ashr X, BW-1)), C => + // select(X >s -1, C, ~C) + // The ashr creates "AllZeroOrAllOne's", which then optionally inverses the + // constant depending on whether this input is less than 0. + const APInt *CA; + if (match(Op0, m_OneUse(m_TruncOrSelf( + m_AShr(m_Value(X), m_APIntAllowUndef(CA))))) && + *CA == X->getType()->getScalarSizeInBits() - 1 && + !match(C1, m_AllOnes())) { + assert(!C1->isZeroValue() && "Unexpected xor with 0"); + Value *ICmp = + Builder.CreateICmpSGT(X, Constant::getAllOnesValue(X->getType())); + return SelectInst::Create(ICmp, Op1, Builder.CreateNot(Op1)); + } } + Type *Ty = I.getType(); { const APInt *RHSC; if (match(Op1, m_APInt(RHSC))) { @@ -3456,13 +3508,13 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // canonicalize to a 'not' before the shift to help SCEV and codegen: // (X << C) ^ RHSC --> ~X << C if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_APInt(C)))) && - *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).shl(*C)) { + *RHSC == APInt::getAllOnes(Ty->getScalarSizeInBits()).shl(*C)) { Value *NotX = Builder.CreateNot(X); return BinaryOperator::CreateShl(NotX, ConstantInt::get(Ty, *C)); } // (X >>u C) ^ RHSC --> ~X >>u C if (match(Op0, m_OneUse(m_LShr(m_Value(X), m_APInt(C)))) && - *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).lshr(*C)) { + *RHSC == APInt::getAllOnes(Ty->getScalarSizeInBits()).lshr(*C)) { Value *NotX = Builder.CreateNot(X); return BinaryOperator::CreateLShr(NotX, ConstantInt::get(Ty, *C)); } @@ -3572,101 +3624,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; - // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: - // ~min(~X, ~Y) --> max(X, Y) - // ~max(~X, Y) --> min(X, ~Y) - auto *II = dyn_cast<IntrinsicInst>(Op0); - if (II && match(Op1, m_AllOnes())) { - if (match(Op0, m_MaxOrMin(m_Not(m_Value(X)), m_Not(m_Value(Y))))) { - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); - return replaceInstUsesWith(I, InvMaxMin); - } - if (match(Op0, m_OneUse(m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y))))) { - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - Value *NotY = Builder.CreateNot(Y); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); - return replaceInstUsesWith(I, InvMaxMin); - } - } - - // TODO: Remove folds if we canonicalize to intrinsics (see above). - // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: - // - // %notx = xor i32 %x, -1 - // %cmp1 = icmp sgt i32 %notx, %y - // %smax = select i1 %cmp1, i32 %notx, i32 %y - // %res = xor i32 %smax, -1 - // => - // %noty = xor i32 %y, -1 - // %cmp2 = icmp slt %x, %noty - // %res = select i1 %cmp2, i32 %x, i32 %noty - // - // Same is applicable for smin/umax/umin. - if (match(Op1, m_AllOnes()) && Op0->hasOneUse()) { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(Op0, LHS, RHS).Flavor; - if (SelectPatternResult::isMinOrMax(SPF)) { - // It's possible we get here before the not has been simplified, so make - // sure the input to the not isn't freely invertible. - if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { - Value *NotY = Builder.CreateNot(RHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); - } - - // It's possible we get here before the not has been simplified, so make - // sure the input to the not isn't freely invertible. - if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { - Value *NotX = Builder.CreateNot(LHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); - } - - // If both sides are freely invertible, then we can get rid of the xor - // completely. - if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && - isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { - Value *NotLHS = Builder.CreateNot(LHS); - Value *NotRHS = Builder.CreateNot(RHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), NotLHS, NotRHS), - NotLHS, NotRHS); - } - } - - // Pull 'not' into operands of select if both operands are one-use compares - // or one is one-use compare and the other one is a constant. - // Inverting the predicates eliminates the 'not' operation. - // Example: - // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> - // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) - // not (select ?, (cmp TPred, ?, ?), true --> - // select ?, (cmp InvTPred, ?, ?), false - if (auto *Sel = dyn_cast<SelectInst>(Op0)) { - Value *TV = Sel->getTrueValue(); - Value *FV = Sel->getFalseValue(); - auto *CmpT = dyn_cast<CmpInst>(TV); - auto *CmpF = dyn_cast<CmpInst>(FV); - bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV); - bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV); - if (InvertibleT && InvertibleF) { - if (CmpT) - CmpT->setPredicate(CmpT->getInversePredicate()); - else - Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV))); - if (CmpF) - CmpF->setPredicate(CmpF->getInversePredicate()); - else - Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV))); - return replaceInstUsesWith(I, Sel); - } - } - } - - if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) - return NewXor; - if (Instruction *Abs = canonicalizeAbs(I, Builder)) return Abs; diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 726bb545be12..bfa7bfa2290a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -67,7 +67,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" @@ -79,11 +78,12 @@ #include <utility> #include <vector> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" - STATISTIC(NumSimplified, "Number of library calls simplified"); static cl::opt<unsigned> GuardWideningWindow( @@ -513,7 +513,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { // If the input to cttz/ctlz is known to be non-zero, // then change the 'ZeroIsUndef' parameter to 'true' // because we know the zero behavior can't affect the result. - if (!Known.One.isNullValue() || + if (!Known.One.isZero() || isKnownNonZero(Op0, IC.getDataLayout(), 0, &IC.getAssumptionCache(), &II, &IC.getDominatorTree())) { if (!match(II.getArgOperand(1), m_One())) @@ -656,8 +656,8 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, unsigned NumOperands) { - assert(I.getNumArgOperands() >= NumOperands && "Not enough operands"); - assert(E.getNumArgOperands() >= NumOperands && "Not enough operands"); + assert(I.arg_size() >= NumOperands && "Not enough operands"); + assert(E.arg_size() >= NumOperands && "Not enough operands"); for (unsigned i = 0; i < NumOperands; i++) if (I.getArgOperand(i) != E.getArgOperand(i)) return false; @@ -682,11 +682,11 @@ removeTriviallyEmptyRange(IntrinsicInst &EndI, InstCombinerImpl &IC, BasicBlock::reverse_iterator BI(EndI), BE(EndI.getParent()->rend()); for (; BI != BE; ++BI) { if (auto *I = dyn_cast<IntrinsicInst>(&*BI)) { - if (isa<DbgInfoIntrinsic>(I) || + if (I->isDebugOrPseudoInst() || I->getIntrinsicID() == EndI.getIntrinsicID()) continue; if (IsStart(*I)) { - if (haveSameOperands(EndI, *I, EndI.getNumArgOperands())) { + if (haveSameOperands(EndI, *I, EndI.arg_size())) { IC.eraseInstFromFunction(*I); IC.eraseInstFromFunction(EndI); return true; @@ -710,7 +710,7 @@ Instruction *InstCombinerImpl::visitVAEndInst(VAEndInst &I) { } static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) { - assert(Call.getNumArgOperands() > 1 && "Need at least 2 args to swap"); + assert(Call.arg_size() > 1 && "Need at least 2 args to swap"); Value *Arg0 = Call.getArgOperand(0), *Arg1 = Call.getArgOperand(1); if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) { Call.setArgOperand(0, Arg1); @@ -754,6 +754,45 @@ static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } +/// Try to canonicalize min/max(X + C0, C1) as min/max(X, C1 - C0) + C0. This +/// can trigger other combines. +static Instruction *moveAddAfterMinMax(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin || + MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) && + "Expected a min or max intrinsic"); + + // TODO: Match vectors with undef elements, but undef may not propagate. + Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); + Value *X; + const APInt *C0, *C1; + if (!match(Op0, m_OneUse(m_Add(m_Value(X), m_APInt(C0)))) || + !match(Op1, m_APInt(C1))) + return nullptr; + + // Check for necessary no-wrap and overflow constraints. + bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin; + auto *Add = cast<BinaryOperator>(Op0); + if ((IsSigned && !Add->hasNoSignedWrap()) || + (!IsSigned && !Add->hasNoUnsignedWrap())) + return nullptr; + + // If the constant difference overflows, then instsimplify should reduce the + // min/max to the add or C1. + bool Overflow; + APInt CDiff = + IsSigned ? C1->ssub_ov(*C0, Overflow) : C1->usub_ov(*C0, Overflow); + assert(!Overflow && "Expected simplify of min/max"); + + // min/max (add X, C0), C1 --> add (min/max X, C1 - C0), C0 + // Note: the "mismatched" no-overflow setting does not propagate. + Constant *NewMinMaxC = ConstantInt::get(II->getType(), CDiff); + Value *NewMinMax = Builder.CreateBinaryIntrinsic(MinMaxID, X, NewMinMaxC); + return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1)) + : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1)); +} + /// If we have a clamp pattern like max (min X, 42), 41 -- where the output /// can only be one of two possible constant values -- turn that into a select /// of constants. @@ -795,6 +834,63 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, return SelectInst::Create(Cmp, ConstantInt::get(II->getType(), *C0), I1); } +/// Reduce a sequence of min/max intrinsics with a common operand. +static Instruction *factorizeMinMaxTree(IntrinsicInst *II) { + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1)); + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + if (!LHS || !RHS || LHS->getIntrinsicID() != MinMaxID || + RHS->getIntrinsicID() != MinMaxID || + (!LHS->hasOneUse() && !RHS->hasOneUse())) + return nullptr; + + Value *A = LHS->getArgOperand(0); + Value *B = LHS->getArgOperand(1); + Value *C = RHS->getArgOperand(0); + Value *D = RHS->getArgOperand(1); + + // Look for a common operand. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (LHS->hasOneUse()) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else { + assert(RHS->hasOneUse() && "Expected one-use operand"); + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + + if (!MinMaxOp || !ThirdOp) + return nullptr; + + Module *Mod = II->getModule(); + Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); + return CallInst::Create(MinMax, { MinMaxOp, ThirdOp }); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. @@ -896,7 +992,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (auto *IIFVTy = dyn_cast<FixedVectorType>(II->getType())) { auto VWidth = IIFVTy->getNumElements(); APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { if (V != II) return replaceInstUsesWith(*II, V); @@ -1007,21 +1103,45 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } - if (match(I0, m_Not(m_Value(X)))) { - // max (not X), (not Y) --> not (min X, Y) - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); - if (match(I1, m_Not(m_Value(Y))) && + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { + // smax (neg nsw X), (neg nsw Y) --> neg nsw (smin X, Y) + // smin (neg nsw X), (neg nsw Y) --> neg nsw (smax X, Y) + // TODO: Canonicalize neg after min/max if I1 is constant. + if (match(I0, m_NSWNeg(m_Value(X))) && match(I1, m_NSWNeg(m_Value(Y))) && (I0->hasOneUse() || I1->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); - return BinaryOperator::CreateNot(InvMaxMin); + return BinaryOperator::CreateNSWNeg(InvMaxMin); } - // max (not X), C --> not(min X, ~C) - if (match(I1, m_Constant(C)) && I0->hasOneUse()) { - Constant *NotC = ConstantExpr::getNot(C); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotC); + } + + // If we can eliminate ~A and Y is free to invert: + // max ~A, Y --> ~(min A, ~Y) + // + // Examples: + // max ~A, ~Y --> ~(min A, Y) + // max ~A, C --> ~(min A, ~C) + // max ~A, (max ~Y, ~Z) --> ~min( A, (min Y, Z)) + auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { + Value *A; + if (match(X, m_OneUse(m_Not(m_Value(A)))) && + !isFreeToInvert(A, A->hasOneUse()) && + isFreeToInvert(Y, Y->hasOneUse())) { + Value *NotY = Builder.CreateNot(Y); + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); return BinaryOperator::CreateNot(InvMaxMin); } - } + return nullptr; + }; + + if (Instruction *I = moveNotAfterMinMax(I0, I1)) + return I; + if (Instruction *I = moveNotAfterMinMax(I1, I0)) + return I; + + if (Instruction *I = moveAddAfterMinMax(II, Builder)) + return I; // smax(X, -X) --> abs(X) // smin(X, -X) --> -abs(X) @@ -1051,11 +1171,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *Sel = foldClampRangeOfTwo(II, Builder)) return Sel; + if (Instruction *SAdd = matchSAddSubSat(*II)) + return SAdd; + if (match(I1, m_ImmConstant())) if (auto *Sel = dyn_cast<SelectInst>(I0)) if (Instruction *R = FoldOpIntoSelect(*II, Sel)) return R; + if (Instruction *NewMinMax = factorizeMinMaxTree(II)) + return NewMinMax; + break; } case Intrinsic::bswap: { @@ -1098,6 +1224,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Power->equalsInt(2)) return BinaryOperator::CreateFMulFMF(II->getArgOperand(0), II->getArgOperand(0), II); + + if (!Power->getValue()[0]) { + Value *X; + // If power is even: + // powi(-x, p) -> powi(x, p) + // powi(fabs(x), p) -> powi(x, p) + // powi(copysign(x, y), p) -> powi(x, p) + if (match(II->getArgOperand(0), m_FNeg(m_Value(X))) || + match(II->getArgOperand(0), m_FAbs(m_Value(X))) || + match(II->getArgOperand(0), + m_Intrinsic<Intrinsic::copysign>(m_Value(X), m_Value()))) + return replaceOperand(*II, 0, X); + } } break; @@ -1637,14 +1776,66 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } case Intrinsic::stackrestore: { - // If the save is right next to the restore, remove the restore. This can - // happen when variable allocas are DCE'd. + enum class ClassifyResult { + None, + Alloca, + StackRestore, + CallWithSideEffects, + }; + auto Classify = [](const Instruction *I) { + if (isa<AllocaInst>(I)) + return ClassifyResult::Alloca; + + if (auto *CI = dyn_cast<CallInst>(I)) { + if (auto *II = dyn_cast<IntrinsicInst>(CI)) { + if (II->getIntrinsicID() == Intrinsic::stackrestore) + return ClassifyResult::StackRestore; + + if (II->mayHaveSideEffects()) + return ClassifyResult::CallWithSideEffects; + } else { + // Consider all non-intrinsic calls to be side effects + return ClassifyResult::CallWithSideEffects; + } + } + + return ClassifyResult::None; + }; + + // If the stacksave and the stackrestore are in the same BB, and there is + // no intervening call, alloca, or stackrestore of a different stacksave, + // remove the restore. This can happen when variable allocas are DCE'd. if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { - if (SS->getIntrinsicID() == Intrinsic::stacksave) { - // Skip over debug info. - if (SS->getNextNonDebugInstruction() == II) { - return eraseInstFromFunction(CI); + if (SS->getIntrinsicID() == Intrinsic::stacksave && + SS->getParent() == II->getParent()) { + BasicBlock::iterator BI(SS); + bool CannotRemove = false; + for (++BI; &*BI != II; ++BI) { + switch (Classify(&*BI)) { + case ClassifyResult::None: + // So far so good, look at next instructions. + break; + + case ClassifyResult::StackRestore: + // If we found an intervening stackrestore for a different + // stacksave, we can't remove the stackrestore. Otherwise, continue. + if (cast<IntrinsicInst>(*BI).getArgOperand(0) != SS) + CannotRemove = true; + break; + + case ClassifyResult::Alloca: + case ClassifyResult::CallWithSideEffects: + // If we found an alloca, a non-intrinsic call, or an intrinsic + // call with side effects, we can't remove the stackrestore. + CannotRemove = true; + break; + } + if (CannotRemove) + break; } + + if (!CannotRemove) + return eraseInstFromFunction(CI); } } @@ -1654,29 +1845,25 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Instruction *TI = II->getParent()->getTerminator(); bool CannotRemove = false; for (++BI; &*BI != TI; ++BI) { - if (isa<AllocaInst>(BI)) { + switch (Classify(&*BI)) { + case ClassifyResult::None: + // So far so good, look at next instructions. + break; + + case ClassifyResult::StackRestore: + // If there is a stackrestore below this one, remove this one. + return eraseInstFromFunction(CI); + + case ClassifyResult::Alloca: + case ClassifyResult::CallWithSideEffects: + // If we found an alloca, a non-intrinsic call, or an intrinsic call + // with side effects (such as llvm.stacksave and llvm.read_register), + // we can't remove the stack restore. CannotRemove = true; break; } - if (CallInst *BCI = dyn_cast<CallInst>(BI)) { - if (auto *II2 = dyn_cast<IntrinsicInst>(BCI)) { - // If there is a stackrestore below this one, remove this one. - if (II2->getIntrinsicID() == Intrinsic::stackrestore) - return eraseInstFromFunction(CI); - - // Bail if we cross over an intrinsic with side effects, such as - // llvm.stacksave, or llvm.read_register. - if (II2->mayHaveSideEffects()) { - CannotRemove = true; - break; - } - } else { - // If we found a non-intrinsic call, we can't remove the stack - // restore. - CannotRemove = true; - break; - } - } + if (CannotRemove) + break; } // If the stack restore is in a return, resume, or unwind block and if there @@ -1963,6 +2150,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::experimental_vector_reverse: { + Value *BO0, *BO1, *X, *Y; + Value *Vec = II->getArgOperand(0); + if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) { + auto *OldBinOp = cast<BinaryOperator>(Vec); + if (match(BO0, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(X)))) { + // rev(binop rev(X), rev(Y)) --> binop X, Y + if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(Y)))) + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, Y, OldBinOp, + OldBinOp->getName(), II)); + // rev(binop rev(X), BO1Splat) --> binop X, BO1Splat + if (isSplatValue(BO1)) + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, BO1, + OldBinOp, OldBinOp->getName(), II)); + } + // rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y + if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(Y))) && + isSplatValue(BO0)) + return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), BO0, Y, + OldBinOp, OldBinOp->getName(), II)); + } + // rev(unop rev(X)) --> unop X + if (match(Vec, m_OneUse(m_UnOp( + m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(X)))))) { + auto *OldUnOp = cast<UnaryOperator>(Vec); + auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags( + OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II); + return replaceInstUsesWith(CI, NewUnOp); + } + break; + } case Intrinsic::vector_reduce_or: case Intrinsic::vector_reduce_and: { // Canonicalize logical or/and reductions: @@ -1973,21 +2200,26 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // %val = bitcast <ReduxWidth x i1> to iReduxWidth // %res = cmp eq iReduxWidth %val, 11111 Value *Arg = II->getArgOperand(0); - Type *RetTy = II->getType(); - if (RetTy == Builder.getInt1Ty()) - if (auto *FVTy = dyn_cast<FixedVectorType>(Arg->getType())) { - Value *Res = Builder.CreateBitCast( - Arg, Builder.getIntNTy(FVTy->getNumElements())); - if (IID == Intrinsic::vector_reduce_and) { - Res = Builder.CreateICmpEQ( - Res, ConstantInt::getAllOnesValue(Res->getType())); - } else { - assert(IID == Intrinsic::vector_reduce_or && - "Expected or reduction."); - Res = Builder.CreateIsNotNull(Res); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateBitCast( + Vect, Builder.getIntNTy(FTy->getNumElements())); + if (IID == Intrinsic::vector_reduce_and) { + Res = Builder.CreateICmpEQ( + Res, ConstantInt::getAllOnesValue(Res->getType())); + } else { + assert(IID == Intrinsic::vector_reduce_or && + "Expected or reduction."); + Res = Builder.CreateIsNotNull(Res); + } + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); } - return replaceInstUsesWith(CI, Res); - } + } LLVM_FALLTHROUGH; } case Intrinsic::vector_reduce_add: { @@ -2017,12 +2249,117 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } LLVM_FALLTHROUGH; } - case Intrinsic::vector_reduce_mul: - case Intrinsic::vector_reduce_xor: - case Intrinsic::vector_reduce_umax: + case Intrinsic::vector_reduce_xor: { + if (IID == Intrinsic::vector_reduce_xor) { + // Exclusive disjunction reduction over the vector with + // (potentially-extended) i1 element type is actually a + // (potentially-extended) arithmetic `add` reduction over the original + // non-extended value: + // vector_reduce_xor(?ext(<n x i1>)) + // --> + // ?ext(vector_reduce_add(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateAddReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_mul: { + if (IID == Intrinsic::vector_reduce_mul) { + // Multiplicative reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially zero-extended) + // logical `and` reduction over the original non-extended value: + // vector_reduce_mul(?ext(<n x i1>)) + // --> + // zext(vector_reduce_and(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateAndReduce(Vect); + if (Res->getType() != II->getType()) + Res = Builder.CreateZExt(Res, II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_umin: - case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umax: { + if (IID == Intrinsic::vector_reduce_umin || + IID == Intrinsic::vector_reduce_umax) { + // UMin/UMax reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially-extended) + // logical `and`/`or` reduction over the original non-extended value: + // vector_reduce_u{min,max}(?ext(<n x i1>)) + // --> + // ?ext(vector_reduce_{and,or}(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = IID == Intrinsic::vector_reduce_umin + ? Builder.CreateAndReduce(Vect) + : Builder.CreateOrReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: { + if (IID == Intrinsic::vector_reduce_smin || + IID == Intrinsic::vector_reduce_smax) { + // SMin/SMax reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially-extended) + // logical `and`/`or` reduction over the original non-extended value: + // vector_reduce_s{min,max}(<n x i1>) + // --> + // vector_reduce_{or,and}(<n x i1>) + // and + // vector_reduce_s{min,max}(sext(<n x i1>)) + // --> + // sext(vector_reduce_{or,and}(<n x i1>)) + // and + // vector_reduce_s{min,max}(zext(<n x i1>)) + // --> + // zext(vector_reduce_{and,or}(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd; + if (Arg != Vect) + ExtOpc = cast<CastInst>(Arg)->getOpcode(); + Value *Res = ((IID == Intrinsic::vector_reduce_smin) == + (ExtOpc == Instruction::CastOps::ZExt)) + ? Builder.CreateAndReduce(Vect) + : Builder.CreateOrReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(ExtOpc, Res, II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_fmax: case Intrinsic::vector_reduce_fmin: case Intrinsic::vector_reduce_fadd: @@ -2228,7 +2565,7 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { } void InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { - unsigned NumArgs = Call.getNumArgOperands(); + unsigned NumArgs = Call.arg_size(); ConstantInt *Op0C = dyn_cast<ConstantInt>(Call.getOperand(0)); ConstantInt *Op1C = (NumArgs == 1) ? nullptr : dyn_cast<ConstantInt>(Call.getOperand(1)); @@ -2239,55 +2576,46 @@ void InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryI if (isMallocLikeFn(&Call, TLI) && Op0C) { if (isOpNewLikeFn(&Call, TLI)) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableBytes( - Call.getContext(), Op0C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableBytes( + Call.getContext(), Op0C->getZExtValue())); else - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op0C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op0C->getZExtValue())); } else if (isAlignedAllocLikeFn(&Call, TLI)) { if (Op1C) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op1C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); // Add alignment attribute if alignment is a power of two constant. if (Op0C && Op0C->getValue().ult(llvm::Value::MaximumAlignment) && isKnownNonZero(Call.getOperand(1), DL, 0, &AC, &Call, &DT)) { uint64_t AlignmentVal = Op0C->getZExtValue(); if (llvm::isPowerOf2_64(AlignmentVal)) { - Call.removeAttribute(AttributeList::ReturnIndex, Attribute::Alignment); - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithAlignment(Call.getContext(), - Align(AlignmentVal))); + Call.removeRetAttr(Attribute::Alignment); + Call.addRetAttr(Attribute::getWithAlignment(Call.getContext(), + Align(AlignmentVal))); } } } else if (isReallocLikeFn(&Call, TLI) && Op1C) { - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op1C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); } else if (isCallocLikeFn(&Call, TLI) && Op0C && Op1C) { bool Overflow; const APInt &N = Op0C->getValue(); APInt Size = N.umul_ov(Op1C->getValue(), Overflow); if (!Overflow) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Size.getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size.getZExtValue())); } else if (isStrdupLikeFn(&Call, TLI)) { uint64_t Len = GetStringLength(Call.getOperand(0)); if (Len) { // strdup if (NumArgs == 1) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Len)); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Len)); // strndup else if (NumArgs == 2 && Op1C) - Call.addAttribute( - AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); } } } @@ -2489,7 +2817,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { // isKnownNonNull -> nonnull attribute if (!GCR.hasRetAttr(Attribute::NonNull) && isKnownNonZero(DerivedPtr, DL, 0, &AC, &Call, &DT)) { - GCR.addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + GCR.addRetAttr(Attribute::NonNull); // We discovered new fact, re-check users. Worklist.pushUsersToWorkList(GCR); } @@ -2646,19 +2974,19 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) return false; // Cannot transform this parameter value. - if (AttrBuilder(CallerPAL.getParamAttributes(i)) + if (AttrBuilder(CallerPAL.getParamAttrs(i)) .overlaps(AttributeFuncs::typeIncompatible(ParamTy))) return false; // Attribute not compatible with transformed value. if (Call.isInAllocaArgument(i)) return false; // Cannot transform to and from inalloca. - if (CallerPAL.hasParamAttribute(i, Attribute::SwiftError)) + if (CallerPAL.hasParamAttr(i, Attribute::SwiftError)) return false; // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. - if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { + if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); if (!ParamPTy || !ParamPTy->getElementType()->isSized()) return false; @@ -2699,7 +3027,7 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // that are compatible with being a vararg call argument. unsigned SRetIdx; if (CallerPAL.hasAttrSomewhere(Attribute::StructRet, &SRetIdx) && - SRetIdx > FT->getNumParams()) + SRetIdx - AttributeList::FirstArgIndex >= FT->getNumParams()) return false; } @@ -2728,12 +3056,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { Args.push_back(NewArg); // Add any parameter attributes. - if (CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { - AttrBuilder AB(CallerPAL.getParamAttributes(i)); + if (CallerPAL.hasParamAttr(i, Attribute::ByVal)) { + AttrBuilder AB(CallerPAL.getParamAttrs(i)); AB.addByValAttr(NewArg->getType()->getPointerElementType()); ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); } else - ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); + ArgAttrs.push_back(CallerPAL.getParamAttrs(i)); } // If the function takes more arguments than the call was taking, add them @@ -2760,12 +3088,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { Args.push_back(NewArg); // Add any parameter attributes. - ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); + ArgAttrs.push_back(CallerPAL.getParamAttrs(i)); } } } - AttributeSet FnAttrs = CallerPAL.getFnAttributes(); + AttributeSet FnAttrs = CallerPAL.getFnAttrs(); if (NewRetTy->isVoidTy()) Caller->setName(""); // Void type should not have a name. @@ -2866,7 +3194,7 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, for (FunctionType::param_iterator I = NestFTy->param_begin(), E = NestFTy->param_end(); I != E; ++NestArgNo, ++I) { - AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo); + AttributeSet AS = NestAttrs.getParamAttrs(NestArgNo); if (AS.hasAttribute(Attribute::Nest)) { // Record the parameter type and any other attributes. NestTy = *I; @@ -2902,7 +3230,7 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, // Add the original argument and attributes. NewArgs.push_back(*I); - NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); + NewArgAttrs.push_back(Attrs.getParamAttrs(ArgNo)); ++ArgNo; ++I; @@ -2948,8 +3276,8 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, NestF : ConstantExpr::getBitCast(NestF, PointerType::getUnqual(NewFTy)); AttributeList NewPAL = - AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(), - Attrs.getRetAttributes(), NewArgAttrs); + AttributeList::get(FTy->getContext(), Attrs.getFnAttrs(), + Attrs.getRetAttrs(), NewArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; Call.getOperandBundlesAsDefs(OpBundles); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 04877bec94ec..ca87477c5d81 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -333,7 +333,7 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { SrcTy->getNumElements() == DestTy->getNumElements() && SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) { Value *CastX = Builder.CreateCast(CI.getOpcode(), X, DestTy); - return new ShuffleVectorInst(CastX, UndefValue::get(DestTy), Mask); + return new ShuffleVectorInst(CastX, Mask); } } @@ -701,10 +701,10 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) && is_splat(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { - // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask - Constant *NarrowUndef = UndefValue::get(Trunc.getType()); + // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask + // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); - return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getShuffleMask()); + return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask()); } return nullptr; @@ -961,14 +961,25 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { return BinaryOperator::CreateAdd(NarrowCtlz, WidthDiff); } } + + if (match(Src, m_VScale(DL))) { + if (Trunc.getFunction() && + Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + unsigned MaxVScale = Trunc.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + if (MaxVScale > 0 && Log2_32(MaxVScale) < DestWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(Trunc, VScale); + } + } + } + return nullptr; } -/// Transform (zext icmp) to bitwise / integer operations in order to -/// eliminate it. If DoTransform is false, just test whether the given -/// (zext icmp) can be transformed. -Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, - bool DoTransform) { +Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -977,10 +988,8 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || - (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { - if (!DoTransform) return Cmp; - + if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) || + (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnes())) { Value *In = Cmp->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), In->getType()->getScalarSizeInBits() - 1); @@ -1004,7 +1013,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) && + if ((Op1CV->isZero() || Op1CV->isPowerOf2()) && // This only works for EQ and NE Cmp->isEquality()) { // If Op1C some other power of two, convert: @@ -1012,10 +1021,8 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - if (!DoTransform) return Cmp; - bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) { + if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true Constant *Res = ConstantInt::get(Zext.getType(), isNE); @@ -1031,7 +1038,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, In->getName() + ".lobit"); } - if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit. + if (!Op1CV->isZero() == isNE) { // Toggle the low bit. Constant *One = ConstantInt::get(In->getType(), 1); In = Builder.CreateXor(In, One); } @@ -1053,9 +1060,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) && match(Cmp->getOperand(0), m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) { - if (!DoTransform) - return Cmp; - if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) X = Builder.CreateNot(X); Value *Lshr = Builder.CreateLShr(X, ShAmt); @@ -1077,8 +1081,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, APInt KnownBits = KnownLHS.Zero | KnownLHS.One; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { - if (!DoTransform) return Cmp; - Value *Result = Builder.CreateXor(LHS, RHS); // Mask off any bits that are set and won't be shifted away. @@ -1316,51 +1318,37 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src)) return transformZExtICmp(Cmp, CI); - BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src); - if (SrcI && SrcI->getOpcode() == Instruction::Or) { - // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) if at least one - // of the (zext icmp) can be eliminated. If so, immediately perform the - // according elimination. - ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0)); - ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1)); - if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && - LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType() && - (transformZExtICmp(LHS, CI, false) || - transformZExtICmp(RHS, CI, false))) { - // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) - Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName()); - Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName()); - Value *Or = Builder.CreateOr(LCast, RCast, CI.getName()); - if (auto *OrInst = dyn_cast<Instruction>(Or)) - Builder.SetInsertPoint(OrInst); - - // Perform the elimination. - if (auto *LZExt = dyn_cast<ZExtInst>(LCast)) - transformZExtICmp(LHS, *LZExt); - if (auto *RZExt = dyn_cast<ZExtInst>(RCast)) - transformZExtICmp(RHS, *RZExt); - - return replaceInstUsesWith(CI, Or); - } - } - // zext(trunc(X) & C) -> (X & zext(C)). Constant *C; Value *X; - if (SrcI && - match(SrcI, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && + if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && X->getType() == CI.getType()) return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType())); // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). Value *And; - if (SrcI && match(SrcI, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && + if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && X->getType() == CI.getType()) { Constant *ZC = ConstantExpr::getZExt(C, CI.getType()); return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } + if (match(Src, m_VScale(DL))) { + if (CI.getFunction() && + CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + unsigned MaxVScale = CI.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); + if (MaxVScale > 0 && Log2_32(MaxVScale) < TypeWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } + } + } + return nullptr; } @@ -1605,6 +1593,32 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { return BinaryOperator::CreateAShr(A, NewShAmt); } + // Splatting a bit of constant-index across a value: + // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1 + // TODO: If the dest type is different, use a cast (adjust use check). + if (match(Src, m_OneUse(m_AShr(m_Trunc(m_Value(X)), + m_SpecificInt(SrcBitSize - 1)))) && + X->getType() == DestTy) { + Constant *ShlAmtC = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); + Constant *AshrAmtC = ConstantInt::get(DestTy, DestBitSize - 1); + Value *Shl = Builder.CreateShl(X, ShlAmtC); + return BinaryOperator::CreateAShr(Shl, AshrAmtC); + } + + if (match(Src, m_VScale(DL))) { + if (CI.getFunction() && + CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + unsigned MaxVScale = CI.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + if (MaxVScale > 0 && Log2_32(MaxVScale) < (SrcBitSize - 1)) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } + } + } + return nullptr; } @@ -2060,6 +2074,19 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); } + if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) { + // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. + // While this can increase the number of instructions it doesn't actually + // increase the overall complexity since the arithmetic is just part of + // the GEP otherwise. + if (GEP->hasOneUse() && + isa<ConstantPointerNull>(GEP->getPointerOperand())) { + return replaceInstUsesWith(CI, + Builder.CreateIntCast(EmitGEPOffset(GEP), Ty, + /*isSigned=*/false)); + } + } + Value *Vec, *Scalar, *Index; if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)), m_Value(Scalar), m_Value(Index)))) && @@ -2133,9 +2160,9 @@ optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, if (SrcElts > DestElts) { // If we're shrinking the number of elements (rewriting an integer // truncate), just shuffle in the elements corresponding to the least - // significant bits from the input and use undef as the second shuffle + // significant bits from the input and use poison as the second shuffle // input. - V2 = UndefValue::get(SrcTy); + V2 = PoisonValue::get(SrcTy); // Make sure the shuffle mask selects the "least significant bits" by // keeping elements from back of the src vector for big endian, and from the // front for little endian. @@ -2528,7 +2555,7 @@ Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, // As long as the user is another old PHI node, then even if we don't // rewrite it, the PHI web we're considering won't have any users // outside itself, so it'll be dead. - if (OldPhiNodes.count(PHI) == 0) + if (!OldPhiNodes.contains(PHI)) return nullptr; } else { return nullptr; @@ -2736,6 +2763,30 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (auto *InsElt = dyn_cast<InsertElementInst>(Src)) return new BitCastInst(InsElt->getOperand(1), DestTy); } + + // Convert an artificial vector insert into more analyzable bitwise logic. + unsigned BitWidth = DestTy->getScalarSizeInBits(); + Value *X, *Y; + uint64_t IndexC; + if (match(Src, m_OneUse(m_InsertElt(m_OneUse(m_BitCast(m_Value(X))), + m_Value(Y), m_ConstantInt(IndexC)))) && + DestTy->isIntegerTy() && X->getType() == DestTy && + isDesirableIntType(BitWidth)) { + // Adjust for big endian - the LSBs are at the high index. + if (DL.isBigEndian()) + IndexC = SrcVTy->getNumElements() - 1 - IndexC; + + // We only handle (endian-normalized) insert to index 0. Any other insert + // would require a left-shift, so that is an extra instruction. + if (IndexC == 0) { + // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y) + unsigned EltWidth = Y->getType()->getScalarSizeInBits(); + APInt MaskC = APInt::getHighBitsSet(BitWidth, BitWidth - EltWidth); + Value *AndX = Builder.CreateAnd(X, MaskC); + Value *ZextY = Builder.CreateZExt(Y, DestTy); + return BinaryOperator::CreateOr(AndX, ZextY); + } + } } if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index c5e14ebf3ae3..7a9e177f19da 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -78,15 +78,15 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (C.isNullValue()) + if (C.isZero()) return ICmpInst::isRelational(Pred); - if (C.isOneValue()) { + if (C.isOne()) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; } - } else if (C.isAllOnesValue()) { + } else if (C.isAllOnes()) { if (Pred == ICmpInst::ICMP_SGT) { Pred = ICmpInst::ICMP_SGE; return true; @@ -541,7 +541,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, if (!CI->isNoopCast(DL)) return false; - if (Explored.count(CI->getOperand(0)) == 0) + if (!Explored.contains(CI->getOperand(0))) WorkList.push_back(CI->getOperand(0)); } @@ -553,7 +553,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, GEP->getType() != Start->getType()) return false; - if (Explored.count(GEP->getOperand(0)) == 0) + if (!Explored.contains(GEP->getOperand(0))) WorkList.push_back(GEP->getOperand(0)); } @@ -575,7 +575,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, // Explore the PHI nodes further. for (auto *PN : PHIs) for (Value *Op : PN->incoming_values()) - if (Explored.count(Op) == 0) + if (!Explored.contains(Op)) WorkList.push_back(Op); } @@ -589,7 +589,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, auto *Inst = dyn_cast<Instruction>(Val); if (Inst == Base || Inst == PHI || !Inst || !PHI || - Explored.count(PHI) == 0) + !Explored.contains(PHI)) continue; if (PHI->getParent() == Inst->getParent()) @@ -1147,12 +1147,12 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2.isNullValue()) + if (AP2.isZero()) return nullptr; bool IsAShr = isa<AShrOperator>(I.getOperand(0)); if (IsAShr) { - if (AP2.isAllOnesValue()) + if (AP2.isAllOnes()) return nullptr; if (AP2.isNegative() != AP1.isNegative()) return nullptr; @@ -1178,7 +1178,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, if (IsAShr && AP1 == AP2.ashr(Shift)) { // There are multiple solutions if we are comparing against -1 and the LHS // of the ashr is not a power of two. - if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) + if (AP1.isAllOnes() && !AP2.isPowerOf2()) return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); } else if (AP1 == AP2.lshr(Shift)) { @@ -1206,7 +1206,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2.isNullValue()) + if (AP2.isZero()) return nullptr; unsigned AP2TrailingZeros = AP2.countTrailingZeros(); @@ -1270,9 +1270,8 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + if (IC.ComputeMinSignedBits(A, 0, &I) > NewWidth || + IC.ComputeMinSignedBits(B, 0, &I) > NewWidth) return nullptr; // In order to replace the original add with a narrower @@ -1544,7 +1543,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (C.isOneValue() && C.getBitWidth() > 1) { + if (C.isOne() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1725,7 +1724,7 @@ Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() && + if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = @@ -1749,7 +1748,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 // TODO: We canonicalize to the longer form for scalars because we have // better analysis/folds for icmp, and codegen may be better with icmp. - if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isNullValue() && + if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isZero() && match(And->getOperand(1), m_One())) return new TruncInst(And->getOperand(0), Cmp.getType()); @@ -1762,7 +1761,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, if (!And->hasOneUse()) return nullptr; - if (Cmp.isEquality() && C1.isNullValue()) { + if (Cmp.isEquality() && C1.isZero()) { // Restrict this fold to single-use 'and' (PR10267). // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 if (C2->isSignMask()) { @@ -1812,7 +1811,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() && + if (!Cmp.isSigned() && C1.isZero() && And->getOperand(0)->hasOneUse() && match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); @@ -1889,7 +1888,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { + if (Cmp.getOperand(1) == Y && C.isNegatedPowerOf2()) { auto NewPred = Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); @@ -1899,7 +1898,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C.isZero() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); @@ -1920,7 +1919,7 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (C.isOneValue()) { + if (C.isOne()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1950,7 +1949,18 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, } } - if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) + // (X | (X-1)) s< 0 --> X s< 1 + // (X | (X-1)) s> -1 --> X s> 0 + Value *X; + bool TrueIfSigned; + if (isSignBitCheck(Pred, C, TrueIfSigned) && + match(Or, m_c_Or(m_Add(m_Value(X), m_AllOnes()), m_Deferred(X)))) { + auto NewPred = TrueIfSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT; + Constant *NewC = ConstantInt::get(X->getType(), TrueIfSigned ? 1 : 0); + return new ICmpInst(NewPred, X, NewC); + } + + if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -2001,14 +2011,14 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, // If the multiply does not wrap, try to divide the compare constant by the // multiplication factor. - if (Cmp.isEquality() && !MulC->isNullValue()) { + if (Cmp.isEquality() && !MulC->isZero()) { // (mul nsw X, MulC) == C --> X == C /s MulC - if (Mul->hasNoSignedWrap() && C.srem(*MulC).isNullValue()) { + if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); return new ICmpInst(Pred, Mul->getOperand(0), NewC); } // (mul nuw X, MulC) == C --> X == C /u MulC - if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isNullValue()) { + if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); return new ICmpInst(Pred, Mul->getOperand(0), NewC); } @@ -2053,7 +2063,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C.isAllOnesValue()) { + if (C.isAllOnes()) { // (1 << Y) <= -1 -> Y == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); @@ -2227,8 +2237,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); - if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && - C.isNullValue()) + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && C.isZero()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; @@ -2316,7 +2325,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, if (Shr->isExact()) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - if (C.isNullValue()) { + if (C.isZero()) { // == 0 is u< 1. if (Pred == CmpInst::ICMP_EQ) return new ICmpInst(CmpInst::ICMP_ULT, X, @@ -2355,7 +2364,7 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, return nullptr; const APInt *DivisorC; - if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + if (!C.isZero() || !match(SRem->getOperand(1), m_Power2(DivisorC))) return nullptr; // Mask off the sign bit and the modulo bits (low-bits). @@ -2435,8 +2444,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // INT_MIN will also fail if the divisor is 1. Although folds of all these // division-by-constant cases should be present, we can not assert that they // have happened before we reach this icmp instruction. - if (C2->isNullValue() || C2->isOneValue() || - (DivIsSigned && C2->isAllOnesValue())) + if (C2->isZero() || C2->isOne() || (DivIsSigned && C2->isAllOnes())) return nullptr; // Compute Prod = C * C2. We are essentially solving an equation of @@ -2476,16 +2484,16 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (C.isNullValue()) { // (X / pos) op 0 + if (C.isZero()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = -(RangeSize - 1); HiBound = RangeSize; - } else if (C.isStrictlyPositive()) { // (X / pos) op pos + } else if (C.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); - } else { // (X / pos) op neg + } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; @@ -2497,7 +2505,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) RangeSize.negate(); - if (C.isNullValue()) { // (X / neg) op 0 + if (C.isZero()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = RangeSize + 1; HiBound = -RangeSize; @@ -2505,13 +2513,13 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = 1; // [INTMIN+1, overflow) HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (C.isStrictlyPositive()) { // (X / neg) op pos + } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; - } else { // (X / neg) op neg + } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) @@ -2581,42 +2589,54 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); - const APInt *C2; - APInt SubResult; + Type *Ty = Sub->getType(); - // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0 - if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality()) - return new ICmpInst(Cmp.getPredicate(), Y, - ConstantInt::get(Y->getType(), 0)); + // (SubC - Y) == C) --> Y == (SubC - C) + // (SubC - Y) != C) --> Y != (SubC - C) + Constant *SubC; + if (Cmp.isEquality() && match(X, m_ImmConstant(SubC))) { + return new ICmpInst(Pred, Y, + ConstantExpr::getSub(SubC, ConstantInt::get(Ty, C))); + } // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) + const APInt *C2; + APInt SubResult; + ICmpInst::Predicate SwappedPred = Cmp.getSwappedPredicate(); + bool HasNSW = Sub->hasNoSignedWrap(); + bool HasNUW = Sub->hasNoUnsignedWrap(); if (match(X, m_APInt(C2)) && - ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || - (Cmp.isSigned() && Sub->hasNoSignedWrap())) && + ((Cmp.isUnsigned() && HasNUW) || (Cmp.isSigned() && HasNSW)) && !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) - return new ICmpInst(Cmp.getSwappedPredicate(), Y, - ConstantInt::get(Y->getType(), SubResult)); + return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult)); // The following transforms are only worth it if the only user of the subtract // is the icmp. + // TODO: This is an artificial restriction for all of the transforms below + // that only need a single replacement icmp. if (!Sub->hasOneUse()) return nullptr; + // X - Y == 0 --> X == Y. + // X - Y != 0 --> X != Y. + if (Cmp.isEquality() && C.isZero()) + return new ICmpInst(Pred, X, Y); + if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) - if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && C.isNullValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isZero()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && C.isNullValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isZero()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && C.isOneValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isOne()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2634,7 +2654,12 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); - return nullptr; + // We have handled special cases that reduce. + // Canonicalize any remaining sub to add as: + // (C2 - Y) > C --> (Y + ~C2) < ~C + Value *Add = Builder.CreateAdd(Y, ConstantInt::get(Ty, ~(*C2)), "notsub", + HasNUW, HasNSW); + return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); } /// Fold icmp (add X, Y), C. @@ -2723,6 +2748,14 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), ConstantExpr::getNeg(cast<Constant>(Y))); + // The range test idiom can use either ult or ugt. Arbitrarily canonicalize + // to the ult form. + // X+C2 >u C -> X+(C2-C-1) <u ~C + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, + Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)), + ConstantInt::get(Ty, ~C)); + return nullptr; } @@ -2830,8 +2863,7 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, return nullptr; } -static Instruction *foldICmpBitCast(ICmpInst &Cmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); if (!Bitcast) return nullptr; @@ -2917,6 +2949,39 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, return new ICmpInst(Pred, BCSrcOp, Op1); } + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C)) || + !Bitcast->getType()->isIntegerTy() || + !Bitcast->getSrcTy()->isIntOrIntVectorTy()) + return nullptr; + + // If this is checking if all elements of a vector compare are set or not, + // invert the casted vector equality compare and test if all compare + // elements are clear or not. Compare against zero is generally easier for + // analysis and codegen. + // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 + // Example: are all elements equal? --> are zero elements not equal? + // TODO: Try harder to reduce compare of 2 freely invertible operands? + if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && + isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { + Type *ScalarTy = Bitcast->getType(); + Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), ScalarTy); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(ScalarTy)); + } + + // If this is checking if all elements of an extended vector are clear or not, + // compare in a narrow type to eliminate the extend: + // icmp eq/ne (bitcast (ext X) to iN), 0 --> icmp eq/ne (bitcast X to iM), 0 + Value *X; + if (Cmp.isEquality() && C->isZero() && Bitcast->hasOneUse() && + match(BCSrcOp, m_ZExtOrSExt(m_Value(X)))) { + if (auto *VecTy = dyn_cast<FixedVectorType>(X->getType())) { + Type *NewType = Builder.getIntNTy(VecTy->getPrimitiveSizeInBits()); + Value *NewCast = Builder.CreateBitCast(X, NewType); + return new ICmpInst(Pred, NewCast, ConstantInt::getNullValue(NewType)); + } + } + // Folding: icmp <pred> iN X, C // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN // and C is a splat of a K-bit pattern @@ -2924,12 +2989,6 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, // Into: // %E = extractelement <M x iK> %vec, i32 C' // icmp <pred> iK %E, trunc(C) - const APInt *C; - if (!match(Cmp.getOperand(1), m_APInt(C)) || - !Bitcast->getType()->isIntegerTy() || - !Bitcast->getSrcTy()->isIntOrIntVectorTy()) - return nullptr; - Value *Vec; ArrayRef<int> Mask; if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { @@ -3055,7 +3114,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (C.isNullValue() && BO->hasOneUse()) { + if (C.isZero() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); @@ -3069,7 +3128,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( if (Constant *BOC = dyn_cast<Constant>(BOp1)) { if (BO->hasOneUse()) return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); - } else if (C.isNullValue()) { + } else if (C.isZero()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -3090,25 +3149,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C.isNullValue()) { + } else if (C.isZero()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } } break; - case Instruction::Sub: - if (BO->hasOneUse()) { - // Only check for constant LHS here, as constant RHS will be canonicalized - // to add and use the fold above. - if (Constant *BOC = dyn_cast<Constant>(BOp0)) { - // Replace ((sub BOC, B) != C) with (B != BOC-C). - return new ICmpInst(Pred, BOp1, ConstantExpr::getSub(BOC, RHS)); - } else if (C.isNullValue()) { - // Replace ((sub A, B) != 0) with (A != B). - return new ICmpInst(Pred, BOp0, BOp1); - } - } - break; case Instruction::Or: { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { @@ -3132,7 +3178,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( break; } case Instruction::UDiv: - if (C.isNullValue()) { + if (C.isZero()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -3149,25 +3195,26 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + switch (II->getIntrinsicID()) { case Intrinsic::abs: // abs(A) == 0 -> A == 0 // abs(A) == INT_MIN -> A == INT_MIN - if (C.isNullValue() || C.isMinSignedValue()) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), - ConstantInt::get(Ty, C)); + if (C.isZero() || C.isMinSignedValue()) + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C)); break; case Intrinsic::bswap: // bswap(A) == C -> A == bswap(C) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C.byteSwap())); case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != if (C == BitWidth) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::getNullValue(Ty)); // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set @@ -3181,9 +3228,8 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( APInt Mask2 = IsTrailing ? APInt::getOneBitSet(BitWidth, Num) : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); - return new ICmpInst(Cmp.getPredicate(), - Builder.CreateAnd(II->getArgOperand(0), Mask1), - ConstantInt::get(Ty, Mask2)); + return new ICmpInst(Pred, Builder.CreateAnd(II->getArgOperand(0), Mask1), + ConstantInt::get(Ty, Mask2)); } break; } @@ -3191,28 +3237,49 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = C.isNullValue(); + bool IsZero = C.isZero(); if (IsZero || C == BitWidth) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), - IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty)); + return new ICmpInst(Pred, II->getArgOperand(0), + IsZero ? Constant::getNullValue(Ty) + : Constant::getAllOnesValue(Ty)); break; } + case Intrinsic::fshl: + case Intrinsic::fshr: + if (II->getArgOperand(0) == II->getArgOperand(1)) { + // (rot X, ?) == 0/-1 --> X == 0/-1 + // TODO: This transform is safe to re-use undef elts in a vector, but + // the constant value passed in by the caller doesn't allow that. + if (C.isZero() || C.isAllOnes()) + return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); + + const APInt *RotAmtC; + // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC) + // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC) + if (match(II->getArgOperand(2), m_APInt(RotAmtC))) + return new ICmpInst(Pred, II->getArgOperand(0), + II->getIntrinsicID() == Intrinsic::fshl + ? ConstantInt::get(Ty, C.rotr(*RotAmtC)) + : ConstantInt::get(Ty, C.rotl(*RotAmtC))); + } + break; + case Intrinsic::uadd_sat: { // uadd.sat(a, b) == 0 -> (a | b) == 0 - if (C.isNullValue()) { + if (C.isZero()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); - return new ICmpInst(Cmp.getPredicate(), Or, Constant::getNullValue(Ty)); + return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); } break; } case Intrinsic::usub_sat: { // usub.sat(a, b) == 0 -> a <= b - if (C.isNullValue()) { - ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + if (C.isZero()) { + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } break; @@ -3224,6 +3291,42 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( return nullptr; } +/// Fold an icmp with LLVM intrinsics +static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { + assert(Cmp.isEquality()); + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *Op1 = Cmp.getOperand(1); + const auto *IIOp0 = dyn_cast<IntrinsicInst>(Op0); + const auto *IIOp1 = dyn_cast<IntrinsicInst>(Op1); + if (!IIOp0 || !IIOp1 || IIOp0->getIntrinsicID() != IIOp1->getIntrinsicID()) + return nullptr; + + switch (IIOp0->getIntrinsicID()) { + case Intrinsic::bswap: + case Intrinsic::bitreverse: + // If both operands are byte-swapped or bit-reversed, just compare the + // original values. + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + case Intrinsic::fshl: + case Intrinsic::fshr: + // If both operands are rotated by same amount, just compare the + // original values. + if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) + break; + if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) + break; + if (IIOp0->getOperand(2) != IIOp1->getOperand(2)) + break; + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + default: + break; + } + + return nullptr; +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, @@ -3663,7 +3766,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, (WidestTy->getScalarSizeInBits() - 1) + (NarrowestTy->getScalarSizeInBits() - 1); APInt MaximalRepresentableShiftAmount = - APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits()); + APInt::getAllOnes(XShAmt->getType()->getScalarSizeInBits()); if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) return nullptr; @@ -3746,19 +3849,22 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /// Fold /// (-1 u/ x) u< y -/// ((x * y) u/ x) != y +/// ((x * y) ?/ x) != y /// to -/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit /// Note that the comparison is commutative, while inverted (u>=, ==) predicate /// will mean that we are looking for the opposite answer. -Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { +Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { ICmpInst::Predicate Pred; Value *X, *Y; Instruction *Mul; + Instruction *Div; bool NeedNegation; // Look for: (-1 u/ x) u</u>= y if (!I.isEquality() && - match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + match(&I, m_c_ICmp(Pred, + m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Instruction(Div)), m_Value(Y)))) { Mul = nullptr; @@ -3773,13 +3879,16 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { default: return nullptr; // Wrong predicate. } - } else // Look for: ((x * y) u/ x) !=/== y + } else // Look for: ((x * y) / x) !=/== y if (I.isEquality() && - match(&I, m_c_ICmp(Pred, m_Value(Y), - m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + match(&I, + m_c_ICmp(Pred, m_Value(Y), + m_CombineAnd( + m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), m_Value(X)), m_Instruction(Mul)), - m_Deferred(X)))))) { + m_Deferred(X))), + m_Instruction(Div))))) { NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; } else return nullptr; @@ -3791,19 +3900,22 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { if (MulHadOtherUses) Builder.SetInsertPoint(Mul); - Function *F = Intrinsic::getDeclaration( - I.getModule(), Intrinsic::umul_with_overflow, X->getType()); - CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + Function *F = Intrinsic::getDeclaration(I.getModule(), + Div->getOpcode() == Instruction::UDiv + ? Intrinsic::umul_with_overflow + : Intrinsic::smul_with_overflow, + X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul"); // If the multiplication was used elsewhere, to ensure that we don't leave // "duplicate" instructions, replace uses of that original multiplication // with the multiplication result from the with.overflow intrinsic. if (MulHadOtherUses) - replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val")); - Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov"); if (NeedNegation) // This technically increases instruction count. - Res = Builder.CreateNot(Res, "umul.not.ov"); + Res = Builder.CreateNot(Res, "mul.not.ov"); // If we replaced the mul, erase it. Do this after all uses of Builder, // as the mul is used as insertion point. @@ -4079,8 +4191,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) if (!C->countTrailingZeros() || - (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) + (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || + (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) return new ICmpInst(Pred, X, Y); } @@ -4146,8 +4258,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, break; const APInt *C; - if (match(BO0->getOperand(1), m_APInt(C)) && !C->isNullValue() && - !C->isOneValue()) { + if (match(BO0->getOperand(1), m_APInt(C)) && !C->isZero() && + !C->isOne()) { // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) // Mask = -1 >> count-trailing-zeros(C). if (unsigned TZs = C->countTrailingZeros()) { @@ -4200,7 +4312,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } } - if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) @@ -4373,6 +4485,19 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } + { + // Similar to above, but specialized for constant because invert is needed: + // (X | C) == (Y | C) --> (X ^ Y) & ~C == 0 + Value *X, *Y; + Constant *C; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_Constant(C)))) && + match(Op1, m_OneUse(m_Or(m_Value(Y), m_Specific(C))))) { + Value *Xor = Builder.CreateXor(X, Y); + Value *And = Builder.CreateAnd(Xor, ConstantExpr::getNot(C)); + return new ICmpInst(Pred, And, Constant::getNullValue(And->getType())); + } + } + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) ConstantInt *Cst1; @@ -4441,14 +4566,8 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - // If both operands are byte-swapped or bit-reversed, just compare the - // original values. - // TODO: Move this to a function similar to foldICmpIntrinsicWithConstant() - // and handle more intrinsics. - if ((match(Op0, m_BSwap(m_Value(A))) && match(Op1, m_BSwap(m_Value(B)))) || - (match(Op0, m_BitReverse(m_Value(A))) && - match(Op1, m_BitReverse(m_Value(B))))) - return new ICmpInst(Pred, A, B); + if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I)) + return ICmp; // Canonicalize checking for a power-of-2-or-zero value: // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) @@ -4474,6 +4593,74 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); } + // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the + // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", + // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps + // of instcombine. + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + if (match(Op0, m_AShr(m_Trunc(m_Value(A)), m_SpecificInt(BitWidth - 1))) && + match(Op1, m_Trunc(m_LShr(m_Specific(A), m_SpecificInt(BitWidth)))) && + A->getType()->getScalarSizeInBits() == BitWidth * 2 && + (I.getOperand(0)->hasOneUse() || I.getOperand(1)->hasOneUse())) { + APInt C = APInt::getOneBitSet(BitWidth * 2, BitWidth - 1); + Value *Add = Builder.CreateAdd(A, ConstantInt::get(A->getType(), C)); + return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT + : ICmpInst::ICMP_UGE, + Add, ConstantInt::get(A->getType(), C.shl(1))); + } + + return nullptr; +} + +static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + const ICmpInst::Predicate Pred = ICmp.getPredicate(); + Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); + + // Try to canonicalize trunc + compare-to-constant into a mask + cmp. + // The trunc masks high bits while the compare may effectively mask low bits. + Value *X; + const APInt *C; + if (!match(Op0, m_OneUse(m_Trunc(m_Value(X)))) || !match(Op1, m_APInt(C))) + return nullptr; + + unsigned SrcBits = X->getType()->getScalarSizeInBits(); + if (Pred == ICmpInst::ICMP_ULT) { + if (C->isPowerOf2()) { + // If C is a power-of-2 (one set bit): + // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) + Constant *MaskC = ConstantInt::get(X->getType(), (-*C).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(ICmpInst::ICMP_EQ, And, Zero); + } + // If C is a negative power-of-2 (high-bit mask): + // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) + if (C->isNegatedPowerOf2()) { + Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); + } + } + + if (Pred == ICmpInst::ICMP_UGT) { + // If C is a low-bit-mask (C+1 is a power-of-2): + // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) + if (C->isMask()) { + Constant *MaskC = ConstantInt::get(X->getType(), (~*C).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + // If C is not-of-power-of-2 (one clear bit): + // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) + if ((~*C).isPowerOf2()) { + Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); + } + } + return nullptr; } @@ -4620,6 +4807,9 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); } + if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) + return R; + return foldICmpWithZextOrSext(ICmp, Builder); } @@ -4943,7 +5133,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { const APInt *RHS; if (!match(I.getOperand(1), m_APInt(RHS))) - return APInt::getAllOnesValue(BitWidth); + return APInt::getAllOnes(BitWidth); // If this is a normal comparison, it demands all bits. If it is a sign bit // comparison, it only demands the sign bit. @@ -4965,7 +5155,7 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); default: - return APInt::getAllOnesValue(BitWidth); + return APInt::getAllOnes(BitWidth); } } @@ -5129,8 +5319,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { Op0Known, 0)) return &I; - if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth), - Op1Known, 0)) + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) return &I; // Given the known and unknown bits, compute a range that the LHS could be @@ -5280,7 +5469,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. const APInt *CI; - if (Op0KnownZeroInverted.isOneValue() && + if (Op0KnownZeroInverted.isOne() && match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { // ((8 >>u X) & 1) == 0 -> X != 3 // ((8 >>u X) & 1) != 0 -> X == 3 @@ -5618,7 +5807,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); - return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); + return new ShuffleVectorInst(NewCmp, M); } // Try to canonicalize compare with splatted operand and splat constant. @@ -5639,8 +5828,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, ScalarC); SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); Value *NewCmp = Builder.CreateCmp(Pred, V1, C); - return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), - NewM); + return new ShuffleVectorInst(NewCmp, NewM); } return nullptr; @@ -5676,6 +5864,23 @@ static Instruction *foldICmpOfUAddOv(ICmpInst &I) { return ExtractValueInst::Create(UAddOv, 1); } +static Instruction *foldICmpInvariantGroup(ICmpInst &I) { + if (!I.getOperand(0)->getType()->isPointerTy() || + NullPointerIsDefined( + I.getParent()->getParent(), + I.getOperand(0)->getType()->getPointerAddressSpace())) { + return nullptr; + } + Instruction *Op; + if (match(I.getOperand(0), m_Instruction(Op)) && + match(I.getOperand(1), m_Zero()) && + Op->isLaunderOrStripInvariantGroup()) { + return ICmpInst::Create(Instruction::ICmp, I.getPredicate(), + Op->getOperand(0), I.getOperand(1)); + } + return nullptr; +} + Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -5729,9 +5934,6 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; - if (Instruction *Res = foldICmpBinOp(I, Q)) - return Res; - if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -5777,6 +5979,15 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } + // The folds in here may rely on wrapping flags and special constants, so + // they can break up min/max idioms in some cases but not seemingly similar + // patterns. + // FIXME: It may be possible to enhance select folding to make this + // unnecessary. It may also be moot if we canonicalize to min/max + // intrinsics. + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + if (Instruction *Res = foldICmpInstWithConstant(I)) return Res; @@ -5788,13 +5999,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; - // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. - if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op0)) + // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'. + if (auto *GEP = dyn_cast<GEPOperator>(Op0)) if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) return NI; - if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op1)) - if (Instruction *NI = foldGEPICmp(GEP, Op0, - ICmpInst::getSwappedPredicate(I.getPredicate()), I)) + if (auto *GEP = dyn_cast<GEPOperator>(Op1)) + if (Instruction *NI = foldGEPICmp(GEP, Op0, I.getSwappedPredicate(), I)) return NI; // Try to optimize equality comparisons against alloca-based pointers. @@ -5808,7 +6018,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { return New; } - if (Instruction *Res = foldICmpBitCast(I, Builder)) + if (Instruction *Res = foldICmpBitCast(I)) return Res; // TODO: Hoist this above the min/max bailout. @@ -5910,6 +6120,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; + if (Instruction *Res = foldICmpInvariantGroup(I)) + return Res; + return Changed ? &I : nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index eaa53348028d..72e1b21e8d49 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -22,14 +22,15 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" using namespace llvm::PatternMatch; @@ -61,7 +62,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final : public InstCombiner, public InstVisitor<InstCombinerImpl, Instruction *> { public: - InstCombinerImpl(InstCombineWorklist &Worklist, BuilderTy &Builder, + InstCombinerImpl(InstructionWorklist &Worklist, BuilderTy &Builder, bool MinimizeSize, AAResults *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, @@ -190,6 +191,7 @@ public: private: void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); + bool isDesirableIntType(unsigned BitWidth) const; bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; bool shouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; @@ -240,15 +242,11 @@ private: /// /// \param ICI The icmp of the (zext icmp) pair we are interested in. /// \parem CI The zext of the (zext icmp) pair we are interested in. - /// \param DoTransform Pass false to just test whether the given (zext icmp) - /// would be transformed. Pass true to actually perform the transformation. /// /// \return null if the transformation cannot be performed. If the /// transformation can be performed the new instruction that replaces the - /// (zext icmp) pair will be returned (if \p DoTransform is false the - /// unmodified \p ICI will be returned in this case). - Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, - bool DoTransform = true); + /// (zext icmp) pair will be returned. + Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); @@ -319,13 +317,15 @@ private: Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); + Instruction *foldBitcastExtElt(ExtractElementInst &ExtElt); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); Instruction *narrowMaskedBinOp(BinaryOperator &And); Instruction *narrowMathIfNoOverflow(BinaryOperator &I); Instruction *narrowFunnelShift(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); - Instruction *matchSAddSubSat(SelectInst &MinMax1); + Instruction *matchSAddSubSat(Instruction &MinMax1); + Instruction *foldNot(BinaryOperator &I); void freelyInvertAllUsersOf(Value *V); @@ -347,6 +347,8 @@ private: Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); + Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should /// already be inserted into the function. @@ -623,6 +625,7 @@ public: Instruction *foldPHIArgGEPIntoPHI(PHINode &PN); Instruction *foldPHIArgLoadIntoPHI(PHINode &PN); Instruction *foldPHIArgZextsIntoPHI(PHINode &PN); + Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN); /// If an integer typed PHI has only one use which is an IntToPtr operation, /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise @@ -657,7 +660,7 @@ public: Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); - Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); + Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp); Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); @@ -701,6 +704,7 @@ public: const APInt &C); Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, const APInt &C); + Instruction *foldICmpBitCast(ICmpInst &Cmp); // Helpers of visitSelectInst(). Instruction *foldSelectExtConst(SelectInst &Sel); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 80abc775299a..79a8a065d02a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -337,8 +337,7 @@ void PointerReplacer::replace(Instruction *I) { MemCpy->getIntrinsicID(), MemCpy->getRawDest(), MemCpy->getDestAlign(), SrcV, MemCpy->getSourceAlign(), MemCpy->getLength(), MemCpy->isVolatile()); - AAMDNodes AAMD; - MemCpy->getAAMetadata(AAMD); + AAMDNodes AAMD = MemCpy->getAAMetadata(); if (AAMD) NewI->setAAMetadata(AAMD); @@ -649,9 +648,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { if (NumElements == 1) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, ST->getTypeAtIndex(0U), ".unpack"); - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - NewLoad->setAAMetadata(AAMD); + NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( UndefValue::get(T), NewLoad, 0, Name)); } @@ -680,9 +677,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { ST->getElementType(i), Ptr, commonAlignment(Align, SL->getElementOffset(i)), Name + ".unpack"); // Propagate AA metadata. It'll still be valid on the narrowed load. - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - L->setAAMetadata(AAMD); + L->setAAMetadata(LI.getAAMetadata()); V = IC.Builder.CreateInsertValue(V, L, i); } @@ -695,9 +690,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto NumElements = AT->getNumElements(); if (NumElements == 1) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack"); - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - NewLoad->setAAMetadata(AAMD); + NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( UndefValue::get(T), NewLoad, 0, Name)); } @@ -729,9 +722,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, commonAlignment(Align, Offset), Name + ".unpack"); - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - L->setAAMetadata(AAMD); + L->setAAMetadata(LI.getAAMetadata()); V = IC.Builder.CreateInsertValue(V, L, i); Offset += EltSize; } @@ -1208,9 +1199,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, SL->getElementOffset(i)); llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); - AAMDNodes AAMD; - SI.getAAMetadata(AAMD); - NS->setAAMetadata(AAMD); + NS->setAAMetadata(SI.getAAMetadata()); } return true; @@ -1256,9 +1245,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, Offset); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); - AAMDNodes AAMD; - SI.getAAMetadata(AAMD); - NS->setAAMetadata(AAMD); + NS->setAAMetadata(SI.getAAMetadata()); Offset += EltSize; } @@ -1500,8 +1487,8 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { StoreInst *OtherStore = nullptr; if (OtherBr->isUnconditional()) { --BBI; - // Skip over debugging info. - while (isa<DbgInfoIntrinsic>(BBI) || + // Skip over debugging info and pseudo probes. + while (BBI->isDebugOrPseudoInst() || (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { if (BBI==OtherBB->begin()) return false; @@ -1569,12 +1556,9 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { NewSI->setDebugLoc(MergedLoc); // If the two stores had AA tags, merge them. - AAMDNodes AATags; - SI.getAAMetadata(AATags); - if (AATags) { - OtherStore->getAAMetadata(AATags, /* Merge = */ true); - NewSI->setAAMetadata(AATags); - } + AAMDNodes AATags = SI.getAAMetadata(); + if (AATags) + NewSI->setAAMetadata(AATags.merge(OtherStore->getAAMetadata())); // Nuke the old stores. eraseInstFromFunction(SI); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6f2a8ebf839a..779d298da7a4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -31,7 +31,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include <cassert> @@ -39,11 +38,12 @@ #include <cstdint> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" - /// The specific integer value is used in a context where it is known to be /// non-zero. If this allows us to simplify the computation, do so and return /// the new operand, otherwise return null. @@ -107,14 +107,19 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), - m_Value(OtherOp)))) - return Builder.CreateSelect(Cond, OtherOp, Builder.CreateNeg(OtherOp)); - + m_Value(OtherOp)))) { + bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); + Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + return Builder.CreateSelect(Cond, OtherOp, Neg); + } // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), - m_Value(OtherOp)))) - return Builder.CreateSelect(Cond, Builder.CreateNeg(OtherOp), OtherOp); + m_Value(OtherOp)))) { + bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); + Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + return Builder.CreateSelect(Cond, Neg, OtherOp); + } // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp @@ -564,6 +569,16 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { return replaceInstUsesWith(I, NewPow); } + // powi(x, y) * powi(x, z) -> powi(x, y + z) + if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && + Y->getType() == Z->getType()) { + auto *YZ = Builder.CreateAdd(Y, Z); + auto *NewPow = Builder.CreateIntrinsic( + Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); + return replaceInstUsesWith(I, NewPow); + } + // exp(X) * exp(Y) -> exp(X + Y) if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { @@ -706,11 +721,11 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, assert(C1.getBitWidth() == C2.getBitWidth() && "Constant widths not equal"); // Bail if we will divide by zero. - if (C2.isNullValue()) + if (C2.isZero()) return false; // Bail if we would divide INT_MIN by -1. - if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) + if (IsSigned && C1.isMinSignedValue() && C2.isAllOnes()) return false; APInt Remainder(C1.getBitWidth(), /*val=*/0ULL, IsSigned); @@ -778,11 +793,12 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) && - *C1 != C1->getBitWidth() - 1) || - (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))))) { + C1->ult(C1->getBitWidth() - 1)) || + (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) && + C1->ult(C1->getBitWidth()))) { APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); APInt C1Shifted = APInt::getOneBitSet( - C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue())); // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of 1 << C1. if (isMultiple(*C2, C1Shifted, Quotient, IsSigned)) { @@ -803,7 +819,7 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } - if (!C2->isNullValue()) // avoid X udiv 0 + if (!C2->isZero()) // avoid X udiv 0 if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) return FoldedDiv; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 37c7e6135501..7dc516c6fdc3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -215,6 +215,20 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { : Builder.CreateSExt(I->getOperand(0), I->getType(), I->getName() + ".neg"); break; + case Instruction::Select: { + // If both arms of the select are constants, we don't need to recurse. + // Therefore, this transform is not limited by uses. + auto *Sel = cast<SelectInst>(I); + Constant *TrueC, *FalseC; + if (match(Sel->getTrueValue(), m_ImmConstant(TrueC)) && + match(Sel->getFalseValue(), m_ImmConstant(FalseC))) { + Constant *NegTrueC = ConstantExpr::getNeg(TrueC); + Constant *NegFalseC = ConstantExpr::getNeg(FalseC); + return Builder.CreateSelect(Sel->getCondition(), NegTrueC, NegFalseC, + I->getName() + ".neg", /*MDFrom=*/I); + } + break; + } default: break; // Other instructions require recursive reasoning. } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 6c6351c70e3a..35739c3b9a21 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -299,6 +299,29 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { IntToPtr->getOperand(0)->getType()); } +// Remove RoundTrip IntToPtr/PtrToInt Cast on PHI-Operand and +// fold Phi-operand to bitcast. +Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) { + // convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] ) + // Make sure all uses of phi are ptr2int. + if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); })) + return nullptr; + + // Iterating over all operands to check presence of target pointers for + // optimization. + bool OperandWithRoundTripCast = false; + for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) { + if (auto *NewOp = + simplifyIntToPtrRoundTripCast(PN.getIncomingValue(OpNum))) { + PN.setIncomingValue(OpNum, NewOp); + OperandWithRoundTripCast = true; + } + } + if (!OperandWithRoundTripCast) + return nullptr; + return &PN; +} + /// If we have something like phi [insertvalue(a,b,0), insertvalue(c,d,0)], /// turn this into a phi[a,c] and phi[b,d] and a single insertvalue. Instruction * @@ -1306,6 +1329,9 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { if (Instruction *Result = foldPHIArgZextsIntoPHI(PN)) return Result; + if (Instruction *Result = foldPHIArgIntToPtrToPHI(PN)) + return Result; + // If all PHI operands are the same operation, pull them through the PHI, // reducing code size. if (isa<Instruction>(PN.getIncomingValue(0)) && diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 5bbc3c87ca4f..4a1e82ae9c1d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -38,15 +38,16 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" static Value *createMinMax(InstCombiner::BuilderTy &Builder, SelectPatternFlavor SPF, Value *A, Value *B) { @@ -165,7 +166,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // simplify/reduce the instructions. APInt TC = *SelTC; APInt FC = *SelFC; - if (!TC.isNullValue() && !FC.isNullValue()) { + if (!TC.isZero() && !FC.isZero()) { // If the select constants differ by exactly one bit and that's the same // bit that is masked and checked by the select condition, the select can // be replaced by bitwise logic to set/clear one bit of the constant result. @@ -202,7 +203,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // Determine which shift is needed to transform result of the 'and' into the // desired result. - const APInt &ValC = !TC.isNullValue() ? TC : FC; + const APInt &ValC = !TC.isZero() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); @@ -224,7 +225,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TC.isNullValue(); + bool ShouldNotVal = !TC.isZero(); ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); @@ -319,8 +320,16 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, Value *X, *Y; if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && (TI->hasOneUse() || FI->hasOneUse())) { + // Intersect FMF from the fneg instructions and union those with the select. + FastMathFlags FMF = TI->getFastMathFlags(); + FMF &= FI->getFastMathFlags(); + FMF |= SI.getFastMathFlags(); Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); - return UnaryOperator::CreateFNegFMF(NewSel, TI); + if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) + NewSelI->setFastMathFlags(FMF); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); + NewFNeg->setFastMathFlags(FMF); + return NewFNeg; } // Min/max intrinsic with a common operand can have the common operand pulled @@ -420,10 +429,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, } static bool isSelect01(const APInt &C1I, const APInt &C2I) { - if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero. + if (!C1I.isZero() && !C2I.isZero()) // One side must be zero. return false; - return C1I.isOneValue() || C1I.isAllOnesValue() || - C2I.isOneValue() || C2I.isAllOnesValue(); + return C1I.isOne() || C1I.isAllOnes() || C2I.isOne() || C2I.isAllOnes(); } /// Try to fold the select into one of the operands to allow further @@ -715,6 +723,58 @@ static Instruction *foldSetClearBits(SelectInst &Sel, return nullptr; } +// select (x == 0), 0, x * y --> freeze(y) * x +// select (y == 0), 0, x * y --> freeze(x) * y +// select (x == 0), undef, x * y --> freeze(y) * x +// select (x == undef), 0, x * y --> freeze(y) * x +// Usage of mul instead of 0 will make the result more poisonous, +// so the operand that was not checked in the condition should be frozen. +// The latter folding is applied only when a constant compared with x is +// is a vector consisting of 0 and undefs. If a constant compared with x +// is a scalar undefined value or undefined vector then an expression +// should be already folded into a constant. +static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { + auto *CondVal = SI.getCondition(); + auto *TrueVal = SI.getTrueValue(); + auto *FalseVal = SI.getFalseValue(); + Value *X, *Y; + ICmpInst::Predicate Predicate; + + // Assuming that constant compared with zero is not undef (but it may be + // a vector with some undef elements). Otherwise (when a constant is undef) + // the select expression should be already simplified. + if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) || + !ICmpInst::isEquality(Predicate)) + return nullptr; + + if (Predicate == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Check that TrueVal is a constant instead of matching it with m_Zero() + // to handle the case when it is a scalar undef value or a vector containing + // non-zero elements that are masked by undef elements in the compare + // constant. + auto *TrueValC = dyn_cast<Constant>(TrueVal); + if (TrueValC == nullptr || + !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y))) || + !isa<Instruction>(FalseVal)) + return nullptr; + + auto *ZeroC = cast<Constant>(cast<Instruction>(CondVal)->getOperand(1)); + auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC); + // If X is compared with 0 then TrueVal could be either zero or undef. + // m_Zero match vectors containing some undef elements, but for scalars + // m_Undef should be used explicitly. + if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef())) + return nullptr; + + auto *FalseValI = cast<Instruction>(FalseVal); + auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"), + *FalseValI); + IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); + return IC.replaceInstUsesWith(SI, FalseValI); +} + /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). /// There are 8 commuted/swapped variants of this pattern. /// TODO: Also support a - UMIN(a,b) patterns. @@ -1229,8 +1289,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // Iff -C1 s<= C2 s<= C0-C1 // Also ULT predicate can also be UGT iff C0 != -1 (+invert result) // SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) -static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, - InstCombiner::BuilderTy &Builder) { +static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, + InstCombiner::BuilderTy &Builder) { Value *X = Sel0.getTrueValue(); Value *Sel1 = Sel0.getFalseValue(); @@ -1238,36 +1298,42 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // Said condition must be one-use. if (!Cmp0.hasOneUse()) return nullptr; + ICmpInst::Predicate Pred0 = Cmp0.getPredicate(); Value *Cmp00 = Cmp0.getOperand(0); Constant *C0; if (!match(Cmp0.getOperand(1), m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))) return nullptr; - // Canonicalize Cmp0 into the form we expect. + + if (!isa<SelectInst>(Sel1)) { + Pred0 = ICmpInst::getInversePredicate(Pred0); + std::swap(X, Sel1); + } + + // Canonicalize Cmp0 into ult or uge. // FIXME: we shouldn't care about lanes that are 'undef' in the end? - switch (Cmp0.getPredicate()) { + switch (Pred0) { case ICmpInst::Predicate::ICMP_ULT: + case ICmpInst::Predicate::ICMP_UGE: + // Although icmp ult %x, 0 is an unusual thing to try and should generally + // have been simplified, it does not verify with undef inputs so ensure we + // are not in a strange state. + if (!match(C0, m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_NE, + APInt::getZero(C0->getType()->getScalarSizeInBits())))) + return nullptr; break; // Great! case ICmpInst::Predicate::ICMP_ULE: - // We'd have to increment C0 by one, and for that it must not have all-ones - // element, but then it would have been canonicalized to 'ult' before - // we get here. So we can't do anything useful with 'ule'. - return nullptr; case ICmpInst::Predicate::ICMP_UGT: - // We want to canonicalize it to 'ult', so we'll need to increment C0, - // which again means it must not have any all-ones elements. + // We want to canonicalize it to 'ult' or 'uge', so we'll need to increment + // C0, which again means it must not have any all-ones elements. if (!match(C0, - m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, - APInt::getAllOnesValue( - C0->getType()->getScalarSizeInBits())))) + m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_NE, + APInt::getAllOnes(C0->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have all-ones element[s]. C0 = InstCombiner::AddOne(C0); - std::swap(X, Sel1); break; - case ICmpInst::Predicate::ICMP_UGE: - // The only way we'd get this predicate if this `icmp` has extra uses, - // but then we won't be able to do this fold. - return nullptr; default: return nullptr; // Unknown predicate. } @@ -1277,11 +1343,16 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!Sel1->hasOneUse()) return nullptr; + // If the types do not match, look through any truncs to the underlying + // instruction. + if (Cmp00->getType() != X->getType() && X->hasOneUse()) + match(X, m_TruncOrSelf(m_Value(X))); + // We now can finish matching the condition of the outermost select: // it should either be the X itself, or an addition of some constant to X. Constant *C1; if (Cmp00 == X) - C1 = ConstantInt::getNullValue(Sel0.getType()); + C1 = ConstantInt::getNullValue(X->getType()); else if (!match(Cmp00, m_Add(m_Specific(X), m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1))))) @@ -1335,6 +1406,8 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // The thresholds of this clamp-like pattern. auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); + if (Pred0 == ICmpInst::Predicate::ICMP_UGE) + std::swap(ThresholdLowIncl, ThresholdHighExcl); // The fold has a precondition 1: C2 s>= ThresholdLow auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, @@ -1347,15 +1420,29 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!match(Precond2, m_One())) return nullptr; + // If we are matching from a truncated input, we need to sext the + // ReplacementLow and ReplacementHigh values. Only do the transform if they + // are free to extend due to being constants. + if (X->getType() != Sel0.getType()) { + Constant *LowC, *HighC; + if (!match(ReplacementLow, m_ImmConstant(LowC)) || + !match(ReplacementHigh, m_ImmConstant(HighC))) + return nullptr; + ReplacementLow = ConstantExpr::getSExt(LowC, X->getType()); + ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType()); + } + // All good, finally emit the new pattern. Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl); Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl); Value *MaybeReplacedLow = Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X); - Instruction *MaybeReplacedHigh = - SelectInst::Create(ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); - return MaybeReplacedHigh; + // Create the final select. If we looked through a truncate above, we will + // need to retruncate the result. + Value *MaybeReplacedHigh = Builder.CreateSelect( + ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); + return Builder.CreateTrunc(MaybeReplacedHigh, Sel0.getType()); } // If we have @@ -1446,8 +1533,8 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, *this)) return NewAbs; - if (Instruction *NewAbs = canonicalizeClampLike(SI, *ICI, Builder)) - return NewAbs; + if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) + return replaceInstUsesWith(SI, V); if (Instruction *NewSel = tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) @@ -1816,9 +1903,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { m_Value(TrueVal), m_Value(FalseVal)))) return false; - auto IsZeroOrOne = [](const APInt &C) { - return C.isNullValue() || C.isOneValue(); - }; + auto IsZeroOrOne = [](const APInt &C) { return C.isZero() || C.isOne(); }; auto IsMinMax = [&](Value *Min, Value *Max) { APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); @@ -2182,7 +2267,7 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, } /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. -Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { +Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) { Type *Ty = MinMax1.getType(); // We are looking for a tree of: @@ -2212,23 +2297,14 @@ Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) return nullptr; - // Also make sure that the number of uses is as expected. The "3"s are for the - // the two items of min/max (the compare and the select). - if (MinMax2->hasNUsesOrMore(3) || AddSub->hasNUsesOrMore(3)) + // Also make sure that the number of uses is as expected. The 3 is for the + // the two items of the compare and the select, or 2 from a min/max. + unsigned ExpUses = isa<IntrinsicInst>(MinMax1) ? 2 : 3; + if (MinMax2->hasNUsesOrMore(ExpUses) || AddSub->hasNUsesOrMore(ExpUses)) return nullptr; // Create the new type (which can be a vector type) Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); - // Match the two extends from the add/sub - Value *A, *B; - if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B))))) - return nullptr; - // And check the incoming values are of a type smaller than or equal to the - // size of the saturation. Otherwise the higher bits can cause different - // results. - if (A->getType()->getScalarSizeInBits() > NewBitWidth || - B->getType()->getScalarSizeInBits() > NewBitWidth) - return nullptr; Intrinsic::ID IntrinsicID; if (AddSub->getOpcode() == Instruction::Add) @@ -2238,10 +2314,16 @@ Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { else return nullptr; + // The two operands of the add/sub must be nsw-truncatable to the NewTy. This + // is usually achieved via a sext from a smaller type. + if (ComputeMinSignedBits(AddSub->getOperand(0), 0, AddSub) > NewBitWidth || + ComputeMinSignedBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth) + return nullptr; + // Finally create and return the sat intrinsic, truncated to the new type Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); - Value *AT = Builder.CreateSExt(A, NewTy); - Value *BT = Builder.CreateSExt(B, NewTy); + Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy); + Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy); Value *Sat = Builder.CreateCall(F, {AT, BT}); return CastInst::Create(Instruction::SExt, Sat, Ty); } @@ -2432,7 +2514,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { unsigned NumElts = VecTy->getNumElements(); APInt UndefElts(NumElts, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(NumElts)); + APInt AllOnesEltMask(APInt::getAllOnes(NumElts)); if (Value *V = SimplifyDemandedVectorElts(&Sel, AllOnesEltMask, UndefElts)) { if (V != &Sel) return replaceInstUsesWith(Sel, V); @@ -2754,11 +2836,16 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { /* IsAnd */ IsAnd)) return I; - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) - if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) + if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) { + if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) { if (auto *V = foldAndOrOfICmpsOfAndWithPow2(ICmp0, ICmp1, &SI, IsAnd, /* IsLogical */ true)) return replaceInstUsesWith(SI, V); + + if (auto *V = foldEqOfParts(ICmp0, ICmp1, IsAnd)) + return replaceInstUsesWith(SI, V); + } + } } // select (select a, true, b), c, false -> select a, c, false @@ -2863,14 +2950,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need - // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We - // also require nnan because we do not want to unintentionally change the - // sign of a NaN value. + // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) - Instruction *FSub; if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(FalseVal))) && - match(TrueVal, m_Instruction(FSub)) && FSub->hasNoNaNs() && (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, &SI); return replaceInstUsesWith(SI, Fabs); @@ -2878,7 +2961,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(TrueVal))) && - match(FalseVal, m_Instruction(FSub)) && FSub->hasNoNaNs() && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, &SI); return replaceInstUsesWith(SI, Fabs); @@ -2886,11 +2968,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // With nnan and nsz: // (X < +/-0.0) ? -X : X --> fabs(X) // (X <= +/-0.0) ? -X : X --> fabs(X) - Instruction *FNeg; if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && - match(TrueVal, m_FNeg(m_Specific(FalseVal))) && - match(TrueVal, m_Instruction(FNeg)) && FNeg->hasNoNaNs() && - FNeg->hasNoSignedZeros() && SI.hasNoSignedZeros() && + match(TrueVal, m_FNeg(m_Specific(FalseVal))) && SI.hasNoSignedZeros() && (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, &SI); @@ -2900,9 +2979,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // (X > +/-0.0) ? X : -X --> fabs(X) // (X >= +/-0.0) ? X : -X --> fabs(X) if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && - match(FalseVal, m_FNeg(m_Specific(TrueVal))) && - match(FalseVal, m_Instruction(FNeg)) && FNeg->hasNoNaNs() && - FNeg->hasNoSignedZeros() && SI.hasNoSignedZeros() && + match(FalseVal, m_FNeg(m_Specific(TrueVal))) && SI.hasNoSignedZeros() && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, &SI); @@ -2920,6 +2997,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { return Add; if (Instruction *Or = foldSetClearBits(SI, Builder)) return Or; + if (Instruction *Mul = foldSelectZeroOrMul(SI, *this)) + return Mul; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast<Instruction>(TrueVal); @@ -2939,8 +3018,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base || !Gep->hasOneUse()) return nullptr; - Type *ElementType = Gep->getResultElementType(); Value *Idx = Gep->getOperand(1); + if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType())) + return nullptr; + Type *ElementType = Gep->getResultElementType(); Value *NewT = Idx; Value *NewF = Constant::getNullValue(Idx->getType()); if (Swap) @@ -3188,9 +3269,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { KnownBits Known(1); computeKnownBits(CondVal, Known, 0, &SI); - if (Known.One.isOneValue()) + if (Known.One.isOne()) return replaceInstUsesWith(SI, TrueVal); - if (Known.Zero.isOneValue()) + if (Known.Zero.isOne()) return replaceInstUsesWith(SI, FalseVal); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index ca5e473fdecb..06421d553915 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -41,7 +41,7 @@ bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1, (Sh0->getType()->getScalarSizeInBits() - 1) + (Sh1->getType()->getScalarSizeInBits() - 1); APInt MaximalRepresentableShiftAmount = - APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + APInt::getAllOnes(ShAmt0->getType()->getScalarSizeInBits()); return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount); } @@ -172,8 +172,8 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( // There are many variants to this pattern: // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt -// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt -// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt +// c) (x & (-1 l>> MaskShAmt)) << ShiftShAmt +// d) (x & ((-1 << MaskShAmt) l>> MaskShAmt)) << ShiftShAmt // e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt // f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt // All these patterns can be simplified to just: @@ -213,11 +213,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); // (~(-1 << maskNbits)) auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); - // (-1 >> MaskShAmt) - auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); - // ((-1 << MaskShAmt) >> MaskShAmt) + // (-1 l>> MaskShAmt) + auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt)); + // ((-1 << MaskShAmt) l>> MaskShAmt) auto MaskD = - m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); + m_LShr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); Value *X; Constant *NewMask; @@ -240,7 +240,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // that shall remain in the root value (OuterShift). // An extend of an undef value becomes zero because the high bits are never - // completely unknown. Replace the the `undef` shift amounts with final + // completely unknown. Replace the `undef` shift amounts with final // shift bitwidth to ensure that the value remains undef when creating the // subsequent shift op. SumOfShAmts = Constant::replaceUndefsWith( @@ -272,7 +272,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // shall be unset in the root value (OuterShift). // An extend of an undef value becomes zero because the high bits are never - // completely unknown. Replace the the `undef` shift amounts with negated + // completely unknown. Replace the `undef` shift amounts with negated // bitwidth of innermost shift to ensure that the value remains undef when // creating the subsequent shift op. unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); @@ -346,9 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // TODO: Remove the one-use check if the other logic operand (Y) is constant. Value *X, *Y; auto matchFirstShift = [&](Value *V) { - BinaryOperator *BO; APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); - return match(V, m_BinOp(BO)) && BO->getOpcode() == ShiftOpcode && + return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && match(ConstantExpr::getAdd(C0, C1), m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); @@ -661,23 +660,22 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { - bool isLeftShift = I.getOpcode() == Instruction::Shl; - const APInt *Op1C; if (!match(Op1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. + bool IsLeftShift = I.getOpcode() == Instruction::Shl; if (I.getOpcode() != Instruction::AShr && - canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { + canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { LLVM_DEBUG( dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I << "\n"); return replaceInstUsesWith( - I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); + I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); } // See if we can simplify any instructions used by the instruction whose sole @@ -686,202 +684,72 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, unsigned TypeBits = Ty->getScalarSizeInBits(); assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); + (void)TypeBits; if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; - // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) - if (auto *TI = dyn_cast<TruncInst>(Op0)) { - // If 'shift2' is an ashr, we would have to get the sign bit into a funny - // place. Don't try to do this transformation in this case. Also, we - // require that the input operand is a shift-by-constant so that we have - // confidence that the shifts will get folded together. We could do this - // xform in more cases, but it is unlikely to be profitable. - const APInt *TrShiftAmt; - if (I.isLogicalShift() && - match(TI->getOperand(0), m_Shift(m_Value(), m_APInt(TrShiftAmt)))) { - auto *TrOp = cast<Instruction>(TI->getOperand(0)); - Type *SrcTy = TrOp->getType(); - - // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy); - // (shift2 (shift1 & 0x00FF), c2) - Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); - - // For logical shifts, the truncation has the effect of making the high - // part of the register be zeros. Emulate this by inserting an AND to - // clear the top bits as needed. This 'and' will usually be zapped by - // other xforms later if dead. - unsigned SrcSize = SrcTy->getScalarSizeInBits(); - Constant *MaskV = - ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits)); - - // The mask we constructed says what the trunc would do if occurring - // between the shifts. We want to know the effect *after* the second - // shift. We know that it is a logical shift by a constant, so adjust the - // mask as appropriate. - MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt); - // shift1 & 0x00FF - Value *And = Builder.CreateAnd(NSh, MaskV, TI->getName()); - // Return the value truncated to the interesting size. - return new TruncInst(And, Ty); - } - } - - if (Op0->hasOneUse()) { - if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { - // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - Value *V1; - const APInt *CC; - switch (Op0BO->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { - // These operators commute. - // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) - if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && - match(Op0BO->getOperand(1), m_Shr(m_Value(V1), - m_Specific(Op1)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); - // (X + (Y << C)) - Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, - Op0BO->getOperand(1)->getName()); - unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(Ty, Bits); - return BinaryOperator::CreateAnd(X, Mask); - } - - // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) - Value *Op0BOOp1 = Op0BO->getOperand(1); - if (isLeftShift && Op0BOOp1->hasOneUse() && - match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_APInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); - // X & (CC << C) - Value *XM = Builder.CreateAnd( - V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), - V1->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); - } - LLVM_FALLTHROUGH; - } - - case Instruction::Sub: { - // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && - match(Op0BO->getOperand(0), m_Shr(m_Value(V1), - m_Specific(Op1)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); - // (X + (Y << C)) - Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, - Op0BO->getOperand(0)->getName()); - unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(Ty, Bits); - return BinaryOperator::CreateAnd(X, Mask); - } - - // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) - if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && - match(Op0BO->getOperand(0), - m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_APInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); - // X & (CC << C) - Value *XM = Builder.CreateAnd( - V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), - V1->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); - } - - break; - } - } + if (!Op0->hasOneUse()) + return nullptr; - // If the operand is a bitwise operator with a constant RHS, and the - // shift is the only use, we can pull it out of the shift. - const APInt *Op0C; - if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { - if (canShiftBinOpWithConstantRHS(I, Op0BO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(Op0BO->getOperand(1)), Op1); + if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) { + // If the operand is a bitwise operator with a constant RHS, and the + // shift is the only use, we can pull it out of the shift. + const APInt *Op0C; + if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { + if (canShiftBinOpWithConstantRHS(I, Op0BO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(Op0BO->getOperand(1)), Op1); - Value *NewShift = + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); - NewShift->takeName(Op0BO); - - return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, - NewRHS); - } - } - - // If the operand is a subtract with a constant LHS, and the shift - // is the only use, we can pull it out of the shift. - // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) - if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && - match(Op0BO->getOperand(0), m_APInt(Op0C))) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(Op0BO->getOperand(0)), Op1); - - Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); NewShift->takeName(Op0BO); - return BinaryOperator::CreateSub(NewRHS, NewShift); + return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS); } } + } - // If we have a select that conditionally executes some binary operator, - // see if we can pull it the select and operator through the shift. - // - // For example, turning: - // shl (select C, (add X, C1), X), C2 - // Into: - // Y = shl X, C2 - // select C, (add Y, C1 << C2), Y - Value *Cond; - BinaryOperator *TBO; - Value *FalseVal; - if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), - m_Value(FalseVal)))) { - const APInt *C; - if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && - match(TBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, TBO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(TBO->getOperand(1)), Op1); - - Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); - Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, - NewRHS); - return SelectInst::Create(Cond, NewOp, NewShift); - } + // If we have a select that conditionally executes some binary operator, + // see if we can pull it the select and operator through the shift. + // + // For example, turning: + // shl (select C, (add X, C1), X), C2 + // Into: + // Y = shl X, C2 + // select C, (add Y, C1 << C2), Y + Value *Cond; + BinaryOperator *TBO; + Value *FalseVal; + if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), + m_Value(FalseVal)))) { + const APInt *C; + if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && + match(TBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, TBO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(TBO->getOperand(1)), Op1); + + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); + Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewOp, NewShift); } + } - BinaryOperator *FBO; - Value *TrueVal; - if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), - m_OneUse(m_BinOp(FBO))))) { - const APInt *C; - if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && - match(FBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, FBO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(FBO->getOperand(1)), Op1); - - Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); - Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, - NewRHS); - return SelectInst::Create(Cond, NewShift, NewOp); - } + BinaryOperator *FBO; + Value *TrueVal; + if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), + m_OneUse(m_BinOp(FBO))))) { + const APInt *C; + if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && + match(FBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, FBO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(FBO->getOperand(1)), Op1); + + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); + Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewShift, NewOp); } } @@ -908,41 +776,41 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); - const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { - unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); - // shl (zext X), ShAmt --> zext (shl X, ShAmt) + // shl (zext X), C --> zext (shl X, C) // This is only valid if X would have zeros shifted out. Value *X; if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); - if (ShAmt < SrcWidth && - MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) - return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); + if (ShAmtC < SrcWidth && + MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), 0, &I)) + return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); } // (X >> C) << C --> X & (-1 << C) if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - const APInt *ShOp1; - if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1)))) && - ShOp1->ult(BitWidth)) { - unsigned ShrAmt = ShOp1->getZExtValue(); - if (ShrAmt < ShAmt) { - // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + const APInt *C1; + if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); return NewShl; } - if (ShrAmt > ShAmt) { - // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>?exact C1) << C --> X >>?exact (C1 - C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); auto *NewShr = BinaryOperator::Create( cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); NewShr->setIsExact(true); @@ -950,49 +818,135 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } } - if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(ShOp1)))) && - ShOp1->ult(BitWidth)) { - unsigned ShrAmt = ShOp1->getZExtValue(); - if (ShrAmt < ShAmt) { - // If C1 < C2: (X >>? C1) << C2 --> X << (C2 - C1) & (-1 << C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); Builder.Insert(NewShl); - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - if (ShrAmt > ShAmt) { - // If C1 > C2: (X >>? C1) << C2 --> X >>? (C1 - C2) & (-1 << C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>? C1) << C --> (X >>? (C1 - C)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); auto *OldShr = cast<BinaryOperator>(Op0); auto *NewShr = BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff); NewShr->setIsExact(OldShr->isExact()); Builder.Insert(NewShr); - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask)); } } - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { - unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Similar to above, but look through an intermediate trunc instruction. + BinaryOperator *Shr; + if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) && + match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) { + // The larger shift direction survives through the transform. + unsigned ShrAmtC = C1->getZExtValue(); + unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC; + Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff); + auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl; + + // If C1 > C: + // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C) + // If C > C1: + // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C) + Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff"); + Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff"); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask)); + } + + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + unsigned AmtSum = ShAmtC + C1->getZExtValue(); // Oversized shifts are simplified to zero in InstSimplify. if (AmtSum < BitWidth) // (X << C1) << C2 --> X << (C1 + C2) return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); } + // If we have an opposite shift by the same amount, we may be able to + // reorder binops and shifts to eliminate math/logic. + auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) { + switch (BinOpcode) { + default: + return false; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Sub: + // NOTE: Sub is not commutable and the tranforms below may not be valid + // when the shift-right is operand 1 (RHS) of the sub. + return true; + } + }; + BinaryOperator *Op0BO; + if (match(Op0, m_OneUse(m_BinOp(Op0BO))) && + isSuitableBinOpcode(Op0BO->getOpcode())) { + // Commute so shift-right is on LHS of the binop. + // (Y bop (X >> C)) << C -> ((X >> C) bop Y) << C + // (Y bop ((X >> C) & CC)) << C -> (((X >> C) & CC) bop Y) << C + Value *Shr = Op0BO->getOperand(0); + Value *Y = Op0BO->getOperand(1); + Value *X; + const APInt *CC; + if (Op0BO->isCommutative() && Y->hasOneUse() && + (match(Y, m_Shr(m_Value(), m_Specific(Op1))) || + match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))), + m_APInt(CC))))) + std::swap(Shr, Y); + + // ((X >> C) bop Y) << C -> (X bop (Y << C)) & (~0 << C) + if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // (X bop (Y << C)) + Value *B = + Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName()); + unsigned Op1Val = C->getLimitedValue(BitWidth); + APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val); + Constant *Mask = ConstantInt::get(Ty, Bits); + return BinaryOperator::CreateAnd(B, Mask); + } + + // (((X >> C) & CC) bop Y) << C -> (X & (CC << C)) bop (Y << C) + if (match(Shr, + m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))), + m_APInt(CC))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // X & (CC << C) + Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)), + X->getName() + ".mask"); + return BinaryOperator::Create(Op0BO->getOpcode(), M, YS); + } + } + + // (C1 - X) << C --> (C1 << C) - (X << C) + if (match(Op0, m_OneUse(m_Sub(m_APInt(C1), m_Value(X))))) { + Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C)); + Value *NewShift = Builder.CreateShl(X, Op1); + return BinaryOperator::CreateSub(NewLHS, NewShift); + } + // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { + MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, + &I)) { I.setHasNoUnsignedWrap(); return &I; } // If the shifted-out value is all signbits, then this is a NSW shift. - if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { + if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { I.setHasNoSignedWrap(); return &I; } @@ -1048,12 +1002,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); - const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { - unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); unsigned BitWidth = Ty->getScalarSizeInBits(); auto *II = dyn_cast<IntrinsicInst>(Op0); - if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && + if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && (II->getIntrinsicID() == Intrinsic::ctlz || II->getIntrinsicID() == Intrinsic::cttz || II->getIntrinsicID() == Intrinsic::ctpop)) { @@ -1067,78 +1021,81 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } Value *X; - const APInt *ShOp1; - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { - if (ShOp1->ult(ShAmt)) { - unsigned ShlAmt = ShOp1->getZExtValue(); - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + const APInt *C1; + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + if (C1->ult(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShlAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) + // (X <<nuw C1) >>u C --> X >>u (C - C1) auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); NewLShr->setIsExact(I.isExact()); return NewLShr; } - // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) + // (X << C1) >>u C --> (X >>u (C - C1)) & (-1 >> C) Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); } - if (ShOp1->ugt(ShAmt)) { - unsigned ShlAmt = ShOp1->getZExtValue(); - Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + if (C1->ugt(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) + // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(true); return NewShl; } - // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) + // (X << C1) >>u C --> X << (C1 - C) & (-1 >> C) Value *NewShl = Builder.CreateShl(X, ShiftDiff); - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - assert(*ShOp1 == ShAmt); + assert(*C1 == ShAmtC); // (X << C) >>u C --> X & (-1 >>u C) - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { - assert(ShAmt < X->getType()->getScalarSizeInBits() && + assert(ShAmtC < X->getType()->getScalarSizeInBits() && "Big shift not simplified to zero?"); // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN - Value *NewLShr = Builder.CreateLShr(X, ShAmt); + Value *NewLShr = Builder.CreateLShr(X, ShAmtC); return new ZExtInst(NewLShr, Ty); } - if (match(Op0, m_SExt(m_Value(X))) && - (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { - // Are we moving the sign bit to the low bit and widening with high zeros? + if (match(Op0, m_SExt(m_Value(X)))) { unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); - if (ShAmt == BitWidth - 1) { - // lshr (sext i1 X to iN), N-1 --> zext X to iN - if (SrcTyBitWidth == 1) - return new ZExtInst(X, Ty); + // lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0) + if (SrcTyBitWidth == 1) { + auto *NewC = ConstantInt::get( + Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } - // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN - if (Op0->hasOneUse()) { + if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) && + Op0->hasOneUse()) { + // Are we moving the sign bit to the low bit and widening with high + // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN + if (ShAmtC == BitWidth - 1) { Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); return new ZExtInst(NewLShr, Ty); } - } - // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN - if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) { - // The new shift amount can't be more than the narrow source type. - unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1); - Value *AShr = Builder.CreateAShr(X, NewShAmt); - return new ZExtInst(AShr, Ty); + // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN + if (ShAmtC == BitWidth - SrcTyBitWidth) { + // The new shift amount can't be more than the narrow source type. + unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1); + Value *AShr = Builder.CreateAShr(X, NewShAmt); + return new ZExtInst(AShr, Ty); + } } } Value *Y; - if (ShAmt == BitWidth - 1) { + if (ShAmtC == BitWidth - 1) { // lshr i32 or(X,-X), 31 --> zext (X != 0) if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X))))) return new ZExtInst(Builder.CreateIsNotNull(X), Ty); @@ -1150,32 +1107,55 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // Check if a number is negative and odd: // lshr i32 (srem X, 2), 31 --> and (X >> 31), X if (match(Op0, m_OneUse(m_SRem(m_Value(X), m_SpecificInt(2))))) { - Value *Signbit = Builder.CreateLShr(X, ShAmt); + Value *Signbit = Builder.CreateLShr(X, ShAmtC); return BinaryOperator::CreateAnd(Signbit, X); } } - if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { - unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // (X >>u C1) >>u C --> X >>u (C1 + C) + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) { // Oversized shifts are simplified to zero in InstSimplify. + unsigned AmtSum = ShAmtC + C1->getZExtValue(); if (AmtSum < BitWidth) - // (X >>u C1) >>u C2 --> X >>u (C1 + C2) return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); } + Instruction *TruncSrc; + if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) && + match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + + // If the combined shift fits in the source width: + // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC + // + // If the first shift covers the number of bits truncated, then the + // mask instruction is eliminated (and so the use check is relaxed). + if (AmtSum < SrcWidth && + (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) { + Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift"); + Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName()); + + // If the first shift does not cover the number of bits truncated, then + // we require a mask to get rid of high bits in the result. + APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC)); + } + } + // Look for a "splat" mul pattern - it replicates bits across each half of // a value, so a right shift is just a mask of the low bits: // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1 // TODO: Generalize to allow more than just half-width shifts? const APInt *MulC; if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && - ShAmt * 2 == BitWidth && (*MulC - 1).isPowerOf2() && - MulC->logBase2() == ShAmt) + ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { I.setIsExact(); return &I; } @@ -1346,6 +1326,22 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { } } + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` + // as the pattern to splat the lowest bit. + // FIXME: iff X is already masked, we don't need the one-use check. + Value *X; + if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) && + match(Op0, m_OneUse(m_Shl(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1))))) { + Constant *Mask = ConstantInt::get(Ty, 1); + // Retain the knowledge about the ignored lanes. + Mask = Constant::mergeUndefsWith( + Constant::mergeUndefsWith(Mask, cast<Constant>(Op1)), + cast<Constant>(cast<Instruction>(Op0)->getOperand(1))); + X = Builder.CreateAnd(X, Mask); + return BinaryOperator::CreateNeg(X); + } + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) return R; @@ -1354,7 +1350,6 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return BinaryOperator::CreateLShr(Op0, Op1); // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 - Value *X; if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { // Note that we must drop 'exact'-ness of the shift! // Note that we can't keep undef's in -1 vector constant! diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 15b51ae8a5ee..e357a9da8b12 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -55,7 +55,7 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); KnownBits Known(BitWidth); - APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); + APInt DemandedMask(APInt::getAllOnes(BitWidth)); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); @@ -124,7 +124,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } Known.resetAll(); - if (DemandedMask.isNullValue()) // Not demanding any bits from V. + if (DemandedMask.isZero()) // Not demanding any bits from V. return UndefValue::get(VTy); if (Depth == MaxAnalysisRecursionDepth) @@ -274,8 +274,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // constant because that's a canonical 'not' op, and that is better for // combining, SCEV, and codegen. const APInt *C; - if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnesValue()) { - if ((*C | ~DemandedMask).isAllOnesValue()) { + if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { + if ((*C | ~DemandedMask).isAllOnes()) { // Force bits to 1 to create a 'not' op. I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); return I; @@ -385,8 +385,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known = KnownBits::commonBits(LHSKnown, RHSKnown); break; } - case Instruction::ZExt: case Instruction::Trunc: { + // If we do not demand the high bits of a right-shifted and truncated value, + // then we may be able to truncate it before the shift. + Value *X; + const APInt *C; + if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // The shift amount must be valid (not poison) in the narrow type, and + // it must not be greater than the high bits demanded of the result. + if (C->ult(I->getType()->getScalarSizeInBits()) && + C->ule(DemandedMask.countLeadingZeros())) { + // trunc (lshr X, C) --> lshr (trunc X), C + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Trunc = Builder.CreateTrunc(X, I->getType()); + return Builder.CreateLShr(Trunc, C->getZExtValue()); + } + } + } + LLVM_FALLTHROUGH; + case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); @@ -516,8 +534,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I->getOperand(0); // We can't do this with the LHS for subtraction, unless we are only // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || - DemandedFromOps.isOneValue()) && + if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); @@ -615,7 +632,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless // the shift amount is >= the size of the datatype, which is undefined. - if (DemandedMask.isOneValue()) { + if (DemandedMask.isOne()) { // Perform the logical shift right. Instruction *NewVal = BinaryOperator::CreateLShr( I->getOperand(0), I->getOperand(1), I->getName()); @@ -743,7 +760,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } case Instruction::URem: { KnownBits Known2(BitWidth); - APInt AllOnes = APInt::getAllOnesValue(BitWidth); + APInt AllOnes = APInt::getAllOnes(BitWidth); if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) return I; @@ -829,6 +846,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBitsComputed = true; break; } + case Intrinsic::umax: { + // UMax(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-zero bit of C. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getActiveBits()) + return II->getArgOperand(0); + break; + } + case Intrinsic::umin: { + // UMin(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-one bit of C. + // This comes from using DeMorgans on the above umax example. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getBitWidth() - C->countLeadingOnes()) + return II->getArgOperand(0); + break; + } default: { // Handle target specific intrinsics Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( @@ -1021,8 +1061,8 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits( Known.Zero.setLowBits(ShlAmt - 1); Known.Zero &= DemandedMask; - APInt BitMask1(APInt::getAllOnesValue(BitWidth)); - APInt BitMask2(APInt::getAllOnesValue(BitWidth)); + APInt BitMask1(APInt::getAllOnes(BitWidth)); + APInt BitMask2(APInt::getAllOnes(BitWidth)); bool isLshr = (Shr->getOpcode() == Instruction::LShr); BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : @@ -1088,7 +1128,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return nullptr; unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); - APInt EltMask(APInt::getAllOnesValue(VWidth)); + APInt EltMask(APInt::getAllOnes(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); if (match(V, m_Undef())) { @@ -1097,7 +1137,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return nullptr; } - if (DemandedElts.isNullValue()) { // If nothing is demanded, provide poison. + if (DemandedElts.isZero()) { // If nothing is demanded, provide poison. UndefElts = EltMask; return PoisonValue::get(V->getType()); } @@ -1107,7 +1147,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, if (auto *C = dyn_cast<Constant>(V)) { // Check if this is identity. If so, return 0 since we are not simplifying // anything. - if (DemandedElts.isAllOnesValue()) + if (DemandedElts.isAllOnes()) return nullptr; Type *EltTy = cast<VectorType>(V->getType())->getElementType(); @@ -1260,7 +1300,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Handle trivial case of a splat. Only check the first element of LHS // operand. if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && - DemandedElts.isAllOnesValue()) { + DemandedElts.isAllOnes()) { if (!match(I->getOperand(1), m_Undef())) { I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType())); MadeChange = true; @@ -1515,8 +1555,8 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Subtlety: If we load from a pointer, the pointer must be valid // regardless of whether the element is demanded. Doing otherwise risks // segfaults which didn't exist in the original program. - APInt DemandedPtrs(APInt::getAllOnesValue(VWidth)), - DemandedPassThrough(DemandedElts); + APInt DemandedPtrs(APInt::getAllOnes(VWidth)), + DemandedPassThrough(DemandedElts); if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) for (unsigned i = 0; i < VWidth; i++) { Constant *CElt = CV->getAggregateElement(i); @@ -1568,7 +1608,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // If we've proven all of the lanes undef, return an undef value. // TODO: Intersect w/demanded lanes - if (UndefElts.isAllOnesValue()) + if (UndefElts.isAllOnes()) return UndefValue::get(I->getType());; return MadeChange ? I : nullptr; diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 32b15376f898..32e537897140 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -35,37 +35,46 @@ #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <cstdint> #include <iterator> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" - STATISTIC(NumAggregateReconstructionsSimplified, "Number of aggregate reconstructions turned into reuse of the " "original aggregate"); /// Return true if the value is cheaper to scalarize than it is to leave as a -/// vector operation. IsConstantExtractIndex indicates whether we are extracting -/// one known element from a vector constant. +/// vector operation. If the extract index \p EI is a constant integer then +/// some operations may be cheap to scalarize. /// /// FIXME: It's possible to create more instructions than previously existed. -static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { +static bool cheapToScalarize(Value *V, Value *EI) { + ConstantInt *CEI = dyn_cast<ConstantInt>(EI); + // If we can pick a scalar constant value out of a vector, that is free. if (auto *C = dyn_cast<Constant>(V)) - return IsConstantExtractIndex || C->getSplatValue(); + return CEI || C->getSplatValue(); + + if (CEI && match(V, m_Intrinsic<Intrinsic::experimental_stepvector>())) { + ElementCount EC = cast<VectorType>(V->getType())->getElementCount(); + // Index needs to be lower than the minimum size of the vector, because + // for scalable vector, the vector size is known at run time. + return CEI->getValue().ult(EC.getKnownMinValue()); + } // An insertelement to the same constant index as our extract will simplify // to the scalar inserted element. An insertelement to a different constant // index is irrelevant to our extract. if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt()))) - return IsConstantExtractIndex; + return CEI; if (match(V, m_OneUse(m_Load(m_Value())))) return true; @@ -75,14 +84,12 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { Value *V0, *V1; if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1))))) - if (cheapToScalarize(V0, IsConstantExtractIndex) || - cheapToScalarize(V1, IsConstantExtractIndex)) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) return true; CmpInst::Predicate UnusedPred; if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1))))) - if (cheapToScalarize(V0, IsConstantExtractIndex) || - cheapToScalarize(V1, IsConstantExtractIndex)) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) return true; return false; @@ -119,7 +126,8 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // and that it is a binary operation which is cheap to scalarize. // otherwise return nullptr. if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) || - !(isa<BinaryOperator>(PHIUser)) || !cheapToScalarize(PHIUser, true)) + !(isa<BinaryOperator>(PHIUser)) || + !cheapToScalarize(PHIUser, EI.getIndexOperand())) return nullptr; // Create a scalar PHI node that will replace the vector PHI node @@ -170,24 +178,46 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, return &EI; } -static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, - InstCombiner::BuilderTy &Builder, - bool IsBigEndian) { +Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { Value *X; uint64_t ExtIndexC; if (!match(Ext.getVectorOperand(), m_BitCast(m_Value(X))) || - !X->getType()->isVectorTy() || !match(Ext.getIndexOperand(), m_ConstantInt(ExtIndexC))) return nullptr; + ElementCount NumElts = + cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); + Type *DestTy = Ext.getType(); + bool IsBigEndian = DL.isBigEndian(); + + // If we are casting an integer to vector and extracting a portion, that is + // a shift-right and truncate. + // TODO: Allow FP dest type by casting the trunc to FP? + if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() && + isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) { + assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) && + "Expected fixed vector type for bitcast from scalar integer"); + + // Big endian requires adjusting the extract index since MSB is at index 0. + // LittleEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 X to i8 + // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8 + if (IsBigEndian) + ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC; + unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits(); + if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) { + Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); + return new TruncInst(Lshr, DestTy); + } + } + + if (!X->getType()->isVectorTy()) + return nullptr; + // If this extractelement is using a bitcast from a vector of the same number // of elements, see if we can find the source element from the source vector: // extelt (bitcast VecX), IndexC --> bitcast X[IndexC] auto *SrcTy = cast<VectorType>(X->getType()); - Type *DestTy = Ext.getType(); ElementCount NumSrcElts = SrcTy->getElementCount(); - ElementCount NumElts = - cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); if (NumSrcElts == NumElts) if (Value *Elt = findScalarElement(X, ExtIndexC)) return new BitCastInst(Elt, DestTy); @@ -274,7 +304,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); // Conservatively assume that all elements are needed. - APInt UsedElts(APInt::getAllOnesValue(VWidth)); + APInt UsedElts(APInt::getAllOnes(VWidth)); switch (UserInstr->getOpcode()) { case Instruction::ExtractElement: { @@ -322,11 +352,11 @@ static APInt findDemandedEltsByAllUsers(Value *V) { if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { UnionUsedElts |= findDemandedEltsBySingleUser(V, I); } else { - UnionUsedElts = APInt::getAllOnesValue(VWidth); + UnionUsedElts = APInt::getAllOnes(VWidth); break; } - if (UnionUsedElts.isAllOnesValue()) + if (UnionUsedElts.isAllOnes()) break; } @@ -388,7 +418,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // If the input vector has multiple uses, simplify it based on a union // of all elements used. APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); - if (!DemandedElts.isAllOnesValue()) { + if (!DemandedElts.isAllOnes()) { APInt UndefElts(NumElts, 0); if (Value *V = SimplifyDemandedVectorElts( SrcVec, DemandedElts, UndefElts, 0 /* Depth */, @@ -402,7 +432,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) + if (Instruction *I = foldBitcastExtElt(EI)) return I; // If there's a vector PHI feeding a scalar use through this extractelement @@ -415,7 +445,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // TODO come up with a n-ary matcher that subsumes both unary and // binary matchers. UnaryOperator *UO; - if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, IndexC)) { + if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, Index)) { // extelt (unop X), Index --> unop (extelt X, Index) Value *X = UO->getOperand(0); Value *E = Builder.CreateExtractElement(X, Index); @@ -423,7 +453,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { } BinaryOperator *BO; - if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) { + if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) { // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) Value *X = BO->getOperand(0), *Y = BO->getOperand(1); Value *E0 = Builder.CreateExtractElement(X, Index); @@ -434,7 +464,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { Value *X, *Y; CmpInst::Predicate Pred; if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) && - cheapToScalarize(SrcVec, IndexC)) { + cheapToScalarize(SrcVec, Index)) { // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) Value *E0 = Builder.CreateExtractElement(X, Index); Value *E1 = Builder.CreateExtractElement(Y, Index); @@ -651,8 +681,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) return; - auto *WideVec = - new ShuffleVectorInst(ExtVecOp, PoisonValue::get(ExtVecType), ExtendMask); + auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask); // Insert the new shuffle after the vector operand of the extract is defined // (as long as it's not a PHI) or at the start of the basic block of the @@ -913,7 +942,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( "We don't store nullptr in SourceAggregate!"); assert((Describe(SourceAggregate) == AggregateDescription::Found) == (I.index() != 0) && - "SourceAggregate should be valid after the the first element,"); + "SourceAggregate should be valid after the first element,"); // For this element, is there a plausible source aggregate? // FIXME: we could special-case undef element, IFF we know that in the @@ -1179,7 +1208,7 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { if (!ElementPresent[i]) Mask[i] = -1; - return new ShuffleVectorInst(FirstIE, PoisonVec, Mask); + return new ShuffleVectorInst(FirstIE, Mask); } /// Try to fold an insert element into an existing splat shuffle by changing @@ -1208,15 +1237,15 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { // Replace the shuffle mask element at the index of this insert with a zero. // For example: - // inselt (shuf (inselt undef, X, 0), undef, <0,undef,0,undef>), X, 1 - // --> shuf (inselt undef, X, 0), undef, <0,0,0,undef> + // inselt (shuf (inselt undef, X, 0), _, <0,undef,0,undef>), X, 1 + // --> shuf (inselt undef, X, 0), poison, <0,0,0,undef> unsigned NumMaskElts = cast<FixedVectorType>(Shuf->getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts); for (unsigned i = 0; i != NumMaskElts; ++i) NewMask[i] = i == IdxC ? 0 : Shuf->getMaskValue(i); - return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); + return new ShuffleVectorInst(Op0, NewMask); } /// Try to fold an extract+insert element into an existing identity shuffle by @@ -1348,6 +1377,10 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { NewShufElts[I] = ShufConstVec->getAggregateElement(I); NewMaskElts[I] = Mask[I]; } + + // Bail if we failed to find an element. + if (!NewShufElts[I]) + return nullptr; } // Create new operands for a shuffle that includes the constant of the @@ -1399,6 +1432,41 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { return nullptr; } +/// If both the base vector and the inserted element are extended from the same +/// type, do the insert element in the narrow source type followed by extend. +/// TODO: This can be extended to include other cast opcodes, but particularly +/// if we create a wider insertelement, make sure codegen is not harmed. +static Instruction *narrowInsElt(InsertElementInst &InsElt, + InstCombiner::BuilderTy &Builder) { + // We are creating a vector extend. If the original vector extend has another + // use, that would mean we end up with 2 vector extends, so avoid that. + // TODO: We could ease the use-clause to "if at least one op has one use" + // (assuming that the source types match - see next TODO comment). + Value *Vec = InsElt.getOperand(0); + if (!Vec->hasOneUse()) + return nullptr; + + Value *Scalar = InsElt.getOperand(1); + Value *X, *Y; + CastInst::CastOps CastOpcode; + if (match(Vec, m_FPExt(m_Value(X))) && match(Scalar, m_FPExt(m_Value(Y)))) + CastOpcode = Instruction::FPExt; + else if (match(Vec, m_SExt(m_Value(X))) && match(Scalar, m_SExt(m_Value(Y)))) + CastOpcode = Instruction::SExt; + else if (match(Vec, m_ZExt(m_Value(X))) && match(Scalar, m_ZExt(m_Value(Y)))) + CastOpcode = Instruction::ZExt; + else + return nullptr; + + // TODO: We can allow mismatched types by creating an intermediate cast. + if (X->getType()->getScalarType() != Y->getType()) + return nullptr; + + // inselt (ext X), (ext Y), Index --> ext (inselt X, Y, Index) + Value *NewInsElt = Builder.CreateInsertElement(X, Y, InsElt.getOperand(2)); + return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); +} + Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -1495,7 +1563,7 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (auto VecTy = dyn_cast<FixedVectorType>(VecOp->getType())) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { if (V != &IE) return replaceInstUsesWith(IE, V); @@ -1518,6 +1586,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) return IdentityShuf; + if (Instruction *Ext = narrowInsElt(IE, Builder)) + return Ext; + return nullptr; } @@ -1924,8 +1995,8 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, // Splat from element 0. Any mask element that is undefined remains undefined. // For example: - // shuf (inselt undef, X, 2), undef, <2,2,undef> - // --> shuf (inselt undef, X, 0), undef, <0,0,undef> + // shuf (inselt undef, X, 2), _, <2,2,undef> + // --> shuf (inselt undef, X, 0), poison, <0,0,undef> unsigned NumMaskElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts, 0); @@ -1933,7 +2004,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, if (Mask[i] == UndefMaskElem) NewMask[i] = Mask[i]; - return new ShuffleVectorInst(NewIns, UndefVec, NewMask); + return new ShuffleVectorInst(NewIns, NewMask); } /// Try to fold shuffles that are the equivalent of a vector select. @@ -2197,12 +2268,8 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, SmallVector<int, 16> Mask; Shuf.getShuffleMask(Mask); - // The shuffle must not change vector sizes. - // TODO: This restriction could be removed if the insert has only one use - // (because the transform would require a new length-changing shuffle). int NumElts = Mask.size(); - if (NumElts != (int)(cast<FixedVectorType>(V0->getType())->getNumElements())) - return nullptr; + int InpNumElts = cast<FixedVectorType>(V0->getType())->getNumElements(); // This is a specialization of a fold in SimplifyDemandedVectorElts. We may // not be able to handle it there if the insertelement has >1 use. @@ -2219,11 +2286,16 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, if (match(V1, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { // Offset the index constant by the vector width because we are checking for // accesses to the 2nd vector input of the shuffle. - IdxC += NumElts; + IdxC += InpNumElts; // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask if (!is_contained(Mask, (int)IdxC)) return IC.replaceOperand(Shuf, 1, X); } + // For the rest of the transform, the shuffle must not change vector sizes. + // TODO: This restriction could be removed if the insert has only one use + // (because the transform would require a new length-changing shuffle). + if (NumElts != InpNumElts) + return nullptr; // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { @@ -2413,16 +2485,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (LHS == RHS) { assert(!match(RHS, m_Undef()) && "Shuffle with 2 undef ops not simplified?"); - // Remap any references to RHS to use LHS. - SmallVector<int, 16> Elts; - for (unsigned i = 0; i != VWidth; ++i) { - // Propagate undef elements or force mask to LHS. - if (Mask[i] < 0) - Elts.push_back(UndefMaskElem); - else - Elts.push_back(Mask[i] % LHSWidth); - } - return new ShuffleVectorInst(LHS, UndefValue::get(RHS->getType()), Elts); + return new ShuffleVectorInst(LHS, createUnaryMask(Mask, LHSWidth)); } // shuffle undef, x, mask --> shuffle x, undef, mask' @@ -2444,7 +2507,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return I; APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { if (V != &SVI) return replaceInstUsesWith(SVI, V); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 4e3b18e805ee..47b6dcb67a78 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -100,7 +100,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> @@ -109,11 +108,12 @@ #include <string> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace llvm::PatternMatch; -#define DEBUG_TYPE "instcombine" - STATISTIC(NumWorklistIterations, "Number of instruction combining iterations performed"); @@ -202,23 +202,37 @@ Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(&Builder, DL, GEP); } +/// Legal integers and common types are considered desirable. This is used to +/// avoid creating instructions with types that may not be supported well by the +/// the backend. +/// NOTE: This treats i8, i16 and i32 specially because they are common +/// types in frontend languages. +bool InstCombinerImpl::isDesirableIntType(unsigned BitWidth) const { + switch (BitWidth) { + case 8: + case 16: + case 32: + return true; + default: + return DL.isLegalInteger(BitWidth); + } +} + /// Return true if it is desirable to convert an integer computation from a /// given bit width to a new bit width. /// We don't want to convert from a legal to an illegal type or from a smaller -/// to a larger illegal type. A width of '1' is always treated as a legal type -/// because i1 is a fundamental type in IR, and there are many specialized -/// optimizations for i1 types. Widths of 8, 16 or 32 are equally treated as +/// to a larger illegal type. A width of '1' is always treated as a desirable +/// type because i1 is a fundamental type in IR, and there are many specialized +/// optimizations for i1 types. Common/desirable widths are equally treated as /// legal to convert to, in order to open up more combining opportunities. -/// NOTE: this treats i8, i16 and i32 specially, due to them being so common -/// from frontend languages. bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); - // Convert to widths of 8, 16 or 32 even if they are not legal types. Only - // shrink types, to prevent infinite loops. - if (ToWidth < FromWidth && (ToWidth == 8 || ToWidth == 16 || ToWidth == 32)) + // Convert to desirable widths even if they are not legal types. + // Only shrink types, to prevent infinite loops. + if (ToWidth < FromWidth && isDesirableIntType(ToWidth)) return true; // If this is a legal integer from type, and the result would be an illegal @@ -359,7 +373,8 @@ Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { PtrToInt->getSrcTy()->getPointerAddressSpace() && DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == DL.getTypeSizeInBits(PtrToInt->getDestTy())) { - return Builder.CreateBitCast(PtrToInt->getOperand(0), CastTy); + return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, + "", PtrToInt); } } return nullptr; @@ -961,14 +976,14 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, assert(canConstantFoldCallTo(II, cast<Function>(II->getCalledOperand())) && "Expected constant-foldable intrinsic"); Intrinsic::ID IID = II->getIntrinsicID(); - if (II->getNumArgOperands() == 1) + if (II->arg_size() == 1) return Builder.CreateUnaryIntrinsic(IID, SO); // This works for real binary ops like min/max (where we always expect the // constant operand to be canonicalized as op1) and unary ops with a bonus // constant argument like ctlz/cttz. // TODO: Handle non-commutative binary intrinsics as below for binops. - assert(II->getNumArgOperands() == 2 && "Expected binary intrinsic"); + assert(II->arg_size() == 2 && "Expected binary intrinsic"); assert(isa<Constant>(II->getArgOperand(1)) && "Expected constant operand"); return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); } @@ -1058,7 +1073,7 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, // Compare for equality including undefs as equal. auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB); const APInt *C; - return match(Cmp, m_APIntAllowUndef(C)) && C->isOneValue(); + return match(Cmp, m_APIntAllowUndef(C)) && C->isOne(); }; if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) || @@ -1120,9 +1135,11 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { BasicBlock *NonConstBB = nullptr; for (unsigned i = 0; i != NumPHIValues; ++i) { Value *InVal = PN->getIncomingValue(i); - // If I is a freeze instruction, count undef as a non-constant. - if (match(InVal, m_ImmConstant()) && - (!isa<FreezeInst>(I) || isGuaranteedNotToBeUndefOrPoison(InVal))) + // For non-freeze, require constant operand + // For freeze, require non-undef, non-poison operand + if (!isa<FreezeInst>(I) && match(InVal, m_ImmConstant())) + continue; + if (isa<FreezeInst>(I) && isGuaranteedNotToBeUndefOrPoison(InVal)) continue; if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. @@ -1268,61 +1285,19 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { /// specified offset. If so, fill them into NewIndices and return the resultant /// element type, otherwise return null. Type * -InstCombinerImpl::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, +InstCombinerImpl::FindElementAtOffset(PointerType *PtrTy, int64_t IntOffset, SmallVectorImpl<Value *> &NewIndices) { Type *Ty = PtrTy->getElementType(); if (!Ty->isSized()) return nullptr; - // Start with the index over the outer type. Note that the type size - // might be zero (even if the offset isn't zero) if the indexed type - // is something like [0 x {int, int}] - Type *IndexTy = DL.getIndexType(PtrTy); - int64_t FirstIdx = 0; - if (int64_t TySize = DL.getTypeAllocSize(Ty)) { - FirstIdx = Offset/TySize; - Offset -= FirstIdx*TySize; - - // Handle hosts where % returns negative instead of values [0..TySize). - if (Offset < 0) { - --FirstIdx; - Offset += TySize; - assert(Offset >= 0); - } - assert((uint64_t)Offset < (uint64_t)TySize && "Out of range offset"); - } - - NewIndices.push_back(ConstantInt::get(IndexTy, FirstIdx)); - - // Index into the types. If we fail, set OrigBase to null. - while (Offset) { - // Indexing into tail padding between struct/array elements. - if (uint64_t(Offset * 8) >= DL.getTypeSizeInBits(Ty)) - return nullptr; - - if (StructType *STy = dyn_cast<StructType>(Ty)) { - const StructLayout *SL = DL.getStructLayout(STy); - assert(Offset < (int64_t)SL->getSizeInBytes() && - "Offset must stay within the indexed type"); - - unsigned Elt = SL->getElementContainingOffset(Offset); - NewIndices.push_back(ConstantInt::get(Type::getInt32Ty(Ty->getContext()), - Elt)); - - Offset -= SL->getElementOffset(Elt); - Ty = STy->getElementType(Elt); - } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) { - uint64_t EltSize = DL.getTypeAllocSize(AT->getElementType()); - assert(EltSize && "Cannot index into a zero-sized array"); - NewIndices.push_back(ConstantInt::get(IndexTy,Offset/EltSize)); - Offset %= EltSize; - Ty = AT->getElementType(); - } else { - // Otherwise, we can't index into the middle of this atomic type, bail. - return nullptr; - } - } + APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), IntOffset); + SmallVector<APInt> Indices = DL.getGEPIndicesForOffset(Ty, Offset); + if (!Offset.isZero()) + return nullptr; + for (const APInt &Index : Indices) + NewIndices.push_back(Builder.getInt(Index)); return Ty; } @@ -1623,7 +1598,7 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { Value *XY = Builder.CreateBinOp(Opcode, X, Y); if (auto *BO = dyn_cast<BinaryOperator>(XY)) BO->copyIRFlags(&Inst); - return new ShuffleVectorInst(XY, UndefValue::get(XY->getType()), M); + return new ShuffleVectorInst(XY, M); }; // If both arguments of the binary operation are shuffles that use the same @@ -1754,25 +1729,20 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { Value *X; ArrayRef<int> MaskC; int SplatIndex; - BinaryOperator *BO; + Value *Y, *OtherOp; if (!match(LHS, m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(MaskC)))) || !match(MaskC, m_SplatOrUndefMask(SplatIndex)) || - X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) || - BO->getOpcode() != Opcode) + X->getType() != Inst.getType() || + !match(RHS, m_OneUse(m_BinOp(Opcode, m_Value(Y), m_Value(OtherOp))))) return nullptr; // FIXME: This may not be safe if the analysis allows undef elements. By // moving 'Y' before the splat shuffle, we are implicitly assuming // that it is not undef/poison at the splat index. - Value *Y, *OtherOp; - if (isSplatValue(BO->getOperand(0), SplatIndex)) { - Y = BO->getOperand(0); - OtherOp = BO->getOperand(1); - } else if (isSplatValue(BO->getOperand(1), SplatIndex)) { - Y = BO->getOperand(1); - OtherOp = BO->getOperand(0); - } else { + if (isSplatValue(OtherOp, SplatIndex)) { + std::swap(Y, OtherOp); + } else if (!isSplatValue(Y, SplatIndex)) { return nullptr; } @@ -1788,7 +1758,7 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { // dropped to be safe. if (isa<FPMathOperator>(R)) { R->copyFastMathFlags(&Inst); - R->andIRFlags(BO); + R->andIRFlags(RHS); } if (auto *NewInstBO = dyn_cast<BinaryOperator>(NewBO)) NewInstBO->copyIRFlags(R); @@ -1896,7 +1866,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); - if (Value *V = SimplifyGEPInst(GEPEltType, Ops, SQ.getWithInstruction(&GEP))) + if (Value *V = SimplifyGEPInst(GEPEltType, Ops, GEP.isInBounds(), + SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); // For vector geps, use the generic demanded vector support. @@ -1905,7 +1876,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (auto *GEPFVTy = dyn_cast<FixedVectorType>(GEPType)) { auto VWidth = GEPFVTy->getNumElements(); APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&GEP, AllOnesEltMask, UndefElts)) { if (V != &GEP) @@ -2117,10 +2088,12 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // -- have to recreate %src & %gep // put NewSrc at same location as %src Builder.SetInsertPoint(cast<Instruction>(PtrOp)); - auto *NewSrc = cast<GetElementPtrInst>( - Builder.CreateGEP(GEPEltType, SO0, GO1, Src->getName())); - NewSrc->setIsInBounds(Src->isInBounds()); - auto *NewGEP = + Value *NewSrc = + Builder.CreateGEP(GEPEltType, SO0, GO1, Src->getName()); + // Propagate 'inbounds' if the new source was not constant-folded. + if (auto *NewSrcGEPI = dyn_cast<GetElementPtrInst>(NewSrc)) + NewSrcGEPI->setIsInBounds(Src->isInBounds()); + GetElementPtrInst *NewGEP = GetElementPtrInst::Create(GEPEltType, NewSrc, {SO1}); NewGEP->setIsInBounds(GEP.isInBounds()); return NewGEP; @@ -2128,18 +2101,6 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } } - - // Fold (gep(gep(Ptr,Idx0),Idx1) -> gep(Ptr,add(Idx0,Idx1)) - if (GO1->getType() == SO1->getType()) { - bool NewInBounds = GEP.isInBounds() && Src->isInBounds(); - auto *NewIdx = - Builder.CreateAdd(GO1, SO1, GEP.getName() + ".idx", - /*HasNUW*/ false, /*HasNSW*/ NewInBounds); - auto *NewGEP = GetElementPtrInst::Create( - GEPEltType, Src->getPointerOperand(), {NewIdx}); - NewGEP->setIsInBounds(NewInBounds); - return NewGEP; - } } // Note that if our source is a gep chain itself then we wait for that @@ -2647,6 +2608,13 @@ static bool isAllocSiteRemovable(Instruction *AI, Users.emplace_back(I); continue; } + + if (isReallocLikeFn(I, TLI, true)) { + Users.emplace_back(I); + Worklist.push_back(I); + continue; + } + return false; case Instruction::Store: { @@ -2834,15 +2802,33 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, // At this point, we know that everything in FreeInstrBB can be moved // before TI. - for (BasicBlock::iterator It = FreeInstrBB->begin(), End = FreeInstrBB->end(); - It != End;) { - Instruction &Instr = *It++; + for (Instruction &Instr : llvm::make_early_inc_range(*FreeInstrBB)) { if (&Instr == FreeInstrBBTerminator) break; Instr.moveBefore(TI); } assert(FreeInstrBB->size() == 1 && "Only the branch instruction should remain"); + + // Now that we've moved the call to free before the NULL check, we have to + // remove any attributes on its parameter that imply it's non-null, because + // those attributes might have only been valid because of the NULL check, and + // we can get miscompiles if we keep them. This is conservative if non-null is + // also implied by something other than the NULL check, but it's guaranteed to + // be correct, and the conservativeness won't matter in practice, since the + // attributes are irrelevant for the call to free itself and the pointer + // shouldn't be used after the call. + AttributeList Attrs = FI.getAttributes(); + Attrs = Attrs.removeParamAttribute(FI.getContext(), 0, Attribute::NonNull); + Attribute Dereferenceable = Attrs.getParamAttr(0, Attribute::Dereferenceable); + if (Dereferenceable.isValid()) { + uint64_t Bytes = Dereferenceable.getDereferenceableBytes(); + Attrs = Attrs.removeParamAttribute(FI.getContext(), 0, + Attribute::Dereferenceable); + Attrs = Attrs.addDereferenceableOrNullParamAttr(FI.getContext(), 0, Bytes); + } + FI.setAttributes(Attrs); + return &FI; } @@ -2861,6 +2847,15 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI) { if (isa<ConstantPointerNull>(Op)) return eraseInstFromFunction(FI); + // If we had free(realloc(...)) with no intervening uses, then eliminate the + // realloc() entirely. + if (CallInst *CI = dyn_cast<CallInst>(Op)) { + if (CI->hasOneUse() && isReallocLikeFn(CI, &TLI, true)) { + return eraseInstFromFunction( + *replaceInstUsesWith(*CI, CI->getOperand(0))); + } + } + // If we optimize for code size, try to move the call to free before the null // test so that simplify cfg can remove the empty block and dead code // elimination the branch. I.e., helps to turn something like: @@ -2947,7 +2942,7 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { auto GetLastSinkableStore = [](BasicBlock::iterator BBI) { auto IsNoopInstrForStoreMerging = [](BasicBlock::iterator BBI) { - return isa<DbgInfoIntrinsic>(BBI) || + return BBI->isDebugOrPseudoInst() || (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy()); }; @@ -3138,26 +3133,21 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { // checking for overflow. const APInt *C; if (match(WO->getRHS(), m_APInt(C))) { - // Compute the no-wrap range [X,Y) for LHS given RHS=C, then - // check for the inverted range using range offset trick (i.e. - // use a subtract to shift the range to bottom of either the - // signed or unsigned domain and then use a single compare to - // check range membership). + // Compute the no-wrap range for LHS given RHS=C, then construct an + // equivalent icmp, potentially using an offset. ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, WO->getNoWrapKind()); - APInt Min = WO->isSigned() ? NWR.getSignedMin() : NWR.getUnsignedMin(); - NWR = NWR.subtract(Min); CmpInst::Predicate Pred; - APInt NewRHSC; - if (NWR.getEquivalentICmp(Pred, NewRHSC)) { - auto *OpTy = WO->getRHS()->getType(); - auto *NewLHS = Builder.CreateSub(WO->getLHS(), - ConstantInt::get(OpTy, Min)); - return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, - ConstantInt::get(OpTy, NewRHSC)); - } + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + auto *OpTy = WO->getRHS()->getType(); + auto *NewLHS = WO->getLHS(); + if (Offset != 0) + NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); + return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, + ConstantInt::get(OpTy, NewRHSC)); } } } @@ -3183,9 +3173,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Instruction *NL = Builder.CreateLoad(EV.getType(), GEP); // Whatever aliasing information we had for the orignal load must also // hold for the smaller load, so propagate the annotations. - AAMDNodes Nodes; - L->getAAMetadata(Nodes); - NL->setAAMetadata(Nodes); + NL->setAAMetadata(L->getAAMetadata()); // Returning the load directly will cause the main loop to insert it in // the wrong spot, so use replaceInstUsesWith(). return replaceInstUsesWith(EV, NL); @@ -3568,8 +3556,14 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { // While we could change the other users of OrigOp to use freeze(OrigOp), that // potentially reduces their optimization potential, so let's only do this iff // the OrigOp is only used by the freeze. - if (!OrigOpInst || !OrigOpInst->hasOneUse() || isa<PHINode>(OrigOp) || - canCreateUndefOrPoison(dyn_cast<Operator>(OrigOp))) + if (!OrigOpInst || !OrigOpInst->hasOneUse() || isa<PHINode>(OrigOp)) + return nullptr; + + // We can't push the freeze through an instruction which can itself create + // poison. If the only source of new poison is flags, we can simply + // strip them (since we know the only use is the freeze and nothing can + // benefit from them.) + if (canCreateUndefOrPoison(cast<Operator>(OrigOp), /*ConsiderFlags*/ false)) return nullptr; // If operand is guaranteed not to be poison, there is no need to add freeze @@ -3585,6 +3579,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { return nullptr; } + OrigOpInst->dropPoisonGeneratingFlags(); + // If all operands are guaranteed to be non-poison, we can drop freeze. if (!MaybePoisonOperand) return OrigOp; @@ -3668,7 +3664,7 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { /// instruction past all of the instructions between it and the end of its /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { - assert(I->getSingleUndroppableUse() && "Invariants didn't hold!"); + assert(I->getUniqueUndroppableUser() && "Invariants didn't hold!"); BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -3822,51 +3818,71 @@ bool InstCombinerImpl::run() { // See if we can trivially sink this instruction to its user if we can // prove that the successor is not executed more frequently than our block. - if (EnableCodeSinking) - if (Use *SingleUse = I->getSingleUndroppableUse()) { - BasicBlock *BB = I->getParent(); - Instruction *UserInst = cast<Instruction>(SingleUse->getUser()); - BasicBlock *UserParent; - - // Get the block the use occurs in. - if (PHINode *PN = dyn_cast<PHINode>(UserInst)) - UserParent = PN->getIncomingBlock(*SingleUse); - else - UserParent = UserInst->getParent(); - - // Try sinking to another block. If that block is unreachable, then do - // not bother. SimplifyCFG should handle it. - if (UserParent != BB && DT.isReachableFromEntry(UserParent)) { - // See if the user is one of our successors that has only one - // predecessor, so that we don't have to split the critical edge. - bool ShouldSink = UserParent->getUniquePredecessor() == BB; - // Another option where we can sink is a block that ends with a - // terminator that does not pass control to other block (such as - // return or unreachable). In this case: - // - I dominates the User (by SSA form); - // - the User will be executed at most once. - // So sinking I down to User is always profitable or neutral. - if (!ShouldSink) { - auto *Term = UserParent->getTerminator(); - ShouldSink = isa<ReturnInst>(Term) || isa<UnreachableInst>(Term); - } - if (ShouldSink) { - assert(DT.dominates(BB, UserParent) && - "Dominance relation broken?"); - // Okay, the CFG is simple enough, try to sink this instruction. - if (TryToSinkInstruction(I, UserParent)) { - LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); - MadeIRChange = true; - // We'll add uses of the sunk instruction below, but since sinking - // can expose opportunities for it's *operands* add them to the - // worklist - for (Use &U : I->operands()) - if (Instruction *OpI = dyn_cast<Instruction>(U.get())) - Worklist.push(OpI); - } + // Return the UserBlock if successful. + auto getOptionalSinkBlockForInst = + [this](Instruction *I) -> Optional<BasicBlock *> { + if (!EnableCodeSinking) + return None; + auto *UserInst = cast_or_null<Instruction>(I->getUniqueUndroppableUser()); + if (!UserInst) + return None; + + BasicBlock *BB = I->getParent(); + BasicBlock *UserParent = nullptr; + + // Special handling for Phi nodes - get the block the use occurs in. + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) { + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + if (PN->getIncomingValue(i) == I) { + // Bail out if we have uses in different blocks. We don't do any + // sophisticated analysis (i.e finding NearestCommonDominator of these + // use blocks). + if (UserParent && UserParent != PN->getIncomingBlock(i)) + return None; + UserParent = PN->getIncomingBlock(i); } } + assert(UserParent && "expected to find user block!"); + } else + UserParent = UserInst->getParent(); + + // Try sinking to another block. If that block is unreachable, then do + // not bother. SimplifyCFG should handle it. + if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) + return None; + + auto *Term = UserParent->getTerminator(); + // See if the user is one of our successors that has only one + // predecessor, so that we don't have to split the critical edge. + // Another option where we can sink is a block that ends with a + // terminator that does not pass control to other block (such as + // return or unreachable). In this case: + // - I dominates the User (by SSA form); + // - the User will be executed at most once. + // So sinking I down to User is always profitable or neutral. + if (UserParent->getUniquePredecessor() == BB || + (isa<ReturnInst>(Term) || isa<UnreachableInst>(Term))) { + assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); + return UserParent; } + return None; + }; + + auto OptBB = getOptionalSinkBlockForInst(I); + if (OptBB) { + auto *UserParent = *OptBB; + // Okay, the CFG is simple enough, try to sink this instruction. + if (TryToSinkInstruction(I, UserParent)) { + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + MadeIRChange = true; + // We'll add uses of the sunk instruction below, but since + // sinking can expose opportunities for it's *operands* add + // them to the worklist + for (Use &U : I->operands()) + if (Instruction *OpI = dyn_cast<Instruction>(U.get())) + Worklist.push(OpI); + } + } // Now that we have an instruction, try combining it to simplify it. Builder.SetInsertPoint(I); @@ -3994,13 +4010,13 @@ public: /// whose condition is a known constant, we only visit the reachable successors. static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, const TargetLibraryInfo *TLI, - InstCombineWorklist &ICWorklist) { + InstructionWorklist &ICWorklist) { bool MadeIRChange = false; SmallPtrSet<BasicBlock *, 32> Visited; SmallVector<BasicBlock*, 256> Worklist; Worklist.push_back(&F.front()); - SmallVector<Instruction*, 128> InstrsForInstCombineWorklist; + SmallVector<Instruction *, 128> InstrsForInstructionWorklist; DenseMap<Constant *, Constant *> FoldedConstants; AliasScopeTracker SeenAliasScopes; @@ -4011,25 +4027,23 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, if (!Visited.insert(BB).second) continue; - for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { - Instruction *Inst = &*BBI++; - + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { // ConstantProp instruction if trivially constant. - if (!Inst->use_empty() && - (Inst->getNumOperands() == 0 || isa<Constant>(Inst->getOperand(0)))) - if (Constant *C = ConstantFoldInstruction(Inst, DL, TLI)) { - LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *Inst + if (!Inst.use_empty() && + (Inst.getNumOperands() == 0 || isa<Constant>(Inst.getOperand(0)))) + if (Constant *C = ConstantFoldInstruction(&Inst, DL, TLI)) { + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << Inst << '\n'); - Inst->replaceAllUsesWith(C); + Inst.replaceAllUsesWith(C); ++NumConstProp; - if (isInstructionTriviallyDead(Inst, TLI)) - Inst->eraseFromParent(); + if (isInstructionTriviallyDead(&Inst, TLI)) + Inst.eraseFromParent(); MadeIRChange = true; continue; } // See if we can constant fold its operands. - for (Use &U : Inst->operands()) { + for (Use &U : Inst.operands()) { if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U)) continue; @@ -4039,7 +4053,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, FoldRes = ConstantFoldConstant(C, DL, TLI); if (FoldRes != C) { - LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst + LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << Inst << "\n Old = " << *C << "\n New = " << *FoldRes << '\n'); U = FoldRes; @@ -4050,9 +4064,9 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Skip processing debug and pseudo intrinsics in InstCombine. Processing // these call instructions consumes non-trivial amount of time and // provides no value for the optimization. - if (!Inst->isDebugOrPseudoInst()) { - InstrsForInstCombineWorklist.push_back(Inst); - SeenAliasScopes.analyse(Inst); + if (!Inst.isDebugOrPseudoInst()) { + InstrsForInstructionWorklist.push_back(&Inst); + SeenAliasScopes.analyse(&Inst); } } @@ -4097,8 +4111,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // of the function down. This jives well with the way that it adds all uses // of instructions to the worklist after doing a transformation, thus avoiding // some N^2 behavior in pathological cases. - ICWorklist.reserve(InstrsForInstCombineWorklist.size()); - for (Instruction *Inst : reverse(InstrsForInstCombineWorklist)) { + ICWorklist.reserve(InstrsForInstructionWorklist.size()); + for (Instruction *Inst : reverse(InstrsForInstructionWorklist)) { // DCE instruction if trivially dead. As we iterate in reverse program // order here, we will clean up whole chains of dead instructions. if (isInstructionTriviallyDead(Inst, TLI) || @@ -4118,7 +4132,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, } static bool combineInstructionsOverFunction( - Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, + Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 0d4ca0bcecfb..b56329ad76ae 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/BinaryFormat/MachO.h" @@ -47,6 +48,7 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" @@ -176,7 +178,15 @@ const char kAMDGPUAddressPrivateName[] = "llvm.amdgcn.is.private"; // Accesses sizes are powers of two: 1, 2, 4, 8, 16. static const size_t kNumberOfAccessSizes = 5; -static const unsigned kAllocaRzSize = 32; +static const uint64_t kAllocaRzSize = 32; + +// ASanAccessInfo implementation constants. +constexpr size_t kCompileKernelShift = 0; +constexpr size_t kCompileKernelMask = 0x1; +constexpr size_t kAccessSizeIndexShift = 1; +constexpr size_t kAccessSizeIndexMask = 0xf; +constexpr size_t kIsWriteShift = 5; +constexpr size_t kIsWriteMask = 0x1; // Command-line flags. @@ -203,6 +213,11 @@ static cl::opt<bool> ClInstrumentWrites( "asan-instrument-writes", cl::desc("instrument write instructions"), cl::Hidden, cl::init(true)); +static cl::opt<bool> + ClUseStackSafety("asan-use-stack-safety", cl::Hidden, cl::init(false), + cl::Hidden, cl::desc("Use Stack Safety analysis results"), + cl::Optional); + static cl::opt<bool> ClInstrumentAtomics( "asan-instrument-atomics", cl::desc("instrument atomic instructions (rmw, cmpxchg)"), cl::Hidden, @@ -348,6 +363,10 @@ static cl::opt<uint64_t> static cl::opt<bool> ClOpt("asan-opt", cl::desc("Optimize instrumentation"), cl::Hidden, cl::init(true)); +static cl::opt<bool> ClOptimizeCallbacks("asan-optimize-callbacks", + cl::desc("Optimize callbacks"), + cl::Hidden, cl::init(false)); + static cl::opt<bool> ClOptSameTemp( "asan-opt-same-temp", cl::desc("Instrument the same temp just once"), cl::Hidden, cl::init(true)); @@ -442,7 +461,7 @@ struct ShadowMapping { } // end anonymous namespace -static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, +static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, bool IsKasan) { bool IsAndroid = TargetTriple.isAndroid(); bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS(); @@ -559,6 +578,32 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, return Mapping; } +namespace llvm { +void getAddressSanitizerParams(const Triple &TargetTriple, int LongSize, + bool IsKasan, uint64_t *ShadowBase, + int *MappingScale, bool *OrShadowOffset) { + auto Mapping = getShadowMapping(TargetTriple, LongSize, IsKasan); + *ShadowBase = Mapping.Offset; + *MappingScale = Mapping.Scale; + *OrShadowOffset = Mapping.OrShadowOffset; +} + +ASanAccessInfo::ASanAccessInfo(int32_t Packed) + : Packed(Packed), + AccessSizeIndex((Packed >> kAccessSizeIndexShift) & kAccessSizeIndexMask), + IsWrite((Packed >> kIsWriteShift) & kIsWriteMask), + CompileKernel((Packed >> kCompileKernelShift) & kCompileKernelMask) {} + +ASanAccessInfo::ASanAccessInfo(bool IsWrite, bool CompileKernel, + uint8_t AccessSizeIndex) + : Packed((IsWrite << kIsWriteShift) + + (CompileKernel << kCompileKernelShift) + + (AccessSizeIndex << kAccessSizeIndexShift)), + AccessSizeIndex(AccessSizeIndex), IsWrite(IsWrite), + CompileKernel(CompileKernel) {} + +} // namespace llvm + static uint64_t getRedzoneSizeForScale(int MappingScale) { // Redzone used for stack and globals is at least 32 bytes. // For scales 6 and 7, the redzone has to be 64 and 128 bytes respectively. @@ -609,6 +654,7 @@ char ASanGlobalsMetadataWrapperPass::ID = 0; /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer { AddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD, + const StackSafetyGlobalInfo *SSGI, bool CompileKernel = false, bool Recover = false, bool UseAfterScope = false, AsanDetectStackUseAfterReturnMode UseAfterReturn = @@ -619,10 +665,12 @@ struct AddressSanitizer { UseAfterScope(UseAfterScope || ClUseAfterScope), UseAfterReturn(ClUseAfterReturn.getNumOccurrences() ? ClUseAfterReturn : UseAfterReturn), - GlobalsMD(*GlobalsMD) { + GlobalsMD(*GlobalsMD), SSGI(SSGI) { C = &(M.getContext()); LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); + Int8PtrTy = Type::getInt8PtrTy(*C); + Int32Ty = Type::getInt32Ty(*C); TargetTriple = Triple(M.getTargetTriple()); Mapping = getShadowMapping(TargetTriple, LongSize, this->CompileKernel); @@ -646,7 +694,7 @@ struct AddressSanitizer { /// Check if we want (and can) handle this alloca. bool isInterestingAlloca(const AllocaInst &AI); - bool ignoreAccess(Value *Ptr); + bool ignoreAccess(Instruction *Inst, Value *Ptr); void getInterestingMemoryOperands( Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting); @@ -713,6 +761,8 @@ private: bool UseAfterScope; AsanDetectStackUseAfterReturnMode UseAfterReturn; Type *IntptrTy; + Type *Int8PtrTy; + Type *Int32Ty; ShadowMapping Mapping; FunctionCallee AsanHandleNoReturnFunc; FunctionCallee AsanPtrCmpFunction, AsanPtrSubFunction; @@ -729,6 +779,7 @@ private: FunctionCallee AsanMemmove, AsanMemcpy, AsanMemset; Value *LocalDynamicShadow = nullptr; const GlobalsMetadata &GlobalsMD; + const StackSafetyGlobalInfo *SSGI; DenseMap<const AllocaInst *, bool> ProcessedAllocas; FunctionCallee AMDGPUAddressShared; @@ -755,16 +806,22 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<ASanGlobalsMetadataWrapperPass>(); + if (ClUseStackSafety) + AU.addRequired<StackSafetyGlobalInfoWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } bool runOnFunction(Function &F) override { GlobalsMetadata &GlobalsMD = getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); + const StackSafetyGlobalInfo *const SSGI = + ClUseStackSafety + ? &getAnalysis<StackSafetyGlobalInfoWrapperPass>().getResult() + : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - AddressSanitizer ASan(*F.getParent(), &GlobalsMD, CompileKernel, Recover, - UseAfterScope, UseAfterReturn); + AddressSanitizer ASan(*F.getParent(), &GlobalsMD, SSGI, CompileKernel, + Recover, UseAfterScope, UseAfterReturn); return ASan.instrumentFunction(F, TLI); } @@ -1212,20 +1269,15 @@ GlobalsMetadata ASanGlobalsMetadataAnalysis::run(Module &M, return GlobalsMetadata(M); } -AddressSanitizerPass::AddressSanitizerPass( - bool CompileKernel, bool Recover, bool UseAfterScope, - AsanDetectStackUseAfterReturnMode UseAfterReturn) - : CompileKernel(CompileKernel), Recover(Recover), - UseAfterScope(UseAfterScope), UseAfterReturn(UseAfterReturn) {} - PreservedAnalyses AddressSanitizerPass::run(Function &F, AnalysisManager<Function> &AM) { auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); Module &M = *F.getParent(); if (auto *R = MAMProxy.getCachedResult<ASanGlobalsMetadataAnalysis>(M)) { const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); - AddressSanitizer Sanitizer(M, R, CompileKernel, Recover, UseAfterScope, - UseAfterReturn); + AddressSanitizer Sanitizer(M, R, nullptr, Options.CompileKernel, + Options.Recover, Options.UseAfterScope, + Options.UseAfterReturn); if (Sanitizer.instrumentFunction(F, TLI)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); @@ -1237,21 +1289,51 @@ PreservedAnalyses AddressSanitizerPass::run(Function &F, return PreservedAnalyses::all(); } +void AddressSanitizerPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<AddressSanitizerPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (Options.CompileKernel) + OS << "kernel"; + OS << ">"; +} + +void ModuleAddressSanitizerPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<ModuleAddressSanitizerPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (Options.CompileKernel) + OS << "kernel"; + OS << ">"; +} + ModuleAddressSanitizerPass::ModuleAddressSanitizerPass( - bool CompileKernel, bool Recover, bool UseGlobalGC, bool UseOdrIndicator, - AsanDtorKind DestructorKind) - : CompileKernel(CompileKernel), Recover(Recover), UseGlobalGC(UseGlobalGC), + const AddressSanitizerOptions &Options, bool UseGlobalGC, + bool UseOdrIndicator, AsanDtorKind DestructorKind) + : Options(Options), UseGlobalGC(UseGlobalGC), UseOdrIndicator(UseOdrIndicator), DestructorKind(DestructorKind) {} PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, - AnalysisManager<Module> &AM) { - GlobalsMetadata &GlobalsMD = AM.getResult<ASanGlobalsMetadataAnalysis>(M); - ModuleAddressSanitizer Sanitizer(M, &GlobalsMD, CompileKernel, Recover, - UseGlobalGC, UseOdrIndicator, - DestructorKind); - if (Sanitizer.instrumentModule(M)) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); + ModuleAnalysisManager &MAM) { + GlobalsMetadata &GlobalsMD = MAM.getResult<ASanGlobalsMetadataAnalysis>(M); + ModuleAddressSanitizer ModuleSanitizer(M, &GlobalsMD, Options.CompileKernel, + Options.Recover, UseGlobalGC, + UseOdrIndicator, DestructorKind); + bool Modified = false; + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + const StackSafetyGlobalInfo *const SSGI = + ClUseStackSafety ? &MAM.getResult<StackSafetyGlobalAnalysis>(M) : nullptr; + for (Function &F : M) { + AddressSanitizer FunctionSanitizer( + M, &GlobalsMD, SSGI, Options.CompileKernel, Options.Recover, + Options.UseAfterScope, Options.UseAfterReturn); + const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + Modified |= FunctionSanitizer.instrumentFunction(F, &TLI); + } + Modified |= ModuleSanitizer.instrumentModule(M); + return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all(); } INITIALIZE_PASS(ASanGlobalsMetadataWrapperPass, "asan-globals-md", @@ -1266,6 +1348,7 @@ INITIALIZE_PASS_BEGIN( "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, false) INITIALIZE_PASS_DEPENDENCY(ASanGlobalsMetadataWrapperPass) +INITIALIZE_PASS_DEPENDENCY(StackSafetyGlobalInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END( AddressSanitizerLegacyPass, "asan", @@ -1404,7 +1487,7 @@ bool AddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { return IsInteresting; } -bool AddressSanitizer::ignoreAccess(Value *Ptr) { +bool AddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { // Instrument acesses from different address spaces only for AMDGPU. Type *PtrTy = cast<PointerType>(Ptr->getType()->getScalarType()); if (PtrTy->getPointerAddressSpace() != 0 && @@ -1425,6 +1508,10 @@ bool AddressSanitizer::ignoreAccess(Value *Ptr) { if (ClSkipPromotableAllocas && !isInterestingAlloca(*AI)) return true; + if (SSGI != nullptr && SSGI->stackAccessIsSafe(*Inst) && + findAllocaForValue(Ptr)) + return true; + return false; } @@ -1439,22 +1526,22 @@ void AddressSanitizer::getInterestingMemoryOperands( return; if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - if (!ClInstrumentReads || ignoreAccess(LI->getPointerOperand())) + if (!ClInstrumentReads || ignoreAccess(LI, LI->getPointerOperand())) return; Interesting.emplace_back(I, LI->getPointerOperandIndex(), false, LI->getType(), LI->getAlign()); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { - if (!ClInstrumentWrites || ignoreAccess(SI->getPointerOperand())) + if (!ClInstrumentWrites || ignoreAccess(LI, SI->getPointerOperand())) return; Interesting.emplace_back(I, SI->getPointerOperandIndex(), true, SI->getValueOperand()->getType(), SI->getAlign()); } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { - if (!ClInstrumentAtomics || ignoreAccess(RMW->getPointerOperand())) + if (!ClInstrumentAtomics || ignoreAccess(LI, RMW->getPointerOperand())) return; Interesting.emplace_back(I, RMW->getPointerOperandIndex(), true, RMW->getValOperand()->getType(), None); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { - if (!ClInstrumentAtomics || ignoreAccess(XCHG->getPointerOperand())) + if (!ClInstrumentAtomics || ignoreAccess(LI, XCHG->getPointerOperand())) return; Interesting.emplace_back(I, XCHG->getPointerOperandIndex(), true, XCHG->getCompareOperand()->getType(), None); @@ -1469,7 +1556,7 @@ void AddressSanitizer::getInterestingMemoryOperands( return; auto BasePtr = CI->getOperand(OpOffset); - if (ignoreAccess(BasePtr)) + if (ignoreAccess(LI, BasePtr)) return; auto Ty = cast<PointerType>(BasePtr->getType())->getElementType(); MaybeAlign Alignment = Align(1); @@ -1479,9 +1566,9 @@ void AddressSanitizer::getInterestingMemoryOperands( Value *Mask = CI->getOperand(2 + OpOffset); Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, Mask); } else { - for (unsigned ArgNo = 0; ArgNo < CI->getNumArgOperands(); ArgNo++) { + for (unsigned ArgNo = 0; ArgNo < CI->arg_size(); ArgNo++) { if (!ClInstrumentByval || !CI->isByValArgument(ArgNo) || - ignoreAccess(CI->getArgOperand(ArgNo))) + ignoreAccess(LI, CI->getArgOperand(ArgNo))) continue; Type *Ty = CI->getParamByValType(ArgNo); Interesting.emplace_back(I, ArgNo, false, Ty, Align(1)); @@ -1738,9 +1825,20 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, } IRBuilder<> IRB(InsertBefore); - Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize); + const ASanAccessInfo AccessInfo(IsWrite, CompileKernel, AccessSizeIndex); + + if (UseCalls && ClOptimizeCallbacks) { + const ASanAccessInfo AccessInfo(IsWrite, CompileKernel, AccessSizeIndex); + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + IRB.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::asan_check_memaccess), + {IRB.CreatePointerCast(Addr, Int8PtrTy), + ConstantInt::get(Int32Ty, AccessInfo.Packed)}); + return; + } + Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (UseCalls) { if (Exp == 0) IRB.CreateCall(AsanMemoryAccessCallback[IsWrite][0][AccessSizeIndex], @@ -1936,7 +2034,8 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { // Globals from llvm.metadata aren't emitted, do not instrument them. if (Section == "llvm.metadata") return false; // Do not instrument globals from special LLVM sections. - if (Section.find("__llvm") != StringRef::npos || Section.find("__LLVM") != StringRef::npos) return false; + if (Section.contains("__llvm") || Section.contains("__LLVM")) + return false; // Do not instrument function pointers to initialization and termination // routines: dynamic linker will not properly handle redzones. @@ -2133,8 +2232,7 @@ Instruction *ModuleAddressSanitizer::CreateAsanModuleDtor(Module &M) { AsanDtorFunction = Function::createWithDefaultAttr( FunctionType::get(Type::getVoidTy(*C), false), GlobalValue::InternalLinkage, 0, kAsanModuleDtorName, &M); - AsanDtorFunction->addAttribute(AttributeList::FunctionIndex, - Attribute::NoUnwind); + AsanDtorFunction->addFnAttr(Attribute::NoUnwind); // Ensure Dtor cannot be discarded, even if in a comdat. appendToUsed(M, {AsanDtorFunction}); BasicBlock *AsanDtorBB = BasicBlock::Create(*C, "", AsanDtorFunction); @@ -2753,7 +2851,7 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); if (II && II->getIntrinsicID() == Intrinsic::localescape) { // We found a call. Mark all the allocas passed in as uninteresting. - for (Value *Arg : II->arg_operands()) { + for (Value *Arg : II->args()) { AllocaInst *AI = dyn_cast<AllocaInst>(Arg->stripPointerCasts()); assert(AI && AI->isStaticAlloca() && "non-static alloca arg to localescape"); @@ -2774,6 +2872,8 @@ bool AddressSanitizer::suppressInstrumentationSiteForDebug(int &Instrumented) { bool AddressSanitizer::instrumentFunction(Function &F, const TargetLibraryInfo *TLI) { + if (F.empty()) + return false; if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false; if (F.getName().startswith("__asan_")) return false; @@ -2916,7 +3016,8 @@ bool AddressSanitizer::LooksLikeCodeInBug11395(Instruction *I) { if (LongSize != 32) return false; CallInst *CI = dyn_cast<CallInst>(I); if (!CI || !CI->isInlineAsm()) return false; - if (CI->getNumArgOperands() <= 5) return false; + if (CI->arg_size() <= 5) + return false; // We have inline assembly with quite a few arguments. return true; } @@ -3112,7 +3213,7 @@ Value *FunctionStackPoisoner::createAllocaForLayout( assert(Alloca->isStaticAlloca()); } assert((ClRealignStack & (ClRealignStack - 1)) == 0); - size_t FrameAlignment = std::max(L.FrameAlignment, (size_t)ClRealignStack); + uint64_t FrameAlignment = std::max(L.FrameAlignment, uint64_t(ClRealignStack)); Alloca->setAlignment(Align(FrameAlignment)); return IRB.CreatePointerCast(Alloca, IntptrTy); } @@ -3256,8 +3357,8 @@ void FunctionStackPoisoner::processStaticAllocas() { // Minimal header size (left redzone) is 4 pointers, // i.e. 32 bytes on 64-bit platforms and 16 bytes in 32-bit platforms. - size_t Granularity = 1ULL << Mapping.Scale; - size_t MinHeaderSize = std::max((size_t)ASan.LongSize / 2, Granularity); + uint64_t Granularity = 1ULL << Mapping.Scale; + uint64_t MinHeaderSize = std::max((uint64_t)ASan.LongSize / 2, Granularity); const ASanStackFrameLayout &L = ComputeASanStackFrameLayout(SVD, Granularity, MinHeaderSize); @@ -3511,7 +3612,7 @@ void FunctionStackPoisoner::poisonAlloca(Value *V, uint64_t Size, void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { IRBuilder<> IRB(AI); - const unsigned Alignment = std::max(kAllocaRzSize, AI->getAlignment()); + const uint64_t Alignment = std::max(kAllocaRzSize, AI->getAlignment()); const uint64_t AllocaRedzoneMask = kAllocaRzSize - 1; Value *Zero = Constant::getNullValue(IntptrTy); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp index 9acd82c005e6..1a7f7a365ce4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp @@ -53,6 +53,8 @@ static bool runCGProfilePass( InstrProfSymtab Symtab; auto UpdateCounts = [&](TargetTransformInfo &TTI, Function *F, Function *CalledF, uint64_t NewCount) { + if (NewCount == 0) + return; if (!CalledF || !TTI.isLoweredToCall(CalledF) || CalledF->hasDLLImportStorageClass()) return; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 3b4d80dc8023..497aac30c3f6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -1553,11 +1553,11 @@ static bool negateICmpIfUsedByBranchOrSelectOnly(ICmpInst *ICmp, SI->swapValues(); SI->swapProfMetadata(); if (Scope->TrueBiasedSelects.count(SI)) { - assert(Scope->FalseBiasedSelects.count(SI) == 0 && + assert(!Scope->FalseBiasedSelects.contains(SI) && "Must not be already in"); Scope->FalseBiasedSelects.insert(SI); } else if (Scope->FalseBiasedSelects.count(SI)) { - assert(Scope->TrueBiasedSelects.count(SI) == 0 && + assert(!Scope->TrueBiasedSelects.contains(SI) && "Must not be already in"); Scope->TrueBiasedSelects.insert(SI); } @@ -1592,7 +1592,7 @@ static void insertTrivialPHIs(CHRScope *Scope, SmallVector<Instruction *, 8> Users; for (User *U : I.users()) { if (auto *UI = dyn_cast<Instruction>(U)) { - if (BlocksInScope.count(UI->getParent()) == 0 && + if (!BlocksInScope.contains(UI->getParent()) && // Unless there's already a phi for I at the exit block. !(isa<PHINode>(UI) && UI->getParent() == ExitBlock)) { CHR_DEBUG(dbgs() << "V " << I << "\n"); @@ -1752,7 +1752,7 @@ void CHR::transformScopes(CHRScope *Scope, DenseSet<PHINode *> &TrivialPHIs) { // Create the combined branch condition and constant-fold the branches/selects // in the hot path. fixupBranchesAndSelects(Scope, PreEntryBlock, MergedBr, - ProfileCount ? ProfileCount.getValue() : 0); + ProfileCount.getValueOr(0)); } // A helper for transformScopes. Clone the blocks in the scope (excluding the diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 63aa84e4a77c..38c219ce3465 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -18,6 +18,9 @@ /// The analysis is based on automatic propagation of data flow labels (also /// known as taint labels) through a program as it performs computation. /// +/// Argument and return value labels are passed through TLS variables +/// __dfsan_arg_tls and __dfsan_retval_tls. +/// /// Each byte of application memory is backed by a shadow memory byte. The /// shadow byte can represent up to 8 labels. On Linux/x86_64, memory is then /// laid out as follows: @@ -144,20 +147,22 @@ static cl::opt<bool> ClPreserveAlignment( // to the "native" (i.e. unsanitized) ABI. Unless the ABI list contains // additional annotations for those functions, a call to one of those functions // will produce a warning message, as the labelling behaviour of the function is -// unknown. The other supported annotations are "functional" and "discard", -// which are described below under DataFlowSanitizer::WrapperKind. +// unknown. The other supported annotations for uninstrumented functions are +// "functional" and "discard", which are described below under +// DataFlowSanitizer::WrapperKind. +// Functions will often be labelled with both "uninstrumented" and one of +// "functional" or "discard". This will leave the function unchanged by this +// pass, and create a wrapper function that will call the original. +// +// Instrumented functions can also be annotated as "force_zero_labels", which +// will make all shadow and return values set zero labels. +// Functions should never be labelled with both "force_zero_labels" and +// "uninstrumented" or any of the unistrumented wrapper kinds. static cl::list<std::string> ClABIListFiles( "dfsan-abilist", cl::desc("File listing native ABI functions and how the pass treats them"), cl::Hidden); -// Controls whether the pass uses IA_Args or IA_TLS as the ABI for instrumented -// functions (see DataFlowSanitizer::InstrumentedABI below). -static cl::opt<bool> - ClArgsABI("dfsan-args-abi", - cl::desc("Use the argument ABI rather than the TLS ABI"), - cl::Hidden); - // Controls whether the pass includes or ignores the labels of pointers in load // instructions. static cl::opt<bool> ClCombinePointerLabelsOnLoad( @@ -349,18 +354,18 @@ transformFunctionAttributes(const TransformedFunction &TransformedFunction, for (unsigned I = 0, IE = TransformedFunction.ArgumentIndexMapping.size(); I < IE; ++I) { unsigned TransformedIndex = TransformedFunction.ArgumentIndexMapping[I]; - ArgumentAttributes[TransformedIndex] = CallSiteAttrs.getParamAttributes(I); + ArgumentAttributes[TransformedIndex] = CallSiteAttrs.getParamAttrs(I); } // Copy annotations on varargs arguments. for (unsigned I = TransformedFunction.OriginalType->getNumParams(), IE = CallSiteAttrs.getNumAttrSets(); I < IE; ++I) { - ArgumentAttributes.push_back(CallSiteAttrs.getParamAttributes(I)); + ArgumentAttributes.push_back(CallSiteAttrs.getParamAttrs(I)); } - return AttributeList::get(Ctx, CallSiteAttrs.getFnAttributes(), - CallSiteAttrs.getRetAttributes(), + return AttributeList::get(Ctx, CallSiteAttrs.getFnAttrs(), + CallSiteAttrs.getRetAttrs(), llvm::makeArrayRef(ArgumentAttributes)); } @@ -372,17 +377,6 @@ class DataFlowSanitizer { enum { OriginWidthBits = 32, OriginWidthBytes = OriginWidthBits / 8 }; - /// Which ABI should be used for instrumented functions? - enum InstrumentedABI { - /// Argument and return value labels are passed through additional - /// arguments and by modifying the return type. - IA_Args, - - /// Argument and return value labels are passed through TLS variables - /// __dfsan_arg_tls and __dfsan_retval_tls. - IA_TLS - }; - /// How should calls to uninstrumented functions be handled? enum WrapperKind { /// This function is present in an uninstrumented form but we don't know @@ -400,9 +394,7 @@ class DataFlowSanitizer { /// Instead of calling the function, a custom wrapper __dfsw_F is called, /// where F is the name of the function. This function may wrap the - /// original function or provide its own implementation. This is similar to - /// the IA_Args ABI, except that IA_Args uses a struct return type to - /// pass the return value shadow in a register, while WK_Custom uses an + /// original function or provide its own implementation. WK_Custom uses an /// extra pointer argument to return the shadow. This allows the wrapped /// form of the function type to be expressed in C. WK_Custom @@ -469,10 +461,9 @@ class DataFlowSanitizer { getShadowOriginAddress(Value *Addr, Align InstAlignment, Instruction *Pos); bool isInstrumented(const Function *F); bool isInstrumented(const GlobalAlias *GA); - FunctionType *getArgsFunctionType(FunctionType *T); + bool isForceZeroLabels(const Function *F); FunctionType *getTrampolineFunctionType(FunctionType *T); TransformedFunction getCustomFunctionType(FunctionType *T); - InstrumentedABI getInstrumentedABI(); WrapperKind getWrapperKind(Function *F); void addGlobalNameSuffix(GlobalValue *GV); Function *buildWrapperFunction(Function *F, StringRef NewFName, @@ -496,18 +487,11 @@ class DataFlowSanitizer { /// Returns whether the pass tracks origins. Supports only TLS ABI mode. bool shouldTrackOrigins(); - /// Returns whether the pass tracks labels for struct fields and array - /// indices. Supports only TLS ABI mode. - bool shouldTrackFieldsAndIndices(); - /// Returns a zero constant with the shadow type of OrigTy. /// /// getZeroShadow({T1,T2,...}) = {getZeroShadow(T1),getZeroShadow(T2,...} /// getZeroShadow([n x T]) = [n x getZeroShadow(T)] /// getZeroShadow(other type) = i16(0) - /// - /// Note that a zero shadow is always i16(0) when shouldTrackFieldsAndIndices - /// returns false. Constant *getZeroShadow(Type *OrigTy); /// Returns a zero constant with the shadow type of V's type. Constant *getZeroShadow(Value *V); @@ -520,9 +504,6 @@ class DataFlowSanitizer { /// getShadowTy({T1,T2,...}) = {getShadowTy(T1),getShadowTy(T2),...} /// getShadowTy([n x T]) = [n x getShadowTy(T)] /// getShadowTy(other type) = i16 - /// - /// Note that a shadow type is always i16 when shouldTrackFieldsAndIndices - /// returns false. Type *getShadowTy(Type *OrigTy); /// Returns the shadow type of of V's type. Type *getShadowTy(Value *V); @@ -539,8 +520,8 @@ struct DFSanFunction { DataFlowSanitizer &DFS; Function *F; DominatorTree DT; - DataFlowSanitizer::InstrumentedABI IA; bool IsNativeABI; + bool IsForceZeroLabels; AllocaInst *LabelReturnAlloca = nullptr; AllocaInst *OriginReturnAlloca = nullptr; DenseMap<Value *, Value *> ValShadowMap; @@ -571,8 +552,10 @@ struct DFSanFunction { DenseMap<Value *, Value *> CachedCollapsedShadows; DenseMap<Value *, std::set<Value *>> ShadowElements; - DFSanFunction(DataFlowSanitizer &DFS, Function *F, bool IsNativeABI) - : DFS(DFS), F(F), IA(DFS.getInstrumentedABI()), IsNativeABI(IsNativeABI) { + DFSanFunction(DataFlowSanitizer &DFS, Function *F, bool IsNativeABI, + bool IsForceZeroLabels) + : DFS(DFS), F(F), IsNativeABI(IsNativeABI), + IsForceZeroLabels(IsForceZeroLabels) { DT.recalculate(*F); } @@ -787,17 +770,6 @@ DataFlowSanitizer::DataFlowSanitizer( SpecialCaseList::createOrDie(AllABIListFiles, *vfs::getRealFileSystem())); } -FunctionType *DataFlowSanitizer::getArgsFunctionType(FunctionType *T) { - SmallVector<Type *, 4> ArgTypes(T->param_begin(), T->param_end()); - ArgTypes.append(T->getNumParams(), PrimitiveShadowTy); - if (T->isVarArg()) - ArgTypes.push_back(PrimitiveShadowPtrTy); - Type *RetType = T->getReturnType(); - if (!RetType->isVoidTy()) - RetType = StructType::get(RetType, PrimitiveShadowTy); - return FunctionType::get(RetType, ArgTypes, T->isVarArg()); -} - FunctionType *DataFlowSanitizer::getTrampolineFunctionType(FunctionType *T) { assert(!T->isVarArg()); SmallVector<Type *, 4> ArgTypes; @@ -861,9 +833,6 @@ TransformedFunction DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { } bool DataFlowSanitizer::isZeroShadow(Value *V) { - if (!shouldTrackFieldsAndIndices()) - return ZeroPrimitiveShadow == V; - Type *T = V->getType(); if (!isa<ArrayType>(T) && !isa<StructType>(T)) { if (const ConstantInt *CI = dyn_cast<ConstantInt>(V)) @@ -880,19 +849,11 @@ bool DataFlowSanitizer::hasLoadSizeForFastPath(uint64_t Size) { } bool DataFlowSanitizer::shouldTrackOrigins() { - static const bool ShouldTrackOrigins = - ClTrackOrigins && getInstrumentedABI() == DataFlowSanitizer::IA_TLS; + static const bool ShouldTrackOrigins = ClTrackOrigins; return ShouldTrackOrigins; } -bool DataFlowSanitizer::shouldTrackFieldsAndIndices() { - return getInstrumentedABI() == DataFlowSanitizer::IA_TLS; -} - Constant *DataFlowSanitizer::getZeroShadow(Type *OrigTy) { - if (!shouldTrackFieldsAndIndices()) - return ZeroPrimitiveShadow; - if (!isa<ArrayType>(OrigTy) && !isa<StructType>(OrigTy)) return ZeroPrimitiveShadow; Type *ShadowTy = getShadowTy(OrigTy); @@ -992,8 +953,6 @@ Value *DFSanFunction::collapseToPrimitiveShadow(Value *Shadow, if (!isa<ArrayType>(ShadowTy) && !isa<StructType>(ShadowTy)) return Shadow; - assert(DFS.shouldTrackFieldsAndIndices()); - // Checks if the cached collapsed shadow value dominates Pos. Value *&CS = CachedCollapsedShadows[Shadow]; if (CS && DT.dominates(CS, Pos)) @@ -1007,9 +966,6 @@ Value *DFSanFunction::collapseToPrimitiveShadow(Value *Shadow, } Type *DataFlowSanitizer::getShadowTy(Type *OrigTy) { - if (!shouldTrackFieldsAndIndices()) - return PrimitiveShadowTy; - if (!OrigTy->isSized()) return PrimitiveShadowTy; if (isa<IntegerType>(OrigTy)) @@ -1107,8 +1063,8 @@ bool DataFlowSanitizer::isInstrumented(const GlobalAlias *GA) { return !ABIList.isIn(*GA, "uninstrumented"); } -DataFlowSanitizer::InstrumentedABI DataFlowSanitizer::getInstrumentedABI() { - return ClArgsABI ? IA_Args : IA_TLS; +bool DataFlowSanitizer::isForceZeroLabels(const Function *F) { + return ABIList.isIn(*F, "force_zero_labels"); } DataFlowSanitizer::WrapperKind DataFlowSanitizer::getWrapperKind(Function *F) { @@ -1139,7 +1095,7 @@ void DataFlowSanitizer::addGlobalNameSuffix(GlobalValue *GV) { Pos = Asm.find("@"); if (Pos == std::string::npos) - report_fatal_error("unsupported .symver: " + Asm); + report_fatal_error(Twine("unsupported .symver: ", Asm)); Asm.replace(Pos, 1, Suffix + "@"); GV->getParent()->setModuleInlineAsm(Asm); @@ -1154,14 +1110,12 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, Function *NewF = Function::Create(NewFT, NewFLink, F->getAddressSpace(), NewFName, F->getParent()); NewF->copyAttributesFrom(F); - NewF->removeAttributes( - AttributeList::ReturnIndex, + NewF->removeRetAttrs( AttributeFuncs::typeIncompatible(NewFT->getReturnType())); BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", NewF); if (F->isVarArg()) { - NewF->removeAttributes(AttributeList::FunctionIndex, - AttrBuilder().addAttribute("split-stack")); + NewF->removeFnAttrs(AttrBuilder().addAttribute("split-stack")); CallInst::Create(DFSanVarargWrapperFn, IRBuilder<>(BB).CreateGlobalStringPtr(F->getName()), "", BB); @@ -1199,7 +1153,8 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, // F is called by a wrapped custom function with primitive shadows. So // its arguments and return value need conversion. - DFSanFunction DFSF(*this, F, /*IsNativeABI=*/true); + DFSanFunction DFSF(*this, F, /*IsNativeABI=*/true, + /*ForceZeroLabels=*/false); Function::arg_iterator ValAI = F->arg_begin(), ShadowAI = AI; ++ValAI; for (unsigned N = FT->getNumParams(); N != 0; ++ValAI, ++ShadowAI, --N) { @@ -1238,23 +1193,17 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { { AttributeList AL; - AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, - Attribute::NoUnwind); - AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, - Attribute::ReadOnly); - AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, - Attribute::ZExt); + AL = AL.addFnAttribute(M.getContext(), Attribute::NoUnwind); + AL = AL.addFnAttribute(M.getContext(), Attribute::ReadOnly); + AL = AL.addRetAttribute(M.getContext(), Attribute::ZExt); DFSanUnionLoadFn = Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy, AL); } { AttributeList AL; - AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, - Attribute::NoUnwind); - AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, - Attribute::ReadOnly); - AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, - Attribute::ZExt); + AL = AL.addFnAttribute(M.getContext(), Attribute::NoUnwind); + AL = AL.addFnAttribute(M.getContext(), Attribute::ReadOnly); + AL = AL.addRetAttribute(M.getContext(), Attribute::ZExt); DFSanLoadLabelAndOriginFn = Mod->getOrInsertFunction( "__dfsan_load_label_and_origin", DFSanLoadLabelAndOriginFnTy, AL); } @@ -1274,8 +1223,7 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { { AttributeList AL; AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); - AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, - Attribute::ZExt); + AL = AL.addRetAttribute(M.getContext(), Attribute::ZExt); DFSanChainOriginFn = Mod->getOrInsertFunction("__dfsan_chain_origin", DFSanChainOriginFnTy, AL); } @@ -1283,8 +1231,7 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { AttributeList AL; AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); - AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, - Attribute::ZExt); + AL = AL.addRetAttribute(M.getContext(), Attribute::ZExt); DFSanChainOriginIfTaintedFn = Mod->getOrInsertFunction( "__dfsan_chain_origin_if_tainted", DFSanChainOriginIfTaintedFnTy, AL); } @@ -1409,34 +1356,32 @@ bool DataFlowSanitizer::runImpl(Module &M) { std::vector<Function *> FnsToInstrument; SmallPtrSet<Function *, 2> FnsWithNativeABI; + SmallPtrSet<Function *, 2> FnsWithForceZeroLabel; for (Function &F : M) if (!F.isIntrinsic() && !DFSanRuntimeFunctions.contains(&F)) FnsToInstrument.push_back(&F); // Give function aliases prefixes when necessary, and build wrappers where the // instrumentedness is inconsistent. - for (Module::alias_iterator AI = M.alias_begin(), AE = M.alias_end(); - AI != AE;) { - GlobalAlias *GA = &*AI; - ++AI; + for (GlobalAlias &GA : llvm::make_early_inc_range(M.aliases())) { // Don't stop on weak. We assume people aren't playing games with the // instrumentedness of overridden weak aliases. - auto *F = dyn_cast<Function>(GA->getBaseObject()); + auto *F = dyn_cast<Function>(GA.getAliaseeObject()); if (!F) continue; - bool GAInst = isInstrumented(GA), FInst = isInstrumented(F); + bool GAInst = isInstrumented(&GA), FInst = isInstrumented(F); if (GAInst && FInst) { - addGlobalNameSuffix(GA); + addGlobalNameSuffix(&GA); } else if (GAInst != FInst) { // Non-instrumented alias of an instrumented function, or vice versa. // Replace the alias with a native-ABI wrapper of the aliasee. The pass // below will take care of instrumenting it. Function *NewF = - buildWrapperFunction(F, "", GA->getLinkage(), F->getFunctionType()); - GA->replaceAllUsesWith(ConstantExpr::getBitCast(NewF, GA->getType())); - NewF->takeName(GA); - GA->eraseFromParent(); + buildWrapperFunction(F, "", GA.getLinkage(), F->getFunctionType()); + GA.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, GA.getType())); + NewF->takeName(&GA); + GA.eraseFromParent(); FnsToInstrument.push_back(NewF); } } @@ -1456,50 +1401,17 @@ bool DataFlowSanitizer::runImpl(Module &M) { FT->getReturnType()->isVoidTy()); if (isInstrumented(&F)) { + if (isForceZeroLabels(&F)) + FnsWithForceZeroLabel.insert(&F); + // Instrumented functions get a '.dfsan' suffix. This allows us to more // easily identify cases of mismatching ABIs. This naming scheme is // mangling-compatible (see Itanium ABI), using a vendor-specific suffix. - if (getInstrumentedABI() == IA_Args && !IsZeroArgsVoidRet) { - FunctionType *NewFT = getArgsFunctionType(FT); - Function *NewF = Function::Create(NewFT, F.getLinkage(), - F.getAddressSpace(), "", &M); - NewF->copyAttributesFrom(&F); - NewF->removeAttributes( - AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewFT->getReturnType())); - for (Function::arg_iterator FArg = F.arg_begin(), - NewFArg = NewF->arg_begin(), - FArgEnd = F.arg_end(); - FArg != FArgEnd; ++FArg, ++NewFArg) { - FArg->replaceAllUsesWith(&*NewFArg); - } - NewF->getBasicBlockList().splice(NewF->begin(), F.getBasicBlockList()); - - for (Function::user_iterator UI = F.user_begin(), UE = F.user_end(); - UI != UE;) { - BlockAddress *BA = dyn_cast<BlockAddress>(*UI); - ++UI; - if (BA) { - BA->replaceAllUsesWith( - BlockAddress::get(NewF, BA->getBasicBlock())); - delete BA; - } - } - F.replaceAllUsesWith( - ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT))); - NewF->takeName(&F); - F.eraseFromParent(); - *FI = NewF; - addGlobalNameSuffix(NewF); - } else { - addGlobalNameSuffix(&F); - } + addGlobalNameSuffix(&F); } else if (!IsZeroArgsVoidRet || getWrapperKind(&F) == WK_Custom) { // Build a wrapper function for F. The wrapper simply calls F, and is // added to FnsToInstrument so that any instrumentation according to its // WrapperKind is done in the second pass below. - FunctionType *NewFT = - getInstrumentedABI() == IA_Args ? getArgsFunctionType(FT) : FT; // If the function being wrapped has local linkage, then preserve the // function's linkage in the wrapper function. @@ -1511,9 +1423,8 @@ bool DataFlowSanitizer::runImpl(Module &M) { &F, (shouldTrackOrigins() ? std::string("dfso$") : std::string("dfsw$")) + std::string(F.getName()), - WrapperLinkage, NewFT); - if (getInstrumentedABI() == IA_TLS) - NewF->removeAttributes(AttributeList::FunctionIndex, ReadOnlyNoneAttrs); + WrapperLinkage, FT); + NewF->removeFnAttrs(ReadOnlyNoneAttrs); Value *WrappedFnCst = ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT)); @@ -1552,7 +1463,8 @@ bool DataFlowSanitizer::runImpl(Module &M) { removeUnreachableBlocks(*F); - DFSanFunction DFSF(*this, F, FnsWithNativeABI.count(F)); + DFSanFunction DFSF(*this, F, FnsWithNativeABI.count(F), + FnsWithForceZeroLabel.count(F)); // DFSanVisitor may create new basic blocks, which confuses df_iterator. // Build a copy of the list before iterating over it. @@ -1649,23 +1561,14 @@ Value *DFSanFunction::getOrigin(Value *V) { if (Argument *A = dyn_cast<Argument>(V)) { if (IsNativeABI) return DFS.ZeroOrigin; - switch (IA) { - case DataFlowSanitizer::IA_TLS: { - if (A->getArgNo() < DFS.NumOfElementsInArgOrgTLS) { - Instruction *ArgOriginTLSPos = &*F->getEntryBlock().begin(); - IRBuilder<> IRB(ArgOriginTLSPos); - Value *ArgOriginPtr = getArgOriginTLS(A->getArgNo(), IRB); - Origin = IRB.CreateLoad(DFS.OriginTy, ArgOriginPtr); - } else { - // Overflow - Origin = DFS.ZeroOrigin; - } - break; - } - case DataFlowSanitizer::IA_Args: { + if (A->getArgNo() < DFS.NumOfElementsInArgOrgTLS) { + Instruction *ArgOriginTLSPos = &*F->getEntryBlock().begin(); + IRBuilder<> IRB(ArgOriginTLSPos); + Value *ArgOriginPtr = getArgOriginTLS(A->getArgNo(), IRB); + Origin = IRB.CreateLoad(DFS.OriginTy, ArgOriginPtr); + } else { + // Overflow Origin = DFS.ZeroOrigin; - break; - } } } else { Origin = DFS.ZeroOrigin; @@ -1716,25 +1619,14 @@ Value *DFSanFunction::getShadowForTLSArgument(Argument *A) { Value *DFSanFunction::getShadow(Value *V) { if (!isa<Argument>(V) && !isa<Instruction>(V)) return DFS.getZeroShadow(V); + if (IsForceZeroLabels) + return DFS.getZeroShadow(V); Value *&Shadow = ValShadowMap[V]; if (!Shadow) { if (Argument *A = dyn_cast<Argument>(V)) { if (IsNativeABI) return DFS.getZeroShadow(V); - switch (IA) { - case DataFlowSanitizer::IA_TLS: { - Shadow = getShadowForTLSArgument(A); - break; - } - case DataFlowSanitizer::IA_Args: { - unsigned ArgIdx = A->getArgNo() + F->arg_size() / 2; - Function::arg_iterator Arg = F->arg_begin(); - std::advance(Arg, ArgIdx); - Shadow = &*Arg; - assert(Shadow->getType() == DFS.PrimitiveShadowTy); - break; - } - } + Shadow = getShadowForTLSArgument(A); NonZeroChecks.push_back(Shadow); } else { Shadow = DFS.getZeroShadow(V); @@ -1745,8 +1637,6 @@ Value *DFSanFunction::getShadow(Value *V) { void DFSanFunction::setShadow(Instruction *I, Value *Shadow) { assert(!ValShadowMap.count(I)); - assert(DFS.shouldTrackFieldsAndIndices() || - Shadow->getType() == DFS.PrimitiveShadowTy); ValShadowMap[I] = Shadow; } @@ -2124,7 +2014,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( IRB.CreateCall(DFS.DFSanLoadLabelAndOriginFn, {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), ConstantInt::get(DFS.IntptrTy, Size)}); - Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + Call->addRetAttr(Attribute::ZExt); return {IRB.CreateTrunc(IRB.CreateLShr(Call, DFS.OriginWidthBits), DFS.PrimitiveShadowTy), IRB.CreateTrunc(Call, DFS.OriginTy)}; @@ -2171,7 +2061,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( IRBuilder<> IRB(Pos); CallInst *FallbackCall = IRB.CreateCall( DFS.DFSanUnionLoadFn, {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)}); - FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + FallbackCall->addRetAttr(Attribute::ZExt); return {FallbackCall, Origin}; } @@ -2563,15 +2453,12 @@ void DFSanVisitor::visitBinaryOperator(BinaryOperator &BO) { } void DFSanVisitor::visitBitCastInst(BitCastInst &BCI) { - if (DFSF.DFS.getInstrumentedABI() == DataFlowSanitizer::IA_TLS) { - // Special case: if this is the bitcast (there is exactly 1 allowed) between - // a musttail call and a ret, don't instrument. New instructions are not - // allowed after a musttail call. - if (auto *CI = dyn_cast<CallInst>(BCI.getOperand(0))) - if (CI->isMustTailCall()) - return; - } - // TODO: handle musttail call returns for IA_Args. + // Special case: if this is the bitcast (there is exactly 1 allowed) between + // a musttail call and a ret, don't instrument. New instructions are not + // allowed after a musttail call. + if (auto *CI = dyn_cast<CallInst>(BCI.getOperand(0))) + if (CI->isMustTailCall()) + return; visitInstOperands(BCI); } @@ -2629,11 +2516,6 @@ void DFSanVisitor::visitShuffleVectorInst(ShuffleVectorInst &I) { } void DFSanVisitor::visitExtractValueInst(ExtractValueInst &I) { - if (!DFSF.DFS.shouldTrackFieldsAndIndices()) { - visitInstOperands(I); - return; - } - IRBuilder<> IRB(&I); Value *Agg = I.getAggregateOperand(); Value *AggShadow = DFSF.getShadow(Agg); @@ -2643,11 +2525,6 @@ void DFSanVisitor::visitExtractValueInst(ExtractValueInst &I) { } void DFSanVisitor::visitInsertValueInst(InsertValueInst &I) { - if (!DFSF.DFS.shouldTrackFieldsAndIndices()) { - visitInstOperands(I); - return; - } - IRBuilder<> IRB(&I); Value *AggShadow = DFSF.getShadow(I.getAggregateOperand()); Value *InsShadow = DFSF.getShadow(I.getInsertedValueOperand()); @@ -2798,41 +2675,22 @@ static bool isAMustTailRetVal(Value *RetVal) { void DFSanVisitor::visitReturnInst(ReturnInst &RI) { if (!DFSF.IsNativeABI && RI.getReturnValue()) { - switch (DFSF.IA) { - case DataFlowSanitizer::IA_TLS: { - // Don't emit the instrumentation for musttail call returns. - if (isAMustTailRetVal(RI.getReturnValue())) - return; - - Value *S = DFSF.getShadow(RI.getReturnValue()); - IRBuilder<> IRB(&RI); - Type *RT = DFSF.F->getFunctionType()->getReturnType(); - unsigned Size = - getDataLayout().getTypeAllocSize(DFSF.DFS.getShadowTy(RT)); - if (Size <= RetvalTLSSize) { - // If the size overflows, stores nothing. At callsite, oversized return - // shadows are set to zero. - IRB.CreateAlignedStore(S, DFSF.getRetvalTLS(RT, IRB), - ShadowTLSAlignment); - } - if (DFSF.DFS.shouldTrackOrigins()) { - Value *O = DFSF.getOrigin(RI.getReturnValue()); - IRB.CreateStore(O, DFSF.getRetvalOriginTLS()); - } - break; - } - case DataFlowSanitizer::IA_Args: { - // TODO: handle musttail call returns for IA_Args. - - IRBuilder<> IRB(&RI); - Type *RT = DFSF.F->getFunctionType()->getReturnType(); - Value *InsVal = - IRB.CreateInsertValue(UndefValue::get(RT), RI.getReturnValue(), 0); - Value *InsShadow = - IRB.CreateInsertValue(InsVal, DFSF.getShadow(RI.getReturnValue()), 1); - RI.setOperand(0, InsShadow); - break; + // Don't emit the instrumentation for musttail call returns. + if (isAMustTailRetVal(RI.getReturnValue())) + return; + + Value *S = DFSF.getShadow(RI.getReturnValue()); + IRBuilder<> IRB(&RI); + Type *RT = DFSF.F->getFunctionType()->getReturnType(); + unsigned Size = getDataLayout().getTypeAllocSize(DFSF.DFS.getShadowTy(RT)); + if (Size <= RetvalTLSSize) { + // If the size overflows, stores nothing. At callsite, oversized return + // shadows are set to zero. + IRB.CreateAlignedStore(S, DFSF.getRetvalTLS(RT, IRB), ShadowTLSAlignment); } + if (DFSF.DFS.shouldTrackOrigins()) { + Value *O = DFSF.getOrigin(RI.getReturnValue()); + IRB.CreateStore(O, DFSF.getRetvalOriginTLS()); } } } @@ -2953,8 +2811,7 @@ bool DFSanVisitor::visitWrappedCallBase(Function &F, CallBase &CB) { // Custom functions returning non-void will write to the return label. if (!FT->getReturnType()->isVoidTy()) { - CustomFn->removeAttributes(AttributeList::FunctionIndex, - DFSF.DFS.ReadOnlyNoneAttrs); + CustomFn->removeFnAttrs(DFSF.DFS.ReadOnlyNoneAttrs); } } @@ -3056,32 +2913,30 @@ void DFSanVisitor::visitCallBase(CallBase &CB) { const bool ShouldTrackOrigins = DFSF.DFS.shouldTrackOrigins(); FunctionType *FT = CB.getFunctionType(); - if (DFSF.DFS.getInstrumentedABI() == DataFlowSanitizer::IA_TLS) { - // Stores argument shadows. - unsigned ArgOffset = 0; - const DataLayout &DL = getDataLayout(); - for (unsigned I = 0, N = FT->getNumParams(); I != N; ++I) { - if (ShouldTrackOrigins) { - // Ignore overflowed origins - Value *ArgShadow = DFSF.getShadow(CB.getArgOperand(I)); - if (I < DFSF.DFS.NumOfElementsInArgOrgTLS && - !DFSF.DFS.isZeroShadow(ArgShadow)) - IRB.CreateStore(DFSF.getOrigin(CB.getArgOperand(I)), - DFSF.getArgOriginTLS(I, IRB)); - } + const DataLayout &DL = getDataLayout(); - unsigned Size = - DL.getTypeAllocSize(DFSF.DFS.getShadowTy(FT->getParamType(I))); - // Stop storing if arguments' size overflows. Inside a function, arguments - // after overflow have zero shadow values. - if (ArgOffset + Size > ArgTLSSize) - break; - IRB.CreateAlignedStore( - DFSF.getShadow(CB.getArgOperand(I)), - DFSF.getArgTLS(FT->getParamType(I), ArgOffset, IRB), - ShadowTLSAlignment); - ArgOffset += alignTo(Size, ShadowTLSAlignment); + // Stores argument shadows. + unsigned ArgOffset = 0; + for (unsigned I = 0, N = FT->getNumParams(); I != N; ++I) { + if (ShouldTrackOrigins) { + // Ignore overflowed origins + Value *ArgShadow = DFSF.getShadow(CB.getArgOperand(I)); + if (I < DFSF.DFS.NumOfElementsInArgOrgTLS && + !DFSF.DFS.isZeroShadow(ArgShadow)) + IRB.CreateStore(DFSF.getOrigin(CB.getArgOperand(I)), + DFSF.getArgOriginTLS(I, IRB)); } + + unsigned Size = + DL.getTypeAllocSize(DFSF.DFS.getShadowTy(FT->getParamType(I))); + // Stop storing if arguments' size overflows. Inside a function, arguments + // after overflow have zero shadow values. + if (ArgOffset + Size > ArgTLSSize) + break; + IRB.CreateAlignedStore(DFSF.getShadow(CB.getArgOperand(I)), + DFSF.getArgTLS(FT->getParamType(I), ArgOffset, IRB), + ShadowTLSAlignment); + ArgOffset += alignTo(Size, ShadowTLSAlignment); } Instruction *Next = nullptr; @@ -3099,99 +2954,31 @@ void DFSanVisitor::visitCallBase(CallBase &CB) { Next = CB.getNextNode(); } - if (DFSF.DFS.getInstrumentedABI() == DataFlowSanitizer::IA_TLS) { - // Don't emit the epilogue for musttail call returns. - if (isa<CallInst>(CB) && cast<CallInst>(CB).isMustTailCall()) - return; - - // Loads the return value shadow. - IRBuilder<> NextIRB(Next); - const DataLayout &DL = getDataLayout(); - unsigned Size = DL.getTypeAllocSize(DFSF.DFS.getShadowTy(&CB)); - if (Size > RetvalTLSSize) { - // Set overflowed return shadow to be zero. - DFSF.setShadow(&CB, DFSF.DFS.getZeroShadow(&CB)); - } else { - LoadInst *LI = NextIRB.CreateAlignedLoad( - DFSF.DFS.getShadowTy(&CB), DFSF.getRetvalTLS(CB.getType(), NextIRB), - ShadowTLSAlignment, "_dfsret"); - DFSF.SkipInsts.insert(LI); - DFSF.setShadow(&CB, LI); - DFSF.NonZeroChecks.push_back(LI); - } - - if (ShouldTrackOrigins) { - LoadInst *LI = NextIRB.CreateLoad( - DFSF.DFS.OriginTy, DFSF.getRetvalOriginTLS(), "_dfsret_o"); - DFSF.SkipInsts.insert(LI); - DFSF.setOrigin(&CB, LI); - } - } - } - - // Do all instrumentation for IA_Args down here to defer tampering with the - // CFG in a way that SplitEdge may be able to detect. - if (DFSF.DFS.getInstrumentedABI() == DataFlowSanitizer::IA_Args) { - // TODO: handle musttail call returns for IA_Args. - - FunctionType *NewFT = DFSF.DFS.getArgsFunctionType(FT); - Value *Func = - IRB.CreateBitCast(CB.getCalledOperand(), PointerType::getUnqual(NewFT)); - - const unsigned NumParams = FT->getNumParams(); - - // Copy original arguments. - auto *ArgIt = CB.arg_begin(), *ArgEnd = CB.arg_end(); - std::vector<Value *> Args(NumParams); - std::copy_n(ArgIt, NumParams, Args.begin()); - - // Add shadow arguments by transforming original arguments. - std::generate_n(std::back_inserter(Args), NumParams, - [&]() { return DFSF.getShadow(*ArgIt++); }); - - if (FT->isVarArg()) { - unsigned VarArgSize = CB.arg_size() - NumParams; - ArrayType *VarArgArrayTy = - ArrayType::get(DFSF.DFS.PrimitiveShadowTy, VarArgSize); - AllocaInst *VarArgShadow = - new AllocaInst(VarArgArrayTy, getDataLayout().getAllocaAddrSpace(), - "", &DFSF.F->getEntryBlock().front()); - Args.push_back(IRB.CreateConstGEP2_32(VarArgArrayTy, VarArgShadow, 0, 0)); - - // Copy remaining var args. - unsigned GepIndex = 0; - std::for_each(ArgIt, ArgEnd, [&](Value *Arg) { - IRB.CreateStore( - DFSF.getShadow(Arg), - IRB.CreateConstGEP2_32(VarArgArrayTy, VarArgShadow, 0, GepIndex++)); - Args.push_back(Arg); - }); - } + // Don't emit the epilogue for musttail call returns. + if (isa<CallInst>(CB) && cast<CallInst>(CB).isMustTailCall()) + return; - CallBase *NewCB; - if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) { - NewCB = IRB.CreateInvoke(NewFT, Func, II->getNormalDest(), - II->getUnwindDest(), Args); + // Loads the return value shadow. + IRBuilder<> NextIRB(Next); + unsigned Size = DL.getTypeAllocSize(DFSF.DFS.getShadowTy(&CB)); + if (Size > RetvalTLSSize) { + // Set overflowed return shadow to be zero. + DFSF.setShadow(&CB, DFSF.DFS.getZeroShadow(&CB)); } else { - NewCB = IRB.CreateCall(NewFT, Func, Args); - } - NewCB->setCallingConv(CB.getCallingConv()); - NewCB->setAttributes(CB.getAttributes().removeAttributes( - *DFSF.DFS.Ctx, AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCB->getType()))); - - if (Next) { - ExtractValueInst *ExVal = ExtractValueInst::Create(NewCB, 0, "", Next); - DFSF.SkipInsts.insert(ExVal); - ExtractValueInst *ExShadow = ExtractValueInst::Create(NewCB, 1, "", Next); - DFSF.SkipInsts.insert(ExShadow); - DFSF.setShadow(ExVal, ExShadow); - DFSF.NonZeroChecks.push_back(ExShadow); - - CB.replaceAllUsesWith(ExVal); + LoadInst *LI = NextIRB.CreateAlignedLoad( + DFSF.DFS.getShadowTy(&CB), DFSF.getRetvalTLS(CB.getType(), NextIRB), + ShadowTLSAlignment, "_dfsret"); + DFSF.SkipInsts.insert(LI); + DFSF.setShadow(&CB, LI); + DFSF.NonZeroChecks.push_back(LI); } - CB.eraseFromParent(); + if (ShouldTrackOrigins) { + LoadInst *LI = NextIRB.CreateLoad(DFSF.DFS.OriginTy, + DFSF.getRetvalOriginTLS(), "_dfsret_o"); + DFSF.SkipInsts.insert(LI); + DFSF.setOrigin(&CB, LI); + } } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index c99f2e66b1cc..325089fc4402 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -86,7 +86,7 @@ GCOVOptions GCOVOptions::getDefault() { Options.Atomic = AtomicCounter; if (DefaultGCOVVersion.size() != 4) { - llvm::report_fatal_error(std::string("Invalid -default-gcov-version: ") + + llvm::report_fatal_error(Twine("Invalid -default-gcov-version: ") + DefaultGCOVVersion); } memcpy(Options.Version, DefaultGCOVVersion.c_str(), 4); @@ -1373,12 +1373,16 @@ Function *GCOVProfiler::insertReset( BasicBlock *Entry = BasicBlock::Create(*Ctx, "entry", ResetF); IRBuilder<> Builder(Entry); + LLVMContext &C = Entry->getContext(); // Zero out the counters. for (const auto &I : CountersBySP) { GlobalVariable *GV = I.first; - Constant *Null = Constant::getNullValue(GV->getValueType()); - Builder.CreateStore(Null, GV); + auto *GVTy = cast<ArrayType>(GV->getValueType()); + Builder.CreateMemSet(GV, Constant::getNullValue(Type::getInt8Ty(C)), + GVTy->getNumElements() * + GVTy->getElementType()->getScalarSizeInBits() / 8, + GV->getAlign()); } Type *RetTy = ResetF->getReturnType(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 60a4ee8811fb..62c265e40dab 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -17,7 +17,10 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/BinaryFormat/ELF.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -26,6 +29,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" @@ -41,6 +45,7 @@ #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -115,6 +120,17 @@ static cl::opt<bool> cl::Hidden, cl::desc("Use Stack Safety analysis results"), cl::Optional); +static cl::opt<size_t> ClMaxLifetimes( + "hwasan-max-lifetimes-for-alloca", cl::Hidden, cl::init(3), + cl::ReallyHidden, + cl::desc("How many lifetime ends to handle for a single alloca."), + cl::Optional); + +static cl::opt<bool> + ClUseAfterScope("hwasan-use-after-scope", + cl::desc("detect use after scope within function"), + cl::Hidden, cl::init(false)); + static cl::opt<bool> ClUARRetagToZero( "hwasan-uar-retag-to-zero", cl::desc("Clear alloca tags before returning from the function to allow " @@ -220,9 +236,21 @@ bool shouldUseStackSafetyAnalysis(const Triple &TargetTriple, return shouldInstrumentStack(TargetTriple) && mightUseStackSafetyAnalysis(DisableOptimization); } + +bool shouldDetectUseAfterScope(const Triple &TargetTriple) { + return ClUseAfterScope && shouldInstrumentStack(TargetTriple); +} + /// An instrumentation pass implementing detection of addressability bugs /// using tagged pointers. class HWAddressSanitizer { +private: + struct AllocaInfo { + AllocaInst *AI; + SmallVector<IntrinsicInst *, 2> LifetimeStart; + SmallVector<IntrinsicInst *, 2> LifetimeEnd; + }; + public: HWAddressSanitizer(Module &M, bool CompileKernel, bool Recover, const StackSafetyGlobalInfo *SSI) @@ -237,7 +265,11 @@ public: void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; } - bool sanitizeFunction(Function &F); + DenseMap<AllocaInst *, AllocaInst *> padInterestingAllocas( + const MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument); + bool sanitizeFunction(Function &F, + llvm::function_ref<const DominatorTree &()> GetDT, + llvm::function_ref<const PostDominatorTree &()> GetPDT); void initializeModule(); void createHwasanCtorComdat(); @@ -250,23 +282,34 @@ public: void untagPointerOperand(Instruction *I, Value *Addr); Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); + + int64_t getAccessInfo(bool IsWrite, unsigned AccessSizeIndex); + void instrumentMemAccessOutline(Value *Ptr, bool IsWrite, + unsigned AccessSizeIndex, + Instruction *InsertBefore); void instrumentMemAccessInline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore); + bool ignoreMemIntrinsic(MemIntrinsic *MI); void instrumentMemIntrinsic(MemIntrinsic *MI); bool instrumentMemAccess(InterestingMemoryOperand &O); - bool ignoreAccess(Value *Ptr); + bool ignoreAccess(Instruction *Inst, Value *Ptr); void getInterestingMemoryOperands( Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting); bool isInterestingAlloca(const AllocaInst &AI); - bool tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); + void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); + static bool isStandardLifetime(const AllocaInfo &AllocaInfo, + const DominatorTree &DT); bool instrumentStack( - SmallVectorImpl<AllocaInst *> &Allocas, + MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument, + SmallVector<Instruction *, 4> &UnrecognizedLifetimes, DenseMap<AllocaInst *, std::vector<DbgVariableIntrinsic *>> &AllocaDbgMap, - SmallVectorImpl<Instruction *> &RetVec, Value *StackTag); + SmallVectorImpl<Instruction *> &RetVec, Value *StackTag, + llvm::function_ref<const DominatorTree &()> GetDT, + llvm::function_ref<const PostDominatorTree &()> GetPDT); Value *readRegister(IRBuilder<> &IRB, StringRef Name); bool instrumentLandingPads(SmallVectorImpl<Instruction *> &RetVec); Value *getNextTagWithCall(IRBuilder<> &IRB); @@ -313,8 +356,9 @@ private: bool WithFrameRecord; void init(Triple &TargetTriple, bool InstrumentWithCalls); - unsigned getObjectAlignment() const { return 1U << Scale; } + uint64_t getObjectAlignment() const { return 1ULL << Scale; } }; + ShadowMapping Mapping; Type *VoidTy = Type::getVoidTy(M.getContext()); @@ -331,6 +375,7 @@ private: bool InstrumentLandingPads; bool InstrumentWithCalls; bool InstrumentStack; + bool DetectUseAfterScope; bool UsePageAliases; bool HasMatchAllTag = false; @@ -377,14 +422,21 @@ public: } bool runOnFunction(Function &F) override { - if (shouldUseStackSafetyAnalysis(Triple(F.getParent()->getTargetTriple()), - DisableOptimization)) { + auto TargetTriple = Triple(F.getParent()->getTargetTriple()); + if (shouldUseStackSafetyAnalysis(TargetTriple, DisableOptimization)) { // We cannot call getAnalysis in doInitialization, that would cause a // crash as the required analyses are not initialized yet. HWASan->setSSI( &getAnalysis<StackSafetyGlobalInfoWrapperPass>().getResult()); } - return HWASan->sanitizeFunction(F); + return HWASan->sanitizeFunction( + F, + [&]() -> const DominatorTree & { + return getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + }, + [&]() -> const PostDominatorTree & { + return getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + }); } bool doFinalization(Module &M) override { @@ -399,6 +451,8 @@ public: // This is so we don't need to plumb TargetTriple all the way to here. if (mightUseStackSafetyAnalysis(DisableOptimization)) AU.addRequired<StackSafetyGlobalInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); } private: @@ -417,6 +471,8 @@ INITIALIZE_PASS_BEGIN( "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) INITIALIZE_PASS_DEPENDENCY(StackSafetyGlobalInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_END( HWAddressSanitizerLegacyPass, "hwasan", "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, @@ -430,25 +486,41 @@ llvm::createHWAddressSanitizerLegacyPassPass(bool CompileKernel, bool Recover, DisableOptimization); } -HWAddressSanitizerPass::HWAddressSanitizerPass(bool CompileKernel, bool Recover, - bool DisableOptimization) - : CompileKernel(CompileKernel), Recover(Recover), - DisableOptimization(DisableOptimization) {} - PreservedAnalyses HWAddressSanitizerPass::run(Module &M, ModuleAnalysisManager &MAM) { const StackSafetyGlobalInfo *SSI = nullptr; - if (shouldUseStackSafetyAnalysis(llvm::Triple(M.getTargetTriple()), - DisableOptimization)) + auto TargetTriple = llvm::Triple(M.getTargetTriple()); + if (shouldUseStackSafetyAnalysis(TargetTriple, Options.DisableOptimization)) SSI = &MAM.getResult<StackSafetyGlobalAnalysis>(M); - HWAddressSanitizer HWASan(M, CompileKernel, Recover, SSI); + + HWAddressSanitizer HWASan(M, Options.CompileKernel, Options.Recover, SSI); bool Modified = false; - for (Function &F : M) - Modified |= HWASan.sanitizeFunction(F); + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) { + Modified |= HWASan.sanitizeFunction( + F, + [&]() -> const DominatorTree & { + return FAM.getResult<DominatorTreeAnalysis>(F); + }, + [&]() -> const PostDominatorTree & { + return FAM.getResult<PostDominatorTreeAnalysis>(F); + }); + } if (Modified) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } +void HWAddressSanitizerPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<HWAddressSanitizerPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (Options.CompileKernel) + OS << "kernel;"; + if (Options.Recover) + OS << "recover"; + OS << ">"; +} void HWAddressSanitizer::createHwasanCtorComdat() { std::tie(HwasanCtorFunction, std::ignore) = @@ -566,6 +638,7 @@ void HWAddressSanitizer::initializeModule() { UsePageAliases = shouldUsePageAliases(TargetTriple); InstrumentWithCalls = shouldInstrumentWithCalls(TargetTriple); InstrumentStack = shouldInstrumentStack(TargetTriple); + DetectUseAfterScope = shouldDetectUseAfterScope(TargetTriple); PointerTagShift = IsX86_64 ? 57 : 56; TagMaskByte = IsX86_64 ? 0x3F : 0xFF; @@ -712,7 +785,7 @@ Value *HWAddressSanitizer::getShadowNonTls(IRBuilder<> &IRB) { } } -bool HWAddressSanitizer::ignoreAccess(Value *Ptr) { +bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { // Do not instrument acesses from different address spaces; we cannot deal // with them. Type *PtrTy = cast<PointerType>(Ptr->getType()->getScalarType()); @@ -726,6 +799,12 @@ bool HWAddressSanitizer::ignoreAccess(Value *Ptr) { if (Ptr->isSwiftError()) return true; + if (findAllocaForValue(Ptr)) { + if (!InstrumentStack) + return true; + if (SSI && SSI->stackAccessIsSafe(*Inst)) + return true; + } return false; } @@ -740,29 +819,29 @@ void HWAddressSanitizer::getInterestingMemoryOperands( return; if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - if (!ClInstrumentReads || ignoreAccess(LI->getPointerOperand())) + if (!ClInstrumentReads || ignoreAccess(I, LI->getPointerOperand())) return; Interesting.emplace_back(I, LI->getPointerOperandIndex(), false, LI->getType(), LI->getAlign()); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { - if (!ClInstrumentWrites || ignoreAccess(SI->getPointerOperand())) + if (!ClInstrumentWrites || ignoreAccess(I, SI->getPointerOperand())) return; Interesting.emplace_back(I, SI->getPointerOperandIndex(), true, SI->getValueOperand()->getType(), SI->getAlign()); } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { - if (!ClInstrumentAtomics || ignoreAccess(RMW->getPointerOperand())) + if (!ClInstrumentAtomics || ignoreAccess(I, RMW->getPointerOperand())) return; Interesting.emplace_back(I, RMW->getPointerOperandIndex(), true, RMW->getValOperand()->getType(), None); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { - if (!ClInstrumentAtomics || ignoreAccess(XCHG->getPointerOperand())) + if (!ClInstrumentAtomics || ignoreAccess(I, XCHG->getPointerOperand())) return; Interesting.emplace_back(I, XCHG->getPointerOperandIndex(), true, XCHG->getCompareOperand()->getType(), None); } else if (auto CI = dyn_cast<CallInst>(I)) { - for (unsigned ArgNo = 0; ArgNo < CI->getNumArgOperands(); ArgNo++) { + for (unsigned ArgNo = 0; ArgNo < CI->arg_size(); ArgNo++) { if (!ClInstrumentByval || !CI->isByValArgument(ArgNo) || - ignoreAccess(CI->getArgOperand(ArgNo))) + ignoreAccess(I, CI->getArgOperand(ArgNo))) continue; Type *Ty = CI->getParamByValType(ArgNo); Interesting.emplace_back(I, ArgNo, false, Ty, Align(1)); @@ -809,30 +888,38 @@ Value *HWAddressSanitizer::memToShadow(Value *Mem, IRBuilder<> &IRB) { return IRB.CreateGEP(Int8Ty, ShadowBase, Shadow); } +int64_t HWAddressSanitizer::getAccessInfo(bool IsWrite, + unsigned AccessSizeIndex) { + return (CompileKernel << HWASanAccessInfo::CompileKernelShift) + + (HasMatchAllTag << HWASanAccessInfo::HasMatchAllShift) + + (MatchAllTag << HWASanAccessInfo::MatchAllShift) + + (Recover << HWASanAccessInfo::RecoverShift) + + (IsWrite << HWASanAccessInfo::IsWriteShift) + + (AccessSizeIndex << HWASanAccessInfo::AccessSizeShift); +} + +void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite, + unsigned AccessSizeIndex, + Instruction *InsertBefore) { + assert(!UsePageAliases); + const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex); + IRBuilder<> IRB(InsertBefore); + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); + IRB.CreateCall(Intrinsic::getDeclaration( + M, UseShortGranules + ? Intrinsic::hwasan_check_memaccess_shortgranules + : Intrinsic::hwasan_check_memaccess), + {ShadowBase, Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); +} + void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore) { assert(!UsePageAliases); - const int64_t AccessInfo = - (CompileKernel << HWASanAccessInfo::CompileKernelShift) + - (HasMatchAllTag << HWASanAccessInfo::HasMatchAllShift) + - (MatchAllTag << HWASanAccessInfo::MatchAllShift) + - (Recover << HWASanAccessInfo::RecoverShift) + - (IsWrite << HWASanAccessInfo::IsWriteShift) + - (AccessSizeIndex << HWASanAccessInfo::AccessSizeShift); + const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex); IRBuilder<> IRB(InsertBefore); - if (OutlinedChecks) { - Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); - IRB.CreateCall(Intrinsic::getDeclaration( - M, UseShortGranules - ? Intrinsic::hwasan_check_memaccess_shortgranules - : Intrinsic::hwasan_check_memaccess), - {ShadowBase, Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); - return; - } - Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), IRB.getInt8Ty()); @@ -908,6 +995,16 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, cast<BranchInst>(CheckFailTerm)->setSuccessor(0, CheckTerm->getParent()); } +bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) { + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { + return (!ClInstrumentWrites || ignoreAccess(MTI, MTI->getDest())) && + (!ClInstrumentReads || ignoreAccess(MTI, MTI->getSource())); + } + if (isa<MemSetInst>(MI)) + return !ClInstrumentWrites || ignoreAccess(MI, MI->getDest()); + return false; +} + void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { IRBuilder<> IRB(MI); if (isa<MemTransferInst>(MI)) { @@ -943,6 +1040,8 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) { if (InstrumentWithCalls) { IRB.CreateCall(HwasanMemoryAccessCallback[O.IsWrite][AccessSizeIndex], IRB.CreatePointerCast(Addr, IntptrTy)); + } else if (OutlinedChecks) { + instrumentMemAccessOutline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn()); } else { instrumentMemAccessInline(Addr, O.IsWrite, AccessSizeIndex, O.getInsn()); } @@ -968,7 +1067,7 @@ static uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { return SizeInBytes * ArraySize; } -bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, +void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size) { size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); if (!UseShortGranules) @@ -999,7 +1098,6 @@ bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, AlignedSize - 1)); } } - return true; } unsigned HWAddressSanitizer::retagMask(unsigned AllocaNo) { @@ -1231,17 +1329,53 @@ bool HWAddressSanitizer::instrumentLandingPads( return true; } +static bool +maybeReachableFromEachOther(const SmallVectorImpl<IntrinsicInst *> &Insts, + const DominatorTree &DT) { + // If we have too many lifetime ends, give up, as the algorithm below is N^2. + if (Insts.size() > ClMaxLifetimes) + return true; + for (size_t I = 0; I < Insts.size(); ++I) { + for (size_t J = 0; J < Insts.size(); ++J) { + if (I == J) + continue; + if (isPotentiallyReachable(Insts[I], Insts[J], nullptr, &DT)) + return true; + } + } + return false; +} + +// static +bool HWAddressSanitizer::isStandardLifetime(const AllocaInfo &AllocaInfo, + const DominatorTree &DT) { + // An alloca that has exactly one start and end in every possible execution. + // If it has multiple ends, they have to be unreachable from each other, so + // at most one of them is actually used for each execution of the function. + return AllocaInfo.LifetimeStart.size() == 1 && + (AllocaInfo.LifetimeEnd.size() == 1 || + (AllocaInfo.LifetimeEnd.size() > 0 && + !maybeReachableFromEachOther(AllocaInfo.LifetimeEnd, DT))); +} + bool HWAddressSanitizer::instrumentStack( - SmallVectorImpl<AllocaInst *> &Allocas, + MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument, + SmallVector<Instruction *, 4> &UnrecognizedLifetimes, DenseMap<AllocaInst *, std::vector<DbgVariableIntrinsic *>> &AllocaDbgMap, - SmallVectorImpl<Instruction *> &RetVec, Value *StackTag) { + SmallVectorImpl<Instruction *> &RetVec, Value *StackTag, + llvm::function_ref<const DominatorTree &()> GetDT, + llvm::function_ref<const PostDominatorTree &()> GetPDT) { // Ideally, we want to calculate tagged stack base pointer, and rewrite all // alloca addresses using that. Unfortunately, offsets are not known yet // (unless we use ASan-style mega-alloca). Instead we keep the base tag in a // temp, shift-OR it into each alloca address and xor with the retag mask. // This generates one extra instruction per alloca use. - for (unsigned N = 0; N < Allocas.size(); ++N) { - auto *AI = Allocas[N]; + unsigned int I = 0; + + for (auto &KV : AllocasToInstrument) { + auto N = I++; + auto *AI = KV.first; + AllocaInfo &Info = KV.second; IRBuilder<> IRB(AI->getNextNode()); // Replace uses of the alloca with tagged address. @@ -1268,17 +1402,40 @@ bool HWAddressSanitizer::instrumentStack( } size_t Size = getAllocaSizeInBytes(*AI); - tagAlloca(IRB, AI, Tag, Size); - - for (auto RI : RetVec) { - IRB.SetInsertPoint(RI); - - // Re-tag alloca memory with the special UAR tag. - Value *Tag = getUARTag(IRB, StackTag); - tagAlloca(IRB, AI, Tag, alignTo(Size, Mapping.getObjectAlignment())); + size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); + bool StandardLifetime = + UnrecognizedLifetimes.empty() && isStandardLifetime(Info, GetDT()); + if (DetectUseAfterScope && StandardLifetime) { + IntrinsicInst *Start = Info.LifetimeStart[0]; + IRB.SetInsertPoint(Start->getNextNode()); + auto TagEnd = [&](Instruction *Node) { + IRB.SetInsertPoint(Node); + Value *UARTag = getUARTag(IRB, StackTag); + tagAlloca(IRB, AI, UARTag, AlignedSize); + }; + tagAlloca(IRB, AI, Tag, Size); + if (!forAllReachableExits(GetDT(), GetPDT(), Start, Info.LifetimeEnd, + RetVec, TagEnd)) { + for (auto *End : Info.LifetimeEnd) + End->eraseFromParent(); + } + } else { + tagAlloca(IRB, AI, Tag, Size); + for (auto *RI : RetVec) { + IRB.SetInsertPoint(RI); + Value *UARTag = getUARTag(IRB, StackTag); + tagAlloca(IRB, AI, UARTag, AlignedSize); + } + if (!StandardLifetime) { + for (auto &II : Info.LifetimeStart) + II->eraseFromParent(); + for (auto &II : Info.LifetimeEnd) + II->eraseFromParent(); + } } } - + for (auto &I : UnrecognizedLifetimes) + I->eraseFromParent(); return true; } @@ -1300,7 +1457,42 @@ bool HWAddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { !(SSI && SSI->isSafe(AI)); } -bool HWAddressSanitizer::sanitizeFunction(Function &F) { +DenseMap<AllocaInst *, AllocaInst *> HWAddressSanitizer::padInterestingAllocas( + const MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument) { + DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap; + for (auto &KV : AllocasToInstrument) { + AllocaInst *AI = KV.first; + uint64_t Size = getAllocaSizeInBytes(*AI); + uint64_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); + AI->setAlignment( + Align(std::max(AI->getAlignment(), Mapping.getObjectAlignment()))); + if (Size != AlignedSize) { + Type *AllocatedType = AI->getAllocatedType(); + if (AI->isArrayAllocation()) { + uint64_t ArraySize = + cast<ConstantInt>(AI->getArraySize())->getZExtValue(); + AllocatedType = ArrayType::get(AllocatedType, ArraySize); + } + Type *TypeWithPadding = StructType::get( + AllocatedType, ArrayType::get(Int8Ty, AlignedSize - Size)); + auto *NewAI = new AllocaInst( + TypeWithPadding, AI->getType()->getAddressSpace(), nullptr, "", AI); + NewAI->takeName(AI); + NewAI->setAlignment(AI->getAlign()); + NewAI->setUsedWithInAlloca(AI->isUsedWithInAlloca()); + NewAI->setSwiftError(AI->isSwiftError()); + NewAI->copyMetadata(*AI); + auto *Bitcast = new BitCastInst(NewAI, AI->getType(), "", AI); + AI->replaceAllUsesWith(Bitcast); + AllocaToPaddedAllocaMap[AI] = NewAI; + } + } + return AllocaToPaddedAllocaMap; +} + +bool HWAddressSanitizer::sanitizeFunction( + Function &F, llvm::function_ref<const DominatorTree &()> GetDT, + llvm::function_ref<const PostDominatorTree &()> GetPDT) { if (&F == HwasanCtorFunction) return false; @@ -1311,18 +1503,36 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { SmallVector<InterestingMemoryOperand, 16> OperandsToInstrument; SmallVector<MemIntrinsic *, 16> IntrinToInstrument; - SmallVector<AllocaInst *, 8> AllocasToInstrument; + MapVector<AllocaInst *, AllocaInfo> AllocasToInstrument; SmallVector<Instruction *, 8> RetVec; SmallVector<Instruction *, 8> LandingPadVec; + SmallVector<Instruction *, 4> UnrecognizedLifetimes; DenseMap<AllocaInst *, std::vector<DbgVariableIntrinsic *>> AllocaDbgMap; for (auto &BB : F) { for (auto &Inst : BB) { - if (InstrumentStack) + if (InstrumentStack) { if (AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { if (isInterestingAlloca(*AI)) - AllocasToInstrument.push_back(AI); + AllocasToInstrument.insert({AI, {}}); + continue; + } + auto *II = dyn_cast<IntrinsicInst>(&Inst); + if (II && (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end)) { + AllocaInst *AI = findAllocaForValue(II->getArgOperand(1)); + if (!AI) { + UnrecognizedLifetimes.push_back(&Inst); + continue; + } + if (!isInterestingAlloca(*AI)) + continue; + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + AllocasToInstrument[AI].LifetimeStart.push_back(II); + else + AllocasToInstrument[AI].LifetimeEnd.push_back(II); continue; } + } if (isa<ReturnInst>(Inst) || isa<ResumeInst>(Inst) || isa<CleanupReturnInst>(Inst)) @@ -1343,7 +1553,8 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { getInterestingMemoryOperands(&Inst, OperandsToInstrument); if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst)) - IntrinToInstrument.push_back(MI); + if (!ignoreMemIntrinsic(MI)) + IntrinToInstrument.push_back(MI); } } @@ -1377,38 +1588,14 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { if (!AllocasToInstrument.empty()) { Value *StackTag = ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB); - instrumentStack(AllocasToInstrument, AllocaDbgMap, RetVec, StackTag); + instrumentStack(AllocasToInstrument, UnrecognizedLifetimes, AllocaDbgMap, + RetVec, StackTag, GetDT, GetPDT); } // Pad and align each of the allocas that we instrumented to stop small // uninteresting allocas from hiding in instrumented alloca's padding and so // that we have enough space to store real tags for short granules. - DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap; - for (AllocaInst *AI : AllocasToInstrument) { - uint64_t Size = getAllocaSizeInBytes(*AI); - uint64_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); - AI->setAlignment( - Align(std::max(AI->getAlignment(), Mapping.getObjectAlignment()))); - if (Size != AlignedSize) { - Type *AllocatedType = AI->getAllocatedType(); - if (AI->isArrayAllocation()) { - uint64_t ArraySize = - cast<ConstantInt>(AI->getArraySize())->getZExtValue(); - AllocatedType = ArrayType::get(AllocatedType, ArraySize); - } - Type *TypeWithPadding = StructType::get( - AllocatedType, ArrayType::get(Int8Ty, AlignedSize - Size)); - auto *NewAI = new AllocaInst( - TypeWithPadding, AI->getType()->getAddressSpace(), nullptr, "", AI); - NewAI->takeName(AI); - NewAI->setAlignment(AI->getAlign()); - NewAI->setUsedWithInAlloca(AI->isUsedWithInAlloca()); - NewAI->setSwiftError(AI->isSwiftError()); - NewAI->copyMetadata(*AI); - auto *Bitcast = new BitCastInst(NewAI, AI->getType(), "", AI); - AI->replaceAllUsesWith(Bitcast); - AllocaToPaddedAllocaMap[AI] = NewAI; - } - } + DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap = + padInterestingAllocas(AllocasToInstrument); if (!AllocaToPaddedAllocaMap.empty()) { for (auto &BB : F) { @@ -1434,13 +1621,11 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { // dynamic allocas. if (EntryIRB.GetInsertBlock() != &F.getEntryBlock()) { InsertPt = &*F.getEntryBlock().begin(); - for (auto II = EntryIRB.GetInsertBlock()->begin(), - IE = EntryIRB.GetInsertBlock()->end(); - II != IE;) { - Instruction *I = &*II++; - if (auto *AI = dyn_cast<AllocaInst>(I)) + for (Instruction &I : + llvm::make_early_inc_range(*EntryIRB.GetInsertBlock())) { + if (auto *AI = dyn_cast<AllocaInst>(&I)) if (isa<ConstantInt>(AI->getArraySize())) - I->moveBefore(InsertPt); + I.moveBefore(InsertPt); } } @@ -1586,9 +1771,10 @@ void HWAddressSanitizer::instrumentGlobals() { Hasher.update(M.getSourceFileName()); MD5::MD5Result Hash; Hasher.final(Hash); - uint8_t Tag = Hash[0] & TagMaskByte; + uint8_t Tag = Hash[0]; for (GlobalVariable *GV : Globals) { + Tag &= TagMaskByte; // Skip tag 0 in order to avoid collisions with untagged memory. if (Tag == 0) Tag = 1; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp index 071feb876540..3ea314329079 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp @@ -1,9 +1,8 @@ //===- InstrOrderFile.cpp ---- Late IR instrumentation for order file ----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index 0d257bb6bd52..ad21fec269ec 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -446,13 +446,12 @@ bool InstrProfiling::lowerIntrinsics(Function *F) { bool MadeChange = false; PromotionCandidates.clear(); for (BasicBlock &BB : *F) { - for (auto I = BB.begin(), E = BB.end(); I != E;) { - auto Instr = I++; - InstrProfIncrementInst *Inc = castToIncrementInst(&*Instr); + for (Instruction &Instr : llvm::make_early_inc_range(BB)) { + InstrProfIncrementInst *Inc = castToIncrementInst(&Instr); if (Inc) { lowerIncrement(Inc); MadeChange = true; - } else if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(Instr)) { + } else if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(&Instr)) { lowerValueProfileInst(Ind); MadeChange = true; } @@ -520,6 +519,14 @@ void InstrProfiling::promoteCounterLoadStores(Function *F) { } } +static bool needsRuntimeHookUnconditionally(const Triple &TT) { + // On Fuchsia, we only need runtime hook if any counters are present. + if (TT.isOSFuchsia()) + return false; + + return true; +} + /// Check if the module contains uses of any profiling intrinsics. static bool containsProfilingIntrinsics(Module &M) { if (auto *F = M.getFunction( @@ -548,8 +555,11 @@ bool InstrProfiling::run( UsedVars.clear(); TT = Triple(M.getTargetTriple()); + bool MadeChange = false; + // Emit the runtime hook even if no counters are present. - bool MadeChange = emitRuntimeHook(); + if (needsRuntimeHookUnconditionally(TT)) + MadeChange = emitRuntimeHook(); // Improve compile time by avoiding linear scans when there is no work. GlobalVariable *CoverageNamesVar = @@ -588,6 +598,7 @@ bool InstrProfiling::run( emitVNodes(); emitNameData(); + emitRuntimeHook(); emitRegistration(); emitUses(); emitInitialization(); @@ -692,7 +703,6 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { LoadInst *LI = dyn_cast<LoadInst>(&I); if (!LI) { IRBuilder<> Builder(&I); - Type *Int64Ty = Type::getInt64Ty(M->getContext()); GlobalVariable *Bias = M->getGlobalVariable(getInstrProfCounterBiasVarName()); if (!Bias) { // Compiler must define this variable when runtime counter relocation @@ -747,14 +757,18 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { } /// Get the name of a profiling variable for a particular function. -static std::string getVarName(InstrProfIncrementInst *Inc, StringRef Prefix) { +static std::string getVarName(InstrProfIncrementInst *Inc, StringRef Prefix, + bool &Renamed) { StringRef NamePrefix = getInstrProfNameVarPrefix(); StringRef Name = Inc->getName()->getName().substr(NamePrefix.size()); Function *F = Inc->getParent()->getParent(); Module *M = F->getParent(); if (!DoHashBasedCounterSplit || !isIRPGOFlagSet(M) || - !canRenameComdatFunc(*F)) + !canRenameComdatFunc(*F)) { + Renamed = false; return (Prefix + Name).str(); + } + Renamed = true; uint64_t FuncHash = Inc->getHash()->getZExtValue(); SmallVector<char, 24> HashPostfix; if (Name.endswith((Twine(".") + Twine(FuncHash)).toStringRef(HashPostfix))) @@ -848,6 +862,15 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { GlobalValue::LinkageTypes Linkage = NamePtr->getLinkage(); GlobalValue::VisibilityTypes Visibility = NamePtr->getVisibility(); + // Due to the limitation of binder as of 2021/09/28, the duplicate weak + // symbols in the same csect won't be discarded. When there are duplicate weak + // symbols, we can NOT guarantee that the relocations get resolved to the + // intended weak symbol, so we can not ensure the correctness of the relative + // CounterPtr, so we have to use private linkage for counter and data symbols. + if (TT.isOSBinFormatXCOFF()) { + Linkage = GlobalValue::PrivateLinkage; + Visibility = GlobalValue::DefaultVisibility; + } // Move the name variable to the right section. Place them in a COMDAT group // if the associated function is a COMDAT. This will make sure that only one // copy of counters of the COMDAT function will be emitted after linking. Keep @@ -867,8 +890,11 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // discarded. bool DataReferencedByCode = profDataReferencedByCode(*M); bool NeedComdat = needsComdatForCounter(*Fn, *M); - std::string CntsVarName = getVarName(Inc, getInstrProfCountersVarPrefix()); - std::string DataVarName = getVarName(Inc, getInstrProfDataVarPrefix()); + bool Renamed; + std::string CntsVarName = + getVarName(Inc, getInstrProfCountersVarPrefix(), Renamed); + std::string DataVarName = + getVarName(Inc, getInstrProfDataVarPrefix(), Renamed); auto MaybeSetComdat = [&](GlobalVariable *GV) { bool UseComdat = (NeedComdat || TT.isOSBinFormatELF()); if (UseComdat) { @@ -909,7 +935,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { ArrayType *ValuesTy = ArrayType::get(Type::getInt64Ty(Ctx), NS); auto *ValuesVar = new GlobalVariable( *M, ValuesTy, false, Linkage, Constant::getNullValue(ValuesTy), - getVarName(Inc, getInstrProfValuesVarPrefix())); + getVarName(Inc, getInstrProfValuesVarPrefix(), Renamed)); ValuesVar->setVisibility(Visibility); ValuesVar->setSection( getInstrProfSectionName(IPSK_vals, TT.getObjectFormat())); @@ -920,6 +946,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { } // Create data variable. + auto *IntPtrTy = M->getDataLayout().getIntPtrType(M->getContext()); auto *Int16Ty = Type::getInt16Ty(Ctx); auto *Int16ArrayTy = ArrayType::get(Int16Ty, IPVK_Last + 1); Type *DataTypes[] = { @@ -936,10 +963,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) Int16ArrayVals[Kind] = ConstantInt::get(Int16Ty, PD.NumValueSites[Kind]); - Constant *DataVals[] = { -#define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init, -#include "llvm/ProfileData/InstrProfData.inc" - }; // If the data variable is not referenced by code (if we don't emit // @llvm.instrprof.value.profile, NS will be 0), and the counter keeps the // data variable live under linker GC, the data variable can be private. This @@ -947,14 +970,30 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // // On COFF, a comdat leader cannot be local so we require DataReferencedByCode // to be false. - if (NS == 0 && (TT.isOSBinFormatELF() || - (!DataReferencedByCode && TT.isOSBinFormatCOFF()))) { + // + // If profd is in a deduplicate comdat, NS==0 with a hash suffix guarantees + // that other copies must have the same CFG and cannot have value profiling. + // If no hash suffix, other profd copies may be referenced by code. + if (NS == 0 && !(DataReferencedByCode && NeedComdat && !Renamed) && + (TT.isOSBinFormatELF() || + (!DataReferencedByCode && TT.isOSBinFormatCOFF()))) { Linkage = GlobalValue::PrivateLinkage; Visibility = GlobalValue::DefaultVisibility; } auto *Data = - new GlobalVariable(*M, DataTy, false, Linkage, - ConstantStruct::get(DataTy, DataVals), DataVarName); + new GlobalVariable(*M, DataTy, false, Linkage, nullptr, DataVarName); + // Reference the counter variable with a label difference (link-time + // constant). + auto *RelativeCounterPtr = + ConstantExpr::getSub(ConstantExpr::getPtrToInt(CounterPtr, IntPtrTy), + ConstantExpr::getPtrToInt(Data, IntPtrTy)); + + Constant *DataVals[] = { +#define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init, +#include "llvm/ProfileData/InstrProfData.inc" + }; + Data->setInitializer(ConstantStruct::get(DataTy, DataVals)); + Data->setVisibility(Visibility); Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat())); Data->setAlignment(Align(INSTR_PROF_DATA_ALIGNMENT)); @@ -1035,7 +1074,7 @@ void InstrProfiling::emitNameData() { std::string CompressedNameStr; if (Error E = collectPGOFuncNameStrings(ReferencedNames, CompressedNameStr, DoInstrProfNameCompression)) { - report_fatal_error(toString(std::move(E)), false); + report_fatal_error(Twine(toString(std::move(E))), false); } auto &Ctx = M->getContext(); @@ -1102,9 +1141,9 @@ void InstrProfiling::emitRegistration() { } bool InstrProfiling::emitRuntimeHook() { - // We expect the linker to be invoked with -u<hook_var> flag for Linux or - // Fuchsia, in which case there is no need to emit the user function. - if (TT.isOSLinux() || TT.isOSFuchsia()) + // We expect the linker to be invoked with -u<hook_var> flag for Linux + // in which case there is no need to emit the external variable. + if (TT.isOSLinux()) return false; // If the module's provided its own runtime, we don't need to do anything. @@ -1117,23 +1156,28 @@ bool InstrProfiling::emitRuntimeHook() { new GlobalVariable(*M, Int32Ty, false, GlobalValue::ExternalLinkage, nullptr, getInstrProfRuntimeHookVarName()); - // Make a function that uses it. - auto *User = Function::Create(FunctionType::get(Int32Ty, false), - GlobalValue::LinkOnceODRLinkage, - getInstrProfRuntimeHookVarUseFuncName(), M); - User->addFnAttr(Attribute::NoInline); - if (Options.NoRedZone) - User->addFnAttr(Attribute::NoRedZone); - User->setVisibility(GlobalValue::HiddenVisibility); - if (TT.supportsCOMDAT()) - User->setComdat(M->getOrInsertComdat(User->getName())); - - IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", User)); - auto *Load = IRB.CreateLoad(Int32Ty, Var); - IRB.CreateRet(Load); - - // Mark the user variable as used so that it isn't stripped out. - CompilerUsedVars.push_back(User); + if (TT.isOSBinFormatELF()) { + // Mark the user variable as used so that it isn't stripped out. + CompilerUsedVars.push_back(Var); + } else { + // Make a function that uses it. + auto *User = Function::Create(FunctionType::get(Int32Ty, false), + GlobalValue::LinkOnceODRLinkage, + getInstrProfRuntimeHookVarUseFuncName(), M); + User->addFnAttr(Attribute::NoInline); + if (Options.NoRedZone) + User->addFnAttr(Attribute::NoRedZone); + User->setVisibility(GlobalValue::HiddenVisibility); + if (TT.supportsCOMDAT()) + User->setComdat(M->getOrInsertComdat(User->getName())); + + IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", User)); + auto *Load = IRB.CreateLoad(Int32Ty, Var); + IRB.CreateRet(Load); + + // Mark the function as used so that it isn't stripped out. + CompilerUsedVars.push_back(User); + } return true; } @@ -1142,12 +1186,12 @@ void InstrProfiling::emitUses() { // GlobalOpt/ConstantMerge) may not discard associated sections as a unit, so // we conservatively retain all unconditionally in the compiler. // - // On ELF, the linker can guarantee the associated sections will be retained - // or discarded as a unit, so llvm.compiler.used is sufficient. Similarly on - // COFF, if prof data is not referenced by code we use one comdat and ensure - // this GC property as well. Otherwise, we have to conservatively make all of - // the sections retained by the linker. - if (TT.isOSBinFormatELF() || + // On ELF and Mach-O, the linker can guarantee the associated sections will be + // retained or discarded as a unit, so llvm.compiler.used is sufficient. + // Similarly on COFF, if prof data is not referenced by code we use one comdat + // and ensure this GC property as well. Otherwise, we have to conservatively + // make all of the sections retained by the linker. + if (TT.isOSBinFormatELF() || TT.isOSBinFormatMachO() || (TT.isOSBinFormatCOFF() && !profDataReferencedByCode(*M))) appendToCompilerUsed(*M, CompilerUsedVars); else diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp index 0e6a404a9e0b..727672fa0605 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" @@ -107,6 +108,10 @@ static cl::opt<int> cl::desc("granularity of memprof shadow mapping"), cl::Hidden, cl::init(DefaultShadowGranularity)); +static cl::opt<bool> ClStack("memprof-instrument-stack", + cl::desc("Instrument scalar stack variables"), + cl::Hidden, cl::init(false)); + // Debug flags. static cl::opt<int> ClDebug("memprof-debug", cl::desc("debug"), cl::Hidden, @@ -123,6 +128,8 @@ static cl::opt<int> ClDebugMax("memprof-debug-max", cl::desc("Debug max inst"), STATISTIC(NumInstrumentedReads, "Number of instrumented reads"); STATISTIC(NumInstrumentedWrites, "Number of instrumented writes"); +STATISTIC(NumSkippedStackReads, "Number of non-instrumented stack reads"); +STATISTIC(NumSkippedStackWrites, "Number of non-instrumented stack writes"); namespace { @@ -255,8 +262,6 @@ PreservedAnalyses MemProfilerPass::run(Function &F, if (Profiler.instrumentFunction(F)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); - - return PreservedAnalyses::all(); } ModuleMemProfilerPass::ModuleMemProfilerPass() {} @@ -448,6 +453,15 @@ void MemProfiler::instrumentMaskedLoadOrStore(const DataLayout &DL, Value *Mask, void MemProfiler::instrumentMop(Instruction *I, const DataLayout &DL, InterestingMemoryAccess &Access) { + // Skip instrumentation of stack accesses unless requested. + if (!ClStack && isa<AllocaInst>(getUnderlyingObject(Access.Addr))) { + if (Access.IsWrite) + ++NumSkippedStackWrites; + else + ++NumSkippedStackReads; + return; + } + if (Access.IsWrite) NumInstrumentedWrites++; else diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 4e755bab15f3..4d15b784f486 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -673,14 +673,27 @@ PreservedAnalyses MemorySanitizerPass::run(Function &F, return PreservedAnalyses::all(); } -PreservedAnalyses MemorySanitizerPass::run(Module &M, - ModuleAnalysisManager &AM) { +PreservedAnalyses +ModuleMemorySanitizerPass::run(Module &M, ModuleAnalysisManager &AM) { if (Options.Kernel) return PreservedAnalyses::all(); insertModuleCtor(M); return PreservedAnalyses::none(); } +void MemorySanitizerPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<MemorySanitizerPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (Options.Recover) + OS << "recover;"; + if (Options.Kernel) + OS << "kernel;"; + OS << "track-origins=" << Options.TrackOrigins; + OS << ">"; +} + char MemorySanitizerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(MemorySanitizerLegacyPass, "msan", @@ -1695,7 +1708,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (FArgEagerCheck) { *ShadowPtr = getCleanShadow(V); setOrigin(A, getCleanOrigin()); - continue; + break; } else if (FArgByVal) { Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset); // ByVal pointer itself has clean shadow. We copy the actual @@ -1745,8 +1758,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { break; } - if (!FArgEagerCheck) - ArgOffset += alignTo(Size, kShadowTLSAlignment); + ArgOffset += alignTo(Size, kShadowTLSAlignment); } assert(*ShadowPtr && "Could not find shadow for an argument"); return *ShadowPtr; @@ -2661,7 +2673,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { RetTy->isX86_MMXTy())) return false; - unsigned NumArgOperands = I.getNumArgOperands(); + unsigned NumArgOperands = I.arg_size(); for (unsigned i = 0; i < NumArgOperands; ++i) { Type *Ty = I.getArgOperand(i)->getType(); if (Ty != RetTy) @@ -2688,7 +2700,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// We special-case intrinsics where this approach fails. See llvm.bswap /// handling as an example of that. bool handleUnknownIntrinsic(IntrinsicInst &I) { - unsigned NumArgOperands = I.getNumArgOperands(); + unsigned NumArgOperands = I.arg_size(); if (NumArgOperands == 0) return false; @@ -2762,10 +2774,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *CopyOp, *ConvertOp; assert((!HasRoundingMode || - isa<ConstantInt>(I.getArgOperand(I.getNumArgOperands() - 1))) && + isa<ConstantInt>(I.getArgOperand(I.arg_size() - 1))) && "Invalid rounding mode"); - switch (I.getNumArgOperands() - HasRoundingMode) { + switch (I.arg_size() - HasRoundingMode) { case 2: CopyOp = I.getArgOperand(0); ConvertOp = I.getArgOperand(1); @@ -2854,7 +2866,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // size, and the rest is ignored. Behavior is defined even if shift size is // greater than register (or field) width. void handleVectorShiftIntrinsic(IntrinsicInst &I, bool Variable) { - assert(I.getNumArgOperands() == 2); + assert(I.arg_size() == 2); IRBuilder<> IRB(&I); // If any of the S2 bits are poisoned, the whole thing is poisoned. // Otherwise perform the same shift on S1. @@ -2919,7 +2931,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // to sext(Sa != zeroinitializer), sext(Sb != zeroinitializer). // EltSizeInBits is used only for x86mmx arguments. void handleVectorPackIntrinsic(IntrinsicInst &I, unsigned EltSizeInBits = 0) { - assert(I.getNumArgOperands() == 2); + assert(I.arg_size() == 2); bool isX86_MMX = I.getOperand(0)->getType()->isX86_MMXTy(); IRBuilder<> IRB(&I); Value *S1 = getShadow(&I, 0); @@ -3653,9 +3665,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { .addAttribute(Attribute::ArgMemOnly) .addAttribute(Attribute::Speculatable); - Call->removeAttributes(AttributeList::FunctionIndex, B); + Call->removeFnAttrs(B); if (Function *Func = Call->getCalledFunction()) { - Func->removeAttributes(AttributeList::FunctionIndex, B); + Func->removeFnAttrs(B); } maybeMarkSanitizerLibraryCallNoBuiltin(Call, TLI); @@ -3696,42 +3708,48 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (EagerCheck) { insertShadowCheck(A, &CB); - continue; - } - if (ByVal) { - // ByVal requires some special handling as it's too big for a single - // load - assert(A->getType()->isPointerTy() && - "ByVal argument is not a pointer!"); - Size = DL.getTypeAllocSize(CB.getParamByValType(i)); - if (ArgOffset + Size > kParamTLSSize) break; - const MaybeAlign ParamAlignment(CB.getParamAlign(i)); - MaybeAlign Alignment = llvm::None; - if (ParamAlignment) - Alignment = std::min(*ParamAlignment, kShadowTLSAlignment); - Value *AShadowPtr = - getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), Alignment, - /*isStore*/ false) - .first; - - Store = IRB.CreateMemCpy(ArgShadowBase, Alignment, AShadowPtr, - Alignment, Size); - // TODO(glider): need to copy origins. - } else { - // Any other parameters mean we need bit-grained tracking of uninit data Size = DL.getTypeAllocSize(A->getType()); - if (ArgOffset + Size > kParamTLSSize) break; - Store = IRB.CreateAlignedStore(ArgShadow, ArgShadowBase, - kShadowTLSAlignment); - Constant *Cst = dyn_cast<Constant>(ArgShadow); - if (Cst && Cst->isNullValue()) ArgIsInitialized = true; + } else { + if (ByVal) { + // ByVal requires some special handling as it's too big for a single + // load + assert(A->getType()->isPointerTy() && + "ByVal argument is not a pointer!"); + Size = DL.getTypeAllocSize(CB.getParamByValType(i)); + if (ArgOffset + Size > kParamTLSSize) + break; + const MaybeAlign ParamAlignment(CB.getParamAlign(i)); + MaybeAlign Alignment = llvm::None; + if (ParamAlignment) + Alignment = std::min(*ParamAlignment, kShadowTLSAlignment); + Value *AShadowPtr = + getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), Alignment, + /*isStore*/ false) + .first; + + Store = IRB.CreateMemCpy(ArgShadowBase, Alignment, AShadowPtr, + Alignment, Size); + // TODO(glider): need to copy origins. + } else { + // Any other parameters mean we need bit-grained tracking of uninit + // data + Size = DL.getTypeAllocSize(A->getType()); + if (ArgOffset + Size > kParamTLSSize) + break; + Store = IRB.CreateAlignedStore(ArgShadow, ArgShadowBase, + kShadowTLSAlignment); + Constant *Cst = dyn_cast<Constant>(ArgShadow); + if (Cst && Cst->isNullValue()) + ArgIsInitialized = true; + } + if (MS.TrackOrigins && !ArgIsInitialized) + IRB.CreateStore(getOrigin(A), + getOriginPtrForArgument(A, IRB, ArgOffset)); + (void)Store; + assert(Store != nullptr); + LLVM_DEBUG(dbgs() << " Param:" << *Store << "\n"); } - if (MS.TrackOrigins && !ArgIsInitialized) - IRB.CreateStore(getOrigin(A), - getOriginPtrForArgument(A, IRB, ArgOffset)); - (void)Store; - assert(Size != 0 && Store != nullptr); - LLVM_DEBUG(dbgs() << " Param:" << *Store << "\n"); + assert(Size != 0); ArgOffset += alignTo(Size, kShadowTLSAlignment); } LLVM_DEBUG(dbgs() << " done with call args\n"); @@ -3807,7 +3825,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (isAMustTailRetVal(RetVal)) return; Value *ShadowPtr = getShadowPtrForRetval(RetVal, IRB); bool HasNoUndef = - F.hasAttribute(AttributeList::ReturnIndex, Attribute::NoUndef); + F.hasRetAttribute(Attribute::NoUndef); bool StoreShadow = !(ClEagerChecks && HasNoUndef); // FIXME: Consider using SpecialCaseList to specify a list of functions that // must always return fully initialized values. For now, we hardcode "main". @@ -4176,7 +4194,7 @@ struct VarArgAMD64Helper : public VarArgHelper { MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) { AMD64FpEndOffset = AMD64FpEndOffsetSSE; - for (const auto &Attr : F.getAttributes().getFnAttributes()) { + for (const auto &Attr : F.getAttributes().getFnAttrs()) { if (Attr.isStringAttribute() && (Attr.getKindAsString() == "target-features")) { if (Attr.getValueAsString().contains("-sse")) @@ -5330,6 +5348,9 @@ bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) { if (!CompileKernel && F.getName() == kMsanModuleCtorName) return false; + if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + return false; + MemorySanitizerVisitor Visitor(F, *this, TLI); // Clear out readonly/readnone attributes. @@ -5339,7 +5360,7 @@ bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) { .addAttribute(Attribute::WriteOnly) .addAttribute(Attribute::ArgMemOnly) .addAttribute(Attribute::Speculatable); - F.removeAttributes(AttributeList::FunctionIndex, B); + F.removeFnAttrs(B); return Visitor.runOnFunction(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 3d9261eb99ba..af5946325bbb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -110,6 +110,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -198,12 +199,14 @@ static cl::opt<bool> "warnings about missing profile data for " "functions.")); +namespace llvm { // Command line option to enable/disable the warning about a hash mismatch in // the profile data. -static cl::opt<bool> +cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden, cl::desc("Use this option to turn off/on " "warnings about profile cfg mismatch.")); +} // namespace llvm // Command line option to enable/disable the warning about a hash mismatch in // the profile data for Comdat functions, which often turns out to be false @@ -462,7 +465,10 @@ public: private: bool runOnModule(Module &M) override { createProfileFileNameVar(M, InstrProfileOutput); - createIRLevelProfileFlagVar(M, /* IsCS */ true, PGOInstrumentEntry); + // The variable in a comdat may be discarded by LTO. Ensure the + // declaration will be retained. + appendToCompilerUsed( + M, createIRLevelProfileFlagVar(M, /*IsCS=*/true, PGOInstrumentEntry)); return false; } std::string InstrProfileOutput; @@ -1610,7 +1616,7 @@ static bool InstrumentAllFunctions( // For the context-sensitve instrumentation, we should have a separated pass // (before LTO/ThinLTO linking) to create these variables. if (!IsCS) - createIRLevelProfileFlagVar(M, /* IsCS */ false, PGOInstrumentEntry); + createIRLevelProfileFlagVar(M, /*IsCS=*/false, PGOInstrumentEntry); std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; collectComdatMembers(M, ComdatMembers); @@ -1630,7 +1636,10 @@ static bool InstrumentAllFunctions( PreservedAnalyses PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &AM) { createProfileFileNameVar(M, CSInstrName); - createIRLevelProfileFlagVar(M, /* IsCS */ true, PGOInstrumentEntry); + // The variable in a comdat may be discarded by LTO. Ensure the declaration + // will be retained. + appendToCompilerUsed( + M, createIRLevelProfileFlagVar(M, /*IsCS=*/true, PGOInstrumentEntry)); return PreservedAnalyses::all(); } @@ -1677,7 +1686,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI, BlockFrequencyInfo NBFI(F, NBPI, LI); #ifndef NDEBUG auto BFIEntryCount = F.getEntryCount(); - assert(BFIEntryCount.hasValue() && (BFIEntryCount.getCount() > 0) && + assert(BFIEntryCount.hasValue() && (BFIEntryCount->getCount() > 0) && "Invalid BFI Entrycount"); #endif auto SumCount = APFloat::getZero(APFloat::IEEEdouble()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 7607464cc0b9..da8ee1f15bf8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -55,6 +55,16 @@ const char SanCovTraceConstCmp1[] = "__sanitizer_cov_trace_const_cmp1"; const char SanCovTraceConstCmp2[] = "__sanitizer_cov_trace_const_cmp2"; const char SanCovTraceConstCmp4[] = "__sanitizer_cov_trace_const_cmp4"; const char SanCovTraceConstCmp8[] = "__sanitizer_cov_trace_const_cmp8"; +const char SanCovLoad1[] = "__sanitizer_cov_load1"; +const char SanCovLoad2[] = "__sanitizer_cov_load2"; +const char SanCovLoad4[] = "__sanitizer_cov_load4"; +const char SanCovLoad8[] = "__sanitizer_cov_load8"; +const char SanCovLoad16[] = "__sanitizer_cov_load16"; +const char SanCovStore1[] = "__sanitizer_cov_store1"; +const char SanCovStore2[] = "__sanitizer_cov_store2"; +const char SanCovStore4[] = "__sanitizer_cov_store4"; +const char SanCovStore8[] = "__sanitizer_cov_store8"; +const char SanCovStore16[] = "__sanitizer_cov_store16"; const char SanCovTraceDiv4[] = "__sanitizer_cov_trace_div4"; const char SanCovTraceDiv8[] = "__sanitizer_cov_trace_div8"; const char SanCovTraceGep[] = "__sanitizer_cov_trace_gep"; @@ -122,6 +132,14 @@ static cl::opt<bool> ClDIVTracing("sanitizer-coverage-trace-divs", cl::desc("Tracing of DIV instructions"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClLoadTracing("sanitizer-coverage-trace-loads", + cl::desc("Tracing of load instructions"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> ClStoreTracing("sanitizer-coverage-trace-stores", + cl::desc("Tracing of store instructions"), + cl::Hidden, cl::init(false)); + static cl::opt<bool> ClGEPTracing("sanitizer-coverage-trace-geps", cl::desc("Tracing of GEP instructions"), cl::Hidden, cl::init(false)); @@ -175,9 +193,11 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { Options.PCTable |= ClCreatePCTable; Options.NoPrune |= !ClPruneBlocks; Options.StackDepth |= ClStackDepth; + Options.TraceLoads |= ClLoadTracing; + Options.TraceStores |= ClStoreTracing; if (!Options.TracePCGuard && !Options.TracePC && !Options.Inline8bitCounters && !Options.StackDepth && - !Options.InlineBoolFlag) + !Options.InlineBoolFlag && !Options.TraceLoads && !Options.TraceStores) Options.TracePCGuard = true; // TracePCGuard is default. return Options; } @@ -207,6 +227,8 @@ private: ArrayRef<BinaryOperator *> DivTraceTargets); void InjectTraceForGep(Function &F, ArrayRef<GetElementPtrInst *> GepTraceTargets); + void InjectTraceForLoadsAndStores(Function &F, ArrayRef<LoadInst *> Loads, + ArrayRef<StoreInst *> Stores); void InjectTraceForSwitch(Function &F, ArrayRef<Instruction *> SwitchTraceTargets); bool InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks, @@ -234,14 +256,17 @@ private: std::string getSectionEnd(const std::string &Section) const; FunctionCallee SanCovTracePCIndir; FunctionCallee SanCovTracePC, SanCovTracePCGuard; - FunctionCallee SanCovTraceCmpFunction[4]; - FunctionCallee SanCovTraceConstCmpFunction[4]; - FunctionCallee SanCovTraceDivFunction[2]; + std::array<FunctionCallee, 4> SanCovTraceCmpFunction; + std::array<FunctionCallee, 4> SanCovTraceConstCmpFunction; + std::array<FunctionCallee, 5> SanCovLoadFunction; + std::array<FunctionCallee, 5> SanCovStoreFunction; + std::array<FunctionCallee, 2> SanCovTraceDivFunction; FunctionCallee SanCovTraceGepFunction; FunctionCallee SanCovTraceSwitchFunction; GlobalVariable *SanCovLowestStack; - Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy, - *Int16Ty, *Int8Ty, *Int8PtrTy, *Int1Ty, *Int1PtrTy; + Type *Int128PtrTy, *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, + *Int32PtrTy, *Int16PtrTy, *Int16Ty, *Int8Ty, *Int8PtrTy, *Int1Ty, + *Int1PtrTy; Module *CurModule; std::string CurModuleUniqueId; Triple TargetTriple; @@ -411,7 +436,9 @@ bool ModuleSanitizerCoverage::instrumentModule( IntptrPtrTy = PointerType::getUnqual(IntptrTy); Type *VoidTy = Type::getVoidTy(*C); IRBuilder<> IRB(*C); + Int128PtrTy = PointerType::getUnqual(IRB.getInt128Ty()); Int64PtrTy = PointerType::getUnqual(IRB.getInt64Ty()); + Int16PtrTy = PointerType::getUnqual(IRB.getInt16Ty()); Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty()); Int1PtrTy = PointerType::getUnqual(IRB.getInt1Ty()); @@ -452,6 +479,28 @@ bool ModuleSanitizerCoverage::instrumentModule( SanCovTraceConstCmpFunction[3] = M.getOrInsertFunction(SanCovTraceConstCmp8, VoidTy, Int64Ty, Int64Ty); + // Loads. + SanCovLoadFunction[0] = M.getOrInsertFunction(SanCovLoad1, VoidTy, Int8PtrTy); + SanCovLoadFunction[1] = + M.getOrInsertFunction(SanCovLoad2, VoidTy, Int16PtrTy); + SanCovLoadFunction[2] = + M.getOrInsertFunction(SanCovLoad4, VoidTy, Int32PtrTy); + SanCovLoadFunction[3] = + M.getOrInsertFunction(SanCovLoad8, VoidTy, Int64PtrTy); + SanCovLoadFunction[4] = + M.getOrInsertFunction(SanCovLoad16, VoidTy, Int128PtrTy); + // Stores. + SanCovStoreFunction[0] = + M.getOrInsertFunction(SanCovStore1, VoidTy, Int8PtrTy); + SanCovStoreFunction[1] = + M.getOrInsertFunction(SanCovStore2, VoidTy, Int16PtrTy); + SanCovStoreFunction[2] = + M.getOrInsertFunction(SanCovStore4, VoidTy, Int32PtrTy); + SanCovStoreFunction[3] = + M.getOrInsertFunction(SanCovStore8, VoidTy, Int64PtrTy); + SanCovStoreFunction[4] = + M.getOrInsertFunction(SanCovStore16, VoidTy, Int128PtrTy); + { AttributeList AL; AL = AL.addParamAttribute(*C, 0, Attribute::ZExt); @@ -632,6 +681,8 @@ void ModuleSanitizerCoverage::instrumentFunction( SmallVector<Instruction *, 8> SwitchTraceTargets; SmallVector<BinaryOperator *, 8> DivTraceTargets; SmallVector<GetElementPtrInst *, 8> GepTraceTargets; + SmallVector<LoadInst *, 8> Loads; + SmallVector<StoreInst *, 8> Stores; const DominatorTree *DT = DTCallback(F); const PostDominatorTree *PDT = PDTCallback(F); @@ -661,6 +712,12 @@ void ModuleSanitizerCoverage::instrumentFunction( if (Options.TraceGep) if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&Inst)) GepTraceTargets.push_back(GEP); + if (Options.TraceLoads) + if (LoadInst *LI = dyn_cast<LoadInst>(&Inst)) + Loads.push_back(LI); + if (Options.TraceStores) + if (StoreInst *SI = dyn_cast<StoreInst>(&Inst)) + Stores.push_back(SI); if (Options.StackDepth) if (isa<InvokeInst>(Inst) || (isa<CallInst>(Inst) && !isa<IntrinsicInst>(Inst))) @@ -674,6 +731,7 @@ void ModuleSanitizerCoverage::instrumentFunction( InjectTraceForSwitch(F, SwitchTraceTargets); InjectTraceForDiv(F, DivTraceTargets); InjectTraceForGep(F, GepTraceTargets); + InjectTraceForLoadsAndStores(F, Loads, Stores); } GlobalVariable *ModuleSanitizerCoverage::CreateFunctionLocalArrayInSection( @@ -857,6 +915,40 @@ void ModuleSanitizerCoverage::InjectTraceForGep( } } +void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( + Function &, ArrayRef<LoadInst *> Loads, ArrayRef<StoreInst *> Stores) { + auto CallbackIdx = [&](const Value *Ptr) -> int { + auto ElementTy = cast<PointerType>(Ptr->getType())->getElementType(); + uint64_t TypeSize = DL->getTypeStoreSizeInBits(ElementTy); + return TypeSize == 8 ? 0 + : TypeSize == 16 ? 1 + : TypeSize == 32 ? 2 + : TypeSize == 64 ? 3 + : TypeSize == 128 ? 4 + : -1; + }; + Type *PointerType[5] = {Int8PtrTy, Int16PtrTy, Int32PtrTy, Int64PtrTy, + Int128PtrTy}; + for (auto LI : Loads) { + IRBuilder<> IRB(LI); + auto Ptr = LI->getPointerOperand(); + int Idx = CallbackIdx(Ptr); + if (Idx < 0) + continue; + IRB.CreateCall(SanCovLoadFunction[Idx], + IRB.CreatePointerCast(Ptr, PointerType[Idx])); + } + for (auto SI : Stores) { + IRBuilder<> IRB(SI); + auto Ptr = SI->getPointerOperand(); + int Idx = CallbackIdx(Ptr); + if (Idx < 0) + continue; + IRB.CreateCall(SanCovStoreFunction[Idx], + IRB.CreatePointerCast(Ptr, PointerType[Idx])); + } +} + void ModuleSanitizerCoverage::InjectTraceForCmp( Function &, ArrayRef<Instruction *> CmpTraceTargets) { for (auto I : CmpTraceTargets) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 063999a68236..f98e39d751f4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -206,8 +206,8 @@ PreservedAnalyses ThreadSanitizerPass::run(Function &F, return PreservedAnalyses::all(); } -PreservedAnalyses ThreadSanitizerPass::run(Module &M, - ModuleAnalysisManager &MAM) { +PreservedAnalyses ModuleThreadSanitizerPass::run(Module &M, + ModuleAnalysisManager &MAM) { insertModuleCtor(M); return PreservedAnalyses::none(); } @@ -249,8 +249,7 @@ void ThreadSanitizer::initialize(Module &M) { IRBuilder<> IRB(M.getContext()); AttributeList Attr; - Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex, - Attribute::NoUnwind); + Attr = Attr.addFnAttribute(M.getContext(), Attribute::NoUnwind); // Initialize the callbacks. TsanFuncEntry = M.getOrInsertFunction("__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); @@ -563,6 +562,12 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, // all. if (F.hasFnAttribute(Attribute::Naked)) return false; + + // __attribute__(disable_sanitizer_instrumentation) prevents all kinds of + // instrumentation. + if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + return false; + initialize(*F.getParent()); SmallVector<InstructionInfo, 8> AllLoadsAndStores; SmallVector<Instruction*, 8> LocalLoadsAndStores; @@ -580,7 +585,8 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, AtomicAccesses.push_back(&Inst); else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) LocalLoadsAndStores.push_back(&Inst); - else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) { + else if ((isa<CallInst>(Inst) && !isa<DbgInfoIntrinsic>(Inst)) || + isa<InvokeInst>(Inst)) { if (CallInst *CI = dyn_cast<CallInst>(&Inst)) maybeMarkSanitizerLibraryCallNoBuiltin(CI, &TLI); if (isa<MemIntrinsic>(Inst)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp index 06b12149f597..1ca6ddabac5b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -103,9 +103,8 @@ CallInst *BundledRetainClaimRVs::insertRVCallWithColors( Instruction *InsertPt, CallBase *AnnotatedCall, const DenseMap<BasicBlock *, ColorVector> &BlockColors) { IRBuilder<> Builder(InsertPt); - bool IsRetainRV = objcarc::hasAttachedCallOpBundle(AnnotatedCall, true); - Function *Func = EP.get(IsRetainRV ? ARCRuntimeEntryPointKind::RetainRV - : ARCRuntimeEntryPointKind::ClaimRV); + Function *Func = *objcarc::getAttachedARCFunction(AnnotatedCall); + assert(Func && "operand isn't a Function"); Type *ParamTy = Func->getArg(0)->getType(); Value *CallArg = Builder.CreateBitCast(AnnotatedCall, ParamTy); auto *Call = @@ -115,16 +114,28 @@ CallInst *BundledRetainClaimRVs::insertRVCallWithColors( } BundledRetainClaimRVs::~BundledRetainClaimRVs() { - if (ContractPass) { - // At this point, we know that the annotated calls can't be tail calls as - // they are followed by marker instructions and retainRV/claimRV calls. Mark - // them as notail, so that the backend knows these calls can't be tail - // calls. - for (auto P : RVCalls) - if (auto *CI = dyn_cast<CallInst>(P.second)) + for (auto P : RVCalls) { + if (ContractPass) { + CallBase *CB = P.second; + // At this point, we know that the annotated calls can't be tail calls + // as they are followed by marker instructions and retainRV/claimRV + // calls. Mark them as notail so that the backend knows these calls + // can't be tail calls. + if (auto *CI = dyn_cast<CallInst>(CB)) CI->setTailCallKind(CallInst::TCK_NoTail); - } else { - for (auto P : RVCalls) + + if (UseMarker) { + // Remove the retainRV/claimRV function operand from the operand bundle + // to reflect the fact that the backend is responsible for emitting only + // the marker instruction, but not the retainRV/claimRV call. + OperandBundleDef OB("clang.arc.attachedcall", None); + auto *NewCB = CallBase::Create(CB, OB, CB); + CB->replaceAllUsesWith(NewCB); + CB->eraseFromParent(); + } + } + + if (!ContractPass || !UseMarker) EraseInstruction(P.first); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h index 1f9d76969bfd..2b47bec7ffe8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h @@ -105,8 +105,8 @@ CallInst *createCallInstWithColors( class BundledRetainClaimRVs { public: - BundledRetainClaimRVs(ARCRuntimeEntryPoints &P, bool ContractPass) - : EP(P), ContractPass(ContractPass) {} + BundledRetainClaimRVs(bool ContractPass, bool UseMarker) + : ContractPass(ContractPass), UseMarker(UseMarker) {} ~BundledRetainClaimRVs(); /// Insert a retainRV/claimRV call to the normal destination blocks of invokes @@ -155,8 +155,10 @@ private: /// A map of inserted retainRV/claimRV calls to annotated calls/invokes. DenseMap<CallInst *, CallBase *> RVCalls; - ARCRuntimeEntryPoints &EP; bool ContractPass; + + /// Indicates whether the target uses a special inline-asm marker. + bool UseMarker; }; } // end namespace objcarc diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp index 6a928f2c7ffb..210ec60f2f87 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp @@ -64,30 +64,29 @@ bool OptimizeBB(BasicBlock *BB) { bool Changed = false; Instruction *Push = nullptr; - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { - Instruction *Inst = &*I++; - switch (GetBasicARCInstKind(Inst)) { + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { + switch (GetBasicARCInstKind(&Inst)) { case ARCInstKind::AutoreleasepoolPush: - Push = Inst; + Push = &Inst; break; case ARCInstKind::AutoreleasepoolPop: // If this pop matches a push and nothing in between can autorelease, // zap the pair. - if (Push && cast<CallInst>(Inst)->getArgOperand(0) == Push) { + if (Push && cast<CallInst>(&Inst)->getArgOperand(0) == Push) { Changed = true; LLVM_DEBUG(dbgs() << "ObjCARCAPElim::OptimizeBB: Zapping push pop " "autorelease pair:\n" " Pop: " - << *Inst << "\n" + << Inst << "\n" << " Push: " << *Push << "\n"); - Inst->eraseFromParent(); + Inst.eraseFromParent(); Push->eraseFromParent(); } Push = nullptr; break; case ARCInstKind::CallOrUser: - if (MayAutorelease(cast<CallBase>(*Inst))) + if (MayAutorelease(cast<CallBase>(Inst))) Push = nullptr; break; default: diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 62161b5b6b40..c2ed94e8e1f6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -226,13 +226,6 @@ static StoreInst *findSafeStoreForStoreStrongContraction(LoadInst *Load, // of Inst. ARCInstKind Class = GetBasicARCInstKind(Inst); - // If Inst is an unrelated retain, we don't care about it. - // - // TODO: This is one area where the optimization could be made more - // aggressive. - if (IsRetain(Class)) - continue; - // If we have seen the store, but not the release... if (Store) { // We need to make sure that it is safe to move the release from its @@ -248,8 +241,18 @@ static StoreInst *findSafeStoreForStoreStrongContraction(LoadInst *Load, return nullptr; } - // Ok, now we know we have not seen a store yet. See if Inst can write to - // our load location, if it can not, just ignore the instruction. + // Ok, now we know we have not seen a store yet. + + // If Inst is a retain, we don't care about it as it doesn't prevent moving + // the load to the store. + // + // TODO: This is one area where the optimization could be made more + // aggressive. + if (IsRetain(Class)) + continue; + + // See if Inst can write to our load location, if it can not, just ignore + // the instruction. if (!isModSet(AA->getModRefInfo(Inst, Loc))) continue; @@ -431,13 +434,21 @@ bool ObjCARCContract::tryToPeepholeInstruction( LLVM_FALLTHROUGH; case ARCInstKind::RetainRV: case ARCInstKind::ClaimRV: { - // If we're compiling for a target which needs a special inline-asm - // marker to do the return value optimization and the retainRV/claimRV call - // wasn't bundled with a call, insert the marker now. + bool IsInstContainedInBundle = BundledInsts->contains(Inst); + + // Return now if the target doesn't need a special inline-asm marker. Return + // true if this is a bundled retainRV/claimRV call, which is going to be + // erased at the end of this pass, to avoid undoing objc-arc-expand and + // replacing uses of the retainRV/claimRV call's argument with its result. if (!RVInstMarker) - return false; + return IsInstContainedInBundle; + + // The target needs a special inline-asm marker. - if (BundledInsts->contains(Inst)) + // We don't have to emit the marker if this is a bundled call since the + // backend is responsible for emitting it. Return false to undo + // objc-arc-expand. + if (IsInstContainedInBundle) return false; BasicBlock::iterator BBI = Inst->getIterator(); @@ -537,7 +548,7 @@ bool ObjCARCContract::run(Function &F, AAResults *A, DominatorTree *D) { AA = A; DT = D; PA.setAA(A); - BundledRetainClaimRVs BRV(EP, true); + BundledRetainClaimRVs BRV(true, RVInstMarker); BundledInsts = &BRV; std::pair<bool, bool> R = BundledInsts->insertAfterInvokes(F, DT); diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp index d2121dcebe91..6b074ac5adab 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp @@ -56,12 +56,10 @@ static bool runImpl(Function &F) { LLVM_DEBUG(dbgs() << "ObjCARCExpand: Visiting Function: " << F.getName() << "\n"); - for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ++I) { - Instruction *Inst = &*I; + for (Instruction &Inst : instructions(&F)) { + LLVM_DEBUG(dbgs() << "ObjCARCExpand: Visiting: " << Inst << "\n"); - LLVM_DEBUG(dbgs() << "ObjCARCExpand: Visiting: " << *Inst << "\n"); - - switch (GetBasicARCInstKind(Inst)) { + switch (GetBasicARCInstKind(&Inst)) { case ARCInstKind::Retain: case ARCInstKind::RetainRV: case ARCInstKind::Autorelease: @@ -73,12 +71,12 @@ static bool runImpl(Function &F) { // harder. Undo any uses of this optimization that the front-end // emitted here. We'll redo them in the contract pass. Changed = true; - Value *Value = cast<CallInst>(Inst)->getArgOperand(0); - LLVM_DEBUG(dbgs() << "ObjCARCExpand: Old = " << *Inst + Value *Value = cast<CallInst>(&Inst)->getArgOperand(0); + LLVM_DEBUG(dbgs() << "ObjCARCExpand: Old = " << Inst << "\n" " New = " << *Value << "\n"); - Inst->replaceAllUsesWith(Value); + Inst.replaceAllUsesWith(Value); break; } default: diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index ada6aa8d9b6d..0fa4904456cd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -2229,13 +2229,12 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { // Then, for each destroyWeak with an alloca operand, check to see if // the alloca and all its users can be zapped. - for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ) { - Instruction *Inst = &*I++; - ARCInstKind Class = GetBasicARCInstKind(Inst); + for (Instruction &Inst : llvm::make_early_inc_range(instructions(F))) { + ARCInstKind Class = GetBasicARCInstKind(&Inst); if (Class != ARCInstKind::DestroyWeak) continue; - CallInst *Call = cast<CallInst>(Inst); + CallInst *Call = cast<CallInst>(&Inst); Value *Arg = Call->getArgOperand(0); if (AllocaInst *Alloca = dyn_cast<AllocaInst>(Arg)) { for (User *U : Alloca->users()) { @@ -2250,8 +2249,8 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { } } Changed = true; - for (auto UI = Alloca->user_begin(), UE = Alloca->user_end(); UI != UE;) { - CallInst *UserInst = cast<CallInst>(*UI++); + for (User *U : llvm::make_early_inc_range(Alloca->users())) { + CallInst *UserInst = cast<CallInst>(U); switch (GetBasicARCInstKind(UserInst)) { case ARCInstKind::InitWeak: case ARCInstKind::StoreWeak: @@ -2462,7 +2461,7 @@ bool ObjCARCOpt::run(Function &F, AAResults &AA) { return false; Changed = CFGChanged = false; - BundledRetainClaimRVs BRV(EP, false); + BundledRetainClaimRVs BRV(false, objcarc::getRVInstMarker(*F.getParent())); BundledInsts = &BRV; LLVM_DEBUG(dbgs() << "<<< ObjCARCOpt: Visiting Function: " << F.getName() diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h index a63e356ce1fc..6d0a67c91cfa 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h @@ -56,7 +56,8 @@ class ProvenanceAnalysis { CachedResultsTy CachedResults; - DenseMap<const Value *, WeakTrackingVH> UnderlyingObjCPtrCache; + DenseMap<const Value *, std::pair<WeakVH, WeakTrackingVH>> + UnderlyingObjCPtrCache; bool relatedCheck(const Value *A, const Value *B); bool relatedSelect(const SelectInst *A, const Value *B); diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp index 6fdfe787d438..fe637ee066a4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp @@ -58,11 +58,11 @@ bool PAEval::runOnFunction(Function &F) { for (auto &Arg : F.args()) insertIfNamed(Values, &Arg); - for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { - insertIfNamed(Values, &*I); + for (Instruction &I : instructions(F)) { + insertIfNamed(Values, &I); - for (auto &Op : I->operands()) - insertIfNamed(Values, Op); + for (auto &Op : I.operands()) + insertIfNamed(Values, Op); } ProvenanceAnalysis PA; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp index 6f3fdb88eda5..b693acceb3f6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -538,7 +538,7 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { // that have no side effects and do not influence the control flow or return // value of the function, and may therefore be deleted safely. // NOTE: We reuse the Worklist vector here for memory efficiency. - for (Instruction &I : instructions(F)) { + for (Instruction &I : llvm::reverse(instructions(F))) { // Check if the instruction is alive. if (isLive(&I)) continue; @@ -554,9 +554,11 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { // Prepare to delete. Worklist.push_back(&I); salvageDebugInfo(I); - I.dropAllReferences(); } + for (Instruction *&I : Worklist) + I->dropAllReferences(); + for (Instruction *&I : Worklist) { ++NumRemoved; I->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp index c06125788f37..6c2467db79f7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -53,7 +53,7 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // in the def-use chain needs to be changed. auto *J = dyn_cast<Instruction>(JU); if (J && J->getType()->isIntOrIntVectorTy() && - !DB.getDemandedBits(J).isAllOnesValue()) { + !DB.getDemandedBits(J).isAllOnes()) { Visited.insert(J); WorkList.push_back(J); } @@ -84,7 +84,7 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // that in the def-use chain needs to be changed. auto *K = dyn_cast<Instruction>(KU); if (K && Visited.insert(K).second && K->getType()->isIntOrIntVectorTy() && - !DB.getDemandedBits(K).isAllOnesValue()) + !DB.getDemandedBits(K).isAllOnes()) WorkList.push_back(K); } } @@ -103,12 +103,9 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { // Remove instructions that are dead, either because they were not reached // during analysis or have no demanded bits. if (DB.isInstructionDead(&I) || - (I.getType()->isIntOrIntVectorTy() && - DB.getDemandedBits(&I).isNullValue() && + (I.getType()->isIntOrIntVectorTy() && DB.getDemandedBits(&I).isZero() && wouldInstructionBeTriviallyDead(&I))) { - salvageDebugInfo(I); Worklist.push_back(&I); - I.dropAllReferences(); Changed = true; continue; } @@ -155,6 +152,11 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { } } + for (Instruction *&I : llvm::reverse(Worklist)) { + salvageDebugInfo(*I); + I->dropAllReferences(); + } + for (Instruction *&I : Worklist) { ++NumRemoved; I->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index 2eb94b721d96..95de59fa8262 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -467,7 +467,7 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallBase &CB, BasicBlock *StopAt = CSDTNode ? CSDTNode->getIDom()->getBlock() : nullptr; SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS; - for (auto *Pred : make_range(Preds.rbegin(), Preds.rend())) { + for (auto *Pred : llvm::reverse(Preds)) { ConditionsTy Conditions; // Record condition on edge BB(CS) <- Pred recordCondition(CB, Pred, CB.getParent(), Conditions); @@ -505,8 +505,7 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI, DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); bool Changed = false; - for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { - BasicBlock &BB = *BI++; + for (BasicBlock &BB : llvm::make_early_inc_range(F)) { auto II = BB.getFirstNonPHIOrDbg()->getIterator(); auto IE = BB.getTerminator()->getIterator(); // Iterate until we reach the terminator instruction. tryToSplitCallSite diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 535f50d4f904..27f54f8026e1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -762,7 +762,7 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx, cast<PointerType>(Ty)->getAddressSpace()); Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", InsertionPt); - Mat = GetElementPtrInst::Create(Int8PtrTy->getElementType(), Base, + Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Offset, "mat_gep", InsertionPt); Mat = new BitCastInst(Mat, Ty, "mat_bitcast", InsertionPt); } else @@ -819,10 +819,9 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, // Aside from constant GEPs, only constant cast expressions are collected. assert(ConstExpr->isCast() && "ConstExpr should be a cast"); - Instruction *ConstExprInst = ConstExpr->getAsInstruction(); + Instruction *ConstExprInst = ConstExpr->getAsInstruction( + findMatInsertPt(ConstUser.Inst, ConstUser.OpndIdx)); ConstExprInst->setOperand(0, Mat); - ConstExprInst->insertBefore(findMatInsertPt(ConstUser.Inst, - ConstUser.OpndIdx)); // Use the same debug location as the instruction we are about to update. ConstExprInst->setDebugLoc(ConstUser.Inst->getDebugLoc()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index efd1c025d0cd..7f2d5d7d9987 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstraintSystem.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -268,6 +269,31 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { continue; WorkList.emplace_back(DT.getNode(&BB)); + // True as long as long as the current instruction is guaranteed to execute. + bool GuaranteedToExecute = true; + // Scan BB for assume calls. + // TODO: also use this scan to queue conditions to simplify, so we can + // interleave facts from assumes and conditions to simplify in a single + // basic block. And to skip another traversal of each basic block when + // simplifying. + for (Instruction &I : BB) { + Value *Cond; + // For now, just handle assumes with a single compare as condition. + if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) && + isa<CmpInst>(Cond)) { + if (GuaranteedToExecute) { + // The assume is guaranteed to execute when BB is entered, hence Cond + // holds on entry to BB. + WorkList.emplace_back(DT.getNode(&BB), cast<CmpInst>(Cond), false); + } else { + // Otherwise the condition only holds in the successors. + for (BasicBlock *Succ : successors(&BB)) + WorkList.emplace_back(DT.getNode(Succ), cast<CmpInst>(Cond), false); + } + } + GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I); + } + auto *Br = dyn_cast<BranchInst>(BB.getTerminator()); if (!Br || !Br->isConditional()) continue; @@ -395,8 +421,13 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { for (auto &E : reverse(DFSInStack)) dbgs() << " C " << *E.Condition << " " << E.IsNot << "\n"; }); - Cmp->replaceAllUsesWith( - ConstantInt::getTrue(F.getParent()->getContext())); + Cmp->replaceUsesWithIf( + ConstantInt::getTrue(F.getParent()->getContext()), [](Use &U) { + // Conditions in an assume trivially simplify to true. Skip uses + // in assume calls to not destroy the available information. + auto *II = dyn_cast<IntrinsicInst>(U.getUser()); + return !II || II->getIntrinsicID() != Intrinsic::assume; + }); NumCondsRemoved++; Changed = true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 36cbd42a5fdd..ca9567dc7ac8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -67,6 +67,7 @@ STATISTIC(NumUDivURemsNarrowed, STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); STATISTIC(NumSExt, "Number of sext converted to zext"); +STATISTIC(NumSICmps, "Number of signed icmp preds simplified to unsigned"); STATISTIC(NumAnd, "Number of ands removed"); STATISTIC(NumNW, "Number of no-wrap deductions"); STATISTIC(NumNSW, "Number of no-signed-wrap deductions"); @@ -295,11 +296,34 @@ static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) { return true; } +static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { + // Only for signed relational comparisons of scalar integers. + if (Cmp->getType()->isVectorTy() || + !Cmp->getOperand(0)->getType()->isIntegerTy()) + return false; + + if (!Cmp->isSigned()) + return false; + + ICmpInst::Predicate UnsignedPred = + ConstantRange::getEquivalentPredWithFlippedSignedness( + Cmp->getPredicate(), LVI->getConstantRange(Cmp->getOperand(0), Cmp), + LVI->getConstantRange(Cmp->getOperand(1), Cmp)); + + if (UnsignedPred == ICmpInst::Predicate::BAD_ICMP_PREDICATE) + return false; + + ++NumSICmps; + Cmp->setPredicate(UnsignedPred); + + return true; +} + /// See if LazyValueInfo's ability to exploit edge conditions or range /// information is sufficient to prove this comparison. Even for local /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. -static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { +static bool constantFoldCmp(CmpInst *Cmp, LazyValueInfo *LVI) { Value *Op0 = Cmp->getOperand(0); auto *C = dyn_cast<Constant>(Cmp->getOperand(1)); if (!C) @@ -318,6 +342,17 @@ static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { return true; } +static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { + if (constantFoldCmp(Cmp, LVI)) + return true; + + if (auto *ICmp = dyn_cast<ICmpInst>(Cmp)) + if (processICmp(ICmp, LVI)) + return true; + + return false; +} + /// Simplify a switch instruction by removing cases which can never fire. If the /// uselessness of a case could be determined locally then constant propagation /// would already have figured it out. Instead, walk the predecessors and @@ -341,7 +376,13 @@ static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, // ConstantFoldTerminator() as the underlying SwitchInst can be changed. SwitchInstProfUpdateWrapper SI(*I); - for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { + APInt Low = + APInt::getSignedMaxValue(Cond->getType()->getScalarSizeInBits()); + APInt High = + APInt::getSignedMinValue(Cond->getType()->getScalarSizeInBits()); + + SwitchInst::CaseIt CI = SI->case_begin(); + for (auto CE = SI->case_end(); CI != CE;) { ConstantInt *Case = CI->getCaseValue(); LazyValueInfo::Tristate State = LVI->getPredicateAt(CmpInst::ICMP_EQ, Cond, Case, I, @@ -374,9 +415,28 @@ static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, break; } + // Get Lower/Upper bound from switch cases. + Low = APIntOps::smin(Case->getValue(), Low); + High = APIntOps::smax(Case->getValue(), High); + // Increment the case iterator since we didn't delete it. ++CI; } + + // Try to simplify default case as unreachable + if (CI == SI->case_end() && SI->getNumCases() != 0 && + !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg())) { + const ConstantRange SIRange = + LVI->getConstantRange(SI->getCondition(), SI); + + // If the numbered switch cases cover the entire range of the condition, + // then the default case is not reachable. + if (SIRange.getSignedMin() == Low && SIRange.getSignedMax() == High && + SI->getNumCases() == High - Low + 1) { + createUnreachableSwitchDefault(SI, &DTU); + Changed = true; + } + } } if (Changed) @@ -690,7 +750,7 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { // sdiv/srem is UB if divisor is -1 and divident is INT_MIN, so unless we can // prove that such a combination is impossible, we need to bump the bitwidth. - if (CRs[1]->contains(APInt::getAllOnesValue(OrigWidth)) && + if (CRs[1]->contains(APInt::getAllOnes(OrigWidth)) && CRs[0]->contains( APInt::getSignedMinValue(MinSignedBits).sextOrSelf(OrigWidth))) ++MinSignedBits; @@ -1023,49 +1083,48 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, // blocks. for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { bool BBChanged = false; - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - Instruction *II = &*BI++; - switch (II->getOpcode()) { + for (Instruction &II : llvm::make_early_inc_range(*BB)) { + switch (II.getOpcode()) { case Instruction::Select: - BBChanged |= processSelect(cast<SelectInst>(II), LVI); + BBChanged |= processSelect(cast<SelectInst>(&II), LVI); break; case Instruction::PHI: - BBChanged |= processPHI(cast<PHINode>(II), LVI, DT, SQ); + BBChanged |= processPHI(cast<PHINode>(&II), LVI, DT, SQ); break; case Instruction::ICmp: case Instruction::FCmp: - BBChanged |= processCmp(cast<CmpInst>(II), LVI); + BBChanged |= processCmp(cast<CmpInst>(&II), LVI); break; case Instruction::Load: case Instruction::Store: - BBChanged |= processMemAccess(II, LVI); + BBChanged |= processMemAccess(&II, LVI); break; case Instruction::Call: case Instruction::Invoke: - BBChanged |= processCallSite(cast<CallBase>(*II), LVI); + BBChanged |= processCallSite(cast<CallBase>(II), LVI); break; case Instruction::SRem: case Instruction::SDiv: - BBChanged |= processSDivOrSRem(cast<BinaryOperator>(II), LVI); + BBChanged |= processSDivOrSRem(cast<BinaryOperator>(&II), LVI); break; case Instruction::UDiv: case Instruction::URem: - BBChanged |= processUDivOrURem(cast<BinaryOperator>(II), LVI); + BBChanged |= processUDivOrURem(cast<BinaryOperator>(&II), LVI); break; case Instruction::AShr: - BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); + BBChanged |= processAShr(cast<BinaryOperator>(&II), LVI); break; case Instruction::SExt: - BBChanged |= processSExt(cast<SExtInst>(II), LVI); + BBChanged |= processSExt(cast<SExtInst>(&II), LVI); break; case Instruction::Add: case Instruction::Sub: case Instruction::Mul: case Instruction::Shl: - BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI); + BBChanged |= processBinOp(cast<BinaryOperator>(&II), LVI); break; case Instruction::And: - BBChanged |= processAnd(cast<BinaryOperator>(II), LVI); + BBChanged |= processAnd(cast<BinaryOperator>(&II), LVI); break; } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 90679bcac4b7..8c4523206070 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -1,9 +1,8 @@ //===- DFAJumpThreading.cpp - Threads a switch statement inside a loop ----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -84,8 +83,6 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <deque> -#include <unordered_map> -#include <unordered_set> using namespace llvm; @@ -147,8 +144,7 @@ private: Stack.push_back(SIToUnfold); while (!Stack.empty()) { - SelectInstToUnfold SIToUnfold = Stack.back(); - Stack.pop_back(); + SelectInstToUnfold SIToUnfold = Stack.pop_back_val(); std::vector<SelectInstToUnfold> NewSIsToUnfold; std::vector<BasicBlock *> NewBBs; @@ -174,6 +170,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } @@ -350,7 +347,7 @@ struct ClonedBlock { typedef std::deque<BasicBlock *> PathType; typedef std::vector<PathType> PathsType; -typedef std::set<const BasicBlock *> VisitedBlocks; +typedef SmallPtrSet<const BasicBlock *, 8> VisitedBlocks; typedef std::vector<ClonedBlock> CloneList; // This data structure keeps track of all blocks that have been cloned. If two @@ -493,7 +490,7 @@ private: } bool isPredictableValue(Value *InpVal, SmallSet<Value *, 16> &SeenValues) { - if (SeenValues.find(InpVal) != SeenValues.end()) + if (SeenValues.contains(InpVal)) return true; if (isa<ConstantInt>(InpVal)) @@ -508,7 +505,7 @@ private: void addInstToQueue(Value *Val, std::deque<Instruction *> &Q, SmallSet<Value *, 16> &SeenValues) { - if (SeenValues.find(Val) != SeenValues.end()) + if (SeenValues.contains(Val)) return; if (Instruction *I = dyn_cast<Instruction>(Val)) Q.push_back(I); @@ -533,7 +530,7 @@ private: return false; if (isa<PHINode>(SIUse) && - SIBB->getSingleSuccessor() != dyn_cast<Instruction>(SIUse)->getParent()) + SIBB->getSingleSuccessor() != cast<Instruction>(SIUse)->getParent()) return false; // If select will not be sunk during unfolding, and it is in the same basic @@ -621,13 +618,9 @@ private: // Some blocks have multiple edges to the same successor, and this set // is used to prevent a duplicate path from being generated SmallSet<BasicBlock *, 4> Successors; - - for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) { - BasicBlock *Succ = *SI; - - if (Successors.find(Succ) != Successors.end()) + for (BasicBlock *Succ : successors(BB)) { + if (!Successors.insert(Succ).second) continue; - Successors.insert(Succ); // Found a cycle through the SwitchBlock if (Succ == SwitchBlock) { @@ -636,7 +629,7 @@ private: } // We have encountered a cycle, do not get caught in it - if (Visited.find(Succ) != Visited.end()) + if (Visited.contains(Succ)) continue; PathsType SuccPaths = paths(Succ, Visited, PathDepth + 1); @@ -668,15 +661,14 @@ private: SmallSet<Value *, 16> SeenValues; while (!Stack.empty()) { - PHINode *CurPhi = Stack.back(); - Stack.pop_back(); + PHINode *CurPhi = Stack.pop_back_val(); Res[CurPhi->getParent()] = CurPhi; SeenValues.insert(CurPhi); for (Value *Incoming : CurPhi->incoming_values()) { if (Incoming == FirstDef || isa<ConstantInt>(Incoming) || - SeenValues.find(Incoming) != SeenValues.end()) { + SeenValues.contains(Incoming)) { continue; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index d22b3f409585..a8ec8bb97970 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -13,10 +13,10 @@ // in between both MemoryDefs. A bit more concretely: // // For all MemoryDefs StartDef: -// 1. Get the next dominating clobbering MemoryDef (EarlierAccess) by walking +// 1. Get the next dominating clobbering MemoryDef (MaybeDeadAccess) by walking // upwards. -// 2. Check that there are no reads between EarlierAccess and the StartDef by -// checking all uses starting at EarlierAccess and walking until we see +// 2. Check that there are no reads between MaybeDeadAccess and the StartDef by +// checking all uses starting at MaybeDeadAccess and walking until we see // StartDef. // 3. For each found CurrentDef, check that: // 1. There are no barrier instructions between CurrentDef and StartDef (like @@ -56,6 +56,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" @@ -78,6 +79,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> @@ -122,7 +124,7 @@ EnablePartialStoreMerging("enable-dse-partial-store-merging", static cl::opt<unsigned> MemorySSAScanLimit("dse-memoryssa-scanlimit", cl::init(150), cl::Hidden, cl::desc("The number of memory instructions to scan for " - "dead store elimination (default = 100)")); + "dead store elimination (default = 150)")); static cl::opt<unsigned> MemorySSAUpwardsStepLimit( "dse-memoryssa-walklimit", cl::init(90), cl::Hidden, cl::desc("The maximum number of steps while walking upwards to find " @@ -203,39 +205,6 @@ static bool hasAnalyzableMemoryWrite(Instruction *I, return false; } -/// Return a Location stored to by the specified instruction. If isRemovable -/// returns true, this function and getLocForRead completely describe the memory -/// operations for this instruction. -static MemoryLocation getLocForWrite(Instruction *Inst, - const TargetLibraryInfo &TLI) { - if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) - return MemoryLocation::get(SI); - - // memcpy/memmove/memset. - if (auto *MI = dyn_cast<AnyMemIntrinsic>(Inst)) - return MemoryLocation::getForDest(MI); - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - switch (II->getIntrinsicID()) { - default: - return MemoryLocation(); // Unhandled intrinsic. - case Intrinsic::init_trampoline: - return MemoryLocation::getAfter(II->getArgOperand(0)); - case Intrinsic::masked_store: - return MemoryLocation::getForArgument(II, 1, TLI); - case Intrinsic::lifetime_end: { - uint64_t Len = cast<ConstantInt>(II->getArgOperand(0))->getZExtValue(); - return MemoryLocation(II->getArgOperand(1), Len); - } - } - } - if (auto *CB = dyn_cast<CallBase>(Inst)) - // All the supported TLI functions so far happen to have dest as their - // first argument. - return MemoryLocation::getAfter(CB->getArgOperand(0)); - return MemoryLocation(); -} - /// If the value of this instruction and the memory it writes to is unused, may /// we delete this instruction? static bool isRemovable(Instruction *I) { @@ -333,147 +302,146 @@ enum OverwriteResult { } // end anonymous namespace /// Check if two instruction are masked stores that completely -/// overwrite one another. More specifically, \p Later has to -/// overwrite \p Earlier. -static OverwriteResult isMaskedStoreOverwrite(const Instruction *Later, - const Instruction *Earlier, +/// overwrite one another. More specifically, \p KillingI has to +/// overwrite \p DeadI. +static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI, + const Instruction *DeadI, BatchAAResults &AA) { - const auto *IIL = dyn_cast<IntrinsicInst>(Later); - const auto *IIE = dyn_cast<IntrinsicInst>(Earlier); - if (IIL == nullptr || IIE == nullptr) + const auto *KillingII = dyn_cast<IntrinsicInst>(KillingI); + const auto *DeadII = dyn_cast<IntrinsicInst>(DeadI); + if (KillingII == nullptr || DeadII == nullptr) return OW_Unknown; - if (IIL->getIntrinsicID() != Intrinsic::masked_store || - IIE->getIntrinsicID() != Intrinsic::masked_store) + if (KillingII->getIntrinsicID() != Intrinsic::masked_store || + DeadII->getIntrinsicID() != Intrinsic::masked_store) return OW_Unknown; // Pointers. - Value *LP = IIL->getArgOperand(1)->stripPointerCasts(); - Value *EP = IIE->getArgOperand(1)->stripPointerCasts(); - if (LP != EP && !AA.isMustAlias(LP, EP)) + Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts(); + Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts(); + if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr)) return OW_Unknown; // Masks. - // TODO: check that Later's mask is a superset of the Earlier's mask. - if (IIL->getArgOperand(3) != IIE->getArgOperand(3)) + // TODO: check that KillingII's mask is a superset of the DeadII's mask. + if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3)) return OW_Unknown; return OW_Complete; } -/// Return 'OW_Complete' if a store to the 'Later' location completely -/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the -/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the -/// beginning of the 'Earlier' location is overwritten by 'Later'. -/// 'OW_PartialEarlierWithFullLater' means that an earlier (big) store was -/// overwritten by a latter (smaller) store which doesn't write outside the big +/// Return 'OW_Complete' if a store to the 'KillingLoc' location completely +/// overwrites a store to the 'DeadLoc' location, 'OW_End' if the end of the +/// 'DeadLoc' location is completely overwritten by 'KillingLoc', 'OW_Begin' +/// if the beginning of the 'DeadLoc' location is overwritten by 'KillingLoc'. +/// 'OW_PartialEarlierWithFullLater' means that a dead (big) store was +/// overwritten by a killing (smaller) store which doesn't write outside the big /// store's memory locations. Returns 'OW_Unknown' if nothing can be determined. -/// NOTE: This function must only be called if both \p Later and \p Earlier -/// write to the same underlying object with valid \p EarlierOff and \p -/// LaterOff. -static OverwriteResult isPartialOverwrite(const MemoryLocation &Later, - const MemoryLocation &Earlier, - int64_t EarlierOff, int64_t LaterOff, - Instruction *DepWrite, +/// NOTE: This function must only be called if both \p KillingLoc and \p +/// DeadLoc belong to the same underlying object with valid \p KillingOff and +/// \p DeadOff. +static OverwriteResult isPartialOverwrite(const MemoryLocation &KillingLoc, + const MemoryLocation &DeadLoc, + int64_t KillingOff, int64_t DeadOff, + Instruction *DeadI, InstOverlapIntervalsTy &IOL) { - const uint64_t LaterSize = Later.Size.getValue(); - const uint64_t EarlierSize = Earlier.Size.getValue(); + const uint64_t KillingSize = KillingLoc.Size.getValue(); + const uint64_t DeadSize = DeadLoc.Size.getValue(); // We may now overlap, although the overlap is not complete. There might also // be other incomplete overlaps, and together, they might cover the complete - // earlier write. + // dead store. // Note: The correctness of this logic depends on the fact that this function // is not even called providing DepWrite when there are any intervening reads. if (EnablePartialOverwriteTracking && - LaterOff < int64_t(EarlierOff + EarlierSize) && - int64_t(LaterOff + LaterSize) >= EarlierOff) { + KillingOff < int64_t(DeadOff + DeadSize) && + int64_t(KillingOff + KillingSize) >= DeadOff) { // Insert our part of the overlap into the map. - auto &IM = IOL[DepWrite]; - LLVM_DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff - << ", " << int64_t(EarlierOff + EarlierSize) - << ") Later [" << LaterOff << ", " - << int64_t(LaterOff + LaterSize) << ")\n"); + auto &IM = IOL[DeadI]; + LLVM_DEBUG(dbgs() << "DSE: Partial overwrite: DeadLoc [" << DeadOff << ", " + << int64_t(DeadOff + DeadSize) << ") KillingLoc [" + << KillingOff << ", " << int64_t(KillingOff + KillingSize) + << ")\n"); // Make sure that we only insert non-overlapping intervals and combine // adjacent intervals. The intervals are stored in the map with the ending // offset as the key (in the half-open sense) and the starting offset as // the value. - int64_t LaterIntStart = LaterOff, LaterIntEnd = LaterOff + LaterSize; + int64_t KillingIntStart = KillingOff; + int64_t KillingIntEnd = KillingOff + KillingSize; - // Find any intervals ending at, or after, LaterIntStart which start - // before LaterIntEnd. - auto ILI = IM.lower_bound(LaterIntStart); - if (ILI != IM.end() && ILI->second <= LaterIntEnd) { + // Find any intervals ending at, or after, KillingIntStart which start + // before KillingIntEnd. + auto ILI = IM.lower_bound(KillingIntStart); + if (ILI != IM.end() && ILI->second <= KillingIntEnd) { // This existing interval is overlapped with the current store somewhere - // in [LaterIntStart, LaterIntEnd]. Merge them by erasing the existing + // in [KillingIntStart, KillingIntEnd]. Merge them by erasing the existing // intervals and adjusting our start and end. - LaterIntStart = std::min(LaterIntStart, ILI->second); - LaterIntEnd = std::max(LaterIntEnd, ILI->first); + KillingIntStart = std::min(KillingIntStart, ILI->second); + KillingIntEnd = std::max(KillingIntEnd, ILI->first); ILI = IM.erase(ILI); // Continue erasing and adjusting our end in case other previous // intervals are also overlapped with the current store. // - // |--- ealier 1 ---| |--- ealier 2 ---| - // |------- later---------| + // |--- dead 1 ---| |--- dead 2 ---| + // |------- killing---------| // - while (ILI != IM.end() && ILI->second <= LaterIntEnd) { - assert(ILI->second > LaterIntStart && "Unexpected interval"); - LaterIntEnd = std::max(LaterIntEnd, ILI->first); + while (ILI != IM.end() && ILI->second <= KillingIntEnd) { + assert(ILI->second > KillingIntStart && "Unexpected interval"); + KillingIntEnd = std::max(KillingIntEnd, ILI->first); ILI = IM.erase(ILI); } } - IM[LaterIntEnd] = LaterIntStart; + IM[KillingIntEnd] = KillingIntStart; ILI = IM.begin(); - if (ILI->second <= EarlierOff && - ILI->first >= int64_t(EarlierOff + EarlierSize)) { - LLVM_DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" - << EarlierOff << ", " - << int64_t(EarlierOff + EarlierSize) - << ") Composite Later [" << ILI->second << ", " + if (ILI->second <= DeadOff && ILI->first >= int64_t(DeadOff + DeadSize)) { + LLVM_DEBUG(dbgs() << "DSE: Full overwrite from partials: DeadLoc [" + << DeadOff << ", " << int64_t(DeadOff + DeadSize) + << ") Composite KillingLoc [" << ILI->second << ", " << ILI->first << ")\n"); ++NumCompletePartials; return OW_Complete; } } - // Check for an earlier store which writes to all the memory locations that - // the later store writes to. - if (EnablePartialStoreMerging && LaterOff >= EarlierOff && - int64_t(EarlierOff + EarlierSize) > LaterOff && - uint64_t(LaterOff - EarlierOff) + LaterSize <= EarlierSize) { - LLVM_DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" - << EarlierOff << ", " - << int64_t(EarlierOff + EarlierSize) - << ") by a later store [" << LaterOff << ", " - << int64_t(LaterOff + LaterSize) << ")\n"); + // Check for a dead store which writes to all the memory locations that + // the killing store writes to. + if (EnablePartialStoreMerging && KillingOff >= DeadOff && + int64_t(DeadOff + DeadSize) > KillingOff && + uint64_t(KillingOff - DeadOff) + KillingSize <= DeadSize) { + LLVM_DEBUG(dbgs() << "DSE: Partial overwrite a dead load [" << DeadOff + << ", " << int64_t(DeadOff + DeadSize) + << ") by a killing store [" << KillingOff << ", " + << int64_t(KillingOff + KillingSize) << ")\n"); // TODO: Maybe come up with a better name? return OW_PartialEarlierWithFullLater; } - // Another interesting case is if the later store overwrites the end of the - // earlier store. + // Another interesting case is if the killing store overwrites the end of the + // dead store. // - // |--earlier--| - // |-- later --| + // |--dead--| + // |-- killing --| // - // In this case we may want to trim the size of earlier to avoid generating - // writes to addresses which will definitely be overwritten later + // In this case we may want to trim the size of dead store to avoid + // generating stores to addresses which will definitely be overwritten killing + // store. if (!EnablePartialOverwriteTracking && - (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + EarlierSize) && - int64_t(LaterOff + LaterSize) >= int64_t(EarlierOff + EarlierSize))) + (KillingOff > DeadOff && KillingOff < int64_t(DeadOff + DeadSize) && + int64_t(KillingOff + KillingSize) >= int64_t(DeadOff + DeadSize))) return OW_End; - // Finally, we also need to check if the later store overwrites the beginning - // of the earlier store. + // Finally, we also need to check if the killing store overwrites the + // beginning of the dead store. // - // |--earlier--| - // |-- later --| + // |--dead--| + // |-- killing --| // // In this case we may want to move the destination address and trim the size - // of earlier to avoid generating writes to addresses which will definitely - // be overwritten later. + // of dead store to avoid generating stores to addresses which will definitely + // be overwritten killing store. if (!EnablePartialOverwriteTracking && - (LaterOff <= EarlierOff && int64_t(LaterOff + LaterSize) > EarlierOff)) { - assert(int64_t(LaterOff + LaterSize) < int64_t(EarlierOff + EarlierSize) && + (KillingOff <= DeadOff && int64_t(KillingOff + KillingSize) > DeadOff)) { + assert(int64_t(KillingOff + KillingSize) < int64_t(DeadOff + DeadSize) && "Expect to be handled as OW_Complete"); return OW_Begin; } @@ -505,7 +473,12 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, BasicBlock::iterator SecondBBI(SecondI); BasicBlock *FirstBB = FirstI->getParent(); BasicBlock *SecondBB = SecondI->getParent(); - MemoryLocation MemLoc = MemoryLocation::get(SecondI); + MemoryLocation MemLoc; + if (auto *MemSet = dyn_cast<MemSetInst>(SecondI)) + MemLoc = MemoryLocation::getForDest(MemSet); + else + MemLoc = MemoryLocation::get(SecondI); + auto *MemLocPtr = const_cast<Value *>(MemLoc.Ptr); // Start checking the SecondBB. @@ -568,11 +541,11 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, return true; } -static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierStart, - uint64_t &EarlierSize, int64_t LaterStart, - uint64_t LaterSize, bool IsOverwriteEnd) { - auto *EarlierIntrinsic = cast<AnyMemIntrinsic>(EarlierWrite); - Align PrefAlign = EarlierIntrinsic->getDestAlign().valueOrOne(); +static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, + uint64_t &DeadSize, int64_t KillingStart, + uint64_t KillingSize, bool IsOverwriteEnd) { + auto *DeadIntrinsic = cast<AnyMemIntrinsic>(DeadI); + Align PrefAlign = DeadIntrinsic->getDestAlign().valueOrOne(); // We assume that memet/memcpy operates in chunks of the "largest" native // type size and aligned on the same value. That means optimal start and size @@ -593,19 +566,19 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierStart, // Compute start and size of the region to remove. Make sure 'PrefAlign' is // maintained on the remaining store. if (IsOverwriteEnd) { - // Calculate required adjustment for 'LaterStart'in order to keep remaining - // store size aligned on 'PerfAlign'. + // Calculate required adjustment for 'KillingStart' in order to keep + // remaining store size aligned on 'PerfAlign'. uint64_t Off = - offsetToAlignment(uint64_t(LaterStart - EarlierStart), PrefAlign); - ToRemoveStart = LaterStart + Off; - if (EarlierSize <= uint64_t(ToRemoveStart - EarlierStart)) + offsetToAlignment(uint64_t(KillingStart - DeadStart), PrefAlign); + ToRemoveStart = KillingStart + Off; + if (DeadSize <= uint64_t(ToRemoveStart - DeadStart)) return false; - ToRemoveSize = EarlierSize - uint64_t(ToRemoveStart - EarlierStart); + ToRemoveSize = DeadSize - uint64_t(ToRemoveStart - DeadStart); } else { - ToRemoveStart = EarlierStart; - assert(LaterSize >= uint64_t(EarlierStart - LaterStart) && + ToRemoveStart = DeadStart; + assert(KillingSize >= uint64_t(DeadStart - KillingStart) && "Not overlapping accesses?"); - ToRemoveSize = LaterSize - uint64_t(EarlierStart - LaterStart); + ToRemoveSize = KillingSize - uint64_t(DeadStart - KillingStart); // Calculate required adjustment for 'ToRemoveSize'in order to keep // start of the remaining store aligned on 'PerfAlign'. uint64_t Off = offsetToAlignment(ToRemoveSize, PrefAlign); @@ -619,10 +592,10 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierStart, } assert(ToRemoveSize > 0 && "Shouldn't reach here if nothing to remove"); - assert(EarlierSize > ToRemoveSize && "Can't remove more than original size"); + assert(DeadSize > ToRemoveSize && "Can't remove more than original size"); - uint64_t NewSize = EarlierSize - ToRemoveSize; - if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(EarlierWrite)) { + uint64_t NewSize = DeadSize - ToRemoveSize; + if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(DeadI)) { // When shortening an atomic memory intrinsic, the newly shortened // length must remain an integer multiple of the element size. const uint32_t ElementSize = AMI->getElementSizeInBytes(); @@ -631,65 +604,62 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierStart, } LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " - << (IsOverwriteEnd ? "END" : "BEGIN") << ": " - << *EarlierWrite << "\n KILLER [" << ToRemoveStart << ", " + << (IsOverwriteEnd ? "END" : "BEGIN") << ": " << *DeadI + << "\n KILLER [" << ToRemoveStart << ", " << int64_t(ToRemoveStart + ToRemoveSize) << ")\n"); - Value *EarlierWriteLength = EarlierIntrinsic->getLength(); - Value *TrimmedLength = - ConstantInt::get(EarlierWriteLength->getType(), NewSize); - EarlierIntrinsic->setLength(TrimmedLength); - EarlierIntrinsic->setDestAlignment(PrefAlign); + Value *DeadWriteLength = DeadIntrinsic->getLength(); + Value *TrimmedLength = ConstantInt::get(DeadWriteLength->getType(), NewSize); + DeadIntrinsic->setLength(TrimmedLength); + DeadIntrinsic->setDestAlignment(PrefAlign); if (!IsOverwriteEnd) { - Value *OrigDest = EarlierIntrinsic->getRawDest(); + Value *OrigDest = DeadIntrinsic->getRawDest(); Type *Int8PtrTy = - Type::getInt8PtrTy(EarlierIntrinsic->getContext(), + Type::getInt8PtrTy(DeadIntrinsic->getContext(), OrigDest->getType()->getPointerAddressSpace()); Value *Dest = OrigDest; if (OrigDest->getType() != Int8PtrTy) - Dest = CastInst::CreatePointerCast(OrigDest, Int8PtrTy, "", EarlierWrite); + Dest = CastInst::CreatePointerCast(OrigDest, Int8PtrTy, "", DeadI); Value *Indices[1] = { - ConstantInt::get(EarlierWriteLength->getType(), ToRemoveSize)}; + ConstantInt::get(DeadWriteLength->getType(), ToRemoveSize)}; Instruction *NewDestGEP = GetElementPtrInst::CreateInBounds( - Type::getInt8Ty(EarlierIntrinsic->getContext()), - Dest, Indices, "", EarlierWrite); - NewDestGEP->setDebugLoc(EarlierIntrinsic->getDebugLoc()); + Type::getInt8Ty(DeadIntrinsic->getContext()), Dest, Indices, "", DeadI); + NewDestGEP->setDebugLoc(DeadIntrinsic->getDebugLoc()); if (NewDestGEP->getType() != OrigDest->getType()) NewDestGEP = CastInst::CreatePointerCast(NewDestGEP, OrigDest->getType(), - "", EarlierWrite); - EarlierIntrinsic->setDest(NewDestGEP); + "", DeadI); + DeadIntrinsic->setDest(NewDestGEP); } - // Finally update start and size of earlier access. + // Finally update start and size of dead access. if (!IsOverwriteEnd) - EarlierStart += ToRemoveSize; - EarlierSize = NewSize; + DeadStart += ToRemoveSize; + DeadSize = NewSize; return true; } -static bool tryToShortenEnd(Instruction *EarlierWrite, - OverlapIntervalsTy &IntervalMap, - int64_t &EarlierStart, uint64_t &EarlierSize) { - if (IntervalMap.empty() || !isShortenableAtTheEnd(EarlierWrite)) +static bool tryToShortenEnd(Instruction *DeadI, OverlapIntervalsTy &IntervalMap, + int64_t &DeadStart, uint64_t &DeadSize) { + if (IntervalMap.empty() || !isShortenableAtTheEnd(DeadI)) return false; OverlapIntervalsTy::iterator OII = --IntervalMap.end(); - int64_t LaterStart = OII->second; - uint64_t LaterSize = OII->first - LaterStart; + int64_t KillingStart = OII->second; + uint64_t KillingSize = OII->first - KillingStart; - assert(OII->first - LaterStart >= 0 && "Size expected to be positive"); + assert(OII->first - KillingStart >= 0 && "Size expected to be positive"); - if (LaterStart > EarlierStart && - // Note: "LaterStart - EarlierStart" is known to be positive due to + if (KillingStart > DeadStart && + // Note: "KillingStart - KillingStart" is known to be positive due to // preceding check. - (uint64_t)(LaterStart - EarlierStart) < EarlierSize && - // Note: "EarlierSize - (uint64_t)(LaterStart - EarlierStart)" is known to + (uint64_t)(KillingStart - DeadStart) < DeadSize && + // Note: "DeadSize - (uint64_t)(KillingStart - DeadStart)" is known to // be non negative due to preceding checks. - LaterSize >= EarlierSize - (uint64_t)(LaterStart - EarlierStart)) { - if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, - LaterSize, true)) { + KillingSize >= DeadSize - (uint64_t)(KillingStart - DeadStart)) { + if (tryToShorten(DeadI, DeadStart, DeadSize, KillingStart, KillingSize, + true)) { IntervalMap.erase(OII); return true; } @@ -697,28 +667,28 @@ static bool tryToShortenEnd(Instruction *EarlierWrite, return false; } -static bool tryToShortenBegin(Instruction *EarlierWrite, +static bool tryToShortenBegin(Instruction *DeadI, OverlapIntervalsTy &IntervalMap, - int64_t &EarlierStart, uint64_t &EarlierSize) { - if (IntervalMap.empty() || !isShortenableAtTheBeginning(EarlierWrite)) + int64_t &DeadStart, uint64_t &DeadSize) { + if (IntervalMap.empty() || !isShortenableAtTheBeginning(DeadI)) return false; OverlapIntervalsTy::iterator OII = IntervalMap.begin(); - int64_t LaterStart = OII->second; - uint64_t LaterSize = OII->first - LaterStart; + int64_t KillingStart = OII->second; + uint64_t KillingSize = OII->first - KillingStart; - assert(OII->first - LaterStart >= 0 && "Size expected to be positive"); + assert(OII->first - KillingStart >= 0 && "Size expected to be positive"); - if (LaterStart <= EarlierStart && - // Note: "EarlierStart - LaterStart" is known to be non negative due to + if (KillingStart <= DeadStart && + // Note: "DeadStart - KillingStart" is known to be non negative due to // preceding check. - LaterSize > (uint64_t)(EarlierStart - LaterStart)) { - // Note: "LaterSize - (uint64_t)(EarlierStart - LaterStart)" is known to be - // positive due to preceding checks. - assert(LaterSize - (uint64_t)(EarlierStart - LaterStart) < EarlierSize && + KillingSize > (uint64_t)(DeadStart - KillingStart)) { + // Note: "KillingSize - (uint64_t)(DeadStart - DeadStart)" is known to + // be positive due to preceding checks. + assert(KillingSize - (uint64_t)(DeadStart - KillingStart) < DeadSize && "Should have been handled as OW_Complete"); - if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, - LaterSize, false)) { + if (tryToShorten(DeadI, DeadStart, DeadSize, KillingStart, KillingSize, + false)) { IntervalMap.erase(OII); return true; } @@ -726,71 +696,48 @@ static bool tryToShortenBegin(Instruction *EarlierWrite, return false; } -static bool removePartiallyOverlappedStores(const DataLayout &DL, - InstOverlapIntervalsTy &IOL, - const TargetLibraryInfo &TLI) { - bool Changed = false; - for (auto OI : IOL) { - Instruction *EarlierWrite = OI.first; - MemoryLocation Loc = getLocForWrite(EarlierWrite, TLI); - assert(isRemovable(EarlierWrite) && "Expect only removable instruction"); - - const Value *Ptr = Loc.Ptr->stripPointerCasts(); - int64_t EarlierStart = 0; - uint64_t EarlierSize = Loc.Size.getValue(); - GetPointerBaseWithConstantOffset(Ptr, EarlierStart, DL); - OverlapIntervalsTy &IntervalMap = OI.second; - Changed |= - tryToShortenEnd(EarlierWrite, IntervalMap, EarlierStart, EarlierSize); - if (IntervalMap.empty()) - continue; - Changed |= - tryToShortenBegin(EarlierWrite, IntervalMap, EarlierStart, EarlierSize); - } - return Changed; -} - -static Constant *tryToMergePartialOverlappingStores( - StoreInst *Earlier, StoreInst *Later, int64_t InstWriteOffset, - int64_t DepWriteOffset, const DataLayout &DL, BatchAAResults &AA, - DominatorTree *DT) { - - if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && - DL.typeSizeEqualsStoreSize(Earlier->getValueOperand()->getType()) && - Later && isa<ConstantInt>(Later->getValueOperand()) && - DL.typeSizeEqualsStoreSize(Later->getValueOperand()->getType()) && - memoryIsNotModifiedBetween(Earlier, Later, AA, DL, DT)) { +static Constant * +tryToMergePartialOverlappingStores(StoreInst *KillingI, StoreInst *DeadI, + int64_t KillingOffset, int64_t DeadOffset, + const DataLayout &DL, BatchAAResults &AA, + DominatorTree *DT) { + + if (DeadI && isa<ConstantInt>(DeadI->getValueOperand()) && + DL.typeSizeEqualsStoreSize(DeadI->getValueOperand()->getType()) && + KillingI && isa<ConstantInt>(KillingI->getValueOperand()) && + DL.typeSizeEqualsStoreSize(KillingI->getValueOperand()->getType()) && + memoryIsNotModifiedBetween(DeadI, KillingI, AA, DL, DT)) { // If the store we find is: // a) partially overwritten by the store to 'Loc' - // b) the later store is fully contained in the earlier one and + // b) the killing store is fully contained in the dead one and // c) they both have a constant value // d) none of the two stores need padding - // Merge the two stores, replacing the earlier store's value with a + // Merge the two stores, replacing the dead store's value with a // merge of both values. // TODO: Deal with other constant types (vectors, etc), and probably // some mem intrinsics (if needed) - APInt EarlierValue = - cast<ConstantInt>(Earlier->getValueOperand())->getValue(); - APInt LaterValue = cast<ConstantInt>(Later->getValueOperand())->getValue(); - unsigned LaterBits = LaterValue.getBitWidth(); - assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); - LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); + APInt DeadValue = cast<ConstantInt>(DeadI->getValueOperand())->getValue(); + APInt KillingValue = + cast<ConstantInt>(KillingI->getValueOperand())->getValue(); + unsigned KillingBits = KillingValue.getBitWidth(); + assert(DeadValue.getBitWidth() > KillingValue.getBitWidth()); + KillingValue = KillingValue.zext(DeadValue.getBitWidth()); // Offset of the smaller store inside the larger store - unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; - unsigned LShiftAmount = DL.isBigEndian() ? EarlierValue.getBitWidth() - - BitOffsetDiff - LaterBits - : BitOffsetDiff; - APInt Mask = APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, - LShiftAmount + LaterBits); + unsigned BitOffsetDiff = (KillingOffset - DeadOffset) * 8; + unsigned LShiftAmount = + DL.isBigEndian() ? DeadValue.getBitWidth() - BitOffsetDiff - KillingBits + : BitOffsetDiff; + APInt Mask = APInt::getBitsSet(DeadValue.getBitWidth(), LShiftAmount, + LShiftAmount + KillingBits); // Clear the bits we'll be replacing, then OR with the smaller // store, shifted appropriately. - APInt Merged = (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); - LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *Earlier - << "\n Later: " << *Later + APInt Merged = (DeadValue & ~Mask) | (KillingValue << LShiftAmount); + LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Dead: " << *DeadI + << "\n Killing: " << *KillingI << "\n Merged Value: " << Merged << '\n'); - return ConstantInt::get(Earlier->getValueOperand()->getType(), Merged); + return ConstantInt::get(DeadI->getValueOperand()->getType(), Merged); } return nullptr; } @@ -819,14 +766,17 @@ bool isNoopIntrinsic(Instruction *I) { } // Check if we can ignore \p D for DSE. -bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) { +bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller, + const TargetLibraryInfo &TLI) { Instruction *DI = D->getMemoryInst(); // Calls that only access inaccessible memory cannot read or write any memory // locations we consider for elimination. if (auto *CB = dyn_cast<CallBase>(DI)) - if (CB->onlyAccessesInaccessibleMemory()) + if (CB->onlyAccessesInaccessibleMemory()) { + if (isAllocLikeFn(DI, &TLI)) + return false; return true; - + } // We can eliminate stores to locations not visible to the caller across // throwing instructions. if (DI->mayThrow() && !DefVisibleToCaller) @@ -841,7 +791,7 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) { return true; // Skip intrinsics that do not really read or modify memory. - if (isNoopIntrinsic(D->getMemoryInst())) + if (isNoopIntrinsic(DI)) return true; return false; @@ -850,6 +800,7 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) { struct DSEState { Function &F; AliasAnalysis &AA; + EarliestEscapeInfo EI; /// The single BatchAA instance that is used to cache AA queries. It will /// not be invalidated over the whole run. This is safe, because: @@ -892,30 +843,29 @@ struct DSEState { /// basic block. DenseMap<BasicBlock *, InstOverlapIntervalsTy> IOLs; + // Class contains self-reference, make sure it's not copied/moved. + DSEState(const DSEState &) = delete; + DSEState &operator=(const DSEState &) = delete; + DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, const TargetLibraryInfo &TLI, const LoopInfo &LI) - : F(F), AA(AA), BatchAA(AA), MSSA(MSSA), DT(DT), PDT(PDT), TLI(TLI), - DL(F.getParent()->getDataLayout()), LI(LI) {} - - static DSEState get(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, - DominatorTree &DT, PostDominatorTree &PDT, - const TargetLibraryInfo &TLI, const LoopInfo &LI) { - DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); + : F(F), AA(AA), EI(DT, LI), BatchAA(AA, &EI), MSSA(MSSA), DT(DT), + PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) { // Collect blocks with throwing instructions not modeled in MemorySSA and // alloc-like objects. unsigned PO = 0; for (BasicBlock *BB : post_order(&F)) { - State.PostOrderNumbers[BB] = PO++; + PostOrderNumbers[BB] = PO++; for (Instruction &I : *BB) { MemoryAccess *MA = MSSA.getMemoryAccess(&I); if (I.mayThrow() && !MA) - State.ThrowingBlocks.insert(I.getParent()); + ThrowingBlocks.insert(I.getParent()); auto *MD = dyn_cast_or_null<MemoryDef>(MA); - if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit && - (State.getLocForWriteEx(&I) || State.isMemTerminatorInst(&I))) - State.MemDefs.push_back(MD); + if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit && + (getLocForWriteEx(&I) || isMemTerminatorInst(&I))) + MemDefs.push_back(MD); } } @@ -925,131 +875,134 @@ struct DSEState { if (AI.hasPassPointeeByValueCopyAttr()) { // For byval, the caller doesn't know the address of the allocation. if (AI.hasByValAttr()) - State.InvisibleToCallerBeforeRet.insert({&AI, true}); - State.InvisibleToCallerAfterRet.insert({&AI, true}); + InvisibleToCallerBeforeRet.insert({&AI, true}); + InvisibleToCallerAfterRet.insert({&AI, true}); } // Collect whether there is any irreducible control flow in the function. - State.ContainsIrreducibleLoops = mayContainIrreducibleControl(F, &LI); - - return State; + ContainsIrreducibleLoops = mayContainIrreducibleControl(F, &LI); } - /// Return 'OW_Complete' if a store to the 'Later' location (by \p LaterI - /// instruction) completely overwrites a store to the 'Earlier' location. - /// (by \p EarlierI instruction). - /// Return OW_MaybePartial if \p Later does not completely overwrite - /// \p Earlier, but they both write to the same underlying object. In that - /// case, use isPartialOverwrite to check if \p Later partially overwrites - /// \p Earlier. Returns 'OW_Unknown' if nothing can be determined. - OverwriteResult - isOverwrite(const Instruction *LaterI, const Instruction *EarlierI, - const MemoryLocation &Later, const MemoryLocation &Earlier, - int64_t &EarlierOff, int64_t &LaterOff) { + /// Return 'OW_Complete' if a store to the 'KillingLoc' location (by \p + /// KillingI instruction) completely overwrites a store to the 'DeadLoc' + /// location (by \p DeadI instruction). + /// Return OW_MaybePartial if \p KillingI does not completely overwrite + /// \p DeadI, but they both write to the same underlying object. In that + /// case, use isPartialOverwrite to check if \p KillingI partially overwrites + /// \p DeadI. Returns 'OW_Unknown' if nothing can be determined. + OverwriteResult isOverwrite(const Instruction *KillingI, + const Instruction *DeadI, + const MemoryLocation &KillingLoc, + const MemoryLocation &DeadLoc, + int64_t &KillingOff, int64_t &DeadOff) { // AliasAnalysis does not always account for loops. Limit overwrite checks - // to dependencies for which we can guarantee they are independant of any + // to dependencies for which we can guarantee they are independent of any // loops they are in. - if (!isGuaranteedLoopIndependent(EarlierI, LaterI, Earlier)) + if (!isGuaranteedLoopIndependent(DeadI, KillingI, DeadLoc)) return OW_Unknown; // FIXME: Vet that this works for size upper-bounds. Seems unlikely that we'll // get imprecise values here, though (except for unknown sizes). - if (!Later.Size.isPrecise() || !Earlier.Size.isPrecise()) { + if (!KillingLoc.Size.isPrecise() || !DeadLoc.Size.isPrecise()) { // In case no constant size is known, try to an IR values for the number // of bytes written and check if they match. - const auto *LaterMemI = dyn_cast<MemIntrinsic>(LaterI); - const auto *EarlierMemI = dyn_cast<MemIntrinsic>(EarlierI); - if (LaterMemI && EarlierMemI) { - const Value *LaterV = LaterMemI->getLength(); - const Value *EarlierV = EarlierMemI->getLength(); - if (LaterV == EarlierV && BatchAA.isMustAlias(Earlier, Later)) + const auto *KillingMemI = dyn_cast<MemIntrinsic>(KillingI); + const auto *DeadMemI = dyn_cast<MemIntrinsic>(DeadI); + if (KillingMemI && DeadMemI) { + const Value *KillingV = KillingMemI->getLength(); + const Value *DeadV = DeadMemI->getLength(); + if (KillingV == DeadV && BatchAA.isMustAlias(DeadLoc, KillingLoc)) return OW_Complete; } // Masked stores have imprecise locations, but we can reason about them // to some extent. - return isMaskedStoreOverwrite(LaterI, EarlierI, BatchAA); + return isMaskedStoreOverwrite(KillingI, DeadI, BatchAA); } - const uint64_t LaterSize = Later.Size.getValue(); - const uint64_t EarlierSize = Earlier.Size.getValue(); + const uint64_t KillingSize = KillingLoc.Size.getValue(); + const uint64_t DeadSize = DeadLoc.Size.getValue(); // Query the alias information - AliasResult AAR = BatchAA.alias(Later, Earlier); + AliasResult AAR = BatchAA.alias(KillingLoc, DeadLoc); // If the start pointers are the same, we just have to compare sizes to see if - // the later store was larger than the earlier store. + // the killing store was larger than the dead store. if (AAR == AliasResult::MustAlias) { - // Make sure that the Later size is >= the Earlier size. - if (LaterSize >= EarlierSize) + // Make sure that the KillingSize size is >= the DeadSize size. + if (KillingSize >= DeadSize) return OW_Complete; } // If we hit a partial alias we may have a full overwrite if (AAR == AliasResult::PartialAlias && AAR.hasOffset()) { int32_t Off = AAR.getOffset(); - if (Off >= 0 && (uint64_t)Off + EarlierSize <= LaterSize) + if (Off >= 0 && (uint64_t)Off + DeadSize <= KillingSize) return OW_Complete; } - // Check to see if the later store is to the entire object (either a global, - // an alloca, or a byval/inalloca argument). If so, then it clearly + // Check to see if the killing store is to the entire object (either a + // global, an alloca, or a byval/inalloca argument). If so, then it clearly // overwrites any other store to the same object. - const Value *P1 = Earlier.Ptr->stripPointerCasts(); - const Value *P2 = Later.Ptr->stripPointerCasts(); - const Value *UO1 = getUnderlyingObject(P1), *UO2 = getUnderlyingObject(P2); + const Value *DeadPtr = DeadLoc.Ptr->stripPointerCasts(); + const Value *KillingPtr = KillingLoc.Ptr->stripPointerCasts(); + const Value *DeadUndObj = getUnderlyingObject(DeadPtr); + const Value *KillingUndObj = getUnderlyingObject(KillingPtr); // If we can't resolve the same pointers to the same object, then we can't // analyze them at all. - if (UO1 != UO2) + if (DeadUndObj != KillingUndObj) return OW_Unknown; - // If the "Later" store is to a recognizable object, get its size. - uint64_t ObjectSize = getPointerSize(UO2, DL, TLI, &F); - if (ObjectSize != MemoryLocation::UnknownSize) - if (ObjectSize == LaterSize && ObjectSize >= EarlierSize) + // If the KillingI store is to a recognizable object, get its size. + uint64_t KillingUndObjSize = getPointerSize(KillingUndObj, DL, TLI, &F); + if (KillingUndObjSize != MemoryLocation::UnknownSize) + if (KillingUndObjSize == KillingSize && KillingUndObjSize >= DeadSize) return OW_Complete; // Okay, we have stores to two completely different pointers. Try to // decompose the pointer into a "base + constant_offset" form. If the base // pointers are equal, then we can reason about the two stores. - EarlierOff = 0; - LaterOff = 0; - const Value *BP1 = GetPointerBaseWithConstantOffset(P1, EarlierOff, DL); - const Value *BP2 = GetPointerBaseWithConstantOffset(P2, LaterOff, DL); - - // If the base pointers still differ, we have two completely different stores. - if (BP1 != BP2) + DeadOff = 0; + KillingOff = 0; + const Value *DeadBasePtr = + GetPointerBaseWithConstantOffset(DeadPtr, DeadOff, DL); + const Value *KillingBasePtr = + GetPointerBaseWithConstantOffset(KillingPtr, KillingOff, DL); + + // If the base pointers still differ, we have two completely different + // stores. + if (DeadBasePtr != KillingBasePtr) return OW_Unknown; - // The later access completely overlaps the earlier store if and only if - // both start and end of the earlier one is "inside" the later one: - // |<->|--earlier--|<->| - // |-------later-------| + // The killing access completely overlaps the dead store if and only if + // both start and end of the dead one is "inside" the killing one: + // |<->|--dead--|<->| + // |-----killing------| // Accesses may overlap if and only if start of one of them is "inside" // another one: - // |<->|--earlier--|<----->| - // |-------later-------| + // |<->|--dead--|<-------->| + // |-------killing--------| // OR - // |----- earlier -----| - // |<->|---later---|<----->| + // |-------dead-------| + // |<->|---killing---|<----->| // // We have to be careful here as *Off is signed while *.Size is unsigned. - // Check if the earlier access starts "not before" the later one. - if (EarlierOff >= LaterOff) { - // If the earlier access ends "not after" the later access then the earlier - // one is completely overwritten by the later one. - if (uint64_t(EarlierOff - LaterOff) + EarlierSize <= LaterSize) + // Check if the dead access starts "not before" the killing one. + if (DeadOff >= KillingOff) { + // If the dead access ends "not after" the killing access then the + // dead one is completely overwritten by the killing one. + if (uint64_t(DeadOff - KillingOff) + DeadSize <= KillingSize) return OW_Complete; - // If start of the earlier access is "before" end of the later access then - // accesses overlap. - else if ((uint64_t)(EarlierOff - LaterOff) < LaterSize) + // If start of the dead access is "before" end of the killing access + // then accesses overlap. + else if ((uint64_t)(DeadOff - KillingOff) < KillingSize) return OW_MaybePartial; } - // If start of the later access is "before" end of the earlier access then + // If start of the killing access is "before" end of the dead access then // accesses overlap. - else if ((uint64_t)(LaterOff - EarlierOff) < EarlierSize) { + else if ((uint64_t)(KillingOff - DeadOff) < DeadSize) { return OW_MaybePartial; } @@ -1106,8 +1059,13 @@ struct DSEState { LibFunc LF; if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) { switch (LF) { - case LibFunc_strcpy: case LibFunc_strncpy: + if (const auto *Len = dyn_cast<ConstantInt>(CB->getArgOperand(2))) + return MemoryLocation(CB->getArgOperand(0), + LocationSize::precise(Len->getZExtValue()), + CB->getAAMetadata()); + LLVM_FALLTHROUGH; + case LibFunc_strcpy: case LibFunc_strcat: case LibFunc_strncat: return {MemoryLocation::getAfter(CB->getArgOperand(0))}; @@ -1145,8 +1103,8 @@ struct DSEState { int64_t InstWriteOffset, DepWriteOffset; if (auto CC = getLocForWriteEx(UseInst)) - return isOverwrite(UseInst, DefInst, *CC, DefLoc, DepWriteOffset, - InstWriteOffset) == OW_Complete; + return isOverwrite(UseInst, DefInst, *CC, DefLoc, InstWriteOffset, + DepWriteOffset) == OW_Complete; return false; } @@ -1248,9 +1206,10 @@ struct DSEState { const Value *LocUO = getUnderlyingObject(Loc.Ptr); return BatchAA.isMustAlias(TermLoc.Ptr, LocUO); } - int64_t InstWriteOffset, DepWriteOffset; - return isOverwrite(MaybeTerm, AccessI, TermLoc, Loc, DepWriteOffset, - InstWriteOffset) == OW_Complete; + int64_t InstWriteOffset = 0; + int64_t DepWriteOffset = 0; + return isOverwrite(MaybeTerm, AccessI, TermLoc, Loc, InstWriteOffset, + DepWriteOffset) == OW_Complete; } // Returns true if \p Use may read from \p DefLoc. @@ -1270,10 +1229,6 @@ struct DSEState { if (CB->onlyAccessesInaccessibleMemory()) return false; - // NOTE: For calls, the number of stores removed could be slightly improved - // by using AA.callCapturesBefore(UseInst, DefLoc, &DT), but that showed to - // be expensive compared to the benefits in practice. For now, avoid more - // expensive analysis to limit compile-time. return isRefSet(BatchAA.getModRefInfo(UseInst, DefLoc)); } @@ -1329,15 +1284,15 @@ struct DSEState { return IsGuaranteedLoopInvariantBase(Ptr); } - // Find a MemoryDef writing to \p DefLoc and dominating \p StartAccess, with - // no read access between them or on any other path to a function exit block - // if \p DefLoc is not accessible after the function returns. If there is no - // such MemoryDef, return None. The returned value may not (completely) - // overwrite \p DefLoc. Currently we bail out when we encounter an aliasing - // MemoryUse (read). + // Find a MemoryDef writing to \p KillingLoc and dominating \p StartAccess, + // with no read access between them or on any other path to a function exit + // block if \p KillingLoc is not accessible after the function returns. If + // there is no such MemoryDef, return None. The returned value may not + // (completely) overwrite \p KillingLoc. Currently we bail out when we + // encounter an aliasing MemoryUse (read). Optional<MemoryAccess *> getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess, - const MemoryLocation &DefLoc, const Value *DefUO, + const MemoryLocation &KillingLoc, const Value *KillingUndObj, unsigned &ScanLimit, unsigned &WalkerStepLimit, bool IsMemTerm, unsigned &PartialLimit) { if (ScanLimit == 0 || WalkerStepLimit == 0) { @@ -1389,19 +1344,20 @@ struct DSEState { MemoryDef *CurrentDef = cast<MemoryDef>(Current); Instruction *CurrentI = CurrentDef->getMemoryInst(); - if (canSkipDef(CurrentDef, !isInvisibleToCallerBeforeRet(DefUO))) + if (canSkipDef(CurrentDef, !isInvisibleToCallerBeforeRet(KillingUndObj), + TLI)) continue; // Before we try to remove anything, check for any extra throwing // instructions that block us from DSEing - if (mayThrowBetween(KillingI, CurrentI, DefUO)) { + if (mayThrowBetween(KillingI, CurrentI, KillingUndObj)) { LLVM_DEBUG(dbgs() << " ... skip, may throw!\n"); return None; } // Check for anything that looks like it will be a barrier to further // removal - if (isDSEBarrier(DefUO, CurrentI)) { + if (isDSEBarrier(KillingUndObj, CurrentI)) { LLVM_DEBUG(dbgs() << " ... skip, barrier\n"); return None; } @@ -1410,14 +1366,14 @@ struct DSEState { // clobber, bail out, as the path is not profitable. We skip this check // for intrinsic calls, because the code knows how to handle memcpy // intrinsics. - if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(DefLoc, CurrentI)) + if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(KillingLoc, CurrentI)) return None; // Quick check if there are direct uses that are read-clobbers. - if (any_of(Current->uses(), [this, &DefLoc, StartAccess](Use &U) { + if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) { if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(U.getUser())) return !MSSA.dominates(StartAccess, UseOrDef) && - isReadClobber(DefLoc, UseOrDef->getMemoryInst()); + isReadClobber(KillingLoc, UseOrDef->getMemoryInst()); return false; })) { LLVM_DEBUG(dbgs() << " ... found a read clobber\n"); @@ -1450,9 +1406,10 @@ struct DSEState { if (!isMemTerminator(*CurrentLoc, CurrentI, KillingI)) continue; } else { - int64_t InstWriteOffset, DepWriteOffset; - auto OR = isOverwrite(KillingI, CurrentI, DefLoc, *CurrentLoc, - DepWriteOffset, InstWriteOffset); + int64_t KillingOffset = 0; + int64_t DeadOffset = 0; + auto OR = isOverwrite(KillingI, CurrentI, KillingLoc, *CurrentLoc, + KillingOffset, DeadOffset); // If Current does not write to the same object as KillingDef, check // the next candidate. if (OR == OW_Unknown) @@ -1473,30 +1430,25 @@ struct DSEState { }; // Accesses to objects accessible after the function returns can only be - // eliminated if the access is killed along all paths to the exit. Collect + // eliminated if the access is dead along all paths to the exit. Collect // the blocks with killing (=completely overwriting MemoryDefs) and check if - // they cover all paths from EarlierAccess to any function exit. + // they cover all paths from MaybeDeadAccess to any function exit. SmallPtrSet<Instruction *, 16> KillingDefs; KillingDefs.insert(KillingDef->getMemoryInst()); - MemoryAccess *EarlierAccess = Current; - Instruction *EarlierMemInst = - cast<MemoryDef>(EarlierAccess)->getMemoryInst(); - LLVM_DEBUG(dbgs() << " Checking for reads of " << *EarlierAccess << " (" - << *EarlierMemInst << ")\n"); + MemoryAccess *MaybeDeadAccess = Current; + MemoryLocation MaybeDeadLoc = *CurrentLoc; + Instruction *MaybeDeadI = cast<MemoryDef>(MaybeDeadAccess)->getMemoryInst(); + LLVM_DEBUG(dbgs() << " Checking for reads of " << *MaybeDeadAccess << " (" + << *MaybeDeadI << ")\n"); SmallSetVector<MemoryAccess *, 32> WorkList; auto PushMemUses = [&WorkList](MemoryAccess *Acc) { for (Use &U : Acc->uses()) WorkList.insert(cast<MemoryAccess>(U.getUser())); }; - PushMemUses(EarlierAccess); - - // Optimistically collect all accesses for reads. If we do not find any - // read clobbers, add them to the cache. - SmallPtrSet<MemoryAccess *, 16> KnownNoReads; - if (!EarlierMemInst->mayReadFromMemory()) - KnownNoReads.insert(EarlierAccess); - // Check if EarlierDef may be read. + PushMemUses(MaybeDeadAccess); + + // Check if DeadDef may be read. for (unsigned I = 0; I < WorkList.size(); I++) { MemoryAccess *UseAccess = WorkList[I]; @@ -1508,7 +1460,6 @@ struct DSEState { } --ScanLimit; NumDomMemDefChecks++; - KnownNoReads.insert(UseAccess); if (isa<MemoryPhi>(UseAccess)) { if (any_of(KillingDefs, [this, UseAccess](Instruction *KI) { @@ -1535,7 +1486,7 @@ struct DSEState { // A memory terminator kills all preceeding MemoryDefs and all succeeding // MemoryAccesses. We do not have to check it's users. - if (isMemTerminator(*CurrentLoc, EarlierMemInst, UseInst)) { + if (isMemTerminator(MaybeDeadLoc, MaybeDeadI, UseInst)) { LLVM_DEBUG( dbgs() << " ... skipping, memterminator invalidates following accesses\n"); @@ -1548,14 +1499,14 @@ struct DSEState { continue; } - if (UseInst->mayThrow() && !isInvisibleToCallerBeforeRet(DefUO)) { + if (UseInst->mayThrow() && !isInvisibleToCallerBeforeRet(KillingUndObj)) { LLVM_DEBUG(dbgs() << " ... found throwing instruction\n"); return None; } // Uses which may read the original MemoryDef mean we cannot eliminate the // original MD. Stop walk. - if (isReadClobber(*CurrentLoc, UseInst)) { + if (isReadClobber(MaybeDeadLoc, UseInst)) { LLVM_DEBUG(dbgs() << " ... found read clobber\n"); return None; } @@ -1563,16 +1514,16 @@ struct DSEState { // If this worklist walks back to the original memory access (and the // pointer is not guarenteed loop invariant) then we cannot assume that a // store kills itself. - if (EarlierAccess == UseAccess && - !isGuaranteedLoopInvariant(CurrentLoc->Ptr)) { + if (MaybeDeadAccess == UseAccess && + !isGuaranteedLoopInvariant(MaybeDeadLoc.Ptr)) { LLVM_DEBUG(dbgs() << " ... found not loop invariant self access\n"); return None; } - // Otherwise, for the KillingDef and EarlierAccess we only have to check + // Otherwise, for the KillingDef and MaybeDeadAccess we only have to check // if it reads the memory location. // TODO: It would probably be better to check for self-reads before // calling the function. - if (KillingDef == UseAccess || EarlierAccess == UseAccess) { + if (KillingDef == UseAccess || MaybeDeadAccess == UseAccess) { LLVM_DEBUG(dbgs() << " ... skipping killing def/dom access\n"); continue; } @@ -1581,18 +1532,18 @@ struct DSEState { // the original location. Otherwise we have to check uses of *all* // MemoryDefs we discover, including non-aliasing ones. Otherwise we might // miss cases like the following - // 1 = Def(LoE) ; <----- EarlierDef stores [0,1] + // 1 = Def(LoE) ; <----- DeadDef stores [0,1] // 2 = Def(1) ; (2, 1) = NoAlias, stores [2,3] // Use(2) ; MayAlias 2 *and* 1, loads [0, 3]. // (The Use points to the *first* Def it may alias) // 3 = Def(1) ; <---- Current (3, 2) = NoAlias, (3,1) = MayAlias, // stores [0,1] if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) { - if (isCompleteOverwrite(*CurrentLoc, EarlierMemInst, UseInst)) { + if (isCompleteOverwrite(MaybeDeadLoc, MaybeDeadI, UseInst)) { BasicBlock *MaybeKillingBlock = UseInst->getParent(); if (PostOrderNumbers.find(MaybeKillingBlock)->second < - PostOrderNumbers.find(EarlierAccess->getBlock())->second) { - if (!isInvisibleToCallerAfterRet(DefUO)) { + PostOrderNumbers.find(MaybeDeadAccess->getBlock())->second) { + if (!isInvisibleToCallerAfterRet(KillingUndObj)) { LLVM_DEBUG(dbgs() << " ... found killing def " << *UseInst << "\n"); KillingDefs.insert(UseInst); @@ -1608,9 +1559,9 @@ struct DSEState { } // For accesses to locations visible after the function returns, make sure - // that the location is killed (=overwritten) along all paths from - // EarlierAccess to the exit. - if (!isInvisibleToCallerAfterRet(DefUO)) { + // that the location is dead (=overwritten) along all paths from + // MaybeDeadAccess to the exit. + if (!isInvisibleToCallerAfterRet(KillingUndObj)) { SmallPtrSet<BasicBlock *, 16> KillingBlocks; for (Instruction *KD : KillingDefs) KillingBlocks.insert(KD->getParent()); @@ -1619,25 +1570,24 @@ struct DSEState { // Find the common post-dominator of all killing blocks. BasicBlock *CommonPred = *KillingBlocks.begin(); - for (auto I = std::next(KillingBlocks.begin()), E = KillingBlocks.end(); - I != E; I++) { + for (BasicBlock *BB : llvm::drop_begin(KillingBlocks)) { if (!CommonPred) break; - CommonPred = PDT.findNearestCommonDominator(CommonPred, *I); + CommonPred = PDT.findNearestCommonDominator(CommonPred, BB); } // If CommonPred is in the set of killing blocks, just check if it - // post-dominates EarlierAccess. + // post-dominates MaybeDeadAccess. if (KillingBlocks.count(CommonPred)) { - if (PDT.dominates(CommonPred, EarlierAccess->getBlock())) - return {EarlierAccess}; + if (PDT.dominates(CommonPred, MaybeDeadAccess->getBlock())) + return {MaybeDeadAccess}; return None; } - // If the common post-dominator does not post-dominate EarlierAccess, - // there is a path from EarlierAccess to an exit not going through a + // If the common post-dominator does not post-dominate MaybeDeadAccess, + // there is a path from MaybeDeadAccess to an exit not going through a // killing block. - if (PDT.dominates(CommonPred, EarlierAccess->getBlock())) { + if (PDT.dominates(CommonPred, MaybeDeadAccess->getBlock())) { SetVector<BasicBlock *> WorkList; // If CommonPred is null, there are multiple exits from the function. @@ -1650,16 +1600,16 @@ struct DSEState { NumCFGTries++; // Check if all paths starting from an exit node go through one of the - // killing blocks before reaching EarlierAccess. + // killing blocks before reaching MaybeDeadAccess. for (unsigned I = 0; I < WorkList.size(); I++) { NumCFGChecks++; BasicBlock *Current = WorkList[I]; if (KillingBlocks.count(Current)) continue; - if (Current == EarlierAccess->getBlock()) + if (Current == MaybeDeadAccess->getBlock()) return None; - // EarlierAccess is reachable from the entry, so we don't have to + // MaybeDeadAccess is reachable from the entry, so we don't have to // explore unreachable blocks further. if (!DT.isReachableFromEntry(Current)) continue; @@ -1671,14 +1621,14 @@ struct DSEState { return None; } NumCFGSuccess++; - return {EarlierAccess}; + return {MaybeDeadAccess}; } return None; } - // No aliasing MemoryUses of EarlierAccess found, EarlierAccess is + // No aliasing MemoryUses of MaybeDeadAccess found, MaybeDeadAccess is // potentially dead. - return {EarlierAccess}; + return {MaybeDeadAccess}; } // Delete dead memory defs @@ -1701,6 +1651,7 @@ struct DSEState { if (MemoryDef *MD = dyn_cast<MemoryDef>(MA)) { SkipStores.insert(MD); } + Updater.removeMemoryAccess(MA); } @@ -1715,47 +1666,49 @@ struct DSEState { NowDeadInsts.push_back(OpI); } + EI.removeInstruction(DeadInst); DeadInst->eraseFromParent(); } } - // Check for any extra throws between SI and NI that block DSE. This only - // checks extra maythrows (those that aren't MemoryDef's). MemoryDef that may - // throw are handled during the walk from one def to the next. - bool mayThrowBetween(Instruction *SI, Instruction *NI, - const Value *SILocUnd) { - // First see if we can ignore it by using the fact that SI is an + // Check for any extra throws between \p KillingI and \p DeadI that block + // DSE. This only checks extra maythrows (those that aren't MemoryDef's). + // MemoryDef that may throw are handled during the walk from one def to the + // next. + bool mayThrowBetween(Instruction *KillingI, Instruction *DeadI, + const Value *KillingUndObj) { + // First see if we can ignore it by using the fact that KillingI is an // alloca/alloca like object that is not visible to the caller during // execution of the function. - if (SILocUnd && isInvisibleToCallerBeforeRet(SILocUnd)) + if (KillingUndObj && isInvisibleToCallerBeforeRet(KillingUndObj)) return false; - if (SI->getParent() == NI->getParent()) - return ThrowingBlocks.count(SI->getParent()); + if (KillingI->getParent() == DeadI->getParent()) + return ThrowingBlocks.count(KillingI->getParent()); return !ThrowingBlocks.empty(); } - // Check if \p NI acts as a DSE barrier for \p SI. The following instructions - // act as barriers: - // * A memory instruction that may throw and \p SI accesses a non-stack + // Check if \p DeadI acts as a DSE barrier for \p KillingI. The following + // instructions act as barriers: + // * A memory instruction that may throw and \p KillingI accesses a non-stack // object. // * Atomic stores stronger that monotonic. - bool isDSEBarrier(const Value *SILocUnd, Instruction *NI) { - // If NI may throw it acts as a barrier, unless we are to an alloca/alloca - // like object that does not escape. - if (NI->mayThrow() && !isInvisibleToCallerBeforeRet(SILocUnd)) + bool isDSEBarrier(const Value *KillingUndObj, Instruction *DeadI) { + // If DeadI may throw it acts as a barrier, unless we are to an + // alloca/alloca like object that does not escape. + if (DeadI->mayThrow() && !isInvisibleToCallerBeforeRet(KillingUndObj)) return true; - // If NI is an atomic load/store stronger than monotonic, do not try to + // If DeadI is an atomic load/store stronger than monotonic, do not try to // eliminate/reorder it. - if (NI->isAtomic()) { - if (auto *LI = dyn_cast<LoadInst>(NI)) + if (DeadI->isAtomic()) { + if (auto *LI = dyn_cast<LoadInst>(DeadI)) return isStrongerThanMonotonic(LI->getOrdering()); - if (auto *SI = dyn_cast<StoreInst>(NI)) + if (auto *SI = dyn_cast<StoreInst>(DeadI)) return isStrongerThanMonotonic(SI->getOrdering()); - if (auto *ARMW = dyn_cast<AtomicRMWInst>(NI)) + if (auto *ARMW = dyn_cast<AtomicRMWInst>(DeadI)) return isStrongerThanMonotonic(ARMW->getOrdering()); - if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(NI)) + if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(DeadI)) return isStrongerThanMonotonic(CmpXchg->getSuccessOrdering()) || isStrongerThanMonotonic(CmpXchg->getFailureOrdering()); llvm_unreachable("other instructions should be skipped in MemorySSA"); @@ -1776,7 +1729,6 @@ struct DSEState { continue; Instruction *DefI = Def->getMemoryInst(); - SmallVector<const Value *, 4> Pointers; auto DefLoc = getLocForWriteEx(DefI); if (!DefLoc) continue; @@ -1787,7 +1739,7 @@ struct DSEState { // uncommon. If it turns out to be important, we can use // getUnderlyingObjects here instead. const Value *UO = getUnderlyingObject(DefLoc->Ptr); - if (!UO || !isInvisibleToCallerAfterRet(UO)) + if (!isInvisibleToCallerAfterRet(UO)) continue; if (isWriteAtEndOfFunction(Def)) { @@ -1804,8 +1756,7 @@ struct DSEState { /// \returns true if \p Def is a no-op store, either because it /// directly stores back a loaded value or stores zero to a calloced object. - bool storeIsNoop(MemoryDef *Def, const MemoryLocation &DefLoc, - const Value *DefUO) { + bool storeIsNoop(MemoryDef *Def, const Value *DefUO) { StoreInst *Store = dyn_cast<StoreInst>(Def->getMemoryInst()); MemSetInst *MemSet = dyn_cast<MemSetInst>(Def->getMemoryInst()); Constant *StoredConstant = nullptr; @@ -1816,13 +1767,78 @@ struct DSEState { if (StoredConstant && StoredConstant->isNullValue()) { auto *DefUOInst = dyn_cast<Instruction>(DefUO); - if (DefUOInst && isCallocLikeFn(DefUOInst, &TLI)) { - auto *UnderlyingDef = cast<MemoryDef>(MSSA.getMemoryAccess(DefUOInst)); - // If UnderlyingDef is the clobbering access of Def, no instructions - // between them can modify the memory location. - auto *ClobberDef = - MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def); - return UnderlyingDef == ClobberDef; + if (DefUOInst) { + if (isCallocLikeFn(DefUOInst, &TLI)) { + auto *UnderlyingDef = + cast<MemoryDef>(MSSA.getMemoryAccess(DefUOInst)); + // If UnderlyingDef is the clobbering access of Def, no instructions + // between them can modify the memory location. + auto *ClobberDef = + MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def); + return UnderlyingDef == ClobberDef; + } + + if (MemSet) { + if (F.hasFnAttribute(Attribute::SanitizeMemory) || + F.hasFnAttribute(Attribute::SanitizeAddress) || + F.hasFnAttribute(Attribute::SanitizeHWAddress) || + F.getName() == "calloc") + return false; + auto *Malloc = const_cast<CallInst *>(dyn_cast<CallInst>(DefUOInst)); + if (!Malloc) + return false; + auto *InnerCallee = Malloc->getCalledFunction(); + if (!InnerCallee) + return false; + LibFunc Func; + if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || + Func != LibFunc_malloc) + return false; + + auto shouldCreateCalloc = [](CallInst *Malloc, CallInst *Memset) { + // Check for br(icmp ptr, null), truebb, falsebb) pattern at the end + // of malloc block + auto *MallocBB = Malloc->getParent(), + *MemsetBB = Memset->getParent(); + if (MallocBB == MemsetBB) + return true; + auto *Ptr = Memset->getArgOperand(0); + auto *TI = MallocBB->getTerminator(); + ICmpInst::Predicate Pred; + BasicBlock *TrueBB, *FalseBB; + if (!match(TI, m_Br(m_ICmp(Pred, m_Specific(Ptr), m_Zero()), TrueBB, + FalseBB))) + return false; + if (Pred != ICmpInst::ICMP_EQ || MemsetBB != FalseBB) + return false; + return true; + }; + + if (Malloc->getOperand(0) == MemSet->getLength()) { + if (shouldCreateCalloc(Malloc, MemSet) && + DT.dominates(Malloc, MemSet) && + memoryIsNotModifiedBetween(Malloc, MemSet, BatchAA, DL, &DT)) { + IRBuilder<> IRB(Malloc); + const auto &DL = Malloc->getModule()->getDataLayout(); + if (auto *Calloc = + emitCalloc(ConstantInt::get(IRB.getIntPtrTy(DL), 1), + Malloc->getArgOperand(0), IRB, TLI)) { + MemorySSAUpdater Updater(&MSSA); + auto *LastDef = cast<MemoryDef>( + Updater.getMemorySSA()->getMemoryAccess(Malloc)); + auto *NewAccess = Updater.createMemoryAccessAfter( + cast<Instruction>(Calloc), LastDef, LastDef); + auto *NewAccessMD = cast<MemoryDef>(NewAccess); + Updater.insertDef(NewAccessMD, /*RenameUses=*/true); + Updater.removeMemoryAccess(Malloc); + Malloc->replaceAllUsesWith(Calloc); + Malloc->eraseFromParent(); + return true; + } + return false; + } + } + } } } @@ -1875,6 +1891,76 @@ struct DSEState { return false; } + + bool removePartiallyOverlappedStores(InstOverlapIntervalsTy &IOL) { + bool Changed = false; + for (auto OI : IOL) { + Instruction *DeadI = OI.first; + MemoryLocation Loc = *getLocForWriteEx(DeadI); + assert(isRemovable(DeadI) && "Expect only removable instruction"); + + const Value *Ptr = Loc.Ptr->stripPointerCasts(); + int64_t DeadStart = 0; + uint64_t DeadSize = Loc.Size.getValue(); + GetPointerBaseWithConstantOffset(Ptr, DeadStart, DL); + OverlapIntervalsTy &IntervalMap = OI.second; + Changed |= tryToShortenEnd(DeadI, IntervalMap, DeadStart, DeadSize); + if (IntervalMap.empty()) + continue; + Changed |= tryToShortenBegin(DeadI, IntervalMap, DeadStart, DeadSize); + } + return Changed; + } + + /// Eliminates writes to locations where the value that is being written + /// is already stored at the same location. + bool eliminateRedundantStoresOfExistingValues() { + bool MadeChange = false; + LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs that write the " + "already existing value\n"); + for (auto *Def : MemDefs) { + if (SkipStores.contains(Def) || MSSA.isLiveOnEntryDef(Def) || + !isRemovable(Def->getMemoryInst())) + continue; + auto *UpperDef = dyn_cast<MemoryDef>(Def->getDefiningAccess()); + if (!UpperDef || MSSA.isLiveOnEntryDef(UpperDef)) + continue; + + Instruction *DefInst = Def->getMemoryInst(); + Instruction *UpperInst = UpperDef->getMemoryInst(); + auto IsRedundantStore = [this, DefInst, + UpperInst](MemoryLocation UpperLoc) { + if (DefInst->isIdenticalTo(UpperInst)) + return true; + if (auto *MemSetI = dyn_cast<MemSetInst>(UpperInst)) { + if (auto *SI = dyn_cast<StoreInst>(DefInst)) { + auto MaybeDefLoc = getLocForWriteEx(DefInst); + if (!MaybeDefLoc) + return false; + int64_t InstWriteOffset = 0; + int64_t DepWriteOffset = 0; + auto OR = isOverwrite(UpperInst, DefInst, UpperLoc, *MaybeDefLoc, + InstWriteOffset, DepWriteOffset); + Value *StoredByte = isBytewiseValue(SI->getValueOperand(), DL); + return StoredByte && StoredByte == MemSetI->getOperand(1) && + OR == OW_Complete; + } + } + return false; + }; + + auto MaybeUpperLoc = getLocForWriteEx(UpperInst); + if (!MaybeUpperLoc || !IsRedundantStore(*MaybeUpperLoc) || + isReadClobber(*MaybeUpperLoc, DefInst)) + continue; + LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *DefInst + << '\n'); + deleteDeadInstruction(DefInst); + NumRedundantStores++; + MadeChange = true; + } + return MadeChange; + } }; static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, @@ -1883,68 +1969,64 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, const LoopInfo &LI) { bool MadeChange = false; - DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI, LI); + DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); // For each store: for (unsigned I = 0; I < State.MemDefs.size(); I++) { MemoryDef *KillingDef = State.MemDefs[I]; if (State.SkipStores.count(KillingDef)) continue; - Instruction *SI = KillingDef->getMemoryInst(); + Instruction *KillingI = KillingDef->getMemoryInst(); - Optional<MemoryLocation> MaybeSILoc; - if (State.isMemTerminatorInst(SI)) - MaybeSILoc = State.getLocForTerminator(SI).map( + Optional<MemoryLocation> MaybeKillingLoc; + if (State.isMemTerminatorInst(KillingI)) + MaybeKillingLoc = State.getLocForTerminator(KillingI).map( [](const std::pair<MemoryLocation, bool> &P) { return P.first; }); else - MaybeSILoc = State.getLocForWriteEx(SI); + MaybeKillingLoc = State.getLocForWriteEx(KillingI); - if (!MaybeSILoc) { + if (!MaybeKillingLoc) { LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for " - << *SI << "\n"); + << *KillingI << "\n"); continue; } - MemoryLocation SILoc = *MaybeSILoc; - assert(SILoc.Ptr && "SILoc should not be null"); - const Value *SILocUnd = getUnderlyingObject(SILoc.Ptr); - - MemoryAccess *Current = KillingDef; + MemoryLocation KillingLoc = *MaybeKillingLoc; + assert(KillingLoc.Ptr && "KillingLoc should not be null"); + const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr); LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " - << *Current << " (" << *SI << ")\n"); + << *KillingDef << " (" << *KillingI << ")\n"); unsigned ScanLimit = MemorySSAScanLimit; unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit; unsigned PartialLimit = MemorySSAPartialStoreLimit; // Worklist of MemoryAccesses that may be killed by KillingDef. SetVector<MemoryAccess *> ToCheck; - - if (SILocUnd) - ToCheck.insert(KillingDef->getDefiningAccess()); + ToCheck.insert(KillingDef->getDefiningAccess()); bool Shortend = false; - bool IsMemTerm = State.isMemTerminatorInst(SI); + bool IsMemTerm = State.isMemTerminatorInst(KillingI); // Check if MemoryAccesses in the worklist are killed by KillingDef. for (unsigned I = 0; I < ToCheck.size(); I++) { - Current = ToCheck[I]; + MemoryAccess *Current = ToCheck[I]; if (State.SkipStores.count(Current)) continue; - Optional<MemoryAccess *> Next = State.getDomMemoryDef( - KillingDef, Current, SILoc, SILocUnd, ScanLimit, WalkerStepLimit, - IsMemTerm, PartialLimit); + Optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef( + KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit, + WalkerStepLimit, IsMemTerm, PartialLimit); - if (!Next) { + if (!MaybeDeadAccess) { LLVM_DEBUG(dbgs() << " finished walk\n"); continue; } - MemoryAccess *EarlierAccess = *Next; - LLVM_DEBUG(dbgs() << " Checking if we can kill " << *EarlierAccess); - if (isa<MemoryPhi>(EarlierAccess)) { + MemoryAccess *DeadAccess = *MaybeDeadAccess; + LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess); + if (isa<MemoryPhi>(DeadAccess)) { LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n"); - for (Value *V : cast<MemoryPhi>(EarlierAccess)->incoming_values()) { + for (Value *V : cast<MemoryPhi>(DeadAccess)->incoming_values()) { MemoryAccess *IncomingAccess = cast<MemoryAccess>(V); BasicBlock *IncomingBlock = IncomingAccess->getBlock(); - BasicBlock *PhiBlock = EarlierAccess->getBlock(); + BasicBlock *PhiBlock = DeadAccess->getBlock(); // We only consider incoming MemoryAccesses that come before the // MemoryPhi. Otherwise we could discover candidates that do not @@ -1955,72 +2037,73 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, } continue; } - auto *NextDef = cast<MemoryDef>(EarlierAccess); - Instruction *NI = NextDef->getMemoryInst(); - LLVM_DEBUG(dbgs() << " (" << *NI << ")\n"); - ToCheck.insert(NextDef->getDefiningAccess()); + auto *DeadDefAccess = cast<MemoryDef>(DeadAccess); + Instruction *DeadI = DeadDefAccess->getMemoryInst(); + LLVM_DEBUG(dbgs() << " (" << *DeadI << ")\n"); + ToCheck.insert(DeadDefAccess->getDefiningAccess()); NumGetDomMemoryDefPassed++; if (!DebugCounter::shouldExecute(MemorySSACounter)) continue; - MemoryLocation NILoc = *State.getLocForWriteEx(NI); + MemoryLocation DeadLoc = *State.getLocForWriteEx(DeadI); if (IsMemTerm) { - const Value *NIUnd = getUnderlyingObject(NILoc.Ptr); - if (SILocUnd != NIUnd) + const Value *DeadUndObj = getUnderlyingObject(DeadLoc.Ptr); + if (KillingUndObj != DeadUndObj) continue; - LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI - << "\n KILLER: " << *SI << '\n'); - State.deleteDeadInstruction(NI); + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI + << "\n KILLER: " << *KillingI << '\n'); + State.deleteDeadInstruction(DeadI); ++NumFastStores; MadeChange = true; } else { - // Check if NI overwrites SI. - int64_t InstWriteOffset, DepWriteOffset; - OverwriteResult OR = State.isOverwrite(SI, NI, SILoc, NILoc, - DepWriteOffset, InstWriteOffset); + // Check if DeadI overwrites KillingI. + int64_t KillingOffset = 0; + int64_t DeadOffset = 0; + OverwriteResult OR = State.isOverwrite( + KillingI, DeadI, KillingLoc, DeadLoc, KillingOffset, DeadOffset); if (OR == OW_MaybePartial) { auto Iter = State.IOLs.insert( std::make_pair<BasicBlock *, InstOverlapIntervalsTy>( - NI->getParent(), InstOverlapIntervalsTy())); + DeadI->getParent(), InstOverlapIntervalsTy())); auto &IOL = Iter.first->second; - OR = isPartialOverwrite(SILoc, NILoc, DepWriteOffset, InstWriteOffset, - NI, IOL); + OR = isPartialOverwrite(KillingLoc, DeadLoc, KillingOffset, + DeadOffset, DeadI, IOL); } if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) { - auto *Earlier = dyn_cast<StoreInst>(NI); - auto *Later = dyn_cast<StoreInst>(SI); + auto *DeadSI = dyn_cast<StoreInst>(DeadI); + auto *KillingSI = dyn_cast<StoreInst>(KillingI); // We are re-using tryToMergePartialOverlappingStores, which requires - // Earlier to domiante Later. + // DeadSI to dominate DeadSI. // TODO: implement tryToMergeParialOverlappingStores using MemorySSA. - if (Earlier && Later && DT.dominates(Earlier, Later)) { + if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) { if (Constant *Merged = tryToMergePartialOverlappingStores( - Earlier, Later, InstWriteOffset, DepWriteOffset, State.DL, + KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL, State.BatchAA, &DT)) { // Update stored value of earlier store to merged constant. - Earlier->setOperand(0, Merged); + DeadSI->setOperand(0, Merged); ++NumModifiedStores; MadeChange = true; Shortend = true; - // Remove later store and remove any outstanding overlap intervals - // for the updated store. - State.deleteDeadInstruction(Later); - auto I = State.IOLs.find(Earlier->getParent()); + // Remove killing store and remove any outstanding overlap + // intervals for the updated store. + State.deleteDeadInstruction(KillingSI); + auto I = State.IOLs.find(DeadSI->getParent()); if (I != State.IOLs.end()) - I->second.erase(Earlier); + I->second.erase(DeadSI); break; } } } if (OR == OW_Complete) { - LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI - << "\n KILLER: " << *SI << '\n'); - State.deleteDeadInstruction(NI); + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI + << "\n KILLER: " << *KillingI << '\n'); + State.deleteDeadInstruction(DeadI); ++NumFastStores; MadeChange = true; } @@ -2028,10 +2111,11 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, } // Check if the store is a no-op. - if (!Shortend && isRemovable(SI) && - State.storeIsNoop(KillingDef, SILoc, SILocUnd)) { - LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *SI << '\n'); - State.deleteDeadInstruction(SI); + if (!Shortend && isRemovable(KillingI) && + State.storeIsNoop(KillingDef, KillingUndObj)) { + LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *KillingI + << '\n'); + State.deleteDeadInstruction(KillingI); NumRedundantStores++; MadeChange = true; continue; @@ -2040,8 +2124,9 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, if (EnablePartialOverwriteTracking) for (auto &KV : State.IOLs) - MadeChange |= removePartiallyOverlappedStores(State.DL, KV.second, TLI); + MadeChange |= State.removePartiallyOverlappedStores(KV.second); + MadeChange |= State.eliminateRedundantStoresOfExistingValues(); MadeChange |= State.eliminateDeadWritesAtEndOfFunction(); return MadeChange; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 978c6a77b8dc..90f71f7729a7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -293,7 +293,7 @@ static unsigned getHashValueImpl(SimpleValue Val) { // TODO: Extend this to handle intrinsics with >2 operands where the 1st // 2 operands are commutative. auto *II = dyn_cast<IntrinsicInst>(Inst); - if (II && II->isCommutative() && II->getNumArgOperands() == 2) { + if (II && II->isCommutative() && II->arg_size() == 2) { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); if (LHS > RHS) std::swap(LHS, RHS); @@ -363,7 +363,7 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { auto *LII = dyn_cast<IntrinsicInst>(LHSI); auto *RII = dyn_cast<IntrinsicInst>(RHSI); if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() && - LII->isCommutative() && LII->getNumArgOperands() == 2) { + LII->isCommutative() && LII->arg_size() == 2) { return LII->getArgOperand(0) == RII->getArgOperand(1) && LII->getArgOperand(1) == RII->getArgOperand(0); } @@ -1265,6 +1265,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } + // Skip pseudoprobe intrinsics, for the same reason as assume intrinsics. + if (match(&Inst, m_Intrinsic<Intrinsic::pseudoprobe>())) { + LLVM_DEBUG(dbgs() << "EarlyCSE skipping pseudoprobe: " << Inst << '\n'); + continue; + } + // We can skip all invariant.start intrinsics since they only read memory, // and we can forward values across it. For invariant starts without // invariant ends, we can use the fact that the invariantness never ends to @@ -1642,6 +1648,16 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, return PA; } +void EarlyCSEPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<EarlyCSEPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (UseMemorySSA) + OS << "memssa"; + OS << ">"; +} + namespace { /// A simple and fast domtree-based CSE pass. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp index 8a5d4f568774..a98bb8358aef 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -256,7 +256,7 @@ void Float2IntPass::walkForwards() { Op = [](ArrayRef<ConstantRange> Ops) { assert(Ops.size() == 1 && "FNeg is a unary operator!"); unsigned Size = Ops[0].getBitWidth(); - auto Zero = ConstantRange(APInt::getNullValue(Size)); + auto Zero = ConstantRange(APInt::getZero(Size)); return Zero.sub(Ops[0]); }; break; @@ -372,7 +372,7 @@ bool Float2IntPass::validateAndTransform() { // If it does, transformation would be illegal. // // Don't count the roots, as they terminate the graphs. - if (Roots.count(I) == 0) { + if (!Roots.contains(I)) { // Set the type of the conversion while we're here. if (!ConvertedToTy) ConvertedToTy = I->getType(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp index 16368aec7c3f..00506fb86006 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp @@ -126,7 +126,7 @@ static cl::opt<uint32_t> MaxBBSpeculations( "into) when deducing if a value is fully available or not in GVN " "(default = 600)")); -struct llvm::GVN::Expression { +struct llvm::GVNPass::Expression { uint32_t opcode; bool commutative = false; Type *type = nullptr; @@ -155,17 +155,18 @@ struct llvm::GVN::Expression { namespace llvm { -template <> struct DenseMapInfo<GVN::Expression> { - static inline GVN::Expression getEmptyKey() { return ~0U; } - static inline GVN::Expression getTombstoneKey() { return ~1U; } +template <> struct DenseMapInfo<GVNPass::Expression> { + static inline GVNPass::Expression getEmptyKey() { return ~0U; } + static inline GVNPass::Expression getTombstoneKey() { return ~1U; } - static unsigned getHashValue(const GVN::Expression &e) { + static unsigned getHashValue(const GVNPass::Expression &e) { using llvm::hash_value; return static_cast<unsigned>(hash_value(e)); } - static bool isEqual(const GVN::Expression &LHS, const GVN::Expression &RHS) { + static bool isEqual(const GVNPass::Expression &LHS, + const GVNPass::Expression &RHS) { return LHS == RHS; } }; @@ -246,7 +247,7 @@ struct llvm::gvn::AvailableValue { /// Emit code at the specified insertion point to adjust the value defined /// here to the specified type. This handles various coercion cases. Value *MaterializeAdjustedValue(LoadInst *Load, Instruction *InsertPt, - GVN &gvn) const; + GVNPass &gvn) const; }; /// Represents an AvailableValue which can be rematerialized at the end of @@ -276,7 +277,7 @@ struct llvm::gvn::AvailableValueInBlock { /// Emit code at the end of this block to adjust the value defined here to /// the specified type. This handles various coercion cases. - Value *MaterializeAdjustedValue(LoadInst *Load, GVN &gvn) const { + Value *MaterializeAdjustedValue(LoadInst *Load, GVNPass &gvn) const { return AV.MaterializeAdjustedValue(Load, BB->getTerminator(), gvn); } }; @@ -285,7 +286,7 @@ struct llvm::gvn::AvailableValueInBlock { // ValueTable Internal Functions //===----------------------------------------------------------------------===// -GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { +GVNPass::Expression GVNPass::ValueTable::createExpr(Instruction *I) { Expression e; e.type = I->getType(); e.opcode = I->getOpcode(); @@ -330,9 +331,8 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { return e; } -GVN::Expression GVN::ValueTable::createCmpExpr(unsigned Opcode, - CmpInst::Predicate Predicate, - Value *LHS, Value *RHS) { +GVNPass::Expression GVNPass::ValueTable::createCmpExpr( + unsigned Opcode, CmpInst::Predicate Predicate, Value *LHS, Value *RHS) { assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && "Not a comparison!"); Expression e; @@ -350,7 +350,8 @@ GVN::Expression GVN::ValueTable::createCmpExpr(unsigned Opcode, return e; } -GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { +GVNPass::Expression +GVNPass::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { assert(EI && "Not an ExtractValueInst?"); Expression e; e.type = EI->getType(); @@ -382,20 +383,21 @@ GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { // ValueTable External Functions //===----------------------------------------------------------------------===// -GVN::ValueTable::ValueTable() = default; -GVN::ValueTable::ValueTable(const ValueTable &) = default; -GVN::ValueTable::ValueTable(ValueTable &&) = default; -GVN::ValueTable::~ValueTable() = default; -GVN::ValueTable &GVN::ValueTable::operator=(const GVN::ValueTable &Arg) = default; +GVNPass::ValueTable::ValueTable() = default; +GVNPass::ValueTable::ValueTable(const ValueTable &) = default; +GVNPass::ValueTable::ValueTable(ValueTable &&) = default; +GVNPass::ValueTable::~ValueTable() = default; +GVNPass::ValueTable & +GVNPass::ValueTable::operator=(const GVNPass::ValueTable &Arg) = default; /// add - Insert a value into the table with a specified value number. -void GVN::ValueTable::add(Value *V, uint32_t num) { +void GVNPass::ValueTable::add(Value *V, uint32_t num) { valueNumbering.insert(std::make_pair(V, num)); if (PHINode *PN = dyn_cast<PHINode>(V)) NumberingPhi[num] = PN; } -uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { +uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) { if (AA->doesNotAccessMemory(C)) { Expression exp = createExpr(C); uint32_t e = assignExpNewValueNum(exp).first; @@ -421,13 +423,12 @@ uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { // a normal load or store instruction. CallInst *local_cdep = dyn_cast<CallInst>(local_dep.getInst()); - if (!local_cdep || - local_cdep->getNumArgOperands() != C->getNumArgOperands()) { + if (!local_cdep || local_cdep->arg_size() != C->arg_size()) { valueNumbering[C] = nextValueNumber; return nextValueNumber++; } - for (unsigned i = 0, e = C->getNumArgOperands(); i < e; ++i) { + for (unsigned i = 0, e = C->arg_size(); i < e; ++i) { uint32_t c_vn = lookupOrAdd(C->getArgOperand(i)); uint32_t cd_vn = lookupOrAdd(local_cdep->getArgOperand(i)); if (c_vn != cd_vn) { @@ -477,11 +478,11 @@ uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { return nextValueNumber++; } - if (cdep->getNumArgOperands() != C->getNumArgOperands()) { + if (cdep->arg_size() != C->arg_size()) { valueNumbering[C] = nextValueNumber; return nextValueNumber++; } - for (unsigned i = 0, e = C->getNumArgOperands(); i < e; ++i) { + for (unsigned i = 0, e = C->arg_size(); i < e; ++i) { uint32_t c_vn = lookupOrAdd(C->getArgOperand(i)); uint32_t cd_vn = lookupOrAdd(cdep->getArgOperand(i)); if (c_vn != cd_vn) { @@ -500,11 +501,13 @@ uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { } /// Returns true if a value number exists for the specified value. -bool GVN::ValueTable::exists(Value *V) const { return valueNumbering.count(V) != 0; } +bool GVNPass::ValueTable::exists(Value *V) const { + return valueNumbering.count(V) != 0; +} /// lookup_or_add - Returns the value number for the specified value, assigning /// it a new number if it did not have one before. -uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { +uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) { DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); if (VI != valueNumbering.end()) return VI->second; @@ -581,7 +584,7 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { /// Returns the value number of the specified value. Fails if /// the value has not yet been numbered. -uint32_t GVN::ValueTable::lookup(Value *V, bool Verify) const { +uint32_t GVNPass::ValueTable::lookup(Value *V, bool Verify) const { DenseMap<Value*, uint32_t>::const_iterator VI = valueNumbering.find(V); if (Verify) { assert(VI != valueNumbering.end() && "Value not numbered?"); @@ -594,15 +597,15 @@ uint32_t GVN::ValueTable::lookup(Value *V, bool Verify) const { /// assigning it a new number if it did not have one before. Useful when /// we deduced the result of a comparison, but don't immediately have an /// instruction realizing that comparison to hand. -uint32_t GVN::ValueTable::lookupOrAddCmp(unsigned Opcode, - CmpInst::Predicate Predicate, - Value *LHS, Value *RHS) { +uint32_t GVNPass::ValueTable::lookupOrAddCmp(unsigned Opcode, + CmpInst::Predicate Predicate, + Value *LHS, Value *RHS) { Expression exp = createCmpExpr(Opcode, Predicate, LHS, RHS); return assignExpNewValueNum(exp).first; } /// Remove all entries from the ValueTable. -void GVN::ValueTable::clear() { +void GVNPass::ValueTable::clear() { valueNumbering.clear(); expressionNumbering.clear(); NumberingPhi.clear(); @@ -614,7 +617,7 @@ void GVN::ValueTable::clear() { } /// Remove a value from the value numbering. -void GVN::ValueTable::erase(Value *V) { +void GVNPass::ValueTable::erase(Value *V) { uint32_t Num = valueNumbering.lookup(V); valueNumbering.erase(V); // If V is PHINode, V <--> value number is an one-to-one mapping. @@ -624,7 +627,7 @@ void GVN::ValueTable::erase(Value *V) { /// verifyRemoved - Verify that the value is removed from all internal data /// structures. -void GVN::ValueTable::verifyRemoved(const Value *V) const { +void GVNPass::ValueTable::verifyRemoved(const Value *V) const { for (DenseMap<Value*, uint32_t>::const_iterator I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) { assert(I->first != V && "Inst still occurs in value numbering map!"); @@ -635,28 +638,28 @@ void GVN::ValueTable::verifyRemoved(const Value *V) const { // GVN Pass //===----------------------------------------------------------------------===// -bool GVN::isPREEnabled() const { +bool GVNPass::isPREEnabled() const { return Options.AllowPRE.getValueOr(GVNEnablePRE); } -bool GVN::isLoadPREEnabled() const { +bool GVNPass::isLoadPREEnabled() const { return Options.AllowLoadPRE.getValueOr(GVNEnableLoadPRE); } -bool GVN::isLoadInLoopPREEnabled() const { +bool GVNPass::isLoadInLoopPREEnabled() const { return Options.AllowLoadInLoopPRE.getValueOr(GVNEnableLoadInLoopPRE); } -bool GVN::isLoadPRESplitBackedgeEnabled() const { +bool GVNPass::isLoadPRESplitBackedgeEnabled() const { return Options.AllowLoadPRESplitBackedge.getValueOr( GVNEnableSplitBackedgeInLoadPRE); } -bool GVN::isMemDepEnabled() const { +bool GVNPass::isMemDepEnabled() const { return Options.AllowMemDep.getValueOr(GVNEnableMemDep); } -PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { +PreservedAnalyses GVNPass::run(Function &F, FunctionAnalysisManager &AM) { // FIXME: The order of evaluation of these 'getResult' calls is very // significant! Re-ordering these variables will cause GVN when run alone to // be less effective! We should fix memdep and basic-aa to not exhibit this @@ -684,8 +687,26 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { return PA; } +void GVNPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<GVNPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + + OS << "<"; + if (Options.AllowPRE != None) + OS << (Options.AllowPRE.getValue() ? "" : "no-") << "pre;"; + if (Options.AllowLoadPRE != None) + OS << (Options.AllowLoadPRE.getValue() ? "" : "no-") << "load-pre;"; + if (Options.AllowLoadPRESplitBackedge != None) + OS << (Options.AllowLoadPRESplitBackedge.getValue() ? "" : "no-") + << "split-backedge-load-pre;"; + if (Options.AllowMemDep != None) + OS << (Options.AllowMemDep.getValue() ? "" : "no-") << "memdep"; + OS << ">"; +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -LLVM_DUMP_METHOD void GVN::dump(DenseMap<uint32_t, Value*>& d) const { +LLVM_DUMP_METHOD void GVNPass::dump(DenseMap<uint32_t, Value *> &d) const { errs() << "{\n"; for (auto &I : d) { errs() << I.first << "\n"; @@ -835,7 +856,7 @@ static bool IsValueFullyAvailableInBlock( static Value * ConstructSSAForLoadSet(LoadInst *Load, SmallVectorImpl<AvailableValueInBlock> &ValuesPerBlock, - GVN &gvn) { + GVNPass &gvn) { // Check for the fully redundant, dominating load case. In this case, we can // just use the dominating value directly. if (ValuesPerBlock.size() == 1 && @@ -878,7 +899,7 @@ ConstructSSAForLoadSet(LoadInst *Load, Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, Instruction *InsertPt, - GVN &gvn) const { + GVNPass &gvn) const { Value *Res; Type *LoadTy = Load->getType(); const DataLayout &DL = Load->getModule()->getDataLayout(); @@ -1002,8 +1023,8 @@ static void reportMayClobberedLoad(LoadInst *Load, MemDepResult DepInfo, ORE->emit(R); } -bool GVN::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, - Value *Address, AvailableValue &Res) { +bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, + Value *Address, AvailableValue &Res) { assert((DepInfo.isDef() || DepInfo.isClobber()) && "expected a local dependence"); assert(Load->isUnordered() && "rules below are incorrect for ordered access"); @@ -1137,9 +1158,9 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, return false; } -void GVN::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, - AvailValInBlkVect &ValuesPerBlock, - UnavailBlkVect &UnavailableBlocks) { +void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, + AvailValInBlkVect &ValuesPerBlock, + UnavailBlkVect &UnavailableBlocks) { // Filter out useless results (non-locals, etc). Keep track of the blocks // where we have a value available in repl, also keep track of whether we see // dependencies that produce an unknown value for the load (such as a call @@ -1182,7 +1203,7 @@ void GVN::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, "post condition violation"); } -void GVN::eliminatePartiallyRedundantLoad( +void GVNPass::eliminatePartiallyRedundantLoad( LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, MapVector<BasicBlock *, Value *> &AvailableLoads) { for (const auto &AvailableLoad : AvailableLoads) { @@ -1212,8 +1233,7 @@ void GVN::eliminatePartiallyRedundantLoad( } // Transfer the old load's AA tags to the new load. - AAMDNodes Tags; - Load->getAAMetadata(Tags); + AAMDNodes Tags = Load->getAAMetadata(); if (Tags) NewLoad->setAAMetadata(Tags); @@ -1257,8 +1277,8 @@ void GVN::eliminatePartiallyRedundantLoad( }); } -bool GVN::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, - UnavailBlkVect &UnavailableBlocks) { +bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, + UnavailBlkVect &UnavailableBlocks) { // Okay, we have *some* definitions of the value. This means that the value // is available in some of our (transitive) predecessors. Lets think about // doing PRE of this load. This will involve inserting a new load into the @@ -1498,8 +1518,9 @@ bool GVN::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, return true; } -bool GVN::performLoopLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, - UnavailBlkVect &UnavailableBlocks) { +bool GVNPass::performLoopLoadPRE(LoadInst *Load, + AvailValInBlkVect &ValuesPerBlock, + UnavailBlkVect &UnavailableBlocks) { if (!LI) return false; @@ -1590,7 +1611,7 @@ static void reportLoadElim(LoadInst *Load, Value *AvailableValue, /// Attempt to eliminate a load whose dependencies are /// non-local by performing PHI construction. -bool GVN::processNonLocalLoad(LoadInst *Load) { +bool GVNPass::processNonLocalLoad(LoadInst *Load) { // non-local speculations are not allowed under asan. if (Load->getParent()->getParent()->hasFnAttribute( Attribute::SanitizeAddress) || @@ -1622,10 +1643,8 @@ bool GVN::processNonLocalLoad(LoadInst *Load) { // If this load follows a GEP, see if we can PRE the indices before analyzing. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Load->getOperand(0))) { - for (GetElementPtrInst::op_iterator OI = GEP->idx_begin(), - OE = GEP->idx_end(); - OI != OE; ++OI) - if (Instruction *I = dyn_cast<Instruction>(OI->get())) + for (Use &U : GEP->indices()) + if (Instruction *I = dyn_cast<Instruction>(U.get())) Changed |= performScalarPRE(I); } @@ -1673,8 +1692,11 @@ bool GVN::processNonLocalLoad(LoadInst *Load) { if (!isLoadInLoopPREEnabled() && LI && LI->getLoopFor(Load->getParent())) return Changed; - return Changed || PerformLoadPRE(Load, ValuesPerBlock, UnavailableBlocks) || - performLoopLoadPRE(Load, ValuesPerBlock, UnavailableBlocks); + if (performLoopLoadPRE(Load, ValuesPerBlock, UnavailableBlocks) || + PerformLoadPRE(Load, ValuesPerBlock, UnavailableBlocks)) + return true; + + return Changed; } static bool impliesEquivalanceIfTrue(CmpInst* Cmp) { @@ -1738,7 +1760,7 @@ static bool hasUsersIn(Value *V, BasicBlock *BB) { return false; } -bool GVN::processAssumeIntrinsic(AssumeInst *IntrinsicI) { +bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { Value *V = IntrinsicI->getArgOperand(0); if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { @@ -1882,7 +1904,7 @@ static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { /// Attempt to eliminate a load, first by eliminating it /// locally, and then attempting non-local elimination if that fails. -bool GVN::processLoad(LoadInst *L) { +bool GVNPass::processLoad(LoadInst *L) { if (!MD) return false; @@ -1936,7 +1958,7 @@ bool GVN::processLoad(LoadInst *L) { /// Return a pair the first field showing the value number of \p Exp and the /// second field showing whether it is a value number newly created. std::pair<uint32_t, bool> -GVN::ValueTable::assignExpNewValueNum(Expression &Exp) { +GVNPass::ValueTable::assignExpNewValueNum(Expression &Exp) { uint32_t &e = expressionNumbering[Exp]; bool CreateNewValNum = !e; if (CreateNewValNum) { @@ -1951,8 +1973,8 @@ GVN::ValueTable::assignExpNewValueNum(Expression &Exp) { /// Return whether all the values related with the same \p num are /// defined in \p BB. -bool GVN::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB, - GVN &Gvn) { +bool GVNPass::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB, + GVNPass &Gvn) { LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; while (Vals && Vals->BB == BB) Vals = Vals->Next; @@ -1960,9 +1982,9 @@ bool GVN::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB, } /// Wrap phiTranslateImpl to provide caching functionality. -uint32_t GVN::ValueTable::phiTranslate(const BasicBlock *Pred, - const BasicBlock *PhiBlock, uint32_t Num, - GVN &Gvn) { +uint32_t GVNPass::ValueTable::phiTranslate(const BasicBlock *Pred, + const BasicBlock *PhiBlock, + uint32_t Num, GVNPass &Gvn) { auto FindRes = PhiTranslateTable.find({Num, Pred}); if (FindRes != PhiTranslateTable.end()) return FindRes->second; @@ -1973,9 +1995,10 @@ uint32_t GVN::ValueTable::phiTranslate(const BasicBlock *Pred, // Return true if the value number \p Num and NewNum have equal value. // Return false if the result is unknown. -bool GVN::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum, - const BasicBlock *Pred, - const BasicBlock *PhiBlock, GVN &Gvn) { +bool GVNPass::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum, + const BasicBlock *Pred, + const BasicBlock *PhiBlock, + GVNPass &Gvn) { CallInst *Call = nullptr; LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; while (Vals) { @@ -2008,9 +2031,9 @@ bool GVN::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum, /// Translate value number \p Num using phis, so that it has the values of /// the phis in BB. -uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, - const BasicBlock *PhiBlock, - uint32_t Num, GVN &Gvn) { +uint32_t GVNPass::ValueTable::phiTranslateImpl(const BasicBlock *Pred, + const BasicBlock *PhiBlock, + uint32_t Num, GVNPass &Gvn) { if (PHINode *PN = NumberingPhi[Num]) { for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { if (PN->getParent() == PhiBlock && PN->getIncomingBlock(i) == Pred) @@ -2063,8 +2086,8 @@ uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, /// Erase stale entry from phiTranslate cache so phiTranslate can be computed /// again. -void GVN::ValueTable::eraseTranslateCacheEntry(uint32_t Num, - const BasicBlock &CurrBlock) { +void GVNPass::ValueTable::eraseTranslateCacheEntry( + uint32_t Num, const BasicBlock &CurrBlock) { for (const BasicBlock *Pred : predecessors(&CurrBlock)) PhiTranslateTable.erase({Num, Pred}); } @@ -2074,7 +2097,7 @@ void GVN::ValueTable::eraseTranslateCacheEntry(uint32_t Num, // and then scan the list to find one whose block dominates the block in // question. This is fast because dominator tree queries consist of only // a few comparisons of DFS numbers. -Value *GVN::findLeader(const BasicBlock *BB, uint32_t num) { +Value *GVNPass::findLeader(const BasicBlock *BB, uint32_t num) { LeaderTableEntry Vals = LeaderTable[num]; if (!Vals.Val) return nullptr; @@ -2113,7 +2136,7 @@ static bool isOnlyReachableViaThisEdge(const BasicBlockEdge &E, return Pred != nullptr; } -void GVN::assignBlockRPONumber(Function &F) { +void GVNPass::assignBlockRPONumber(Function &F) { BlockRPONumber.clear(); uint32_t NextBlockNumber = 1; ReversePostOrderTraversal<Function *> RPOT(&F); @@ -2122,7 +2145,7 @@ void GVN::assignBlockRPONumber(Function &F) { InvalidBlockRPONumbers = false; } -bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const { +bool GVNPass::replaceOperandsForInBlockEquality(Instruction *Instr) const { bool Changed = false; for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { Value *Operand = Instr->getOperand(OpNum); @@ -2142,8 +2165,9 @@ bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const { /// 'RHS' everywhere in the scope. Returns whether a change was made. /// If DominatesByEdge is false, then it means that we will propagate the RHS /// value starting from the end of Root.Start. -bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, - bool DominatesByEdge) { +bool GVNPass::propagateEquality(Value *LHS, Value *RHS, + const BasicBlockEdge &Root, + bool DominatesByEdge) { SmallVector<std::pair<Value*, Value*>, 4> Worklist; Worklist.push_back(std::make_pair(LHS, RHS)); bool Changed = false; @@ -2291,7 +2315,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, /// When calculating availability, handle an instruction /// by inserting it into the appropriate sets -bool GVN::processInstruction(Instruction *I) { +bool GVNPass::processInstruction(Instruction *I) { // Ignore dbg info intrinsics. if (isa<DbgInfoIntrinsic>(I)) return false; @@ -2432,10 +2456,10 @@ bool GVN::processInstruction(Instruction *I) { } /// runOnFunction - This is the main transformation entry point for a function. -bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, - const TargetLibraryInfo &RunTLI, AAResults &RunAA, - MemoryDependenceResults *RunMD, LoopInfo *LI, - OptimizationRemarkEmitter *RunORE, MemorySSA *MSSA) { +bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, + const TargetLibraryInfo &RunTLI, AAResults &RunAA, + MemoryDependenceResults *RunMD, LoopInfo *LI, + OptimizationRemarkEmitter *RunORE, MemorySSA *MSSA) { AC = &RunAC; DT = &RunDT; VN.setDomTree(DT); @@ -2457,10 +2481,8 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); // Merge unconditional branches, allowing PRE to catch more // optimization opportunities. - for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ) { - BasicBlock *BB = &*FI++; - - bool removedBlock = MergeBlockIntoPredecessor(BB, &DTU, LI, MSSAU, MD); + for (BasicBlock &BB : llvm::make_early_inc_range(F)) { + bool removedBlock = MergeBlockIntoPredecessor(&BB, &DTU, LI, MSSAU, MD); if (removedBlock) ++NumGVNBlocks; @@ -2502,7 +2524,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, return Changed; } -bool GVN::processBlock(BasicBlock *BB) { +bool GVNPass::processBlock(BasicBlock *BB) { // FIXME: Kill off InstrsToErase by doing erasing eagerly in a helper function // (and incrementing BI before processing an instruction). assert(InstrsToErase.empty() && @@ -2563,8 +2585,8 @@ bool GVN::processBlock(BasicBlock *BB) { } // Instantiate an expression in a predecessor that lacked it. -bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, - BasicBlock *Curr, unsigned int ValNo) { +bool GVNPass::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, + BasicBlock *Curr, unsigned int ValNo) { // Because we are going top-down through the block, all value numbers // will be available in the predecessor by the time we need them. Any // that weren't originally present will have been instantiated earlier @@ -2612,7 +2634,7 @@ bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, return true; } -bool GVN::performScalarPRE(Instruction *CurInst) { +bool GVNPass::performScalarPRE(Instruction *CurInst) { if (isa<AllocaInst>(CurInst) || CurInst->isTerminator() || isa<PHINode>(CurInst) || CurInst->getType()->isVoidTy() || CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || @@ -2797,7 +2819,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { /// Perform a purely local form of PRE that looks for diamond /// control flow patterns and attempts to perform simple PRE at the join point. -bool GVN::performPRE(Function &F) { +bool GVNPass::performPRE(Function &F) { bool Changed = false; for (BasicBlock *CurrentBlock : depth_first(&F.getEntryBlock())) { // Nothing to PRE in the entry block. @@ -2824,7 +2846,7 @@ bool GVN::performPRE(Function &F) { /// Split the critical edge connecting the given two blocks, and return /// the block inserted to the critical edge. -BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { +BasicBlock *GVNPass::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { // GVN does not require loop-simplify, do not try to preserve it if it is not // possible. BasicBlock *BB = SplitCriticalEdge( @@ -2840,7 +2862,7 @@ BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { /// Split critical edges found during the previous /// iteration that may enable further optimization. -bool GVN::splitCriticalEdges() { +bool GVNPass::splitCriticalEdges() { if (toSplit.empty()) return false; @@ -2860,7 +2882,7 @@ bool GVN::splitCriticalEdges() { } /// Executes one iteration of GVN -bool GVN::iterateOnFunction(Function &F) { +bool GVNPass::iterateOnFunction(Function &F) { cleanupGlobalSets(); // Top-down walk of the dominator tree @@ -2876,7 +2898,7 @@ bool GVN::iterateOnFunction(Function &F) { return Changed; } -void GVN::cleanupGlobalSets() { +void GVNPass::cleanupGlobalSets() { VN.clear(); LeaderTable.clear(); BlockRPONumber.clear(); @@ -2887,7 +2909,7 @@ void GVN::cleanupGlobalSets() { /// Verify that the specified instruction does not occur in our /// internal data structures. -void GVN::verifyRemoved(const Instruction *Inst) const { +void GVNPass::verifyRemoved(const Instruction *Inst) const { VN.verifyRemoved(Inst); // Walk through the value number scope to make sure the instruction isn't @@ -2907,7 +2929,7 @@ void GVN::verifyRemoved(const Instruction *Inst) const { /// function is to add all these blocks to "DeadBlocks". For the dead blocks' /// live successors, update their phi nodes by replacing the operands /// corresponding to dead blocks with UndefVal. -void GVN::addDeadBlock(BasicBlock *BB) { +void GVNPass::addDeadBlock(BasicBlock *BB) { SmallVector<BasicBlock *, 4> NewDead; SmallSetVector<BasicBlock *, 4> DF; @@ -2995,7 +3017,7 @@ void GVN::addDeadBlock(BasicBlock *BB) { // dead blocks with "UndefVal" in an hope these PHIs will optimized away. // // Return true iff *NEW* dead code are found. -bool GVN::processFoldableCondBr(BranchInst *BI) { +bool GVNPass::processFoldableCondBr(BranchInst *BI) { if (!BI || BI->isUnconditional()) return false; @@ -3023,7 +3045,7 @@ bool GVN::processFoldableCondBr(BranchInst *BI) { // associated val-num. As it normally has far more live instructions than dead // instructions, it makes more sense just to "fabricate" a val-number for the // dead code than checking if instruction involved is dead or not. -void GVN::assignValNumForDeadCode() { +void GVNPass::assignValNumForDeadCode() { for (BasicBlock *BB : DeadBlocks) { for (Instruction &Inst : *BB) { unsigned ValNum = VN.lookupOrAdd(&Inst); @@ -3078,7 +3100,7 @@ public: } private: - GVN Impl; + GVNPass Impl; }; char GVNLegacyPass::ID = 0; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp index 790d71992da4..fdc3afd9348a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -169,7 +169,7 @@ class InsnInfo { public: // Inserts I and its value number in VNtoScalars. - void insert(Instruction *I, GVN::ValueTable &VN) { + void insert(Instruction *I, GVNPass::ValueTable &VN) { // Scalar instruction. unsigned V = VN.lookupOrAdd(I); VNtoScalars[{V, InvalidVN}].push_back(I); @@ -184,7 +184,7 @@ class LoadInfo { public: // Insert Load and the value number of its memory address in VNtoLoads. - void insert(LoadInst *Load, GVN::ValueTable &VN) { + void insert(LoadInst *Load, GVNPass::ValueTable &VN) { if (Load->isSimple()) { unsigned V = VN.lookupOrAdd(Load->getPointerOperand()); VNtoLoads[{V, InvalidVN}].push_back(Load); @@ -201,7 +201,7 @@ class StoreInfo { public: // Insert the Store and a hash number of the store address and the stored // value in VNtoStores. - void insert(StoreInst *Store, GVN::ValueTable &VN) { + void insert(StoreInst *Store, GVNPass::ValueTable &VN) { if (!Store->isSimple()) return; // Hash the store address and the stored value. @@ -221,7 +221,7 @@ class CallInfo { public: // Insert Call and its value numbering in one of the VNtoCalls* containers. - void insert(CallInst *Call, GVN::ValueTable &VN) { + void insert(CallInst *Call, GVNPass::ValueTable &VN) { // A call that doesNotAccessMemory is handled as a Scalar, // onlyReadsMemory will be handled as a Load instruction, // all other calls will be handled as stores. @@ -274,7 +274,7 @@ public: unsigned int rank(const Value *V) const; private: - GVN::ValueTable VN; + GVNPass::ValueTable VN; DominatorTree *DT; PostDominatorTree *PDT; AliasAnalysis *AA; @@ -377,12 +377,12 @@ private: if (!Root) return; // Depth first walk on PDom tree to fill the CHIargs at each PDF. - RenameStackType RenameStack; for (auto Node : depth_first(Root)) { BasicBlock *BB = Node->getBlock(); if (!BB) continue; + RenameStackType RenameStack; // Collect all values in BB and push to stack. fillRenameStack(BB, ValueBBs, RenameStack); @@ -827,6 +827,8 @@ void GVNHoist::fillRenameStack(BasicBlock *BB, InValuesType &ValueBBs, auto it1 = ValueBBs.find(BB); if (it1 != ValueBBs.end()) { // Iterate in reverse order to keep lower ranked values on the top. + LLVM_DEBUG(dbgs() << "\nVisiting: " << BB->getName() + << " for pushing instructions on stack";); for (std::pair<VNType, Instruction *> &VI : reverse(it1->second)) { // Get the value of instruction I LLVM_DEBUG(dbgs() << "\nPushing on stack: " << *VI.second); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp index 61eb4ce0ed46..82b81003ef21 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -46,6 +46,7 @@ #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" @@ -105,8 +106,10 @@ static void setCondition(Instruction *I, Value *NewCond) { } // Eliminates the guard instruction properly. -static void eliminateGuard(Instruction *GuardInst) { +static void eliminateGuard(Instruction *GuardInst, MemorySSAUpdater *MSSAU) { GuardInst->eraseFromParent(); + if (MSSAU) + MSSAU->removeMemoryAccess(GuardInst); ++GuardsEliminated; } @@ -114,6 +117,7 @@ class GuardWideningImpl { DominatorTree &DT; PostDominatorTree *PDT; LoopInfo &LI; + MemorySSAUpdater *MSSAU; /// Together, these describe the region of interest. This might be all of /// the blocks within a function, or only a given loop's blocks and preheader. @@ -269,12 +273,12 @@ class GuardWideningImpl { } public: - explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree *PDT, - LoopInfo &LI, DomTreeNode *Root, + LoopInfo &LI, MemorySSAUpdater *MSSAU, + DomTreeNode *Root, std::function<bool(BasicBlock*)> BlockFilter) - : DT(DT), PDT(PDT), LI(LI), Root(Root), BlockFilter(BlockFilter) - {} + : DT(DT), PDT(PDT), LI(LI), MSSAU(MSSAU), Root(Root), + BlockFilter(BlockFilter) {} /// The entry point for this pass. bool run(); @@ -313,7 +317,7 @@ bool GuardWideningImpl::run() { if (!WidenedGuards.count(I)) { assert(isa<ConstantInt>(getCondition(I)) && "Should be!"); if (isSupportedGuardInstruction(I)) - eliminateGuard(I); + eliminateGuard(I, MSSAU); else { assert(isa<BranchInst>(I) && "Eliminated something other than guard or branch?"); @@ -514,27 +518,20 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, ConstantRange CR1 = ConstantRange::makeExactICmpRegion(Pred1, RHS1->getValue()); - // SubsetIntersect is a subset of the actual mathematical intersection of - // CR0 and CR1, while SupersetIntersect is a superset of the actual - // mathematical intersection. If these two ConstantRanges are equal, then - // we know we were able to represent the actual mathematical intersection - // of CR0 and CR1, and can use the same to generate an icmp instruction. - // // Given what we're doing here and the semantics of guards, it would - // actually be correct to just use SubsetIntersect, but that may be too + // be correct to use a subset intersection, but that may be too // aggressive in cases we care about. - auto SubsetIntersect = CR0.inverse().unionWith(CR1.inverse()).inverse(); - auto SupersetIntersect = CR0.intersectWith(CR1); - - APInt NewRHSAP; - CmpInst::Predicate Pred; - if (SubsetIntersect == SupersetIntersect && - SubsetIntersect.getEquivalentICmp(Pred, NewRHSAP)) { - if (InsertPt) { - ConstantInt *NewRHS = ConstantInt::get(Cond0->getContext(), NewRHSAP); - Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + if (Optional<ConstantRange> Intersect = CR0.exactIntersectWith(CR1)) { + APInt NewRHSAP; + CmpInst::Predicate Pred; + if (Intersect->getEquivalentICmp(Pred, NewRHSAP)) { + if (InsertPt) { + ConstantInt *NewRHS = + ConstantInt::get(Cond0->getContext(), NewRHSAP); + Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + } + return true; } - return true; } } } @@ -766,12 +763,18 @@ PreservedAnalyses GuardWideningPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - if (!GuardWideningImpl(DT, &PDT, LI, DT.getRootNode(), - [](BasicBlock*) { return true; } ).run()) + auto *MSSAA = AM.getCachedResult<MemorySSAAnalysis>(F); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSAA) + MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAA->getMSSA()); + if (!GuardWideningImpl(DT, &PDT, LI, MSSAU ? MSSAU.get() : nullptr, + DT.getRootNode(), [](BasicBlock *) { return true; }) + .run()) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -784,11 +787,17 @@ PreservedAnalyses GuardWideningPass::run(Loop &L, LoopAnalysisManager &AM, auto BlockFilter = [&](BasicBlock *BB) { return BB == RootBB || L.contains(BB); }; - if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, AR.DT.getNode(RootBB), - BlockFilter).run()) + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (AR.MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA); + if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, MSSAU ? MSSAU.get() : nullptr, + AR.DT.getNode(RootBB), BlockFilter).run()) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; } namespace { @@ -805,8 +814,14 @@ struct GuardWideningLegacyPass : public FunctionPass { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - return GuardWideningImpl(DT, &PDT, LI, DT.getRootNode(), - [](BasicBlock*) { return true; } ).run(); + auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSAWP) + MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); + return GuardWideningImpl(DT, &PDT, LI, MSSAU ? MSSAU.get() : nullptr, + DT.getRootNode(), + [](BasicBlock *) { return true; }) + .run(); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -814,6 +829,7 @@ struct GuardWideningLegacyPass : public FunctionPass { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<PostDominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); } }; @@ -833,13 +849,18 @@ struct LoopGuardWideningLegacyPass : public LoopPass { auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; + auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSAWP) + MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); + BasicBlock *RootBB = L->getLoopPredecessor(); if (!RootBB) RootBB = L->getHeader(); auto BlockFilter = [&](BasicBlock *BB) { return BB == RootBB || L->contains(BB); }; - return GuardWideningImpl(DT, PDT, LI, + return GuardWideningImpl(DT, PDT, LI, MSSAU ? MSSAU.get() : nullptr, DT.getNode(RootBB), BlockFilter).run(); } @@ -847,6 +868,7 @@ struct LoopGuardWideningLegacyPass : public LoopPass { AU.setPreservesCFG(); getLoopAnalysisUsage(AU); AU.addPreserved<PostDominatorTreeWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); } }; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 9ee2a2d0bf08..ae2fe2767074 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -89,6 +89,7 @@ #include <utility> using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "indvars" @@ -155,6 +156,10 @@ class IndVarSimplify { bool rewriteNonIntegerIVs(Loop *L); bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); + /// Try to improve our exit conditions by converting condition from signed + /// to unsigned or rotating computation out of the loop. + /// (See inline comment about why this is duplicated from simplifyAndExtend) + bool canonicalizeExitCondition(Loop *L); /// Try to eliminate loop exits based on analyzeable exit counts bool optimizeLoopExits(Loop *L, SCEVExpander &Rewriter); /// Try to form loop invariant tests for loop exits by changing how many @@ -494,6 +499,7 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { MadeAnyChanges = true; PN.setIncomingValue(IncomingValIdx, ExitVal->getIncomingValue(PreheaderIdx)); + SE->forgetValue(&PN); } } } @@ -541,18 +547,18 @@ static void visitIVCast(CastInst *Cast, WideIVInfo &WI, return; } - if (!WI.WidestNativeType) { + if (!WI.WidestNativeType || + Width > SE->getTypeSizeInBits(WI.WidestNativeType)) { WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); WI.IsSigned = IsSigned; return; } - // We extend the IV to satisfy the sign of its first user, arbitrarily. - if (WI.IsSigned != IsSigned) - return; - - if (Width > SE->getTypeSizeInBits(WI.WidestNativeType)) - WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); + // We extend the IV to satisfy the sign of its user(s), or 'signed' + // if there are multiple users with both sign- and zero extensions, + // in order not to introduce nondeterministic behaviour based on the + // unspecified order of a PHI nodes' users-iterator. + WI.IsSigned |= IsSigned; } //===----------------------------------------------------------------------===// @@ -1274,9 +1280,9 @@ bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { // Skip debug info intrinsics. do { --I; - } while (isa<DbgInfoIntrinsic>(I) && I != Preheader->begin()); + } while (I->isDebugOrPseudoInst() && I != Preheader->begin()); - if (isa<DbgInfoIntrinsic>(I) && I == Preheader->begin()) + if (I->isDebugOrPseudoInst() && I == Preheader->begin()) Done = true; } else { Done = true; @@ -1309,6 +1315,18 @@ static void foldExit(const Loop *L, BasicBlock *ExitingBB, bool IsTaken, replaceExitCond(BI, NewCond, DeadInsts); } +static void replaceLoopPHINodesWithPreheaderValues( + Loop *L, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + assert(L->isLoopSimplifyForm() && "Should only do it in simplify form!"); + auto *LoopPreheader = L->getLoopPreheader(); + auto *LoopHeader = L->getHeader(); + for (auto &PN : LoopHeader->phis()) { + auto *PreheaderIncoming = PN.getIncomingValueForBlock(LoopPreheader); + PN.replaceAllUsesWith(PreheaderIncoming); + DeadInsts.emplace_back(&PN); + } +} + static void replaceWithInvariantCond( const Loop *L, BasicBlock *ExitingBB, ICmpInst::Predicate InvariantPred, const SCEV *InvariantLHS, const SCEV *InvariantRHS, SCEVExpander &Rewriter, @@ -1333,7 +1351,6 @@ static bool optimizeLoopExitWithUnknownExitCount( SmallVectorImpl<WeakTrackingVH> &DeadInsts) { ICmpInst::Predicate Pred; Value *LHS, *RHS; - using namespace PatternMatch; BasicBlock *TrueSucc, *FalseSucc; if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) @@ -1394,6 +1411,140 @@ static bool optimizeLoopExitWithUnknownExitCount( return true; } +bool IndVarSimplify::canonicalizeExitCondition(Loop *L) { + // Note: This is duplicating a particular part on SimplifyIndVars reasoning. + // We need to duplicate it because given icmp zext(small-iv), C, IVUsers + // never reaches the icmp since the zext doesn't fold to an AddRec unless + // it already has flags. The alternative to this would be to extending the + // set of "interesting" IV users to include the icmp, but doing that + // regresses results in practice by querying SCEVs before trip counts which + // rely on them which results in SCEV caching sub-optimal answers. The + // concern about caching sub-optimal results is why we only query SCEVs of + // the loop invariant RHS here. + SmallVector<BasicBlock*, 16> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + bool Changed = false; + for (auto *ExitingBB : ExitingBlocks) { + auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); + if (!BI) + continue; + assert(BI->isConditional() && "exit branch must be conditional"); + + auto *ICmp = dyn_cast<ICmpInst>(BI->getCondition()); + if (!ICmp || !ICmp->hasOneUse()) + continue; + + auto *LHS = ICmp->getOperand(0); + auto *RHS = ICmp->getOperand(1); + // For the range reasoning, avoid computing SCEVs in the loop to avoid + // poisoning cache with sub-optimal results. For the must-execute case, + // this is a neccessary precondition for correctness. + if (!L->isLoopInvariant(RHS)) { + if (!L->isLoopInvariant(LHS)) + continue; + // Same logic applies for the inverse case + std::swap(LHS, RHS); + } + + // Match (icmp signed-cond zext, RHS) + Value *LHSOp = nullptr; + if (!match(LHS, m_ZExt(m_Value(LHSOp))) || !ICmp->isSigned()) + continue; + + const DataLayout &DL = ExitingBB->getModule()->getDataLayout(); + const unsigned InnerBitWidth = DL.getTypeSizeInBits(LHSOp->getType()); + const unsigned OuterBitWidth = DL.getTypeSizeInBits(RHS->getType()); + auto FullCR = ConstantRange::getFull(InnerBitWidth); + FullCR = FullCR.zeroExtend(OuterBitWidth); + auto RHSCR = SE->getUnsignedRange(SE->applyLoopGuards(SE->getSCEV(RHS), L)); + if (FullCR.contains(RHSCR)) { + // We have now matched icmp signed-cond zext(X), zext(Y'), and can thus + // replace the signed condition with the unsigned version. + ICmp->setPredicate(ICmp->getUnsignedPredicate()); + Changed = true; + // Note: No SCEV invalidation needed. We've changed the predicate, but + // have not changed exit counts, or the values produced by the compare. + continue; + } + } + + // Now that we've canonicalized the condition to match the extend, + // see if we can rotate the extend out of the loop. + for (auto *ExitingBB : ExitingBlocks) { + auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); + if (!BI) + continue; + assert(BI->isConditional() && "exit branch must be conditional"); + + auto *ICmp = dyn_cast<ICmpInst>(BI->getCondition()); + if (!ICmp || !ICmp->hasOneUse() || !ICmp->isUnsigned()) + continue; + + bool Swapped = false; + auto *LHS = ICmp->getOperand(0); + auto *RHS = ICmp->getOperand(1); + if (L->isLoopInvariant(LHS) == L->isLoopInvariant(RHS)) + // Nothing to rotate + continue; + if (L->isLoopInvariant(LHS)) { + // Same logic applies for the inverse case until we actually pick + // which operand of the compare to update. + Swapped = true; + std::swap(LHS, RHS); + } + assert(!L->isLoopInvariant(LHS) && L->isLoopInvariant(RHS)); + + // Match (icmp unsigned-cond zext, RHS) + // TODO: Extend to handle corresponding sext/signed-cmp case + // TODO: Extend to other invertible functions + Value *LHSOp = nullptr; + if (!match(LHS, m_ZExt(m_Value(LHSOp)))) + continue; + + // In general, we only rotate if we can do so without increasing the number + // of instructions. The exception is when we have an zext(add-rec). The + // reason for allowing this exception is that we know we need to get rid + // of the zext for SCEV to be able to compute a trip count for said loops; + // we consider the new trip count valuable enough to increase instruction + // count by one. + if (!LHS->hasOneUse() && !isa<SCEVAddRecExpr>(SE->getSCEV(LHSOp))) + continue; + + // Given a icmp unsigned-cond zext(Op) where zext(trunc(RHS)) == RHS + // replace with an icmp of the form icmp unsigned-cond Op, trunc(RHS) + // when zext is loop varying and RHS is loop invariant. This converts + // loop varying work to loop-invariant work. + auto doRotateTransform = [&]() { + assert(ICmp->isUnsigned() && "must have proven unsigned already"); + auto *NewRHS = + CastInst::Create(Instruction::Trunc, RHS, LHSOp->getType(), "", + L->getLoopPreheader()->getTerminator()); + ICmp->setOperand(Swapped ? 1 : 0, LHSOp); + ICmp->setOperand(Swapped ? 0 : 1, NewRHS); + if (LHS->use_empty()) + DeadInsts.push_back(LHS); + }; + + + const DataLayout &DL = ExitingBB->getModule()->getDataLayout(); + const unsigned InnerBitWidth = DL.getTypeSizeInBits(LHSOp->getType()); + const unsigned OuterBitWidth = DL.getTypeSizeInBits(RHS->getType()); + auto FullCR = ConstantRange::getFull(InnerBitWidth); + FullCR = FullCR.zeroExtend(OuterBitWidth); + auto RHSCR = SE->getUnsignedRange(SE->applyLoopGuards(SE->getSCEV(RHS), L)); + if (FullCR.contains(RHSCR)) { + doRotateTransform(); + Changed = true; + // Note, we are leaving SCEV in an unfortunately imprecise case here + // as rotation tends to reveal information about trip counts not + // previously visible. + continue; + } + } + + return Changed; +} + bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { SmallVector<BasicBlock*, 16> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); @@ -1499,20 +1650,18 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // If we know we'd exit on the first iteration, rewrite the exit to // reflect this. This does not imply the loop must exit through this // exit; there may be an earlier one taken on the first iteration. - // TODO: Given we know the backedge can't be taken, we should go ahead - // and break it. Or at least, kill all the header phis and simplify. + // We know that the backedge can't be taken, so we replace all + // the header PHIs with values coming from the preheader. if (ExitCount->isZero()) { foldExit(L, ExitingBB, true, DeadInsts); + replaceLoopPHINodesWithPreheaderValues(L, DeadInsts); Changed = true; continue; } - // If we end up with a pointer exit count, bail. Note that we can end up - // with a pointer exit count for one exiting block, and not for another in - // the same loop. - if (!ExitCount->getType()->isIntegerTy() || - !MaxExitCount->getType()->isIntegerTy()) - continue; + assert(ExitCount->getType()->isIntegerTy() && + MaxExitCount->getType()->isIntegerTy() && + "Exit counts must be integers"); Type *WiderType = SE->getWiderType(MaxExitCount->getType(), ExitCount->getType()); @@ -1569,14 +1718,11 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // through *explicit* control flow. We have to eliminate the possibility of // implicit exits (see below) before we know it's truly exact. const SCEV *ExactBTC = SE->getBackedgeTakenCount(L); - if (isa<SCEVCouldNotCompute>(ExactBTC) || - !SE->isLoopInvariant(ExactBTC, L) || - !isSafeToExpand(ExactBTC, *SE)) + if (isa<SCEVCouldNotCompute>(ExactBTC) || !isSafeToExpand(ExactBTC, *SE)) return false; - // If we end up with a pointer exit count, bail. It may be unsized. - if (!ExactBTC->getType()->isIntegerTy()) - return false; + assert(SE->isLoopInvariant(ExactBTC, L) && "BTC must be loop invariant"); + assert(ExactBTC->getType()->isIntegerTy() && "BTC must be integer"); auto BadExit = [&](BasicBlock *ExitingBB) { // If our exiting block exits multiple loops, we can only rewrite the @@ -1603,15 +1749,12 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { return true; const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); - if (isa<SCEVCouldNotCompute>(ExitCount) || - !SE->isLoopInvariant(ExitCount, L) || - !isSafeToExpand(ExitCount, *SE)) - return true; - - // If we end up with a pointer exit count, bail. It may be unsized. - if (!ExitCount->getType()->isIntegerTy()) + if (isa<SCEVCouldNotCompute>(ExitCount) || !isSafeToExpand(ExitCount, *SE)) return true; + assert(SE->isLoopInvariant(ExitCount, L) && + "Exit count must be loop invariant"); + assert(ExitCount->getType()->isIntegerTy() && "Exit count must be integer"); return false; }; @@ -1781,7 +1924,11 @@ bool IndVarSimplify::run(Loop *L) { } // Eliminate redundant IV cycles. - NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); + NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts, TTI); + + // Try to convert exit conditions to unsigned and rotate computation + // out of the loop. Note: Handles invalidation internally if needed. + Changed |= canonicalizeExitCondition(L); // Try to eliminate loop exits based on analyzeable exit counts if (optimizeLoopExits(L, Rewriter)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index f7d631f5e785..883d4afff3bd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -96,10 +96,13 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -115,6 +118,7 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -146,6 +150,14 @@ static const unsigned UninitializedAddressSpace = namespace { using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; +// Different from ValueToAddrSpaceMapTy, where a new addrspace is inferred on +// the *def* of a value, PredicatedAddrSpaceMapTy is map where a new +// addrspace is inferred on the *use* of a pointer. This map is introduced to +// infer addrspace from the addrspace predicate assumption built from assume +// intrinsic. In that scenario, only specific uses (under valid assumption +// context) could be inferred with a new addrspace. +using PredicatedAddrSpaceMapTy = + DenseMap<std::pair<const Value *, const Value *>, unsigned>; using PostorderStackTy = llvm::SmallVector<PointerIntPair<Value *, 1, bool>, 4>; class InferAddressSpaces : public FunctionPass { @@ -160,6 +172,8 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } @@ -167,6 +181,8 @@ public: }; class InferAddressSpacesImpl { + AssumptionCache &AC; + DominatorTree *DT = nullptr; const TargetTransformInfo *TTI = nullptr; const DataLayout *DL = nullptr; @@ -174,21 +190,24 @@ class InferAddressSpacesImpl { /// possible. unsigned FlatAddrSpace = 0; - // Returns the new address space of V if updated; otherwise, returns None. - Optional<unsigned> - updateAddressSpace(const Value &V, - const ValueToAddrSpaceMapTy &InferredAddrSpace) const; + // Try to update the address space of V. If V is updated, returns true and + // false otherwise. + bool updateAddressSpace(const Value &V, + ValueToAddrSpaceMapTy &InferredAddrSpace, + PredicatedAddrSpaceMapTy &PredicatedAS) const; // Tries to infer the specific address space of each address expression in // Postorder. void inferAddressSpaces(ArrayRef<WeakTrackingVH> Postorder, - ValueToAddrSpaceMapTy *InferredAddrSpace) const; + ValueToAddrSpaceMapTy &InferredAddrSpace, + PredicatedAddrSpaceMapTy &PredicatedAS) const; bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; Value *cloneInstructionWithNewAddressSpace( Instruction *I, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl<const Use *> *UndefUsesToFix) const; // Changes the flat address expressions in function F to point to specific @@ -196,7 +215,8 @@ class InferAddressSpacesImpl { // all flat expressions in the use-def graph of function F. bool rewriteWithNewAddressSpaces( const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, - const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const; + const ValueToAddrSpaceMapTy &InferredAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const; void appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, @@ -211,14 +231,18 @@ class InferAddressSpacesImpl { std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const; Value *cloneValueWithNewAddressSpace( - Value *V, unsigned NewAddrSpace, - const ValueToValueMapTy &ValueWithNewAddrSpace, - SmallVectorImpl<const Use *> *UndefUsesToFix) const; + Value *V, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, + SmallVectorImpl<const Use *> *UndefUsesToFix) const; unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; + unsigned getPredicatedAddrSpace(const Value &V, Value *Opnd) const; + public: - InferAddressSpacesImpl(const TargetTransformInfo *TTI, unsigned FlatAddrSpace) - : TTI(TTI), FlatAddrSpace(FlatAddrSpace) {} + InferAddressSpacesImpl(AssumptionCache &AC, DominatorTree *DT, + const TargetTransformInfo *TTI, unsigned FlatAddrSpace) + : AC(AC), DT(DT), TTI(TTI), FlatAddrSpace(FlatAddrSpace) {} bool run(Function &F); }; @@ -232,8 +256,12 @@ void initializeInferAddressSpacesPass(PassRegistry &); } // end namespace llvm -INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", - false, false) +INITIALIZE_PASS_BEGIN(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", + false, false) // Check whether that's no-op pointer bicast using a pair of // `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over @@ -505,6 +533,7 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { static Value *operandWithNewAddressSpaceOrCreateUndef( const Use &OperandUse, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl<const Use *> *UndefUsesToFix) { Value *Operand = OperandUse.get(); @@ -517,6 +546,18 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) return NewOperand; + Instruction *Inst = cast<Instruction>(OperandUse.getUser()); + auto I = PredicatedAS.find(std::make_pair(Inst, Operand)); + if (I != PredicatedAS.end()) { + // Insert an addrspacecast on that operand before the user. + unsigned NewAS = I->second; + Type *NewPtrTy = PointerType::getWithSamePointeeType( + cast<PointerType>(Operand->getType()), NewAS); + auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); + NewI->insertBefore(Inst); + return NewI; + } + UndefUsesToFix->push_back(&OperandUse); return UndefValue::get(NewPtrTy); } @@ -536,6 +577,7 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( Instruction *I, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl<const Use *> *UndefUsesToFix) const { Type *NewPtrType = PointerType::getWithSamePointeeType( cast<PointerType>(I->getType()), NewAddrSpace); @@ -557,7 +599,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( assert(II->getIntrinsicID() == Intrinsic::ptrmask); Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef( II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace, - UndefUsesToFix); + PredicatedAS, UndefUsesToFix); Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr); if (Rewrite) { @@ -586,7 +628,8 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( NewPointerOperands.push_back(nullptr); else NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( - OperandUse, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix)); + OperandUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, + UndefUsesToFix)); } switch (I->getOpcode()) { @@ -708,9 +751,8 @@ static Value *cloneConstantExprWithNewAddressSpace( if (CE->getOpcode() == Instruction::GetElementPtr) { // Needs to specify the source type while constructing a getelementptr // constant expression. - return CE->getWithOperands( - NewOperands, TargetType, /*OnlyIfReduced=*/false, - NewOperands[0]->getType()->getPointerElementType()); + return CE->getWithOperands(NewOperands, TargetType, /*OnlyIfReduced=*/false, + cast<GEPOperator>(CE)->getSourceElementType()); } return CE->getWithOperands(NewOperands, TargetType); @@ -724,6 +766,7 @@ static Value *cloneConstantExprWithNewAddressSpace( Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace( Value *V, unsigned NewAddrSpace, const ValueToValueMapTy &ValueWithNewAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl<const Use *> *UndefUsesToFix) const { // All values in Postorder are flat address expressions. assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace && @@ -731,7 +774,7 @@ Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace( if (Instruction *I = dyn_cast<Instruction>(V)) { Value *NewV = cloneInstructionWithNewAddressSpace( - I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); + I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, UndefUsesToFix); if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) { if (NewI->getParent() == nullptr) { NewI->insertBefore(I); @@ -779,46 +822,43 @@ bool InferAddressSpacesImpl::run(Function &F) { // Runs a data-flow analysis to refine the address spaces of every expression // in Postorder. ValueToAddrSpaceMapTy InferredAddrSpace; - inferAddressSpaces(Postorder, &InferredAddrSpace); + PredicatedAddrSpaceMapTy PredicatedAS; + inferAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS); // Changes the address spaces of the flat address expressions who are inferred // to point to a specific address space. - return rewriteWithNewAddressSpaces(*TTI, Postorder, InferredAddrSpace, &F); + return rewriteWithNewAddressSpaces(*TTI, Postorder, InferredAddrSpace, + PredicatedAS, &F); } // Constants need to be tracked through RAUW to handle cases with nested // constant expressions, so wrap values in WeakTrackingVH. void InferAddressSpacesImpl::inferAddressSpaces( ArrayRef<WeakTrackingVH> Postorder, - ValueToAddrSpaceMapTy *InferredAddrSpace) const { + ValueToAddrSpaceMapTy &InferredAddrSpace, + PredicatedAddrSpaceMapTy &PredicatedAS) const { SetVector<Value *> Worklist(Postorder.begin(), Postorder.end()); // Initially, all expressions are in the uninitialized address space. for (Value *V : Postorder) - (*InferredAddrSpace)[V] = UninitializedAddressSpace; + InferredAddrSpace[V] = UninitializedAddressSpace; while (!Worklist.empty()) { Value *V = Worklist.pop_back_val(); - // Tries to update the address space of the stack top according to the + // Try to update the address space of the stack top according to the // address spaces of its operands. - LLVM_DEBUG(dbgs() << "Updating the address space of\n " << *V << '\n'); - Optional<unsigned> NewAS = updateAddressSpace(*V, *InferredAddrSpace); - if (!NewAS.hasValue()) + if (!updateAddressSpace(*V, InferredAddrSpace, PredicatedAS)) continue; - // If any updates are made, grabs its users to the worklist because - // their address spaces can also be possibly updated. - LLVM_DEBUG(dbgs() << " to " << NewAS.getValue() << '\n'); - (*InferredAddrSpace)[V] = NewAS.getValue(); for (Value *User : V->users()) { // Skip if User is already in the worklist. if (Worklist.count(User)) continue; - auto Pos = InferredAddrSpace->find(User); + auto Pos = InferredAddrSpace.find(User); // Our algorithm only updates the address spaces of flat address // expressions, which are those in InferredAddrSpace. - if (Pos == InferredAddrSpace->end()) + if (Pos == InferredAddrSpace.end()) continue; // Function updateAddressSpace moves the address space down a lattice @@ -832,10 +872,37 @@ void InferAddressSpacesImpl::inferAddressSpaces( } } -Optional<unsigned> InferAddressSpacesImpl::updateAddressSpace( - const Value &V, const ValueToAddrSpaceMapTy &InferredAddrSpace) const { +unsigned InferAddressSpacesImpl::getPredicatedAddrSpace(const Value &V, + Value *Opnd) const { + const Instruction *I = dyn_cast<Instruction>(&V); + if (!I) + return UninitializedAddressSpace; + + Opnd = Opnd->stripInBoundsOffsets(); + for (auto &AssumeVH : AC.assumptionsFor(Opnd)) { + if (!AssumeVH) + continue; + CallInst *CI = cast<CallInst>(AssumeVH); + if (!isValidAssumeForContext(CI, I, DT)) + continue; + + const Value *Ptr; + unsigned AS; + std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(CI->getArgOperand(0)); + if (Ptr) + return AS; + } + + return UninitializedAddressSpace; +} + +bool InferAddressSpacesImpl::updateAddressSpace( + const Value &V, ValueToAddrSpaceMapTy &InferredAddrSpace, + PredicatedAddrSpaceMapTy &PredicatedAS) const { assert(InferredAddrSpace.count(&V)); + LLVM_DEBUG(dbgs() << "Updating the address space of\n " << V << '\n'); + // The new inferred address space equals the join of the address spaces // of all its pointer operands. unsigned NewAS = UninitializedAddressSpace; @@ -861,7 +928,7 @@ Optional<unsigned> InferAddressSpacesImpl::updateAddressSpace( // address space is known. if ((C1 && Src0AS == UninitializedAddressSpace) || (C0 && Src1AS == UninitializedAddressSpace)) - return None; + return false; if (C0 && isSafeToCastConstAddrSpace(C0, Src1AS)) NewAS = Src1AS; @@ -878,10 +945,23 @@ Optional<unsigned> InferAddressSpacesImpl::updateAddressSpace( // Otherwise, infer the address space from its pointer operands. for (Value *PtrOperand : getPointerOperands(V, *DL, TTI)) { auto I = InferredAddrSpace.find(PtrOperand); - unsigned OperandAS = - I != InferredAddrSpace.end() - ? I->second - : PtrOperand->getType()->getPointerAddressSpace(); + unsigned OperandAS; + if (I == InferredAddrSpace.end()) { + OperandAS = PtrOperand->getType()->getPointerAddressSpace(); + if (OperandAS == FlatAddrSpace) { + // Check AC for assumption dominating V. + unsigned AS = getPredicatedAddrSpace(V, PtrOperand); + if (AS != UninitializedAddressSpace) { + LLVM_DEBUG(dbgs() + << " deduce operand AS from the predicate addrspace " + << AS << '\n'); + OperandAS = AS; + // Record this use with the predicated AS. + PredicatedAS[std::make_pair(&V, PtrOperand)] = OperandAS; + } + } + } else + OperandAS = I->second; // join(flat, *) = flat. So we can break if NewAS is already flat. NewAS = joinAddressSpaces(NewAS, OperandAS); @@ -894,8 +974,13 @@ Optional<unsigned> InferAddressSpacesImpl::updateAddressSpace( unsigned OldAS = InferredAddrSpace.lookup(&V); assert(OldAS != FlatAddrSpace); if (OldAS == NewAS) - return None; - return NewAS; + return false; + + // If any updates are made, grabs its users to the worklist because + // their address spaces can also be possibly updated. + LLVM_DEBUG(dbgs() << " to " << NewAS << '\n'); + InferredAddrSpace[&V] = NewAS; + return true; } /// \p returns true if \p U is the pointer operand of a memory instruction with @@ -1026,7 +1111,8 @@ static Value::use_iterator skipToNextUser(Value::use_iterator I, bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, - const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const { + const ValueToAddrSpaceMapTy &InferredAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const { // For each address expression to be modified, creates a clone of it with its // pointer operands converted to the new address space. Since the pointer // operands are converted, the clone is naturally in the new address space by @@ -1042,8 +1128,9 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( continue; if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { - Value *New = cloneValueWithNewAddressSpace( - V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + Value *New = + cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace, + PredicatedAS, &UndefUsesToFix); if (New) ValueWithNewAddrSpace[V] = New; } @@ -1155,8 +1242,9 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) { unsigned NewAS = NewV->getType()->getPointerAddressSpace(); if (ASC->getDestAddressSpace() == NewAS) { - if (ASC->getType()->getPointerElementType() != - NewV->getType()->getPointerElementType()) { + if (!cast<PointerType>(ASC->getType()) + ->hasSameElementTypeAs( + cast<PointerType>(NewV->getType()))) { NewV = CastInst::Create(Instruction::BitCast, NewV, ASC->getType(), "", ASC); } @@ -1199,7 +1287,10 @@ bool InferAddressSpaces::runOnFunction(Function &F) { if (skipFunction(F)) return false; + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; return InferAddressSpacesImpl( + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), DT, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), FlatAddrSpace) .run(F); @@ -1217,11 +1308,14 @@ InferAddressSpacesPass::InferAddressSpacesPass(unsigned AddressSpace) PreservedAnalyses InferAddressSpacesPass::run(Function &F, FunctionAnalysisManager &AM) { bool Changed = - InferAddressSpacesImpl(&AM.getResult<TargetIRAnalysis>(F), FlatAddrSpace) + InferAddressSpacesImpl(AM.getResult<AssumptionAnalysis>(F), + AM.getCachedResult<DominatorTreeAnalysis>(F), + &AM.getResult<TargetIRAnalysis>(F), FlatAddrSpace) .run(F); if (Changed) { PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); + PA.preserve<DominatorTreeAnalysis>(); return PA; } return PreservedAnalyses::all(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 9dc3b0351346..fe9a7211967c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -331,7 +331,7 @@ bool JumpThreading::runOnFunction(Function &F) { BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = Impl.runImpl(F, TLI, LVI, AA, &DTU, F.hasProfileData(), + bool Changed = Impl.runImpl(F, TLI, TTI, LVI, AA, &DTU, F.hasProfileData(), std::move(BFI), std::move(BPI)); if (PrintLVIAfterJumpThreading) { dbgs() << "LVI for function '" << F.getName() << "':\n"; @@ -360,7 +360,7 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = runImpl(F, &TLI, &LVI, &AA, &DTU, F.hasProfileData(), + bool Changed = runImpl(F, &TLI, &TTI, &LVI, &AA, &DTU, F.hasProfileData(), std::move(BFI), std::move(BPI)); if (PrintLVIAfterJumpThreading) { @@ -377,12 +377,14 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, } bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, - LazyValueInfo *LVI_, AliasAnalysis *AA_, - DomTreeUpdater *DTU_, bool HasProfileData_, + TargetTransformInfo *TTI_, LazyValueInfo *LVI_, + AliasAnalysis *AA_, DomTreeUpdater *DTU_, + bool HasProfileData_, std::unique_ptr<BlockFrequencyInfo> BFI_, std::unique_ptr<BranchProbabilityInfo> BPI_) { LLVM_DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = TLI_; + TTI = TTI_; LVI = LVI_; AA = AA_; DTU = DTU_; @@ -514,7 +516,8 @@ static void replaceFoldableUses(Instruction *Cond, Value *ToVal) { /// Return the cost of duplicating a piece of this block from first non-phi /// and before StopAt instruction to thread across it. Stop scanning the block /// when exceeding the threshold. If duplication is impossible, returns ~0U. -static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, +static unsigned getJumpThreadDuplicationCost(const TargetTransformInfo *TTI, + BasicBlock *BB, Instruction *StopAt, unsigned Threshold) { assert(StopAt->getParent() == BB && "Not an instruction from proper BB?"); @@ -550,26 +553,21 @@ static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, if (Size > Threshold) return Size; - // Debugger intrinsics don't incur code size. - if (isa<DbgInfoIntrinsic>(I)) continue; - - // Pseudo-probes don't incur code size. - if (isa<PseudoProbeInst>(I)) - continue; - - // If this is a pointer->pointer bitcast, it is free. - if (isa<BitCastInst>(I) && I->getType()->isPointerTy()) - continue; - - // Freeze instruction is free, too. - if (isa<FreezeInst>(I)) - continue; - // Bail out if this instruction gives back a token type, it is not possible // to duplicate it if it is used outside this BB. if (I->getType()->isTokenTy() && I->isUsedOutsideOfBlock(BB)) return ~0U; + // Blocks with NoDuplicate are modelled as having infinite cost, so they + // are never duplicated. + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (CI->cannotDuplicate() || CI->isConvergent()) + return ~0U; + + if (TTI->getUserCost(&*I, TargetTransformInfo::TCK_SizeAndLatency) + == TargetTransformInfo::TCC_Free) + continue; + // All other instructions count for at least one unit. ++Size; @@ -578,11 +576,7 @@ static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, // as having cost of 2 total, and if they are a vector intrinsic, we model // them as having cost 1. if (const CallInst *CI = dyn_cast<CallInst>(I)) { - if (CI->cannotDuplicate() || CI->isConvergent()) - // Blocks with NoDuplicate are modelled as having infinite cost, so they - // are never duplicated. - return ~0U; - else if (!isa<IntrinsicInst>(CI)) + if (!isa<IntrinsicInst>(CI)) Size += 3; else if (!CI->getType()->isVectorTy()) Size += 1; @@ -1363,8 +1357,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { // If all of the loads and stores that feed the value have the same AA tags, // then we can propagate them onto any newly inserted loads. - AAMDNodes AATags; - LoadI->getAAMetadata(AATags); + AAMDNodes AATags = LoadI->getAAMetadata(); SmallPtrSet<BasicBlock*, 8> PredsScanned; @@ -2235,10 +2228,10 @@ bool JumpThreadingPass::maybethreadThroughTwoBasicBlocks(BasicBlock *BB, } // Compute the cost of duplicating BB and PredBB. - unsigned BBCost = - getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + unsigned BBCost = getJumpThreadDuplicationCost( + TTI, BB, BB->getTerminator(), BBDupThreshold); unsigned PredBBCost = getJumpThreadDuplicationCost( - PredBB, PredBB->getTerminator(), BBDupThreshold); + TTI, PredBB, PredBB->getTerminator(), BBDupThreshold); // Give up if costs are too high. We need to check BBCost and PredBBCost // individually before checking their sum because getJumpThreadDuplicationCost @@ -2346,8 +2339,8 @@ bool JumpThreadingPass::tryThreadEdge( return false; } - unsigned JumpThreadCost = - getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + unsigned JumpThreadCost = getJumpThreadDuplicationCost( + TTI, BB, BB->getTerminator(), BBDupThreshold); if (JumpThreadCost > BBDupThreshold) { LLVM_DEBUG(dbgs() << " Not threading BB '" << BB->getName() << "' - Cost is too high: " << JumpThreadCost << "\n"); @@ -2615,8 +2608,8 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( return false; } - unsigned DuplicationCost = - getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + unsigned DuplicationCost = getJumpThreadDuplicationCost( + TTI, BB, BB->getTerminator(), BBDupThreshold); if (DuplicationCost > BBDupThreshold) { LLVM_DEBUG(dbgs() << " Not duplicating BB '" << BB->getName() << "' - Cost is too high: " << DuplicationCost << "\n"); @@ -3032,7 +3025,8 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard, ValueToValueMapTy UnguardedMapping, GuardedMapping; Instruction *AfterGuard = Guard->getNextNode(); - unsigned Cost = getJumpThreadDuplicationCost(BB, AfterGuard, BBDupThreshold); + unsigned Cost = + getJumpThreadDuplicationCost(TTI, BB, AfterGuard, BBDupThreshold); if (Cost > BBDupThreshold) return false; // Duplicate all instructions before the guard and the guard itself to the diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp index 30058df3ded5..bf714d167670 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp @@ -117,13 +117,6 @@ static cl::opt<uint32_t> MaxNumUsesTraversed( cl::desc("Max num uses visited for identifying load " "invariance in loop using invariant start (default = 8)")); -// Default value of zero implies we use the regular alias set tracker mechanism -// instead of the cross product using AA to identify aliasing of the memory -// location we are interested in. -static cl::opt<int> -LICMN2Theshold("licm-n2-threshold", cl::Hidden, cl::init(0), - cl::desc("How many instruction to cross product using AA")); - // Experimental option to allow imprecision in LICM in pathological cases, in // exchange for faster compile. This is to be removed if MemorySSA starts to // address the same issue. This flag applies only when LICM uses MemorySSA @@ -151,7 +144,8 @@ cl::opt<unsigned> llvm::SetLicmMssaNoAccForPromotionCap( static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, - TargetTransformInfo *TTI, bool &FreeInLoop); + TargetTransformInfo *TTI, bool &FreeInLoop, + bool LoopNestMode); static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU, ScalarEvolution *SE, @@ -180,7 +174,7 @@ static Instruction *cloneInstructionInExitBlock( const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, - AliasSetTracker *AST, MemorySSAUpdater *MSSAU); + MemorySSAUpdater *MSSAU); static void moveInstructionBefore(Instruction &I, Instruction &Dest, ICFLoopSafetyInfo &SafetyInfo, @@ -206,9 +200,6 @@ struct LoopInvariantCodeMotion { private: unsigned LicmMssaOptCap; unsigned LicmMssaNoAccForPromotionCap; - - std::unique_ptr<AliasSetTracker> - collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AAResults *AA); }; struct LegacyLICMPass : public LoopPass { @@ -228,9 +219,7 @@ struct LegacyLICMPass : public LoopPass { << L->getHeader()->getNameOrAsOperand() << "\n"); auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - MemorySSA *MSSA = EnableMSSALoopDependency - ? (&getAnalysis<MemorySSAWrapperPass>().getMSSA()) - : nullptr; + MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); bool hasProfileData = L->getHeader()->getParent()->hasProfileData(); BlockFrequencyInfo *BFI = hasProfileData ? &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() @@ -257,10 +246,8 @@ struct LegacyLICMPass : public LoopPass { AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); - if (EnableMSSALoopDependency) { - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); getLoopAnalysisUsage(AU); LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); @@ -275,6 +262,9 @@ private: PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { + if (!AR.MSSA) + report_fatal_error("LICM requires MemorySSA (loop-mssa)"); + // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations // but ORE cannot be preserved (see comment before the pass definition). @@ -289,8 +279,7 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LoopAnalysis>(); - if (AR.MSSA) - PA.preserve<MemorySSAAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -298,6 +287,9 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { + if (!AR.MSSA) + report_fatal_error("LNICM requires MemorySSA (loop-mssa)"); + // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations // but ORE cannot be preserved (see comment before the pass definition). @@ -316,8 +308,7 @@ PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LoopAnalysis>(); - if (AR.MSSA) - PA.preserve<MemorySSAAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -386,10 +377,6 @@ bool LoopInvariantCodeMotion::runOnLoop( return false; } - std::unique_ptr<AliasSetTracker> CurAST; - std::unique_ptr<MemorySSAUpdater> MSSAU; - std::unique_ptr<SinkAndHoistLICMFlags> Flags; - // Don't sink stores from loops with coroutine suspend instructions. // LICM would sink instructions into the default destination of // the coroutine switch. The default destination of the switch is to @@ -406,17 +393,9 @@ bool LoopInvariantCodeMotion::runOnLoop( }); }); - if (!MSSA) { - LLVM_DEBUG(dbgs() << "LICM: Using Alias Set Tracker.\n"); - CurAST = collectAliasInfoForLoop(L, LI, AA); - Flags = std::make_unique<SinkAndHoistLICMFlags>( - LicmMssaOptCap, LicmMssaNoAccForPromotionCap, /*IsSink=*/true); - } else { - LLVM_DEBUG(dbgs() << "LICM: Using MemorySSA.\n"); - MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - Flags = std::make_unique<SinkAndHoistLICMFlags>( - LicmMssaOptCap, LicmMssaNoAccForPromotionCap, /*IsSink=*/true, L, MSSA); - } + MemorySSAUpdater MSSAU(MSSA); + SinkAndHoistLICMFlags Flags(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, + /*IsSink=*/true, L, MSSA); // Get the preheader block to move instructions into... BasicBlock *Preheader = L->getLoopPreheader(); @@ -435,14 +414,16 @@ bool LoopInvariantCodeMotion::runOnLoop( // us to sink instructions in one pass, without iteration. After sinking // instructions, we perform another pass to hoist them out of the loop. if (L->hasDedicatedExits()) - Changed |= - sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, TLI, TTI, L, - CurAST.get(), MSSAU.get(), &SafetyInfo, *Flags.get(), ORE); - Flags->setIsSink(false); + Changed |= LoopNestMode + ? sinkRegionForLoopNest(DT->getNode(L->getHeader()), AA, LI, + DT, BFI, TLI, TTI, L, &MSSAU, + &SafetyInfo, Flags, ORE) + : sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, + TLI, TTI, L, &MSSAU, &SafetyInfo, Flags, ORE); + Flags.setIsSink(false); if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, TLI, L, - CurAST.get(), MSSAU.get(), SE, &SafetyInfo, - *Flags.get(), ORE, LoopNestMode); + &MSSAU, SE, &SafetyInfo, Flags, ORE, LoopNestMode); // Now that all loop invariants have been removed from the loop, promote any // memory references to scalars that we can. @@ -452,7 +433,7 @@ bool LoopInvariantCodeMotion::runOnLoop( // preheader for SSA updater, so also avoid sinking when no preheader // is available. if (!DisablePromotion && Preheader && L->hasDedicatedExits() && - !Flags->tooManyMemoryAccesses() && !HasCoroSuspendInst) { + !Flags.tooManyMemoryAccesses() && !HasCoroSuspendInst) { // Figure out the loop exits and their insertion points SmallVector<BasicBlock *, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); @@ -466,55 +447,29 @@ bool LoopInvariantCodeMotion::runOnLoop( SmallVector<Instruction *, 8> InsertPts; SmallVector<MemoryAccess *, 8> MSSAInsertPts; InsertPts.reserve(ExitBlocks.size()); - if (MSSAU) - MSSAInsertPts.reserve(ExitBlocks.size()); + MSSAInsertPts.reserve(ExitBlocks.size()); for (BasicBlock *ExitBlock : ExitBlocks) { InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); - if (MSSAU) - MSSAInsertPts.push_back(nullptr); + MSSAInsertPts.push_back(nullptr); } PredIteratorCache PIC; + // Promoting one set of accesses may make the pointers for another set + // loop invariant, so run this in a loop (with the MaybePromotable set + // decreasing in size over time). bool Promoted = false; - if (CurAST.get()) { - // Loop over all of the alias sets in the tracker object. - for (AliasSet &AS : *CurAST) { - // We can promote this alias set if it has a store, if it is a "Must" - // alias set, if the pointer is loop invariant, and if we are not - // eliminating any volatile loads or stores. - if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || - !L->isLoopInvariant(AS.begin()->getValue())) - continue; - - assert( - !AS.empty() && - "Must alias set should have at least one pointer element in it!"); - - SmallSetVector<Value *, 8> PointerMustAliases; - for (const auto &ASI : AS) - PointerMustAliases.insert(ASI.getValue()); - - Promoted |= promoteLoopAccessesToScalars( - PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, LI, - DT, TLI, L, CurAST.get(), MSSAU.get(), &SafetyInfo, ORE); + bool LocalPromoted; + do { + LocalPromoted = false; + for (const SmallSetVector<Value *, 8> &PointerMustAliases : + collectPromotionCandidates(MSSA, AA, L)) { + LocalPromoted |= promoteLoopAccessesToScalars( + PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, + LI, DT, TLI, L, &MSSAU, &SafetyInfo, ORE); } - } else { - // Promoting one set of accesses may make the pointers for another set - // loop invariant, so run this in a loop (with the MaybePromotable set - // decreasing in size over time). - bool LocalPromoted; - do { - LocalPromoted = false; - for (const SmallSetVector<Value *, 8> &PointerMustAliases : - collectPromotionCandidates(MSSA, AA, L)) { - LocalPromoted |= promoteLoopAccessesToScalars( - PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, - LI, DT, TLI, L, /*AST*/nullptr, MSSAU.get(), &SafetyInfo, ORE); - } - Promoted |= LocalPromoted; - } while (LocalPromoted); - } + Promoted |= LocalPromoted; + } while (LocalPromoted); // Once we have promoted values across the loop body we have to // recursively reform LCSSA as any nested loop may now have values defined @@ -536,8 +491,8 @@ bool LoopInvariantCodeMotion::runOnLoop( assert((L->isOutermost() || L->getParentLoop()->isLCSSAForm(*DT)) && "Parent loop not left in LCSSA form after LICM!"); - if (MSSAU.get() && VerifyMemorySSA) - MSSAU->getMemorySSA()->verifyMemorySSA(); + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); if (Changed && SE) SE->forgetLoopDispositions(L); @@ -552,17 +507,15 @@ bool LoopInvariantCodeMotion::runOnLoop( bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, - Loop *CurLoop, AliasSetTracker *CurAST, - MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + Loop *CurLoop, MemorySSAUpdater *MSSAU, + ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, - OptimizationRemarkEmitter *ORE) { + OptimizationRemarkEmitter *ORE, Loop *OutermostLoop) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && - CurLoop != nullptr && SafetyInfo != nullptr && + CurLoop != nullptr && MSSAU != nullptr && SafetyInfo != nullptr && "Unexpected input to sinkRegion."); - assert(((CurAST != nullptr) ^ (MSSAU != nullptr)) && - "Either AliasSetTracker or MemorySSA should be initialized."); // We want to visit children before parents. We will enque all the parents // before their children in the worklist and process the worklist in reverse @@ -587,7 +540,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, salvageKnowledge(&I); salvageDebugInfo(I); ++II; - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(I, *SafetyInfo, MSSAU); Changed = true; continue; } @@ -598,26 +551,46 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // operands of the instruction are loop invariant. // bool FreeInLoop = false; + bool LoopNestMode = OutermostLoop != nullptr; if (!I.mayHaveSideEffects() && - isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, - ORE)) { + isNotUsedOrFreeInLoop(I, LoopNestMode ? OutermostLoop : CurLoop, + SafetyInfo, TTI, FreeInLoop, LoopNestMode) && + canSinkOrHoistInst(I, AA, DT, CurLoop, /*CurAST*/nullptr, MSSAU, true, + &Flags, ORE)) { if (sink(I, LI, DT, BFI, CurLoop, SafetyInfo, MSSAU, ORE)) { if (!FreeInLoop) { ++II; salvageDebugInfo(I); - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(I, *SafetyInfo, MSSAU); } Changed = true; } } } } - if (MSSAU && VerifyMemorySSA) + if (VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); return Changed; } +bool llvm::sinkRegionForLoopNest( + DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, + BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, + Loop *CurLoop, MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE) { + + bool Changed = false; + SmallPriorityWorklist<Loop *, 4> Worklist; + Worklist.insert(CurLoop); + appendLoopsToWorklist(*CurLoop, Worklist); + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, TLI, + TTI, L, MSSAU, SafetyInfo, Flags, ORE, CurLoop); + } + return Changed; +} + namespace { // This is a helper class for hoistRegion to make it able to hoist control flow // in order to be able to hoist phis. The way this works is that we initially @@ -820,9 +793,8 @@ public: if (HoistTarget == InitialPreheader) { // Phis in the loop header now need to use the new preheader. InitialPreheader->replaceSuccessorsPhiUsesWith(HoistCommonSucc); - if (MSSAU) - MSSAU->wireOldPredecessorsToNewImmediatePredecessor( - HoistTarget->getSingleSuccessor(), HoistCommonSucc, {HoistTarget}); + MSSAU->wireOldPredecessorsToNewImmediatePredecessor( + HoistTarget->getSingleSuccessor(), HoistCommonSucc, {HoistTarget}); // The new preheader dominates the loop header. DomTreeNode *PreheaderNode = DT->getNode(HoistCommonSucc); DomTreeNode *HeaderNode = DT->getNode(CurLoop->getHeader()); @@ -884,16 +856,14 @@ static bool worthSinkOrHoistInst(Instruction &I, BasicBlock *DstBlock, bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, - ScalarEvolution *SE, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU, ScalarEvolution *SE, + ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE, bool LoopNestMode) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && - CurLoop != nullptr && SafetyInfo != nullptr && + CurLoop != nullptr && MSSAU != nullptr && SafetyInfo != nullptr && "Unexpected input to hoistRegion."); - assert(((CurAST != nullptr) ^ (MSSAU != nullptr)) && - "Either AliasSetTracker or MemorySSA should be initialized."); ControlFlowHoister CFH(LI, DT, CurLoop, MSSAU); @@ -913,8 +883,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, if (!LoopNestMode && inSubLoop(BB, CurLoop, LI)) continue; - for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { - Instruction &I = *II++; + for (Instruction &I : llvm::make_early_inc_range(*BB)) { // Try constant folding this instruction. If all the operands are // constants, it is technically hoistable, but it would be better to // just fold it. @@ -922,12 +891,10 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, &I, I.getModule()->getDataLayout(), TLI)) { LLVM_DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C << '\n'); - if (CurAST) - CurAST->copyValue(&I, C); // FIXME MSSA: Such replacements may make accesses unoptimized (D51960). I.replaceAllUsesWith(C); if (isInstructionTriviallyDead(&I, TLI)) - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(I, *SafetyInfo, MSSAU); Changed = true; continue; } @@ -940,8 +907,8 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // and we have accurately duplicated the control flow from the loop header // to that block. if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, - ORE) && + canSinkOrHoistInst(I, AA, DT, CurLoop, /*CurAST*/ nullptr, MSSAU, + true, &Flags, ORE) && worthSinkOrHoistInst(I, CurLoop->getLoopPreheader(), ORE, BFI) && isSafeToExecuteUnconditionally( I, DT, TLI, CurLoop, SafetyInfo, ORE, @@ -970,7 +937,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, SafetyInfo->insertInstructionTo(Product, I.getParent()); Product->insertAfter(&I); I.replaceAllUsesWith(Product); - eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(I, *SafetyInfo, MSSAU); hoist(*ReciprocalDivisor, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); @@ -1049,7 +1016,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, } } } - if (MSSAU && VerifyMemorySSA) + if (VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); // Now that we've finished hoisting make sure that LI and DT are still @@ -1101,6 +1068,10 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, return false; Addr = BC->getOperand(0); } + // If we've ended up at a global/constant, bail. We shouldn't be looking at + // uselists for non-local Values in a loop pass. + if (isa<Constant>(Addr)) + return false; unsigned UsesVisited = 0; // Traverse all uses of the load operand value, to see if invariant.start is @@ -1273,7 +1244,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // writes to this memory in the loop, we can hoist or sink. if (AAResults::onlyAccessesArgPointees(Behavior)) { // TODO: expand to writeable arguments - for (Value *Op : CI->arg_operands()) + for (Value *Op : CI->args()) if (Op->getType()->isPointerTy()) { bool Invalidated; if (CurAST) @@ -1443,7 +1414,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, /// (e.g., a GEP can be folded into a load as an addressing mode in the loop). static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, - TargetTransformInfo *TTI, bool &FreeInLoop) { + TargetTransformInfo *TTI, bool &FreeInLoop, + bool LoopNestMode) { const auto &BlockColors = SafetyInfo->getBlockColors(); bool IsFree = isFreeInLoop(I, CurLoop, TTI); for (const User *U : I.users()) { @@ -1460,6 +1432,15 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, if (!BlockColors.empty() && BlockColors.find(const_cast<BasicBlock *>(BB))->second.size() != 1) return false; + + if (LoopNestMode) { + while (isa<PHINode>(UI) && UI->hasOneUser() && + UI->getNumOperands() == 1) { + if (!CurLoop->contains(UI)) + break; + UI = cast<Instruction>(UI->user_back()); + } + } } if (CurLoop->contains(UI)) { @@ -1546,9 +1527,7 @@ static Instruction *cloneInstructionInExitBlock( } static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, - AliasSetTracker *AST, MemorySSAUpdater *MSSAU) { - if (AST) - AST->deleteValue(&I); + MemorySSAUpdater *MSSAU) { if (MSSAU) MSSAU->removeMemoryAccess(&I); SafetyInfo.removeInstruction(&I); @@ -1599,8 +1578,7 @@ static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) { // predecessor fairly simple. if (!SafetyInfo->getBlockColors().empty() && BB->getFirstNonPHI()->isEHPad()) return false; - for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { - BasicBlock *BBPred = *PI; + for (BasicBlock *BBPred : predecessors(BB)) { if (isa<IndirectBrInst>(BBPred->getTerminator()) || isa<CallBrInst>(BBPred->getTerminator())) return false; @@ -1786,7 +1764,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, Instruction *New = sinkThroughTriviallyReplaceablePHI( PN, &I, LI, SunkCopies, SafetyInfo, CurLoop, MSSAU); PN->replaceAllUsesWith(New); - eraseInstruction(*PN, *SafetyInfo, nullptr, nullptr); + eraseInstruction(*PN, *SafetyInfo, nullptr); Changed = true; } return Changed; @@ -1875,11 +1853,10 @@ class LoopPromoter : public LoadAndStorePromoter { SmallVectorImpl<Instruction *> &LoopInsertPts; SmallVectorImpl<MemoryAccess *> &MSSAInsertPts; PredIteratorCache &PredCache; - AliasSetTracker *AST; MemorySSAUpdater *MSSAU; LoopInfo &LI; DebugLoc DL; - int Alignment; + Align Alignment; bool UnorderedAtomic; AAMDNodes AATags; ICFLoopSafetyInfo &SafetyInfo; @@ -1907,13 +1884,13 @@ public: SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, SmallVectorImpl<MemoryAccess *> &MSSAIP, PredIteratorCache &PIC, - AliasSetTracker *ast, MemorySSAUpdater *MSSAU, LoopInfo &li, - DebugLoc dl, int alignment, bool UnorderedAtomic, - const AAMDNodes &AATags, ICFLoopSafetyInfo &SafetyInfo) + MemorySSAUpdater *MSSAU, LoopInfo &li, DebugLoc dl, + Align Alignment, bool UnorderedAtomic, const AAMDNodes &AATags, + ICFLoopSafetyInfo &SafetyInfo) : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), LoopExitBlocks(LEB), LoopInsertPts(LIP), MSSAInsertPts(MSSAIP), - PredCache(PIC), AST(ast), MSSAU(MSSAU), LI(li), DL(std::move(dl)), - Alignment(alignment), UnorderedAtomic(UnorderedAtomic), AATags(AATags), + PredCache(PIC), MSSAU(MSSAU), LI(li), DL(std::move(dl)), + Alignment(Alignment), UnorderedAtomic(UnorderedAtomic), AATags(AATags), SafetyInfo(SafetyInfo) {} bool isInstInList(Instruction *I, @@ -1940,39 +1917,29 @@ public: StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); if (UnorderedAtomic) NewSI->setOrdering(AtomicOrdering::Unordered); - NewSI->setAlignment(Align(Alignment)); + NewSI->setAlignment(Alignment); NewSI->setDebugLoc(DL); if (AATags) NewSI->setAAMetadata(AATags); - if (MSSAU) { - MemoryAccess *MSSAInsertPoint = MSSAInsertPts[i]; - MemoryAccess *NewMemAcc; - if (!MSSAInsertPoint) { - NewMemAcc = MSSAU->createMemoryAccessInBB( - NewSI, nullptr, NewSI->getParent(), MemorySSA::Beginning); - } else { - NewMemAcc = - MSSAU->createMemoryAccessAfter(NewSI, nullptr, MSSAInsertPoint); - } - MSSAInsertPts[i] = NewMemAcc; - MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); - // FIXME: true for safety, false may still be correct. + MemoryAccess *MSSAInsertPoint = MSSAInsertPts[i]; + MemoryAccess *NewMemAcc; + if (!MSSAInsertPoint) { + NewMemAcc = MSSAU->createMemoryAccessInBB( + NewSI, nullptr, NewSI->getParent(), MemorySSA::Beginning); + } else { + NewMemAcc = + MSSAU->createMemoryAccessAfter(NewSI, nullptr, MSSAInsertPoint); } + MSSAInsertPts[i] = NewMemAcc; + MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); + // FIXME: true for safety, false may still be correct. } } - void replaceLoadWithValue(LoadInst *LI, Value *V) const override { - // Update alias analysis. - if (AST) - AST->copyValue(LI, V); - } void instructionDeleted(Instruction *I) const override { SafetyInfo.removeInstruction(I); - if (AST) - AST->deleteValue(I); - if (MSSAU) - MSSAU->removeMemoryAccess(I); + MSSAU->removeMemoryAccess(I); } }; @@ -2023,8 +1990,8 @@ bool llvm::promoteLoopAccessesToScalars( SmallVectorImpl<Instruction *> &InsertPts, SmallVectorImpl<MemoryAccess *> &MSSAInsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, - ICFLoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { + Loop *CurLoop, MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && SafetyInfo != nullptr && @@ -2189,9 +2156,9 @@ bool llvm::promoteLoopAccessesToScalars( // Merge the AA tags. if (LoopUses.empty()) { // On the first load/store, just take its AA tags. - UI->getAAMetadata(AATags); + AATags = UI->getAAMetadata(); } else if (AATags) { - UI->getAAMetadata(AATags, /* Merge = */ true); + AATags = AATags.merge(UI->getAAMetadata()); } LoopUses.push_back(UI); @@ -2256,9 +2223,8 @@ bool llvm::promoteLoopAccessesToScalars( SmallVector<PHINode *, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, - InsertPts, MSSAInsertPts, PIC, CurAST, MSSAU, *LI, DL, - Alignment.value(), SawUnorderedAtomic, AATags, - *SafetyInfo); + InsertPts, MSSAInsertPts, PIC, MSSAU, *LI, DL, + Alignment, SawUnorderedAtomic, AATags, *SafetyInfo); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. @@ -2273,24 +2239,22 @@ bool llvm::promoteLoopAccessesToScalars( PreheaderLoad->setAAMetadata(AATags); SSA.AddAvailableValue(Preheader, PreheaderLoad); - if (MSSAU) { - MemoryAccess *PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB( - PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End); - MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess); - MSSAU->insertUse(NewMemUse, /*RenameUses=*/true); - } + MemoryAccess *PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB( + PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End); + MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess); + MSSAU->insertUse(NewMemUse, /*RenameUses=*/true); - if (MSSAU && VerifyMemorySSA) + if (VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); // Rewrite all the loads in the loop and remember all the definitions from // stores in the loop. Promoter.run(LoopUses); - if (MSSAU && VerifyMemorySSA) + if (VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); // If the SSAUpdater didn't use the load in the preheader, just zap it now. if (PreheaderLoad->use_empty()) - eraseInstruction(*PreheaderLoad, *SafetyInfo, CurAST, MSSAU); + eraseInstruction(*PreheaderLoad, *SafetyInfo, MSSAU); return true; } @@ -2356,71 +2320,10 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { return Result; } -/// Returns an owning pointer to an alias set which incorporates aliasing info -/// from L and all subloops of L. -std::unique_ptr<AliasSetTracker> -LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, - AAResults *AA) { - auto CurAST = std::make_unique<AliasSetTracker>(*AA); - - // Add everything from all the sub loops. - for (Loop *InnerL : L->getSubLoops()) - for (BasicBlock *BB : InnerL->blocks()) - CurAST->add(*BB); - - // And merge in this loop (without anything from inner loops). - for (BasicBlock *BB : L->blocks()) - if (LI->getLoopFor(BB) == L) - CurAST->add(*BB); - - return CurAST; -} - static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, AAResults *AA) { - // First check to see if any of the basic blocks in CurLoop invalidate *V. - bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod(); - - if (!isInvalidatedAccordingToAST || !LICMN2Theshold) - return isInvalidatedAccordingToAST; - - // Check with a diagnostic analysis if we can refine the information above. - // This is to identify the limitations of using the AST. - // The alias set mechanism used by LICM has a major weakness in that it - // combines all things which may alias into a single set *before* asking - // modref questions. As a result, a single readonly call within a loop will - // collapse all loads and stores into a single alias set and report - // invalidation if the loop contains any store. For example, readonly calls - // with deopt states have this form and create a general alias set with all - // loads and stores. In order to get any LICM in loops containing possible - // deopt states we need a more precise invalidation of checking the mod ref - // info of each instruction within the loop and LI. This has a complexity of - // O(N^2), so currently, it is used only as a diagnostic tool since the - // default value of LICMN2Threshold is zero. - - // Don't look at nested loops. - if (CurLoop->begin() != CurLoop->end()) - return true; - - int N = 0; - for (BasicBlock *BB : CurLoop->getBlocks()) - for (Instruction &I : *BB) { - if (N >= LICMN2Theshold) { - LLVM_DEBUG(dbgs() << "Alasing N2 threshold exhausted for " - << *(MemLoc.Ptr) << "\n"); - return true; - } - N++; - auto Res = AA->getModRefInfo(&I, MemLoc); - if (isModSet(Res)) { - LLVM_DEBUG(dbgs() << "Aliasing failed on " << I << " for " - << *(MemLoc.Ptr) << "\n"); - return true; - } - } - LLVM_DEBUG(dbgs() << "Aliasing okay for " << *(MemLoc.Ptr) << "\n"); - return false; + return CurAST->getAliasSetFor(MemLoc).isMod(); } bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp index 993b154dc9a8..d438d56e38ca 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopBoundSplit.h" +#include "llvm/ADT/Sequence.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" @@ -39,10 +40,12 @@ struct ConditionInfo { ICmpInst::Predicate Pred; /// AddRec llvm value Value *AddRecValue; + /// Non PHI AddRec llvm value + Value *NonPHIAddRecValue; /// Bound llvm value Value *BoundValue; /// AddRec SCEV - const SCEV *AddRecSCEV; + const SCEVAddRecExpr *AddRecSCEV; /// Bound SCEV const SCEV *BoundSCEV; @@ -54,19 +57,31 @@ struct ConditionInfo { } // namespace static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, - ConditionInfo &Cond) { + ConditionInfo &Cond, const Loop &L) { Cond.ICmp = ICmp; if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), m_Value(Cond.BoundValue)))) { - Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue); - Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue); + const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue); + const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue); + const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); + const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV); // Locate AddRec in LHSSCEV and Bound in RHSSCEV. - if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) && - !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) { + if (!LHSAddRecSCEV && RHSAddRecSCEV) { std::swap(Cond.AddRecValue, Cond.BoundValue); - std::swap(Cond.AddRecSCEV, Cond.BoundSCEV); + std::swap(AddRecSCEV, BoundSCEV); Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); } + + Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); + Cond.BoundSCEV = BoundSCEV; + Cond.NonPHIAddRecValue = Cond.AddRecValue; + + // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with + // value from backedge. + if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) { + PHINode *PN = cast<PHINode>(Cond.AddRecValue); + Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch()); + } } } @@ -118,21 +133,20 @@ static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, ICmpInst *ICmp, ConditionInfo &Cond, bool IsExitCond) { - analyzeICmp(SE, ICmp, Cond); + analyzeICmp(SE, ICmp, Cond, L); // The BoundSCEV should be evaluated at loop entry. if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) return false; - const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV); // Allowed AddRec as induction variable. - if (!AddRecSCEV) + if (!Cond.AddRecSCEV) return false; - if (!AddRecSCEV->isAffine()) + if (!Cond.AddRecSCEV->isAffine()) return false; - const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE); + const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE); // Allowed constant step. if (!isa<SCEVConstant>(StepRecSCEV)) return false; @@ -264,6 +278,14 @@ static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, SplitCandidateCond.BoundSCEV->getType()) continue; + // After transformation, we assume the split condition of the pre-loop is + // always true. In order to guarantee it, we need to check the start value + // of the split cond AddRec satisfies the split condition. + if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred, + SplitCandidateCond.AddRecSCEV->getStart(), + SplitCandidateCond.BoundSCEV)) + continue; + SplitCandidateCond.BI = BI; return BI; } @@ -341,13 +363,45 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, ".split", &LI, &DT, PostLoopBlocks); remapInstructionsInBlocks(PostLoopBlocks, VMap); - // Add conditional branch to check we can skip post-loop in its preheader. BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); - IRBuilder<> Builder(PostLoopPreHeader); + IRBuilder<> Builder(&PostLoopPreHeader->front()); + + // Update phi nodes in header of post-loop. + bool isExitingLatch = + (L.getExitingBlock() == L.getLoopLatch()) ? true : false; + Value *ExitingCondLCSSAPhi = nullptr; + for (PHINode &PN : L.getHeader()->phis()) { + // Create LCSSA phi node in preheader of post-loop. + PHINode *LCSSAPhi = + Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); + LCSSAPhi->setDebugLoc(PN.getDebugLoc()); + // If the exiting block is loop latch, the phi does not have the update at + // last iteration. In this case, update lcssa phi with value from backedge. + LCSSAPhi->addIncoming( + isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN, + L.getExitingBlock()); + + // Update the start value of phi node in post-loop with the LCSSA phi node. + PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); + PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi); + + // Find PHI with exiting condition from pre-loop. The PHI should be + // SCEVAddRecExpr and have same incoming value from backedge with + // ExitingCond. + if (!SE.isSCEVable(PN.getType())) + continue; + + const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); + if (PhiSCEV && ExitingCond.NonPHIAddRecValue == + PN.getIncomingValueForBlock(L.getLoopLatch())) + ExitingCondLCSSAPhi = LCSSAPhi; + } + + // Add conditional branch to check we can skip post-loop in its preheader. Instruction *OrigBI = PostLoopPreHeader->getTerminator(); ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; Value *Cond = - Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue); + Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue); Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); OrigBI->eraseFromParent(); @@ -368,21 +422,6 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, // Replace exiting bound value of pre-loop NewBound. ExitingCond.ICmp->setOperand(1, NewBoundValue); - // Replace IV's start value of post-loop by NewBound. - for (PHINode &PN : L.getHeader()->phis()) { - // Find PHI with exiting condition from pre-loop. - if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) { - for (Value *Op : PN.incoming_values()) { - if (Op == ExitingCond.AddRecValue) { - // Find cloned PHI for post-loop. - PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); - PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, - NewBoundValue); - } - } - } - } - // Replace SplitCandidateCond.BI's condition of pre-loop by True. LLVMContext &Context = PreHeader->getContext(); SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); @@ -398,6 +437,30 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, else ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); + // Update phi node in exit block of post-loop. + Builder.SetInsertPoint(&PostLoopPreHeader->front()); + for (PHINode &PN : PostLoop->getExitBlock()->phis()) { + for (auto i : seq<int>(0, PN.getNumOperands())) { + // Check incoming block is pre-loop's exiting block. + if (PN.getIncomingBlock(i) == L.getExitingBlock()) { + Value *IncomingValue = PN.getIncomingValue(i); + + // Create LCSSA phi node for incoming value. + PHINode *LCSSAPhi = + Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); + LCSSAPhi->setDebugLoc(PN.getDebugLoc()); + LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i)); + + // Replace pre-loop's exiting block by post-loop's preheader. + PN.setIncomingBlock(i, PostLoopPreHeader); + // Replace incoming value by LCSSAPhi. + PN.setIncomingValue(i, LCSSAPhi); + // Add a new incoming value with post-loop's exiting block. + PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock()); + } + } + } + // Update dominator tree. DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); @@ -406,10 +469,7 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, SE.forgetLoop(&L); // Canonicalize loops. - // TODO: Try to update LCSSA information according to above change. - formLCSSA(L, DT, &LI, &SE); simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); - formLCSSA(*PostLoop, DT, &LI, &SE); simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); // Add new post-loop to loop pass manager. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index a5d7835bd094..77d76609c926 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -29,6 +29,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -127,6 +128,8 @@ public: AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); @@ -143,6 +146,7 @@ INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch", diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index f7e8442fae81..5814e2f043d5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -36,6 +36,8 @@ using namespace llvm; #define DEBUG_TYPE "loop-delete" STATISTIC(NumDeleted, "Number of loops deleted"); +STATISTIC(NumBackedgesBroken, + "Number of loops for which we managed to break the backedge"); static cl::opt<bool> EnableSymbolicExecution( "loop-deletion-enable-symbolic-execution", cl::Hidden, cl::init(true), @@ -191,6 +193,20 @@ getValueOnFirstIteration(Value *V, DenseMap<Value *, Value *> &FirstIterValue, Value *RHS = getValueOnFirstIteration(BO->getOperand(1), FirstIterValue, SQ); FirstIterV = SimplifyBinOp(BO->getOpcode(), LHS, RHS, SQ); + } else if (auto *Cmp = dyn_cast<ICmpInst>(V)) { + Value *LHS = + getValueOnFirstIteration(Cmp->getOperand(0), FirstIterValue, SQ); + Value *RHS = + getValueOnFirstIteration(Cmp->getOperand(1), FirstIterValue, SQ); + FirstIterV = SimplifyICmpInst(Cmp->getPredicate(), LHS, RHS, SQ); + } else if (auto *Select = dyn_cast<SelectInst>(V)) { + Value *Cond = + getValueOnFirstIteration(Select->getCondition(), FirstIterValue, SQ); + if (auto *C = dyn_cast<ConstantInt>(Cond)) { + auto *Selected = C->isAllOnesValue() ? Select->getTrueValue() + : Select->getFalseValue(); + FirstIterV = getValueOnFirstIteration(Selected, FirstIterValue, SQ); + } } if (!FirstIterV) FirstIterV = V; @@ -314,22 +330,20 @@ static bool canProveExitOnFirstIteration(Loop *L, DominatorTree &DT, } using namespace PatternMatch; - ICmpInst::Predicate Pred; - Value *LHS, *RHS; + Value *Cond; BasicBlock *IfTrue, *IfFalse; auto *Term = BB->getTerminator(); - if (match(Term, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), + if (match(Term, m_Br(m_Value(Cond), m_BasicBlock(IfTrue), m_BasicBlock(IfFalse)))) { - if (!LHS->getType()->isIntegerTy()) { + auto *ICmp = dyn_cast<ICmpInst>(Cond); + if (!ICmp || !ICmp->getType()->isIntegerTy()) { MarkAllSuccessorsLive(BB); continue; } // Can we prove constant true or false for this condition? - LHS = getValueOnFirstIteration(LHS, FirstIterValue, SQ); - RHS = getValueOnFirstIteration(RHS, FirstIterValue, SQ); - auto *KnownCondition = SimplifyICmpInst(Pred, LHS, RHS, SQ); - if (!KnownCondition) { + auto *KnownCondition = getValueOnFirstIteration(ICmp, FirstIterValue, SQ); + if (KnownCondition == ICmp) { // Failed to simplify. MarkAllSuccessorsLive(BB); continue; @@ -393,14 +407,25 @@ breakBackedgeIfNotTaken(Loop *L, DominatorTree &DT, ScalarEvolution &SE, if (!L->getLoopLatch()) return LoopDeletionResult::Unmodified; - auto *BTC = SE.getBackedgeTakenCount(L); - if (!isa<SCEVCouldNotCompute>(BTC) && SE.isKnownNonZero(BTC)) - return LoopDeletionResult::Unmodified; - if (!BTC->isZero() && !canProveExitOnFirstIteration(L, DT, LI)) - return LoopDeletionResult::Unmodified; + auto *BTC = SE.getSymbolicMaxBackedgeTakenCount(L); + if (BTC->isZero()) { + // SCEV knows this backedge isn't taken! + breakLoopBackedge(L, DT, SE, LI, MSSA); + ++NumBackedgesBroken; + return LoopDeletionResult::Deleted; + } - breakLoopBackedge(L, DT, SE, LI, MSSA); - return LoopDeletionResult::Deleted; + // If SCEV leaves open the possibility of a zero trip count, see if + // symbolically evaluating the first iteration lets us prove the backedge + // unreachable. + if (isa<SCEVCouldNotCompute>(BTC) || !SE.isKnownNonZero(BTC)) + if (canProveExitOnFirstIteration(L, DT, LI)) { + breakLoopBackedge(L, DT, SE, LI, MSSA); + ++NumBackedgesBroken; + return LoopDeletionResult::Deleted; + } + + return LoopDeletionResult::Unmodified; } /// Remove a loop if it is dead. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index bac3dc0f3fb9..0f4c767c1e4c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -1057,8 +1057,8 @@ PreservedAnalyses LoopDistributePass::run(Function &F, auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, nullptr}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, + TLI, TTI, nullptr, nullptr, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index f54289f85ef5..965d1575518e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -27,6 +27,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopFlatten.h" + +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -49,11 +51,13 @@ #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" -#define DEBUG_TYPE "loop-flatten" - using namespace llvm; using namespace llvm::PatternMatch; +#define DEBUG_TYPE "loop-flatten" + +STATISTIC(NumFlattened, "Number of loops flattened"); + static cl::opt<unsigned> RepeatedInstructionThreshold( "loop-flatten-cost-threshold", cl::Hidden, cl::init(2), cl::desc("Limit on the cost of instructions that can be repeated due to " @@ -90,9 +94,33 @@ struct FlattenInfo { // Whether this holds the flatten info before or after widening. bool Widened = false; + // Holds the old/narrow induction phis, i.e. the Phis before IV widening has + // been applied. This bookkeeping is used so we can skip some checks on these + // phi nodes. + PHINode *NarrowInnerInductionPHI = nullptr; + PHINode *NarrowOuterInductionPHI = nullptr; + FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; + + bool isNarrowInductionPhi(PHINode *Phi) { + // This can't be the narrow phi if we haven't widened the IV first. + if (!Widened) + return false; + return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi; + } }; +static bool +setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, + SmallPtrSetImpl<Instruction *> &IterationInstructions) { + TripCount = TC; + IterationInstructions.insert(Increment); + LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump()); + LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); + LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); + return true; +} + // Finds the induction variable, increment and trip count for a simple loop that // we can flatten. static bool findLoopComponents( @@ -164,36 +192,68 @@ static bool findLoopComponents( return false; } // The trip count is the RHS of the compare. If this doesn't match the trip - // count computed by SCEV then this is either because the trip count variable - // has been widened (then leave the trip count as it is), or because it is a - // constant and another transformation has changed the compare, e.g. - // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten - // the loop (yet). - TripCount = Compare->getOperand(1); + // count computed by SCEV then this is because the trip count variable + // has been widened so the types don't match, or because it is a constant and + // another transformation has changed the compare (e.g. icmp ult %inc, + // tripcount -> icmp ult %j, tripcount-1), or both. + Value *RHS = Compare->getOperand(1); + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) { + LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); + return false; + } + // The use of the Extend=false flag on getTripCountFromExitCount was added + // during a refactoring to preserve existing behavior. However, there's + // nothing obvious in the surrounding code when handles the overflow case. + // FIXME: audit code to establish whether there's a latent bug here. const SCEV *SCEVTripCount = - SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L)); - if (SE->getSCEV(TripCount) != SCEVTripCount) { - if (!IsWidened) { - LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); - return false; - } - auto TripCountInst = dyn_cast<Instruction>(TripCount); - if (!TripCountInst) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; + SE->getTripCountFromExitCount(BackedgeTakenCount, false); + const SCEV *SCEVRHS = SE->getSCEV(RHS); + if (SCEVRHS == SCEVTripCount) + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); + ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS); + if (ConstantRHS) { + const SCEV *BackedgeTCExt = nullptr; + if (IsWidened) { + const SCEV *SCEVTripCountExt; + // Find the extended backedge taken count and extended trip count using + // SCEV. One of these should now match the RHS of the compare. + BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); + if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } } - if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) || - SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; + // If the RHS of the compare is equal to the backedge taken count we need + // to add one to get the trip count. + if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { + ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1); + Value *NewRHS = ConstantInt::get( + ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue()); + return setLoopComponents(NewRHS, TripCount, Increment, + IterationInstructions); } + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); } - IterationInstructions.insert(Increment); - LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); - - LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); - return true; + // If the RHS isn't a constant then check that the reason it doesn't match + // the SCEV trip count is because the RHS is a ZExt or SExt instruction + // (and take the trip count to be the RHS). + if (!IsWidened) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + auto *TripCountInst = dyn_cast<Instruction>(RHS); + if (!TripCountInst) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) || + SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); } static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { @@ -221,6 +281,8 @@ static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { // them specially when doing the transformation. if (&InnerPHI == FI.InnerInductionPHI) continue; + if (FI.isNarrowInductionPhi(&InnerPHI)) + continue; // Each inner loop PHI node must have two incoming values/blocks - one // from the pre-header, and one from the latch. @@ -266,6 +328,8 @@ static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { } for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { + if (FI.isNarrowInductionPhi(&OuterPHI)) + continue; if (!SafeOuterPHIs.count(&OuterPHI)) { LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); return false; @@ -356,18 +420,25 @@ static bool checkIVUsers(FlattenInfo &FI) { if (U == FI.InnerIncrement) continue; - // After widening the IVs, a trunc instruction might have been introduced, so - // look through truncs. + // After widening the IVs, a trunc instruction might have been introduced, + // so look through truncs. if (isa<TruncInst>(U)) { if (!U->hasOneUse()) return false; U = *U->user_begin(); } + // If the use is in the compare (which is also the condition of the inner + // branch) then the compare has been altered by another transformation e.g + // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is + // a constant. Ignore this use as the compare gets removed later anyway. + if (U == FI.InnerBranch->getCondition()) + continue; + LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); - Value *MatchedMul; - Value *MatchedItCount; + Value *MatchedMul = nullptr; + Value *MatchedItCount = nullptr; bool IsAdd = match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), m_Value(MatchedMul))) && match(MatchedMul, m_c_Mul(m_Specific(FI.OuterInductionPHI), @@ -375,11 +446,23 @@ static bool checkIVUsers(FlattenInfo &FI) { // Matches the same pattern as above, except it also looks for truncs // on the phi, which can be the result of widening the induction variables. - bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), - m_Value(MatchedMul))) && - match(MatchedMul, - m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), - m_Value(MatchedItCount))); + bool IsAddTrunc = + match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), + m_Value(MatchedItCount))); + + if (!MatchedItCount) + return false; + // Look through extends if the IV has been widened. + if (FI.Widened && + (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) { + assert(MatchedItCount->getType() == FI.InnerInductionPHI->getType() && + "Unexpected type mismatch in types after widening"); + MatchedItCount = isa<SExtInst>(MatchedItCount) + ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0) + : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0); + } if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); @@ -451,17 +534,27 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, for (Value *V : FI.LinearIVUses) { for (Value *U : V->users()) { if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) { - // The IV is used as the operand of a GEP, and the IV is at least as - // wide as the address space of the GEP. In this case, the GEP would - // wrap around the address space before the IV increment wraps, which - // would be UB. - if (GEP->isInBounds() && - V->getType()->getIntegerBitWidth() >= - DL.getPointerTypeSizeInBits(GEP->getType())) { - LLVM_DEBUG( - dbgs() << "use of linear IV would be UB if overflow occurred: "; - GEP->dump()); - return OverflowResult::NeverOverflows; + for (Value *GEPUser : U->users()) { + Instruction *GEPUserInst = dyn_cast<Instruction>(GEPUser); + if (!isa<LoadInst>(GEPUserInst) && + !(isa<StoreInst>(GEPUserInst) && + GEP == GEPUserInst->getOperand(1))) + continue; + if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, + FI.InnerLoop)) + continue; + // The IV is used as the operand of a GEP which dominates the loop + // latch, and the IV is at least as wide as the address space of the + // GEP. In this case, the GEP would wrap around the address space + // before the IV increment wraps, which would be UB. + if (GEP->isInBounds() && + V->getType()->getIntegerBitWidth() >= + DL.getPointerTypeSizeInBits(GEP->getType())) { + LLVM_DEBUG( + dbgs() << "use of linear IV would be UB if overflow occurred: "; + GEP->dump()); + return OverflowResult::NeverOverflows; + } } } } @@ -518,7 +611,7 @@ static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, LPMUpdater *U) { Function *F = FI.OuterLoop->getHeader()->getParent(); LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); { @@ -574,7 +667,13 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, // deleted, and any information that have about the outer loop invalidated. SE->forgetLoop(FI.OuterLoop); SE->forgetLoop(FI.InnerLoop); + if (U) + U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName()); LI->erase(FI.InnerLoop); + + // Increment statistic value. + NumFlattened++; + return true; } @@ -605,14 +704,11 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, } SCEVExpander Rewriter(*SE, DL, "loopflatten"); - SmallVector<WideIVInfo, 2> WideIVs; SmallVector<WeakTrackingVH, 4> DeadInsts; - WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false }); - WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false }); unsigned ElimExt = 0; unsigned Widened = 0; - for (const auto &WideIV : WideIVs) { + auto CreateWideIV = [&] (WideIVInfo WideIV, bool &Deleted) -> bool { PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened, true /* HasGuards */, true /* UsePostIncrementRanges */); @@ -620,17 +716,35 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, return false; LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump()); - RecursivelyDeleteDeadPHINode(WideIV.NarrowIV); - } - // After widening, rediscover all the loop components. + Deleted = RecursivelyDeleteDeadPHINode(WideIV.NarrowIV); + return true; + }; + + bool Deleted; + if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false }, Deleted)) + return false; + // Add the narrow phi to list, so that it will be adjusted later when the + // the transformation is performed. + if (!Deleted) + FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI); + + if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false }, Deleted)) + return false; + assert(Widened && "Widened IV expected"); FI.Widened = true; + + // Save the old/narrow induction phis, which we need to ignore in CheckPHIs. + FI.NarrowInnerInductionPHI = FI.InnerInductionPHI; + FI.NarrowOuterInductionPHI = FI.OuterInductionPHI; + + // After widening, rediscover all the loop components. return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); } static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, LPMUpdater *U) { LLVM_DEBUG( dbgs() << "Loop flattening running on outer loop " << FI.OuterLoop->getHeader()->getName() << " and inner loop " @@ -641,12 +755,30 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, return false; // Check if we can widen the induction variables to avoid overflow checks. - if (CanWidenIV(FI, DT, LI, SE, AC, TTI)) - return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI); - - // Check if the new iteration variable might overflow. In this case, we - // need to version the loop, and select the original version at runtime if - // the iteration space is too large. + bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI); + + // It can happen that after widening of the IV, flattening may not be + // possible/happening, e.g. when it is deemed unprofitable. So bail here if + // that is the case. + // TODO: IV widening without performing the actual flattening transformation + // is not ideal. While this codegen change should not matter much, it is an + // unnecessary change which is better to avoid. It's unlikely this happens + // often, because if it's unprofitibale after widening, it should be + // unprofitabe before widening as checked in the first round of checks. But + // 'RepeatedInstructionThreshold' is set to only 2, which can probably be + // relaxed. Because this is making a code change (the IV widening, but not + // the flattening), we return true here. + if (FI.Widened && !CanFlatten) + return true; + + // If we have widened and can perform the transformation, do that here. + if (CanFlatten) + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U); + + // Otherwise, if we haven't widened the IV, check if the new iteration + // variable might overflow. In this case, we need to version the loop, and + // select the original version at runtime if the iteration space is too + // large. // TODO: We currently don't version the loop. OverflowResult OR = checkOverflow(FI, DT, AC); if (OR == OverflowResult::AlwaysOverflowsHigh || @@ -659,18 +791,18 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, } LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); - return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI); + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U); } bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI) { + AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U) { bool Changed = false; for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) continue; FlattenInfo FI(OuterLoop, InnerLoop); - Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI); + Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U); } return Changed; } @@ -685,12 +817,12 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI); + Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U); if (!Changed) return PreservedAnalyses::all(); - return PreservedAnalyses::none(); + return getLoopPassPreservedAnalyses(); } namespace { @@ -735,7 +867,7 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) { bool Changed = false; for (Loop *L : *LI) { auto LN = LoopNest::getLoopNest(*L, *SE); - Changed |= Flatten(*LN, DT, LI, SE, AC, TTI); + Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr); } return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index a153f393448c..42da86a9ecf5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -217,15 +217,15 @@ private: bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); - bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, + bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, - bool NegStride, bool IsLoopMemset = false); + bool IsNegStride, bool IsLoopMemset = false); bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount); bool processLoopStoreOfLoopLoad(Value *DestPtr, Value *SourcePtr, - unsigned StoreSize, MaybeAlign StoreAlign, + const SCEV *StoreSize, MaybeAlign StoreAlign, MaybeAlign LoadAlign, Instruction *TheStore, Instruction *TheLoad, const SCEVAddRecExpr *StoreEv, @@ -625,8 +625,8 @@ bool LoopIdiomRecognize::runOnLoopBlock( // We can only promote stores in this block if they are unconditionally // executed in the loop. For a block to be unconditionally executed, it has // to dominate all the exit blocks of the loop. Verify this now. - for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) - if (!DT->dominates(BB, ExitBlocks[i])) + for (BasicBlock *ExitBlock : ExitBlocks) + if (!DT->dominates(BB, ExitBlock)) return false; bool MadeChange = false; @@ -750,16 +750,13 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, bool Changed = false; // For stores that start but don't end a link in the chain: - for (SetVector<StoreInst *>::iterator it = Heads.begin(), e = Heads.end(); - it != e; ++it) { - if (Tails.count(*it)) + for (StoreInst *I : Heads) { + if (Tails.count(I)) continue; // We found a store instr that starts a chain. Now follow the chain and try // to transform it. SmallPtrSet<Instruction *, 8> AdjacentStores; - StoreInst *I = *it; - StoreInst *HeadStore = I; unsigned StoreSize = 0; @@ -784,12 +781,14 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, if (StoreSize != Stride && StoreSize != -Stride) continue; - bool NegStride = StoreSize == -Stride; + bool IsNegStride = StoreSize == -Stride; - if (processLoopStridedStore(StorePtr, StoreSize, + Type *IntIdxTy = DL->getIndexType(StorePtr->getType()); + const SCEV *StoreSizeSCEV = SE->getConstant(IntIdxTy, StoreSize); + if (processLoopStridedStore(StorePtr, StoreSizeSCEV, MaybeAlign(HeadStore->getAlignment()), StoredVal, HeadStore, AdjacentStores, StoreEv, - BECount, NegStride)) { + BECount, IsNegStride)) { TransformedStores.insert(AdjacentStores.begin(), AdjacentStores.end()); Changed = true; } @@ -857,15 +856,15 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, // Check if the stride matches the size of the memcpy. If so, then we know // that every byte is touched in the loop. - const SCEVConstant *StoreStride = + const SCEVConstant *ConstStoreStride = dyn_cast<SCEVConstant>(StoreEv->getOperand(1)); - const SCEVConstant *LoadStride = + const SCEVConstant *ConstLoadStride = dyn_cast<SCEVConstant>(LoadEv->getOperand(1)); - if (!StoreStride || !LoadStride) + if (!ConstStoreStride || !ConstLoadStride) return false; - APInt StoreStrideValue = StoreStride->getAPInt(); - APInt LoadStrideValue = LoadStride->getAPInt(); + APInt StoreStrideValue = ConstStoreStride->getAPInt(); + APInt LoadStrideValue = ConstLoadStride->getAPInt(); // Huge stride value - give up if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64) return false; @@ -875,7 +874,7 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI) << ore::NV("Inst", "memcpy") << " in " << ore::NV("Function", MCI->getFunction()) - << " function will not be hoised: " + << " function will not be hoisted: " << ore::NV("Reason", "memcpy size is not equal to stride"); }); return false; @@ -887,16 +886,17 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, if (StoreStrideInt != LoadStrideInt) return false; - return processLoopStoreOfLoopLoad(Dest, Source, (unsigned)SizeInBytes, - MCI->getDestAlign(), MCI->getSourceAlign(), - MCI, MCI, StoreEv, LoadEv, BECount); + return processLoopStoreOfLoopLoad( + Dest, Source, SE->getConstant(Dest->getType(), SizeInBytes), + MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, StoreEv, LoadEv, + BECount); } /// processLoopMemSet - See if this memset can be promoted to a large memset. bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { - // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa<ConstantInt>(MSI->getLength())) + // We can only handle non-volatile memsets. + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -909,23 +909,72 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, // loop, which indicates a strided store. If we have something else, it's a // random store we can't handle. const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Pointer)); - if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) + if (!Ev || Ev->getLoop() != CurLoop) return false; - - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) + if (!Ev->isAffine()) { + LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n"); return false; + } - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1)); - if (!ConstStride) + const SCEV *PointerStrideSCEV = Ev->getOperand(1); + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!PointerStrideSCEV || !MemsetSizeSCEV) return false; - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + bool IsNegStride = false; + const bool IsConstantSize = isa<ConstantInt>(MSI->getLength()); + + if (IsConstantSize) { + // Memset size is constant. + // Check if the pointer stride matches the memset size. If so, then + // we know that every byte is touched in the loop. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue(); + const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + IsNegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant. + // Check if the pointer stride matches the memset size. + // To be conservative, the pass would not promote pointers that aren't in + // address space zero. Also, the pass only handles memset length and stride + // that are invariant for the top level loop. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (Pointer->getType()->getPointerAddressSpace() != 0) { + LLVM_DEBUG(dbgs() << " pointer is not in address space zero, " + << "abort\n"); + return false; + } + if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop)) { + LLVM_DEBUG(dbgs() << " memset size is not a loop-invariant, " + << "abort\n"); + return false; + } + + // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV + IsNegStride = PointerStrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV) + : PointerStrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + // TODO: folding can be done to the SCEVs + // The folding is to fold expressions that is covered by the loop guard + // at loop entry. After the folding, compare again and proceed + // optimization if equal. + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -935,10 +984,10 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, SmallPtrSet<Instruction *, 1> MSIs; MSIs.insert(MSI); - bool NegStride = SizeInBytes == -Stride; - return processLoopStridedStore( - Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()), - SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true); + return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), + MaybeAlign(MSI->getDestAlignment()), + SplatValue, MSI, MSIs, Ev, BECount, + IsNegStride, /*IsLoopMemset=*/true); } /// mayLoopAccessLocation - Return true if the specified loop might access the @@ -946,9 +995,9 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, /// argument specifies what the verboten forms of access are (read or write). static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, - const SCEV *BECount, unsigned StoreSize, + const SCEV *BECount, const SCEV *StoreSizeSCEV, AliasAnalysis &AA, - SmallPtrSetImpl<Instruction *> &IgnoredStores) { + SmallPtrSetImpl<Instruction *> &IgnoredInsts) { // Get the location that may be stored across the loop. Since the access is // strided positively through memory, we say that the modified location starts // at the pointer and has infinite size. @@ -956,9 +1005,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // If the loop iterates a fixed number of times, we can refine the access size // to be exactly the size of the memset, which is (BECount+1)*StoreSize - if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) + const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount); + const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV); + if (BECst && ConstSize) AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * - StoreSize); + ConstSize->getValue()->getZExtValue()); // TODO: For this to be really effective, we have to dive into the pointer // operand in the store. Store to &A[i] of 100 will always return may alias @@ -966,14 +1017,12 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // which will then no-alias a store to &A[100]. MemoryLocation StoreLoc(Ptr, AccessSize); - for (Loop::block_iterator BI = L->block_begin(), E = L->block_end(); BI != E; - ++BI) - for (Instruction &I : **BI) - if (IgnoredStores.count(&I) == 0 && + for (BasicBlock *B : L->blocks()) + for (Instruction &I : *B) + if (!IgnoredInsts.contains(&I) && isModOrRefSet( intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) return true; - return false; } @@ -981,57 +1030,67 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // we're trying to memset. Therefore, we need to recompute the base pointer, // which is just Start - BECount*Size. static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, - Type *IntPtr, unsigned StoreSize, + Type *IntPtr, const SCEV *StoreSizeSCEV, ScalarEvolution *SE) { const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr); - if (StoreSize != 1) - Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize), + if (!StoreSizeSCEV->isOne()) { + // index = back edge count * store size + Index = SE->getMulExpr(Index, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), SCEV::FlagNUW); + } + // base pointer = start - index * store size return SE->getMinusSCEV(Start, Index); } -/// Compute the number of bytes as a SCEV from the backedge taken count. -/// -/// This also maps the SCEV into the provided type and tries to handle the -/// computation in a way that will fold cleanly. -static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, - unsigned StoreSize, Loop *CurLoop, - const DataLayout *DL, ScalarEvolution *SE) { - const SCEV *NumBytesS; - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to +/// Compute trip count from the backedge taken count. +static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr, + Loop *CurLoop, const DataLayout *DL, + ScalarEvolution *SE) { + const SCEV *TripCountS = nullptr; + // The # stored bytes is (BECount+1). Expand the trip count out to // pointer size if it isn't already. // // If we're going to need to zero extend the BE count, check if we can add // one to it prior to zero extending without overflow. Provided this is safe, // it allows better simplification of the +1. - if (DL->getTypeSizeInBits(BECount->getType()).getFixedSize() < - DL->getTypeSizeInBits(IntPtr).getFixedSize() && + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && SE->isLoopEntryGuardedByCond( CurLoop, ICmpInst::ICMP_NE, BECount, SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { - NumBytesS = SE->getZeroExtendExpr( + TripCountS = SE->getZeroExtendExpr( SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), IntPtr); } else { - NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), - SE->getOne(IntPtr), SCEV::FlagNUW); + TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); } - // And scale it based on the store size. - if (StoreSize != 1) { - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } - return NumBytesS; + return TripCountS; +} + +/// Compute the number of bytes as a SCEV from the backedge taken count. +/// +/// This also maps the SCEV into the provided type and tries to handle the +/// computation in a way that will fold cleanly. +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + const SCEV *StoreSizeSCEV, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); + + return SE->getMulExpr(TripCountSCEV, + SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr), + SCEV::FlagNUW); } /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( - Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment, + Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, - const SCEV *BECount, bool NegStride, bool IsLoopMemset) { + const SCEV *BECount, bool IsNegStride, bool IsLoopMemset) { Value *SplatValue = isBytewiseValue(StoredVal, *DL); Constant *PatternValue = nullptr; @@ -1056,8 +1115,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( bool Changed = false; const SCEV *Start = Ev->getStart(); // Handle negative strided loops. - if (NegStride) - Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE); + if (IsNegStride) + Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1082,7 +1141,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Changed = true; if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores)) + StoreSizeSCEV, *AA, Stores)) return Changed; if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset)) @@ -1091,7 +1150,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, everything looks good, insert the memset. const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -1138,13 +1197,20 @@ bool LoopIdiomRecognize::processLoopStridedStore( << "\n"); ORE.emit([&]() { - return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStridedStore", - NewCall->getDebugLoc(), Preheader) - << "Transformed loop-strided store in " - << ore::NV("Function", TheStore->getFunction()) - << " function into a call to " - << ore::NV("NewFunction", NewCall->getCalledFunction()) - << "() intrinsic"; + OptimizationRemark R(DEBUG_TYPE, "ProcessLoopStridedStore", + NewCall->getDebugLoc(), Preheader); + R << "Transformed loop-strided store in " + << ore::NV("Function", TheStore->getFunction()) + << " function into a call to " + << ore::NV("NewFunction", NewCall->getCalledFunction()) + << "() intrinsic"; + if (!Stores.empty()) + R << ore::setExtraArgs(); + for (auto *I : Stores) { + R << ore::NV("FromBlock", I->getParent()->getName()) + << ore::NV("ToBlock", Preheader->getName()); + } + return R; }); // Okay, the memset has been formed. Zap the original store and anything that @@ -1181,16 +1247,63 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // random load we can't handle. Value *LoadPtr = LI->getPointerOperand(); const SCEVAddRecExpr *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr)); - return processLoopStoreOfLoopLoad(StorePtr, LoadPtr, StoreSize, + + const SCEV *StoreSizeSCEV = SE->getConstant(StorePtr->getType(), StoreSize); + return processLoopStoreOfLoopLoad(StorePtr, LoadPtr, StoreSizeSCEV, SI->getAlign(), LI->getAlign(), SI, LI, StoreEv, LoadEv, BECount); } +class MemmoveVerifier { +public: + explicit MemmoveVerifier(const Value &LoadBasePtr, const Value &StoreBasePtr, + const DataLayout &DL) + : DL(DL), LoadOff(0), StoreOff(0), + BP1(llvm::GetPointerBaseWithConstantOffset( + LoadBasePtr.stripPointerCasts(), LoadOff, DL)), + BP2(llvm::GetPointerBaseWithConstantOffset( + StoreBasePtr.stripPointerCasts(), StoreOff, DL)), + IsSameObject(BP1 == BP2) {} + + bool loadAndStoreMayFormMemmove(unsigned StoreSize, bool IsNegStride, + const Instruction &TheLoad, + bool IsMemCpy) const { + if (IsMemCpy) { + // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr + // for negative stride. + if ((!IsNegStride && LoadOff <= StoreOff) || + (IsNegStride && LoadOff >= StoreOff)) + return false; + } else { + // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr + // for negative stride. LoadBasePtr shouldn't overlap with StoreBasePtr. + int64_t LoadSize = + DL.getTypeSizeInBits(TheLoad.getType()).getFixedSize() / 8; + if (BP1 != BP2 || LoadSize != int64_t(StoreSize)) + return false; + if ((!IsNegStride && LoadOff < StoreOff + int64_t(StoreSize)) || + (IsNegStride && LoadOff + LoadSize > StoreOff)) + return false; + } + return true; + } + +private: + const DataLayout &DL; + int64_t LoadOff; + int64_t StoreOff; + const Value *BP1; + const Value *BP2; + +public: + const bool IsSameObject; +}; + bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( - Value *DestPtr, Value *SourcePtr, unsigned StoreSize, MaybeAlign StoreAlign, - MaybeAlign LoadAlign, Instruction *TheStore, Instruction *TheLoad, - const SCEVAddRecExpr *StoreEv, const SCEVAddRecExpr *LoadEv, - const SCEV *BECount) { + Value *DestPtr, Value *SourcePtr, const SCEV *StoreSizeSCEV, + MaybeAlign StoreAlign, MaybeAlign LoadAlign, Instruction *TheStore, + Instruction *TheLoad, const SCEVAddRecExpr *StoreEv, + const SCEVAddRecExpr *LoadEv, const SCEV *BECount) { // FIXME: until llvm.memcpy.inline supports dynamic sizes, we need to // conservatively bail here, since otherwise we may have to transform @@ -1213,11 +1326,18 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( Type *IntIdxTy = Builder.getIntNTy(DL->getIndexSizeInBits(StrAS)); APInt Stride = getStoreStride(StoreEv); - bool NegStride = StoreSize == -Stride; + const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(StoreSizeSCEV); + + // TODO: Deal with non-constant size; Currently expect constant store size + assert(ConstStoreSize && "store size is expected to be a constant"); + + int64_t StoreSize = ConstStoreSize->getValue()->getZExtValue(); + bool IsNegStride = StoreSize == -Stride; // Handle negative strided loops. - if (NegStride) - StrStart = getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSize, SE); + if (IsNegStride) + StrStart = + getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE); // Okay, we have a strided store "p[i]" of a loaded value. We can turn // this into a memcpy in the loop preheader now if we want. However, this @@ -1237,24 +1357,24 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // the return value will read this comment, and leave them alone. Changed = true; - SmallPtrSet<Instruction *, 2> Stores; - Stores.insert(TheStore); + SmallPtrSet<Instruction *, 2> IgnoredInsts; + IgnoredInsts.insert(TheStore); bool IsMemCpy = isa<MemCpyInst>(TheStore); const StringRef InstRemark = IsMemCpy ? "memcpy" : "load and store"; - bool UseMemMove = + bool LoopAccessStore = mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, - StoreSize, *AA, Stores); - if (UseMemMove) { + StoreSizeSCEV, *AA, IgnoredInsts); + if (LoopAccessStore) { // For memmove case it's not enough to guarantee that loop doesn't access // TheStore and TheLoad. Additionally we need to make sure that TheStore is // the only user of TheLoad. if (!TheLoad->hasOneUse()) return Changed; - Stores.insert(TheLoad); + IgnoredInsts.insert(TheLoad); if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, - BECount, StoreSize, *AA, Stores)) { + BECount, StoreSizeSCEV, *AA, IgnoredInsts)) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore", TheStore) @@ -1265,15 +1385,16 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( }); return Changed; } - Stores.erase(TheLoad); + IgnoredInsts.erase(TheLoad); } const SCEV *LdStart = LoadEv->getStart(); unsigned LdAS = SourcePtr->getType()->getPointerAddressSpace(); // Handle negative strided loops. - if (NegStride) - LdStart = getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSize, SE); + if (IsNegStride) + LdStart = + getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE); // For a memcpy, we have to make sure that the input array is not being // mutated by the loop. @@ -1283,42 +1404,40 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // If the store is a memcpy instruction, we must check if it will write to // the load memory locations. So remove it from the ignored stores. if (IsMemCpy) - Stores.erase(TheStore); + IgnoredInsts.erase(TheStore); + MemmoveVerifier Verifier(*LoadBasePtr, *StoreBasePtr, *DL); if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, - StoreSize, *AA, Stores)) { - ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad) - << ore::NV("Inst", InstRemark) << " in " - << ore::NV("Function", TheStore->getFunction()) - << " function will not be hoisted: " - << ore::NV("Reason", "The loop may access load location"); - }); - return Changed; - } - if (UseMemMove) { - // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr for - // negative stride. LoadBasePtr shouldn't overlap with StoreBasePtr. - int64_t LoadOff = 0, StoreOff = 0; - const Value *BP1 = llvm::GetPointerBaseWithConstantOffset( - LoadBasePtr->stripPointerCasts(), LoadOff, *DL); - const Value *BP2 = llvm::GetPointerBaseWithConstantOffset( - StoreBasePtr->stripPointerCasts(), StoreOff, *DL); - int64_t LoadSize = - DL->getTypeSizeInBits(TheLoad->getType()).getFixedSize() / 8; - if (BP1 != BP2 || LoadSize != int64_t(StoreSize)) + StoreSizeSCEV, *AA, IgnoredInsts)) { + if (!IsMemCpy) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", + TheLoad) + << ore::NV("Inst", InstRemark) << " in " + << ore::NV("Function", TheStore->getFunction()) + << " function will not be hoisted: " + << ore::NV("Reason", "The loop may access load location"); + }); return Changed; - if ((!NegStride && LoadOff < StoreOff + int64_t(StoreSize)) || - (NegStride && LoadOff + LoadSize > StoreOff)) + } + // At this point loop may access load only for memcpy in same underlying + // object. If that's not the case bail out. + if (!Verifier.IsSameObject) return Changed; } + bool UseMemMove = IsMemCpy ? Verifier.IsSameObject : LoopAccessStore; + if (UseMemMove) + if (!Verifier.loadAndStoreMayFormMemmove(StoreSize, IsNegStride, *TheLoad, + IsMemCpy)) + return Changed; + if (avoidLIRForMultiBlockLoop()) return Changed; // Okay, everything is safe, we can transform this! const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); @@ -1380,11 +1499,14 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( << ore::NV("NewFunction", NewCall->getCalledFunction()) << "() intrinsic from " << ore::NV("Inst", InstRemark) << " instruction in " << ore::NV("Function", TheStore->getFunction()) - << " function"; + << " function" + << ore::setExtraArgs() + << ore::NV("FromBlock", TheStore->getParent()->getName()) + << ore::NV("ToBlock", Preheader->getName()); }); - // Okay, the memcpy has been formed. Zap the original store and anything that - // feeds into it. + // Okay, a new call to memcpy/memmove has been formed. Zap the original store + // and anything that feeds into it. if (MSSAU) MSSAU->removeMemoryAccess(TheStore, true); deleteDeadInstruction(TheStore); @@ -1549,24 +1671,22 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, // step 4: Find the instruction which count the population: cnt2 = cnt1 + 1 { CountInst = nullptr; - for (BasicBlock::iterator Iter = LoopEntry->getFirstNonPHI()->getIterator(), - IterE = LoopEntry->end(); - Iter != IterE; Iter++) { - Instruction *Inst = &*Iter; - if (Inst->getOpcode() != Instruction::Add) + for (Instruction &Inst : llvm::make_range( + LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) { + if (Inst.getOpcode() != Instruction::Add) continue; - ConstantInt *Inc = dyn_cast<ConstantInt>(Inst->getOperand(1)); + ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1)); if (!Inc || !Inc->isOne()) continue; - PHINode *Phi = getRecurrenceVar(Inst->getOperand(0), Inst, LoopEntry); + PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry); if (!Phi) continue; // Check if the result of the instruction is live of the loop. bool LiveOutLoop = false; - for (User *U : Inst->users()) { + for (User *U : Inst.users()) { if ((cast<Instruction>(U))->getParent() != LoopEntry) { LiveOutLoop = true; break; @@ -1574,7 +1694,7 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, } if (LiveOutLoop) { - CountInst = Inst; + CountInst = &Inst; CountPhi = Phi; break; } @@ -1675,22 +1795,20 @@ static bool detectShiftUntilZeroIdiom(Loop *CurLoop, const DataLayout &DL, // plus "cnt0". Currently it is not optimized. // This step could be used to detect POPCNT instruction: // cnt.next = cnt + (x.next & 1) - for (BasicBlock::iterator Iter = LoopEntry->getFirstNonPHI()->getIterator(), - IterE = LoopEntry->end(); - Iter != IterE; Iter++) { - Instruction *Inst = &*Iter; - if (Inst->getOpcode() != Instruction::Add) + for (Instruction &Inst : llvm::make_range( + LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) { + if (Inst.getOpcode() != Instruction::Add) continue; - ConstantInt *Inc = dyn_cast<ConstantInt>(Inst->getOperand(1)); + ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1)); if (!Inc || (!Inc->isOne() && !Inc->isMinusOne())) continue; - PHINode *Phi = getRecurrenceVar(Inst->getOperand(0), Inst, LoopEntry); + PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry); if (!Phi) continue; - CntInst = Inst; + CntInst = &Inst; CntPhi = Phi; break; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index 3153a8721193..b9e63a4bc06f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -105,9 +105,7 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, if (!V || !LI.replacementPreservesLCSSAForm(&I, V)) continue; - for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); - UI != UE;) { - Use &U = *UI++; + for (Use &U : llvm::make_early_inc_range(I.uses())) { auto *UserI = cast<Instruction>(U.getUser()); U.set(V); @@ -195,15 +193,10 @@ public: const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( *L->getHeader()->getParent()); - MemorySSA *MSSA = nullptr; - Optional<MemorySSAUpdater> MSSAU; - if (EnableMSSALoopDependency) { - MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = MemorySSAUpdater(MSSA); - } + MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MemorySSAUpdater MSSAU(MSSA); - return simplifyLoopInst(*L, DT, LI, AC, TLI, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + return simplifyLoopInst(*L, DT, LI, AC, TLI, &MSSAU); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -211,10 +204,8 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.setPreservesCFG(); - if (EnableMSSALoopDependency) { - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); getLoopAnalysisUsage(AU); } }; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 34545f35b3c3..9f605b4ac4ad 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1710,16 +1710,12 @@ bool LoopInterchangeTransform::adjustLoopBranches() { auto &OuterInnerReductions = LIL.getOuterInnerReductions(); // Now update the reduction PHIs in the inner and outer loop headers. SmallVector<PHINode *, 4> InnerLoopPHIs, OuterLoopPHIs; - for (PHINode &PHI : InnerLoopHeader->phis()) { - if (OuterInnerReductions.find(&PHI) == OuterInnerReductions.end()) - continue; - InnerLoopPHIs.push_back(cast<PHINode>(&PHI)); - } - for (PHINode &PHI : OuterLoopHeader->phis()) { - if (OuterInnerReductions.find(&PHI) == OuterInnerReductions.end()) - continue; - OuterLoopPHIs.push_back(cast<PHINode>(&PHI)); - } + for (PHINode &PHI : InnerLoopHeader->phis()) + if (OuterInnerReductions.contains(&PHI)) + InnerLoopPHIs.push_back(cast<PHINode>(&PHI)); + for (PHINode &PHI : OuterLoopHeader->phis()) + if (OuterInnerReductions.contains(&PHI)) + OuterLoopPHIs.push_back(cast<PHINode>(&PHI)); // Now move the remaining reduction PHIs from outer to inner loop header and // vice versa. The PHI nodes must be part of a reduction across the inner and @@ -1767,6 +1763,7 @@ bool LoopInterchangeTransform::adjustLoopLinks() { return Changed; } +namespace { /// Main LoopInterchange Pass. struct LoopInterchangeLegacyPass : public LoopPass { static char ID; @@ -1795,6 +1792,7 @@ struct LoopInterchangeLegacyPass : public LoopPass { return LoopInterchange(SE, LI, DI, DT, ORE).run(L); } }; +} // namespace char LoopInterchangeLegacyPass::ID = 0; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index aaf586173e44..21d59936616b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -34,7 +34,6 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -109,8 +108,8 @@ struct StoreToLoadForwardingCandidate { // Currently we only support accesses with unit stride. FIXME: we should be // able to handle non unit stirde as well as long as the stride is equal to // the dependence distance. - if (getPtrStride(PSE, LoadPtr, L) != 1 || - getPtrStride(PSE, StorePtr, L) != 1) + if (getPtrStride(PSE, LoadType, LoadPtr, L) != 1 || + getPtrStride(PSE, LoadType, StorePtr, L) != 1) return false; auto &DL = Load->getParent()->getModule()->getDataLayout(); @@ -718,15 +717,12 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; - MemorySSA *MSSA = EnableMSSALoopDependency - ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() - : nullptr; auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); bool Changed = eliminateLoadsAcrossLoops( F, LI, DT, BFI, PSI, &SE, &AC, [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, MSSA}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, + TLI, TTI, nullptr, nullptr, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index f4fce4871331..3df4cfe8e4c1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -10,6 +10,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" @@ -44,6 +45,18 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, return PA; } +void PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, + LPMUpdater &>::printPipeline(raw_ostream &OS, + function_ref<StringRef(StringRef)> + MapClassName2PassName) { + for (unsigned Idx = 0, Size = LoopPasses.size(); Idx != Size; ++Idx) { + auto *P = LoopPasses[Idx].get(); + P->printPipeline(OS, MapClassName2PassName); + if (Idx + 1 < Size) + OS << ","; + } +} + // Run both loop passes and loop-nest passes on top-level loop \p L. PreservedAnalyses LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM, @@ -112,12 +125,6 @@ LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // notify the updater, otherwise U.ParentL might gets outdated and triggers // assertion failures in addSiblingLoops and addChildLoops. U.setParentLoop(L.getParentLoop()); - - // FIXME: Historically, the pass managers all called the LLVM context's - // yield function here. We don't have a generic way to acquire the - // context and it isn't yet clear what the right pattern is for yielding - // in the new pass manager so it is currently omitted. - // ...getContext().yield(); } return PA; } @@ -161,17 +168,17 @@ LoopPassManager::runWithoutLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // notify the updater, otherwise U.ParentL might gets outdated and triggers // assertion failures in addSiblingLoops and addChildLoops. U.setParentLoop(L.getParentLoop()); - - // FIXME: Historically, the pass managers all called the LLVM context's - // yield function here. We don't have a generic way to acquire the - // context and it isn't yet clear what the right pattern is for yielding - // in the new pass manager so it is currently omitted. - // ...getContext().yield(); } return PA; } } // namespace llvm +void FunctionToLoopPassAdaptor::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + OS << (UseMemorySSA ? "loop-mssa(" : "loop("); + Pass->printPipeline(OS, MapClassName2PassName); + OS << ")"; +} PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, FunctionAnalysisManager &AM) { // Before we even compute any loop analyses, first run a miniature function @@ -201,6 +208,10 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, BlockFrequencyInfo *BFI = UseBlockFrequencyInfo && F.hasProfileData() ? (&AM.getResult<BlockFrequencyAnalysis>(F)) : nullptr; + BranchProbabilityInfo *BPI = + UseBranchProbabilityInfo && F.hasProfileData() + ? (&AM.getResult<BranchProbabilityAnalysis>(F)) + : nullptr; LoopStandardAnalysisResults LAR = {AM.getResult<AAManager>(F), AM.getResult<AssumptionAnalysis>(F), AM.getResult<DominatorTreeAnalysis>(F), @@ -209,6 +220,7 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, AM.getResult<TargetLibraryAnalysis>(F), AM.getResult<TargetIRAnalysis>(F), BFI, + BPI, MSSA}; // Setup the loop analysis manager from its proxy. It is important that @@ -285,6 +297,10 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, else PI.runAfterPass<Loop>(*Pass, *L, PassPA); + if (LAR.MSSA && !PassPA.getChecker<MemorySSAAnalysis>().preserved()) + report_fatal_error("Loop pass manager using MemorySSA contains a pass " + "that does not preserve MemorySSA"); + #ifndef NDEBUG // LoopAnalysisResults should always be valid. // Note that we don't LAR.SE.verify() because that can change observed SE @@ -325,6 +341,8 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, PA.preserve<ScalarEvolutionAnalysis>(); if (UseBlockFrequencyInfo && F.hasProfileData()) PA.preserve<BlockFrequencyAnalysis>(); + if (UseBranchProbabilityInfo && F.hasProfileData()) + PA.preserve<BranchProbabilityAnalysis>(); if (UseMemorySSA) PA.preserve<MemorySSAAnalysis>(); return PA; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 4f97641e2027..aa7e79a589f2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -183,6 +183,8 @@ #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/Function.h" @@ -254,7 +256,7 @@ class LoopPredication { DominatorTree *DT; ScalarEvolution *SE; LoopInfo *LI; - BranchProbabilityInfo *BPI; + MemorySSAUpdater *MSSAU; Loop *L; const DataLayout *DL; @@ -302,16 +304,15 @@ class LoopPredication { // If the loop always exits through another block in the loop, we should not // predicate based on the latch check. For example, the latch check can be a // very coarse grained check and there can be more fine grained exit checks - // within the loop. We identify such unprofitable loops through BPI. + // within the loop. bool isLoopProfitableToPredicate(); bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); public: - LoopPredication(AliasAnalysis *AA, DominatorTree *DT, - ScalarEvolution *SE, LoopInfo *LI, - BranchProbabilityInfo *BPI) - : AA(AA), DT(DT), SE(SE), LI(LI), BPI(BPI) {}; + LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE, + LoopInfo *LI, MemorySSAUpdater *MSSAU) + : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){}; bool runOnLoop(Loop *L); }; @@ -325,6 +326,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<BranchProbabilityInfoWrapperPass>(); getLoopAnalysisUsage(AU); + AU.addPreserved<MemorySSAWrapperPass>(); } bool runOnLoop(Loop *L, LPPassManager &LPM) override { @@ -333,10 +335,12 @@ public: auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - BranchProbabilityInfo &BPI = - getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSAWP) + MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - LoopPredication LP(AA, DT, SE, LI, &BPI); + LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr); return LP.runOnLoop(L); } }; @@ -358,16 +362,18 @@ Pass *llvm::createLoopPredicationPass() { PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - Function *F = L.getHeader()->getParent(); - // For the new PM, we also can't use BranchProbabilityInfo as an analysis - // pass. Function analyses need to be preserved across loop transformations - // but BPI is not preserved, hence a newly built one is needed. - BranchProbabilityInfo BPI(*F, AR.LI, &AR.TLI, &AR.DT, nullptr); - LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, &BPI); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (AR.MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA); + LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, + MSSAU ? MSSAU.get() : nullptr); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; } Optional<LoopICmp> @@ -809,7 +815,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = Guard->getOperand(0); Guard->setOperand(0, AllChecks); - RecursivelyDeleteTriviallyDeadInstructions(OldCond); + RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); return true; @@ -835,7 +841,7 @@ bool LoopPredication::widenWidenableBranchGuardConditions( Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = BI->getCondition(); BI->setCondition(AllChecks); - RecursivelyDeleteTriviallyDeadInstructions(OldCond); + RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); assert(isGuardAsWidenableBranch(BI) && "Stopped being a guard after transform?"); @@ -912,7 +918,7 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { bool LoopPredication::isLoopProfitableToPredicate() { - if (SkipProfitabilityChecks || !BPI) + if (SkipProfitabilityChecks) return true; SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; @@ -934,8 +940,61 @@ bool LoopPredication::isLoopProfitableToPredicate() { "expected to be an exiting block with 2 succs!"); unsigned LatchBrExitIdx = LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; + // We compute branch probabilities without BPI. We do not rely on BPI since + // Loop predication is usually run in an LPM and BPI is only preserved + // lossily within loop pass managers, while BPI has an inherent notion of + // being complete for an entire function. + + // If the latch exits into a deoptimize or an unreachable block, do not + // predicate on that latch check. + auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx); + if (isa<UnreachableInst>(LatchTerm) || + LatchExitBlock->getTerminatingDeoptimizeCall()) + return false; + + auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) { + if (!ProfileData || !ProfileData->getOperand(0)) + return false; + if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0))) + if (!MDS->getString().equals("branch_weights")) + return false; + if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) + return false; + return true; + }; + MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof); + // Latch terminator has no valid profile data, so nothing to check + // profitability on. + if (!IsValidProfileData(LatchProfileData, LatchTerm)) + return true; + + auto ComputeBranchProbability = + [&](const BasicBlock *ExitingBlock, + const BasicBlock *ExitBlock) -> BranchProbability { + auto *Term = ExitingBlock->getTerminator(); + MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); + unsigned NumSucc = Term->getNumSuccessors(); + if (IsValidProfileData(ProfileData, Term)) { + uint64_t Numerator = 0, Denominator = 0, ProfVal = 0; + for (unsigned i = 0; i < NumSucc; i++) { + ConstantInt *CI = + mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1)); + ProfVal = CI->getValue().getZExtValue(); + if (Term->getSuccessor(i) == ExitBlock) + Numerator += ProfVal; + Denominator += ProfVal; + } + return BranchProbability::getBranchProbability(Numerator, Denominator); + } else { + assert(LatchBlock != ExitingBlock && + "Latch term should always have profile data!"); + // No profile data, so we choose the weight as 1/num_of_succ(Src) + return BranchProbability::getBranchProbability(1, NumSucc); + } + }; + BranchProbability LatchExitProbability = - BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); + ComputeBranchProbability(LatchBlock, LatchExitBlock); // Protect against degenerate inputs provided by the user. Providing a value // less than one, can invert the definition of profitable loop predication. @@ -948,18 +1007,18 @@ bool LoopPredication::isLoopProfitableToPredicate() { LLVM_DEBUG(dbgs() << "The value is set to 1.0\n"); ScaleFactor = 1.0; } - const auto LatchProbabilityThreshold = - LatchExitProbability * ScaleFactor; + const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor; for (const auto &ExitEdge : ExitEdges) { BranchProbability ExitingBlockProbability = - BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second); + ComputeBranchProbability(ExitEdge.first, ExitEdge.second); // Some exiting edge has higher probability than the latch exiting edge. // No longer profitable to predicate. if (ExitingBlockProbability > LatchProbabilityThreshold) return false; } - // Using BPI, we have concluded that the most probable way to exit from the + + // We have concluded that the most probable way to exit from the // loop is through the latch (or there's no profile information and all // exits are equally likely). return true; @@ -1071,28 +1130,26 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // widen so that we gain ability to analyze it's exit count and perform this // transform. TODO: It'd be nice to know for sure the exit became // analyzeable after dropping widenability. - { - bool Invalidate = false; + bool ChangedLoop = false; - for (auto *ExitingBB : ExitingBlocks) { - if (LI->getLoopFor(ExitingBB) != L) - continue; + for (auto *ExitingBB : ExitingBlocks) { + if (LI->getLoopFor(ExitingBB) != L) + continue; - auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); - if (!BI) - continue; + auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); + if (!BI) + continue; - Use *Cond, *WC; - BasicBlock *IfTrueBB, *IfFalseBB; - if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) && - L->contains(IfTrueBB)) { - WC->set(ConstantInt::getTrue(IfTrueBB->getContext())); - Invalidate = true; - } + Use *Cond, *WC; + BasicBlock *IfTrueBB, *IfFalseBB; + if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) && + L->contains(IfTrueBB)) { + WC->set(ConstantInt::getTrue(IfTrueBB->getContext())); + ChangedLoop = true; } - if (Invalidate) - SE->forgetLoop(L); } + if (ChangedLoop) + SE->forgetLoop(L); // The use of umin(all analyzeable exits) instead of latch is subtle, but // important for profitability. We may have a loop which hasn't been fully @@ -1104,18 +1161,24 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() || !SE->isLoopInvariant(MinEC, L) || !isSafeToExpandAt(MinEC, WidenableBR, *SE)) - return false; + return ChangedLoop; // Subtlety: We need to avoid inserting additional uses of the WC. We know // that it can only have one transitive use at the moment, and thus moving // that use to just before the branch and inserting code before it and then // modifying the operand is legal. auto *IP = cast<Instruction>(WidenableBR->getCondition()); + // Here we unconditionally modify the IR, so after this point we should return + // only `true`! IP->moveBefore(WidenableBR); + if (MSSAU) + if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP)) + MSSAU->moveToPlace(MUD, WidenableBR->getParent(), + MemorySSA::BeforeTerminator); Rewriter.setInsertPoint(IP); IRBuilder<> B(IP); - bool Changed = false; + bool InvalidateLoop = false; Value *MinECV = nullptr; // lazily generated if needed for (BasicBlock *ExitingBB : ExitingBlocks) { // If our exiting block exits multiple loops, we can only rewrite the @@ -1172,16 +1235,18 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { Value *OldCond = BI->getCondition(); BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue)); - Changed = true; + InvalidateLoop = true; } - if (Changed) + if (InvalidateLoop) // We just mutated a bunch of loop exits changing there exit counts // widely. We need to force recomputation of the exit counts given these // changes. Note that all of the inserted exits are never taken, and // should be removed next time the CFG is modified. SE->forgetLoop(L); - return Changed; + + // Always return `true` since we have moved the WidenableBR's condition. + return true; } bool LoopPredication::runOnLoop(Loop *Loop) { @@ -1242,5 +1307,8 @@ bool LoopPredication::runOnLoop(Loop *Loop) { for (auto *Guard : GuardsAsWidenableBranches) Changed |= widenWidenableBranchGuardConditions(Guard, Expander); Changed |= predicateLoopExits(L, Expander); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp index 6d5b19443c76..5ba137b1c85f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -99,8 +99,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - if (EnableMSSALoopDependency) - AU.addPreserved<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); getLoopAnalysisUsage(AU); // Lazy BFI and BPI are marked as preserved here so LoopRotate @@ -121,13 +120,11 @@ public: auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); Optional<MemorySSAUpdater> MSSAU; - if (EnableMSSALoopDependency) { - // Not requiring MemorySSA and getting it only if available will split - // the loop pass pipeline when LoopRotate is being run first. - auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - if (MSSAA) - MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); - } + // Not requiring MemorySSA and getting it only if available will split + // the loop pass pipeline when LoopRotate is being run first. + auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + if (MSSAA) + MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); // Vectorization requires loop-rotation. Use default threshold for loops the // user explicitly marked for vectorization, even when header duplication is // disabled. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index cc6d11220807..a87843d658a9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -733,13 +733,12 @@ public: DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); Optional<MemorySSAUpdater> MSSAU; - if (EnableMSSALoopDependency) { - MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = MemorySSAUpdater(MSSA); - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - } + if (MSSAA) + MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); + if (MSSAA && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); bool DeleteCurrentLoop = false; bool Changed = simplifyLoopCFG( *L, DT, LI, SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, @@ -750,10 +749,7 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { - if (EnableMSSALoopDependency) { - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addPreserved<MemorySSAWrapperPass>(); AU.addPreserved<DependenceAnalysisWrapperPass>(); getLoopAnalysisUsage(AU); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp index a01287f587d7..c9c9e60d0921 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -323,15 +323,14 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, // Traverse preheader's instructions in reverse order becaue if A depends // on B (A appears after B), A needs to be sinked first before B can be // sinked. - for (auto II = Preheader->rbegin(), E = Preheader->rend(); II != E;) { - Instruction *I = &*II++; + for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { // No need to check for instruction's operands are loop invariant. - assert(L.hasLoopInvariantOperands(I) && + assert(L.hasLoopInvariantOperands(&I) && "Insts in a loop's preheader should have loop invariant operands!"); - if (!canSinkOrHoistInst(*I, &AA, &DT, &L, CurAST, MSSAU.get(), false, + if (!canSinkOrHoistInst(I, &AA, &DT, &L, CurAST, MSSAU.get(), false, LICMFlags.get())) continue; - if (sinkInstruction(L, *I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI, + if (sinkInstruction(L, I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI, MSSAU.get())) Changed = true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 404852f1dd4d..a9a2266e1196 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -136,6 +136,12 @@ using namespace llvm; /// worst cases before LSR burns too much compile time and stack space. static const unsigned MaxIVUsers = 200; +/// Limit the size of expression that SCEV-based salvaging will attempt to +/// translate into a DIExpression. +/// Choose a maximum size such that debuginfo is not excessively increased and +/// the salvaging is not too expensive for the compiler. +static const unsigned MaxSCEVSalvageExpressionSize = 64; + // Temporary flag to cleanup congruent phis after LSR phi expansion. // It's currently disabled until we can determine whether it's truly useful or // not. The flag should be removed after the v3.0 release. @@ -689,7 +695,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, const APInt &RA = RC->getAPInt(); // Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do // some folding. - if (RA.isAllOnesValue()) { + if (RA.isAllOnes()) { if (LHS->getType()->isPointerTy()) return nullptr; return SE.getMulExpr(LHS, RC); @@ -2816,9 +2822,7 @@ static const SCEV *getExprBase(const SCEV *S) { // there's nothing more complex. // FIXME: not sure if we want to recognize negation. const SCEVAddExpr *Add = cast<SCEVAddExpr>(S); - for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(Add->op_end()), - E(Add->op_begin()); I != E; ++I) { - const SCEV *SubExpr = *I; + for (const SCEV *SubExpr : reverse(Add->operands())) { if (SubExpr->getSCEVType() == scAddExpr) return getExprBase(SubExpr); @@ -3150,7 +3154,7 @@ void LSRInstance::CollectChains() { void LSRInstance::FinalizeChain(IVChain &Chain) { assert(!Chain.Incs.empty() && "empty IV chains are not allowed"); LLVM_DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n"); - + for (const IVInc &Inc : Chain) { LLVM_DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n"); auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand); @@ -3385,7 +3389,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { void LSRInstance::InsertInitialFormula(const SCEV *S, LSRUse &LU, size_t LUIdx) { // Mark uses whose expressions cannot be expanded. - if (!isSafeToExpand(S, SE)) + if (!isSafeToExpand(S, SE, /*CanonicalMode*/ false)) LU.RigidFormula = true; Formula F; @@ -3934,6 +3938,9 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, // Check each interesting stride. for (int64_t Factor : Factors) { + // Check that Factor can be represented by IntTy + if (!ConstantInt::isValueValidForType(IntTy, Factor)) + continue; // Check that the multiplication doesn't overflow. if (Base.BaseOffset == std::numeric_limits<int64_t>::min() && Factor == -1) continue; @@ -4082,6 +4089,14 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { if (DstTy->isPointerTy()) return; + // It is invalid to extend a pointer type so exit early if ScaledReg or + // any of the BaseRegs are pointers. + if (Base.ScaledReg && Base.ScaledReg->getType()->isPointerTy()) + return; + if (any_of(Base.BaseRegs, + [](const SCEV *S) { return S->getType()->isPointerTy(); })) + return; + for (Type *SrcTy : Types) { if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) { Formula F = Base; @@ -5689,23 +5704,6 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, } } -#ifndef NDEBUG - // All dominating loops must have preheaders, or SCEVExpander may not be able - // to materialize an AddRecExpr whose Start is an outer AddRecExpr. - // - // IVUsers analysis should only create users that are dominated by simple loop - // headers. Since this loop should dominate all of its users, its user list - // should be empty if this loop itself is not within a simple loop nest. - for (DomTreeNode *Rung = DT.getNode(L->getLoopPreheader()); - Rung; Rung = Rung->getIDom()) { - BasicBlock *BB = Rung->getBlock(); - const Loop *DomLoop = LI.getLoopFor(BB); - if (DomLoop && DomLoop->getHeader() == BB) { - assert(DomLoop->getLoopPreheader() && "LSR needs a simplified loop nest"); - } - } -#endif // DEBUG - LLVM_DEBUG(dbgs() << "\nLSR on loop "; L->getHeader()->printAsOperand(dbgs(), /*PrintType=*/false); dbgs() << ":\n"); @@ -5870,6 +5868,7 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { AU.addPreserved<MemorySSAWrapperPass>(); } +namespace { struct SCEVDbgValueBuilder { SCEVDbgValueBuilder() = default; SCEVDbgValueBuilder(const SCEVDbgValueBuilder &Base) { @@ -6117,14 +6116,15 @@ struct DVIRecoveryRec { Metadata *LocationOp; const llvm::SCEV *SCEV; }; +} // namespace -static bool RewriteDVIUsingIterCount(DVIRecoveryRec CachedDVI, +static void RewriteDVIUsingIterCount(DVIRecoveryRec CachedDVI, const SCEVDbgValueBuilder &IterationCount, ScalarEvolution &SE) { // LSR may add locations to previously single location-op DVIs which // are currently not supported. if (CachedDVI.DVI->getNumVariableLocationOps() != 1) - return false; + return; // SCEVs for SSA values are most frquently of the form // {start,+,stride}, but sometimes they are ({start,+,stride} + %a + ..). @@ -6132,48 +6132,70 @@ static bool RewriteDVIUsingIterCount(DVIRecoveryRec CachedDVI, // SCEVs have not been observed to result in debuginfo-lossy optimisations, // so its not expected this point will be reached. if (!isa<SCEVAddRecExpr>(CachedDVI.SCEV)) - return false; + return; LLVM_DEBUG(dbgs() << "scev-salvage: Value to salvage SCEV: " << *CachedDVI.SCEV << '\n'); const auto *Rec = cast<SCEVAddRecExpr>(CachedDVI.SCEV); if (!Rec->isAffine()) - return false; + return; + + if (CachedDVI.SCEV->getExpressionSize() > MaxSCEVSalvageExpressionSize) + return; // Initialise a new builder with the iteration count expression. In // combination with the value's SCEV this enables recovery. SCEVDbgValueBuilder RecoverValue(IterationCount); if (!RecoverValue.SCEVToValueExpr(*Rec, SE)) - return false; + return; LLVM_DEBUG(dbgs() << "scev-salvage: Updating: " << *CachedDVI.DVI << '\n'); RecoverValue.applyExprToDbgValue(*CachedDVI.DVI, CachedDVI.Expr); LLVM_DEBUG(dbgs() << "scev-salvage: to: " << *CachedDVI.DVI << '\n'); - return true; } -static bool +static void RewriteDVIUsingOffset(DVIRecoveryRec &DVIRec, llvm::PHINode &IV, + int64_t Offset) { + assert(!DVIRec.DVI->hasArgList() && "Expected single location-op dbg.value."); + DbgValueInst *DVI = DVIRec.DVI; + SmallVector<uint64_t, 8> Ops; + DIExpression::appendOffset(Ops, Offset); + DIExpression *Expr = DIExpression::prependOpcodes(DVIRec.Expr, Ops, true); + LLVM_DEBUG(dbgs() << "scev-salvage: Updating: " << *DVIRec.DVI << '\n'); + DVI->setExpression(Expr); + llvm::Value *ValIV = dyn_cast<llvm::Value>(&IV); + DVI->replaceVariableLocationOp( + 0u, llvm::MetadataAsValue::get(DVI->getContext(), + llvm::ValueAsMetadata::get(ValIV))); + LLVM_DEBUG(dbgs() << "scev-salvage: updated with offset to IV: " + << *DVIRec.DVI << '\n'); +} + +static void DbgRewriteSalvageableDVIs(llvm::Loop *L, ScalarEvolution &SE, llvm::PHINode *LSRInductionVar, SmallVector<DVIRecoveryRec, 2> &DVIToUpdate) { if (DVIToUpdate.empty()) - return false; + return; const llvm::SCEV *SCEVInductionVar = SE.getSCEV(LSRInductionVar); assert(SCEVInductionVar && "Anticipated a SCEV for the post-LSR induction variable"); - bool Changed = false; if (const SCEVAddRecExpr *IVAddRec = dyn_cast<SCEVAddRecExpr>(SCEVInductionVar)) { if (!IVAddRec->isAffine()) - return false; + return; + if (IVAddRec->getExpressionSize() > MaxSCEVSalvageExpressionSize) + return; + + // The iteration count is required to recover location values. SCEVDbgValueBuilder IterCountExpr; IterCountExpr.pushValue(LSRInductionVar); if (!IterCountExpr.SCEVToIterCountExpr(*IVAddRec, SE)) - return false; + return; LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV: " << *SCEVInductionVar << '\n'); @@ -6196,14 +6218,26 @@ DbgRewriteSalvageableDVIs(llvm::Loop *L, ScalarEvolution &SE, DVIRec.DVI->setExpression(DVIRec.Expr); } - Changed |= RewriteDVIUsingIterCount(DVIRec, IterCountExpr, SE); + LLVM_DEBUG(dbgs() << "scev-salvage: value to recover SCEV: " + << *DVIRec.SCEV << '\n'); + + // Create a simple expression if the IV and value to salvage SCEVs + // start values differ by only a constant value. + if (Optional<APInt> Offset = + SE.computeConstantDifference(DVIRec.SCEV, SCEVInductionVar)) { + if (Offset.getValue().getMinSignedBits() <= 64) + RewriteDVIUsingOffset(DVIRec, *LSRInductionVar, + Offset.getValue().getSExtValue()); + } else { + RewriteDVIUsingIterCount(DVIRec, IterCountExpr, SE); + } } } - return Changed; } /// Identify and cache salvageable DVI locations and expressions along with the -/// corresponding SCEV(s). Also ensure that the DVI is not deleted before +/// corresponding SCEV(s). Also ensure that the DVI is not deleted between +/// cacheing and salvaging. static void DbgGatherSalvagableDVI(Loop *L, ScalarEvolution &SE, SmallVector<DVIRecoveryRec, 2> &SalvageableDVISCEVs, @@ -6214,6 +6248,9 @@ DbgGatherSalvagableDVI(Loop *L, ScalarEvolution &SE, if (!DVI) continue; + if (DVI->isUndef()) + continue; + if (DVI->hasArgList()) continue; @@ -6221,6 +6258,16 @@ DbgGatherSalvagableDVI(Loop *L, ScalarEvolution &SE, !SE.isSCEVable(DVI->getVariableLocationOp(0)->getType())) continue; + // SCEVUnknown wraps an llvm::Value, it does not have a start and stride. + // Therefore no translation to DIExpression is performed. + const SCEV *S = SE.getSCEV(DVI->getVariableLocationOp(0)); + if (isa<SCEVUnknown>(S)) + continue; + + // Avoid wasting resources generating an expression containing undef. + if (SE.containsUndefs(S)) + continue; + SalvageableDVISCEVs.push_back( {DVI, DVI->getExpression(), DVI->getRawLocation(), SE.getSCEV(DVI->getVariableLocationOp(0))}); @@ -6234,33 +6281,32 @@ DbgGatherSalvagableDVI(Loop *L, ScalarEvolution &SE, /// surviving subsequent transforms. static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE, const LSRInstance &LSR) { - // For now, just pick the first IV generated and inserted. Ideally pick an IV - // that is unlikely to be optimised away by subsequent transforms. + + auto IsSuitableIV = [&](PHINode *P) { + if (!SE.isSCEVable(P->getType())) + return false; + if (const SCEVAddRecExpr *Rec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(P))) + return Rec->isAffine() && !SE.containsUndefs(SE.getSCEV(P)); + return false; + }; + + // For now, just pick the first IV that was generated and inserted by + // ScalarEvolution. Ideally pick an IV that is unlikely to be optimised away + // by subsequent transforms. for (const WeakVH &IV : LSR.getScalarEvolutionIVs()) { if (!IV) continue; - assert(isa<PHINode>(&*IV) && "Expected PhI node."); - if (SE.isSCEVable((*IV).getType())) { - PHINode *Phi = dyn_cast<PHINode>(&*IV); - LLVM_DEBUG(dbgs() << "scev-salvage: IV : " << *IV - << "with SCEV: " << *SE.getSCEV(Phi) << "\n"); - return Phi; - } - } + // There should only be PHI node IVs. + PHINode *P = cast<PHINode>(&*IV); - for (PHINode &Phi : L.getHeader()->phis()) { - if (!SE.isSCEVable(Phi.getType())) - continue; - - const llvm::SCEV *PhiSCEV = SE.getSCEV(&Phi); - if (const llvm::SCEVAddRecExpr *Rec = dyn_cast<SCEVAddRecExpr>(PhiSCEV)) - if (!Rec->isAffine()) - continue; + if (IsSuitableIV(P)) + return P; + } - LLVM_DEBUG(dbgs() << "scev-salvage: Selected IV from loop header: " << Phi - << " with SCEV: " << *PhiSCEV << "\n"); - return Φ + for (PHINode &P : L.getHeader()->phis()) { + if (IsSuitableIV(&P)) + return &P; } return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 71eb393fcdd7..1ecbb86724e1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -286,8 +286,8 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, AssumptionCache &AC, DependenceInfo &DI, OptimizationRemarkEmitter &ORE, int OptLevel) { TargetTransformInfo::UnrollingPreferences UP = - gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, OptLevel, None, - None, None, None, None, None); + gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, ORE, OptLevel, + None, None, None, None, None, None); TargetTransformInfo::PeelingPreferences PP = gatherPeelingPreferences(L, SE, TTI, None, None); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 49501f324a49..67702520511b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -184,7 +184,8 @@ static const unsigned NoThreshold = std::numeric_limits<unsigned>::max(); /// flags, TTI overrides and user specified parameters. TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, - BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, + OptimizationRemarkEmitter &ORE, int OptLevel, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, Optional<bool> UserUpperBound, Optional<unsigned> UserFullUnrollMaxCount) { @@ -214,7 +215,7 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze; // Override with any target specific settings - TTI.getUnrollingPreferences(L, SE, UP); + TTI.getUnrollingPreferences(L, SE, UP, &ORE); // Apply size attributes bool OptForSize = L->getHeader()->getParent()->hasOptSize() || @@ -318,6 +319,16 @@ struct EstimatedUnrollCost { unsigned RolledDynamicCost; }; +struct PragmaInfo { + PragmaInfo(bool UUC, bool PFU, unsigned PC, bool PEU) + : UserUnrollCount(UUC), PragmaFullUnroll(PFU), PragmaCount(PC), + PragmaEnableUnroll(PEU) {} + const bool UserUnrollCount; + const bool PragmaFullUnroll; + const unsigned PragmaCount; + const bool PragmaEnableUnroll; +}; + } // end anonymous namespace /// Figure out if the loop is worth full unrolling. @@ -746,13 +757,132 @@ public: // Returns loop size estimation for unrolled loop, given the unrolling // configuration specified by UP. - uint64_t getUnrolledLoopSize(TargetTransformInfo::UnrollingPreferences &UP) { + uint64_t + getUnrolledLoopSize(const TargetTransformInfo::UnrollingPreferences &UP, + const unsigned CountOverwrite = 0) const { assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!"); - return (uint64_t)(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns; + if (CountOverwrite) + return static_cast<uint64_t>(LoopSize - UP.BEInsns) * CountOverwrite + + UP.BEInsns; + else + return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + + UP.BEInsns; } }; +static Optional<unsigned> +shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, + const unsigned TripMultiple, const unsigned TripCount, + const UnrollCostEstimator UCE, + const TargetTransformInfo::UnrollingPreferences &UP) { + + // Using unroll pragma + // 1st priority is unroll count set by "unroll-count" option. + + if (PInfo.UserUnrollCount) { + if (UP.AllowRemainder && + UCE.getUnrolledLoopSize(UP, (unsigned)UnrollCount) < UP.Threshold) + return (unsigned)UnrollCount; + } + + // 2nd priority is unroll count set by pragma. + if (PInfo.PragmaCount > 0) { + if ((UP.AllowRemainder || (TripMultiple % PInfo.PragmaCount == 0)) && + UCE.getUnrolledLoopSize(UP, PInfo.PragmaCount) < PragmaUnrollThreshold) + return PInfo.PragmaCount; + } + + if (PInfo.PragmaFullUnroll && TripCount != 0) { + if (UCE.getUnrolledLoopSize(UP, TripCount) < PragmaUnrollThreshold) + return TripCount; + } + // if didn't return until here, should continue to other priorties + return None; +} + +static Optional<unsigned> shouldFullUnroll( + Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, + ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, + const unsigned FullUnrollTripCount, const UnrollCostEstimator UCE, + const TargetTransformInfo::UnrollingPreferences &UP) { + + if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) { + // When computing the unrolled size, note that BEInsns are not replicated + // like the rest of the loop body. + if (UCE.getUnrolledLoopSize(UP) < UP.Threshold) { + return FullUnrollTripCount; + + } else { + // The loop isn't that small, but we still can fully unroll it if that + // helps to remove a significant number of instructions. + // To check that, run additional analysis on the loop. + if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( + L, FullUnrollTripCount, DT, SE, EphValues, TTI, + UP.Threshold * UP.MaxPercentThresholdBoost / 100, + UP.MaxIterationsCountToAnalyze)) { + unsigned Boost = + getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); + if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { + return FullUnrollTripCount; + } + } + } + } + return None; +} + +static Optional<unsigned> +shouldPartialUnroll(const unsigned LoopSize, const unsigned TripCount, + const UnrollCostEstimator UCE, + const TargetTransformInfo::UnrollingPreferences &UP) { + + unsigned count = UP.Count; + if (TripCount) { + if (!UP.Partial) { + LLVM_DEBUG(dbgs() << " will not try to unroll partially because " + << "-unroll-allow-partial not given\n"); + count = 0; + return count; + } + if (count == 0) + count = TripCount; + if (UP.PartialThreshold != NoThreshold) { + // Reduce unroll count to be modulo of TripCount for partial unrolling. + if (UCE.getUnrolledLoopSize(UP, count) > UP.PartialThreshold) + count = (std::max(UP.PartialThreshold, UP.BEInsns + 1) - UP.BEInsns) / + (LoopSize - UP.BEInsns); + if (count > UP.MaxCount) + count = UP.MaxCount; + while (count != 0 && TripCount % count != 0) + count--; + if (UP.AllowRemainder && count <= 1) { + // If there is no Count that is modulo of TripCount, set Count to + // largest power-of-two factor that satisfies the threshold limit. + // As we'll create fixup loop, do the type of unrolling only if + // remainder loop is allowed. + count = UP.DefaultUnrollRuntimeCount; + while (count != 0 && + UCE.getUnrolledLoopSize(UP, count) > UP.PartialThreshold) + count >>= 1; + } + if (count < 2) { + count = 0; + } + } else { + count = TripCount; + } + if (count > UP.MaxCount) + count = UP.MaxCount; + + LLVM_DEBUG(dbgs() << " partially unrolling with count: " << count << "\n"); + + return count; + } + + // if didn't return until here, should continue to other priorties + return None; +} // Returns true if unroll count was set explicitly. // Calculates unroll count and writes it to UP.Count. // Unless IgnoreUser is true, will also use metadata and command-line options @@ -770,7 +900,18 @@ bool llvm::computeUnrollCount( TargetTransformInfo::PeelingPreferences &PP, bool &UseUpperBound) { UnrollCostEstimator UCE(*L, LoopSize); + Optional<unsigned> UnrollFactor; + + const bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; + const bool PragmaFullUnroll = hasUnrollFullPragma(L); + const unsigned PragmaCount = unrollCountPragmaValue(L); + const bool PragmaEnableUnroll = hasUnrollEnablePragma(L); + const bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll || + PragmaEnableUnroll || UserUnrollCount; + + PragmaInfo PInfo(UserUnrollCount, PragmaFullUnroll, PragmaCount, + PragmaEnableUnroll); // Use an explicit peel count that has been specified for testing. In this // case it's not permitted to also specify an explicit unroll count. if (PP.PeelCount) { @@ -782,47 +923,29 @@ bool llvm::computeUnrollCount( UP.Runtime = false; return true; } - // Check for explicit Count. // 1st priority is unroll count set by "unroll-count" option. - bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; - if (UserUnrollCount) { - UP.Count = UnrollCount; - UP.AllowExpensiveTripCount = true; - UP.Force = true; - if (UP.AllowRemainder && UCE.getUnrolledLoopSize(UP) < UP.Threshold) - return true; - } - // 2nd priority is unroll count set by pragma. - unsigned PragmaCount = unrollCountPragmaValue(L); - if (PragmaCount > 0) { - UP.Count = PragmaCount; - UP.Runtime = true; - UP.AllowExpensiveTripCount = true; - UP.Force = true; - if ((UP.AllowRemainder || (TripMultiple % PragmaCount == 0)) && - UCE.getUnrolledLoopSize(UP) < PragmaUnrollThreshold) - return true; - } - bool PragmaFullUnroll = hasUnrollFullPragma(L); - if (PragmaFullUnroll && TripCount != 0) { - UP.Count = TripCount; - if (UCE.getUnrolledLoopSize(UP) < PragmaUnrollThreshold) - return false; - } + UnrollFactor = shouldPragmaUnroll(L, PInfo, TripMultiple, TripCount, UCE, UP); + + if (UnrollFactor) { + UP.Count = *UnrollFactor; - bool PragmaEnableUnroll = hasUnrollEnablePragma(L); - bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll || - PragmaEnableUnroll || UserUnrollCount; - - if (ExplicitUnroll && TripCount != 0) { - // If the loop has an unrolling pragma, we want to be more aggressive with - // unrolling limits. Set thresholds to at least the PragmaUnrollThreshold - // value which is larger than the default limits. - UP.Threshold = std::max<unsigned>(UP.Threshold, PragmaUnrollThreshold); - UP.PartialThreshold = - std::max<unsigned>(UP.PartialThreshold, PragmaUnrollThreshold); + if (UserUnrollCount || (PragmaCount > 0)) { + UP.AllowExpensiveTripCount = true; + UP.Force = true; + } + UP.Runtime |= (PragmaCount > 0); + return ExplicitUnroll; + } else { + if (ExplicitUnroll && TripCount != 0) { + // If the loop has an unrolling pragma, we want to be more aggressive with + // unrolling limits. Set thresholds to at least the PragmaUnrollThreshold + // value which is larger than the default limits. + UP.Threshold = std::max<unsigned>(UP.Threshold, PragmaUnrollThreshold); + UP.PartialThreshold = + std::max<unsigned>(UP.PartialThreshold, PragmaUnrollThreshold); + } } // 3rd priority is full unroll count. @@ -852,71 +975,55 @@ bool llvm::computeUnrollCount( unsigned FullUnrollTripCount = ExactTripCount ? ExactTripCount : FullUnrollMaxTripCount; UP.Count = FullUnrollTripCount; - if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) { - // When computing the unrolled size, note that BEInsns are not replicated - // like the rest of the loop body. - if (UCE.getUnrolledLoopSize(UP) < UP.Threshold) { - UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount); - return ExplicitUnroll; - } else { - // The loop isn't that small, but we still can fully unroll it if that - // helps to remove a significant number of instructions. - // To check that, run additional analysis on the loop. - if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( - L, FullUnrollTripCount, DT, SE, EphValues, TTI, - UP.Threshold * UP.MaxPercentThresholdBoost / 100, - UP.MaxIterationsCountToAnalyze)) { - unsigned Boost = - getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); - if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { - UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount); - return ExplicitUnroll; - } - } - } + + UnrollFactor = + shouldFullUnroll(L, TTI, DT, SE, EphValues, FullUnrollTripCount, UCE, UP); + + // if shouldFullUnroll can do the unrolling, some side parameteres should be + // set + if (UnrollFactor) { + UP.Count = *UnrollFactor; + UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount); + TripCount = FullUnrollTripCount; + TripMultiple = UP.UpperBound ? 1 : TripMultiple; + return ExplicitUnroll; + } else { + UP.Count = FullUnrollTripCount; } // 4th priority is loop peeling. - computePeelCount(L, LoopSize, PP, TripCount, SE, UP.Threshold); + computePeelCount(L, LoopSize, PP, TripCount, DT, SE, UP.Threshold); if (PP.PeelCount) { UP.Runtime = false; UP.Count = 1; return ExplicitUnroll; } + // Before starting partial unrolling, set up.partial to true, + // if user explicitly asked for unrolling + if (TripCount) + UP.Partial |= ExplicitUnroll; + // 5th priority is partial unrolling. // Try partial unroll only when TripCount could be statically calculated. - if (TripCount) { - UP.Partial |= ExplicitUnroll; - if (!UP.Partial) { - LLVM_DEBUG(dbgs() << " will not try to unroll partially because " - << "-unroll-allow-partial not given\n"); - UP.Count = 0; - return false; - } - if (UP.Count == 0) - UP.Count = TripCount; + UnrollFactor = shouldPartialUnroll(LoopSize, TripCount, UCE, UP); + + if (UnrollFactor) { + UP.Count = *UnrollFactor; + + if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && + UP.Count != TripCount) + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "FullUnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll pragma " + "because " + "unrolled size is too large."; + }); + if (UP.PartialThreshold != NoThreshold) { - // Reduce unroll count to be modulo of TripCount for partial unrolling. - if (UCE.getUnrolledLoopSize(UP) > UP.PartialThreshold) - UP.Count = - (std::max(UP.PartialThreshold, UP.BEInsns + 1) - UP.BEInsns) / - (LoopSize - UP.BEInsns); - if (UP.Count > UP.MaxCount) - UP.Count = UP.MaxCount; - while (UP.Count != 0 && TripCount % UP.Count != 0) - UP.Count--; - if (UP.AllowRemainder && UP.Count <= 1) { - // If there is no Count that is modulo of TripCount, set Count to - // largest power-of-two factor that satisfies the threshold limit. - // As we'll create fixup loop, do the type of unrolling only if - // remainder loop is allowed. - UP.Count = UP.DefaultUnrollRuntimeCount; - while (UP.Count != 0 && - UCE.getUnrolledLoopSize(UP) > UP.PartialThreshold) - UP.Count >>= 1; - } - if (UP.Count < 2) { + if (UP.Count == 0) { if (PragmaEnableUnroll) ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, @@ -926,25 +1033,8 @@ bool llvm::computeUnrollCount( "pragma " "because unrolled size is too large."; }); - UP.Count = 0; } - } else { - UP.Count = TripCount; } - if (UP.Count > UP.MaxCount) - UP.Count = UP.MaxCount; - if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && - UP.Count != TripCount) - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, - "FullUnrollAsDirectedTooLarge", - L->getStartLoc(), L->getHeader()) - << "Unable to fully unroll loop as directed by unroll pragma " - "because " - "unrolled size is too large."; - }); - LLVM_DEBUG(dbgs() << " partially unrolling with count: " << UP.Count - << "\n"); return ExplicitUnroll; } assert(TripCount == 0 && @@ -981,8 +1071,6 @@ bool llvm::computeUnrollCount( UP.AllowExpensiveTripCount = true; } } - - // Reduce count based on the type of unrolling and the threshold values. UP.Runtime |= PragmaEnableUnroll || PragmaCount > 0 || UserUnrollCount; if (!UP.Runtime) { LLVM_DEBUG( @@ -1017,7 +1105,7 @@ bool llvm::computeUnrollCount( using namespace ore; - if (PragmaCount > 0 && !UP.AllowRemainder) + if (unrollCountPragmaValue(L) > 0 && !UP.AllowRemainder) ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "DifferentUnrollCountFromDirected", @@ -1079,7 +1167,7 @@ static LoopUnrollResult tryToUnrollLoop( bool NotDuplicatable; bool Convergent; TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount, + L, SE, TTI, BFI, PSI, ORE, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, ProvidedFullUnrollMaxCount); TargetTransformInfo::PeelingPreferences PP = gatherPeelingPreferences( @@ -1529,3 +1617,25 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, return getLoopPassPreservedAnalyses(); } + +void LoopUnrollPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LoopUnrollPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (UnrollOpts.AllowPartial != None) + OS << (UnrollOpts.AllowPartial.getValue() ? "" : "no-") << "partial;"; + if (UnrollOpts.AllowPeeling != None) + OS << (UnrollOpts.AllowPeeling.getValue() ? "" : "no-") << "peeling;"; + if (UnrollOpts.AllowRuntime != None) + OS << (UnrollOpts.AllowRuntime.getValue() ? "" : "no-") << "runtime;"; + if (UnrollOpts.AllowUpperBound != None) + OS << (UnrollOpts.AllowUpperBound.getValue() ? "" : "no-") << "upperbound;"; + if (UnrollOpts.AllowProfileBasedPeeling != None) + OS << (UnrollOpts.AllowProfileBasedPeeling.getValue() ? "" : "no-") + << "profile-peeling;"; + if (UnrollOpts.FullUnrollMaxCount != None) + OS << "full-unroll-max=" << UnrollOpts.FullUnrollMaxCount << ";"; + OS << "O" << UnrollOpts.OptLevel; + OS << ">"; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index 9a854ff80246..76bb5497c2c2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -232,10 +232,8 @@ namespace { AU.addPreserved<LazyBranchProbabilityInfoPass>(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - if (EnableMSSALoopDependency) { - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); if (HasBranchDivergence) AU.addRequired<LegacyDivergenceAnalysis>(); getLoopAnalysisUsage(AU); @@ -539,11 +537,8 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPMRef) { LPM = &LPMRef; DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - if (EnableMSSALoopDependency) { - MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - assert(DT && "Cannot update MemorySSA without a valid DomTree."); - } + MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); CurrentLoop = L; Function *F = CurrentLoop->getHeader()->getParent(); @@ -551,19 +546,19 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPMRef) { if (SanitizeMemory) SafetyInfo.computeLoopSafetyInfo(L); - if (MSSA && VerifyMemorySSA) + if (VerifyMemorySSA) MSSA->verifyMemorySSA(); bool Changed = false; do { assert(CurrentLoop->isLCSSAForm(*DT)); - if (MSSA && VerifyMemorySSA) + if (VerifyMemorySSA) MSSA->verifyMemorySSA(); RedoLoop = false; Changed |= processCurrentLoop(); } while (RedoLoop); - if (MSSA && VerifyMemorySSA) + if (VerifyMemorySSA) MSSA->verifyMemorySSA(); return Changed; @@ -1312,8 +1307,7 @@ void LoopUnswitch::splitExitEdges( for (unsigned I = 0, E = ExitBlocks.size(); I != E; ++I) { BasicBlock *ExitBlock = ExitBlocks[I]; - SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBlock), - pred_end(ExitBlock)); + SmallVector<BasicBlock *, 4> Preds(predecessors(ExitBlock)); // Although SplitBlockPredecessors doesn't preserve loop-simplify in // general, if we call it on all predecessors of all exits then it does. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index bd3001988369..186065db327e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -55,11 +55,17 @@ static bool replaceConditionalBranchesOnConstant(Instruction *II, Value *NewValue, DomTreeUpdater *DTU) { bool HasDeadBlocks = false; - SmallSetVector<Instruction *, 8> Worklist; + SmallSetVector<Instruction *, 8> UnsimplifiedUsers; replaceAndRecursivelySimplify(II, NewValue, nullptr, nullptr, nullptr, - &Worklist); - for (auto I : Worklist) { - BranchInst *BI = dyn_cast<BranchInst>(I); + &UnsimplifiedUsers); + // UnsimplifiedUsers can contain PHI nodes that may be removed when + // replacing the branch instructions, so use a value handle worklist + // to handle those possibly removed instructions. + SmallVector<WeakVH, 8> Worklist(UnsimplifiedUsers.begin(), + UnsimplifiedUsers.end()); + + for (auto &VH : Worklist) { + BranchInst *BI = dyn_cast_or_null<BranchInst>(VH); if (!BI) continue; if (BI->isUnconditional()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index ead8082f3036..1c186e9a0488 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -357,11 +357,10 @@ static bool lowerExpectIntrinsic(Function &F) { // Remove llvm.expect intrinsics. Iterate backwards in order // to process select instructions before the intrinsic gets // removed. - for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE;) { - Instruction *Inst = &*BI++; - CallInst *CI = dyn_cast<CallInst>(Inst); + for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(BB))) { + CallInst *CI = dyn_cast<CallInst>(&Inst); if (!CI) { - if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { + if (SelectInst *SI = dyn_cast<SelectInst>(&Inst)) { if (handleBrSelExpect(*SI)) ExpectIntrinsicsHandled++; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 42c183a6408e..4e4097e13271 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -900,8 +900,7 @@ public: // UndefedInsts and then check that we in fact remove them. SmallSet<Instruction *, 16> UndefedInsts; for (auto *Inst : reverse(ToRemove)) { - for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { - Use &U = *I++; + for (Use &U : llvm::make_early_inc_range(Inst->uses())) { if (auto *Undefed = dyn_cast<Instruction>(U.getUser())) UndefedInsts.insert(Undefed); U.set(UndefValue::get(Inst->getType())); @@ -981,8 +980,9 @@ public: Value *EltPtr = createElementPtr(Ptr, EltTy, Builder); MatrixTy Result; for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { - Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, - Shape.getStride(), EltTy, Builder); + Value *GEP = computeVectorAddr( + EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), + Stride, Shape.getStride(), EltTy, Builder); Value *Vector = Builder.CreateAlignedLoad( VecTy, GEP, getAlignForIndex(I, Stride, EltTy, MAlign), IsVolatile, "col.load"); @@ -1071,9 +1071,11 @@ public: auto VType = cast<VectorType>(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); for (auto Vec : enumerate(StoreVal.vectors())) { - Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), - Stride, StoreVal.getStride(), - VType->getElementType(), Builder); + Value *GEP = computeVectorAddr( + EltPtr, + Builder.getIntN(Stride->getType()->getScalarSizeInBits(), + Vec.index()), + Stride, StoreVal.getStride(), VType->getElementType(), Builder); Builder.CreateAlignedStore(Vec.value(), GEP, getAlignForIndex(Vec.index(), Stride, VType->getElementType(), @@ -2261,6 +2263,16 @@ PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, return PreservedAnalyses::all(); } +void LowerMatrixIntrinsicsPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + if (Minimal) + OS << "minimal"; + OS << ">"; +} + namespace { class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 9afbe0e9a2a5..67335a45fb58 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -22,7 +22,6 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" @@ -67,9 +66,10 @@ using namespace llvm; #define DEBUG_TYPE "memcpyopt" -static cl::opt<bool> - EnableMemorySSA("enable-memcpyopt-memoryssa", cl::init(true), cl::Hidden, - cl::desc("Use MemorySSA-backed MemCpyOpt.")); +static cl::opt<bool> EnableMemCpyOptWithoutLibcalls( + "enable-memcpyopt-without-libcalls", cl::init(false), cl::Hidden, + cl::ZeroOrMore, + cl::desc("Enable memcpyopt even when libcalls are disabled")); STATISTIC(NumMemCpyInstr, "Number of memcpy instructions deleted"); STATISTIC(NumMemSetInfer, "Number of memsets inferred"); @@ -282,13 +282,9 @@ private: AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); - if (!EnableMemorySSA) - AU.addRequired<MemoryDependenceWrapperPass>(); - AU.addPreserved<MemoryDependenceWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); - if (EnableMemorySSA) - AU.addRequired<MemorySSAWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); AU.addPreserved<MemorySSAWrapperPass>(); } }; @@ -304,7 +300,6 @@ INITIALIZE_PASS_BEGIN(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) @@ -329,10 +324,7 @@ static bool mayBeVisibleThroughUnwinding(Value *V, Instruction *Start, } void MemCpyOptPass::eraseInstruction(Instruction *I) { - if (MSSAU) - MSSAU->removeMemoryAccess(I); - if (MD) - MD->removeInstruction(I); + MSSAU->removeMemoryAccess(I); I->eraseFromParent(); } @@ -394,14 +386,12 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // memsets. MemoryDef *LastMemDef = nullptr; for (++BI; !BI->isTerminator(); ++BI) { - if (MSSAU) { - auto *CurrentAcc = cast_or_null<MemoryUseOrDef>( - MSSAU->getMemorySSA()->getMemoryAccess(&*BI)); - if (CurrentAcc) { - MemInsertPoint = CurrentAcc; - if (auto *CurrentDef = dyn_cast<MemoryDef>(CurrentAcc)) - LastMemDef = CurrentDef; - } + auto *CurrentAcc = cast_or_null<MemoryUseOrDef>( + MSSAU->getMemorySSA()->getMemoryAccess(&*BI)); + if (CurrentAcc) { + MemInsertPoint = CurrentAcc; + if (auto *CurrentDef = dyn_cast<MemoryDef>(CurrentAcc)) + LastMemDef = CurrentDef; } // Calls that only access inaccessible memory do not block merging @@ -503,19 +493,17 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, if (!Range.TheStores.empty()) AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); - if (MSSAU) { - assert(LastMemDef && MemInsertPoint && - "Both LastMemDef and MemInsertPoint need to be set"); - auto *NewDef = - cast<MemoryDef>(MemInsertPoint->getMemoryInst() == &*BI - ? MSSAU->createMemoryAccessBefore( - AMemSet, LastMemDef, MemInsertPoint) - : MSSAU->createMemoryAccessAfter( - AMemSet, LastMemDef, MemInsertPoint)); - MSSAU->insertDef(NewDef, /*RenameUses=*/true); - LastMemDef = NewDef; - MemInsertPoint = NewDef; - } + assert(LastMemDef && MemInsertPoint && + "Both LastMemDef and MemInsertPoint need to be set"); + auto *NewDef = + cast<MemoryDef>(MemInsertPoint->getMemoryInst() == &*BI + ? MSSAU->createMemoryAccessBefore( + AMemSet, LastMemDef, MemInsertPoint) + : MSSAU->createMemoryAccessAfter( + AMemSet, LastMemDef, MemInsertPoint)); + MSSAU->insertDef(NewDef, /*RenameUses=*/true); + LastMemDef = NewDef; + MemInsertPoint = NewDef; // Zap all the stores. for (Instruction *SI : Range.TheStores) @@ -624,17 +612,15 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { // TODO: Simplify this once P will be determined by MSSA, in which case the // discrepancy can no longer occur. MemoryUseOrDef *MemInsertPoint = nullptr; - if (MSSAU) { - if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(P)) { - MemInsertPoint = cast<MemoryUseOrDef>(--MA->getIterator()); - } else { - const Instruction *ConstP = P; - for (const Instruction &I : make_range(++ConstP->getReverseIterator(), - ++LI->getReverseIterator())) { - if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(&I)) { - MemInsertPoint = MA; - break; - } + if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(P)) { + MemInsertPoint = cast<MemoryUseOrDef>(--MA->getIterator()); + } else { + const Instruction *ConstP = P; + for (const Instruction &I : make_range(++ConstP->getReverseIterator(), + ++LI->getReverseIterator())) { + if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(&I)) { + MemInsertPoint = MA; + break; } } } @@ -643,12 +629,10 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { for (auto *I : llvm::reverse(ToLift)) { LLVM_DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); I->moveBefore(P); - if (MSSAU) { - assert(MemInsertPoint && "Must have found insert point"); - if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(I)) { - MSSAU->moveAfter(MA, MemInsertPoint); - MemInsertPoint = MA; - } + assert(MemInsertPoint && "Must have found insert point"); + if (MemoryUseOrDef *MA = MSSAU->getMemorySSA()->getMemoryAccess(I)) { + MSSAU->moveAfter(MA, MemInsertPoint); + MemInsertPoint = MA; } } @@ -682,7 +666,13 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { LI->getParent() == SI->getParent()) { auto *T = LI->getType(); - if (T->isAggregateType()) { + // Don't introduce calls to memcpy/memmove intrinsics out of thin air if + // the corresponding libcalls are not available. + // TODO: We should really distinguish between libcall availability and + // our ability to introduce intrinsics. + if (T->isAggregateType() && + (EnableMemCpyOptWithoutLibcalls || + (TLI->has(LibFunc_memcpy) && TLI->has(LibFunc_memmove)))) { MemoryLocation LoadLoc = MemoryLocation::get(LI); // We use alias analysis to check if an instruction may store to @@ -712,9 +702,10 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (P) { // If we load from memory that may alias the memory we store to, // memmove must be used to preserve semantic. If not, memcpy can - // be used. + // be used. Also, if we load from constant memory, memcpy can be used + // as the constant memory won't be modified. bool UseMemMove = false; - if (!AA->isNoAlias(MemoryLocation::get(SI), LoadLoc)) + if (isModSet(AA->getModRefInfo(SI, LoadLoc))) UseMemMove = true; uint64_t Size = DL.getTypeStoreSize(T); @@ -733,13 +724,10 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " << *M << "\n"); - if (MSSAU) { - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); - auto *NewAccess = - MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - } + auto *LastDef = + cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(SI); eraseInstruction(LI); @@ -755,38 +743,21 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // happen to be using a load-store pair to implement it, rather than // a memcpy. CallInst *C = nullptr; - if (EnableMemorySSA) { - if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>( - MSSA->getWalker()->getClobberingMemoryAccess(LI))) { - // The load most post-dom the call. Limit to the same block for now. - // TODO: Support non-local call-slot optimization? - if (LoadClobber->getBlock() == SI->getParent()) - C = dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst()); - } - } else { - MemDepResult ldep = MD->getDependency(LI); - if (ldep.isClobber() && !isa<MemCpyInst>(ldep.getInst())) - C = dyn_cast<CallInst>(ldep.getInst()); + if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>( + MSSA->getWalker()->getClobberingMemoryAccess(LI))) { + // The load most post-dom the call. Limit to the same block for now. + // TODO: Support non-local call-slot optimization? + if (LoadClobber->getBlock() == SI->getParent()) + C = dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst()); } if (C) { // Check that nothing touches the dest of the "copy" between // the call and the store. MemoryLocation StoreLoc = MemoryLocation::get(SI); - if (EnableMemorySSA) { - if (accessedBetween(*AA, StoreLoc, MSSA->getMemoryAccess(C), - MSSA->getMemoryAccess(SI))) - C = nullptr; - } else { - for (BasicBlock::iterator I = --SI->getIterator(), - E = C->getIterator(); - I != E; --I) { - if (isModOrRefSet(AA->getModRefInfo(&*I, StoreLoc))) { - C = nullptr; - break; - } - } - } + if (accessedBetween(*AA, StoreLoc, MSSA->getMemoryAccess(C), + MSSA->getMemoryAccess(SI))) + C = nullptr; } if (C) { @@ -805,6 +776,13 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { } } + // The following code creates memset intrinsics out of thin air. Don't do + // this if the corresponding libfunc is not available. + // TODO: We should really distinguish between libcall availability and + // our ability to introduce intrinsics. + if (!(TLI->has(LibFunc_memset) || EnableMemCpyOptWithoutLibcalls)) + return false; + // There are two cases that are interesting for this code to handle: memcpy // and memset. Right now we only handle memset. @@ -831,13 +809,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); - if (MSSAU) { - assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI))); - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - } + // The newly inserted memset is immediately overwritten by the original + // store, so we do not need to rename uses. + auto *StoreDef = cast<MemoryDef>(MSSA->getMemoryAccess(SI)); + auto *NewAccess = MSSAU->createMemoryAccessBefore( + M, StoreDef->getDefiningAccess(), StoreDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/false); eraseInstruction(SI); NumMemSetInfer++; @@ -1033,11 +1010,6 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, cast<AllocaInst>(cpyDest)->setAlignment(srcAlign); } - // Drop any cached information about the call, because we may have changed - // its dependence information by changing its parameter. - if (MD) - MD->removeInstruction(C); - // Update AA metadata // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be // handled here, but combineMetadata doesn't support them yet @@ -1086,28 +1058,19 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // // TODO: If the code between M and MDep is transparent to the destination "c", // then we could still perform the xform by moving M up to the first memcpy. - if (EnableMemorySSA) { - // TODO: It would be sufficient to check the MDep source up to the memcpy - // size of M, rather than MDep. - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), - MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) - return false; - } else { - // NOTE: This is conservative, it will stop on any read from the source loc, - // not just the defining memcpy. - MemDepResult SourceDep = - MD->getPointerDependencyFrom(MemoryLocation::getForSource(MDep), false, - M->getIterator(), M->getParent()); - if (!SourceDep.isClobber() || SourceDep.getInst() != MDep) - return false; - } + // TODO: It would be sufficient to check the MDep source up to the memcpy + // size of M, rather than MDep. + if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) + return false; // If the dest of the second might alias the source of the first, then the - // source and dest might overlap. We still want to eliminate the intermediate - // value, but we have to generate a memmove instead of memcpy. + // source and dest might overlap. In addition, if the source of the first + // points to constant memory, they won't overlap by definition. Otherwise, we + // still want to eliminate the intermediate value, but we have to generate a + // memmove instead of memcpy. bool UseMemMove = false; - if (!AA->isNoAlias(MemoryLocation::getForDest(M), - MemoryLocation::getForSource(MDep))) + if (isModSet(AA->getModRefInfo(M, MemoryLocation::getForSource(MDep)))) UseMemMove = true; // If all checks passed, then we can transform M. @@ -1134,12 +1097,10 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, MDep->getRawSource(), MDep->getSourceAlign(), M->getLength(), M->isVolatile()); - if (MSSAU) { - assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M))); - auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - } + assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M))); + auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); // Remove the instruction we're replacing. eraseInstruction(M); @@ -1169,30 +1130,16 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, // Check that src and dst of the memcpy aren't the same. While memcpy // operands cannot partially overlap, exact equality is allowed. - if (!AA->isNoAlias(MemoryLocation(MemCpy->getSource(), - LocationSize::precise(1)), - MemoryLocation(MemCpy->getDest(), - LocationSize::precise(1)))) + if (isModSet(AA->getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy)))) return false; - if (EnableMemorySSA) { - // We know that dst up to src_size is not written. We now need to make sure - // that dst up to dst_size is not accessed. (If we did not move the memset, - // checking for reads would be sufficient.) - if (accessedBetween(*AA, MemoryLocation::getForDest(MemSet), - MSSA->getMemoryAccess(MemSet), - MSSA->getMemoryAccess(MemCpy))) { - return false; - } - } else { - // We have already checked that dst up to src_size is not accessed. We - // need to make sure that there are no accesses up to dst_size either. - MemDepResult DstDepInfo = MD->getPointerDependencyFrom( - MemoryLocation::getForDest(MemSet), false, MemCpy->getIterator(), - MemCpy->getParent()); - if (DstDepInfo.getInst() != MemSet) - return false; - } + // We know that dst up to src_size is not written. We now need to make sure + // that dst up to dst_size is not accessed. (If we did not move the memset, + // checking for reads would be sufficient.) + if (accessedBetween(*AA, MemoryLocation::getForDest(MemSet), + MSSA->getMemoryAccess(MemSet), + MSSA->getMemoryAccess(MemCpy))) + return false; // Use the same i8* dest as the memcpy, killing the memset dest if different. Value *Dest = MemCpy->getRawDest(); @@ -1242,18 +1189,16 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, SrcSize), MemSet->getOperand(1), MemsetLen, MaybeAlign(Align)); - if (MSSAU) { - assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) && - "MemCpy must be a MemoryDef"); - // The new memset is inserted after the memcpy, but it is known that its - // defining access is the memset about to be removed which immediately - // precedes the memcpy. - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); - auto *NewAccess = MSSAU->createMemoryAccessBefore( - NewMemSet, LastDef->getDefiningAccess(), LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - } + assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) && + "MemCpy must be a MemoryDef"); + // The new memset is inserted after the memcpy, but it is known that its + // defining access is the memset about to be removed which immediately + // precedes the memcpy. + auto *LastDef = + cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); + auto *NewAccess = MSSAU->createMemoryAccessBefore( + NewMemSet, LastDef->getDefiningAccess(), LastDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(MemSet); return true; @@ -1261,23 +1206,8 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, /// Determine whether the instruction has undefined content for the given Size, /// either because it was freshly alloca'd or started its lifetime. -static bool hasUndefContents(Instruction *I, Value *Size) { - if (isa<AllocaInst>(I)) - return true; - - if (ConstantInt *CSize = dyn_cast<ConstantInt>(Size)) { - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) - if (II->getIntrinsicID() == Intrinsic::lifetime_start) - if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0))) - if (LTSize->getZExtValue() >= CSize->getZExtValue()) - return true; - } - - return false; -} - -static bool hasUndefContentsMSSA(MemorySSA *MSSA, AliasAnalysis *AA, Value *V, - MemoryDef *Def, Value *Size) { +static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V, + MemoryDef *Def, Value *Size) { if (MSSA->isLiveOnEntryDef(Def)) return isa<AllocaInst>(getUnderlyingObject(V)); @@ -1351,19 +1281,12 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, // easily represent this location, we use the full 0..CopySize range. MemoryLocation MemCpyLoc = MemoryLocation::getForSource(MemCpy); bool CanReduceSize = false; - if (EnableMemorySSA) { - MemoryUseOrDef *MemSetAccess = MSSA->getMemoryAccess(MemSet); - MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( - MemSetAccess->getDefiningAccess(), MemCpyLoc); - if (auto *MD = dyn_cast<MemoryDef>(Clobber)) - if (hasUndefContentsMSSA(MSSA, AA, MemCpy->getSource(), MD, CopySize)) - CanReduceSize = true; - } else { - MemDepResult DepInfo = MD->getPointerDependencyFrom( - MemCpyLoc, true, MemSet->getIterator(), MemSet->getParent()); - if (DepInfo.isDef() && hasUndefContents(DepInfo.getInst(), CopySize)) + MemoryUseOrDef *MemSetAccess = MSSA->getMemoryAccess(MemSet); + MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( + MemSetAccess->getDefiningAccess(), MemCpyLoc); + if (auto *MD = dyn_cast<MemoryDef>(Clobber)) + if (hasUndefContents(MSSA, AA, MemCpy->getSource(), MD, CopySize)) CanReduceSize = true; - } if (!CanReduceSize) return false; @@ -1375,12 +1298,10 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, Instruction *NewM = Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), CopySize, MaybeAlign(MemCpy->getDestAlignment())); - if (MSSAU) { - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - } + auto *LastDef = + cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); return true; } @@ -1410,151 +1331,90 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { Instruction *NewM = Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), MaybeAlign(M->getDestAlignment()), false); - if (MSSAU) { - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); - auto *NewAccess = - MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - } + auto *LastDef = + cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); + auto *NewAccess = + MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(M); ++NumCpyToSet; return true; } - if (EnableMemorySSA) { - MemoryUseOrDef *MA = MSSA->getMemoryAccess(M); - MemoryAccess *AnyClobber = MSSA->getWalker()->getClobberingMemoryAccess(MA); - MemoryLocation DestLoc = MemoryLocation::getForDest(M); - const MemoryAccess *DestClobber = - MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc); - - // Try to turn a partially redundant memset + memcpy into - // memcpy + smaller memset. We don't need the memcpy size for this. - // The memcpy most post-dom the memset, so limit this to the same basic - // block. A non-local generalization is likely not worthwhile. - if (auto *MD = dyn_cast<MemoryDef>(DestClobber)) - if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst())) - if (DestClobber->getBlock() == M->getParent()) - if (processMemSetMemCpyDependence(M, MDep)) - return true; - - MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess( - AnyClobber, MemoryLocation::getForSource(M)); - - // There are four possible optimizations we can do for memcpy: - // a) memcpy-memcpy xform which exposes redundance for DSE. - // b) call-memcpy xform for return slot optimization. - // c) memcpy from freshly alloca'd space or space that has just started - // its lifetime copies undefined data, and we can therefore eliminate - // the memcpy in favor of the data that was already at the destination. - // d) memcpy from a just-memset'd source can be turned into memset. - if (auto *MD = dyn_cast<MemoryDef>(SrcClobber)) { - if (Instruction *MI = MD->getMemoryInst()) { - if (ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength())) { - if (auto *C = dyn_cast<CallInst>(MI)) { - // The memcpy must post-dom the call. Limit to the same block for - // now. Additionally, we need to ensure that there are no accesses - // to dest between the call and the memcpy. Accesses to src will be - // checked by performCallSlotOptzn(). - // TODO: Support non-local call-slot optimization? - if (C->getParent() == M->getParent() && - !accessedBetween(*AA, DestLoc, MD, MA)) { - // FIXME: Can we pass in either of dest/src alignment here instead - // of conservatively taking the minimum? - Align Alignment = std::min(M->getDestAlign().valueOrOne(), - M->getSourceAlign().valueOrOne()); - if (performCallSlotOptzn( - M, M, M->getDest(), M->getSource(), - TypeSize::getFixed(CopySize->getZExtValue()), Alignment, - C)) { - LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n" - << " call: " << *C << "\n" - << " memcpy: " << *M << "\n"); - eraseInstruction(M); - ++NumMemCpyInstr; - return true; - } - } - } - } - if (auto *MDep = dyn_cast<MemCpyInst>(MI)) - return processMemCpyMemCpyDependence(M, MDep); - if (auto *MDep = dyn_cast<MemSetInst>(MI)) { - if (performMemCpyToMemSetOptzn(M, MDep)) { - LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n"); - eraseInstruction(M); - ++NumCpyToSet; - return true; - } - } - } - - if (hasUndefContentsMSSA(MSSA, AA, M->getSource(), MD, M->getLength())) { - LLVM_DEBUG(dbgs() << "Removed memcpy from undef\n"); - eraseInstruction(M); - ++NumMemCpyInstr; - return true; - } - } - } else { - MemDepResult DepInfo = MD->getDependency(M); - - // Try to turn a partially redundant memset + memcpy into - // memcpy + smaller memset. We don't need the memcpy size for this. - if (DepInfo.isClobber()) - if (MemSetInst *MDep = dyn_cast<MemSetInst>(DepInfo.getInst())) + MemoryUseOrDef *MA = MSSA->getMemoryAccess(M); + MemoryAccess *AnyClobber = MSSA->getWalker()->getClobberingMemoryAccess(MA); + MemoryLocation DestLoc = MemoryLocation::getForDest(M); + const MemoryAccess *DestClobber = + MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc); + + // Try to turn a partially redundant memset + memcpy into + // memcpy + smaller memset. We don't need the memcpy size for this. + // The memcpy most post-dom the memset, so limit this to the same basic + // block. A non-local generalization is likely not worthwhile. + if (auto *MD = dyn_cast<MemoryDef>(DestClobber)) + if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst())) + if (DestClobber->getBlock() == M->getParent()) if (processMemSetMemCpyDependence(M, MDep)) return true; - // There are four possible optimizations we can do for memcpy: - // a) memcpy-memcpy xform which exposes redundance for DSE. - // b) call-memcpy xform for return slot optimization. - // c) memcpy from freshly alloca'd space or space that has just started - // its lifetime copies undefined data, and we can therefore eliminate - // the memcpy in favor of the data that was already at the destination. - // d) memcpy from a just-memset'd source can be turned into memset. - if (ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength())) { - if (DepInfo.isClobber()) { - if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) { - // FIXME: Can we pass in either of dest/src alignment here instead - // of conservatively taking the minimum? - Align Alignment = std::min(M->getDestAlign().valueOrOne(), - M->getSourceAlign().valueOrOne()); - if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(), - TypeSize::getFixed(CopySize->getZExtValue()), - Alignment, C)) { - eraseInstruction(M); - ++NumMemCpyInstr; - return true; + MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess( + AnyClobber, MemoryLocation::getForSource(M)); + + // There are four possible optimizations we can do for memcpy: + // a) memcpy-memcpy xform which exposes redundance for DSE. + // b) call-memcpy xform for return slot optimization. + // c) memcpy from freshly alloca'd space or space that has just started + // its lifetime copies undefined data, and we can therefore eliminate + // the memcpy in favor of the data that was already at the destination. + // d) memcpy from a just-memset'd source can be turned into memset. + if (auto *MD = dyn_cast<MemoryDef>(SrcClobber)) { + if (Instruction *MI = MD->getMemoryInst()) { + if (ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength())) { + if (auto *C = dyn_cast<CallInst>(MI)) { + // The memcpy must post-dom the call. Limit to the same block for + // now. Additionally, we need to ensure that there are no accesses + // to dest between the call and the memcpy. Accesses to src will be + // checked by performCallSlotOptzn(). + // TODO: Support non-local call-slot optimization? + if (C->getParent() == M->getParent() && + !accessedBetween(*AA, DestLoc, MD, MA)) { + // FIXME: Can we pass in either of dest/src alignment here instead + // of conservatively taking the minimum? + Align Alignment = std::min(M->getDestAlign().valueOrOne(), + M->getSourceAlign().valueOrOne()); + if (performCallSlotOptzn( + M, M, M->getDest(), M->getSource(), + TypeSize::getFixed(CopySize->getZExtValue()), Alignment, + C)) { + LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n" + << " call: " << *C << "\n" + << " memcpy: " << *M << "\n"); + eraseInstruction(M); + ++NumMemCpyInstr; + return true; + } } } } - } - - MemoryLocation SrcLoc = MemoryLocation::getForSource(M); - MemDepResult SrcDepInfo = MD->getPointerDependencyFrom( - SrcLoc, true, M->getIterator(), M->getParent()); - - if (SrcDepInfo.isClobber()) { - if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst())) + if (auto *MDep = dyn_cast<MemCpyInst>(MI)) return processMemCpyMemCpyDependence(M, MDep); - } else if (SrcDepInfo.isDef()) { - if (hasUndefContents(SrcDepInfo.getInst(), M->getLength())) { - eraseInstruction(M); - ++NumMemCpyInstr; - return true; - } - } - - if (SrcDepInfo.isClobber()) - if (MemSetInst *MDep = dyn_cast<MemSetInst>(SrcDepInfo.getInst())) + if (auto *MDep = dyn_cast<MemSetInst>(MI)) { if (performMemCpyToMemSetOptzn(M, MDep)) { + LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n"); eraseInstruction(M); ++NumCpyToSet; return true; } + } + } + + if (hasUndefContents(MSSA, AA, M->getSource(), MD, M->getLength())) { + LLVM_DEBUG(dbgs() << "Removed memcpy from undef\n"); + eraseInstruction(M); + ++NumMemCpyInstr; + return true; + } } return false; @@ -1563,12 +1423,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { /// Transforms memmove calls to memcpy calls when the src/dst are guaranteed /// not to alias. bool MemCpyOptPass::processMemMove(MemMoveInst *M) { - if (!TLI->has(LibFunc_memmove)) - return false; - - // See if the pointers alias. - if (!AA->isNoAlias(MemoryLocation::getForDest(M), - MemoryLocation::getForSource(M))) + // See if the source could be modified by this memmove potentially. + if (isModSet(AA->getModRefInfo(M, MemoryLocation::getForSource(M)))) return false; LLVM_DEBUG(dbgs() << "MemCpyOptPass: Optimizing memmove -> memcpy: " << *M @@ -1584,11 +1440,6 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) { // For MemorySSA nothing really changes (except that memcpy may imply stricter // aliasing guarantees). - // MemDep may have over conservative information about this instruction, just - // conservatively flush it from the cache. - if (MD) - MD->removeInstruction(M); - ++NumMoveToCpy; return true; } @@ -1601,22 +1452,14 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { Type *ByValTy = CB.getParamByValType(ArgNo); TypeSize ByValSize = DL.getTypeAllocSize(ByValTy); MemoryLocation Loc(ByValArg, LocationSize::precise(ByValSize)); + MemoryUseOrDef *CallAccess = MSSA->getMemoryAccess(&CB); + if (!CallAccess) + return false; MemCpyInst *MDep = nullptr; - if (EnableMemorySSA) { - MemoryUseOrDef *CallAccess = MSSA->getMemoryAccess(&CB); - if (!CallAccess) - return false; - MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( - CallAccess->getDefiningAccess(), Loc); - if (auto *MD = dyn_cast<MemoryDef>(Clobber)) - MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst()); - } else { - MemDepResult DepInfo = MD->getPointerDependencyFrom( - Loc, true, CB.getIterator(), CB.getParent()); - if (!DepInfo.isClobber()) - return false; - MDep = dyn_cast<MemCpyInst>(DepInfo.getInst()); - } + MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( + CallAccess->getDefiningAccess(), Loc); + if (auto *MD = dyn_cast<MemoryDef>(Clobber)) + MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst()); // If the byval argument isn't fed by a memcpy, ignore it. If it is fed by // a memcpy, see if we can byval from the source of the memcpy instead of the @@ -1655,19 +1498,9 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { // *b = 42; // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). - if (EnableMemorySSA) { - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), - MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) - return false; - } else { - // NOTE: This is conservative, it will stop on any read from the source loc, - // not just the defining memcpy. - MemDepResult SourceDep = MD->getPointerDependencyFrom( - MemoryLocation::getForSource(MDep), false, - CB.getIterator(), MDep->getParent()); - if (!SourceDep.isClobber() || SourceDep.getInst() != MDep) - return false; - } + if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) + return false; Value *TmpCast = MDep->getSource(); if (MDep->getSource()->getType() != ByValArg->getType()) { @@ -1734,47 +1567,33 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) { } PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { - auto *MD = !EnableMemorySSA ? &AM.getResult<MemoryDependenceAnalysis>(F) - : AM.getCachedResult<MemoryDependenceAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto *AA = &AM.getResult<AAManager>(F); auto *AC = &AM.getResult<AssumptionAnalysis>(F); auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); - auto *MSSA = EnableMemorySSA ? &AM.getResult<MemorySSAAnalysis>(F) - : AM.getCachedResult<MemorySSAAnalysis>(F); + auto *MSSA = &AM.getResult<MemorySSAAnalysis>(F); - bool MadeChange = - runImpl(F, MD, &TLI, AA, AC, DT, MSSA ? &MSSA->getMSSA() : nullptr); + bool MadeChange = runImpl(F, &TLI, AA, AC, DT, &MSSA->getMSSA()); if (!MadeChange) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); - if (MD) - PA.preserve<MemoryDependenceAnalysis>(); - if (MSSA) - PA.preserve<MemorySSAAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); return PA; } -bool MemCpyOptPass::runImpl(Function &F, MemoryDependenceResults *MD_, - TargetLibraryInfo *TLI_, AliasAnalysis *AA_, - AssumptionCache *AC_, DominatorTree *DT_, - MemorySSA *MSSA_) { +bool MemCpyOptPass::runImpl(Function &F, TargetLibraryInfo *TLI_, + AliasAnalysis *AA_, AssumptionCache *AC_, + DominatorTree *DT_, MemorySSA *MSSA_) { bool MadeChange = false; - MD = MD_; TLI = TLI_; AA = AA_; AC = AC_; DT = DT_; MSSA = MSSA_; MemorySSAUpdater MSSAU_(MSSA_); - MSSAU = MSSA_ ? &MSSAU_ : nullptr; - // If we don't have at least memset and memcpy, there is little point of doing - // anything here. These are required by a freestanding implementation, so if - // even they are disabled, there is no point in trying hard. - if (!TLI->has(LibFunc_memset) || !TLI->has(LibFunc_memcpy)) - return false; + MSSAU = &MSSAU_; while (true) { if (!iterateOnFunction(F)) @@ -1782,10 +1601,9 @@ bool MemCpyOptPass::runImpl(Function &F, MemoryDependenceResults *MD_, MadeChange = true; } - if (MSSA_ && VerifyMemorySSA) + if (VerifyMemorySSA) MSSA_->verifyMemorySSA(); - MD = nullptr; return MadeChange; } @@ -1794,17 +1612,11 @@ bool MemCpyOptLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; - auto *MDWP = !EnableMemorySSA - ? &getAnalysis<MemoryDependenceWrapperPass>() - : getAnalysisIfAvailable<MemoryDependenceWrapperPass>(); auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *MSSAWP = EnableMemorySSA - ? &getAnalysis<MemorySSAWrapperPass>() - : getAnalysisIfAvailable<MemorySSAWrapperPass>(); + auto *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - return Impl.runImpl(F, MDWP ? & MDWP->getMemDep() : nullptr, TLI, AA, AC, DT, - MSSAWP ? &MSSAWP->getMSSA() : nullptr); + return Impl.runImpl(F, TLI, AA, AC, DT, MSSA); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp index f13f24ad2027..aac0deea5be3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -66,15 +66,6 @@ namespace { #define DEBUG_TYPE "mergeicmps" -// Returns true if the instruction is a simple load or a simple store -static bool isSimpleLoadOrStore(const Instruction *I) { - if (const LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->isSimple(); - if (const StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->isSimple(); - return false; -} - // A BCE atom "Binary Compare Expression Atom" represents an integer load // that is a constant offset from a base value, e.g. `a` or `o.c` in the example // at the top. @@ -154,6 +145,10 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { return {}; } Value *const Addr = LoadI->getOperand(0); + if (Addr->getType()->getPointerAddressSpace() != 0) { + LLVM_DEBUG(dbgs() << "from non-zero AddressSpace\n"); + return {}; + } auto *const GEP = dyn_cast<GetElementPtrInst>(Addr); if (!GEP) return {}; @@ -234,6 +229,8 @@ class BCECmpBlock { InstructionSet BlockInsts; // The block requires splitting. bool RequireSplit = false; + // Original order of this block in the chain. + unsigned OrigOrder = 0; private: BCECmp Cmp; @@ -244,14 +241,13 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, // If this instruction may clobber the loads and is in middle of the BCE cmp // block instructions, then bail for now. if (Inst->mayWriteToMemory()) { - // Bail if this is not a simple load or store - if (!isSimpleLoadOrStore(Inst)) - return false; - // Disallow stores that might alias the BCE operands - MemoryLocation LLoc = MemoryLocation::get(Cmp.Lhs.LoadI); - MemoryLocation RLoc = MemoryLocation::get(Cmp.Rhs.LoadI); - if (isModSet(AA.getModRefInfo(Inst, LLoc)) || - isModSet(AA.getModRefInfo(Inst, RLoc))) + auto MayClobber = [&](LoadInst *LI) { + // If a potentially clobbering instruction comes before the load, + // we can still safely sink the load. + return !Inst->comesBefore(LI) && + isModSet(AA.getModRefInfo(Inst, MemoryLocation::get(LI))); + }; + if (MayClobber(Cmp.Lhs.LoadI) || MayClobber(Cmp.Rhs.LoadI)) return false; } // Make sure this instruction does not use any of the BCE cmp block @@ -386,39 +382,83 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, << Comparison.Rhs().BaseId << " + " << Comparison.Rhs().Offset << "\n"); LLVM_DEBUG(dbgs() << "\n"); + Comparison.OrigOrder = Comparisons.size(); Comparisons.push_back(std::move(Comparison)); } // A chain of comparisons. class BCECmpChain { - public: - BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, - AliasAnalysis &AA); - - int size() const { return Comparisons_.size(); } +public: + using ContiguousBlocks = std::vector<BCECmpBlock>; -#ifdef MERGEICMPS_DOT_ON - void dump() const; -#endif // MERGEICMPS_DOT_ON + BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, + AliasAnalysis &AA); bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, DomTreeUpdater &DTU); -private: - static bool IsContiguous(const BCECmpBlock &First, - const BCECmpBlock &Second) { - return First.Lhs().BaseId == Second.Lhs().BaseId && - First.Rhs().BaseId == Second.Rhs().BaseId && - First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && - First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; + bool atLeastOneMerged() const { + return any_of(MergedBlocks_, + [](const auto &Blocks) { return Blocks.size() > 1; }); } +private: PHINode &Phi_; - std::vector<BCECmpBlock> Comparisons_; + // The list of all blocks in the chain, grouped by contiguity. + std::vector<ContiguousBlocks> MergedBlocks_; // The original entry block (before sorting); BasicBlock *EntryBlock_; }; +static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { + return First.Lhs().BaseId == Second.Lhs().BaseId && + First.Rhs().BaseId == Second.Rhs().BaseId && + First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && + First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; +} + +static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) { + unsigned MinOrigOrder = std::numeric_limits<unsigned>::max(); + for (const BCECmpBlock &Block : Blocks) + MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder); + return MinOrigOrder; +} + +/// Given a chain of comparison blocks, groups the blocks into contiguous +/// ranges that can be merged together into a single comparison. +static std::vector<BCECmpChain::ContiguousBlocks> +mergeBlocks(std::vector<BCECmpBlock> &&Blocks) { + std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks; + + // Sort to detect continuous offsets. + llvm::sort(Blocks, + [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { + return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) < + std::tie(RhsBlock.Lhs(), RhsBlock.Rhs()); + }); + + BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr; + for (BCECmpBlock &Block : Blocks) { + if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) { + MergedBlocks.emplace_back(); + LastMergedBlock = &MergedBlocks.back(); + } else { + LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into " + << LastMergedBlock->back().BB->getName() << "\n"); + } + LastMergedBlock->push_back(std::move(Block)); + } + + // While we allow reordering for merging, do not reorder unmerged comparisons. + // Doing so may introduce branch on poison. + llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks, + const BCECmpChain::ContiguousBlocks &RhsBlocks) { + return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks); + }); + + return MergedBlocks; +} + BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, AliasAnalysis &AA) : Phi_(Phi) { @@ -498,47 +538,9 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, return; } EntryBlock_ = Comparisons[0].BB; - Comparisons_ = std::move(Comparisons); -#ifdef MERGEICMPS_DOT_ON - errs() << "BEFORE REORDERING:\n\n"; - dump(); -#endif // MERGEICMPS_DOT_ON - // Reorder blocks by LHS. We can do that without changing the - // semantics because we are only accessing dereferencable memory. - llvm::sort(Comparisons_, - [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { - return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) < - std::tie(RhsBlock.Lhs(), RhsBlock.Rhs()); - }); -#ifdef MERGEICMPS_DOT_ON - errs() << "AFTER REORDERING:\n\n"; - dump(); -#endif // MERGEICMPS_DOT_ON + MergedBlocks_ = mergeBlocks(std::move(Comparisons)); } -#ifdef MERGEICMPS_DOT_ON -void BCECmpChain::dump() const { - errs() << "digraph dag {\n"; - errs() << " graph [bgcolor=transparent];\n"; - errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n"; - errs() << " edge [color=black];\n"; - for (size_t I = 0; I < Comparisons_.size(); ++I) { - const auto &Comparison = Comparisons_[I]; - errs() << " \"" << I << "\" [label=\"%" - << Comparison.Lhs().Base()->getName() << " + " - << Comparison.Lhs().Offset << " == %" - << Comparison.Rhs().Base()->getName() << " + " - << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8) - << " bytes)\"];\n"; - const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB); - if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n"; - errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n"; - } - errs() << " \"Phi\" [label=\"Phi\"];\n"; - errs() << "}\n\n"; -} -#endif // MERGEICMPS_DOT_ON - namespace { // A class to compute the name of a set of merged basic blocks. @@ -661,47 +663,18 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, DomTreeUpdater &DTU) { - assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain"); - // First pass to check if there is at least one merge. If not, we don't do - // anything and we keep analysis passes intact. - const auto AtLeastOneMerged = [this]() { - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) - return true; - } - return false; - }; - if (!AtLeastOneMerged()) - return false; - + assert(atLeastOneMerged() && "simplifying trivial BCECmpChain"); LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block " << EntryBlock_->getName() << "\n"); // Effectively merge blocks. We go in the reverse direction from the phi block // so that the next block is always available to branch to. - const auto mergeRange = [this, &TLI, &AA, &DTU](int I, int Num, - BasicBlock *InsertBefore, - BasicBlock *Next) { - return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num), - InsertBefore, Next, Phi_, TLI, AA, DTU); - }; - int NumMerged = 1; + BasicBlock *InsertBefore = EntryBlock_; BasicBlock *NextCmpBlock = Phi_.getParent(); - for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) { - if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) { - LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName() - << " into " << Comparisons_[I + 1].BB->getName() - << "\n"); - ++NumMerged; - } else { - NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock); - NumMerged = 1; - } + for (const auto &Blocks : reverse(MergedBlocks_)) { + InsertBefore = NextCmpBlock = mergeComparisons( + Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU); } - // Insert the entry block for the new chain before the old entry block. - // If the old entry block was the function entry, this ensures that the new - // entry can become the function entry. - NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock); // Replace the original cmp chain with the new cmp chain by pointing all // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp @@ -729,13 +702,16 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, // Delete merged blocks. This also removes incoming values in phi. SmallVector<BasicBlock *, 16> DeadBlocks; - for (auto &Cmp : Comparisons_) { - LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n"); - DeadBlocks.push_back(Cmp.BB); + for (const auto &Blocks : MergedBlocks_) { + for (const BCECmpBlock &Block : Blocks) { + LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName() + << "\n"); + DeadBlocks.push_back(Block.BB); + } } DeleteDeadBlocks(DeadBlocks, &DTU); - Comparisons_.clear(); + MergedBlocks_.clear(); return true; } @@ -835,8 +811,8 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA, if (Blocks.empty()) return false; BCECmpChain CmpChain(Blocks, Phi, AA); - if (CmpChain.size() < 2) { - LLVM_DEBUG(dbgs() << "skip: only one compare block\n"); + if (!CmpChain.atLeastOneMerged()) { + LLVM_DEBUG(dbgs() << "skip: nothing merged\n"); return false; } @@ -862,9 +838,9 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI, bool MadeChange = false; - for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { + for (BasicBlock &BB : llvm::drop_begin(F)) { // A Phi operation is always first in a basic block. - if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) + if (auto *const Phi = dyn_cast<PHINode>(&*BB.begin())) MadeChange |= processPhi(*Phi, TLI, AA, DTU); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 033fc168a67f..734532a6670c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -420,3 +420,12 @@ MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserveSet<CFGAnalyses>(); return PA; } + +void MergedLoadStoreMotionPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<MergedLoadStoreMotionPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + OS << (Options.SplitFooterBB ? "" : "no-") << "split-footer-bb"; + OS << ">"; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index ded5caf53b5a..6dca30d9876e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -282,8 +282,12 @@ NaryReassociatePass::matchAndReassociateMinOrMax(Instruction *I, m_Value(LHS), m_Value(RHS)); if (match(I, MinMaxMatcher)) { OrigSCEV = SE->getSCEV(I); - return dyn_cast_or_null<Instruction>( - tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS)); + if (auto *NewMinMax = dyn_cast_or_null<Instruction>( + tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS))) + return NewMinMax; + if (auto *NewMinMax = dyn_cast_or_null<Instruction>( + tryReassociateMinOrMax(I, MinMaxMatcher, RHS, LHS))) + return NewMinMax; } return nullptr; } @@ -596,58 +600,60 @@ Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I, Value *LHS, Value *RHS) { Value *A = nullptr, *B = nullptr; MaxMinT m_MaxMin(m_Value(A), m_Value(B)); - for (unsigned int i = 0; i < 2; ++i) { - if (!LHS->hasNUsesOrMore(3) && match(LHS, m_MaxMin)) { - const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); - const SCEV *RHSExpr = SE->getSCEV(RHS); - for (unsigned int j = 0; j < 2; ++j) { - if (j == 0) { - if (BExpr == RHSExpr) - continue; - // Transform 'I = (A op B) op RHS' to 'I = (A op RHS) op B' on the - // first iteration. - std::swap(BExpr, RHSExpr); - } else { - if (AExpr == RHSExpr) - continue; - // Transform 'I = (A op RHS) op B' 'I = (B op RHS) op A' on the second - // iteration. - std::swap(AExpr, RHSExpr); - } - - // The optimization is profitable only if LHS can be removed in the end. - // In other words LHS should be used (directly or indirectly) by I only. - if (llvm::any_of(LHS->users(), [&](auto *U) { - return U != I && !(U->hasOneUser() && *U->users().begin() == I); - })) - continue; - - SCEVExpander Expander(*SE, *DL, "nary-reassociate"); - SmallVector<const SCEV *, 2> Ops1{ BExpr, AExpr }; - const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin); - const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1); - - Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I); - - if (!R1MinMax) - continue; - - LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax - << "\n"); - - R1Expr = SE->getUnknown(R1MinMax); - SmallVector<const SCEV *, 2> Ops2{ RHSExpr, R1Expr }; - const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2); - - Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I); - NewMinMax->setName(Twine(I->getName()).concat(".nary")); - - LLVM_DEBUG(dbgs() << "NARY: Deleting: " << *I << "\n" - << "NARY: Inserting: " << *NewMinMax << "\n"); - return NewMinMax; - } - } - std::swap(LHS, RHS); + + if (LHS->hasNUsesOrMore(3) || + // The optimization is profitable only if LHS can be removed in the end. + // In other words LHS should be used (directly or indirectly) by I only. + llvm::any_of(LHS->users(), + [&](auto *U) { + return U != I && + !(U->hasOneUser() && *U->users().begin() == I); + }) || + !match(LHS, m_MaxMin)) + return nullptr; + + auto tryCombination = [&](Value *A, const SCEV *AExpr, Value *B, + const SCEV *BExpr, Value *C, + const SCEV *CExpr) -> Value * { + SmallVector<const SCEV *, 2> Ops1{BExpr, AExpr}; + const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin); + const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1); + + Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I); + + if (!R1MinMax) + return nullptr; + + LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n"); + + SmallVector<const SCEV *, 2> Ops2{SE->getUnknown(C), + SE->getUnknown(R1MinMax)}; + const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2); + + SCEVExpander Expander(*SE, *DL, "nary-reassociate"); + Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I); + NewMinMax->setName(Twine(I->getName()).concat(".nary")); + + LLVM_DEBUG(dbgs() << "NARY: Deleting: " << *I << "\n" + << "NARY: Inserting: " << *NewMinMax << "\n"); + return NewMinMax; + }; + + const SCEV *AExpr = SE->getSCEV(A); + const SCEV *BExpr = SE->getSCEV(B); + const SCEV *RHSExpr = SE->getSCEV(RHS); + + if (BExpr != RHSExpr) { + // Try (A op RHS) op B + if (auto *NewMinMax = tryCombination(A, AExpr, RHS, RHSExpr, B, BExpr)) + return NewMinMax; + } + + if (AExpr != RHSExpr) { + // Try (RHS op B) op A + if (auto *NewMinMax = tryCombination(RHS, RHSExpr, B, BExpr, A, AExpr)) + return NewMinMax; } + return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp index a137d13c6ea0..91215cd19e2b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -1194,9 +1194,10 @@ NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { SimplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), SQ); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; - } else if (isa<GetElementPtrInst>(I)) { - Value *V = SimplifyGEPInst( - E->getType(), ArrayRef<Value *>(E->op_begin(), E->op_end()), SQ); + } else if (auto *GEPI = dyn_cast<GetElementPtrInst>(I)) { + Value *V = SimplifyGEPInst(GEPI->getSourceElementType(), + ArrayRef<Value *>(E->op_begin(), E->op_end()), + GEPI->isInBounds(), SQ); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (AllConstant) { @@ -1818,7 +1819,7 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { // See if we know something about the comparison itself, like it is the target // of an assume. auto *CmpPI = PredInfo->getPredicateInfoFor(I); - if (dyn_cast_or_null<PredicateAssume>(CmpPI)) + if (isa_and_nonnull<PredicateAssume>(CmpPI)) return ExprResult::some( createConstantExpression(ConstantInt::getTrue(CI->getType()))); @@ -3606,7 +3607,7 @@ void NewGVN::convertClassToDFSOrdered( // Skip uses in unreachable blocks, as we're going // to delete them. - if (ReachableBlocks.count(IBlock) == 0) + if (!ReachableBlocks.contains(IBlock)) continue; DomTreeNode *DomNode = DT->getNode(IBlock); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 7872c553b412..44027ccd92ca 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -82,7 +82,7 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, // Add attribute "readnone" so that backend can use a native sqrt instruction // for this call. - Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); + Call->addFnAttr(Attribute::ReadNone); // Insert a FP compare instruction and use it as the CurrBB branch condition. Builder.SetInsertPoint(CurrBBTerm); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp index 888edc4d69a8..b0fb8daaba8f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -140,7 +140,7 @@ XorOpnd::XorOpnd(Value *V) { // view the operand as "V | 0" SymbolicPart = V; - ConstPart = APInt::getNullValue(V->getType()->getScalarSizeInBits()); + ConstPart = APInt::getZero(V->getType()->getScalarSizeInBits()); isOr = true; } @@ -1279,10 +1279,10 @@ static Value *OptimizeAndOrXor(unsigned Opcode, /// be returned. static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, const APInt &ConstOpnd) { - if (ConstOpnd.isNullValue()) + if (ConstOpnd.isZero()) return nullptr; - if (ConstOpnd.isAllOnesValue()) + if (ConstOpnd.isAllOnes()) return Opnd; Instruction *I = BinaryOperator::CreateAnd( @@ -1304,7 +1304,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, // = ((x | c1) ^ c1) ^ (c1 ^ c2) // = (x & ~c1) ^ (c1 ^ c2) // It is useful only when c1 == c2. - if (!Opnd1->isOrExpr() || Opnd1->getConstPart().isNullValue()) + if (!Opnd1->isOrExpr() || Opnd1->getConstPart().isZero()) return false; if (!Opnd1->getValue()->hasOneUse()) @@ -1361,7 +1361,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt C3((~C1) ^ C2); // Do not increase code size! - if (!C3.isNullValue() && !C3.isAllOnesValue()) { + if (!C3.isZero() && !C3.isAllOnes()) { int NewInstNum = ConstOpnd.getBoolValue() ? 1 : 2; if (NewInstNum > DeadInstNum) return false; @@ -1377,7 +1377,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt C3 = C1 ^ C2; // Do not increase code size - if (!C3.isNullValue() && !C3.isAllOnesValue()) { + if (!C3.isZero() && !C3.isAllOnes()) { int NewInstNum = ConstOpnd.getBoolValue() ? 1 : 2; if (NewInstNum > DeadInstNum) return false; @@ -1468,8 +1468,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, Value *CV; // Step 3.1: Try simplifying "CurrOpnd ^ ConstOpnd" - if (!ConstOpnd.isNullValue() && - CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) { + if (!ConstOpnd.isZero() && CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) { Changed = true; if (CV) *CurrOpnd = XorOpnd(CV); @@ -1510,7 +1509,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, ValueEntry VE(getRank(O.getValue()), O.getValue()); Ops.push_back(VE); } - if (!ConstOpnd.isNullValue()) { + if (!ConstOpnd.isZero()) { Value *C = ConstantInt::get(Ty, ConstOpnd); ValueEntry VE(getRank(C), C); Ops.push_back(VE); @@ -1519,7 +1518,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, if (Sz == 1) return Ops.back().Op; if (Sz == 0) { - assert(ConstOpnd.isNullValue()); + assert(ConstOpnd.isZero()); return ConstantInt::get(Ty, ConstOpnd); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index bc0fecc972fc..2d3490b2d29e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -755,7 +755,7 @@ public: } bool operator==(const BDVState &Other) const { - return OriginalValue == OriginalValue && BaseValue == Other.BaseValue && + return OriginalValue == Other.OriginalValue && BaseValue == Other.BaseValue && Status == Other.Status; } @@ -910,7 +910,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { #ifndef NDEBUG VerifyStates(); LLVM_DEBUG(dbgs() << "States after initialization:\n"); - for (auto Pair : States) { + for (const auto &Pair : States) { LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); } #endif @@ -1002,7 +1002,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { #ifndef NDEBUG VerifyStates(); LLVM_DEBUG(dbgs() << "States after meet iteration:\n"); - for (auto Pair : States) { + for (const auto &Pair : States) { LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); } #endif @@ -1163,7 +1163,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // llvm::Value of the correct type (and still remain pure). // This will remove the need to add bitcasts. assert(Base->stripPointerCasts() == OldBase->stripPointerCasts() && - "Sanity -- findBaseOrBDV should be pure!"); + "findBaseOrBDV should be pure!"); #endif } Value *Base = BlockToValue[InBB]; @@ -1377,11 +1377,11 @@ static AttributeList legalizeCallAttributes(LLVMContext &Ctx, return AL; // Remove the readonly, readnone, and statepoint function attributes. - AttrBuilder FnAttrs = AL.getFnAttributes(); + AttrBuilder FnAttrs = AL.getFnAttrs(); for (auto Attr : FnAttrsToStrip) FnAttrs.removeAttribute(Attr); - for (Attribute A : AL.getFnAttributes()) { + for (Attribute A : AL.getFnAttrs()) { if (isStatepointDirectiveAttr(A)) FnAttrs.remove(A); } @@ -1533,9 +1533,8 @@ static StringRef getDeoptLowering(CallBase *Call) { // FIXME: Calls have a *really* confusing interface around attributes // with values. const AttributeList &CSAS = Call->getAttributes(); - if (CSAS.hasAttribute(AttributeList::FunctionIndex, DeoptLowering)) - return CSAS.getAttribute(AttributeList::FunctionIndex, DeoptLowering) - .getValueAsString(); + if (CSAS.hasFnAttr(DeoptLowering)) + return CSAS.getFnAttr(DeoptLowering).getValueAsString(); Function *F = Call->getCalledFunction(); assert(F && F->hasFnAttribute(DeoptLowering)); return F->getFnAttribute(DeoptLowering).getValueAsString(); @@ -1801,7 +1800,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ CallInst *GCResult = Builder.CreateGCResult(Token, Call->getType(), Name); GCResult->setAttributes( AttributeList::get(GCResult->getContext(), AttributeList::ReturnIndex, - Call->getAttributes().getRetAttributes())); + Call->getAttributes().getRetAttrs())); // We cannot RAUW or delete CS.getInstruction() because it could be in the // live set of some other safepoint, in which case that safepoint's @@ -1855,7 +1854,7 @@ makeStatepointExplicit(DominatorTree &DT, CallBase *Call, // It receives iterator to the statepoint gc relocates and emits a store to the // assigned location (via allocaMap) for the each one of them. It adds the // visited values into the visitedLiveValues set, which we will later use them -// for sanity checking. +// for validation checking. static void insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, DenseMap<Value *, AllocaInst *> &AllocaMap, @@ -2454,7 +2453,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, SmallVectorImpl<CallBase *> &ToUpdate, DefiningValueMapTy &DVCache) { #ifndef NDEBUG - // sanity check the input + // Validate the input std::set<CallBase *> Uniqued; Uniqued.insert(ToUpdate.begin(), ToUpdate.end()); assert(Uniqued.size() == ToUpdate.size() && "no duplicates please!"); @@ -2620,9 +2619,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // we just grab that. llvm::append_range(Live, Info.StatepointToken->gc_args()); #ifndef NDEBUG - // Do some basic sanity checks on our liveness results before performing - // relocation. Relocation can and will turn mistakes in liveness results - // into non-sensical code which is must harder to debug. + // Do some basic validation checking on our liveness results before + // performing relocation. Relocation can and will turn mistakes in liveness + // results into non-sensical code which is must harder to debug. // TODO: It would be nice to test consistency as well assert(DT.isReachableFromEntry(Info.StatepointToken->getParent()) && "statepoint must be reachable or liveness is meaningless"); @@ -2641,7 +2640,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, unique_unsorted(Live); #ifndef NDEBUG - // sanity check + // Validation check for (auto *Ptr : Live) assert(isHandledGCPointerType(Ptr->getType()) && "must be a gc pointer type"); @@ -2656,18 +2655,19 @@ template <typename AttrHolder> static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH, unsigned Index) { AttrBuilder R; - if (AH.getDereferenceableBytes(Index)) + AttributeSet AS = AH.getAttributes().getAttributes(Index); + if (AS.getDereferenceableBytes()) R.addAttribute(Attribute::get(Ctx, Attribute::Dereferenceable, - AH.getDereferenceableBytes(Index))); - if (AH.getDereferenceableOrNullBytes(Index)) + AS.getDereferenceableBytes())); + if (AS.getDereferenceableOrNullBytes()) R.addAttribute(Attribute::get(Ctx, Attribute::DereferenceableOrNull, - AH.getDereferenceableOrNullBytes(Index))); + AS.getDereferenceableOrNullBytes())); for (auto Attr : ParamAttrsToStrip) - if (AH.getAttributes().hasAttribute(Index, Attr)) + if (AS.hasAttribute(Attr)) R.addAttribute(Attr); if (!R.empty()) - AH.setAttributes(AH.getAttributes().removeAttributes(Ctx, Index, R)); + AH.setAttributes(AH.getAttributes().removeAttributesAtIndex(Ctx, Index, R)); } static void stripNonValidAttributesFromPrototype(Function &F) { @@ -3016,7 +3016,7 @@ static SetVector<Value *> computeKillSet(BasicBlock *BB) { #ifndef NDEBUG /// Check that the items in 'Live' dominate 'TI'. This is used as a basic -/// sanity check for the liveness computation. +/// validation check for the liveness computation. static void checkBasicSSA(DominatorTree &DT, SetVector<Value *> &Live, Instruction *TI, bool TermOkay = false) { for (Value *V : Live) { @@ -3103,7 +3103,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F, } // while (!Worklist.empty()) #ifndef NDEBUG - // Sanity check our output against SSA properties. This helps catch any + // Verify our output against SSA properties. This helps catch any // missing kills during the above iteration. for (BasicBlock &BB : F) checkBasicSSA(DT, Data, BB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp index b09f896d0157..28e00c873361 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -490,17 +490,17 @@ bool llvm::runIPSCCP( AttrBuilder AttributesToRemove; AttributesToRemove.addAttribute(Attribute::ArgMemOnly); AttributesToRemove.addAttribute(Attribute::InaccessibleMemOrArgMemOnly); - F.removeAttributes(AttributeList::FunctionIndex, AttributesToRemove); + F.removeFnAttrs(AttributesToRemove); for (User *U : F.users()) { auto *CB = dyn_cast<CallBase>(U); if (!CB || CB->getCalledFunction() != &F) continue; - CB->removeAttributes(AttributeList::FunctionIndex, - AttributesToRemove); + CB->removeFnAttrs(AttributesToRemove); } } + MadeChanges |= ReplacedPointerArg; } SmallPtrSet<Value *, 32> InsertedValues; @@ -540,14 +540,13 @@ bool llvm::runIPSCCP( DTU.deleteBB(DeadBB); for (BasicBlock &BB : F) { - for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { - Instruction *Inst = &*BI++; - if (Solver.getPredicateInfoFor(Inst)) { - if (auto *II = dyn_cast<IntrinsicInst>(Inst)) { + for (Instruction &Inst : llvm::make_early_inc_range(BB)) { + if (Solver.getPredicateInfoFor(&Inst)) { + if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { if (II->getIntrinsicID() == Intrinsic::ssa_copy) { Value *Op = II->getOperand(0); - Inst->replaceAllUsesWith(Op); - Inst->eraseFromParent(); + Inst.replaceAllUsesWith(Op); + Inst.eraseFromParent(); } } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp index fe160d5415bd..31c8999c3724 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp @@ -122,7 +122,7 @@ namespace { class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter { std::string Prefix; - const Twine getNameWithPrefix(const Twine &Name) const { + Twine getNameWithPrefix(const Twine &Name) const { return Name.isTriviallyEmpty() ? Name : Prefix + Name; } @@ -1275,8 +1275,7 @@ static void speculatePHINodeLoads(PHINode &PN) { // Get the AA tags and alignment to use from one of the loads. It does not // matter which one we get and if any differ. - AAMDNodes AATags; - SomeLoad->getAAMetadata(AATags); + AAMDNodes AATags = SomeLoad->getAAMetadata(); Align Alignment = SomeLoad->getAlign(); // Rewrite all loads of the PN to use the new PHI. @@ -1330,14 +1329,21 @@ static void speculatePHINodeLoads(PHINode &PN) { /// %V = select i1 %cond, i32 %V1, i32 %V2 /// /// We can do this to a select if its only uses are loads and if the operand -/// to the select can be loaded unconditionally. +/// to the select can be loaded unconditionally. If found an intervening bitcast +/// with a single use of the load, allow the promotion. static bool isSafeSelectToSpeculate(SelectInst &SI) { Value *TValue = SI.getTrueValue(); Value *FValue = SI.getFalseValue(); const DataLayout &DL = SI.getModule()->getDataLayout(); for (User *U : SI.users()) { - LoadInst *LI = dyn_cast<LoadInst>(U); + LoadInst *LI; + BitCastInst *BC = dyn_cast<BitCastInst>(U); + if (BC && BC->hasOneUse()) + LI = dyn_cast<LoadInst>(*BC->user_begin()); + else + LI = dyn_cast<LoadInst>(U); + if (!LI || !LI->isSimple()) return false; @@ -1363,13 +1369,27 @@ static void speculateSelectInstLoads(SelectInst &SI) { Value *FV = SI.getFalseValue(); // Replace the loads of the select with a select of two loads. while (!SI.use_empty()) { - LoadInst *LI = cast<LoadInst>(SI.user_back()); + LoadInst *LI; + BitCastInst *BC = dyn_cast<BitCastInst>(SI.user_back()); + if (BC) { + assert(BC->hasOneUse() && "Bitcast should have a single use."); + LI = cast<LoadInst>(BC->user_back()); + } else { + LI = cast<LoadInst>(SI.user_back()); + } + assert(LI->isSimple() && "We only speculate simple loads"); IRB.SetInsertPoint(LI); - LoadInst *TL = IRB.CreateLoad(LI->getType(), TV, + Value *NewTV = + BC ? IRB.CreateBitCast(TV, BC->getType(), TV->getName() + ".sroa.cast") + : TV; + Value *NewFV = + BC ? IRB.CreateBitCast(FV, BC->getType(), FV->getName() + ".sroa.cast") + : FV; + LoadInst *TL = IRB.CreateLoad(LI->getType(), NewTV, LI->getName() + ".sroa.speculate.load.true"); - LoadInst *FL = IRB.CreateLoad(LI->getType(), FV, + LoadInst *FL = IRB.CreateLoad(LI->getType(), NewFV, LI->getName() + ".sroa.speculate.load.false"); NumLoadsSpeculated += 2; @@ -1377,8 +1397,7 @@ static void speculateSelectInstLoads(SelectInst &SI) { TL->setAlignment(LI->getAlign()); FL->setAlignment(LI->getAlign()); - AAMDNodes Tags; - LI->getAAMetadata(Tags); + AAMDNodes Tags = LI->getAAMetadata(); if (Tags) { TL->setAAMetadata(Tags); FL->setAAMetadata(Tags); @@ -1390,6 +1409,8 @@ static void speculateSelectInstLoads(SelectInst &SI) { LLVM_DEBUG(dbgs() << " speculated to: " << *V << "\n"); LI->replaceAllUsesWith(V); LI->eraseFromParent(); + if (BC) + BC->eraseFromParent(); } SI.eraseFromParent(); } @@ -1462,76 +1483,6 @@ static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL, return buildGEP(IRB, BasePtr, Indices, NamePrefix); } -/// Recursively compute indices for a natural GEP. -/// -/// This is the recursive step for getNaturalGEPWithOffset that walks down the -/// element types adding appropriate indices for the GEP. -static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, - Value *Ptr, Type *Ty, APInt &Offset, - Type *TargetTy, - SmallVectorImpl<Value *> &Indices, - const Twine &NamePrefix) { - if (Offset == 0) - return getNaturalGEPWithType(IRB, DL, Ptr, Ty, TargetTy, Indices, - NamePrefix); - - // We can't recurse through pointer types. - if (Ty->isPointerTy()) - return nullptr; - - // We try to analyze GEPs over vectors here, but note that these GEPs are - // extremely poorly defined currently. The long-term goal is to remove GEPing - // over a vector from the IR completely. - if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) { - unsigned ElementSizeInBits = - DL.getTypeSizeInBits(VecTy->getScalarType()).getFixedSize(); - if (ElementSizeInBits % 8 != 0) { - // GEPs over non-multiple of 8 size vector elements are invalid. - return nullptr; - } - APInt ElementSize(Offset.getBitWidth(), ElementSizeInBits / 8); - APInt NumSkippedElements = Offset.sdiv(ElementSize); - if (NumSkippedElements.ugt(cast<FixedVectorType>(VecTy)->getNumElements())) - return nullptr; - Offset -= NumSkippedElements * ElementSize; - Indices.push_back(IRB.getInt(NumSkippedElements)); - return getNaturalGEPRecursively(IRB, DL, Ptr, VecTy->getElementType(), - Offset, TargetTy, Indices, NamePrefix); - } - - if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { - Type *ElementTy = ArrTy->getElementType(); - APInt ElementSize(Offset.getBitWidth(), - DL.getTypeAllocSize(ElementTy).getFixedSize()); - APInt NumSkippedElements = Offset.sdiv(ElementSize); - if (NumSkippedElements.ugt(ArrTy->getNumElements())) - return nullptr; - - Offset -= NumSkippedElements * ElementSize; - Indices.push_back(IRB.getInt(NumSkippedElements)); - return getNaturalGEPRecursively(IRB, DL, Ptr, ElementTy, Offset, TargetTy, - Indices, NamePrefix); - } - - StructType *STy = dyn_cast<StructType>(Ty); - if (!STy) - return nullptr; - - const StructLayout *SL = DL.getStructLayout(STy); - uint64_t StructOffset = Offset.getZExtValue(); - if (StructOffset >= SL->getSizeInBytes()) - return nullptr; - unsigned Index = SL->getElementContainingOffset(StructOffset); - Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index)); - Type *ElementTy = STy->getElementType(Index); - if (Offset.uge(DL.getTypeAllocSize(ElementTy).getFixedSize())) - return nullptr; // The offset points into alignment padding. - - Indices.push_back(IRB.getInt32(Index)); - return getNaturalGEPRecursively(IRB, DL, Ptr, ElementTy, Offset, TargetTy, - Indices, NamePrefix); -} - /// Get a natural GEP from a base pointer to a particular offset and /// resulting in a particular type. /// @@ -1556,18 +1507,15 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, Type *ElementTy = Ty->getElementType(); if (!ElementTy->isSized()) return nullptr; // We can't GEP through an unsized element. - if (isa<ScalableVectorType>(ElementTy)) + + SmallVector<APInt> IntIndices = DL.getGEPIndicesForOffset(ElementTy, Offset); + if (Offset != 0) return nullptr; - APInt ElementSize(Offset.getBitWidth(), - DL.getTypeAllocSize(ElementTy).getFixedSize()); - if (ElementSize == 0) - return nullptr; // Zero-length arrays can't help us build a natural GEP. - APInt NumSkippedElements = Offset.sdiv(ElementSize); - - Offset -= NumSkippedElements * ElementSize; - Indices.push_back(IRB.getInt(NumSkippedElements)); - return getNaturalGEPRecursively(IRB, DL, Ptr, ElementTy, Offset, TargetTy, - Indices, NamePrefix); + + for (const APInt &Index : IntIndices) + Indices.push_back(IRB.getInt(Index)); + return getNaturalGEPWithType(IRB, DL, Ptr, ElementTy, TargetTy, Indices, + NamePrefix); } /// Compute an adjusted pointer from Ptr by Offset bytes where the @@ -1588,6 +1536,15 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, APInt Offset, Type *PointerTy, const Twine &NamePrefix) { + // Create i8 GEP for opaque pointers. + if (Ptr->getType()->isOpaquePointerTy()) { + if (Offset != 0) + Ptr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(Offset), + NamePrefix + "sroa_idx"); + return IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, PointerTy, + NamePrefix + "sroa_cast"); + } + // Even though we don't look through PHI nodes, we could be called on an // instruction in an unreachable block, which may be on a cycle. SmallPtrSet<Value *, 4> Visited; @@ -1851,13 +1808,13 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { if (!II->isLifetimeStartOrEnd() && !II->isDroppable()) return false; - } else if (U->get()->getType()->getPointerElementType()->isStructTy()) { - // Disable vector promotion when there are loads or stores of an FCA. - return false; } else if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) { if (LI->isVolatile()) return false; Type *LTy = LI->getType(); + // Disable vector promotion when there are loads or stores of an FCA. + if (LTy->isStructTy()) + return false; if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) { assert(LTy->isIntegerTy()); LTy = SplitIntTy; @@ -1868,6 +1825,9 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, if (SI->isVolatile()) return false; Type *STy = SI->getValueOperand()->getType(); + // Disable vector promotion when there are loads or stores of an FCA. + if (STy->isStructTy()) + return false; if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) { assert(STy->isIntegerTy()); STy = SplitIntTy; @@ -2282,7 +2242,7 @@ class llvm::sroa::AllocaSliceRewriter const DataLayout &DL; AllocaSlices &AS; - SROA &Pass; + SROAPass &Pass; AllocaInst &OldAI, &NewAI; const uint64_t NewAllocaBeginOffset, NewAllocaEndOffset; Type *NewAllocaTy; @@ -2330,7 +2290,7 @@ class llvm::sroa::AllocaSliceRewriter IRBuilderTy IRB; public: - AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROA &Pass, + AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROAPass &Pass, AllocaInst &OldAI, AllocaInst &NewAI, uint64_t NewAllocaBeginOffset, uint64_t NewAllocaEndOffset, bool IsIntegerPromotable, @@ -2510,8 +2470,7 @@ private: Value *OldOp = LI.getOperand(0); assert(OldOp == OldPtr); - AAMDNodes AATags; - LI.getAAMetadata(AATags); + AAMDNodes AATags = LI.getAAMetadata(); unsigned AS = LI.getPointerAddressSpace(); @@ -2675,9 +2634,7 @@ private: Value *OldOp = SI.getOperand(1); assert(OldOp == OldPtr); - AAMDNodes AATags; - SI.getAAMetadata(AATags); - + AAMDNodes AATags = SI.getAAMetadata(); Value *V = SI.getValueOperand(); // Strip all inbounds GEPs and pointer casts to try to dig out any root @@ -2743,7 +2700,9 @@ private: deleteIfTriviallyDead(OldOp); LLVM_DEBUG(dbgs() << " to: " << *NewSI << "\n"); - return NewSI->getPointerOperand() == &NewAI && !SI.isVolatile(); + return NewSI->getPointerOperand() == &NewAI && + NewSI->getValueOperand()->getType() == NewAllocaTy && + !SI.isVolatile(); } /// Compute an integer value from splatting an i8 across the given @@ -2784,8 +2743,7 @@ private: LLVM_DEBUG(dbgs() << " original: " << II << "\n"); assert(II.getRawDest() == OldPtr); - AAMDNodes AATags; - II.getAAMetadata(AATags); + AAMDNodes AATags = II.getAAMetadata(); // If the memset has a variable size, it cannot be split, just adjust the // pointer to the new alloca. @@ -2913,8 +2871,7 @@ private: LLVM_DEBUG(dbgs() << " original: " << II << "\n"); - AAMDNodes AATags; - II.getAAMetadata(AATags); + AAMDNodes AATags = II.getAAMetadata(); bool IsDest = &II.getRawDestUse() == OldUse; assert((IsDest && II.getRawDest() == OldPtr) || @@ -3421,9 +3378,7 @@ private: // We have an aggregate being loaded, split it apart. LLVM_DEBUG(dbgs() << " original: " << LI << "\n"); - AAMDNodes AATags; - LI.getAAMetadata(AATags); - LoadOpSplitter Splitter(&LI, *U, LI.getType(), AATags, + LoadOpSplitter Splitter(&LI, *U, LI.getType(), LI.getAAMetadata(), getAdjustedAlignment(&LI, 0), DL); Value *V = UndefValue::get(LI.getType()); Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); @@ -3474,9 +3429,7 @@ private: // We have an aggregate being stored, split it apart. LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); - AAMDNodes AATags; - SI.getAAMetadata(AATags); - StoreOpSplitter Splitter(&SI, *U, V->getType(), AATags, + StoreOpSplitter Splitter(&SI, *U, V->getType(), SI.getAAMetadata(), getAdjustedAlignment(&SI, 0), DL); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); Visited.erase(&SI); @@ -3802,7 +3755,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, /// there all along. /// /// \returns true if any changes are made. -bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { +bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { LLVM_DEBUG(dbgs() << "Pre-splitting loads and stores\n"); // Track the loads and stores which are candidates for pre-splitting here, in @@ -4282,8 +4235,8 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { /// appropriate new offsets. It also evaluates how successful the rewrite was /// at enabling promotion and if it was successful queues the alloca to be /// promoted. -AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, - Partition &P) { +AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, + Partition &P) { // Try to compute a friendly type for this partition of the alloca. This // won't always succeed, in which case we fall back to a legal integer type // or an i8 array of an appropriate size. @@ -4434,7 +4387,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, /// Walks the slices of an alloca and form partitions based on them, /// rewriting each of their uses. -bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { +bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { if (AS.begin() == AS.end()) return false; @@ -4605,7 +4558,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } /// Clobber a use with undef, deleting the used value if it becomes dead. -void SROA::clobberUse(Use &U) { +void SROAPass::clobberUse(Use &U) { Value *OldV = U; // Replace the use with an undef value. U = UndefValue::get(OldV->getType()); @@ -4624,7 +4577,7 @@ void SROA::clobberUse(Use &U) { /// This analyzes the alloca to ensure we can reason about it, builds /// the slices of the alloca, and then hands it off to be split and /// rewritten as needed. -bool SROA::runOnAlloca(AllocaInst &AI) { +bool SROAPass::runOnAlloca(AllocaInst &AI) { LLVM_DEBUG(dbgs() << "SROA alloca: " << AI << "\n"); ++NumAllocasAnalyzed; @@ -4698,7 +4651,7 @@ bool SROA::runOnAlloca(AllocaInst &AI) { /// /// We also record the alloca instructions deleted here so that they aren't /// subsequently handed to mem2reg to promote. -bool SROA::deleteDeadInstructions( +bool SROAPass::deleteDeadInstructions( SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) { bool Changed = false; while (!DeadInsts.empty()) { @@ -4737,7 +4690,7 @@ bool SROA::deleteDeadInstructions( /// This attempts to promote whatever allocas have been identified as viable in /// the PromotableAllocas list. If that list is empty, there is nothing to do. /// This function returns whether any promotion occurred. -bool SROA::promoteAllocas(Function &F) { +bool SROAPass::promoteAllocas(Function &F) { if (PromotableAllocas.empty()) return false; @@ -4749,8 +4702,8 @@ bool SROA::promoteAllocas(Function &F) { return true; } -PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, - AssumptionCache &RunAC) { +PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, + AssumptionCache &RunAC) { LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); C = &F.getContext(); DT = &RunDT; @@ -4804,7 +4757,7 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, return PA; } -PreservedAnalyses SROA::run(Function &F, FunctionAnalysisManager &AM) { +PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) { return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F), AM.getResult<AssumptionAnalysis>(F)); } @@ -4815,7 +4768,7 @@ PreservedAnalyses SROA::run(Function &F, FunctionAnalysisManager &AM) { /// SROA pass. class llvm::sroa::SROALegacyPass : public FunctionPass { /// The SROA implementation. - SROA Impl; + SROAPass Impl; public: static char ID; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index ca288a533f46..1284bae820a4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -873,13 +873,11 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI, auto &DL = F.getParent()->getDataLayout(); while (MadeChange) { MadeChange = false; - for (Function::iterator I = F.begin(); I != F.end();) { - BasicBlock *BB = &*I++; + for (BasicBlock &BB : llvm::make_early_inc_range(F)) { bool ModifiedDTOnIteration = false; - MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL, + MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, DTU.hasValue() ? DTU.getPointer() : nullptr); - // Restart BB iteration if the dominator tree of the Function was changed if (ModifiedDTOnIteration) break; @@ -933,7 +931,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, if (II) { // The scalarization code below does not work for scalable vectors. if (isa<ScalableVectorType>(II->getType()) || - any_of(II->arg_operands(), + any_of(II->args(), [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 8ef6b69673be..6b7419abe1d1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -66,6 +66,15 @@ static cl::opt<bool> namespace { +BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) { + BasicBlock *BB = Itr->getParent(); + if (isa<PHINode>(Itr)) + Itr = BB->getFirstInsertionPt(); + if (Itr != BB->end()) + Itr = skipDebugIntrinsics(Itr); + return Itr; +} + // Used to store the scattered form of a vector. using ValueVector = SmallVector<Value *, 8>; @@ -371,10 +380,11 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { return Scatterer(Point->getParent(), Point->getIterator(), UndefValue::get(V->getType())); // Put the scattered form of an instruction directly after the - // instruction. + // instruction, skipping over PHI nodes and debug intrinsics. BasicBlock *BB = VOp->getParent(); - return Scatterer(BB, std::next(BasicBlock::iterator(VOp)), - V, &Scattered[V]); + return Scatterer( + BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, + &Scattered[V]); } // In the fallback case, just put the scattered before Point and // keep the result local to Point. @@ -530,7 +540,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { return false; unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); - unsigned NumArgs = CI.getNumArgOperands(); + unsigned NumArgs = CI.arg_size(); ValueVector ScalarOperands(NumArgs); SmallVector<Scatterer, 8> Scattered(NumArgs); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index f216956406b6..ffa2f9adb978 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -1164,8 +1164,11 @@ bool SeparateConstOffsetFromGEP::run(Function &F) { DL = &F.getParent()->getDataLayout(); bool Changed = false; for (BasicBlock &B : F) { - for (BasicBlock::iterator I = B.begin(), IE = B.end(); I != IE;) - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I++)) + if (!DT->isReachableFromEntry(&B)) + continue; + + for (Instruction &I : llvm::make_early_inc_range(B)) + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) Changed |= splitGEP(GEP); // No need to split GEP ConstantExprs because all its indices are constant // already. @@ -1258,10 +1261,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) { DominatingSubs.clear(); for (const auto Node : depth_first(DT)) { BasicBlock *BB = Node->getBlock(); - for (auto I = BB->begin(); I != BB->end(); ) { - Instruction *Cur = &*I++; - Changed |= reuniteExts(Cur); - } + for (Instruction &I : llvm::make_early_inc_range(*BB)) + Changed |= reuniteExts(&I); } return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index b1c105258027..a27da047bfd3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -49,7 +50,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GenericDomTree.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -81,6 +81,7 @@ static cl::opt<bool> EnableNonTrivialUnswitch( static cl::opt<int> UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, + cl::ZeroOrMore, cl::desc("The cost threshold for unswitching a loop.")); static cl::opt<bool> EnableUnswitchCostMultiplier( @@ -108,6 +109,10 @@ static cl::opt<unsigned> cl::desc("Max number of memory uses to explore during " "partial unswitching analysis"), cl::init(100), cl::Hidden); +static cl::opt<bool> FreezeLoopUnswitchCond( + "freeze-loop-unswitch-cond", cl::init(false), cl::Hidden, + cl::desc("If enabled, the freeze instruction will be added to condition " + "of loop unswitch to prevent miscompilation.")); /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. @@ -195,15 +200,15 @@ static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, /// Copy a set of loop invariant values \p ToDuplicate and insert them at the /// end of \p BB and conditionally branch on the copied condition. We only /// branch on a single value. -static void buildPartialUnswitchConditionalBranch(BasicBlock &BB, - ArrayRef<Value *> Invariants, - bool Direction, - BasicBlock &UnswitchedSucc, - BasicBlock &NormalSucc) { +static void buildPartialUnswitchConditionalBranch( + BasicBlock &BB, ArrayRef<Value *> Invariants, bool Direction, + BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, bool InsertFreeze) { IRBuilder<> IRB(&BB); Value *Cond = Direction ? IRB.CreateOr(Invariants) : IRB.CreateAnd(Invariants); + if (InsertFreeze) + Cond = IRB.CreateFreeze(Cond, Cond->getName() + ".fr"); IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, Direction ? &NormalSucc : &UnswitchedSucc); } @@ -564,7 +569,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, "Must have an `and` of `i1`s or `select i1 X, Y, false`s for the" " condition!"); buildPartialUnswitchConditionalBranch(*OldPH, Invariants, ExitDirection, - *UnswitchedBB, *NewPH); + *UnswitchedBB, *NewPH, false); } // Update the dominator tree with the added edge. @@ -2123,6 +2128,13 @@ static void unswitchNontrivialInvariants( SE->forgetTopmostLoop(&L); } + bool InsertFreeze = false; + if (FreezeLoopUnswitchCond) { + ICFLoopSafetyInfo SafetyInfo; + SafetyInfo.computeLoopSafetyInfo(&L); + InsertFreeze = !SafetyInfo.isGuaranteedToExecute(TI, &DT, &L); + } + // If the edge from this terminator to a successor dominates that successor, // store a map from each block in its dominator subtree to it. This lets us // tell when cloning for a particular successor if a block is dominated by @@ -2197,6 +2209,11 @@ static void unswitchNontrivialInvariants( BasicBlock *ClonedPH = ClonedPHs.begin()->second; BI->setSuccessor(ClonedSucc, ClonedPH); BI->setSuccessor(1 - ClonedSucc, LoopPH); + if (InsertFreeze) { + auto Cond = BI->getCondition(); + if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, BI, &DT)) + BI->setCondition(new FreezeInst(Cond, Cond->getName() + ".fr", BI)); + } DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); } else { assert(SI && "Must either be a branch or switch!"); @@ -2211,6 +2228,11 @@ static void unswitchNontrivialInvariants( else Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); + if (InsertFreeze) { + auto Cond = SI->getCondition(); + if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, SI, &DT)) + SI->setCondition(new FreezeInst(Cond, Cond->getName() + ".fr", SI)); + } // We need to use the set to populate domtree updates as even when there // are multiple cases pointing at the same successor we only want to // remove and insert one edge in the domtree. @@ -2291,7 +2313,7 @@ static void unswitchNontrivialInvariants( *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU); else buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, - *ClonedPH, *LoopPH); + *ClonedPH, *LoopPH, InsertFreeze); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); if (MSSAU) { @@ -2370,7 +2392,9 @@ static void unswitchNontrivialInvariants( ConstantInt *ContinueReplacement = Direction ? ConstantInt::getFalse(BI->getContext()) : ConstantInt::getTrue(BI->getContext()); - for (Value *Invariant : Invariants) + for (Value *Invariant : Invariants) { + assert(!isa<Constant>(Invariant) && + "Should not be replacing constant values!"); // Use make_early_inc_range here as set invalidates the iterator. for (Use &U : llvm::make_early_inc_range(Invariant->uses())) { Instruction *UserI = dyn_cast<Instruction>(U.getUser()); @@ -2385,6 +2409,7 @@ static void unswitchNontrivialInvariants( DT.dominates(ClonedPH, UserI->getParent())) U.set(UnswitchedReplacement); } + } } // We can change which blocks are exit blocks of all the cloned sibling @@ -2727,6 +2752,9 @@ static bool unswitchBestCondition( Cond = CondNext; BI->setCondition(Cond); + if (isa<Constant>(Cond)) + continue; + if (L.isLoopInvariant(BI->getCondition())) { UnswitchCandidates.push_back({BI, {BI->getCondition()}}); continue; @@ -3121,6 +3149,17 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, return PA; } +void SimpleLoopUnswitchPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<SimpleLoopUnswitchPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + + OS << "<"; + OS << (NonTrivial ? "" : "no-") << "nontrivial;"; + OS << (Trivial ? "" : "no-") << "trivial"; + OS << ">"; +} + namespace { class SimpleLoopUnswitchLegacyPass : public LoopPass { @@ -3140,10 +3179,8 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - if (EnableMSSALoopDependency) { - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); getLoopAnalysisUsage(AU); } }; @@ -3164,12 +3201,8 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - MemorySSA *MSSA = nullptr; - Optional<MemorySSAUpdater> MSSAU; - if (EnableMSSALoopDependency) { - MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = MemorySSAUpdater(MSSA); - } + MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MemorySSAUpdater MSSAU(MSSA); auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; @@ -3197,15 +3230,13 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { LPM.markLoopAsDeleted(L); }; - if (MSSA && VerifyMemorySSA) + if (VerifyMemorySSA) MSSA->verifyMemorySSA(); - bool Changed = - unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, UnswitchCB, SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, - DestroyLoopCB); + bool Changed = unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, + UnswitchCB, SE, &MSSAU, DestroyLoopCB); - if (MSSA && VerifyMemorySSA) + if (VerifyMemorySSA) MSSA->verifyMemorySSA(); // Historically this pass has had issues with the dominator tree so verify it diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 09d59b0e884a..86d3620c312e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -224,7 +224,11 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, SmallVector<WeakVH, 16> LoopHeaders(UniqueLoopHeaders.begin(), UniqueLoopHeaders.end()); + unsigned IterCnt = 0; + (void)IterCnt; while (LocalChange) { + assert(IterCnt++ < 1000 && + "Sanity: iterative simplification didn't converge!"); LocalChange = false; // Loop over all of the basic blocks and remove them if they are unneeded. @@ -319,6 +323,21 @@ SimplifyCFGPass::SimplifyCFGPass(const SimplifyCFGOptions &Opts) applyCommandLineOverridesToOptions(Options); } +void SimplifyCFGPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<SimplifyCFGPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << "<"; + OS << "bonus-inst-threshold=" << Options.BonusInstThreshold << ";"; + OS << (Options.ForwardSwitchCondToPhi ? "" : "no-") << "forward-switch-cond;"; + OS << (Options.ConvertSwitchToLookupTable ? "" : "no-") + << "switch-to-lookup;"; + OS << (Options.NeedCanonicalLoop ? "" : "no-") << "keep-loops;"; + OS << (Options.HoistCommonInsts ? "" : "no-") << "hoist-common-insts;"; + OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts"; + OS << ">"; +} + PreservedAnalyses SimplifyCFGPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index dfa30418ea01..06169a7834f6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -268,7 +268,7 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( if (const auto *DVI = dyn_cast<DbgVariableIntrinsic>(U)) { return all_of(DVI->location_ops(), [&NotHoisted](Value *V) { if (const auto *I = dyn_cast_or_null<Instruction>(V)) { - if (NotHoisted.count(I) == 0) + if (!NotHoisted.contains(I)) return true; } return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 20b8b982e14b..b47378808216 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -607,7 +607,7 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, if (IndexOffset == 1) return C.Stride; // Common case 2: if (i' - i) is -1, Bump = -S. - if (IndexOffset.isAllOnesValue()) + if (IndexOffset.isAllOnes()) return Builder.CreateNeg(C.Stride); // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may @@ -620,7 +620,7 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2()); return Builder.CreateShl(ExtendedStride, Exponent); } - if ((-IndexOffset).isPowerOf2()) { + if (IndexOffset.isNegatedPowerOf2()) { // If (i - i') is a power of 2, Bump = -sext/trunc(S) << log(i' - i). ConstantInt *Exponent = ConstantInt::get(DeltaType, (-IndexOffset).logBase2()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 846a9321f53e..3bcf92e28a21 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -262,7 +262,7 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) { // Note that this runs whether we know an alloca has escaped or not. If // it has, then we can't trust Tracker.AllocaUsers to be accurate. bool SafeToTail = true; - for (auto &Arg : CI->arg_operands()) { + for (auto &Arg : CI->args()) { if (isa<Constant>(Arg.getUser())) continue; if (Argument *A = dyn_cast<Argument>(Arg.getUser())) @@ -584,8 +584,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { // call instruction into the newly created temporarily variable. void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(CallInst *CI, int OpndIdx) { - PointerType *ArgTy = cast<PointerType>(CI->getArgOperand(OpndIdx)->getType()); - Type *AggTy = ArgTy->getElementType(); + Type *AggTy = CI->getParamByValType(OpndIdx); + assert(AggTy); const DataLayout &DL = F.getParent()->getDataLayout(); // Get alignment of byVal operand. @@ -611,8 +611,8 @@ void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(CallInst *CI, // into the corresponding function argument location. void TailRecursionEliminator::copyLocalTempOfByValueOperandIntoArguments( CallInst *CI, int OpndIdx) { - PointerType *ArgTy = cast<PointerType>(CI->getArgOperand(OpndIdx)->getType()); - Type *AggTy = ArgTy->getElementType(); + Type *AggTy = CI->getParamByValType(OpndIdx); + assert(AggTy); const DataLayout &DL = F.getParent()->getDataLayout(); // Get alignment of byVal operand. @@ -667,7 +667,7 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { createTailRecurseLoopHeader(CI); // Copy values of ByVal operands into local temporarily variables. - for (unsigned I = 0, E = CI->getNumArgOperands(); I != E; ++I) { + for (unsigned I = 0, E = CI->arg_size(); I != E; ++I) { if (CI->isByValArgument(I)) copyByValueOperandIntoLocalTemp(CI, I); } @@ -675,7 +675,7 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { // Ok, now that we know we have a pseudo-entry block WITH all of the // required PHI nodes, add entries into the PHI node for the actual // parameters passed into the tail-recursive call. - for (unsigned I = 0, E = CI->getNumArgOperands(); I != E; ++I) { + for (unsigned I = 0, E = CI->arg_size(); I != E; ++I) { if (CI->isByValArgument(I)) { copyLocalTempOfByValueOperandIntoArguments(CI, I); ArgumentPHIs[I]->addIncoming(F.getArg(I), BB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp index 8cd16ca3906f..fdc914a72bfd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp @@ -63,6 +63,9 @@ static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) { auto Int64Ty = Builder.getInt64Ty(); auto M = Builder.GetInsertBlock()->getModule(); auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty); + if (!M->getModuleFlag("amdgpu_hostcall")) { + M->addModuleFlag(llvm::Module::Override, "amdgpu_hostcall", 1); + } return Builder.CreateCall(Fn, Version); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp index 01912297324a..cbc508bb863a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp @@ -33,14 +33,14 @@ static inline bool CompareVars(const ASanStackVariableDescription &a, // We also force minimal alignment for all vars to kMinAlignment so that vars // with e.g. alignment 1 and alignment 16 do not get reordered by CompareVars. -static const size_t kMinAlignment = 16; +static const uint64_t kMinAlignment = 16; // We want to add a full redzone after every variable. // The larger the variable Size the larger is the redzone. // The resulting frame size is a multiple of Alignment. -static size_t VarAndRedzoneSize(size_t Size, size_t Granularity, - size_t Alignment) { - size_t Res = 0; +static uint64_t VarAndRedzoneSize(uint64_t Size, uint64_t Granularity, + uint64_t Alignment) { + uint64_t Res = 0; if (Size <= 4) Res = 16; else if (Size <= 16) Res = 32; else if (Size <= 128) Res = Size + 32; @@ -52,7 +52,7 @@ static size_t VarAndRedzoneSize(size_t Size, size_t Granularity, ASanStackFrameLayout ComputeASanStackFrameLayout(SmallVectorImpl<ASanStackVariableDescription> &Vars, - size_t Granularity, size_t MinHeaderSize) { + uint64_t Granularity, uint64_t MinHeaderSize) { assert(Granularity >= 8 && Granularity <= 64 && (Granularity & (Granularity - 1)) == 0); assert(MinHeaderSize >= 16 && (MinHeaderSize & (MinHeaderSize - 1)) == 0 && @@ -67,22 +67,22 @@ ComputeASanStackFrameLayout(SmallVectorImpl<ASanStackVariableDescription> &Vars, ASanStackFrameLayout Layout; Layout.Granularity = Granularity; Layout.FrameAlignment = std::max(Granularity, Vars[0].Alignment); - size_t Offset = std::max(std::max(MinHeaderSize, Granularity), - Vars[0].Alignment); + uint64_t Offset = + std::max(std::max(MinHeaderSize, Granularity), Vars[0].Alignment); assert((Offset % Granularity) == 0); for (size_t i = 0; i < NumVars; i++) { bool IsLast = i == NumVars - 1; - size_t Alignment = std::max(Granularity, Vars[i].Alignment); + uint64_t Alignment = std::max(Granularity, Vars[i].Alignment); (void)Alignment; // Used only in asserts. - size_t Size = Vars[i].Size; + uint64_t Size = Vars[i].Size; assert((Alignment & (Alignment - 1)) == 0); assert(Layout.FrameAlignment >= Alignment); assert((Offset % Alignment) == 0); assert(Size > 0); - size_t NextAlignment = IsLast ? Granularity - : std::max(Granularity, Vars[i + 1].Alignment); - size_t SizeWithRedzone = VarAndRedzoneSize(Size, Granularity, - NextAlignment); + uint64_t NextAlignment = + IsLast ? Granularity : std::max(Granularity, Vars[i + 1].Alignment); + uint64_t SizeWithRedzone = + VarAndRedzoneSize(Size, Granularity, NextAlignment); Vars[i].Offset = Offset; Offset += SizeWithRedzone; } @@ -118,7 +118,7 @@ GetShadowBytes(const SmallVectorImpl<ASanStackVariableDescription> &Vars, assert(Vars.size() > 0); SmallVector<uint8_t, 64> SB; SB.clear(); - const size_t Granularity = Layout.Granularity; + const uint64_t Granularity = Layout.Granularity; SB.resize(Vars[0].Offset / Granularity, kAsanStackLeftRedzoneMagic); for (const auto &Var : Vars) { SB.resize(Var.Offset / Granularity, kAsanStackMidRedzoneMagic); @@ -135,13 +135,13 @@ SmallVector<uint8_t, 64> GetShadowBytesAfterScope( const SmallVectorImpl<ASanStackVariableDescription> &Vars, const ASanStackFrameLayout &Layout) { SmallVector<uint8_t, 64> SB = GetShadowBytes(Vars, Layout); - const size_t Granularity = Layout.Granularity; + const uint64_t Granularity = Layout.Granularity; for (const auto &Var : Vars) { assert(Var.LifetimeSize <= Var.Size); - const size_t LifetimeShadowSize = + const uint64_t LifetimeShadowSize = (Var.LifetimeSize + Granularity - 1) / Granularity; - const size_t Offset = Var.Offset / Granularity; + const uint64_t Offset = Var.Offset / Granularity; std::fill(SB.begin() + Offset, SB.begin() + Offset + LifetimeShadowSize, kAsanStackUseAfterScopeMagic); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index d689e04da36f..f910f7c3c31f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -67,7 +67,8 @@ bool isUsefullToPreserve(Attribute::AttrKind Kind) { /// This function will try to transform the given knowledge into a more /// canonical one. the canonical knowledge maybe the given one. -RetainedKnowledge canonicalizedKnowledge(RetainedKnowledge RK, DataLayout DL) { +RetainedKnowledge canonicalizedKnowledge(RetainedKnowledge RK, + const DataLayout &DL) { switch (RK.AttrKind) { default: return RK; @@ -103,7 +104,7 @@ struct AssumeBuilderState { Module *M; using MapKey = std::pair<Value *, Attribute::AttrKind>; - SmallMapVector<MapKey, unsigned, 8> AssumedKnowledgeMap; + SmallMapVector<MapKey, uint64_t, 8> AssumedKnowledgeMap; Instruction *InstBeingModified = nullptr; AssumptionCache* AC = nullptr; DominatorTree* DT = nullptr; @@ -196,28 +197,27 @@ struct AssumeBuilderState { (!ShouldPreserveAllAttributes && !isUsefullToPreserve(Attr.getKindAsEnum()))) return; - unsigned AttrArg = 0; + uint64_t AttrArg = 0; if (Attr.isIntAttribute()) AttrArg = Attr.getValueAsInt(); addKnowledge({Attr.getKindAsEnum(), AttrArg, WasOn}); } void addCall(const CallBase *Call) { - auto addAttrList = [&](AttributeList AttrList) { - for (unsigned Idx = AttributeList::FirstArgIndex; - Idx < AttrList.getNumAttrSets(); Idx++) - for (Attribute Attr : AttrList.getAttributes(Idx)) { + auto addAttrList = [&](AttributeList AttrList, unsigned NumArgs) { + for (unsigned Idx = 0; Idx < NumArgs; Idx++) + for (Attribute Attr : AttrList.getParamAttrs(Idx)) { bool IsPoisonAttr = Attr.hasAttribute(Attribute::NonNull) || Attr.hasAttribute(Attribute::Alignment); - if (!IsPoisonAttr || Call->isPassingUndefUB(Idx - 1)) - addAttribute(Attr, Call->getArgOperand(Idx - 1)); + if (!IsPoisonAttr || Call->isPassingUndefUB(Idx)) + addAttribute(Attr, Call->getArgOperand(Idx)); } - for (Attribute Attr : AttrList.getFnAttributes()) + for (Attribute Attr : AttrList.getFnAttrs()) addAttribute(Attr, nullptr); }; - addAttrList(Call->getAttributes()); + addAttrList(Call->getAttributes(), Call->arg_size()); if (Function *Fn = Call->getCalledFunction()) - addAttrList(Fn->getAttributes()); + addAttrList(Fn->getAttributes(), Fn->arg_size()); } AssumeInst *build() { @@ -261,8 +261,7 @@ struct AssumeBuilderState { addKnowledge({Attribute::NonNull, 0u, Pointer}); } if (MA.valueOrOne() > 1) - addKnowledge( - {Attribute::Alignment, unsigned(MA.valueOrOne().value()), Pointer}); + addKnowledge({Attribute::Alignment, MA.valueOrOne().value(), Pointer}); } void addInstruction(Instruction *I) { @@ -392,7 +391,7 @@ struct AssumeSimplify { void dropRedundantKnowledge() { struct MapValue { IntrinsicInst *Assume; - unsigned ArgValue; + uint64_t ArgValue; CallInst::BundleOpInfo *BOI; }; buildMapping(false); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index ee933b638a23..6469c899feea 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -39,6 +39,7 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" @@ -52,6 +53,12 @@ using namespace llvm; #define DEBUG_TYPE "basicblock-utils" +static cl::opt<unsigned> MaxDeoptOrUnreachableSuccessorCheckDepth( + "max-deopt-or-unreachable-succ-check-depth", cl::init(8), cl::Hidden, + cl::desc("Set the maximum path length when checking whether a basic block " + "is followed by a block that either has a terminating " + "deoptimizing call or is terminated with an unreachable")); + void llvm::DetatchDeadBlocks( ArrayRef<BasicBlock *> BBs, SmallVectorImpl<DominatorTree::UpdateType> *Updates, @@ -230,7 +237,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, if (DTU) { SmallPtrSet<BasicBlock *, 2> SuccsOfBB(succ_begin(BB), succ_end(BB)); SmallPtrSet<BasicBlock *, 2> SuccsOfPredBB(succ_begin(PredBB), - succ_begin(PredBB)); + succ_end(PredBB)); Updates.reserve(Updates.size() + 2 * SuccsOfBB.size() + 1); // Add insert edges first. Experimentally, for the particular case of two // blocks that can be merged, with a single successor and single predecessor @@ -485,6 +492,20 @@ void llvm::ReplaceInstWithInst(BasicBlock::InstListType &BIL, BI = New; } +bool llvm::IsBlockFollowedByDeoptOrUnreachable(const BasicBlock *BB) { + // Remember visited blocks to avoid infinite loop + SmallPtrSet<const BasicBlock *, 8> VisitedBlocks; + unsigned Depth = 0; + while (BB && Depth++ < MaxDeoptOrUnreachableSuccessorCheckDepth && + VisitedBlocks.insert(BB).second) { + if (BB->getTerminatingDeoptimizeCall() || + isa<UnreachableInst>(BB->getTerminator())) + return true; + BB = BB->getUniqueSuccessor(); + } + return false; +} + void llvm::ReplaceInstWithInst(Instruction *From, Instruction *To) { BasicBlock::iterator BI(From); ReplaceInstWithInst(From->getParent()->getInstList(), BI, To); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 35e22f7a57e2..957935398972 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -96,9 +96,9 @@ static bool setDoesNotThrow(Function &F) { } static bool setRetDoesNotAlias(Function &F) { - if (F.hasAttribute(AttributeList::ReturnIndex, Attribute::NoAlias)) + if (F.hasRetAttribute(Attribute::NoAlias)) return false; - F.addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + F.addRetAttr(Attribute::NoAlias); ++NumNoAlias; return true; } @@ -145,8 +145,8 @@ static bool setSignExtendedArg(Function &F, unsigned ArgNo) { static bool setRetNoUndef(Function &F) { if (!F.getReturnType()->isVoidTy() && - !F.hasAttribute(AttributeList::ReturnIndex, Attribute::NoUndef)) { - F.addAttribute(AttributeList::ReturnIndex, Attribute::NoUndef); + !F.hasRetAttribute(Attribute::NoUndef)) { + F.addRetAttr(Attribute::NoUndef); ++NumNoUndef; return true; } @@ -174,7 +174,10 @@ static bool setArgNoUndef(Function &F, unsigned ArgNo) { } static bool setRetAndArgsNoUndef(Function &F) { - return setRetNoUndef(F) | setArgsNoUndef(F); + bool UndefAdded = false; + UndefAdded |= setRetNoUndef(F); + UndefAdded |= setArgsNoUndef(F); + return UndefAdded; } static bool setReturnedArg(Function &F, unsigned ArgNo) { @@ -1268,7 +1271,7 @@ Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - Type *I8Ptr = B.getInt8PtrTy(); + Type *I8Ptr = Dst->getType(); return emitLibCall(LibFunc_strcpy, I8Ptr, {I8Ptr, I8Ptr}, {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI); } @@ -1453,9 +1456,8 @@ static Value *emitUnaryFloatFnCallHelper(Value *Op, StringRef Name, // The incoming attribute set may have come from a speculatable intrinsic, but // is being replaced with a library call which is not allowed to be // speculatable. - CI->setAttributes(Attrs.removeAttribute(B.getContext(), - AttributeList::FunctionIndex, - Attribute::Speculatable)); + CI->setAttributes( + Attrs.removeFnAttribute(B.getContext(), Attribute::Speculatable)); if (const Function *F = dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1498,9 +1500,8 @@ static Value *emitBinaryFloatFnCallHelper(Value *Op1, Value *Op2, // The incoming attribute set may have come from a speculatable intrinsic, but // is being replaced with a library call which is not allowed to be // speculatable. - CI->setAttributes(Attrs.removeAttribute(B.getContext(), - AttributeList::FunctionIndex, - Attribute::Speculatable)); + CI->setAttributes( + Attrs.removeFnAttribute(B.getContext(), Attribute::Speculatable)); if (const Function *F = dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1655,8 +1656,8 @@ Value *llvm::emitMalloc(Value *Num, IRBuilderBase &B, const DataLayout &DL, return CI; } -Value *llvm::emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, - IRBuilderBase &B, const TargetLibraryInfo &TLI) { +Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B, + const TargetLibraryInfo &TLI) { if (!TLI.has(LibFunc_calloc)) return nullptr; @@ -1664,8 +1665,8 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, StringRef CallocName = TLI.getName(LibFunc_calloc); const DataLayout &DL = M->getDataLayout(); IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); - FunctionCallee Calloc = M->getOrInsertFunction( - CallocName, Attrs, B.getInt8PtrTy(), PtrType, PtrType); + FunctionCallee Calloc = + M->getOrInsertFunction(CallocName, B.getInt8PtrTy(), PtrType, PtrType); inferLibFuncAttributes(M, CallocName, TLI); CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 87868251036c..ebe19f1751e5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -424,6 +424,21 @@ bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee, *FailureReason = "Argument type mismatch"; return false; } + // Make sure that the callee and call agree on byval/inalloca. The types do + // not have to match. + + if (Callee->hasParamAttribute(I, Attribute::ByVal) != + CB.getAttributes().hasParamAttr(I, Attribute::ByVal)) { + if (FailureReason) + *FailureReason = "byval mismatch"; + return false; + } + if (Callee->hasParamAttribute(I, Attribute::InAlloca) != + CB.getAttributes().hasParamAttr(I, Attribute::InAlloca)) { + if (FailureReason) + *FailureReason = "inalloca mismatch"; + return false; + } } for (; I < NumArgs; I++) { // Vararg functions can have more arguments than parameters. @@ -485,18 +500,19 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee, CB.setArgOperand(ArgNo, Cast); // Remove any incompatible attributes for the argument. - AttrBuilder ArgAttrs(CallerPAL.getParamAttributes(ArgNo)); + AttrBuilder ArgAttrs(CallerPAL.getParamAttrs(ArgNo)); ArgAttrs.remove(AttributeFuncs::typeIncompatible(FormalTy)); - // If byval is used, this must be a pointer type, and the byval type must - // match the element type. Update it if present. + // We may have a different byval/inalloca type. if (ArgAttrs.getByValType()) ArgAttrs.addByValAttr(Callee->getParamByValType(ArgNo)); + if (ArgAttrs.getInAllocaType()) + ArgAttrs.addInAllocaAttr(Callee->getParamInAllocaType(ArgNo)); NewArgAttrs.push_back(AttributeSet::get(Ctx, ArgAttrs)); AttributeChanged = true; } else - NewArgAttrs.push_back(CallerPAL.getParamAttributes(ArgNo)); + NewArgAttrs.push_back(CallerPAL.getParamAttrs(ArgNo)); } // If the return type of the call site doesn't match that of the callee, cast @@ -511,7 +527,7 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee, // Set the new callsite attribute. if (AttributeChanged) - CB.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttributes(), + CB.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttrs(), AttributeSet::get(Ctx, RAttrs), NewArgAttrs)); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp index 1f649fe6c748..049c7d113521 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp @@ -33,7 +33,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/IVDescriptors.h" -#include "llvm/Analysis/IVUsers.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp index 0ac9a5aaa425..048e691e33cf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -62,7 +62,7 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, NewBB->getInstList().push_back(NewInst); VMap[&I] = NewInst; // Add instruction map to value. - hasCalls |= (isa<CallInst>(I) && !isa<DbgInfoIntrinsic>(I)); + hasCalls |= (isa<CallInst>(I) && !I.isDebugOrPseudoInst()); if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { if (!AI->isStaticAlloca()) { hasDynamicAllocas = true; @@ -116,13 +116,13 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, for (const Argument &OldArg : OldFunc->args()) { if (Argument *NewArg = dyn_cast<Argument>(VMap[&OldArg])) { NewArgAttrs[NewArg->getArgNo()] = - OldAttrs.getParamAttributes(OldArg.getArgNo()); + OldAttrs.getParamAttrs(OldArg.getArgNo()); } } NewFunc->setAttributes( - AttributeList::get(NewFunc->getContext(), OldAttrs.getFnAttributes(), - OldAttrs.getRetAttributes(), NewArgAttrs)); + AttributeList::get(NewFunc->getContext(), OldAttrs.getFnAttrs(), + OldAttrs.getRetAttrs(), NewArgAttrs)); // Everything else beyond this point deals with function instructions, // so if we are dealing with a function declaration, we're done. @@ -410,7 +410,7 @@ void PruningFunctionCloner::CloneBlock( NewInst->setName(II->getName() + NameSuffix); VMap[&*II] = NewInst; // Add instruction map to value. NewBB->getInstList().push_back(NewInst); - hasCalls |= (isa<CallInst>(II) && !isa<DbgInfoIntrinsic>(II)); + hasCalls |= (isa<CallInst>(II) && !II->isDebugOrPseudoInst()); if (CodeInfo) { CodeInfo->OrigVMap[&*II] = NewInst; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 9edc52b53550..96aff563aa9b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -434,6 +434,7 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { } // Now add the old exit block to the outline region. Blocks.insert(CommonExitBlock); + OldTargets.push_back(NewExitBlock); return CommonExitBlock; } @@ -885,7 +886,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // "target-features" attribute allowing it to be lowered. // FIXME: This should be changed to check to see if a specific // attribute can not be inherited. - for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) { + for (const auto &Attr : oldFunction->getAttributes().getFnAttrs()) { if (Attr.isStringAttribute()) { if (Attr.getKindAsString() == "thunk") continue; @@ -943,6 +944,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: case Attribute::Cold: + case Attribute::DisableSanitizerInstrumentation: case Attribute::Hot: case Attribute::NoRecurse: case Attribute::InlineHint: @@ -1044,9 +1046,8 @@ static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks, const SetVector<Value *> &SunkAllocas, SetVector<Value *> &LifetimesStart) { for (BasicBlock *BB : Blocks) { - for (auto It = BB->begin(), End = BB->end(); It != End;) { - auto *II = dyn_cast<IntrinsicInst>(&*It); - ++It; + for (Instruction &I : llvm::make_early_inc_range(*BB)) { + auto *II = dyn_cast<IntrinsicInst>(&I); if (!II || !II->isLifetimeStartOrEnd()) continue; @@ -1247,45 +1248,57 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // not in the region to be extracted. std::map<BasicBlock *, BasicBlock *> ExitBlockMap; + // Iterate over the previously collected targets, and create new blocks inside + // the function to branch to. unsigned switchVal = 0; + for (BasicBlock *OldTarget : OldTargets) { + if (Blocks.count(OldTarget)) + continue; + BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; + if (NewTarget) + continue; + + // If we don't already have an exit stub for this non-extracted + // destination, create one now! + NewTarget = BasicBlock::Create(Context, + OldTarget->getName() + ".exitStub", + newFunction); + unsigned SuccNum = switchVal++; + + Value *brVal = nullptr; + assert(NumExitBlocks < 0xffff && "too many exit blocks for switch"); + switch (NumExitBlocks) { + case 0: + case 1: break; // No value needed. + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); + break; + default: + brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); + break; + } + + ReturnInst::Create(Context, brVal, NewTarget); + + // Update the switch instruction. + TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), + SuccNum), + OldTarget); + } + for (BasicBlock *Block : Blocks) { Instruction *TI = Block->getTerminator(); - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (!Blocks.count(TI->getSuccessor(i))) { - BasicBlock *OldTarget = TI->getSuccessor(i); - // add a new basic block which returns the appropriate value - BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; - if (!NewTarget) { - // If we don't already have an exit stub for this non-extracted - // destination, create one now! - NewTarget = BasicBlock::Create(Context, - OldTarget->getName() + ".exitStub", - newFunction); - unsigned SuccNum = switchVal++; - - Value *brVal = nullptr; - switch (NumExitBlocks) { - case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool - brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); - break; - default: - brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); - break; - } - - ReturnInst::Create(Context, brVal, NewTarget); - - // Update the switch instruction. - TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), - SuccNum), - OldTarget); - } - - // rewrite the original branch instruction with this new target - TI->setSuccessor(i, NewTarget); - } + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + if (Blocks.count(TI->getSuccessor(i))) + continue; + BasicBlock *OldTarget = TI->getSuccessor(i); + // add a new basic block which returns the appropriate value + BasicBlock *NewTarget = ExitBlockMap[OldTarget]; + assert(NewTarget && "Unknown target block!"); + + // rewrite the original branch instruction with this new target + TI->setSuccessor(i, NewTarget); + } } // Store the arguments right after the definition of output value. @@ -1388,12 +1401,17 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); + auto newFuncIt = newFunction->front().getIterator(); for (BasicBlock *Block : Blocks) { // Delete the basic block from the old function, and the list of blocks oldBlocks.remove(Block); // Insert this basic block into the new function - newBlocks.push_back(Block); + // Insert the original blocks after the entry block created + // for the new function. The entry block may be followed + // by a set of exit blocks at this point, but these exit + // blocks better be placed at the end of the new function. + newFuncIt = newBlocks.insertAfter(newFuncIt, Block); } } @@ -1569,6 +1587,13 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, Function * CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { + ValueSet Inputs, Outputs; + return extractCodeRegion(CEAC, Inputs, Outputs); +} + +Function * +CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + ValueSet &inputs, ValueSet &outputs) { if (!isEligible()) return nullptr; @@ -1593,11 +1618,8 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { // Remove @llvm.assume calls that will be moved to the new function from the // old function's assumption cache. for (BasicBlock *Block : Blocks) { - for (auto It = Block->begin(), End = Block->end(); It != End;) { - Instruction *I = &*It; - ++It; - - if (auto *AI = dyn_cast<AssumeInst>(I)) { + for (Instruction &I : llvm::make_early_inc_range(*Block)) { + if (auto *AI = dyn_cast<AssumeInst>(&I)) { if (AC) AC->unregisterAssumption(AI); AI->eraseFromParent(); @@ -1627,6 +1649,16 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { } NumExitBlocks = ExitBlocks.size(); + for (BasicBlock *Block : Blocks) { + Instruction *TI = Block->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + if (Blocks.count(TI->getSuccessor(i))) + continue; + BasicBlock *OldTarget = TI->getSuccessor(i); + OldTargets.push_back(OldTarget); + } + } + // If we have to split PHI nodes of the entry or exit blocks, do so now. severSplitPHINodesOfEntry(header); severSplitPHINodesOfExits(ExitBlocks); @@ -1657,7 +1689,7 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { } newFuncRoot->getInstList().push_back(BranchI); - ValueSet inputs, outputs, SinkingCands, HoistingCands; + ValueSet SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); assert(HoistingCands.empty() || CommonExit); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp index ce982c7403aa..648f4e64a4d2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp @@ -309,7 +309,7 @@ collectInstructionsInBetween(Instruction &StartInst, const Instruction &EndInst, bool llvm::isSafeToMoveBefore(Instruction &I, Instruction &InsertPoint, DominatorTree &DT, const PostDominatorTree *PDT, - DependenceInfo *DI) { + DependenceInfo *DI, bool CheckForEntireBlock) { // Skip tests when we don't have PDT or DI if (!PDT || !DI) return false; @@ -332,16 +332,24 @@ bool llvm::isSafeToMoveBefore(Instruction &I, Instruction &InsertPoint, if (!isControlFlowEquivalent(I, InsertPoint, DT, *PDT)) return reportInvalidCandidate(I, NotControlFlowEquivalent); - if (!DT.dominates(&InsertPoint, &I)) + if (isReachedBefore(&I, &InsertPoint, &DT, PDT)) for (const Use &U : I.uses()) if (auto *UserInst = dyn_cast<Instruction>(U.getUser())) if (UserInst != &InsertPoint && !DT.dominates(&InsertPoint, U)) return false; - if (!DT.dominates(&I, &InsertPoint)) + if (isReachedBefore(&InsertPoint, &I, &DT, PDT)) for (const Value *Op : I.operands()) - if (auto *OpInst = dyn_cast<Instruction>(Op)) - if (&InsertPoint == OpInst || !DT.dominates(OpInst, &InsertPoint)) + if (auto *OpInst = dyn_cast<Instruction>(Op)) { + if (&InsertPoint == OpInst) + return false; + // If OpInst is an instruction that appears earlier in the same BB as + // I, then it is okay to move since OpInst will still be available. + if (CheckForEntireBlock && I.getParent() == OpInst->getParent() && + DT.dominates(OpInst, &I)) + continue; + if (!DT.dominates(OpInst, &InsertPoint)) return false; + } DT.updateDFSNumbers(); const bool MoveForward = domTreeLevelBefore(&DT, &I, &InsertPoint); @@ -393,7 +401,8 @@ bool llvm::isSafeToMoveBefore(BasicBlock &BB, Instruction &InsertPoint, if (BB.getTerminator() == &I) return true; - return isSafeToMoveBefore(I, InsertPoint, DT, PDT, DI); + return isSafeToMoveBefore(I, InsertPoint, DT, PDT, DI, + /*CheckForEntireBlock=*/true); }); } @@ -401,11 +410,9 @@ void llvm::moveInstructionsToTheBeginning(BasicBlock &FromBB, BasicBlock &ToBB, DominatorTree &DT, const PostDominatorTree &PDT, DependenceInfo &DI) { - for (auto It = ++FromBB.rbegin(); It != FromBB.rend();) { + for (Instruction &I : + llvm::make_early_inc_range(llvm::drop_begin(llvm::reverse(FromBB)))) { Instruction *MovePos = ToBB.getFirstNonPHIOrDbg(); - Instruction &I = *It; - // Increment the iterator before modifying FromBB. - ++It; if (isSafeToMoveBefore(I, *MovePos, DT, &PDT, &DI)) I.moveBefore(MovePos); @@ -423,3 +430,47 @@ void llvm::moveInstructionsToTheEnd(BasicBlock &FromBB, BasicBlock &ToBB, I.moveBefore(MovePos); } } + +bool llvm::nonStrictlyPostDominate(const BasicBlock *ThisBlock, + const BasicBlock *OtherBlock, + const DominatorTree *DT, + const PostDominatorTree *PDT) { + assert(isControlFlowEquivalent(*ThisBlock, *OtherBlock, *DT, *PDT) && + "ThisBlock and OtherBlock must be CFG equivalent!"); + const BasicBlock *CommonDominator = + DT->findNearestCommonDominator(ThisBlock, OtherBlock); + if (CommonDominator == nullptr) + return false; + + /// Recursively check the predecessors of \p ThisBlock up to + /// their common dominator, and see if any of them post-dominates + /// \p OtherBlock. + SmallVector<const BasicBlock *, 8> WorkList; + SmallPtrSet<const BasicBlock *, 8> Visited; + WorkList.push_back(ThisBlock); + while (!WorkList.empty()) { + const BasicBlock *CurBlock = WorkList.back(); + WorkList.pop_back(); + Visited.insert(CurBlock); + if (PDT->dominates(CurBlock, OtherBlock)) + return true; + + for (auto *Pred : predecessors(CurBlock)) { + if (Pred == CommonDominator || Visited.count(Pred)) + continue; + WorkList.push_back(Pred); + } + } + return false; +} + +bool llvm::isReachedBefore(const Instruction *I0, const Instruction *I1, + const DominatorTree *DT, + const PostDominatorTree *PDT) { + const BasicBlock *BB0 = I0->getParent(); + const BasicBlock *BB1 = I1->getParent(); + if (BB0 == BB1) + return DT->dominates(I0, I1); + + return nonStrictlyPostDominate(BB1, BB0, DT, PDT); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp index 30c3fa521d52..fc7083b0c30d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp @@ -457,14 +457,14 @@ static bool checkInstructions(const DebugInstMap &DILocsBefore, } // This checks the preservation of original debug variable intrinsics. -static bool checkVars(const DebugVarMap &DIFunctionsBefore, - const DebugVarMap &DIFunctionsAfter, +static bool checkVars(const DebugVarMap &DIVarsBefore, + const DebugVarMap &DIVarsAfter, StringRef NameOfWrappedPass, StringRef FileNameFromCU, bool ShouldWriteIntoJSON, llvm::json::Array &Bugs) { bool Preserved = true; - for (const auto &V : DIFunctionsBefore) { - auto VarIt = DIFunctionsAfter.find(V.first); - if (VarIt == DIFunctionsAfter.end()) + for (const auto &V : DIVarsBefore) { + auto VarIt = DIVarsAfter.find(V.first); + if (VarIt == DIVarsAfter.end()) continue; unsigned NumOfDbgValsAfter = VarIt->second; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 31d03e1e86af..e3e8f63383df 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -89,7 +89,7 @@ static bool runOnFunction(Function &F, bool PostInlining) { insertCall(F, EntryFunc, &*F.begin()->getFirstInsertionPt(), DL); Changed = true; - F.removeAttribute(AttributeList::FunctionIndex, EntryAttr); + F.removeFnAttr(EntryAttr); } if (!ExitFunc.empty()) { @@ -111,7 +111,7 @@ static bool runOnFunction(Function &F, bool PostInlining) { insertCall(F, ExitFunc, T, DL); Changed = true; } - F.removeAttribute(AttributeList::FunctionIndex, ExitAttr); + F.removeFnAttr(ExitAttr); } return Changed; @@ -183,3 +183,13 @@ llvm::EntryExitInstrumenterPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserveSet<CFGAnalyses>(); return PA; } + +void llvm::EntryExitInstrumenterPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<llvm::EntryExitInstrumenterPass> *>(this) + ->printPipeline(OS, MapClassName2PassName); + OS << "<"; + if (PostInlining) + OS << "post-inline"; + OS << ">"; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp index 463c223d9e8f..9c8aed94708e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -128,11 +128,6 @@ isSimpleEnoughValueToCommit(Constant *C, /// globals and GEP's of globals. This should be kept up to date with /// CommitValueTo. static bool isSimpleEnoughPointerToCommit(Constant *C, const DataLayout &DL) { - // Conservatively, avoid aggregate types. This is because we don't - // want to worry about them partially overlapping other stores. - if (!cast<PointerType>(C->getType())->getElementType()->isSingleValueType()) - return false; - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) // Do not allow weak/*_odr/linkonce linkage or external globals. return GV->hasUniqueInitializer(); @@ -284,7 +279,7 @@ bool Evaluator::getFormalParams(CallBase &CB, Function *F, return false; auto *FTy = F->getFunctionType(); - if (FTy->getNumParams() > CB.getNumArgOperands()) { + if (FTy->getNumParams() > CB.arg_size()) { LLVM_DEBUG(dbgs() << "Too few arguments for function.\n"); return false; } @@ -343,7 +338,10 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB, Ptr = FoldedPtr; LLVM_DEBUG(dbgs() << "; To: " << *Ptr << "\n"); } - if (!isSimpleEnoughPointerToCommit(Ptr, DL)) { + // Conservatively, avoid aggregate types. This is because we don't + // want to worry about them partially overlapping other stores. + if (!SI->getValueOperand()->getType()->isSingleValueType() || + !isSimpleEnoughPointerToCommit(Ptr, DL)) { // If this is too complex for us to commit, reject it. LLVM_DEBUG( dbgs() << "Pointer is too complex for us to evaluate store."); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp index 10f48fe827f4..8de3ce876bab 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp @@ -124,7 +124,7 @@ static void reconnectChildLoops(LoopInfo &LI, Loop *ParentLoop, Loop *NewLoop, // children to a new vector. auto FirstChild = std::partition( CandidateLoops.begin(), CandidateLoops.end(), [&](Loop *L) { - return L == NewLoop || Blocks.count(L->getHeader()) == 0; + return L == NewLoop || !Blocks.contains(L->getHeader()); }); SmallVector<Loop *, 8> ChildLoops(FirstChild, CandidateLoops.end()); CandidateLoops.erase(FirstChild, CandidateLoops.end()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp index dbcacc20b589..ddd3f597ae01 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp @@ -162,7 +162,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { // of \param BB (BB4) and should not have address-taken. // There should exist only one such unconditional // branch among the predecessors. - if (UnCondBlock || !PP || (Preds.count(PP) == 0) || + if (UnCondBlock || !PP || !Preds.contains(PP) || Pred->hasAddressTaken()) return false; @@ -215,7 +215,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { // PS is the successor which is not BB. Check successors to identify // the last conditional branch. - if (Preds.count(PS) == 0) { + if (!Preds.contains(PS)) { // Case 2. LastCondBlock = Pred; } else { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 2696557a719f..326864803d7c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -110,7 +110,7 @@ int FunctionComparator::cmpAttrs(const AttributeList L, if (int Res = cmpNumbers(L.getNumAttrSets(), R.getNumAttrSets())) return Res; - for (unsigned i = L.index_begin(), e = L.index_end(); i != e; ++i) { + for (unsigned i : L.indexes()) { AttributeSet LAS = L.getAttributes(i); AttributeSet RAS = R.getAttributes(i); AttributeSet::iterator LI = LAS.begin(), LE = LAS.end(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp index f782396be7b6..9bfc73e4ba6c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp @@ -105,8 +105,10 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, // value, not an aggregate), keep more specific information about // stores. if (GS.StoredType != GlobalStatus::Stored) { - if (const GlobalVariable *GV = - dyn_cast<GlobalVariable>(SI->getOperand(1))) { + const Value *Ptr = SI->getPointerOperand(); + if (isa<ConstantExpr>(Ptr)) + Ptr = Ptr->stripPointerCasts(); + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) { Value *StoredVal = SI->getOperand(0); if (Constant *C = dyn_cast<Constant>(StoredVal)) { @@ -125,9 +127,9 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, GS.StoredType = GlobalStatus::InitializerStored; } else if (GS.StoredType < GlobalStatus::StoredOnce) { GS.StoredType = GlobalStatus::StoredOnce; - GS.StoredOnceValue = StoredVal; + GS.StoredOnceStore = SI; } else if (GS.StoredType == GlobalStatus::StoredOnce && - GS.StoredOnceValue == StoredVal) { + GS.getStoredOnceValue() == StoredVal) { // noop. } else { GS.StoredType = GlobalStatus::Stored; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp index a1e160d144dc..047bf5569ded 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp @@ -47,7 +47,7 @@ static void addVariantDeclaration(CallInst &CI, const ElementCount &VF, // Add function declaration. Type *RetTy = ToVectorTy(CI.getType(), VF); SmallVector<Type *, 4> Tys; - for (Value *ArgOperand : CI.arg_operands()) + for (Value *ArgOperand : CI.args()) Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); assert(!CI.getFunctionType()->isVarArg() && "VarArg functions are not supported."); @@ -94,8 +94,8 @@ static void addMappingsFromTLI(const TargetLibraryInfo &TLI, CallInst &CI) { const std::string TLIName = std::string(TLI.getVectorizedFunction(ScalarName, VF)); if (!TLIName.empty()) { - std::string MangledName = VFABI::mangleTLIVectorName( - TLIName, ScalarName, CI.getNumArgOperands(), VF); + std::string MangledName = + VFABI::mangleTLIVectorName(TLIName, ScalarName, CI.arg_size(), VF); if (!OriginalSetOfMappings.count(MangledName)) { Mappings.push_back(MangledName); ++NumCallInjected; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp index 792aa8208f27..f4776589910f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -539,12 +539,10 @@ static Value *getUnwindDestToken(Instruction *EHPad, static BasicBlock *HandleCallsInBlockInlinedThroughInvoke( BasicBlock *BB, BasicBlock *UnwindEdge, UnwindDestMemoTy *FuncletUnwindMap = nullptr) { - for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { - Instruction *I = &*BBI++; - + for (Instruction &I : llvm::make_early_inc_range(*BB)) { // We only need to check for function calls: inlined invoke // instructions require no special handling. - CallInst *CI = dyn_cast<CallInst>(I); + CallInst *CI = dyn_cast<CallInst>(&I); if (!CI || CI->doesNotThrow()) continue; @@ -830,6 +828,7 @@ static void PropagateCallSiteMetadata(CallBase &CB, Function::iterator FStart, } } +namespace { /// Utility for cloning !noalias and !alias.scope metadata. When a code region /// using scoped alias metadata is inlined, the aliasing relationships may not /// hold between the two version. It is necessary to create a deep clone of the @@ -851,6 +850,7 @@ public: /// metadata. void remap(Function::iterator FStart, Function::iterator FEnd); }; +} // namespace ScopedAliasMetadataDeepCloner::ScopedAliasMetadataDeepCloner( const Function *F) { @@ -1179,14 +1179,8 @@ static bool MayContainThrowingOrExitingCall(Instruction *Begin, assert(Begin->getParent() == End->getParent() && "Expected to be in same basic block!"); - unsigned NumInstChecked = 0; - // Check that all instructions in the range [Begin, End) are guaranteed to - // transfer execution to successor. - for (auto &I : make_range(Begin->getIterator(), End->getIterator())) - if (NumInstChecked++ > InlinerAttributeWindow || - !isGuaranteedToTransferExecutionToSuccessor(&I)) - return true; - return false; + return !llvm::isGuaranteedToTransferExecutionToSuccessor( + Begin->getIterator(), End->getIterator(), InlinerAttributeWindow + 1); } static AttrBuilder IdentifyValidAttributes(CallBase &CB) { @@ -1259,8 +1253,7 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { // existing attribute value (i.e. attributes such as dereferenceable, // dereferenceable_or_null etc). See AttrBuilder::merge for more details. AttributeList AL = NewRetVal->getAttributes(); - AttributeList NewAL = - AL.addAttributes(Context, AttributeList::ReturnIndex, Valid); + AttributeList NewAL = AL.addRetAttributes(Context, Valid); NewRetVal->setAttributes(NewAL); } } @@ -1376,13 +1369,13 @@ static void UpdateCallGraphAfterInlining(CallBase &CB, CallerNode->removeCallEdgeFor(*cast<CallBase>(&CB)); } -static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, - BasicBlock *InsertBlock, +static void HandleByValArgumentInit(Type *ByValType, Value *Dst, Value *Src, + Module *M, BasicBlock *InsertBlock, InlineFunctionInfo &IFI) { - Type *AggTy = cast<PointerType>(Src->getType())->getElementType(); IRBuilder<> Builder(InsertBlock, InsertBlock->begin()); - Value *Size = Builder.getInt64(M->getDataLayout().getTypeStoreSize(AggTy)); + Value *Size = + Builder.getInt64(M->getDataLayout().getTypeStoreSize(ByValType)); // Always generate a memcpy of alignment 1 here because we don't know // the alignment of the src pointer. Other optimizations can infer @@ -1393,13 +1386,13 @@ static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, /// When inlining a call site that has a byval argument, /// we have to make the implicit memcpy explicit by adding it. -static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, +static Value *HandleByValArgument(Type *ByValType, Value *Arg, + Instruction *TheCall, const Function *CalledFunc, InlineFunctionInfo &IFI, unsigned ByValAlignment) { - PointerType *ArgTy = cast<PointerType>(Arg->getType()); - Type *AggTy = ArgTy->getElementType(); - + assert(cast<PointerType>(Arg->getType()) + ->isOpaqueOrPointeeTypeMatches(ByValType)); Function *Caller = TheCall->getFunction(); const DataLayout &DL = Caller->getParent()->getDataLayout(); @@ -1427,7 +1420,7 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, } // Create the alloca. If we have DataLayout, use nice alignment. - Align Alignment(DL.getPrefTypeAlignment(AggTy)); + Align Alignment(DL.getPrefTypeAlignment(ByValType)); // If the byval had an alignment specified, we *must* use at least that // alignment, as it is required by the byval argument (and uses of the @@ -1435,7 +1428,7 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, Alignment = max(Alignment, MaybeAlign(ByValAlignment)); Value *NewAlloca = - new AllocaInst(AggTy, DL.getAllocaAddrSpace(), nullptr, Alignment, + new AllocaInst(ByValType, DL.getAllocaAddrSpace(), nullptr, Alignment, Arg->getName(), &*Caller->begin()->begin()); IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca)); @@ -1607,8 +1600,7 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, const ProfileCount &CalleeEntryCount, const CallBase &TheCall, ProfileSummaryInfo *PSI, BlockFrequencyInfo *CallerBFI) { - if (!CalleeEntryCount.hasValue() || CalleeEntryCount.isSynthetic() || - CalleeEntryCount.getCount() < 1) + if (CalleeEntryCount.isSynthetic() || CalleeEntryCount.getCount() < 1) return; auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; int64_t CallCount = @@ -1617,40 +1609,39 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, } void llvm::updateProfileCallee( - Function *Callee, int64_t entryDelta, + Function *Callee, int64_t EntryDelta, const ValueMap<const Value *, WeakTrackingVH> *VMap) { auto CalleeCount = Callee->getEntryCount(); if (!CalleeCount.hasValue()) return; - uint64_t priorEntryCount = CalleeCount.getCount(); - uint64_t newEntryCount; + const uint64_t PriorEntryCount = CalleeCount->getCount(); // Since CallSiteCount is an estimate, it could exceed the original callee // count and has to be set to 0 so guard against underflow. - if (entryDelta < 0 && static_cast<uint64_t>(-entryDelta) > priorEntryCount) - newEntryCount = 0; - else - newEntryCount = priorEntryCount + entryDelta; + const uint64_t NewEntryCount = + (EntryDelta < 0 && static_cast<uint64_t>(-EntryDelta) > PriorEntryCount) + ? 0 + : PriorEntryCount + EntryDelta; // During inlining ? if (VMap) { - uint64_t cloneEntryCount = priorEntryCount - newEntryCount; + uint64_t CloneEntryCount = PriorEntryCount - NewEntryCount; for (auto Entry : *VMap) if (isa<CallInst>(Entry.first)) if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) - CI->updateProfWeight(cloneEntryCount, priorEntryCount); + CI->updateProfWeight(CloneEntryCount, PriorEntryCount); } - if (entryDelta) { - Callee->setEntryCount(newEntryCount); + if (EntryDelta) { + Callee->setEntryCount(NewEntryCount); for (BasicBlock &BB : *Callee) // No need to update the callsite if it is pruned during inlining. if (!VMap || VMap->count(&BB)) for (Instruction &I : BB) if (CallInst *CI = dyn_cast<CallInst>(&I)) - CI->updateProfWeight(newEntryCount, priorEntryCount); + CI->updateProfWeight(NewEntryCount, PriorEntryCount); } } @@ -1672,66 +1663,69 @@ void llvm::updateProfileCallee( /// 3. Otherwise, a call to objc_retain is inserted if the call in the caller is /// a retainRV call. static void -inlineRetainOrClaimRVCalls(CallBase &CB, +inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind, const SmallVectorImpl<ReturnInst *> &Returns) { Module *Mod = CB.getModule(); - bool IsRetainRV = objcarc::hasAttachedCallOpBundle(&CB, true), + assert(objcarc::isRetainOrClaimRV(RVCallKind) && "unexpected ARC function"); + bool IsRetainRV = RVCallKind == objcarc::ARCInstKind::RetainRV, IsClaimRV = !IsRetainRV; for (auto *RI : Returns) { Value *RetOpnd = objcarc::GetRCIdentityRoot(RI->getOperand(0)); - BasicBlock::reverse_iterator I = ++(RI->getIterator().getReverse()); - BasicBlock::reverse_iterator EI = RI->getParent()->rend(); bool InsertRetainCall = IsRetainRV; IRBuilder<> Builder(RI->getContext()); // Walk backwards through the basic block looking for either a matching // autoreleaseRV call or an unannotated call. - for (; I != EI;) { - auto CurI = I++; - + auto InstRange = llvm::make_range(++(RI->getIterator().getReverse()), + RI->getParent()->rend()); + for (Instruction &I : llvm::make_early_inc_range(InstRange)) { // Ignore casts. - if (isa<CastInst>(*CurI)) + if (isa<CastInst>(I)) continue; - if (auto *II = dyn_cast<IntrinsicInst>(&*CurI)) { - if (II->getIntrinsicID() == Intrinsic::objc_autoreleaseReturnValue && - II->hasNUses(0) && - objcarc::GetRCIdentityRoot(II->getOperand(0)) == RetOpnd) { - // If we've found a matching authoreleaseRV call: - // - If claimRV is attached to the call, insert a call to objc_release - // and erase the autoreleaseRV call. - // - If retainRV is attached to the call, just erase the autoreleaseRV - // call. - if (IsClaimRV) { - Builder.SetInsertPoint(II); - Function *IFn = - Intrinsic::getDeclaration(Mod, Intrinsic::objc_release); - Value *BC = - Builder.CreateBitCast(RetOpnd, IFn->getArg(0)->getType()); - Builder.CreateCall(IFn, BC, ""); - } - II->eraseFromParent(); - InsertRetainCall = false; - } - } else if (auto *CI = dyn_cast<CallInst>(&*CurI)) { - if (objcarc::GetRCIdentityRoot(CI) == RetOpnd && - !objcarc::hasAttachedCallOpBundle(CI)) { - // If we've found an unannotated call that defines RetOpnd, add a - // "clang.arc.attachedcall" operand bundle. - Value *BundleArgs[] = {ConstantInt::get( - Builder.getInt64Ty(), - objcarc::getAttachedCallOperandBundleEnum(IsRetainRV))}; - OperandBundleDef OB("clang.arc.attachedcall", BundleArgs); - auto *NewCall = CallBase::addOperandBundle( - CI, LLVMContext::OB_clang_arc_attachedcall, OB, CI); - NewCall->copyMetadata(*CI); - CI->replaceAllUsesWith(NewCall); - CI->eraseFromParent(); - InsertRetainCall = false; + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + if (II->getIntrinsicID() != Intrinsic::objc_autoreleaseReturnValue || + !II->hasNUses(0) || + objcarc::GetRCIdentityRoot(II->getOperand(0)) != RetOpnd) + break; + + // If we've found a matching authoreleaseRV call: + // - If claimRV is attached to the call, insert a call to objc_release + // and erase the autoreleaseRV call. + // - If retainRV is attached to the call, just erase the autoreleaseRV + // call. + if (IsClaimRV) { + Builder.SetInsertPoint(II); + Function *IFn = + Intrinsic::getDeclaration(Mod, Intrinsic::objc_release); + Value *BC = Builder.CreateBitCast(RetOpnd, IFn->getArg(0)->getType()); + Builder.CreateCall(IFn, BC, ""); } + II->eraseFromParent(); + InsertRetainCall = false; + break; } + auto *CI = dyn_cast<CallInst>(&I); + + if (!CI) + break; + + if (objcarc::GetRCIdentityRoot(CI) != RetOpnd || + objcarc::hasAttachedCallOpBundle(CI)) + break; + + // If we've found an unannotated call that defines RetOpnd, add a + // "clang.arc.attachedcall" operand bundle. + Value *BundleArgs[] = {*objcarc::getAttachedARCFunction(&CB)}; + OperandBundleDef OB("clang.arc.attachedcall", BundleArgs); + auto *NewCall = CallBase::addOperandBundle( + CI, LLVMContext::OB_clang_arc_attachedcall, OB, CI); + NewCall->copyMetadata(*CI); + CI->replaceAllUsesWith(NewCall); + CI->eraseFromParent(); + InsertRetainCall = false; break; } @@ -1895,8 +1889,13 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, { // Scope to destroy VMap after cloning. ValueToValueMapTy VMap; + struct ByValInit { + Value *Dst; + Value *Src; + Type *Ty; + }; // Keep a list of pair (dst, src) to emit byval initializations. - SmallVector<std::pair<Value*, Value*>, 4> ByValInit; + SmallVector<ByValInit, 4> ByValInits; // When inlining a function that contains noalias scope metadata, // this metadata needs to be cloned so that the inlined blocks @@ -1921,10 +1920,12 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // or readnone, because the copy would be unneeded: the callee doesn't // modify the struct. if (CB.isByValArgument(ArgNo)) { - ActualArg = HandleByValArgument(ActualArg, &CB, CalledFunc, IFI, + ActualArg = HandleByValArgument(CB.getParamByValType(ArgNo), ActualArg, + &CB, CalledFunc, IFI, CalledFunc->getParamAlignment(ArgNo)); if (ActualArg != *AI) - ByValInit.push_back(std::make_pair(ActualArg, (Value*) *AI)); + ByValInits.push_back( + {ActualArg, (Value *)*AI, CB.getParamByValType(ArgNo)}); } VMap[&*I] = ActualArg; @@ -1953,8 +1954,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, FirstNewBlock = LastBlock; ++FirstNewBlock; // Insert retainRV/clainRV runtime calls. - if (objcarc::hasAttachedCallOpBundle(&CB)) - inlineRetainOrClaimRVCalls(CB, Returns); + objcarc::ARCInstKind RVCallKind = objcarc::getAttachedARCFunctionKind(&CB); + if (RVCallKind != objcarc::ARCInstKind::None) + inlineRetainOrClaimRVCalls(CB, RVCallKind, Returns); // Updated caller/callee profiles only when requested. For sample loader // inlining, the context-sensitive inlinee profile doesn't need to be @@ -1966,13 +1968,14 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI, CalledFunc->front()); - updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), CB, - IFI.PSI, IFI.CallerBFI); + if (auto Profile = CalledFunc->getEntryCount()) + updateCallProfile(CalledFunc, VMap, *Profile, CB, IFI.PSI, + IFI.CallerBFI); } // Inject byval arguments initialization. - for (std::pair<Value*, Value*> &Init : ByValInit) - HandleByValArgumentInit(Init.first, Init.second, Caller->getParent(), + for (ByValInit &Init : ByValInits) + HandleByValArgumentInit(Init.Ty, Init.Dst, Init.Src, Caller->getParent(), &*FirstNewBlock, IFI); Optional<OperandBundleUse> ParentDeopt = @@ -2100,9 +2103,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, SmallVector<Value*,4> VarArgsToForward; SmallVector<AttributeSet, 4> VarArgsAttrs; for (unsigned i = CalledFunc->getFunctionType()->getNumParams(); - i < CB.getNumArgOperands(); i++) { + i < CB.arg_size(); i++) { VarArgsToForward.push_back(CB.getArgOperand(i)); - VarArgsAttrs.push_back(CB.getAttributes().getParamAttributes(i)); + VarArgsAttrs.push_back(CB.getAttributes().getParamAttrs(i)); } bool InlinedMustTailCalls = false, InlinedDeoptimizeCalls = false; @@ -2117,8 +2120,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E; ++BB) { - for (auto II = BB->begin(); II != BB->end();) { - Instruction &I = *II++; + for (Instruction &I : llvm::make_early_inc_range(*BB)) { CallInst *CI = dyn_cast<CallInst>(&I); if (!CI) continue; @@ -2135,15 +2137,15 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (!Attrs.isEmpty() || !VarArgsAttrs.empty()) { for (unsigned ArgNo = 0; ArgNo < CI->getFunctionType()->getNumParams(); ++ArgNo) - ArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); + ArgAttrs.push_back(Attrs.getParamAttrs(ArgNo)); } // Add VarArg attributes. ArgAttrs.append(VarArgsAttrs.begin(), VarArgsAttrs.end()); - Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttributes(), - Attrs.getRetAttributes(), ArgAttrs); + Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttrs(), + Attrs.getRetAttrs(), ArgAttrs); // Add VarArgs to existing parameters. - SmallVector<Value *, 6> Params(CI->arg_operands()); + SmallVector<Value *, 6> Params(CI->args()); Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); CallInst *NewCI = CallInst::Create( CI->getFunctionType(), CI->getCalledOperand(), Params, "", CI); @@ -2295,8 +2297,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, BB != E; ++BB) { // Add bundle operands to any top-level call sites. SmallVector<OperandBundleDef, 1> OpBundles; - for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;) { - CallBase *I = dyn_cast<CallBase>(&*BBI++); + for (Instruction &II : llvm::make_early_inc_range(*BB)) { + CallBase *I = dyn_cast<CallBase>(&II); if (!I) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp index 277fd903e9aa..668626fef933 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -309,7 +309,7 @@ static void computeBlocksDominatingExits( // worklist, unless we visited it already. BasicBlock *IDomBB = DT.getNode(BB)->getIDom()->getBlock(); - // Exit blocks can have an immediate dominator not beloinging to the + // Exit blocks can have an immediate dominator not belonging to the // loop. For an exit block to be immediately dominated by another block // outside the loop, it implies not all paths from that dominator, to the // exit block, go through the loop. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 7e5832148bc0..6958a89f5be6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -304,7 +304,7 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) { if (!TLI.getLibFunc(*Callee, Func) || !TLI.has(Func)) return; - if (CI.getNumArgOperands() == 0) + if (CI.arg_empty()) return; // TODO: Handle long double in other formats. Type *ArgType = CI.getArgOperand(0)->getType(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp index d03d76f57ca1..74ab37fadf36 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp @@ -1413,8 +1413,6 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { if (auto *AI = dyn_cast_or_null<AllocaInst>(DII->getVariableLocationOp(0))) { if (Optional<TypeSize> FragmentSize = AI->getAllocationSizeInBits(DL)) { - assert(ValueSize.isScalable() == FragmentSize->isScalable() && - "Both sizes should agree on the scalable flag."); return TypeSize::isKnownGE(ValueSize, *FragmentSize); } } @@ -1733,9 +1731,11 @@ void llvm::salvageDebugInfo(Instruction &I) { void llvm::salvageDebugInfoForDbgValues( Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers) { - // This is an arbitrary chosen limit on the maximum number of values we can - // salvage up to in a DIArgList, used for performance reasons. + // These are arbitrary chosen limits on the maximum number of values and the + // maximum size of a debug expression we can salvage up to, used for + // performance reasons. const unsigned MaxDebugArgs = 16; + const unsigned MaxExpressionSize = 128; bool Salvaged = false; for (auto *DII : DbgUsers) { @@ -1752,23 +1752,30 @@ void llvm::salvageDebugInfoForDbgValues( // must be updated in the DIExpression and potentially have additional // values added; thus we call salvageDebugInfoImpl for each `I` instance in // DIILocation. + Value *Op0 = nullptr; DIExpression *SalvagedExpr = DII->getExpression(); auto LocItr = find(DIILocation, &I); while (SalvagedExpr && LocItr != DIILocation.end()) { + SmallVector<uint64_t, 16> Ops; unsigned LocNo = std::distance(DIILocation.begin(), LocItr); - SalvagedExpr = salvageDebugInfoImpl(I, SalvagedExpr, StackValue, LocNo, - AdditionalValues); + uint64_t CurrentLocOps = SalvagedExpr->getNumLocationOperands(); + Op0 = salvageDebugInfoImpl(I, CurrentLocOps, Ops, AdditionalValues); + if (!Op0) + break; + SalvagedExpr = + DIExpression::appendOpsToArg(SalvagedExpr, Ops, LocNo, StackValue); LocItr = std::find(++LocItr, DIILocation.end(), &I); } // salvageDebugInfoImpl should fail on examining the first element of // DbgUsers, or none of them. - if (!SalvagedExpr) + if (!Op0) break; - DII->replaceVariableLocationOp(&I, I.getOperand(0)); - if (AdditionalValues.empty()) { + DII->replaceVariableLocationOp(&I, Op0); + bool IsValidSalvageExpr = SalvagedExpr->getNumElements() <= MaxExpressionSize; + if (AdditionalValues.empty() && IsValidSalvageExpr) { DII->setExpression(SalvagedExpr); - } else if (isa<DbgValueInst>(DII) && + } else if (isa<DbgValueInst>(DII) && IsValidSalvageExpr && DII->getNumVariableLocationOps() + AdditionalValues.size() <= MaxDebugArgs) { DII->addVariableLocationOps(AdditionalValues, SalvagedExpr); @@ -1793,16 +1800,16 @@ void llvm::salvageDebugInfoForDbgValues( } } -bool getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, - uint64_t CurrentLocOps, - SmallVectorImpl<uint64_t> &Opcodes, - SmallVectorImpl<Value *> &AdditionalValues) { +Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, + uint64_t CurrentLocOps, + SmallVectorImpl<uint64_t> &Opcodes, + SmallVectorImpl<Value *> &AdditionalValues) { unsigned BitWidth = DL.getIndexSizeInBits(GEP->getPointerAddressSpace()); // Rewrite a GEP into a DIExpression. MapVector<Value *, APInt> VariableOffsets; APInt ConstantOffset(BitWidth, 0); if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) - return false; + return nullptr; if (!VariableOffsets.empty() && !CurrentLocOps) { Opcodes.insert(Opcodes.begin(), {dwarf::DW_OP_LLVM_arg, 0}); CurrentLocOps = 1; @@ -1816,7 +1823,7 @@ bool getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, dwarf::DW_OP_plus}); } DIExpression::appendOffset(Opcodes, ConstantOffset.getSExtValue()); - return true; + return GEP->getOperand(0); } uint64_t getDwarfOpForBinOp(Instruction::BinaryOps Opcode) { @@ -1849,14 +1856,14 @@ uint64_t getDwarfOpForBinOp(Instruction::BinaryOps Opcode) { } } -bool getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, - SmallVectorImpl<uint64_t> &Opcodes, - SmallVectorImpl<Value *> &AdditionalValues) { +Value *getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, + SmallVectorImpl<uint64_t> &Opcodes, + SmallVectorImpl<Value *> &AdditionalValues) { // Handle binary operations with constant integer operands as a special case. auto *ConstInt = dyn_cast<ConstantInt>(BI->getOperand(1)); // Values wider than 64 bits cannot be represented within a DIExpression. if (ConstInt && ConstInt->getBitWidth() > 64) - return false; + return nullptr; Instruction::BinaryOps BinOpcode = BI->getOpcode(); // Push any Constant Int operand onto the expression stack. @@ -1867,7 +1874,7 @@ bool getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, if (BinOpcode == Instruction::Add || BinOpcode == Instruction::Sub) { uint64_t Offset = BinOpcode == Instruction::Add ? Val : -int64_t(Val); DIExpression::appendOffset(Opcodes, Offset); - return true; + return BI->getOperand(0); } Opcodes.append({dwarf::DW_OP_constu, Val}); } else { @@ -1883,62 +1890,51 @@ bool getSalvageOpsForBinOp(BinaryOperator *BI, uint64_t CurrentLocOps, // representation in a DIExpression. uint64_t DwarfBinOp = getDwarfOpForBinOp(BinOpcode); if (!DwarfBinOp) - return false; + return nullptr; Opcodes.push_back(DwarfBinOp); - - return true; + return BI->getOperand(0); } -DIExpression * -llvm::salvageDebugInfoImpl(Instruction &I, DIExpression *SrcDIExpr, - bool WithStackValue, unsigned LocNo, - SmallVectorImpl<Value *> &AdditionalValues) { - uint64_t CurrentLocOps = SrcDIExpr->getNumLocationOperands(); +Value *llvm::salvageDebugInfoImpl(Instruction &I, uint64_t CurrentLocOps, + SmallVectorImpl<uint64_t> &Ops, + SmallVectorImpl<Value *> &AdditionalValues) { auto &M = *I.getModule(); auto &DL = M.getDataLayout(); - // Apply a vector of opcodes to the source DIExpression. - auto doSalvage = [&](SmallVectorImpl<uint64_t> &Ops) -> DIExpression * { - DIExpression *DIExpr = SrcDIExpr; - if (!Ops.empty()) { - DIExpr = DIExpression::appendOpsToArg(DIExpr, Ops, LocNo, WithStackValue); - } - return DIExpr; - }; - - // initializer-list helper for applying operators to the source DIExpression. - auto applyOps = [&](ArrayRef<uint64_t> Opcodes) { - SmallVector<uint64_t, 8> Ops(Opcodes.begin(), Opcodes.end()); - return doSalvage(Ops); - }; - if (auto *CI = dyn_cast<CastInst>(&I)) { + Value *FromValue = CI->getOperand(0); // No-op casts are irrelevant for debug info. - if (CI->isNoopCast(DL)) - return SrcDIExpr; + if (CI->isNoopCast(DL)) { + return FromValue; + } Type *Type = CI->getType(); + if (Type->isPointerTy()) + Type = DL.getIntPtrType(Type); // Casts other than Trunc, SExt, or ZExt to scalar types cannot be salvaged. if (Type->isVectorTy() || - !(isa<TruncInst>(&I) || isa<SExtInst>(&I) || isa<ZExtInst>(&I))) + !(isa<TruncInst>(&I) || isa<SExtInst>(&I) || isa<ZExtInst>(&I) || + isa<IntToPtrInst>(&I) || isa<PtrToIntInst>(&I))) return nullptr; - Value *FromValue = CI->getOperand(0); - unsigned FromTypeBitSize = FromValue->getType()->getScalarSizeInBits(); + llvm::Type *FromType = FromValue->getType(); + if (FromType->isPointerTy()) + FromType = DL.getIntPtrType(FromType); + + unsigned FromTypeBitSize = FromType->getScalarSizeInBits(); unsigned ToTypeBitSize = Type->getScalarSizeInBits(); - return applyOps(DIExpression::getExtOps(FromTypeBitSize, ToTypeBitSize, - isa<SExtInst>(&I))); + auto ExtOps = DIExpression::getExtOps(FromTypeBitSize, ToTypeBitSize, + isa<SExtInst>(&I)); + Ops.append(ExtOps.begin(), ExtOps.end()); + return FromValue; } - SmallVector<uint64_t, 8> Ops; - if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { - if (getSalvageOpsForGEP(GEP, DL, CurrentLocOps, Ops, AdditionalValues)) - return doSalvage(Ops); - } else if (auto *BI = dyn_cast<BinaryOperator>(&I)) { - if (getSalvageOpsForBinOp(BI, CurrentLocOps, Ops, AdditionalValues)) - return doSalvage(Ops); - } + if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) + return getSalvageOpsForGEP(GEP, DL, CurrentLocOps, Ops, AdditionalValues); + if (auto *BI = dyn_cast<BinaryOperator>(&I)) + return getSalvageOpsForBinOp(BI, CurrentLocOps, Ops, AdditionalValues); + // *Not* to do: we should not attempt to salvage load instructions, // because the validity and lifetime of a dbg.value containing // DW_OP_deref becomes difficult to analyze. See PR40628 for examples. @@ -2194,6 +2190,26 @@ void llvm::changeToCall(InvokeInst *II, DomTreeUpdater *DTU) { DTU->applyUpdates({{DominatorTree::Delete, BB, UnwindDestBB}}); } +void llvm::createUnreachableSwitchDefault(SwitchInst *Switch, + DomTreeUpdater *DTU) { + LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); + auto *BB = Switch->getParent(); + auto *OrigDefaultBlock = Switch->getDefaultDest(); + OrigDefaultBlock->removePredecessor(BB); + BasicBlock *NewDefaultBlock = BasicBlock::Create( + BB->getContext(), BB->getName() + ".unreachabledefault", BB->getParent(), + OrigDefaultBlock); + new UnreachableInst(Switch->getContext(), NewDefaultBlock); + Switch->setDefaultDest(&*NewDefaultBlock); + if (DTU) { + SmallVector<DominatorTree::UpdateType, 2> Updates; + Updates.push_back({DominatorTree::Insert, BB, &*NewDefaultBlock}); + if (!is_contained(successors(BB), OrigDefaultBlock)) + Updates.push_back({DominatorTree::Delete, BB, &*OrigDefaultBlock}); + DTU->applyUpdates(Updates); + } +} + BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, BasicBlock *UnwindEdge, DomTreeUpdater *DTU) { @@ -2669,9 +2685,7 @@ static unsigned replaceDominatedUsesWith(Value *From, Value *To, assert(From->getType() == To->getType()); unsigned Count = 0; - for (Value::use_iterator UI = From->use_begin(), UE = From->use_end(); - UI != UE;) { - Use &U = *UI++; + for (Use &U : llvm::make_early_inc_range(From->uses())) { if (!Dominates(Root, U)) continue; U.set(To); @@ -2687,9 +2701,7 @@ unsigned llvm::replaceNonLocalUsesWith(Instruction *From, Value *To) { auto *BB = From->getParent(); unsigned Count = 0; - for (Value::use_iterator UI = From->use_begin(), UE = From->use_end(); - UI != UE;) { - Use &U = *UI++; + for (Use &U : llvm::make_early_inc_range(From->uses())) { auto *I = cast<Instruction>(U.getUser()); if (I->getParent() == BB) continue; @@ -3171,7 +3183,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( // Now, is the bit permutation correct for a bswap or a bitreverse? We can // only byteswap values with an even number of bytes. - APInt DemandedMask = APInt::getAllOnesValue(DemandedBW); + APInt DemandedMask = APInt::getAllOnes(DemandedBW); bool OKForBSwap = MatchBSwaps && (DemandedBW % 16) == 0; bool OKForBitReverse = MatchBitReversals; for (unsigned BitIdx = 0; @@ -3208,7 +3220,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( Instruction *Result = CallInst::Create(F, Provider, "rev", I); InsertedInsts.push_back(Result); - if (!DemandedMask.isAllOnesValue()) { + if (!DemandedMask.isAllOnes()) { auto *Mask = ConstantInt::get(DemandedTy, DemandedMask); Result = BinaryOperator::Create(Instruction::And, Result, Mask, "mask", I); InsertedInsts.push_back(Result); @@ -3235,7 +3247,7 @@ void llvm::maybeMarkSanitizerLibraryCallNoBuiltin( if (F && !F->hasLocalLinkage() && F->hasName() && TLI->getLibFunc(F->getName(), Func) && TLI->hasOptimizedCodeGen(Func) && !F->doesNotAccessMemory()) - CI->addAttribute(AttributeList::FunctionIndex, Attribute::NoBuiltin); + CI->addFnAttr(Attribute::NoBuiltin); } bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { @@ -3263,7 +3275,7 @@ bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { if (CB.isBundleOperand(OpIdx)) return false; - if (OpIdx < CB.getNumArgOperands()) { + if (OpIdx < CB.arg_size()) { // Some variadic intrinsics require constants in the variadic arguments, // which currently aren't markable as immarg. if (isa<IntrinsicInst>(CB) && diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp index cd1f6f0c78a5..f3cf42be8ba1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -73,57 +74,39 @@ static cl::opt<unsigned> UnrollForcePeelCount( "unroll-force-peel-count", cl::init(0), cl::Hidden, cl::desc("Force a peel count regardless of profiling information.")); -static cl::opt<bool> UnrollPeelMultiDeoptExit( - "unroll-peel-multi-deopt-exit", cl::init(true), cl::Hidden, - cl::desc("Allow peeling of loops with multiple deopt exits.")); - static const char *PeeledCountMetaData = "llvm.loop.peeled.count"; -// Designates that a Phi is estimated to become invariant after an "infinite" -// number of loop iterations (i.e. only may become an invariant if the loop is -// fully unrolled). -static const unsigned InfiniteIterationsToInvariance = - std::numeric_limits<unsigned>::max(); - // Check whether we are capable of peeling this loop. bool llvm::canPeel(Loop *L) { // Make sure the loop is in simplified form if (!L->isLoopSimplifyForm()) return false; - if (UnrollPeelMultiDeoptExit) { - SmallVector<BasicBlock *, 4> Exits; - L->getUniqueNonLatchExitBlocks(Exits); - - if (!Exits.empty()) { - // Latch's terminator is a conditional branch, Latch is exiting and - // all non Latch exits ends up with deoptimize. - const BasicBlock *Latch = L->getLoopLatch(); - const BranchInst *T = dyn_cast<BranchInst>(Latch->getTerminator()); - return T && T->isConditional() && L->isLoopExiting(Latch) && - all_of(Exits, [](const BasicBlock *BB) { - return BB->getTerminatingDeoptimizeCall(); - }); - } - } - - // Only peel loops that contain a single exit - if (!L->getExitingBlock() || !L->getUniqueExitBlock()) - return false; - // Don't try to peel loops where the latch is not the exiting block. // This can be an indication of two different things: // 1) The loop is not rotated. // 2) The loop contains irreducible control flow that involves the latch. const BasicBlock *Latch = L->getLoopLatch(); - if (Latch != L->getExitingBlock()) + if (!L->isLoopExiting(Latch)) return false; // Peeling is only supported if the latch is a branch. if (!isa<BranchInst>(Latch->getTerminator())) return false; - return true; + SmallVector<BasicBlock *, 4> Exits; + L->getUniqueNonLatchExitBlocks(Exits); + // The latch must either be the only exiting block or all non-latch exit + // blocks have either a deopt or unreachable terminator or compose a chain of + // blocks where the last one is either deopt or unreachable terminated. Both + // deopt and unreachable terminators are a strong indication they are not + // taken. Note that this is a profitability check, not a legality check. Also + // note that LoopPeeling currently can only update the branch weights of latch + // blocks and branch weights to blocks with deopt or unreachable do not need + // updating. + return all_of(Exits, [](const BasicBlock *BB) { + return IsBlockFollowedByDeoptOrUnreachable(BB); + }); } // This function calculates the number of iterations after which the given Phi @@ -139,9 +122,9 @@ bool llvm::canPeel(Loop *L) { // %x = phi(0, %a), <-- becomes invariant starting from 3rd iteration. // %y = phi(0, 5), // %a = %y + 1. -static unsigned calculateIterationsToInvariance( +static Optional<unsigned> calculateIterationsToInvariance( PHINode *Phi, Loop *L, BasicBlock *BackEdge, - SmallDenseMap<PHINode *, unsigned> &IterationsToInvariance) { + SmallDenseMap<PHINode *, Optional<unsigned> > &IterationsToInvariance) { assert(Phi->getParent() == L->getHeader() && "Non-loop Phi should not be checked for turning into invariant."); assert(BackEdge == L->getLoopLatch() && "Wrong latch?"); @@ -154,29 +137,90 @@ static unsigned calculateIterationsToInvariance( Value *Input = Phi->getIncomingValueForBlock(BackEdge); // Place infinity to map to avoid infinite recursion for cycled Phis. Such // cycles can never stop on an invariant. - IterationsToInvariance[Phi] = InfiniteIterationsToInvariance; - unsigned ToInvariance = InfiniteIterationsToInvariance; + IterationsToInvariance[Phi] = None; + Optional<unsigned> ToInvariance = None; if (L->isLoopInvariant(Input)) ToInvariance = 1u; else if (PHINode *IncPhi = dyn_cast<PHINode>(Input)) { // Only consider Phis in header block. if (IncPhi->getParent() != L->getHeader()) - return InfiniteIterationsToInvariance; + return None; // If the input becomes an invariant after X iterations, then our Phi // becomes an invariant after X + 1 iterations. - unsigned InputToInvariance = calculateIterationsToInvariance( + auto InputToInvariance = calculateIterationsToInvariance( IncPhi, L, BackEdge, IterationsToInvariance); - if (InputToInvariance != InfiniteIterationsToInvariance) - ToInvariance = InputToInvariance + 1u; + if (InputToInvariance) + ToInvariance = *InputToInvariance + 1u; } // If we found that this Phi lies in an invariant chain, update the map. - if (ToInvariance != InfiniteIterationsToInvariance) + if (ToInvariance) IterationsToInvariance[Phi] = ToInvariance; return ToInvariance; } +// Try to find any invariant memory reads that will become dereferenceable in +// the remainder loop after peeling. The load must also be used (transitively) +// by an exit condition. Returns the number of iterations to peel off (at the +// moment either 0 or 1). +static unsigned peelToTurnInvariantLoadsDerefencebale(Loop &L, + DominatorTree &DT) { + // Skip loops with a single exiting block, because there should be no benefit + // for the heuristic below. + if (L.getExitingBlock()) + return 0; + + // All non-latch exit blocks must have an UnreachableInst terminator. + // Otherwise the heuristic below may not be profitable. + SmallVector<BasicBlock *, 4> Exits; + L.getUniqueNonLatchExitBlocks(Exits); + if (any_of(Exits, [](const BasicBlock *BB) { + return !isa<UnreachableInst>(BB->getTerminator()); + })) + return 0; + + // Now look for invariant loads that dominate the latch and are not known to + // be dereferenceable. If there are such loads and no writes, they will become + // dereferenceable in the loop if the first iteration is peeled off. Also + // collect the set of instructions controlled by such loads. Only peel if an + // exit condition uses (transitively) such a load. + BasicBlock *Header = L.getHeader(); + BasicBlock *Latch = L.getLoopLatch(); + SmallPtrSet<Value *, 8> LoadUsers; + const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + for (BasicBlock *BB : L.blocks()) { + for (Instruction &I : *BB) { + if (I.mayWriteToMemory()) + return 0; + + auto Iter = LoadUsers.find(&I); + if (Iter != LoadUsers.end()) { + for (Value *U : I.users()) + LoadUsers.insert(U); + } + // Do not look for reads in the header; they can already be hoisted + // without peeling. + if (BB == Header) + continue; + if (auto *LI = dyn_cast<LoadInst>(&I)) { + Value *Ptr = LI->getPointerOperand(); + if (DT.dominates(BB, Latch) && L.isLoopInvariant(Ptr) && + !isDereferenceablePointer(Ptr, LI->getType(), DL, LI, &DT)) + for (Value *U : I.users()) + LoadUsers.insert(U); + } + } + } + SmallVector<BasicBlock *> ExitingBlocks; + L.getExitingBlocks(ExitingBlocks); + if (any_of(ExitingBlocks, [&LoadUsers](BasicBlock *Exiting) { + return LoadUsers.contains(Exiting->getTerminator()); + })) + return 1; + return 0; +} + // Return the number of iterations to peel off that make conditions in the // body true/false. For example, if we peel 2 iterations off the loop below, // the condition i < 2 can be evaluated at compile time. @@ -292,8 +336,8 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, // Return the number of iterations we want to peel off. void llvm::computePeelCount(Loop *L, unsigned LoopSize, TargetTransformInfo::PeelingPreferences &PP, - unsigned &TripCount, ScalarEvolution &SE, - unsigned Threshold) { + unsigned &TripCount, DominatorTree &DT, + ScalarEvolution &SE, unsigned Threshold) { assert(LoopSize > 0 && "Zero loop size is not allowed!"); // Save the PP.PeelCount value set by the target in // TTI.getPeelingPreferences or by the flag -unroll-peel-count. @@ -337,7 +381,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, // First, check that we can peel at least one iteration. if (2 * LoopSize <= Threshold && UnrollPeelMaxCount > 0) { // Store the pre-calculated values here. - SmallDenseMap<PHINode *, unsigned> IterationsToInvariance; + SmallDenseMap<PHINode *, Optional<unsigned> > IterationsToInvariance; // Now go through all Phis to calculate their the number of iterations they // need to become invariants. // Start the max computation with the UP.PeelCount value set by the target @@ -347,10 +391,10 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, assert(BackEdge && "Loop is not in simplified form?"); for (auto BI = L->getHeader()->begin(); isa<PHINode>(&*BI); ++BI) { PHINode *Phi = cast<PHINode>(&*BI); - unsigned ToInvariance = calculateIterationsToInvariance( + auto ToInvariance = calculateIterationsToInvariance( Phi, L, BackEdge, IterationsToInvariance); - if (ToInvariance != InfiniteIterationsToInvariance) - DesiredPeelCount = std::max(DesiredPeelCount, ToInvariance); + if (ToInvariance) + DesiredPeelCount = std::max(DesiredPeelCount, *ToInvariance); } // Pay respect to limitations implied by loop size and the max peel count. @@ -360,6 +404,9 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, DesiredPeelCount = std::max(DesiredPeelCount, countToEliminateCompares(*L, MaxPeelCount, SE)); + if (DesiredPeelCount == 0) + DesiredPeelCount = peelToTurnInvariantLoadsDerefencebale(*L, DT); + if (DesiredPeelCount > 0) { DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); // Consider max peel count limitation. @@ -679,34 +726,27 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> ExitEdges; L->getExitEdges(ExitEdges); - DenseMap<BasicBlock *, BasicBlock *> ExitIDom; + // Remember dominators of blocks we might reach through exits to change them + // later. Immediate dominator of such block might change, because we add more + // routes which can lead to the exit: we can reach it from the peeled + // iterations too. + DenseMap<BasicBlock *, BasicBlock *> NonLoopBlocksIDom; if (DT) { - // We'd like to determine the idom of exit block after peeling one - // iteration. - // Let Exit is exit block. - // Let ExitingSet - is a set of predecessors of Exit block. They are exiting - // blocks. - // Let Latch' and ExitingSet' are copies after a peeling. - // We'd like to find an idom'(Exit) - idom of Exit after peeling. - // It is an evident that idom'(Exit) will be the nearest common dominator - // of ExitingSet and ExitingSet'. - // idom(Exit) is a nearest common dominator of ExitingSet. - // idom(Exit)' is a nearest common dominator of ExitingSet'. - // Taking into account that we have a single Latch, Latch' will dominate - // Header and idom(Exit). - // So the idom'(Exit) is nearest common dominator of idom(Exit)' and Latch'. - // All these basic blocks are in the same loop, so what we find is - // (nearest common dominator of idom(Exit) and Latch)'. - // In the loop below we remember nearest common dominator of idom(Exit) and - // Latch to update idom of Exit later. - assert(L->hasDedicatedExits() && "No dedicated exits?"); - for (auto Edge : ExitEdges) { - if (ExitIDom.count(Edge.second)) - continue; - BasicBlock *BB = DT->findNearestCommonDominator( - DT->getNode(Edge.second)->getIDom()->getBlock(), Latch); - assert(L->contains(BB) && "IDom is not in a loop"); - ExitIDom[Edge.second] = BB; + for (auto *BB : L->blocks()) { + auto *BBDomNode = DT->getNode(BB); + SmallVector<BasicBlock *, 16> ChildrenToUpdate; + for (auto *ChildDomNode : BBDomNode->children()) { + auto *ChildBB = ChildDomNode->getBlock(); + if (!L->contains(ChildBB)) + ChildrenToUpdate.push_back(ChildBB); + } + // The new idom of the block will be the nearest common dominator + // of all copies of the previous idom. This is equivalent to the + // nearest common dominator of the previous idom and the first latch, + // which dominates all copies of the previous idom. + BasicBlock *NewIDom = DT->findNearestCommonDominator(BB, Latch); + for (auto *ChildBB : ChildrenToUpdate) + NonLoopBlocksIDom[ChildBB] = NewIDom; } } @@ -795,13 +835,11 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, remapInstructionsInBlocks(NewBlocks, VMap); if (DT) { - // Latches of the cloned loops dominate over the loop exit, so idom of the - // latter is the first cloned loop body, as original PreHeader dominates - // the original loop body. + // Update IDoms of the blocks reachable through exits. if (Iter == 0) - for (auto Exit : ExitIDom) - DT->changeImmediateDominator(Exit.first, - cast<BasicBlock>(LVMap[Exit.second])); + for (auto BBIDom : NonLoopBlocksIDom) + DT->changeImmediateDominator(BBIDom.first, + cast<BasicBlock>(LVMap[BBIDom.second])); #ifdef EXPENSIVE_CHECKS assert(DT->verify(DominatorTree::VerificationLevel::Fast)); #endif diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index ff7905bed91d..c66fd7bb0588 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -103,6 +103,7 @@ static void InsertNewValueIntoMap(ValueToValueMapTy &VM, Value *K, Value *V) { static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, BasicBlock *OrigPreheader, ValueToValueMapTy &ValueMap, + ScalarEvolution *SE, SmallVectorImpl<PHINode*> *InsertedPHIs) { // Remove PHI node entries that are no longer live. BasicBlock::iterator I, E = OrigHeader->end(); @@ -125,19 +126,15 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // The value now exits in two versions: the initial value in the preheader // and the loop "next" value in the original header. SSA.Initialize(OrigHeaderVal->getType(), OrigHeaderVal->getName()); + // Force re-computation of OrigHeaderVal, as some users now need to use the + // new PHI node. + if (SE) + SE->forgetValue(OrigHeaderVal); SSA.AddAvailableValue(OrigHeader, OrigHeaderVal); SSA.AddAvailableValue(OrigPreheader, OrigPreHeaderVal); // Visit each use of the OrigHeader instruction. - for (Value::use_iterator UI = OrigHeaderVal->use_begin(), - UE = OrigHeaderVal->use_end(); - UI != UE;) { - // Grab the use before incrementing the iterator. - Use &U = *UI; - - // Increment the iterator before removing the use from the list. - ++UI; - + for (Use &U : llvm::make_early_inc_range(OrigHeaderVal->uses())) { // SSAUpdater can't handle a non-PHI use in the same block as an // earlier def. We can easily handle those cases manually. Instruction *UserInst = cast<Instruction>(U.getUser()); @@ -399,9 +396,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { D->getExpression()}; }; SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; - for (auto I = std::next(OrigPreheader->rbegin()), E = OrigPreheader->rend(); - I != E; ++I) { - if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&*I)) + for (Instruction &I : llvm::drop_begin(llvm::reverse(*OrigPreheader))) { + if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I)) DbgIntrinsics.insert(makeHash(DII)); else break; @@ -563,7 +559,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { SmallVector<PHINode*, 2> InsertedPHIs; // If there were any uses of instructions in the duplicated block outside the // loop, update them, inserting PHI nodes as required - RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, + RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, SE, &InsertedPHIs); // Attach dbg.value intrinsics to the new phis if that phi uses a value that @@ -621,7 +617,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // one predecessor. Note that Exit could be an exit block for multiple // nested loops, causing both of the edges to now be critical and need to // be split. - SmallVector<BasicBlock *, 4> ExitPreds(pred_begin(Exit), pred_end(Exit)); + SmallVector<BasicBlock *, 4> ExitPreds(predecessors(Exit)); bool SplitLatchEdge = false; for (BasicBlock *ExitPred : ExitPreds) { // We only need to split loop exit edges. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp index d2fd32c98d73..d14c006c8032 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -779,8 +779,7 @@ namespace { AU.addPreserved<DependenceAnalysisWrapperPass>(); AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. AU.addPreserved<BranchProbabilityInfoWrapperPass>(); - if (EnableMSSALoopDependency) - AU.addPreserved<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); } /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees. @@ -814,12 +813,10 @@ bool LoopSimplify::runOnFunction(Function &F) { &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); MemorySSA *MSSA = nullptr; std::unique_ptr<MemorySSAUpdater> MSSAU; - if (EnableMSSALoopDependency) { - auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - if (MSSAAnalysis) { - MSSA = &MSSAAnalysis->getMSSA(); - MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - } + auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + if (MSSAAnalysis) { + MSSA = &MSSAAnalysis->getMSSA(); + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); } bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp index a91bf7b7af13..b0c622b98d5e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -224,13 +224,12 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); SmallVector<WeakTrackingVH, 16> DeadInsts; for (BasicBlock *BB : L->getBlocks()) { - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { - Instruction *Inst = &*I++; - if (Value *V = SimplifyInstruction(Inst, {DL, nullptr, DT, AC})) - if (LI->replacementPreservesLCSSAForm(Inst, V)) - Inst->replaceAllUsesWith(V); - if (isInstructionTriviallyDead(Inst)) - DeadInsts.emplace_back(Inst); + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { + if (Value *V = SimplifyInstruction(&Inst, {DL, nullptr, DT, AC})) + if (LI->replacementPreservesLCSSAForm(&Inst, V)) + Inst.replaceAllUsesWith(V); + if (isInstructionTriviallyDead(&Inst)) + DeadInsts.emplace_back(&Inst); } // We can't do recursive deletion until we're done iterating, as we might // have a phi which (potentially indirectly) uses instructions later in @@ -515,6 +514,10 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, SmallVector<MDNode *, 6> LoopLocalNoAliasDeclScopes; identifyNoAliasScopesToClone(L->getBlocks(), LoopLocalNoAliasDeclScopes); + // We place the unrolled iterations immediately after the original loop + // latch. This is a reasonable default placement if we don't have block + // frequencies, and if we do, well the layout will be adjusted later. + auto BlockInsertPt = std::next(LatchBlock->getIterator()); for (unsigned It = 1; It != ULO.Count; ++It) { SmallVector<BasicBlock *, 8> NewBlocks; SmallDenseMap<const Loop *, Loop *, 4> NewLoops; @@ -523,7 +526,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { ValueToValueMapTy VMap; BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); - Header->getParent()->getBasicBlockList().push_back(New); + Header->getParent()->getBasicBlockList().insert(BlockInsertPt, New); assert((*BB != Header || LI->getLoopFor(*BB) == L) && "Header should not be in a sub-loop"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 6749d3db743c..a92cb6a313d3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" @@ -35,6 +36,7 @@ #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/UnrollLoop.h" @@ -167,8 +169,11 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // Add the branch to the exit block (around the unrolled loop) B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader); InsertPt->eraseFromParent(); - if (DT) - DT->changeImmediateDominator(OriginalLoopLatchExit, PrologExit); + if (DT) { + auto *NewDom = DT->findNearestCommonDominator(OriginalLoopLatchExit, + PrologExit); + DT->changeImmediateDominator(OriginalLoopLatchExit, NewDom); + } } /// Connect the unrolling epilog code to the original loop. @@ -215,7 +220,10 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // PN = PHI [I, Latch] // ... // Exit: - // EpilogPN = PHI [PN, EpilogPreHeader] + // EpilogPN = PHI [PN, EpilogPreHeader], [X, Exit2], [Y, Exit2.epil] + // + // Exits from non-latch blocks point to the original exit block and the + // epilogue edges have already been added. // // There is EpilogPreHeader incoming block instead of NewExit as // NewExit was spilt 1 more time to get EpilogPreHeader. @@ -282,8 +290,10 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Add the branch to the exit block (around the unrolling loop) B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit); InsertPt->eraseFromParent(); - if (DT) - DT->changeImmediateDominator(Exit, NewExit); + if (DT) { + auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit); + DT->changeImmediateDominator(Exit, NewDom); + } // Split the main loop exit to maintain canonicalization guarantees. SmallVector<BasicBlock*, 4> NewExitPreds{Latch}; @@ -291,17 +301,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, PreserveLCSSA); } -/// Create a clone of the blocks in a loop and connect them together. -/// If CreateRemainderLoop is false, loop structure will not be cloned, -/// otherwise a new loop will be created including all cloned blocks, and the -/// iterator of it switches to count NewIter down to 0. +/// Create a clone of the blocks in a loop and connect them together. A new +/// loop will be created including all cloned blocks, and the iterator of the +/// new loop switched to count NewIter down to 0. /// The cloned blocks should be inserted between InsertTop and InsertBot. -/// If loop structure is cloned InsertTop should be new preheader, InsertBot -/// new loop exit. -/// Return the new cloned loop that is created when CreateRemainderLoop is true. +/// InsertTop should be new preheader, InsertBot new loop exit. +/// Returns the new cloned loop that is created. static Loop * -CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, - const bool UseEpilogRemainder, const bool UnrollRemainder, +CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, + const bool UnrollRemainder, BasicBlock *InsertTop, BasicBlock *InsertBot, BasicBlock *Preheader, std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, @@ -315,8 +323,6 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, Loop *ParentLoop = L->getParentLoop(); NewLoopsMap NewLoops; NewLoops[ParentLoop] = ParentLoop; - if (!CreateRemainderLoop) - NewLoops[L] = ParentLoop; // For each block in the original loop, create a new copy, // and update the value map with the newly created values. @@ -324,11 +330,7 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F); NewBlocks.push_back(NewBB); - // If we're unrolling the outermost loop, there's no remainder loop, - // and this block isn't in a nested loop, then the new block is not - // in any loop. Otherwise, add it to loopinfo. - if (CreateRemainderLoop || LI->getLoopFor(*BB) != L || ParentLoop) - addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops); + addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops); VMap[*BB] = NewBB; if (Header == *BB) { @@ -349,27 +351,24 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, } if (Latch == *BB) { - // For the last block, if CreateRemainderLoop is false, create a direct - // jump to InsertBot. If not, create a loop back to cloned head. + // For the last block, create a loop back to cloned head. VMap.erase((*BB)->getTerminator()); + // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. + // Subtle: NewIter can be 0 if we wrapped when computing the trip count, + // thus we must compare the post-increment (wrapping) value. BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); IRBuilder<> Builder(LatchBR); - if (!CreateRemainderLoop) { - Builder.CreateBr(InsertBot); - } else { - PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, - suffix + ".iter", - FirstLoopBB->getFirstNonPHI()); - Value *IdxSub = - Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), - NewIdx->getName() + ".sub"); - Value *IdxCmp = - Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp"); - Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot); - NewIdx->addIncoming(NewIter, InsertTop); - NewIdx->addIncoming(IdxSub, NewBB); - } + PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, + suffix + ".iter", + FirstLoopBB->getFirstNonPHI()); + auto *Zero = ConstantInt::get(NewIdx->getType(), 0); + auto *One = ConstantInt::get(NewIdx->getType(), 1); + Value *IdxNext = Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next"); + Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp"); + Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot); + NewIdx->addIncoming(Zero, InsertTop); + NewIdx->addIncoming(IdxNext, NewBB); LatchBR->eraseFromParent(); } } @@ -378,99 +377,45 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, // cloned loop. for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { PHINode *NewPHI = cast<PHINode>(VMap[&*I]); - if (!CreateRemainderLoop) { - if (UseEpilogRemainder) { - unsigned idx = NewPHI->getBasicBlockIndex(Preheader); - NewPHI->setIncomingBlock(idx, InsertTop); - NewPHI->removeIncomingValue(Latch, false); - } else { - VMap[&*I] = NewPHI->getIncomingValueForBlock(Preheader); - cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); - } - } else { - unsigned idx = NewPHI->getBasicBlockIndex(Preheader); - NewPHI->setIncomingBlock(idx, InsertTop); - BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); - idx = NewPHI->getBasicBlockIndex(Latch); - Value *InVal = NewPHI->getIncomingValue(idx); - NewPHI->setIncomingBlock(idx, NewLatch); - if (Value *V = VMap.lookup(InVal)) - NewPHI->setIncomingValue(idx, V); - } - } - if (CreateRemainderLoop) { - Loop *NewLoop = NewLoops[L]; - assert(NewLoop && "L should have been cloned"); - MDNode *LoopID = NewLoop->getLoopID(); - - // Only add loop metadata if the loop is not going to be completely - // unrolled. - if (UnrollRemainder) - return NewLoop; - - Optional<MDNode *> NewLoopID = makeFollowupLoopID( - LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); - if (NewLoopID.hasValue()) { - NewLoop->setLoopID(NewLoopID.getValue()); - - // Do not setLoopAlreadyUnrolled if loop attributes have been defined - // explicitly. - return NewLoop; - } - - // Add unroll disable metadata to disable future unrolling for this loop. - NewLoop->setLoopAlreadyUnrolled(); - return NewLoop; + unsigned idx = NewPHI->getBasicBlockIndex(Preheader); + NewPHI->setIncomingBlock(idx, InsertTop); + BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); + idx = NewPHI->getBasicBlockIndex(Latch); + Value *InVal = NewPHI->getIncomingValue(idx); + NewPHI->setIncomingBlock(idx, NewLatch); + if (Value *V = VMap.lookup(InVal)) + NewPHI->setIncomingValue(idx, V); } - else - return nullptr; -} -/// Returns true if we can safely unroll a multi-exit/exiting loop. OtherExits -/// is populated with all the loop exit blocks other than the LatchExit block. -static bool canSafelyUnrollMultiExitLoop(Loop *L, BasicBlock *LatchExit, - bool PreserveLCSSA, - bool UseEpilogRemainder) { + Loop *NewLoop = NewLoops[L]; + assert(NewLoop && "L should have been cloned"); + MDNode *LoopID = NewLoop->getLoopID(); - // We currently have some correctness constrains in unrolling a multi-exit - // loop. Check for these below. + // Only add loop metadata if the loop is not going to be completely + // unrolled. + if (UnrollRemainder) + return NewLoop; - // We rely on LCSSA form being preserved when the exit blocks are transformed. - if (!PreserveLCSSA) - return false; + Optional<MDNode *> NewLoopID = makeFollowupLoopID( + LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); + if (NewLoopID.hasValue()) { + NewLoop->setLoopID(NewLoopID.getValue()); - // TODO: Support multiple exiting blocks jumping to the `LatchExit` when - // UnrollRuntimeMultiExit is true. This will need updating the logic in - // connectEpilog/connectProlog. - if (!LatchExit->getSinglePredecessor()) { - LLVM_DEBUG( - dbgs() << "Bailout for multi-exit handling when latch exit has >1 " - "predecessor.\n"); - return false; + // Do not setLoopAlreadyUnrolled if loop attributes have been defined + // explicitly. + return NewLoop; } - // FIXME: We bail out of multi-exit unrolling when epilog loop is generated - // and L is an inner loop. This is because in presence of multiple exits, the - // outer loop is incorrect: we do not add the EpilogPreheader and exit to the - // outer loop. This is automatically handled in the prolog case, so we do not - // have that bug in prolog generation. - if (UseEpilogRemainder && L->getParentLoop()) - return false; - // All constraints have been satisfied. - return true; + // Add unroll disable metadata to disable future unrolling for this loop. + NewLoop->setLoopAlreadyUnrolled(); + return NewLoop; } /// Returns true if we can profitably unroll the multi-exit loop L. Currently, /// we return true only if UnrollRuntimeMultiExit is set to true. static bool canProfitablyUnrollMultiExitLoop( Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, BasicBlock *LatchExit, - bool PreserveLCSSA, bool UseEpilogRemainder) { - -#if !defined(NDEBUG) - assert(canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, - UseEpilogRemainder) && - "Should be safe to unroll before checking profitability!"); -#endif + bool UseEpilogRemainder) { // Priority goes to UnrollRuntimeMultiExit if it's supplied. if (UnrollRuntimeMultiExit.getNumOccurrences()) @@ -523,24 +468,56 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop, uint64_t TrueWeight, FalseWeight; BranchInst *LatchBR = cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()); - if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) { - uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() - ? FalseWeight - : TrueWeight; - assert(UnrollFactor > 1); - uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight; - BasicBlock *Header = RemainderLoop->getHeader(); - BasicBlock *Latch = RemainderLoop->getLoopLatch(); - auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator()); - unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1); - MDBuilder MDB(RemainderLatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) - : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); - RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); - } + if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) + return; + uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() + ? FalseWeight + : TrueWeight; + assert(UnrollFactor > 1); + uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight; + BasicBlock *Header = RemainderLoop->getHeader(); + BasicBlock *Latch = RemainderLoop->getLoopLatch(); + auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator()); + unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(RemainderLatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) + : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); } +/// Calculate ModVal = (BECount + 1) % Count on the abstract integer domain +/// accounting for the possibility of unsigned overflow in the 2s complement +/// domain. Preconditions: +/// 1) TripCount = BECount + 1 (allowing overflow) +/// 2) Log2(Count) <= BitWidth(BECount) +static Value *CreateTripRemainder(IRBuilder<> &B, Value *BECount, + Value *TripCount, unsigned Count) { + // Note that TripCount is BECount + 1. + if (isPowerOf2_32(Count)) + // If the expression is zero, then either: + // 1. There are no iterations to be run in the prolog/epilog loop. + // OR + // 2. The addition computing TripCount overflowed. + // + // If (2) is true, we know that TripCount really is (1 << BEWidth) and so + // the number of iterations that remain to be run in the original loop is a + // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (a + // precondition of this method). + return B.CreateAnd(TripCount, Count - 1, "xtraiter"); + + // As (BECount + 1) can potentially unsigned overflow we count + // (BECount % Count) + 1 which is overflow safe as BECount % Count < Count. + Constant *CountC = ConstantInt::get(BECount->getType(), Count); + Value *ModValTmp = B.CreateURem(BECount, CountC); + Value *ModValAdd = B.CreateAdd(ModValTmp, + ConstantInt::get(ModValTmp->getType(), 1)); + // At that point (BECount % Count) + 1 could be equal to Count. + // To handle this case we need to take mod by Count one more time. + return B.CreateURem(ModValAdd, CountC, "xtraiter"); +} + + /// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. /// @@ -624,19 +601,22 @@ bool llvm::UnrollRuntimeLoopRemainder( // These are exit blocks other than the target of the latch exiting block. SmallVector<BasicBlock *, 4> OtherExits; L->getUniqueNonLatchExitBlocks(OtherExits); - bool isMultiExitUnrollingEnabled = - canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, - UseEpilogRemainder) && - canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, - UseEpilogRemainder); - // Support only single exit and exiting block unless multi-exit loop unrolling is enabled. - if (!isMultiExitUnrollingEnabled && - (!L->getExitingBlock() || OtherExits.size())) { - LLVM_DEBUG( - dbgs() - << "Multiple exit/exiting blocks in loop and multi-exit unrolling not " - "enabled!\n"); - return false; + // Support only single exit and exiting block unless multi-exit loop + // unrolling is enabled. + if (!L->getExitingBlock() || OtherExits.size()) { + // We rely on LCSSA form being preserved when the exit blocks are transformed. + // (Note that only an off-by-default mode of the old PM disables PreserveLCCA.) + if (!PreserveLCSSA) + return false; + + if (!canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, + UseEpilogRemainder)) { + LLVM_DEBUG( + dbgs() + << "Multiple exit/exiting blocks in loop and multi-exit unrolling not " + "enabled!\n"); + return false; + } } // Use Scalar Evolution to compute the trip count. This allows more loops to // be unrolled than relying on induction var simplification. @@ -659,6 +639,7 @@ bool llvm::UnrollRuntimeLoopRemainder( unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth(); // Add 1 since the backedge count doesn't include the first loop iteration. + // (Note that overflow can occur, this is handled explicitly below) const SCEV *TripCountSC = SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1)); if (isa<SCEVCouldNotCompute>(TripCountSC)) { @@ -706,8 +687,7 @@ bool llvm::UnrollRuntimeLoopRemainder( NewPreHeader = SplitBlock(PreHeader, PreHeader->getTerminator(), DT, LI); NewPreHeader->setName(PreHeader->getName() + ".new"); // Split LatchExit to create phi nodes from branch above. - SmallVector<BasicBlock*, 4> Preds(predecessors(LatchExit)); - NewExit = SplitBlockPredecessors(LatchExit, Preds, ".unr-lcssa", DT, LI, + NewExit = SplitBlockPredecessors(LatchExit, {Latch}, ".unr-lcssa", DT, LI, nullptr, PreserveLCSSA); // NewExit gets its DebugLoc from LatchExit, which is not part of the // original Loop. @@ -717,6 +697,21 @@ bool llvm::UnrollRuntimeLoopRemainder( // Split NewExit to insert epilog remainder loop. EpilogPreHeader = SplitBlock(NewExit, NewExitTerminator, DT, LI); EpilogPreHeader->setName(Header->getName() + ".epil.preheader"); + + // If the latch exits from multiple level of nested loops, then + // by assumption there must be another loop exit which branches to the + // outer loop and we must adjust the loop for the newly inserted blocks + // to account for the fact that our epilogue is still in the same outer + // loop. Note that this leaves loopinfo temporarily out of sync with the + // CFG until the actual epilogue loop is inserted. + if (auto *ParentL = L->getParentLoop()) + if (LI->getLoopFor(LatchExit) != ParentL) { + LI->removeBlock(NewExit); + ParentL->addBasicBlockToLoop(NewExit, *LI); + LI->removeBlock(EpilogPreHeader); + ParentL->addBasicBlockToLoop(EpilogPreHeader, *LI); + } + } else { // If prolog remainder // Split the original preheader twice to insert prolog remainder loop @@ -751,35 +746,8 @@ bool llvm::UnrollRuntimeLoopRemainder( Value *BECount = Expander.expandCodeFor(BECountSC, BECountSC->getType(), PreHeaderBR); IRBuilder<> B(PreHeaderBR); - Value *ModVal; - // Calculate ModVal = (BECount + 1) % Count. - // Note that TripCount is BECount + 1. - if (isPowerOf2_32(Count)) { - // When Count is power of 2 we don't BECount for epilog case, however we'll - // need it for a branch around unrolling loop for prolog case. - ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); - // 1. There are no iterations to be run in the prolog/epilog loop. - // OR - // 2. The addition computing TripCount overflowed. - // - // If (2) is true, we know that TripCount really is (1 << BEWidth) and so - // the number of iterations that remain to be run in the original loop is a - // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we - // explicitly check this above). - } else { - // As (BECount + 1) can potentially unsigned overflow we count - // (BECount % Count) + 1 which is overflow safe as BECount % Count < Count. - Value *ModValTmp = B.CreateURem(BECount, - ConstantInt::get(BECount->getType(), - Count)); - Value *ModValAdd = B.CreateAdd(ModValTmp, - ConstantInt::get(ModValTmp->getType(), 1)); - // At that point (BECount % Count) + 1 could be equal to Count. - // To handle this case we need to take mod by Count one more time. - ModVal = B.CreateURem(ModValAdd, - ConstantInt::get(BECount->getType(), Count), - "xtraiter"); - } + Value * const ModVal = CreateTripRemainder(B, BECount, TripCount, Count); + Value *BranchVal = UseEpilogRemainder ? B.CreateICmpULT(BECount, ConstantInt::get(BECount->getType(), @@ -810,18 +778,13 @@ bool llvm::UnrollRuntimeLoopRemainder( std::vector<BasicBlock *> NewBlocks; ValueToValueMapTy VMap; - // For unroll factor 2 remainder loop will have 1 iterations. - // Do not create 1 iteration loop. - bool CreateRemainderLoop = (Count != 2); - // Clone all the basic blocks in the loop. If Count is 2, we don't clone // the loop, otherwise we create a cloned loop to execute the extra // iterations. This function adds the appropriate CFG connections. BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit; BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; Loop *remainderLoop = CloneLoopBlocks( - L, ModVal, CreateRemainderLoop, UseEpilogRemainder, UnrollRemainder, - InsertTop, InsertBot, + L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); // Assign the maximum possible trip count as the back edge weight for the @@ -840,36 +803,33 @@ bool llvm::UnrollRuntimeLoopRemainder( // work is to update the phi nodes in the original loop, and take in the // values from the cloned region. for (auto *BB : OtherExits) { - for (auto &II : *BB) { - - // Given we preserve LCSSA form, we know that the values used outside the - // loop will be used through these phi nodes at the exit blocks that are - // transformed below. - if (!isa<PHINode>(II)) - break; - PHINode *Phi = cast<PHINode>(&II); - unsigned oldNumOperands = Phi->getNumIncomingValues(); + // Given we preserve LCSSA form, we know that the values used outside the + // loop will be used through these phi nodes at the exit blocks that are + // transformed below. + for (PHINode &PN : BB->phis()) { + unsigned oldNumOperands = PN.getNumIncomingValues(); // Add the incoming values from the remainder code to the end of the phi // node. - for (unsigned i =0; i < oldNumOperands; i++){ - Value *newVal = VMap.lookup(Phi->getIncomingValue(i)); - // newVal can be a constant or derived from values outside the loop, and - // hence need not have a VMap value. Also, since lookup already generated - // a default "null" VMap entry for this value, we need to populate that - // VMap entry correctly, with the mapped entry being itself. - if (!newVal) { - newVal = Phi->getIncomingValue(i); - VMap[Phi->getIncomingValue(i)] = Phi->getIncomingValue(i); - } - Phi->addIncoming(newVal, - cast<BasicBlock>(VMap[Phi->getIncomingBlock(i)])); + for (unsigned i = 0; i < oldNumOperands; i++){ + auto *PredBB =PN.getIncomingBlock(i); + if (PredBB == Latch) + // The latch exit is handled seperately, see connectX + continue; + if (!L->contains(PredBB)) + // Even if we had dedicated exits, the code above inserted an + // extra branch which can reach the latch exit. + continue; + + auto *V = PN.getIncomingValue(i); + if (Instruction *I = dyn_cast<Instruction>(V)) + if (L->contains(I)) + V = VMap.lookup(I); + PN.addIncoming(V, cast<BasicBlock>(VMap[PredBB])); } } #if defined(EXPENSIVE_CHECKS) && !defined(NDEBUG) for (BasicBlock *SuccBB : successors(BB)) { - assert(!(any_of(OtherExits, - [SuccBB](BasicBlock *EB) { return EB == SuccBB; }) || - SuccBB == LatchExit) && + assert(!(llvm::is_contained(OtherExits, SuccBB) || SuccBB == LatchExit) && "Breaks the definition of dedicated exits!"); } #endif @@ -931,23 +891,22 @@ bool llvm::UnrollRuntimeLoopRemainder( PreserveLCSSA); // Update counter in loop for unrolling. - // I should be multiply of Count. + // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. + // Subtle: TestVal can be 0 if we wrapped when computing the trip count, + // thus we must compare the post-increment (wrapping) value. IRBuilder<> B2(NewPreHeader->getTerminator()); Value *TestVal = B2.CreateSub(TripCount, ModVal, "unroll_iter"); BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); - B2.SetInsertPoint(LatchBR); PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter", Header->getFirstNonPHI()); - Value *IdxSub = - B2.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), - NewIdx->getName() + ".nsub"); - Value *IdxCmp; - if (LatchBR->getSuccessor(0) == Header) - IdxCmp = B2.CreateIsNotNull(IdxSub, NewIdx->getName() + ".ncmp"); - else - IdxCmp = B2.CreateIsNull(IdxSub, NewIdx->getName() + ".ncmp"); - NewIdx->addIncoming(TestVal, NewPreHeader); - NewIdx->addIncoming(IdxSub, Latch); + B2.SetInsertPoint(LatchBR); + auto *Zero = ConstantInt::get(NewIdx->getType(), 0); + auto *One = ConstantInt::get(NewIdx->getType(), 1); + Value *IdxNext = B2.CreateAdd(NewIdx, One, NewIdx->getName() + ".next"); + auto Pred = LatchBR->getSuccessor(0) == Header ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + Value *IdxCmp = B2.CreateICmp(Pred, IdxNext, TestVal, NewIdx->getName() + ".ncmp"); + NewIdx->addIncoming(Zero, NewPreHeader); + NewIdx->addIncoming(IdxNext, Latch); LatchBR->setCondition(IdxCmp); } else { // Connect the prolog code to the original loop and update the @@ -960,12 +919,49 @@ bool llvm::UnrollRuntimeLoopRemainder( // of its parent loops, so the Scalar Evolution pass needs to be run again. SE->forgetTopmostLoop(L); - // Verify that the Dom Tree is correct. + // Verify that the Dom Tree and Loop Info are correct. #if defined(EXPENSIVE_CHECKS) && !defined(NDEBUG) - if (DT) + if (DT) { assert(DT->verify(DominatorTree::VerificationLevel::Full)); + LI->verify(*DT); + } #endif + // For unroll factor 2 remainder loop will have 1 iteration. + if (Count == 2 && DT && LI && SE) { + // TODO: This code could probably be pulled out into a helper function + // (e.g. breakLoopBackedgeAndSimplify) and reused in loop-deletion. + BasicBlock *RemainderLatch = remainderLoop->getLoopLatch(); + assert(RemainderLatch); + SmallVector<BasicBlock*> RemainderBlocks(remainderLoop->getBlocks().begin(), + remainderLoop->getBlocks().end()); + breakLoopBackedge(remainderLoop, *DT, *SE, *LI, nullptr); + remainderLoop = nullptr; + + // Simplify loop values after breaking the backedge + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SmallVector<WeakTrackingVH, 16> DeadInsts; + for (BasicBlock *BB : RemainderBlocks) { + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { + if (Value *V = SimplifyInstruction(&Inst, {DL, nullptr, DT, AC})) + if (LI->replacementPreservesLCSSAForm(&Inst, V)) + Inst.replaceAllUsesWith(V); + if (isInstructionTriviallyDead(&Inst)) + DeadInsts.emplace_back(&Inst); + } + // We can't do recursive deletion until we're done iterating, as we might + // have a phi which (potentially indirectly) uses instructions later in + // the block we're iterating through. + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts); + } + + // Merge latch into exit block. + auto *ExitBB = RemainderLatch->getSingleSuccessor(); + assert(ExitBB && "required after breaking cond br backedge"); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + MergeBlockIntoPredecessor(ExitBB, &DTU, LI); + } + // Canonicalize to LoopSimplifyForm both original and remainder loops. We // cannot rely on the LoopUnrollPass to do this because it only does // canonicalization for parent/subloops and not the sibling loops. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp index e4d78f9ada08..f0f079335683 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -612,10 +612,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, for (auto *Block : L->blocks()) for (Instruction &I : *Block) { auto *Undef = UndefValue::get(I.getType()); - for (Value::use_iterator UI = I.use_begin(), E = I.use_end(); - UI != E;) { - Use &U = *UI; - ++UI; + for (Use &U : llvm::make_early_inc_range(I.uses())) { if (auto *Usr = dyn_cast<Instruction>(U.getUser())) if (L->contains(Usr->getParent())) continue; @@ -710,21 +707,58 @@ void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE, SE.forgetLoop(L); - // Note: By splitting the backedge, and then explicitly making it unreachable - // we gracefully handle corner cases such as non-bottom tested loops and the - // like. We also have the benefit of being able to reuse existing well tested - // code. It might be worth special casing the common bottom tested case at - // some point to avoid code churn. - std::unique_ptr<MemorySSAUpdater> MSSAU; if (MSSA) MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - auto *BackedgeBB = SplitEdge(Latch, Header, &DT, &LI, MSSAU.get()); + // Update the CFG and domtree. We chose to special case a couple of + // of common cases for code quality and test readability reasons. + [&]() -> void { + if (auto *BI = dyn_cast<BranchInst>(Latch->getTerminator())) { + if (!BI->isConditional()) { + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); + (void)changeToUnreachable(BI, /*PreserveLCSSA*/ true, &DTU, + MSSAU.get()); + return; + } + + // Conditional latch/exit - note that latch can be shared by inner + // and outer loop so the other target doesn't need to an exit + if (L->isLoopExiting(Latch)) { + // TODO: Generalize ConstantFoldTerminator so that it can be used + // here without invalidating LCSSA or MemorySSA. (Tricky case for + // LCSSA: header is an exit block of a preceeding sibling loop w/o + // dedicated exits.) + const unsigned ExitIdx = L->contains(BI->getSuccessor(0)) ? 1 : 0; + BasicBlock *ExitBB = BI->getSuccessor(ExitIdx); + + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); + Header->removePredecessor(Latch, true); + + IRBuilder<> Builder(BI); + auto *NewBI = Builder.CreateBr(ExitBB); + // Transfer the metadata to the new branch instruction (minus the + // loop info since this is no longer a loop) + NewBI->copyMetadata(*BI, {LLVMContext::MD_dbg, + LLVMContext::MD_annotation}); + + BI->eraseFromParent(); + DTU.applyUpdates({{DominatorTree::Delete, Latch, Header}}); + if (MSSA) + MSSAU->applyUpdates({{DominatorTree::Delete, Latch, Header}}, DT); + return; + } + } - DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); - (void)changeToUnreachable(BackedgeBB->getTerminator(), - /*PreserveLCSSA*/ true, &DTU, MSSAU.get()); + // General case. By splitting the backedge, and then explicitly making it + // unreachable we gracefully handle corner cases such as switch and invoke + // termiantors. + auto *BackedgeBB = SplitEdge(Latch, Header, &DT, &LI, MSSAU.get()); + + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); + (void)changeToUnreachable(BackedgeBB->getTerminator(), + /*PreserveLCSSA*/ true, &DTU, MSSAU.get()); + }(); // Erase (and destroy) this loop instance. Handles relinking sub-loops // and blocks within the loop as needed. @@ -852,32 +886,37 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } -Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, - Value *Right) { - CmpInst::Predicate Pred; +CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { switch (RK) { default: llvm_unreachable("Unknown min/max recurrence kind"); case RecurKind::UMin: - Pred = CmpInst::ICMP_ULT; - break; + return CmpInst::ICMP_ULT; case RecurKind::UMax: - Pred = CmpInst::ICMP_UGT; - break; + return CmpInst::ICMP_UGT; case RecurKind::SMin: - Pred = CmpInst::ICMP_SLT; - break; + return CmpInst::ICMP_SLT; case RecurKind::SMax: - Pred = CmpInst::ICMP_SGT; - break; + return CmpInst::ICMP_SGT; case RecurKind::FMin: - Pred = CmpInst::FCMP_OLT; - break; + return CmpInst::FCMP_OLT; case RecurKind::FMax: - Pred = CmpInst::FCMP_OGT; - break; + return CmpInst::FCMP_OGT; } +} +Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, + RecurKind RK, Value *Left, Value *Right) { + if (auto VTy = dyn_cast<VectorType>(Left->getType())) + StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal); + Value *Cmp = + Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp"); + return Builder.CreateSelect(Cmp, Left, Right, "rdx.select"); +} + +Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, + Value *Right) { + CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK); Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp"); Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); return Select; @@ -955,15 +994,50 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } +Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, + const TargetTransformInfo *TTI, + Value *Src, + const RecurrenceDescriptor &Desc, + PHINode *OrigPhi) { + assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind( + Desc.getRecurrenceKind()) && + "Unexpected reduction kind"); + Value *InitVal = Desc.getRecurrenceStartValue(); + Value *NewVal = nullptr; + + // First use the original phi to determine the new value we're trying to + // select from in the loop. + SelectInst *SI = nullptr; + for (auto *U : OrigPhi->users()) { + if ((SI = dyn_cast<SelectInst>(U))) + break; + } + assert(SI && "One user of the original phi should be a select"); + + if (SI->getTrueValue() == OrigPhi) + NewVal = SI->getFalseValue(); + else { + assert(SI->getFalseValue() == OrigPhi && + "At least one input to the select should be the original Phi"); + NewVal = SI->getTrueValue(); + } + + // Create a splat vector with the new value and compare this to the vector + // we want to reduce. + ElementCount EC = cast<VectorType>(Src->getType())->getElementCount(); + Value *Right = Builder.CreateVectorSplat(EC, InitVal); + Value *Cmp = + Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp"); + + // If any predicate is true it means that we want to select the new value. + Cmp = Builder.CreateOrReduce(Cmp); + return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select"); +} + Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, const TargetTransformInfo *TTI, Value *Src, RecurKind RdxKind, ArrayRef<Value *> RedOps) { - TargetTransformInfo::ReductionFlags RdxFlags; - RdxFlags.IsMaxOp = RdxKind == RecurKind::SMax || RdxKind == RecurKind::UMax || - RdxKind == RecurKind::FMax; - RdxFlags.IsSigned = RdxKind == RecurKind::SMax || RdxKind == RecurKind::SMin; - auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); switch (RdxKind) { case RecurKind::Add: @@ -1000,14 +1074,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *llvm::createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI, - const RecurrenceDescriptor &Desc, - Value *Src) { + const RecurrenceDescriptor &Desc, Value *Src, + PHINode *OrigPhi) { // TODO: Support in-order reductions based on the recurrence descriptor. // All ops in the reduction inherit fast-math-flags from the recurrence // descriptor. IRBuilderBase::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(Desc.getFastMathFlags()); - return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind()); + + RecurKind RK = Desc.getRecurrenceKind(); + if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) + return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi); + + return createSimpleTargetReduction(B, TTI, Src, RK); } Value *llvm::createOrderedReduction(IRBuilderBase &B, @@ -1081,58 +1160,6 @@ bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE, // As a side effect, reduces the amount of IV processing within the loop. //===----------------------------------------------------------------------===// -// Return true if the SCEV expansion generated by the rewriter can replace the -// original value. SCEV guarantees that it produces the same value, but the way -// it is produced may be illegal IR. Ideally, this function will only be -// called for verification. -static bool isValidRewrite(ScalarEvolution *SE, Value *FromVal, Value *ToVal) { - // If an SCEV expression subsumed multiple pointers, its expansion could - // reassociate the GEP changing the base pointer. This is illegal because the - // final address produced by a GEP chain must be inbounds relative to its - // underlying object. Otherwise basic alias analysis, among other things, - // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid - // producing an expression involving multiple pointers. Until then, we must - // bail out here. - // - // Retrieve the pointer operand of the GEP. Don't use getUnderlyingObject - // because it understands lcssa phis while SCEV does not. - Value *FromPtr = FromVal; - Value *ToPtr = ToVal; - if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) - FromPtr = GEP->getPointerOperand(); - - if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) - ToPtr = GEP->getPointerOperand(); - - if (FromPtr != FromVal || ToPtr != ToVal) { - // Quickly check the common case - if (FromPtr == ToPtr) - return true; - - // SCEV may have rewritten an expression that produces the GEP's pointer - // operand. That's ok as long as the pointer operand has the same base - // pointer. Unlike getUnderlyingObject(), getPointerBase() will find the - // base of a recurrence. This handles the case in which SCEV expansion - // converts a pointer type recurrence into a nonrecurrent pointer base - // indexed by an integer recurrence. - - // If the GEP base pointer is a vector of pointers, abort. - if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) - return false; - - const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); - const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); - if (FromBase == ToBase) - return true; - - LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: GEP rewrite bail out " - << *FromBase << " != " << *ToBase << "\n"); - - return false; - } - return true; -} - static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) { SmallPtrSet<const Instruction *, 8> Visited; SmallVector<const Instruction *, 8> WorkList; @@ -1165,9 +1192,6 @@ struct RewritePhi { Instruction *ExpansionPoint; // Where we'd like to expand that SCEV? bool HighCost; // Is this expansion a high-cost? - Value *Expansion = nullptr; - bool ValidRewrite = false; - RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt, bool H) : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt), @@ -1204,8 +1228,6 @@ static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) // phase later. Skip it in the loop invariant check below. bool found = false; for (const RewritePhi &Phi : RewritePhiSet) { - if (!Phi.ValidRewrite) - continue; unsigned i = Phi.Ith; if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { found = true; @@ -1264,13 +1286,6 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, if (!SE->isSCEVable(PN->getType())) continue; - // It's necessary to tell ScalarEvolution about this explicitly so that - // it can walk the def-use list and forget all SCEVs, as it may not be - // watching the PHI itself. Once the new exit value is in place, there - // may not be a def-use connection between the loop and every instruction - // which got a SCEVAddRecExpr for that loop. - SE->forgetValue(PN); - // Iterate over all of the values in all the PHI nodes. for (unsigned i = 0; i != NumPreds; ++i) { // If the value being merged in is not integer or is not defined @@ -1339,61 +1354,49 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, } } - // Now that we've done preliminary filtering and billed all the SCEV's, - // we can perform the last sanity check - the expansion must be valid. - for (RewritePhi &Phi : RewritePhiSet) { - Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(), - Phi.ExpansionPoint); + // TODO: evaluate whether it is beneficial to change how we calculate + // high-cost: if we have SCEV 'A' which we know we will expand, should we + // calculate the cost of other SCEV's after expanding SCEV 'A', thus + // potentially giving cost bonus to those other SCEV's? - LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " - << *(Phi.Expansion) << '\n' - << " LoopVal = " << *(Phi.ExpansionPoint) << "\n"); + bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); + int NumReplaced = 0; + + // Transformation. + for (const RewritePhi &Phi : RewritePhiSet) { + PHINode *PN = Phi.PN; - // FIXME: isValidRewrite() is a hack. it should be an assert, eventually. - Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion); - if (!Phi.ValidRewrite) { - DeadInsts.push_back(Phi.Expansion); + // Only do the rewrite when the ExitValue can be expanded cheaply. + // If LoopCanBeDel is true, rewrite exit value aggressively. + if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) continue; - } + + Value *ExitVal = Rewriter.expandCodeFor( + Phi.ExpansionSCEV, Phi.PN->getType(), Phi.ExpansionPoint); + + LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal + << '\n' + << " LoopVal = " << *(Phi.ExpansionPoint) << "\n"); #ifndef NDEBUG // If we reuse an instruction from a loop which is neither L nor one of // its containing loops, we end up breaking LCSSA form for this loop by // creating a new use of its instruction. - if (auto *ExitInsn = dyn_cast<Instruction>(Phi.Expansion)) + if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal)) if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) if (EVL != L) assert(EVL->contains(L) && "LCSSA breach detected!"); #endif - } - - // TODO: after isValidRewrite() is an assertion, evaluate whether - // it is beneficial to change how we calculate high-cost: - // if we have SCEV 'A' which we know we will expand, should we calculate - // the cost of other SCEV's after expanding SCEV 'A', - // thus potentially giving cost bonus to those other SCEV's? - - bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); - int NumReplaced = 0; - - // Transformation. - for (const RewritePhi &Phi : RewritePhiSet) { - if (!Phi.ValidRewrite) - continue; - - PHINode *PN = Phi.PN; - Value *ExitVal = Phi.Expansion; - - // Only do the rewrite when the ExitValue can be expanded cheaply. - // If LoopCanBeDel is true, rewrite exit value aggressively. - if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { - DeadInsts.push_back(ExitVal); - continue; - } NumReplaced++; Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith)); PN->setIncomingValue(Phi.Ith, ExitVal); + // It's necessary to tell ScalarEvolution about this explicitly so that + // it can walk the def-use list and forget all SCEVs, as it may not be + // watching the PHI itself. Once the new exit value is in place, there + // may not be a def-use connection between the loop and every instruction + // which got a SCEVAddRecExpr for that loop. + SE->forgetValue(PN); // If this instruction is dead now, delete it. Don't do it now to avoid // invalidating iterators. @@ -1554,7 +1557,7 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, return ChecksWithBounds; } -std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( +Value *llvm::addRuntimeChecks( Instruction *Loc, Loop *TheLoop, const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, SCEVExpander &Exp) { @@ -1563,22 +1566,10 @@ std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp); LLVMContext &Ctx = Loc->getContext(); - Instruction *FirstInst = nullptr; IRBuilder<> ChkBuilder(Loc); // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; - // FIXME: this helper is currently a duplicate of the one in - // LoopVectorize.cpp. - auto GetFirstInst = [](Instruction *FirstInst, Value *V, - Instruction *Loc) -> Instruction * { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast<Instruction>(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; - }; - for (const auto &Check : ExpandedChecks) { const PointerBounds &A = Check.first, &B = Check.second; // Check if two pointers (A and B) conflict where conflict is computed as: @@ -1607,30 +1598,16 @@ std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( // bound1 = (A.Start < B.End) // IsConflict = bound0 & bound1 Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); - FirstInst = GetFirstInst(FirstInst, Cmp0, Loc); Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); - FirstInst = GetFirstInst(FirstInst, Cmp1, Loc); Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); - FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); - FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); } MemoryRuntimeCheck = IsConflict; } - if (!MemoryRuntimeCheck) - return std::make_pair(nullptr, nullptr); - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - Instruction *Check = - BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(Check, "memcheck.conflict"); - FirstInst = GetFirstInst(FirstInst, Check, Loc); - return std::make_pair(FirstInst, Check); + return MemoryRuntimeCheck; } Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 8a89158788cf..771b7d25b0f2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -14,9 +14,9 @@ #include "llvm/Transforms/Utils/LoopVersioning.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Dominators.h" @@ -52,8 +52,7 @@ void LoopVersioning::versionLoop( assert(VersionedLoop->isLoopSimplifyForm() && "Loop is not in loop-simplify form"); - Instruction *FirstCheckInst; - Instruction *MemRuntimeCheck; + Value *MemRuntimeCheck; Value *SCEVRuntimeCheck; Value *RuntimeCheck = nullptr; @@ -64,8 +63,8 @@ void LoopVersioning::versionLoop( SCEVExpander Exp2(*RtPtrChecking.getSE(), VersionedLoop->getHeader()->getModule()->getDataLayout(), "induction"); - std::tie(FirstCheckInst, MemRuntimeCheck) = addRuntimeChecks( - RuntimeCheckBB->getTerminator(), VersionedLoop, AliasChecks, Exp2); + MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(), + VersionedLoop, AliasChecks, Exp2); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), "scev.check"); @@ -354,14 +353,11 @@ PreservedAnalyses LoopVersioningPass::run(Function &F, auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - MemorySSA *MSSA = EnableMSSALoopDependency - ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() - : nullptr; auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, MSSA}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, + TLI, TTI, nullptr, nullptr, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 616b4e8eb01c..8dc4702993c3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -442,7 +442,7 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, /* DestAlign */ Memcpy->getDestAlign().valueOrOne(), /* SrcIsVolatile */ Memcpy->isVolatile(), /* DstIsVolatile */ Memcpy->isVolatile(), - /* TargetTransfomrInfo */ TTI); + /* TargetTransformInfo */ TTI); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp index ec8d7a7074cd..aff9d1311688 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -524,16 +524,14 @@ bool LowerSwitch(Function &F, LazyValueInfo *LVI, AssumptionCache *AC) { bool Changed = false; SmallPtrSet<BasicBlock *, 8> DeleteList; - for (Function::iterator I = F.begin(), E = F.end(); I != E;) { - BasicBlock *Cur = - &*I++; // Advance over block so we don't traverse new blocks - + // We use make_early_inc_range here so that we don't traverse new blocks. + for (BasicBlock &Cur : llvm::make_early_inc_range(F)) { // If the block is a dead Default block that will be deleted later, don't // waste time processing it. - if (DeleteList.count(Cur)) + if (DeleteList.count(&Cur)) continue; - if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { + if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur.getTerminator())) { Changed = true; ProcessSwitchInst(SI, DeleteList, AC, LVI); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 2aef37205c53..bb5ff59cba4b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -125,7 +125,7 @@ Function *llvm::createSanitizerCtor(Module &M, StringRef CtorName) { Function *Ctor = Function::createWithDefaultAttr( FunctionType::get(Type::getVoidTy(M.getContext()), false), GlobalValue::InternalLinkage, 0, CtorName, &M); - Ctor->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); + Ctor->addFnAttr(Attribute::NoUnwind); BasicBlock *CtorBB = BasicBlock::Create(M.getContext(), "", Ctor); ReturnInst::Create(M.getContext(), CtorBB); // Ensure Ctor cannot be discarded, even if in a comdat. @@ -165,7 +165,7 @@ llvm::getOrCreateSanitizerCtorAndInitFunctions( if (Function *Ctor = M.getFunction(CtorName)) // FIXME: Sink this logic into the module, similar to the handling of // globals. This will make moving to a concurrent model much easier. - if (Ctor->arg_size() == 0 || + if (Ctor->arg_empty() || Ctor->getReturnType() == Type::getVoidTy(M.getContext())) return {Ctor, declareSanitizerInitFunction(M, InitName, InitArgTypes)}; @@ -297,7 +297,6 @@ void VFABI::setVectorVariantNames( "vector function declaration is missing."); } #endif - CI->addAttribute( - AttributeList::FunctionIndex, + CI->addFnAttr( Attribute::get(M->getContext(), MappingsAttrName, Buffer.str())); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 427028066026..b35ab57e0d87 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -70,7 +70,8 @@ bool llvm::isAllocaPromotable(const AllocaInst *AI) { if (LI->isVolatile()) return false; } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) { - if (SI->getOperand(0) == AI) + if (SI->getValueOperand() == AI || + SI->getValueOperand()->getType() != AI->getAllocatedType()) return false; // Don't allow a store OF the AI, only INTO the AI. // Note that atomic stores can be transformed; atomic semantics do // not have any meaning for a local alloca. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp index 3127432dc6c9..65207056a3f4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp @@ -181,9 +181,7 @@ static bool convertToRelativeLookupTables( bool Changed = false; - for (auto GVI = M.global_begin(), E = M.global_end(); GVI != E;) { - GlobalVariable &GV = *GVI++; - + for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { if (!shouldConvertToRelLookupTable(M, GV)) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 4cf99abcc10f..d7e8eaf677c6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -540,8 +540,14 @@ void SCCPInstVisitor::markArgInFuncSpecialization(Function *F, Argument *A, E = F->arg_end(); I != E; ++I, ++J) if (J != A && ValueState.count(I)) { - ValueState[J] = ValueState[I]; - pushToWorkList(ValueState[J], J); + // Note: This previously looked like this: + // ValueState[J] = ValueState[I]; + // This is incorrect because the DenseMap class may resize the underlying + // memory when inserting `J`, which will invalidate the reference to `I`. + // Instead, we make sure `J` exists, then set it to `I` afterwards. + auto &NewValue = ValueState[J]; + NewValue = ValueState[I]; + pushToWorkList(NewValue, J); } } @@ -802,6 +808,9 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { return; ValueLatticeElement OpSt = getValueState(I.getOperand(0)); + if (OpSt.isUnknownOrUndef()) + return; + if (Constant *OpC = getConstant(OpSt)) { // Fold the constant as we build. Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); @@ -809,9 +818,14 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { return; // Propagate constant value markConstant(&I, C); - } else if (OpSt.isConstantRange() && I.getDestTy()->isIntegerTy()) { + } else if (I.getDestTy()->isIntegerTy()) { auto &LV = getValueState(&I); - ConstantRange OpRange = OpSt.getConstantRange(); + ConstantRange OpRange = + OpSt.isConstantRange() + ? OpSt.getConstantRange() + : ConstantRange::getFull( + I.getOperand(0)->getType()->getScalarSizeInBits()); + Type *DestTy = I.getDestTy(); // Vectors where all elements have the same known constant range are treated // as a single constant range in the lattice. When bitcasting such vectors, @@ -826,7 +840,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { ConstantRange Res = OpRange.castOp(I.getOpcode(), DL.getTypeSizeInBits(DestTy)); mergeInValue(LV, &I, ValueLatticeElement::getRange(Res)); - } else if (!OpSt.isUnknownOrUndef()) + } else markOverdefined(&I); } @@ -1183,10 +1197,10 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { // a declaration, maybe we can constant fold it. if (F && F->isDeclaration() && canConstantFoldCallTo(&CB, F)) { SmallVector<Constant *, 8> Operands; - for (auto AI = CB.arg_begin(), E = CB.arg_end(); AI != E; ++AI) { - if (AI->get()->getType()->isStructTy()) + for (const Use &A : CB.args()) { + if (A.get()->getType()->isStructTy()) return markOverdefined(&CB); // Can't handle struct args. - ValueLatticeElement State = getValueState(*AI); + ValueLatticeElement State = getValueState(A); if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp index 917d5e0a1ef0..7de76b86817b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp @@ -65,12 +65,6 @@ void SSAUpdaterBulk::AddUse(unsigned Var, Use *U) { Rewrites[Var].Uses.push_back(U); } -/// Return true if the SSAUpdater already has a value for the specified variable -/// in the specified block. -bool SSAUpdaterBulk::HasValueForBlock(unsigned Var, BasicBlock *BB) { - return (Var < Rewrites.size()) ? Rewrites[Var].Defines.count(BB) : false; -} - // Compute value at the given block BB. We either should already know it, or we // should be able to recursively reach it going up dominator tree. Value *SSAUpdaterBulk::computeValueAt(BasicBlock *BB, RewriteInfo &R, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 3978e1e29825..a042146d7ace 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -747,9 +747,8 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { // so that pointer operands are inserted first, which the code below relies on // to form more involved GEPs. SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops; - for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(S->op_end()), - E(S->op_begin()); I != E; ++I) - OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); + for (const SCEV *Op : reverse(S->operands())) + OpsAndLoops.push_back(std::make_pair(getRelevantLoop(Op), Op)); // Sort by loop. Use a stable sort so that constants follow non-constants and // pointer operands precede non-pointer operands. @@ -765,7 +764,11 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { // This is the first operand. Just expand it. Sum = expand(Op); ++I; - } else if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) { + continue; + } + + assert(!Op->getType()->isPointerTy() && "Only first op can be pointer"); + if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) { // The running sum expression is a pointer. Try to form a getelementptr // at this level with that as the base. SmallVector<const SCEV *, 4> NewOps; @@ -779,16 +782,6 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { NewOps.push_back(X); } Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum); - } else if (PointerType *PTy = dyn_cast<PointerType>(Op->getType())) { - // The running sum is an integer, and there's a pointer at this level. - // Try to form a getelementptr. If the running sum is instructions, - // use a SCEVUnknown to avoid re-analyzing them. - SmallVector<const SCEV *, 4> NewOps; - NewOps.push_back(isa<Instruction>(Sum) ? SE.getUnknown(Sum) : - SE.getSCEV(Sum)); - for (++I; I != E && I->first == CurLoop; ++I) - NewOps.push_back(I->second); - Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op)); } else if (Op->isNonConstantNegative()) { // Instead of doing a negate and add, just do a subtract. Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty, false); @@ -817,9 +810,8 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { // Collect all the mul operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal. SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops; - for (std::reverse_iterator<SCEVMulExpr::op_iterator> I(S->op_end()), - E(S->op_begin()); I != E; ++I) - OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); + for (const SCEV *Op : reverse(S->operands())) + OpsAndLoops.push_back(std::make_pair(getRelevantLoop(Op), Op)); // Sort by loop. Use a stable sort so that constants follow non-constants. llvm::stable_sort(OpsAndLoops, LoopCompare(SE.DT)); @@ -923,28 +915,6 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS())); } -/// Move parts of Base into Rest to leave Base with the minimal -/// expression that provides a pointer operand suitable for a -/// GEP expansion. -static void ExposePointerBase(const SCEV *&Base, const SCEV *&Rest, - ScalarEvolution &SE) { - while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Base)) { - Base = A->getStart(); - Rest = SE.getAddExpr(Rest, - SE.getAddRecExpr(SE.getConstant(A->getType(), 0), - A->getStepRecurrence(SE), - A->getLoop(), - A->getNoWrapFlags(SCEV::FlagNW))); - } - if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(Base)) { - Base = A->getOperand(A->getNumOperands()-1); - SmallVector<const SCEV *, 8> NewAddOps(A->operands()); - NewAddOps.back() = Rest; - Rest = SE.getAddExpr(NewAddOps); - ExposePointerBase(Base, Rest, SE); - } -} - /// Determine if this is a well-behaved chain of instructions leading back to /// the PHI. If so, it may be reused by expanded expressions. bool SCEVExpander::isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV, @@ -1125,22 +1095,6 @@ Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, return IncV; } -/// Hoist the addrec instruction chain rooted in the loop phi above the -/// position. This routine assumes that this is possible (has been checked). -void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, - Instruction *Pos, PHINode *LoopPhi) { - do { - if (DT->dominates(InstToHoist, Pos)) - break; - // Make sure the increment is where we want it. But don't move it - // down past a potential existing post-inc user. - fixupInsertPoints(InstToHoist); - InstToHoist->moveBefore(Pos); - Pos = InstToHoist; - InstToHoist = cast<Instruction>(InstToHoist->getOperand(0)); - } while (InstToHoist != LoopPhi); -} - /// Check whether we can cheaply express the requested SCEV in terms of /// the available PHI SCEV by truncation and/or inversion of the step. static bool canBeCheaplyTransformed(ScalarEvolution &SE, @@ -1264,8 +1218,6 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, if (LSRMode) { if (!isExpandedAddRecExprPHI(&PN, TempIncV, L)) continue; - if (L == IVIncInsertLoop && !hoistIVInc(TempIncV, IVIncInsertPos)) - continue; } else { if (!isNormalAddRecExprPHI(&PN, TempIncV, L)) continue; @@ -1293,11 +1245,6 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, } if (AddRecPhiMatch) { - // Potentially, move the increment. We have made sure in - // isExpandedAddRecExprPHI or hoistIVInc that this is possible. - if (L == IVIncInsertLoop) - hoistBeforePos(&SE.DT, IncV, IVIncInsertPos, AddRecPhiMatch); - // Ok, the add recurrence looks usable. // Remember this PHI, even in post-inc mode. InsertedValues.insert(AddRecPhiMatch); @@ -1597,29 +1544,17 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { // {X,+,F} --> X + {0,+,F} if (!S->getStart()->isZero()) { + if (PointerType *PTy = dyn_cast<PointerType>(S->getType())) { + Value *StartV = expand(SE.getPointerBase(S)); + assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); + return expandAddToGEP(SE.removePointerBase(S), PTy, Ty, StartV); + } + SmallVector<const SCEV *, 4> NewOps(S->operands()); NewOps[0] = SE.getConstant(Ty, 0); const SCEV *Rest = SE.getAddRecExpr(NewOps, L, S->getNoWrapFlags(SCEV::FlagNW)); - // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the - // comments on expandAddToGEP for details. - const SCEV *Base = S->getStart(); - // Dig into the expression to find the pointer base for a GEP. - const SCEV *ExposedRest = Rest; - ExposePointerBase(Base, ExposedRest, SE); - // If we found a pointer, expand the AddRec with a GEP. - if (PointerType *PTy = dyn_cast<PointerType>(Base->getType())) { - // Make sure the Base isn't something exotic, such as a multiplied - // or divided pointer value. In those cases, the result type isn't - // actually a pointer type. - if (!isa<SCEVMulExpr>(Base) && !isa<SCEVUDivExpr>(Base)) { - Value *StartV = expand(Base); - assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); - return expandAddToGEP(ExposedRest, PTy, Ty, StartV); - } - } - // Just do a normal add. Pre-expand the operands to suppress folding. // // The LHS and RHS values are factored out of the expand call to make the @@ -1898,6 +1833,22 @@ Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) { return V; } +/// Check whether value has nuw/nsw/exact set but SCEV does not. +/// TODO: In reality it is better to check the poison recursively +/// but this is better than nothing. +static bool SCEVLostPoisonFlags(const SCEV *S, const Instruction *I) { + if (isa<OverflowingBinaryOperator>(I)) { + if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) { + if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap()) + return true; + if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap()) + return true; + } + } else if (isa<PossiblyExactOperator>(I) && I->isExact()) + return true; + return false; +} + ScalarEvolution::ValueOffsetPair SCEVExpander::FindValueInExprValueMap(const SCEV *S, const Instruction *InsertPt) { @@ -1907,19 +1858,22 @@ SCEVExpander::FindValueInExprValueMap(const SCEV *S, if (CanonicalMode || !SE.containsAddRecurrence(S)) { // If S is scConstant, it may be worse to reuse an existing Value. if (S->getSCEVType() != scConstant && Set) { - // Choose a Value from the set which dominates the insertPt. - // insertPt should be inside the Value's parent loop so as not to break + // Choose a Value from the set which dominates the InsertPt. + // InsertPt should be inside the Value's parent loop so as not to break // the LCSSA form. for (auto const &VOPair : *Set) { Value *V = VOPair.first; ConstantInt *Offset = VOPair.second; - Instruction *EntInst = nullptr; - if (V && isa<Instruction>(V) && (EntInst = cast<Instruction>(V)) && - S->getType() == V->getType() && - EntInst->getFunction() == InsertPt->getFunction() && + Instruction *EntInst = dyn_cast_or_null<Instruction>(V); + if (!EntInst) + continue; + + assert(EntInst->getFunction() == InsertPt->getFunction()); + if (S->getType() == V->getType() && SE.DT.dominates(EntInst, InsertPt) && (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || - SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) + SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt)) && + !SCEVLostPoisonFlags(S, EntInst)) return {V, Offset}; } } @@ -2068,7 +2022,9 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, Phis.push_back(&PN); if (TTI) - llvm::sort(Phis, [](Value *LHS, Value *RHS) { + // Use stable_sort to preserve order of equivalent PHIs, so the order + // of the sorted Phis is the same from run to run on the same loop. + llvm::stable_sort(Phis, [](Value *LHS, Value *RHS) { // Put pointers at the back and make sure pointer < pointer = false. if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy(); @@ -2524,18 +2480,14 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, IntegerType *Ty = IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy)); - Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty; Value *StepValue = expandCodeForImpl(Step, Ty, Loc, false); Value *NegStepValue = expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc, false); - Value *StartValue = expandCodeForImpl( - isa<PointerType>(ARExpandTy) ? Start - : SE.getPtrToIntExpr(Start, ARExpandTy), - ARExpandTy, Loc, false); + Value *StartValue = expandCodeForImpl(Start, ARTy, Loc, false); ConstantInt *Zero = - ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); + ConstantInt::get(Loc->getContext(), APInt::getZero(DstBits)); Builder.SetInsertPoint(Loc); // Compute |Step| @@ -2544,25 +2496,33 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, // Get the backedge taken count and truncate or extended to the AR type. Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); - auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), - Intrinsic::umul_with_overflow, Ty); // Compute |Step| * Backedge - CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); - Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); - Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + Value *MulV, *OfMul; + if (Step->isOne()) { + // Special-case Step of one. Potentially-costly `umul_with_overflow` isn't + // needed, there is never an overflow, so to avoid artificially inflating + // the cost of the check, directly emit the optimized IR. + MulV = TruncTripCount; + OfMul = ConstantInt::getFalse(MulV->getContext()); + } else { + auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), + Intrinsic::umul_with_overflow, Ty); + CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); + MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); + OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + } // Compute: // Start + |Step| * Backedge < Start // Start - |Step| * Backedge > Start Value *Add = nullptr, *Sub = nullptr; - if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARExpandTy)) { - const SCEV *MulS = SE.getSCEV(MulV); - const SCEV *NegMulS = SE.getNegativeSCEV(MulS); - Add = Builder.CreateBitCast(expandAddToGEP(MulS, ARPtrTy, Ty, StartValue), - ARPtrTy); - Sub = Builder.CreateBitCast( - expandAddToGEP(NegMulS, ARPtrTy, Ty, StartValue), ARPtrTy); + if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARTy)) { + StartValue = InsertNoopCastOfTo( + StartValue, Builder.getInt8PtrTy(ARPtrTy->getAddressSpace())); + Value *NegMulV = Builder.CreateNeg(MulV); + Add = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, MulV); + Sub = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, NegMulV); } else { Add = Builder.CreateAdd(StartValue, MulV); Sub = Builder.CreateSub(StartValue, MulV); @@ -2686,9 +2646,11 @@ namespace { // perfectly reduced form, which can't be guaranteed. struct SCEVFindUnsafe { ScalarEvolution &SE; + bool CanonicalMode; bool IsUnsafe; - SCEVFindUnsafe(ScalarEvolution &se): SE(se), IsUnsafe(false) {} + SCEVFindUnsafe(ScalarEvolution &SE, bool CanonicalMode) + : SE(SE), CanonicalMode(CanonicalMode), IsUnsafe(false) {} bool follow(const SCEV *S) { if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { @@ -2704,6 +2666,14 @@ struct SCEVFindUnsafe { IsUnsafe = true; return false; } + + // For non-affine addrecs or in non-canonical mode we need a preheader + // to insert into. + if (!AR->getLoop()->getLoopPreheader() && + (!CanonicalMode || !AR->isAffine())) { + IsUnsafe = true; + return false; + } } return true; } @@ -2712,8 +2682,8 @@ struct SCEVFindUnsafe { } namespace llvm { -bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE) { - SCEVFindUnsafe Search(SE); +bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE, bool CanonicalMode) { + SCEVFindUnsafe Search(SE, CanonicalMode); visitAll(S, Search); return !Search.IsUnsafe; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index d86ecbb6db00..f467de5f924e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GuardUtils.h" @@ -159,6 +160,13 @@ static cl::opt<unsigned> cl::desc("Maximum cost of combining conditions when " "folding branches")); +static cl::opt<unsigned> BranchFoldToCommonDestVectorMultiplier( + "simplifycfg-branch-fold-common-dest-vector-multiplier", cl::Hidden, + cl::init(2), + cl::desc("Multiplier to apply to threshold when determining whether or not " + "to fold branch to common destination when vector operations are " + "present")); + STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); @@ -272,7 +280,6 @@ public: } bool simplifyOnce(BasicBlock *BB); - bool simplifyOnceImpl(BasicBlock *BB); bool run(BasicBlock *BB); // Helper to set Resimplify and return change indication. @@ -2051,7 +2058,7 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, unsigned NumPHIdValues = 0; for (auto *I : *LRI) for (auto *V : PHIOperands[I]) { - if (InstructionsToSink.count(V) == 0) + if (!InstructionsToSink.contains(V)) ++NumPHIdValues; // FIXME: this check is overly optimistic. We may end up not sinking // said instruction, due to the very same profitability check. @@ -2257,6 +2264,23 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, return SI->getValueOperand(); return nullptr; // Unknown store. } + + if (auto *LI = dyn_cast<LoadInst>(&CurI)) { + if (LI->getPointerOperand() == StorePtr && LI->getType() == StoreTy && + LI->isSimple()) { + // Local objects (created by an `alloca` instruction) are always + // writable, so once we are past a read from a location it is valid to + // also write to that same location. + // If the address of the local object never escapes the function, that + // means it's never concurrently read or written, hence moving the store + // from under the condition will not introduce a data race. + auto *AI = dyn_cast<AllocaInst>(getUnderlyingObject(StorePtr)); + if (AI && !PointerMayBeCaptured(AI, false, true)) + // Found a previous load, return it. + return LI; + } + // The load didn't work out, but we may still find a store. + } } return nullptr; @@ -2552,17 +2576,17 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { int Size = 0; SmallPtrSet<const Value *, 32> EphValues; - auto IsEphemeral = [&](const Value *V) { - if (isa<AssumeInst>(V)) + auto IsEphemeral = [&](const Instruction *I) { + if (isa<AssumeInst>(I)) return true; - return isSafeToSpeculativelyExecute(V) && - all_of(V->users(), + return !I->mayHaveSideEffects() && !I->isTerminator() && + all_of(I->users(), [&](const User *U) { return EphValues.count(U); }); }; // Walk the loop in reverse so that we can identify ephemeral values properly // (values only feeding assumes). - for (Instruction &I : reverse(BB->instructionsWithoutDebug())) { + for (Instruction &I : reverse(BB->instructionsWithoutDebug(false))) { // Can't fold blocks that contain noduplicate or convergent calls. if (CallInst *CI = dyn_cast<CallInst>(&I)) if (CI->cannotDuplicate() || CI->isConvergent()) @@ -2595,8 +2619,10 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { /// If we have a conditional branch on a PHI node value that is defined in the /// same block as the branch and if any PHI entries are constants, thread edges /// corresponding to that entry to be branches to their ultimate destination. -static bool FoldCondBranchOnPHI(BranchInst *BI, DomTreeUpdater *DTU, - const DataLayout &DL, AssumptionCache *AC) { +static Optional<bool> FoldCondBranchOnPHIImpl(BranchInst *BI, + DomTreeUpdater *DTU, + const DataLayout &DL, + AssumptionCache *AC) { BasicBlock *BB = BI->getParent(); PHINode *PN = dyn_cast<PHINode>(BI->getCondition()); // NOTE: we currently cannot transform this case if the PHI node is used @@ -2710,13 +2736,25 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, DomTreeUpdater *DTU, DTU->applyUpdates(Updates); } - // Recurse, simplifying any other constants. - return FoldCondBranchOnPHI(BI, DTU, DL, AC) || true; + // Signal repeat, simplifying any other constants. + return None; } return false; } +static bool FoldCondBranchOnPHI(BranchInst *BI, DomTreeUpdater *DTU, + const DataLayout &DL, AssumptionCache *AC) { + Optional<bool> Result; + bool EverChanged = false; + do { + // Note that None means "we changed things, but recurse further." + Result = FoldCondBranchOnPHIImpl(BI, DTU, DL, AC); + EverChanged |= Result == None || *Result; + } while (Result == None); + return EverChanged; +} + /// Given a BB that starts with the specified two-entry PHI node, /// see if we can eliminate it. static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, @@ -2852,8 +2890,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // instructions. for (BasicBlock *IfBlock : IfBlocks) for (BasicBlock::iterator I = IfBlock->begin(); !I->isTerminator(); ++I) - if (!AggressiveInsts.count(&*I) && !isa<DbgInfoIntrinsic>(I) && - !isa<PseudoProbeInst>(I)) { + if (!AggressiveInsts.count(&*I) && !I->isDebugOrPseudoInst()) { // This is not an aggressive instruction that we can promote. // Because of this, we won't be able to get rid of the control flow, so // the xform is not worth it. @@ -3112,6 +3149,14 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, return true; } +/// Return if an instruction's type or any of its operands' types are a vector +/// type. +static bool isVectorOp(Instruction &I) { + return I.getType()->isVectorTy() || any_of(I.operands(), [](Use &U) { + return U->getType()->isVectorTy(); + }); +} + /// If this basic block is simple enough, and if a predecessor branches to us /// and one of our successors, fold the block into the predecessor and use /// logical operations to pick the right destination. @@ -3196,6 +3241,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, // number of the bonus instructions we'll need to create when cloning into // each predecessor does not exceed a certain threshold. unsigned NumBonusInsts = 0; + bool SawVectorOp = false; const unsigned PredCount = Preds.size(); for (Instruction &I : *BB) { // Don't check the branch condition comparison itself. @@ -3207,13 +3253,19 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, // I must be safe to execute unconditionally. if (!isSafeToSpeculativelyExecute(&I)) return false; + SawVectorOp |= isVectorOp(I); // Account for the cost of duplicating this instruction into each - // predecessor. - NumBonusInsts += PredCount; - // Early exits once we reach the limit. - if (NumBonusInsts > BonusInstThreshold) - return false; + // predecessor. Ignore free instructions. + if (!TTI || + TTI->getUserCost(&I, CostKind) != TargetTransformInfo::TCC_Free) { + NumBonusInsts += PredCount; + + // Early exits once we reach the limit. + if (NumBonusInsts > + BonusInstThreshold * BranchFoldToCommonDestVectorMultiplier) + return false; + } auto IsBCSSAUse = [BB, &I](Use &U) { auto *UI = cast<Instruction>(U.getUser()); @@ -3226,6 +3278,10 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, if (!all_of(I.uses(), IsBCSSAUse)) return false; } + if (NumBonusInsts > + BonusInstThreshold * + (SawVectorOp ? BranchFoldToCommonDestVectorMultiplier : 1)) + return false; // Ok, we have the budget. Perform the transformation. for (BasicBlock *PredBlock : Preds) { @@ -3358,7 +3414,7 @@ static bool mergeConditionalStoreToAddress( InstructionCost Cost = 0; InstructionCost Budget = PHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic; - for (auto &I : BB->instructionsWithoutDebug()) { + for (auto &I : BB->instructionsWithoutDebug(false)) { // Consider terminator instruction to be free. if (I.isTerminator()) continue; @@ -3431,10 +3487,7 @@ static bool mergeConditionalStoreToAddress( /*BranchWeights=*/nullptr, DTU); QB.SetInsertPoint(T); StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address)); - AAMDNodes AAMD; - PStore->getAAMetadata(AAMD, /*Merge=*/false); - PStore->getAAMetadata(AAMD, /*Merge=*/true); - SI->setAAMetadata(AAMD); + SI->setAAMetadata(PStore->getAAMetadata().merge(QStore->getAAMetadata())); // Choose the minimum alignment. If we could prove both stores execute, we // could use biggest one. In this case, though, we only know that one of the // stores executes. And we don't know it's safe to take the alignment from a @@ -3684,7 +3737,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // fold the conditions into logical ops and one cond br. // Ignore dbg intrinsics. - if (&*BB->instructionsWithoutDebug().begin() != BI) + if (&*BB->instructionsWithoutDebug(false).begin() != BI) return false; int PBIOp, BIOp; @@ -4729,29 +4782,6 @@ static bool CasesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) { return true; } -static void createUnreachableSwitchDefault(SwitchInst *Switch, - DomTreeUpdater *DTU) { - LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); - auto *BB = Switch->getParent(); - BasicBlock *NewDefaultBlock = SplitBlockPredecessors( - Switch->getDefaultDest(), Switch->getParent(), "", DTU); - auto *OrigDefaultBlock = Switch->getDefaultDest(); - Switch->setDefaultDest(&*NewDefaultBlock); - if (DTU) - DTU->applyUpdates({{DominatorTree::Insert, BB, &*NewDefaultBlock}, - {DominatorTree::Delete, BB, OrigDefaultBlock}}); - SplitBlock(&*NewDefaultBlock, &NewDefaultBlock->front(), DTU); - SmallVector<DominatorTree::UpdateType, 2> Updates; - if (DTU) - for (auto *Successor : successors(NewDefaultBlock)) - Updates.push_back({DominatorTree::Delete, NewDefaultBlock, Successor}); - auto *NewTerminator = NewDefaultBlock->getTerminator(); - new UnreachableInst(Switch->getContext(), NewTerminator); - EraseTerminatorAndDCECond(NewTerminator); - if (DTU) - DTU->applyUpdates(Updates); -} - /// Turn a switch with two reachable destinations into an integer range /// comparison and branch. bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI, @@ -5057,9 +5087,10 @@ static bool ValidLookupTableConstant(Constant *C, const TargetTransformInfo &TTI return false; if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { - if (!CE->isGEPWithNoNotionalOverIndexing()) - return false; - if (!ValidLookupTableConstant(CE->getOperand(0), TTI)) + // Pointer casts and in-bounds GEPs will not prohibit the backend from + // materializing the array of constants. + Constant *StrippedC = cast<Constant>(CE->stripInBoundsConstantOffsets()); + if (StrippedC == C || !ValidLookupTableConstant(StrippedC, TTI)) return false; } @@ -5129,7 +5160,7 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, // which we can constant-propagate the CaseVal, continue to its successor. SmallDenseMap<Value *, Constant *> ConstantPool; ConstantPool.insert(std::make_pair(SI->getCondition(), CaseVal)); - for (Instruction &I :CaseDest->instructionsWithoutDebug()) { + for (Instruction &I : CaseDest->instructionsWithoutDebug(false)) { if (I.isTerminator()) { // If the terminator is a simple branch, continue to the next block. if (I.getNumSuccessors() != 1 || I.isExceptionalTerminator()) @@ -5622,8 +5653,32 @@ bool SwitchLookupTable::WouldFitInRegister(const DataLayout &DL, return DL.fitsInLegalInteger(TableSize * IT->getBitWidth()); } +static bool isTypeLegalForLookupTable(Type *Ty, const TargetTransformInfo &TTI, + const DataLayout &DL) { + // Allow any legal type. + if (TTI.isTypeLegal(Ty)) + return true; + + auto *IT = dyn_cast<IntegerType>(Ty); + if (!IT) + return false; + + // Also allow power of 2 integer types that have at least 8 bits and fit in + // a register. These types are common in frontend languages and targets + // usually support loads of these types. + // TODO: We could relax this to any integer that fits in a register and rely + // on ABI alignment and padding in the table to allow the load to be widened. + // Or we could widen the constants and truncate the load. + unsigned BitWidth = IT->getBitWidth(); + return BitWidth >= 8 && isPowerOf2_32(BitWidth) && + DL.fitsInLegalInteger(IT->getBitWidth()); +} + /// Determine whether a lookup table should be built for this switch, based on /// the number of cases, size of the table, and the types of the results. +// TODO: We could support larger than legal types by limiting based on the +// number of loads required and/or table size. If the constants are small we +// could use smaller table entries and extend after the load. static bool ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, const TargetTransformInfo &TTI, const DataLayout &DL, @@ -5637,7 +5692,7 @@ ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, Type *Ty = I.second; // Saturate this flag to true. - HasIllegalType = HasIllegalType || !TTI.isTypeLegal(Ty); + HasIllegalType = HasIllegalType || !isTypeLegalForLookupTable(Ty, TTI, DL); // Saturate this flag to false. AllTablesFitInRegister = @@ -6120,7 +6175,7 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // If the block only contains the switch, see if we can fold the block // away into any preds. - if (SI == &*BB->instructionsWithoutDebug().begin()) + if (SI == &*BB->instructionsWithoutDebug(false).begin()) if (FoldValueComparisonIntoPredecessors(SI, Builder)) return requestResimplify(); } @@ -6264,12 +6319,9 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, // The debug info in OtherPred doesn't cover the merged control flow that // used to go through BB. We need to delete it or update it. - for (auto I = OtherPred->begin(), E = OtherPred->end(); I != E;) { - Instruction &Inst = *I; - I++; + for (Instruction &Inst : llvm::make_early_inc_range(*OtherPred)) if (isa<DbgInfoIntrinsic>(Inst)) Inst.eraseFromParent(); - } SmallPtrSet<BasicBlock *, 16> Succs(succ_begin(BB), succ_end(BB)); for (BasicBlock *Succ : Succs) { @@ -6356,6 +6408,11 @@ static BasicBlock *allPredecessorsComeFromSameSource(BasicBlock *BB) { } bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { + assert( + !isa<ConstantInt>(BI->getCondition()) && + BI->getSuccessor(0) != BI->getSuccessor(1) && + "Tautological conditional branch should have been eliminated already."); + BasicBlock *BB = BI->getParent(); if (!Options.SimplifyCondBranch) return false; @@ -6470,19 +6527,21 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu if (C->isNullValue() || isa<UndefValue>(C)) { // Only look at the first use, avoid hurting compile time with long uselists - User *Use = *I->user_begin(); + auto *Use = cast<Instruction>(*I->user_begin()); + // Bail out if Use is not in the same BB as I or Use == I or Use comes + // before I in the block. The latter two can be the case if Use is a PHI + // node. + if (Use->getParent() != I->getParent() || Use == I || Use->comesBefore(I)) + return false; // Now make sure that there are no instructions in between that can alter // control flow (eg. calls) - for (BasicBlock::iterator - i = ++BasicBlock::iterator(I), - UI = BasicBlock::iterator(dyn_cast<Instruction>(Use)); - i != UI; ++i) { - if (i == I->getParent()->end()) - return false; - if (!isGuaranteedToTransferExecutionToSuccessor(&*i)) - return false; - } + auto InstrRange = + make_range(std::next(I->getIterator()), Use->getIterator()); + if (any_of(InstrRange, [](Instruction &I) { + return !isGuaranteedToTransferExecutionToSuccessor(&I); + })) + return false; // Look through GEPs. A load from a GEP derived from NULL is still undefined if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Use)) @@ -6558,21 +6617,51 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB, // destination from conditional branches. if (BI->isUnconditional()) Builder.CreateUnreachable(); - else + else { + // Preserve guarding condition in assume, because it might not be + // inferrable from any dominating condition. + Value *Cond = BI->getCondition(); + if (BI->getSuccessor(0) == BB) + Builder.CreateAssumption(Builder.CreateNot(Cond)); + else + Builder.CreateAssumption(Cond); Builder.CreateBr(BI->getSuccessor(0) == BB ? BI->getSuccessor(1) : BI->getSuccessor(0)); + } BI->eraseFromParent(); if (DTU) DTU->applyUpdates({{DominatorTree::Delete, Predecessor, BB}}); return true; + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(T)) { + // Redirect all branches leading to UB into + // a newly created unreachable block. + BasicBlock *Unreachable = BasicBlock::Create( + Predecessor->getContext(), "unreachable", BB->getParent(), BB); + Builder.SetInsertPoint(Unreachable); + // The new block contains only one instruction: Unreachable + Builder.CreateUnreachable(); + for (auto &Case : SI->cases()) + if (Case.getCaseSuccessor() == BB) { + BB->removePredecessor(Predecessor); + Case.setSuccessor(Unreachable); + } + if (SI->getDefaultDest() == BB) { + BB->removePredecessor(Predecessor); + SI->setDefaultDest(Unreachable); + } + + if (DTU) + DTU->applyUpdates( + { { DominatorTree::Insert, Predecessor, Unreachable }, + { DominatorTree::Delete, Predecessor, BB } }); + return true; } - // TODO: SwitchInst. } return false; } -bool SimplifyCFGOpt::simplifyOnceImpl(BasicBlock *BB) { +bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { bool Changed = false; assert(BB && BB->getParent() && "Block not embedded in function!"); @@ -6596,7 +6685,8 @@ bool SimplifyCFGOpt::simplifyOnceImpl(BasicBlock *BB) { Changed |= EliminateDuplicatePHINodes(BB); // Check for and remove branches that will always cause undefined behavior. - Changed |= removeUndefIntroducingPredecessor(BB, DTU); + if (removeUndefIntroducingPredecessor(BB, DTU)) + return requestResimplify(); // Merge basic blocks into their predecessor if there is only one distinct // pred, and if there is only one distinct successor of the predecessor, and @@ -6621,7 +6711,8 @@ bool SimplifyCFGOpt::simplifyOnceImpl(BasicBlock *BB) { // eliminate it, do so now. if (auto *PN = dyn_cast<PHINode>(BB->begin())) if (PN->getNumIncomingValues() == 2) - Changed |= FoldTwoEntryPHINode(PN, TTI, DTU, DL); + if (FoldTwoEntryPHINode(PN, TTI, DTU, DL)) + return true; } Instruction *Terminator = BB->getTerminator(); @@ -6650,12 +6741,6 @@ bool SimplifyCFGOpt::simplifyOnceImpl(BasicBlock *BB) { return Changed; } -bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { - bool Changed = simplifyOnceImpl(BB); - - return Changed; -} - bool SimplifyCFGOpt::run(BasicBlock *BB) { bool Changed = false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index bd30be011472..5b7fd4349c6c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -942,6 +942,7 @@ bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, DominatorTree *DT, } // namespace llvm +namespace { //===----------------------------------------------------------------------===// // Widen Induction Variables - Extend the width of an IV to cover its // widest uses. @@ -1072,7 +1073,7 @@ protected: private: SmallVector<NarrowIVDefUse, 8> NarrowIVUsers; }; - +} // namespace /// Determine the insertion point for this user. By default, insert immediately /// before the user. SCEVExpander or LICM will hoist loop invariants out of the diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index b8e0f63c481d..e190a1294eb3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -142,12 +142,10 @@ static void annotateDereferenceableBytes(CallInst *CI, unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace(); if (!llvm::NullPointerIsDefined(F, AS) || CI->paramHasAttr(ArgNo, Attribute::NonNull)) - DerefBytes = std::max(CI->getDereferenceableOrNullBytes( - ArgNo + AttributeList::FirstArgIndex), + DerefBytes = std::max(CI->getParamDereferenceableOrNullBytes(ArgNo), DereferenceableBytes); - - if (CI->getDereferenceableBytes(ArgNo + AttributeList::FirstArgIndex) < - DerefBytes) { + + if (CI->getParamDereferenceableBytes(ArgNo) < DerefBytes) { CI->removeParamAttr(ArgNo, Attribute::Dereferenceable); if (!llvm::NullPointerIsDefined(F, AS) || CI->paramHasAttr(ArgNo, Attribute::NonNull)) @@ -512,14 +510,18 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) { B.CreateMemCpy(Dst, Align(1), Src, Align(1), ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len)); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return Dst; } Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) { Function *Callee = CI->getCalledFunction(); Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + + // stpcpy(d,s) -> strcpy(d,s) if the result is not used. + if (CI->use_empty()) + return emitStrCpy(Dst, Src, B, TLI); + if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) Value *StrLen = emitStrLen(Src, B, DL, TLI); return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr; @@ -541,8 +543,7 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) { // copy for us. Make a memcpy to copy the nul byte with align = 1. CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), LenV); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return DstEnd; } @@ -577,9 +578,9 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) { if (SrcLen == 0) { // strncpy(x, "", y) -> memset(x, '\0', y) Align MemSetAlign = - CI->getAttributes().getParamAttributes(0).getAlignment().valueOrOne(); + CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne(); CallInst *NewCI = B.CreateMemSet(Dst, B.getInt8('\0'), Size, MemSetAlign); - AttrBuilder ArgAttrs(CI->getAttributes().getParamAttributes(0)); + AttrBuilder ArgAttrs(CI->getAttributes().getParamAttrs(0)); NewCI->setAttributes(NewCI->getAttributes().addParamAttributes( CI->getContext(), 0, ArgAttrs)); return Dst; @@ -604,8 +605,7 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) { CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), ConstantInt::get(DL.getIntPtrType(PT), Len)); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return Dst; } @@ -1082,8 +1082,7 @@ Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilderBase &B) { CallInst *NewCI = B.CreateMemCpy(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), Size); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return CI->getArgOperand(0); } @@ -1136,8 +1135,7 @@ Value *LibCallSimplifier::optimizeMemPCpy(CallInst *CI, IRBuilderBase &B) { // any return attributes are compliant. // TODO: Attach return value attributes to the 1st operand to preserve them? NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, N); } @@ -1151,70 +1149,21 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilderBase &B) { CallInst *NewCI = B.CreateMemMove(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), Size); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return CI->getArgOperand(0); } -/// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n). -Value *LibCallSimplifier::foldMallocMemset(CallInst *Memset, IRBuilderBase &B) { - // This has to be a memset of zeros (bzero). - auto *FillValue = dyn_cast<ConstantInt>(Memset->getArgOperand(1)); - if (!FillValue || FillValue->getZExtValue() != 0) - return nullptr; - - // TODO: We should handle the case where the malloc has more than one use. - // This is necessary to optimize common patterns such as when the result of - // the malloc is checked against null or when a memset intrinsic is used in - // place of a memset library call. - auto *Malloc = dyn_cast<CallInst>(Memset->getArgOperand(0)); - if (!Malloc || !Malloc->hasOneUse()) - return nullptr; - - // Is the inner call really malloc()? - Function *InnerCallee = Malloc->getCalledFunction(); - if (!InnerCallee) - return nullptr; - - LibFunc Func; - if (!TLI->getLibFunc(*InnerCallee, Func) || !TLI->has(Func) || - Func != LibFunc_malloc) - return nullptr; - - // The memset must cover the same number of bytes that are malloc'd. - if (Memset->getArgOperand(2) != Malloc->getArgOperand(0)) - return nullptr; - - // Replace the malloc with a calloc. We need the data layout to know what the - // actual size of a 'size_t' parameter is. - B.SetInsertPoint(Malloc->getParent(), ++Malloc->getIterator()); - const DataLayout &DL = Malloc->getModule()->getDataLayout(); - IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext()); - if (Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), - Malloc->getArgOperand(0), - Malloc->getAttributes(), B, *TLI)) { - substituteInParent(Malloc, Calloc); - return Calloc; - } - - return nullptr; -} - Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilderBase &B) { Value *Size = CI->getArgOperand(2); annotateNonNullAndDereferenceable(CI, 0, Size, DL); if (isa<IntrinsicInst>(CI)) return nullptr; - if (auto *Calloc = foldMallocMemset(CI, B)) - return Calloc; - // memset(p, v, n) -> llvm.memset(align 1 p, v, n) Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, Size, Align(1)); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return CI->getArgOperand(0); } @@ -1346,13 +1295,13 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) { B.setFastMathFlags(CI->getFastMathFlags()); Value *Real, *Imag; - if (CI->getNumArgOperands() == 1) { + if (CI->arg_size() == 1) { Value *Op = CI->getArgOperand(0); assert(Op->getType()->isArrayTy() && "Unexpected signature for cabs!"); Real = B.CreateExtractValue(Op, 0, "real"); Imag = B.CreateExtractValue(Op, 1, "imag"); } else { - assert(CI->getNumArgOperands() == 2 && "Unexpected signature for cabs!"); + assert(CI->arg_size() == 2 && "Unexpected signature for cabs!"); Real = CI->getArgOperand(0); Imag = CI->getArgOperand(1); } @@ -2333,7 +2282,7 @@ Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilderBase &B, // Proceedings of PACT'98, Oct. 1998, IEEE if (!CI->hasFnAttr(Attribute::Cold) && isReportingError(Callee, CI, StreamArg)) { - CI->addAttribute(AttributeList::FunctionIndex, Attribute::Cold); + CI->addFnAttr(Attribute::Cold); } return nullptr; @@ -2349,7 +2298,7 @@ static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg) { // These functions might be considered cold, but only if their stream // argument is stderr. - if (StreamArg >= (int)CI->getNumArgOperands()) + if (StreamArg >= (int)CI->arg_size()) return false; LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); if (!LI) @@ -2381,7 +2330,7 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { return emitPutChar(B.getInt32(FormatStr[0]), B, TLI); // Try to remove call or emit putchar/puts. - if (FormatStr == "%s" && CI->getNumArgOperands() > 1) { + if (FormatStr == "%s" && CI->arg_size() > 1) { StringRef OperandStr; if (!getConstantStringInfo(CI->getOperand(1), OperandStr)) return nullptr; @@ -2402,7 +2351,7 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { // printf("foo\n") --> puts("foo") if (FormatStr.back() == '\n' && - FormatStr.find('%') == StringRef::npos) { // No format characters. + !FormatStr.contains('%')) { // No format characters. // Create a string literal with no \n on it. We expect the constant merge // pass to be run after this pass, to merge duplicate strings. FormatStr = FormatStr.drop_back(); @@ -2412,12 +2361,12 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { // Optimize specific format strings. // printf("%c", chr) --> putchar(chr) - if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && + if (FormatStr == "%c" && CI->arg_size() > 1 && CI->getArgOperand(1)->getType()->isIntegerTy()) return emitPutChar(CI->getArgOperand(1), B, TLI); // printf("%s\n", str) --> puts(str) - if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && + if (FormatStr == "%s\n" && CI->arg_size() > 1 && CI->getArgOperand(1)->getType()->isPointerTy()) return emitPutS(CI->getArgOperand(1), B, TLI); return nullptr; @@ -2469,10 +2418,10 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, // If we just have a format string (nothing else crazy) transform it. Value *Dest = CI->getArgOperand(0); - if (CI->getNumArgOperands() == 2) { + if (CI->arg_size() == 2) { // Make sure there's no % in the constant array. We could try to handle // %% -> % in the future if we cared. - if (FormatStr.find('%') != StringRef::npos) + if (FormatStr.contains('%')) return nullptr; // we found a format specifier, bail out. // sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1) @@ -2485,8 +2434,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, // The remaining optimizations require the format string to be "%s" or "%c" // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) + if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->arg_size() < 3) return nullptr; // Decode the second character of the format string. @@ -2597,10 +2545,10 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, return nullptr; // If we just have a format string (nothing else crazy) transform it. - if (CI->getNumArgOperands() == 3) { + if (CI->arg_size() == 3) { // Make sure there's no % in the constant array. We could try to handle // %% -> % in the future if we cared. - if (FormatStr.find('%') != StringRef::npos) + if (FormatStr.contains('%')) return nullptr; // we found a format specifier, bail out. if (N == 0) @@ -2619,8 +2567,7 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, // The remaining optimizations require the format string to be "%s" or "%c" // and have an extra operand. - if (FormatStr.size() == 2 && FormatStr[0] == '%' && - CI->getNumArgOperands() == 4) { + if (FormatStr.size() == 2 && FormatStr[0] == '%' && CI->arg_size() == 4) { // Decode the second character of the format string. if (FormatStr[1] == 'c') { @@ -2688,9 +2635,9 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, return nullptr; // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) - if (CI->getNumArgOperands() == 2) { + if (CI->arg_size() == 2) { // Could handle %% -> % if we cared. - if (FormatStr.find('%') != StringRef::npos) + if (FormatStr.contains('%')) return nullptr; // We found a format specifier. return emitFWrite( @@ -2701,8 +2648,7 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, // The remaining optimizations require the format string to be "%s" or "%c" // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) + if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->arg_size() < 3) return nullptr; // Decode the second character of the format string. @@ -3066,7 +3012,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) { return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); - // TODO: Use foldMallocMemset() with memset intrinsic. case Intrinsic::memset: return optimizeMemSet(CI, Builder); case Intrinsic::memcpy: @@ -3266,8 +3211,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, B.CreateMemCpy(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), CI->getArgOperand(2)); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return CI->getArgOperand(0); } return nullptr; @@ -3280,8 +3224,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, B.CreateMemMove(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), CI->getArgOperand(2)); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return CI->getArgOperand(0); } return nullptr; @@ -3289,15 +3232,12 @@ Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilderBase &B) { - // TODO: Try foldMallocMemset() here. - if (isFortifiedCallFoldable(CI, 3, 2)) { Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), Align(1)); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes(AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return CI->getArgOperand(0); } return nullptr; @@ -3311,9 +3251,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemPCpyChk(CallInst *CI, CI->getArgOperand(2), B, DL, TLI)) { CallInst *NewCI = cast<CallInst>(Call); NewCI->setAttributes(CI->getAttributes()); - NewCI->removeAttributes( - AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewCI->getType())); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); return NewCI; } return nullptr; @@ -3354,7 +3292,11 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, else return nullptr; - Type *SizeTTy = DL.getIntPtrType(CI->getContext()); + // FIXME: There is really no guarantee that sizeof(size_t) is equal to + // sizeof(int*) for every target. So the assumption used here to derive the + // SizeTBits based on the size of an integer pointer in address space zero + // isn't always valid. + Type *SizeTTy = DL.getIntPtrType(CI->getContext(), /*AddressSpace=*/0); Value *LenV = ConstantInt::get(SizeTTy, Len); Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); // If the function was an __stpcpy_chk, and we were able to fold it into diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp index 32f2f4e233b2..7e12bbd2851c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp @@ -24,7 +24,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalObject.h" -#include "llvm/IR/GlobalIndirectSymbol.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instruction.h" @@ -65,9 +64,8 @@ static void addNonConstUser(ClusterMapType &GVtoClusterMap, if (const Instruction *I = dyn_cast<Instruction>(U)) { const GlobalValue *F = I->getParent()->getParent(); GVtoClusterMap.unionSets(GV, F); - } else if (isa<GlobalIndirectSymbol>(U) || isa<Function>(U) || - isa<GlobalVariable>(U)) { - GVtoClusterMap.unionSets(GV, cast<GlobalValue>(U)); + } else if (const GlobalValue *GVU = dyn_cast<GlobalValue>(U)) { + GVtoClusterMap.unionSets(GV, GVU); } else { llvm_unreachable("Underimplemented use case"); } @@ -91,6 +89,13 @@ static void addAllGlobalValueUsers(ClusterMapType &GVtoClusterMap, } } +static const GlobalObject *getGVPartitioningRoot(const GlobalValue *GV) { + const GlobalObject *GO = GV->getAliaseeObject(); + if (const auto *GI = dyn_cast_or_null<GlobalIFunc>(GO)) + GO = GI->getResolverFunction(); + return GO; +} + // Find partitions for module in the way that no locals need to be // globalized. // Try to balance pack those partitions into N files since this roughly equals @@ -123,12 +128,11 @@ static void findPartitions(Module &M, ClusterIDMapType &ClusterIDMap, Member = &GV; } - // For aliases we should not separate them from their aliasees regardless - // of linkage. - if (auto *GIS = dyn_cast<GlobalIndirectSymbol>(&GV)) { - if (const GlobalObject *Base = GIS->getBaseObject()) - GVtoClusterMap.unionSets(&GV, Base); - } + // Aliases should not be separated from their aliasees and ifuncs should + // not be separated from their resolvers regardless of linkage. + if (const GlobalObject *Root = getGVPartitioningRoot(&GV)) + if (&GV != Root) + GVtoClusterMap.unionSets(&GV, Root); if (const Function *F = dyn_cast<Function>(&GV)) { for (const BasicBlock &BB : *F) { @@ -225,9 +229,8 @@ static void externalize(GlobalValue *GV) { // Returns whether GV should be in partition (0-based) I of N. static bool isInPartition(const GlobalValue *GV, unsigned I, unsigned N) { - if (auto *GIS = dyn_cast<GlobalIndirectSymbol>(GV)) - if (const GlobalObject *Base = GIS->getBaseObject()) - GV = Base; + if (const GlobalObject *Root = getGVPartitioningRoot(GV)) + GV = Root; StringRef Name; if (const Comdat *C = GV->getComdat()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index ec4ea848a5d4..6a0eb34a7999 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -184,7 +184,7 @@ performOnModule(Module &M) { std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error); if (!Error.empty()) - report_fatal_error("unable to transforn " + C.getName() + " in " + + report_fatal_error(Twine("unable to transforn ") + C.getName() + " in " + M.getModuleIdentifier() + ": " + Error); if (C.getName() == Name) @@ -256,11 +256,11 @@ bool RewriteMapParser::parse(const std::string &MapFile, MemoryBuffer::getFile(MapFile); if (!Mapping) - report_fatal_error("unable to read rewrite map '" + MapFile + "': " + - Mapping.getError().message()); + report_fatal_error(Twine("unable to read rewrite map '") + MapFile + + "': " + Mapping.getError().message()); if (!parse(*Mapping, DL)) - report_fatal_error("unable to parse rewrite map '" + MapFile + "'"); + report_fatal_error(Twine("unable to parse rewrite map '") + MapFile + "'"); return true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp index 6336af25ef98..dbe3cc93e72b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp @@ -403,19 +403,10 @@ int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, if (Offset == -1) return Offset; - unsigned AS = Src->getType()->getPointerAddressSpace(); // Otherwise, see if we can constant fold a load from the constant with the // offset applied as appropriate. - if (Offset) { - Src = ConstantExpr::getBitCast(Src, - Type::getInt8PtrTy(Src->getContext(), AS)); - Constant *OffsetCst = - ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); - Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), - Src, OffsetCst); - } - Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); - if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL)) + unsigned IndexSize = DL.getIndexTypeSizeInBits(Src->getType()); + if (ConstantFoldLoadFromConstPtr(Src, LoadTy, APInt(IndexSize, Offset), DL)) return Offset; return -1; } @@ -584,19 +575,11 @@ T *getMemInstValueForLoadHelper(MemIntrinsic *SrcInst, unsigned Offset, MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); Constant *Src = cast<Constant>(MTI->getSource()); - unsigned AS = Src->getType()->getPointerAddressSpace(); // Otherwise, see if we can constant fold a load from the constant with the // offset applied as appropriate. - if (Offset) { - Src = ConstantExpr::getBitCast(Src, - Type::getInt8PtrTy(Src->getContext(), AS)); - Constant *OffsetCst = - ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); - Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), - Src, OffsetCst); - } - Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); - return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL); + unsigned IndexSize = DL.getIndexTypeSizeInBits(Src->getType()); + return ConstantFoldLoadFromConstPtr( + Src, LoadTy, APInt(IndexSize, Offset), DL); } /// This function is called when we have a diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp index f3afd42e6163..c3eafd6b2492 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -26,7 +26,8 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalIndirectSymbol.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalIFunc.h" #include "llvm/IR/GlobalObject.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InlineAsm.h" @@ -68,7 +69,7 @@ struct WorklistEntry { enum EntryKind { MapGlobalInit, MapAppendingVar, - MapGlobalIndirectSymbol, + MapAliasOrIFunc, RemapFunction }; struct GVInitTy { @@ -79,8 +80,8 @@ struct WorklistEntry { GlobalVariable *GV; Constant *InitPrefix; }; - struct GlobalIndirectSymbolTy { - GlobalIndirectSymbol *GIS; + struct AliasOrIFuncTy { + GlobalValue *GV; Constant *Target; }; @@ -91,7 +92,7 @@ struct WorklistEntry { union { GVInitTy GVInit; AppendingGVTy AppendingGV; - GlobalIndirectSymbolTy GlobalIndirectSymbol; + AliasOrIFuncTy AliasOrIFunc; Function *RemapF; } Data; }; @@ -163,8 +164,8 @@ public: bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers, unsigned MCID); - void scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, Constant &Target, - unsigned MCID); + void scheduleMapAliasOrIFunc(GlobalValue &GV, Constant &Target, + unsigned MCID); void scheduleRemapFunction(Function &F, unsigned MCID); void flush(); @@ -873,10 +874,17 @@ void Mapper::flush() { E.AppendingGVIsOldCtorDtor, makeArrayRef(NewInits)); break; } - case WorklistEntry::MapGlobalIndirectSymbol: - E.Data.GlobalIndirectSymbol.GIS->setIndirectSymbol( - mapConstant(E.Data.GlobalIndirectSymbol.Target)); + case WorklistEntry::MapAliasOrIFunc: { + GlobalValue *GV = E.Data.AliasOrIFunc.GV; + Constant *Target = mapConstant(E.Data.AliasOrIFunc.Target); + if (auto *GA = dyn_cast<GlobalAlias>(GV)) + GA->setAliasee(Target); + else if (auto *GI = dyn_cast<GlobalIFunc>(GV)) + GI->setResolver(Target); + else + llvm_unreachable("Not alias or ifunc"); break; + } case WorklistEntry::RemapFunction: remapFunction(*E.Data.RemapF); break; @@ -944,12 +952,13 @@ void Mapper::remapInstruction(Instruction *I) { LLVMContext &C = CB->getContext(); AttributeList Attrs = CB->getAttributes(); for (unsigned i = 0; i < Attrs.getNumAttrSets(); ++i) { - for (Attribute::AttrKind TypedAttr : - {Attribute::ByVal, Attribute::StructRet, Attribute::ByRef, - Attribute::InAlloca}) { - if (Type *Ty = Attrs.getAttribute(i, TypedAttr).getValueAsType()) { - Attrs = Attrs.replaceAttributeType(C, i, TypedAttr, - TypeMapper->remapType(Ty)); + for (int AttrIdx = Attribute::FirstTypeAttr; + AttrIdx <= Attribute::LastTypeAttr; AttrIdx++) { + Attribute::AttrKind TypedAttr = (Attribute::AttrKind)AttrIdx; + if (Type *Ty = + Attrs.getAttributeAtIndex(i, TypedAttr).getValueAsType()) { + Attrs = Attrs.replaceAttributeTypeAtIndex(C, i, TypedAttr, + TypeMapper->remapType(Ty)); break; } } @@ -1068,16 +1077,18 @@ void Mapper::scheduleMapAppendingVariable(GlobalVariable &GV, AppendingInits.append(NewMembers.begin(), NewMembers.end()); } -void Mapper::scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, - Constant &Target, unsigned MCID) { - assert(AlreadyScheduled.insert(&GIS).second && "Should not reschedule"); +void Mapper::scheduleMapAliasOrIFunc(GlobalValue &GV, Constant &Target, + unsigned MCID) { + assert(AlreadyScheduled.insert(&GV).second && "Should not reschedule"); + assert((isa<GlobalAlias>(GV) || isa<GlobalIFunc>(GV)) && + "Should be alias or ifunc"); assert(MCID < MCs.size() && "Invalid mapping context"); WorklistEntry WE; - WE.Kind = WorklistEntry::MapGlobalIndirectSymbol; + WE.Kind = WorklistEntry::MapAliasOrIFunc; WE.MCID = MCID; - WE.Data.GlobalIndirectSymbol.GIS = &GIS; - WE.Data.GlobalIndirectSymbol.Target = &Target; + WE.Data.AliasOrIFunc.GV = &GV; + WE.Data.AliasOrIFunc.Target = &Target; Worklist.push_back(WE); } @@ -1174,10 +1185,14 @@ void ValueMapper::scheduleMapAppendingVariable(GlobalVariable &GV, GV, InitPrefix, IsOldCtorDtor, NewMembers, MCID); } -void ValueMapper::scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, - Constant &Target, - unsigned MCID) { - getAsMapper(pImpl)->scheduleMapGlobalIndirectSymbol(GIS, Target, MCID); +void ValueMapper::scheduleMapGlobalAlias(GlobalAlias &GA, Constant &Aliasee, + unsigned MCID) { + getAsMapper(pImpl)->scheduleMapAliasOrIFunc(GA, Aliasee, MCID); +} + +void ValueMapper::scheduleMapGlobalIFunc(GlobalIFunc &GI, Constant &Resolver, + unsigned MCID) { + getAsMapper(pImpl)->scheduleMapAliasOrIFunc(GI, Resolver, MCID); } void ValueMapper::scheduleRemapFunction(Function &F, unsigned MCID) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 3b90997100f1..5a4a2f0924f6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -694,31 +694,16 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { }); for (Instruction &I : make_range(getBoundaryInstrs(Chain))) { - if (isa<LoadInst>(I) || isa<StoreInst>(I)) { - if (!is_contained(Chain, &I)) - MemoryInstrs.push_back(&I); - else - ChainInstrs.push_back(&I); - } else if (isa<IntrinsicInst>(&I) && - cast<IntrinsicInst>(&I)->getIntrinsicID() == - Intrinsic::sideeffect) { - // Ignore llvm.sideeffect calls. - } else if (isa<IntrinsicInst>(&I) && - cast<IntrinsicInst>(&I)->getIntrinsicID() == - Intrinsic::pseudoprobe) { - // Ignore llvm.pseudoprobe calls. - } else if (isa<IntrinsicInst>(&I) && - cast<IntrinsicInst>(&I)->getIntrinsicID() == Intrinsic::assume) { - // Ignore llvm.assume calls. - } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) { - LLVM_DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I - << '\n'); - break; - } else if (!IsLoadChain && (I.mayReadOrWriteMemory() || I.mayThrow())) { - LLVM_DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I - << '\n'); + if ((isa<LoadInst>(I) || isa<StoreInst>(I)) && is_contained(Chain, &I)) { + ChainInstrs.push_back(&I); + continue; + } + if (I.mayThrow()) { + LLVM_DEBUG(dbgs() << "LSV: Found may-throw operation: " << I << '\n'); break; } + if (I.mayReadOrWriteMemory()) + MemoryInstrs.push_back(&I); } // Loop until we find an instruction in ChainInstrs that we can't vectorize. @@ -751,26 +736,28 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { return LI->hasMetadata(LLVMContext::MD_invariant_load); }; - // We can ignore the alias as long as the load comes before the store, - // because that means we won't be moving the load past the store to - // vectorize it (the vectorized load is inserted at the location of the - // first load in the chain). - if (isa<StoreInst>(MemInstr) && ChainLoad && - (IsInvariantLoad(ChainLoad) || ChainLoad->comesBefore(MemInstr))) - continue; - - // Same case, but in reverse. - if (MemLoad && isa<StoreInst>(ChainInstr) && - (IsInvariantLoad(MemLoad) || MemLoad->comesBefore(ChainInstr))) - continue; + if (IsLoadChain) { + // We can ignore the alias as long as the load comes before the store, + // because that means we won't be moving the load past the store to + // vectorize it (the vectorized load is inserted at the location of the + // first load in the chain). + if (ChainInstr->comesBefore(MemInstr) || + (ChainLoad && IsInvariantLoad(ChainLoad))) + continue; + } else { + // Same case, but in reverse. + if (MemInstr->comesBefore(ChainInstr) || + (MemLoad && IsInvariantLoad(MemLoad))) + continue; + } - if (!AA.isNoAlias(MemoryLocation::get(MemInstr), - MemoryLocation::get(ChainInstr))) { + ModRefInfo MR = + AA.getModRefInfo(MemInstr, MemoryLocation::get(ChainInstr)); + if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) { LLVM_DEBUG({ dbgs() << "LSV: Found alias:\n" - " Aliasing instruction and pointer:\n" + " Aliasing instruction:\n" << " " << *MemInstr << '\n' - << " " << *getLoadStorePointerOperand(MemInstr) << '\n' << " Aliased instruction and pointer:\n" << " " << *ChainInstr << '\n' << " " << *getLoadStorePointerOperand(ChainInstr) << '\n'; @@ -1085,9 +1072,12 @@ bool Vectorizer::vectorizeStoreChain( if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." " Creating two separate arrays.\n"); - return vectorizeStoreChain(Chain.slice(0, TargetVF), - InstructionsProcessed) | - vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed); + bool Vectorized = false; + Vectorized |= + vectorizeStoreChain(Chain.slice(0, TargetVF), InstructionsProcessed); + Vectorized |= + vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed); + return Vectorized; } LLVM_DEBUG({ @@ -1104,8 +1094,10 @@ bool Vectorizer::vectorizeStoreChain( if (accessIsMisaligned(SzInBytes, AS, Alignment)) { if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { auto Chains = splitOddVectorElts(Chain, Sz); - return vectorizeStoreChain(Chains.first, InstructionsProcessed) | - vectorizeStoreChain(Chains.second, InstructionsProcessed); + bool Vectorized = false; + Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed); + Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed); + return Vectorized; } Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(), @@ -1119,15 +1111,17 @@ bool Vectorizer::vectorizeStoreChain( if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { auto Chains = splitOddVectorElts(Chain, Sz); - return vectorizeStoreChain(Chains.first, InstructionsProcessed) | - vectorizeStoreChain(Chains.second, InstructionsProcessed); + bool Vectorized = false; + Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed); + Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed); + return Vectorized; } BasicBlock::iterator First, Last; std::tie(First, Last) = getBoundaryInstrs(Chain); Builder.SetInsertPoint(&*Last); - Value *Vec = UndefValue::get(VecTy); + Value *Vec = PoisonValue::get(VecTy); if (VecStoreTy) { unsigned VecWidth = VecStoreTy->getNumElements(); @@ -1237,8 +1231,12 @@ bool Vectorizer::vectorizeLoadChain( if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." " Creating two separate arrays.\n"); - return vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed) | - vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed); + bool Vectorized = false; + Vectorized |= + vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed); + Vectorized |= + vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed); + return Vectorized; } // We won't try again to vectorize the elements of the chain, regardless of @@ -1249,8 +1247,10 @@ bool Vectorizer::vectorizeLoadChain( if (accessIsMisaligned(SzInBytes, AS, Alignment)) { if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { auto Chains = splitOddVectorElts(Chain, Sz); - return vectorizeLoadChain(Chains.first, InstructionsProcessed) | - vectorizeLoadChain(Chains.second, InstructionsProcessed); + bool Vectorized = false; + Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed); + Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed); + return Vectorized; } Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(), @@ -1264,8 +1264,10 @@ bool Vectorizer::vectorizeLoadChain( if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { auto Chains = splitOddVectorElts(Chain, Sz); - return vectorizeLoadChain(Chains.first, InstructionsProcessed) | - vectorizeLoadChain(Chains.second, InstructionsProcessed); + bool Vectorized = false; + Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed); + Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed); + return Vectorized; } LLVM_DEBUG({ diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 3c484fb0d28a..805011191da0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -419,7 +419,8 @@ static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, return false; } -int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) const { +int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy, + Value *Ptr) const { const ValueToValueMap &Strides = getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap(); @@ -428,7 +429,8 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) const { llvm::shouldOptimizeForSize(TheLoop->getHeader(), PSI, BFI, PGSOQueryType::IRPass); bool CanAddPredicate = !OptForSize; - int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, CanAddPredicate, false); + int Stride = getPtrStride(PSE, AccessTy, Ptr, TheLoop, Strides, + CanAddPredicate, false); if (Stride == 1 || Stride == -1) return Stride; return 0; @@ -747,7 +749,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (CI) { auto *SE = PSE.getSE(); Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI); - for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) + for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) if (hasVectorInstrinsicScalarOpd(IntrinID, i)) { if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(i)), TheLoop)) { reportVectorizationFailure("Found unvectorizable intrinsic", diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 5c4c4fdfa3f7..a7d6609f8c56 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -268,12 +268,6 @@ class LoopVectorizationPlanner { /// A builder used to construct the current plan. VPBuilder Builder; - /// The best number of elements of the vector types used in the - /// transformed loop. BestVF = None means that vectorization is - /// disabled. - Optional<ElementCount> BestVF = None; - unsigned BestUF = 0; - public: LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, @@ -295,12 +289,13 @@ public: /// VF and its cost. VectorizationFactor planInVPlanNativePath(ElementCount UserVF); - /// Finalize the best decision and dispose of all other VPlans. - void setBestPlan(ElementCount VF, unsigned UF); + /// Return the best VPlan for \p VF. + VPlan &getBestPlanFor(ElementCount VF) const; /// Generate the IR code for the body of the vectorized loop according to the - /// best selected VPlan. - void executePlan(InnerLoopVectorizer &LB, DominatorTree *DT); + /// best selected \p VF, \p UF and VPlan \p BestPlan. + void executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan, + InnerLoopVectorizer &LB, DominatorTree *DT); #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printPlans(raw_ostream &O); @@ -308,12 +303,9 @@ public: /// Look through the existing plans and return true if we have one with all /// the vectorization factors in question. - bool hasPlanWithVFs(const ArrayRef<ElementCount> VFs) const { - return any_of(VPlans, [&](const VPlanPtr &Plan) { - return all_of(VFs, [&](const ElementCount &VF) { - return Plan->hasVF(VF); - }); - }); + bool hasPlanWithVF(ElementCount VF) const { + return any_of(VPlans, + [&](const VPlanPtr &Plan) { return Plan->hasVF(VF); }); } /// Test a \p Predicate on a \p Range of VF's. Return the value of applying @@ -351,13 +343,14 @@ private: /// legal to vectorize the loop. This method creates VPlans using VPRecipes. void buildVPlansWithVPRecipes(ElementCount MinVF, ElementCount MaxVF); - /// Adjust the recipes for any inloop reductions. The chain of instructions - /// leading from the loop exit instr to the phi need to be converted to - /// reductions, with one operand being vector and the other being the scalar - /// reduction chain. - void adjustRecipesForInLoopReductions(VPlanPtr &Plan, - VPRecipeBuilder &RecipeBuilder, - ElementCount MinVF); + // Adjust the recipes for reductions. For in-loop reductions the chain of + // instructions leading from the loop exit instr to the phi need to be + // converted to reductions, with one operand being vector and the other being + // the scalar reduction chain. For other reductions, a select is introduced + // between the phi and live-out recipes when folding the tail. + void adjustRecipesForReductions(VPBasicBlock *LatchVPBB, VPlanPtr &Plan, + VPRecipeBuilder &RecipeBuilder, + ElementCount MinVF); }; } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 671bc6b5212b..23bb6f0860c9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -87,7 +87,6 @@ #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -332,8 +331,8 @@ static cl::opt<bool> cl::desc("Prefer in-loop vector reductions, " "overriding the targets preference.")); -cl::opt<bool> EnableStrictReductions( - "enable-strict-reductions", cl::init(false), cl::Hidden, +static cl::opt<bool> ForceOrderedReductions( + "force-ordered-reductions", cl::init(false), cl::Hidden, cl::desc("Enable the vectorisation of loops with in-order (strict) " "FP reductions")); @@ -545,7 +544,8 @@ public: /// vectorized loop. void vectorizeMemoryInstruction(Instruction *Instr, VPTransformState &State, VPValue *Def, VPValue *Addr, - VPValue *StoredValue, VPValue *BlockInMask); + VPValue *StoredValue, VPValue *BlockInMask, + bool ConsecutiveStride, bool Reverse); /// Set the debug location in the builder \p Ptr using the debug location in /// \p V. If \p Ptr is None then it uses the class member's Builder. @@ -590,12 +590,11 @@ protected: /// Handle all cross-iteration phis in the header. void fixCrossIterationPHIs(VPTransformState &State); - /// Fix a first-order recurrence. This is the second phase of vectorizing - /// this phi node. + /// Create the exit value of first order recurrences in the middle block and + /// update their users. void fixFirstOrderRecurrence(VPWidenPHIRecipe *PhiR, VPTransformState &State); - /// Fix a reduction cross-iteration phi. This is the second phase of - /// vectorizing this phi node. + /// Create code for the loop exit value of the reduction. void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); /// Clear NSW/NUW flags from reduction instructions if necessary. @@ -621,9 +620,9 @@ protected: /// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...) /// to each vector element of Val. The sequence starts at StartIndex. /// \p Opcode is relevant for FP induction variable. - virtual Value *getStepVector(Value *Val, int StartIdx, Value *Step, - Instruction::BinaryOps Opcode = - Instruction::BinaryOpsEnd); + virtual Value * + getStepVector(Value *Val, Value *StartIdx, Value *Step, + Instruction::BinaryOps Opcode = Instruction::BinaryOpsEnd); /// Compute scalar induction steps. \p ScalarIV is the scalar induction /// variable on which to base the steps, \p Step is the size of the step, and @@ -890,9 +889,9 @@ public: private: Value *getBroadcastInstrs(Value *V) override; - Value *getStepVector(Value *Val, int StartIdx, Value *Step, - Instruction::BinaryOps Opcode = - Instruction::BinaryOpsEnd) override; + Value *getStepVector( + Value *Val, Value *StartIdx, Value *Step, + Instruction::BinaryOps Opcode = Instruction::BinaryOpsEnd) override; Value *reverseVector(Value *Vec) override; }; @@ -911,10 +910,9 @@ struct EpilogueLoopVectorizationInfo { Value *TripCount = nullptr; Value *VectorTripCount = nullptr; - EpilogueLoopVectorizationInfo(unsigned MVF, unsigned MUF, unsigned EVF, - unsigned EUF) - : MainLoopVF(ElementCount::getFixed(MVF)), MainLoopUF(MUF), - EpilogueVF(ElementCount::getFixed(EVF)), EpilogueUF(EUF) { + EpilogueLoopVectorizationInfo(ElementCount MVF, unsigned MUF, + ElementCount EVF, unsigned EUF) + : MainLoopVF(MVF), MainLoopUF(MUF), EpilogueVF(EVF), EpilogueUF(EUF) { assert(EUF == 1 && "A high UF for the epilogue loop is likely not beneficial."); } @@ -1105,11 +1103,10 @@ static OptimizationRemarkAnalysis createLVAnalysis(const char *PassName, } /// Return a value for Step multiplied by VF. -static Value *createStepForVF(IRBuilder<> &B, Constant *Step, ElementCount VF) { - assert(isa<ConstantInt>(Step) && "Expected an integer step"); - Constant *StepVal = ConstantInt::get( - Step->getType(), - cast<ConstantInt>(Step)->getSExtValue() * VF.getKnownMinValue()); +static Value *createStepForVF(IRBuilder<> &B, Type *Ty, ElementCount VF, + int64_t Step) { + assert(Ty->isIntegerTy() && "Expected an integer step"); + Constant *StepVal = ConstantInt::get(Ty, Step * VF.getKnownMinValue()); return VF.isScalable() ? B.CreateVScale(StepVal) : StepVal; } @@ -1121,6 +1118,13 @@ Value *getRuntimeVF(IRBuilder<> &B, Type *Ty, ElementCount VF) { return VF.isScalable() ? B.CreateVScale(EC) : EC; } +static Value *getRuntimeVFAsFloat(IRBuilder<> &B, Type *FTy, ElementCount VF) { + assert(FTy->isFloatingPointTy() && "Expected floating point type!"); + Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits()); + Value *RuntimeVF = getRuntimeVF(B, IntTy, VF); + return B.CreateUIToFP(RuntimeVF, FTy); +} + void reportVectorizationFailure(const StringRef DebugMsg, const StringRef OREMsg, const StringRef ORETag, OptimizationRemarkEmitter *ORE, Loop *TheLoop, @@ -1319,8 +1323,7 @@ public: /// the IsOrdered flag of RdxDesc is set and we do not allow reordering /// of FP operations. bool useOrderedReductions(const RecurrenceDescriptor &RdxDesc) { - return EnableStrictReductions && !Hints->allowReordering() && - RdxDesc.isOrdered(); + return !Hints->allowReordering() && RdxDesc.isOrdered(); } /// \returns The smallest bitwidth each instruction can be represented with. @@ -1495,14 +1498,14 @@ public: /// Returns true if the target machine supports masked store operation /// for the given \p DataType and kind of access to \p Ptr. bool isLegalMaskedStore(Type *DataType, Value *Ptr, Align Alignment) const { - return Legal->isConsecutivePtr(Ptr) && + return Legal->isConsecutivePtr(DataType, Ptr) && TTI.isLegalMaskedStore(DataType, Alignment); } /// Returns true if the target machine supports masked load operation /// for the given \p DataType and kind of access to \p Ptr. bool isLegalMaskedLoad(Type *DataType, Value *Ptr, Align Alignment) const { - return Legal->isConsecutivePtr(Ptr) && + return Legal->isConsecutivePtr(DataType, Ptr) && TTI.isLegalMaskedLoad(DataType, Alignment); } @@ -1539,7 +1542,7 @@ public: // through scalar predication or masked load/store or masked gather/scatter. // Superset of instructions that return true for isScalarWithPredication. bool isPredicatedInst(Instruction *I) { - if (!blockNeedsPredication(I->getParent())) + if (!blockNeedsPredicationForAnyReason(I->getParent())) return false; // Loads and stores that need some form of masked operation are predicated // instructions. @@ -1593,7 +1596,10 @@ public: /// Returns true if all loop blocks should be masked to fold tail loop. bool foldTailByMasking() const { return FoldTailByMasking; } - bool blockNeedsPredication(BasicBlock *BB) const { + /// Returns true if the instructions in this block requires predication + /// for any reason, e.g. because tail folding now requires a predicate + /// or because the block in the original loop was predicated. + bool blockNeedsPredicationForAnyReason(BasicBlock *BB) const { return foldTailByMasking() || Legal->blockNeedsPredication(BB); } @@ -1928,7 +1934,7 @@ class GeneratedRTChecks { /// The value representing the result of the generated memory runtime checks. /// If it is nullptr, either no memory runtime checks have been generated or /// they have been used. - Instruction *MemRuntimeCheckCond = nullptr; + Value *MemRuntimeCheckCond = nullptr; DominatorTree *DT; LoopInfo *LI; @@ -1971,7 +1977,7 @@ public: MemCheckBlock = SplitBlock(Pred, Pred->getTerminator(), DT, LI, nullptr, "vector.memcheck"); - std::tie(std::ignore, MemRuntimeCheckCond) = + MemRuntimeCheckCond = addRuntimeChecks(MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(), MemCheckExp); assert(MemRuntimeCheckCond && @@ -2030,7 +2036,6 @@ public: if (MemCheckExp.isInsertedInstruction(&I)) continue; SE.forgetValue(&I); - SE.eraseValueFromMap(&I); I.eraseFromParent(); } } @@ -2289,9 +2294,11 @@ void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( Step = Builder.CreateTrunc(Step, TruncType); Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); } + + Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0); Value *SplatStart = Builder.CreateVectorSplat(VF, Start); Value *SteppedStart = - getStepVector(SplatStart, 0, Step, II.getInductionOpcode()); + getStepVector(SplatStart, Zero, Step, II.getInductionOpcode()); // We create vector phi nodes for both integer and floating-point induction // variables. Here, we determine the kind of arithmetic we will perform. @@ -2308,12 +2315,11 @@ void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( // Multiply the vectorization factor by the step using integer or // floating-point arithmetic as appropriate. Type *StepType = Step->getType(); + Value *RuntimeVF; if (Step->getType()->isFloatingPointTy()) - StepType = IntegerType::get(StepType->getContext(), - StepType->getScalarSizeInBits()); - Value *RuntimeVF = getRuntimeVF(Builder, StepType, VF); - if (Step->getType()->isFloatingPointTy()) - RuntimeVF = Builder.CreateSIToFP(RuntimeVF, Step->getType()); + RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, VF); + else + RuntimeVF = getRuntimeVF(Builder, StepType, VF); Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF); // Create a vector splat to use in the induction update. @@ -2388,9 +2394,13 @@ void InnerLoopVectorizer::recordVectorLoopValueForInductionCast( if (isa<TruncInst>(EntryVal)) return; - const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts(); - if (Casts.empty()) + if (!CastDef) { + assert(ID.getCastInsts().empty() && + "there are casts for ID, but no CastDef"); return; + } + assert(!ID.getCastInsts().empty() && + "there is a CastDef, but no casts for ID"); // Only the first Cast instruction in the Casts vector is of interest. // The rest of the Casts (if exist) have no uses outside the // induction update chain itself. @@ -2462,9 +2472,14 @@ void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, Value *Start, Value *Broadcasted = getBroadcastInstrs(ScalarIV); for (unsigned Part = 0; Part < UF; ++Part) { assert(!VF.isScalable() && "scalable vectors not yet supported."); + Value *StartIdx; + if (Step->getType()->isFloatingPointTy()) + StartIdx = getRuntimeVFAsFloat(Builder, Step->getType(), VF * Part); + else + StartIdx = getRuntimeVF(Builder, Step->getType(), VF * Part); + Value *EntryPart = - getStepVector(Broadcasted, VF.getKnownMinValue() * Part, Step, - ID.getInductionOpcode()); + getStepVector(Broadcasted, StartIdx, Step, ID.getInductionOpcode()); State.set(Def, EntryPart, Part); if (Trunc) addMetadata(EntryPart, Trunc); @@ -2520,7 +2535,8 @@ void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, Value *Start, buildScalarSteps(ScalarIV, Step, EntryVal, ID, Def, CastDef, State); } -Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, +Value *InnerLoopVectorizer::getStepVector(Value *Val, Value *StartIdx, + Value *Step, Instruction::BinaryOps BinOp) { // Create and check the types. auto *ValVTy = cast<VectorType>(Val->getType()); @@ -2543,12 +2559,11 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, } Value *InitVec = Builder.CreateStepVector(InitVecValVTy); - // Add on StartIdx - Value *StartIdxSplat = Builder.CreateVectorSplat( - VLen, ConstantInt::get(InitVecValSTy, StartIdx)); - InitVec = Builder.CreateAdd(InitVec, StartIdxSplat); + // Splat the StartIdx + Value *StartIdxSplat = Builder.CreateVectorSplat(VLen, StartIdx); if (STy->isIntegerTy()) { + InitVec = Builder.CreateAdd(InitVec, StartIdxSplat); Step = Builder.CreateVectorSplat(VLen, Step); assert(Step->getType() == Val->getType() && "Invalid step vec"); // FIXME: The newly created binary instructions should contain nsw/nuw flags, @@ -2561,6 +2576,8 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && "Binary Opcode should be specified for FP induction"); InitVec = Builder.CreateUIToFP(InitVec, ValVTy); + InitVec = Builder.CreateFAdd(InitVec, StartIdxSplat); + Step = Builder.CreateVectorSplat(VLen, Step); Value *MulOp = Builder.CreateFMul(InitVec, Step); return Builder.CreateBinOp(BinOp, Val, MulOp, "induction"); @@ -2609,8 +2626,7 @@ void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, } for (unsigned Part = 0; Part < UF; ++Part) { - Value *StartIdx0 = - createStepForVF(Builder, ConstantInt::get(IntStepTy, Part), VF); + Value *StartIdx0 = createStepForVF(Builder, IntStepTy, VF, Part); if (!IsUniform && VF.isScalable()) { auto *SplatStartIdx = Builder.CreateVectorSplat(VF, StartIdx0); @@ -2838,12 +2854,25 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( auto *SubVT = VectorType::get(ScalarTy, VF); // Vectorize the interleaved store group. + MaskForGaps = createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); + assert((!MaskForGaps || useMaskedInterleavedAccesses(*TTI)) && + "masked interleaved groups are not allowed."); + assert((!MaskForGaps || !VF.isScalable()) && + "masking gaps for scalable vectors is not yet supported."); for (unsigned Part = 0; Part < UF; Part++) { // Collect the stored vector from each member. SmallVector<Value *, 4> StoredVecs; for (unsigned i = 0; i < InterleaveFactor; i++) { - // Interleaved store group doesn't allow a gap, so each index has a member - assert(Group->getMember(i) && "Fail to get a member from an interleaved store group"); + assert((Group->getMember(i) || MaskForGaps) && + "Fail to get a member from an interleaved store group"); + Instruction *Member = Group->getMember(i); + + // Skip the gaps in the group. + if (!Member) { + Value *Undef = PoisonValue::get(SubVT); + StoredVecs.push_back(Undef); + continue; + } Value *StoredVec = State.get(StoredValues[i], Part); @@ -2867,16 +2896,21 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( "interleaved.vec"); Instruction *NewStoreInstr; - if (BlockInMask) { - Value *BlockInMaskPart = State.get(BlockInMask, Part); - Value *ShuffledMask = Builder.CreateShuffleVector( - BlockInMaskPart, - createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()), - "interleaved.mask"); - NewStoreInstr = Builder.CreateMaskedStore( - IVec, AddrParts[Part], Group->getAlign(), ShuffledMask); - } - else + if (BlockInMask || MaskForGaps) { + Value *GroupMask = MaskForGaps; + if (BlockInMask) { + Value *BlockInMaskPart = State.get(BlockInMask, Part); + Value *ShuffledMask = Builder.CreateShuffleVector( + BlockInMaskPart, + createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()), + "interleaved.mask"); + GroupMask = MaskForGaps ? Builder.CreateBinOp(Instruction::And, + ShuffledMask, MaskForGaps) + : ShuffledMask; + } + NewStoreInstr = Builder.CreateMaskedStore(IVec, AddrParts[Part], + Group->getAlign(), GroupMask); + } else NewStoreInstr = Builder.CreateAlignedStore(IVec, AddrParts[Part], Group->getAlign()); @@ -2886,7 +2920,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( void InnerLoopVectorizer::vectorizeMemoryInstruction( Instruction *Instr, VPTransformState &State, VPValue *Def, VPValue *Addr, - VPValue *StoredValue, VPValue *BlockInMask) { + VPValue *StoredValue, VPValue *BlockInMask, bool ConsecutiveStride, + bool Reverse) { // Attempt to issue a wide load. LoadInst *LI = dyn_cast<LoadInst>(Instr); StoreInst *SI = dyn_cast<StoreInst>(Instr); @@ -2895,31 +2930,11 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction( assert((!SI || StoredValue) && "No stored value provided for widened store"); assert((!LI || !StoredValue) && "Stored value provided for widened load"); - LoopVectorizationCostModel::InstWidening Decision = - Cost->getWideningDecision(Instr, VF); - assert((Decision == LoopVectorizationCostModel::CM_Widen || - Decision == LoopVectorizationCostModel::CM_Widen_Reverse || - Decision == LoopVectorizationCostModel::CM_GatherScatter) && - "CM decision is not to widen the memory instruction"); - Type *ScalarDataTy = getLoadStoreType(Instr); auto *DataTy = VectorType::get(ScalarDataTy, VF); const Align Alignment = getLoadStoreAlignment(Instr); - - // Determine if the pointer operand of the access is either consecutive or - // reverse consecutive. - bool Reverse = (Decision == LoopVectorizationCostModel::CM_Widen_Reverse); - bool ConsecutiveStride = - Reverse || (Decision == LoopVectorizationCostModel::CM_Widen); - bool CreateGatherScatter = - (Decision == LoopVectorizationCostModel::CM_GatherScatter); - - // Either Ptr feeds a vector load/store, or a vector GEP should feed a vector - // gather/scatter. Otherwise Decision should have been to Scalarize. - assert((ConsecutiveStride || CreateGatherScatter) && - "The instruction should be scalarized"); - (void)ConsecutiveStride; + bool CreateGatherScatter = !ConsecutiveStride; VectorParts BlockInMaskParts(UF); bool isMaskRequired = BlockInMask; @@ -2953,7 +2968,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction( if (isMaskRequired) // Reverse of a null all-one mask is a null mask. BlockInMaskParts[Part] = reverseVector(BlockInMaskParts[Part]); } else { - Value *Increment = createStepForVF(Builder, Builder.getInt32(Part), VF); + Value *Increment = + createStepForVF(Builder, Builder.getInt32Ty(), VF, Part); PartPtr = cast<GetElementPtrInst>( Builder.CreateGEP(ScalarDataTy, Ptr, Increment)); PartPtr->setIsInBounds(InBounds); @@ -3172,7 +3188,7 @@ Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { Type *Ty = TC->getType(); // This is where we can make the step a runtime constant. - Value *Step = createStepForVF(Builder, ConstantInt::get(Ty, UF), VF); + Value *Step = createStepForVF(Builder, Ty, VF, UF); // If the tail is to be folded by masking, round the number of iterations N // up to a multiple of Step instead of rounding down. This is done by first @@ -3262,8 +3278,7 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, // If tail is to be folded, vector loop takes care of all iterations. Value *CheckMinIters = Builder.getFalse(); if (!Cost->foldTailByMasking()) { - Value *Step = - createStepForVF(Builder, ConstantInt::get(Count->getType(), UF), VF); + Value *Step = createStepForVF(Builder, Count->getType(), VF, UF); CheckMinIters = Builder.CreateICmp(P, Count, Step, "min.iters.check"); } // Create new preheader for vector loop. @@ -3433,7 +3448,7 @@ Value *InnerLoopVectorizer::emitTransformedIndex( assert(isa<SCEVConstant>(Step) && "Expected constant step for pointer induction"); return B.CreateGEP( - StartValue->getType()->getPointerElementType(), StartValue, + ID.getElementType(), StartValue, CreateMul(Index, Exp.expandCodeFor(Step, Index->getType()->getScalarType(), GetInsertPoint()))); @@ -3739,7 +3754,7 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { // The loop step is equal to the vectorization factor (num of SIMD elements) // times the unroll factor (num of SIMD instructions). Builder.SetInsertPoint(&*Lp->getHeader()->getFirstInsertionPt()); - Value *Step = createStepForVF(Builder, ConstantInt::get(IdxTy, UF), VF); + Value *Step = createStepForVF(Builder, IdxTy, VF, UF); Value *CountRoundDown = getOrCreateVectorTripCount(Lp); Induction = createInductionVariable(Lp, StartIdx, CountRoundDown, Step, @@ -3857,21 +3872,19 @@ struct CSEDenseMapInfo { static void cse(BasicBlock *BB) { // Perform simple cse. SmallDenseMap<Instruction *, Instruction *, 4, CSEDenseMapInfo> CSEMap; - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { - Instruction *In = &*I++; - - if (!CSEDenseMapInfo::canHandle(In)) + for (Instruction &In : llvm::make_early_inc_range(*BB)) { + if (!CSEDenseMapInfo::canHandle(&In)) continue; // Check if we can replace this instruction with any of the // visited instructions. - if (Instruction *V = CSEMap.lookup(In)) { - In->replaceAllUsesWith(V); - In->eraseFromParent(); + if (Instruction *V = CSEMap.lookup(&In)) { + In.replaceAllUsesWith(V); + In.eraseFromParent(); continue; } - CSEMap[In] = In; + CSEMap[&In] = &In; } } @@ -3881,7 +3894,7 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF, Function *F = CI->getCalledFunction(); Type *ScalarRetTy = CI->getType(); SmallVector<Type *, 4> Tys, ScalarTys; - for (auto &ArgOp : CI->arg_operands()) + for (auto &ArgOp : CI->args()) ScalarTys.push_back(ArgOp->getType()); // Estimate cost of scalarized vector call. The source operands are assumed @@ -3940,7 +3953,7 @@ LoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI, if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) FMF = FPMO->getFastMathFlags(); - SmallVector<const Value *> Arguments(CI->arg_begin(), CI->arg_end()); + SmallVector<const Value *> Arguments(CI->args()); FunctionType *FTy = CI->getCalledFunction()->getFunctionType(); SmallVector<Type *> ParamTys; std::transform(FTy->param_begin(), FTy->param_end(), @@ -3974,7 +3987,8 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) { // If the value wasn't vectorized, we must maintain the original scalar // type. The absence of the value from State indicates that it // wasn't vectorized. - VPValue *Def = State.Plan->getVPValue(KV.first); + // FIXME: Should not rely on getVPValue at this point. + VPValue *Def = State.Plan->getVPValue(KV.first, true); if (!State.hasAnyVectorValue(Def)) continue; for (unsigned Part = 0; Part < UF; ++Part) { @@ -4081,7 +4095,8 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) { // If the value wasn't vectorized, we must maintain the original scalar // type. The absence of the value from State indicates that it // wasn't vectorized. - VPValue *Def = State.Plan->getVPValue(KV.first); + // FIXME: Should not rely on getVPValue at this point. + VPValue *Def = State.Plan->getVPValue(KV.first, true); if (!State.hasAnyVectorValue(Def)) continue; for (unsigned Part = 0; Part < UF; ++Part) { @@ -4222,17 +4237,12 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(VPWidenPHIRecipe *PhiR, // After execution completes the vector loop, we extract the next value of // the recurrence (x) to use as the initial value in the scalar loop. - auto *IdxTy = Builder.getInt32Ty(); - auto *VecPhi = cast<PHINode>(State.get(PhiR, 0)); - - // Fix the latch value of the new recurrence in the vector loop. - VPValue *PreviousDef = PhiR->getBackedgeValue(); - Value *Incoming = State.get(PreviousDef, UF - 1); - VecPhi->addIncoming(Incoming, LI->getLoopFor(LoopVectorBody)->getLoopLatch()); - // Extract the last vector element in the middle block. This will be the // initial value for the recurrence when jumping to the scalar loop. + VPValue *PreviousDef = PhiR->getBackedgeValue(); + Value *Incoming = State.get(PreviousDef, UF - 1); auto *ExtractForScalar = Incoming; + auto *IdxTy = Builder.getInt32Ty(); if (VF.isVector()) { auto *One = ConstantInt::get(IdxTy, 1); Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); @@ -4283,8 +4293,7 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(VPWidenPHIRecipe *PhiR, // and thus no phis which needed updated. if (!Cost->requiresScalarEpilogue(VF)) for (PHINode &LCSSAPhi : LoopExitBlock->phis()) - if (any_of(LCSSAPhi.incoming_values(), - [Phi](Value *V) { return V == Phi; })) + if (llvm::is_contained(LCSSAPhi.incoming_values(), Phi)) LCSSAPhi.addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); } @@ -4301,29 +4310,13 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); setDebugLocFromInst(ReductionStartValue); - VPValue *LoopExitInstDef = State.Plan->getVPValue(LoopExitInst); + VPValue *LoopExitInstDef = PhiR->getBackedgeValue(); // This is the vector-clone of the value that leaves the loop. Type *VecTy = State.get(LoopExitInstDef, 0)->getType(); // Wrap flags are in general invalid after vectorization, clear them. clearReductionWrapFlags(RdxDesc, State); - // Fix the vector-loop phi. - - // Reductions do not have to start at zero. They can start with - // any loop invariant values. - BasicBlock *VectorLoopLatch = LI->getLoopFor(LoopVectorBody)->getLoopLatch(); - - unsigned LastPartForNewPhi = PhiR->isOrdered() ? 1 : UF; - for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *VecRdxPhi = State.get(PhiR->getVPSingleValue(), Part); - Value *Val = State.get(PhiR->getBackedgeValue(), Part); - if (PhiR->isOrdered()) - Val = State.get(PhiR->getBackedgeValue(), UF - 1); - - cast<PHINode>(VecRdxPhi)->addIncoming(Val, VectorLoopLatch); - } - // Before each round, move the insertion point right between // the PHIs and the values we are going to write. // This allows us to write both PHINodes and the extractelement @@ -4361,7 +4354,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, RdxDesc.getOpcode(), PhiTy, TargetTransformInfo::ReductionFlags())) { auto *VecRdxPhi = - cast<PHINode>(State.get(PhiR->getVPSingleValue(), Part)); + cast<PHINode>(State.get(PhiR, Part)); VecRdxPhi->setIncomingValueForBlock( LI->getLoopFor(LoopVectorBody)->getLoopLatch(), Sel); } @@ -4382,13 +4375,10 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, Value *Trunc = Builder.CreateTrunc(RdxParts[Part], RdxVecTy); Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) : Builder.CreateZExt(Trunc, VecTy); - for (Value::user_iterator UI = RdxParts[Part]->user_begin(); - UI != RdxParts[Part]->user_end();) - if (*UI != Trunc) { - (*UI++)->replaceUsesOfWith(RdxParts[Part], Extnd); + for (User *U : llvm::make_early_inc_range(RdxParts[Part]->users())) + if (U != Trunc) { + U->replaceUsesOfWith(RdxParts[Part], Extnd); RdxParts[Part] = Extnd; - } else { - ++UI; } } Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); @@ -4421,9 +4411,11 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, if (Op != Instruction::ICmp && Op != Instruction::FCmp) { ReducedPartRdx = Builder.CreateBinOp( (Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx"); - } else { + } else if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) + ReducedPartRdx = createSelectCmpOp(Builder, ReductionStartValue, RK, + ReducedPartRdx, RdxPart); + else ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart); - } } } @@ -4431,7 +4423,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // target reduction in the loop using a Reduction recipe. if (VF.isVector() && !PhiR->isInLoop()) { ReducedPartRdx = - createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx); + createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, OrigPhi); // If the reduction can be performed in a smaller type, we need to extend // the reduction to the wider type before we branch to the original loop. if (PhiTy != RdxDesc.getRecurrenceType()) @@ -4456,8 +4448,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // fixFirstOrderRecurrence for a more complete explaination of the logic. if (!Cost->requiresScalarEpilogue(VF)) for (PHINode &LCSSAPhi : LoopExitBlock->phis()) - if (any_of(LCSSAPhi.incoming_values(), - [LoopExitInst](Value *V) { return V == LoopExitInst; })) + if (llvm::is_contained(LCSSAPhi.incoming_values(), LoopExitInst)) LCSSAPhi.addIncoming(ReducedPartRdx, LoopMiddleBlock); // Fix the scalar loop reduction variable with the incoming reduction sum @@ -4488,7 +4479,8 @@ void InnerLoopVectorizer::clearReductionWrapFlags(const RecurrenceDescriptor &Rd Instruction *Cur = Worklist.pop_back_val(); if (isa<OverflowingBinaryOperator>(Cur)) for (unsigned Part = 0; Part < UF; ++Part) { - Value *V = State.get(State.Plan->getVPValue(Cur), Part); + // FIXME: Should not rely on getVPValue at this point. + Value *V = State.get(State.Plan->getVPValue(Cur, true), Part); cast<Instruction>(V)->dropPoisonGeneratingFlags(); } @@ -4519,11 +4511,12 @@ void InnerLoopVectorizer::fixLCSSAPHIs(VPTransformState &State) { // Can be a loop invariant incoming value or the last scalar value to be // extracted from the vectorized loop. + // FIXME: Should not rely on getVPValue at this point. Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); Value *lastIncomingValue = OrigLoop->isLoopInvariant(IncomingValue) ? IncomingValue - : State.get(State.Plan->getVPValue(IncomingValue), + : State.get(State.Plan->getVPValue(IncomingValue, true), VPIteration(UF - 1, Lane)); LCSSAPhi.addIncoming(lastIncomingValue, LoopMiddleBlock); } @@ -4763,10 +4756,18 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, } for (unsigned Part = 0; Part < UF; ++Part) { - Value *PartStart = createStepForVF( - Builder, ConstantInt::get(PtrInd->getType(), Part), VF); + Value *PartStart = + createStepForVF(Builder, PtrInd->getType(), VF, Part); if (NeedsVectorIndex) { + // Here we cache the whole vector, which means we can support the + // extraction of any lane. However, in some cases the extractelement + // instruction that is generated for scalar uses of this vector (e.g. + // a load instruction) is not folded away. Therefore we still + // calculate values for the first n lanes to avoid redundant moves + // (when extracting the 0th element) and to produce scalar code (i.e. + // additional add/gep instructions instead of expensive extractelement + // instructions) when extracting higher-order elements. Value *PartStartSplat = Builder.CreateVectorSplat(VF, PartStart); Value *Indices = Builder.CreateAdd(PartStartSplat, UnitStepVec); Value *GlobalIndices = Builder.CreateAdd(PtrIndSplat, Indices); @@ -4774,9 +4775,6 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, emitTransformedIndex(Builder, GlobalIndices, PSE.getSE(), DL, II); SclrGep->setName("next.gep"); State.set(PhiR, SclrGep, Part); - // We've cached the whole vector, which means we can support the - // extraction of any lane. - continue; } for (unsigned Lane = 0; Lane < Lanes; ++Lane) { @@ -4813,7 +4811,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, Value *NumUnrolledElems = Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, State.UF)); Value *InductionGEP = GetElementPtrInst::Create( - ScStValueType->getPointerElementType(), NewPointerPhi, + II.getElementType(), NewPointerPhi, Builder.CreateMul(ScalarStepValue, NumUnrolledElems), "ptr.ind", InductionLoc); NewPointerPhi->addIncoming(InductionGEP, LoopLatch); @@ -4832,7 +4830,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, Builder.CreateAdd(StartOffset, Builder.CreateStepVector(VecPhiType)); Value *GEP = Builder.CreateGEP( - ScStValueType->getPointerElementType(), NewPointerPhi, + II.getElementType(), NewPointerPhi, Builder.CreateMul( StartOffset, Builder.CreateVectorSplat(State.VF, ScalarStepValue), "vector.gep")); @@ -4979,7 +4977,7 @@ void InnerLoopVectorizer::widenCallInstruction(CallInst &I, VPValue *Def, auto *CI = cast<CallInst>(&I); SmallVector<Type *, 4> Tys; - for (Value *ArgOperand : CI->arg_operands()) + for (Value *ArgOperand : CI->args()) Tys.push_back(ToVectorTy(ArgOperand->getType(), VF.getKnownMinValue())); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); @@ -5128,8 +5126,14 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { Instruction *Update = cast<Instruction>( cast<PHINode>(Ptr)->getIncomingValueForBlock(Latch)); - ScalarPtrs.insert(Update); - return; + + // If there is more than one user of Update (Ptr), we shouldn't assume it + // will be scalar after vectorisation as other users of the instruction + // may require widening. Otherwise, add it to ScalarPtrs. + if (Update->hasOneUse() && cast<Value>(*Update->user_begin()) == Ptr) { + ScalarPtrs.insert(Update); + return; + } } // We only care about bitcast and getelementptr instructions contained in // the loop. @@ -5142,12 +5146,11 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { if (Worklist.count(I)) return; - // If all users of the pointer will be memory accesses and scalar, place the - // pointer in ScalarPtrs. Otherwise, place the pointer in - // PossibleNonScalarPtrs. - if (llvm::all_of(I->users(), [&](User *U) { - return (isa<LoadInst>(U) || isa<StoreInst>(U)) && - isScalarUse(cast<Instruction>(U), Ptr); + // If the use of the pointer will be a scalar use, and all users of the + // pointer are memory accesses, place the pointer in ScalarPtrs. Otherwise, + // place the pointer in PossibleNonScalarPtrs. + if (isScalarUse(MemAccess, Ptr) && llvm::all_of(I->users(), [&](User *U) { + return isa<LoadInst>(U) || isa<StoreInst>(U); })) ScalarPtrs.insert(I); else @@ -5254,7 +5257,7 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { } bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I) const { - if (!blockNeedsPredication(I->getParent())) + if (!blockNeedsPredicationForAnyReason(I->getParent())) return false; switch(I->getOpcode()) { default: @@ -5297,12 +5300,20 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( // Check if masking is required. // A Group may need masking for one of two reasons: it resides in a block that - // needs predication, or it was decided to use masking to deal with gaps. + // needs predication, or it was decided to use masking to deal with gaps + // (either a gap at the end of a load-access that may result in a speculative + // load, or any gaps in a store-access). bool PredicatedAccessRequiresMasking = - Legal->blockNeedsPredication(I->getParent()) && Legal->isMaskRequired(I); - bool AccessWithGapsRequiresMasking = - Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed(); - if (!PredicatedAccessRequiresMasking && !AccessWithGapsRequiresMasking) + blockNeedsPredicationForAnyReason(I->getParent()) && + Legal->isMaskRequired(I); + bool LoadAccessWithGapsRequiresEpilogMasking = + isa<LoadInst>(I) && Group->requiresScalarEpilogue() && + !isScalarEpilogueAllowed(); + bool StoreAccessWithGapsRequiresMasking = + isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor()); + if (!PredicatedAccessRequiresMasking && + !LoadAccessWithGapsRequiresEpilogMasking && + !StoreAccessWithGapsRequiresMasking) return true; // If masked interleaving is required, we expect that the user/target had @@ -5311,6 +5322,9 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( assert(useMaskedInterleavedAccesses(TTI) && "Masked interleave-groups for predicated accesses are not enabled."); + if (Group->isReverse()) + return false; + auto *Ty = getLoadStoreType(I); const Align Alignment = getLoadStoreAlignment(I); return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty, Alignment) @@ -5320,14 +5334,13 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( bool LoopVectorizationCostModel::memoryInstructionCanBeWidened( Instruction *I, ElementCount VF) { // Get and ensure we have a valid memory instruction. - LoadInst *LI = dyn_cast<LoadInst>(I); - StoreInst *SI = dyn_cast<StoreInst>(I); - assert((LI || SI) && "Invalid memory instruction"); + assert((isa<LoadInst, StoreInst>(I)) && "Invalid memory instruction"); auto *Ptr = getLoadStorePointerOperand(I); + auto *ScalarTy = getLoadStoreType(I); // In order to be widened, the pointer should be consecutive, first of all. - if (!Legal->isConsecutivePtr(Ptr)) + if (!Legal->isConsecutivePtr(ScalarTy, Ptr)) return false; // If the instruction is a store located in a predicated block, it will be @@ -5338,7 +5351,6 @@ bool LoopVectorizationCostModel::memoryInstructionCanBeWidened( // If the instruction's allocated size doesn't equal it's type size, it // requires padding and will be scalarized. auto &DL = I->getModule()->getDataLayout(); - auto *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); if (hasIrregularType(ScalarTy, DL)) return false; @@ -5369,12 +5381,14 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { return (!I || !TheLoop->contains(I)); }; + // Worklist containing uniform instructions demanding lane 0. SetVector<Instruction *> Worklist; BasicBlock *Latch = TheLoop->getLoopLatch(); - // Instructions that are scalar with predication must not be considered - // uniform after vectorization, because that would create an erroneous - // replicating region where only a single instance out of VF should be formed. + // Add uniform instructions demanding lane 0 to the worklist. Instructions + // that are scalar with predication must not be considered uniform after + // vectorization, because that would create an erroneous replicating region + // where only a single instance out of VF should be formed. // TODO: optimize such seldom cases if found important, see PR40816. auto addToWorklistIfAllowed = [&](Instruction *I) -> void { if (isOutOfScope(I)) { @@ -5448,6 +5462,15 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { } } + // ExtractValue instructions must be uniform, because the operands are + // known to be loop-invariant. + if (auto *EVI = dyn_cast<ExtractValueInst>(&I)) { + assert(isOutOfScope(EVI->getAggregateOperand()) && + "Expected aggregate value to be loop invariant"); + addToWorklistIfAllowed(EVI); + continue; + } + // If there's no pointer operand, there's nothing to do. auto *Ptr = getLoadStorePointerOperand(&I); if (!Ptr) @@ -5580,13 +5603,8 @@ bool LoopVectorizationCostModel::runtimeChecksRequired() { ElementCount LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { - if (!TTI.supportsScalableVectors() && !ForceTargetSupportsScalableVectors) { - reportVectorizationInfo( - "Disabling scalable vectorization, because target does not " - "support scalable vectors.", - "ScalableVectorsUnsupported", ORE, TheLoop); + if (!TTI.supportsScalableVectors() && !ForceTargetSupportsScalableVectors) return ElementCount::getScalable(0); - } if (Hints->isScalableVectorizationDisabled()) { reportVectorizationInfo("Scalable vectorization is explicitly disabled", @@ -5594,6 +5612,8 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { return ElementCount::getScalable(0); } + LLVM_DEBUG(dbgs() << "LV: Scalable vectorization is available\n"); + auto MaxScalableVF = ElementCount::getScalable( std::numeric_limits<ElementCount::ScalarTy>::max()); @@ -5629,6 +5649,13 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { // Limit MaxScalableVF by the maximum safe dependence distance. Optional<unsigned> MaxVScale = TTI.getMaxVScale(); + if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange)) { + unsigned VScaleMax = TheFunction->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + if (VScaleMax > 0) + MaxVScale = VScaleMax; + } MaxScalableVF = ElementCount::getScalable( MaxVScale ? (MaxSafeElements / MaxVScale.getValue()) : 0); if (!MaxScalableVF) @@ -5696,17 +5723,32 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(unsigned ConstTripCount, return MaxSafeFixedVF; } - LLVM_DEBUG(dbgs() << "LV: User VF=" << UserVF - << " is unsafe. Ignoring scalable UserVF.\n"); - ORE->emit([&]() { - return OptimizationRemarkAnalysis(DEBUG_TYPE, "VectorizationFactor", - TheLoop->getStartLoc(), - TheLoop->getHeader()) - << "User-specified vectorization factor " - << ore::NV("UserVectorizationFactor", UserVF) - << " is unsafe. Ignoring the hint to let the compiler pick a " - "suitable VF."; - }); + if (!TTI.supportsScalableVectors() && !ForceTargetSupportsScalableVectors) { + LLVM_DEBUG(dbgs() << "LV: User VF=" << UserVF + << " is ignored because scalable vectors are not " + "available.\n"); + ORE->emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "VectorizationFactor", + TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "User-specified vectorization factor " + << ore::NV("UserVectorizationFactor", UserVF) + << " is ignored because the target does not support scalable " + "vectors. The compiler will pick a more suitable value."; + }); + } else { + LLVM_DEBUG(dbgs() << "LV: User VF=" << UserVF + << " is unsafe. Ignoring scalable UserVF.\n"); + ORE->emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "VectorizationFactor", + TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "User-specified vectorization factor " + << ore::NV("UserVectorizationFactor", UserVF) + << " is unsafe. Ignoring the hint to let the compiler pick a " + "more suitable value."; + }); + } } LLVM_DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType @@ -5987,19 +6029,27 @@ bool LoopVectorizationCostModel::isMoreProfitable( return RTCostA < RTCostB; } - // When set to preferred, for now assume vscale may be larger than 1, so - // that scalable vectorization is slightly favorable over fixed-width - // vectorization. + // Improve estimate for the vector width if it is scalable. + unsigned EstimatedWidthA = A.Width.getKnownMinValue(); + unsigned EstimatedWidthB = B.Width.getKnownMinValue(); + if (Optional<unsigned> VScale = TTI.getVScaleForTuning()) { + if (A.Width.isScalable()) + EstimatedWidthA *= VScale.getValue(); + if (B.Width.isScalable()) + EstimatedWidthB *= VScale.getValue(); + } + + // When set to preferred, for now assume vscale may be larger than 1 (or the + // one being tuned for), so that scalable vectorization is slightly favorable + // over fixed-width vectorization. if (Hints->isScalableVectorizationPreferred()) if (A.Width.isScalable() && !B.Width.isScalable()) - return (CostA * B.Width.getKnownMinValue()) <= - (CostB * A.Width.getKnownMinValue()); + return (CostA * B.Width.getFixedValue()) <= (CostB * EstimatedWidthA); // To avoid the need for FP division: // (CostA / A.Width) < (CostB / B.Width) // <=> (CostA * B.Width) < (CostB * A.Width) - return (CostA * B.Width.getKnownMinValue()) < - (CostB * A.Width.getKnownMinValue()); + return (CostA * EstimatedWidthB) < (CostB * EstimatedWidthA); } VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( @@ -6029,11 +6079,22 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( VectorizationCostTy C = expectedCost(i, &InvalidCosts); VectorizationFactor Candidate(i, C.first); - LLVM_DEBUG( - dbgs() << "LV: Vector loop of width " << i << " costs: " - << (Candidate.Cost / Candidate.Width.getKnownMinValue()) - << (i.isScalable() ? " (assuming a minimum vscale of 1)" : "") - << ".\n"); + +#ifndef NDEBUG + unsigned AssumedMinimumVscale = 1; + if (Optional<unsigned> VScale = TTI.getVScaleForTuning()) + AssumedMinimumVscale = VScale.getValue(); + unsigned Width = + Candidate.Width.isScalable() + ? Candidate.Width.getKnownMinValue() * AssumedMinimumVscale + : Candidate.Width.getFixedValue(); + LLVM_DEBUG(dbgs() << "LV: Vector loop of width " << i + << " costs: " << (Candidate.Cost / Width)); + if (i.isScalable()) + LLVM_DEBUG(dbgs() << " (assuming a minimum vscale of " + << AssumedMinimumVscale << ")"); + LLVM_DEBUG(dbgs() << ".\n"); +#endif if (!C.second && !ForceVectorization) { LLVM_DEBUG( @@ -6197,15 +6258,6 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( return Result; } - // FIXME: This can be fixed for scalable vectors later, because at this stage - // the LoopVectorizer will only consider vectorizing a loop with scalable - // vectors when the loop has a hint to enable vectorization for a given VF. - if (MainLoopVF.isScalable()) { - LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization for scalable vectors not " - "yet supported.\n"); - return Result; - } - // Not really a cost consideration, but check for unsupported cases here to // simplify the logic. if (!isCandidateForEpilogueVectorization(*TheLoop, MainLoopVF)) { @@ -6217,9 +6269,9 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( if (EpilogueVectorizationForceVF > 1) { LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n";); - if (LVP.hasPlanWithVFs( - {MainLoopVF, ElementCount::getFixed(EpilogueVectorizationForceVF)})) - return {ElementCount::getFixed(EpilogueVectorizationForceVF), 0}; + ElementCount ForcedEC = ElementCount::getFixed(EpilogueVectorizationForceVF); + if (LVP.hasPlanWithVF(ForcedEC)) + return {ForcedEC, 0}; else { LLVM_DEBUG( dbgs() @@ -6236,14 +6288,24 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( return Result; } - if (!isEpilogueVectorizationProfitable(MainLoopVF)) + auto FixedMainLoopVF = ElementCount::getFixed(MainLoopVF.getKnownMinValue()); + if (MainLoopVF.isScalable()) + LLVM_DEBUG( + dbgs() << "LEV: Epilogue vectorization using scalable vectors not " + "yet supported. Converting to fixed-width (VF=" + << FixedMainLoopVF << ") instead\n"); + + if (!isEpilogueVectorizationProfitable(FixedMainLoopVF)) { + LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization is not profitable for " + "this loop\n"); return Result; + } for (auto &NextVF : ProfitableVFs) - if (ElementCount::isKnownLT(NextVF.Width, MainLoopVF) && + if (ElementCount::isKnownLT(NextVF.Width, FixedMainLoopVF) && (Result.Width.getFixedValue() == 1 || isMoreProfitable(NextVF, Result)) && - LVP.hasPlanWithVFs({MainLoopVF, NextVF.Width})) + LVP.hasPlanWithVF(NextVF.Width)) Result = NextVF; if (Result != VectorizationFactor::Disabled()) @@ -6486,6 +6548,22 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, unsigned StoresIC = IC / (NumStores ? NumStores : 1); unsigned LoadsIC = IC / (NumLoads ? NumLoads : 1); + // There is little point in interleaving for reductions containing selects + // and compares when VF=1 since it may just create more overhead than it's + // worth for loops with small trip counts. This is because we still have to + // do the final reduction after the loop. + bool HasSelectCmpReductions = + HasReductions && + any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool { + const RecurrenceDescriptor &RdxDesc = Reduction.second; + return RecurrenceDescriptor::isSelectCmpRecurrenceKind( + RdxDesc.getRecurrenceKind()); + }); + if (HasSelectCmpReductions) { + LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n"); + return 1; + } + // If we have a scalar reduction (vector reductions are already dealt with // by this point), we can increase the critical path length if the loop // we're interleaving is inside another loop. For tree-wise reductions @@ -6771,7 +6849,7 @@ void LoopVectorizationCostModel::collectInstsToScalarize(ElementCount VF) { // determine if it would be better to not if-convert the blocks they are in. // If so, we also record the instructions to scalarize. for (BasicBlock *BB : TheLoop->blocks()) { - if (!blockNeedsPredication(BB)) + if (!blockNeedsPredicationForAnyReason(BB)) continue; for (Instruction &I : *BB) if (isScalarWithPredication(&I)) { @@ -6866,7 +6944,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) { ScalarCost += TTI.getScalarizationOverhead( cast<VectorType>(ToVectorTy(I->getType(), VF)), - APInt::getAllOnesValue(VF.getFixedValue()), true, false); + APInt::getAllOnes(VF.getFixedValue()), true, false); ScalarCost += VF.getFixedValue() * TTI.getCFInstrCost(Instruction::PHI, TTI::TCK_RecipThroughput); @@ -6885,7 +6963,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( else if (needsExtract(J, VF)) { ScalarCost += TTI.getScalarizationOverhead( cast<VectorType>(ToVectorTy(J->getType(), VF)), - APInt::getAllOnesValue(VF.getFixedValue()), false, true); + APInt::getAllOnes(VF.getFixedValue()), false, true); } } @@ -7031,7 +7109,7 @@ LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, auto *Vec_i1Ty = VectorType::get(IntegerType::getInt1Ty(ValTy->getContext()), VF); Cost += TTI.getScalarizationOverhead( - Vec_i1Ty, APInt::getAllOnesValue(VF.getKnownMinValue()), + Vec_i1Ty, APInt::getAllOnes(VF.getKnownMinValue()), /*Insert=*/false, /*Extract=*/true); Cost += TTI.getCFInstrCost(Instruction::Br, TTI::TCK_RecipThroughput); @@ -7051,7 +7129,7 @@ LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, auto *VectorTy = cast<VectorType>(ToVectorTy(ValTy, VF)); Value *Ptr = getLoadStorePointerOperand(I); unsigned AS = getLoadStoreAddressSpace(I); - int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); + int ConsecutiveStride = Legal->isConsecutivePtr(ValTy, Ptr); enum TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && @@ -7132,18 +7210,16 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, unsigned InterleaveFactor = Group->getFactor(); auto *WideVecTy = VectorType::get(ValTy, VF * InterleaveFactor); - // Holds the indices of existing members in an interleaved load group. - // An interleaved store group doesn't need this as it doesn't allow gaps. + // Holds the indices of existing members in the interleaved group. SmallVector<unsigned, 4> Indices; - if (isa<LoadInst>(I)) { - for (unsigned i = 0; i < InterleaveFactor; i++) - if (Group->getMember(i)) - Indices.push_back(i); - } + for (unsigned IF = 0; IF < InterleaveFactor; IF++) + if (Group->getMember(IF)) + Indices.push_back(IF); // Calculate the cost of the whole interleaved group. bool UseMaskForGaps = - Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed(); + (Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed()) || + (isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor())); InstructionCost Cost = TTI.getInterleavedMemoryOpCost( I->getOpcode(), WideVecTy, Group->getFactor(), Indices, Group->getAlign(), AS, TTI::TCK_RecipThroughput, Legal->isMaskRequired(I), UseMaskForGaps); @@ -7225,8 +7301,41 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy); Instruction *Op0, *Op1; - if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) && - !TheLoop->isLoopInvariant(RedOp)) { + if (RedOp && + match(RedOp, + m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) && + match(Op0, m_ZExtOrSExt(m_Value())) && + Op0->getOpcode() == Op1->getOpcode() && + Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() && + !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1) && + (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) { + + // Matched reduce(ext(mul(ext(A), ext(B))) + // Note that the extend opcodes need to all match, or if A==B they will have + // been converted to zext(mul(sext(A), sext(A))) as it is known positive, + // which is equally fine. + bool IsUnsigned = isa<ZExtInst>(Op0); + auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy); + auto *MulType = VectorType::get(Op0->getType(), VectorTy); + + InstructionCost ExtCost = + TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType, + TTI::CastContextHint::None, CostKind, Op0); + InstructionCost MulCost = + TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind); + InstructionCost Ext2Cost = + TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, MulType, + TTI::CastContextHint::None, CostKind, RedOp); + + InstructionCost RedCost = TTI.getExtendedAddReductionCost( + /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, + CostKind); + + if (RedCost.isValid() && + RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost) + return I == RetI ? RedCost : 0; + } else if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) && + !TheLoop->isLoopInvariant(RedOp)) { // Matched reduce(ext(A)) bool IsUnsigned = isa<ZExtInst>(RedOp); auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy); @@ -7260,7 +7369,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost) return I == RetI ? RedCost : 0; - } else { + } else if (!match(I, m_ZExtOrSExt(m_Value()))) { // Matched reduce(mul()) InstructionCost MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); @@ -7319,9 +7428,14 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, Type *VectorTy; InstructionCost C = getInstructionCost(I, VF, VectorTy); - bool TypeNotScalarized = - VF.isVector() && VectorTy->isVectorTy() && - TTI.getNumberOfParts(VectorTy) < VF.getKnownMinValue(); + bool TypeNotScalarized = false; + if (VF.isVector() && VectorTy->isVectorTy()) { + unsigned NumParts = TTI.getNumberOfParts(VectorTy); + if (NumParts) + TypeNotScalarized = NumParts < VF.getKnownMinValue(); + else + C = InstructionCost::getInvalid(); + } return VectorizationCostTy(C, TypeNotScalarized); } @@ -7342,8 +7456,8 @@ LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I, if (!RetTy->isVoidTy() && (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore())) Cost += TTI.getScalarizationOverhead( - cast<VectorType>(RetTy), APInt::getAllOnesValue(VF.getKnownMinValue()), - true, false); + cast<VectorType>(RetTy), APInt::getAllOnes(VF.getKnownMinValue()), true, + false); // Some targets keep addresses scalar. if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing()) @@ -7355,7 +7469,7 @@ LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I, // Collect operands to consider. CallInst *CI = dyn_cast<CallInst>(I); - Instruction::op_range Ops = CI ? CI->arg_operands() : I->operands(); + Instruction::op_range Ops = CI ? CI->args() : I->operands(); // Skip operands that do not require extraction/scalarization and do not incur // any overhead. @@ -7406,8 +7520,8 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { // We assume that widening is the best solution when possible. if (memoryInstructionCanBeWidened(&I, VF)) { InstructionCost Cost = getConsecutiveMemOpCost(&I, VF); - int ConsecutiveStride = - Legal->isConsecutivePtr(getLoadStorePointerOperand(&I)); + int ConsecutiveStride = Legal->isConsecutivePtr( + getLoadStoreType(&I), getLoadStorePointerOperand(&I)); assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && "Expected consecutive stride."); InstWidening Decision = @@ -7594,8 +7708,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF); return ( TTI.getScalarizationOverhead( - Vec_i1Ty, APInt::getAllOnesValue(VF.getFixedValue()), false, - true) + + Vec_i1Ty, APInt::getAllOnes(VF.getFixedValue()), false, true) + (TTI.getCFInstrCost(Instruction::Br, CostKind) * VF.getFixedValue())); } else if (I->getParent() == TheLoop->getLoopLatch() || VF.isScalar()) // The back-edge branch will remain, as will all scalar branches. @@ -7908,7 +8021,7 @@ bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { // Check if the pointer operand of a load or store instruction is // consecutive. if (auto *Ptr = getLoadStorePointerOperand(Inst)) - return Legal->isConsecutivePtr(Ptr); + return Legal->isConsecutivePtr(getLoadStoreType(Inst), Ptr); return false; } @@ -8034,7 +8147,7 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { return None; // Invalidate interleave groups if all blocks of loop will be predicated. - if (CM.blockNeedsPredication(OrigLoop->getHeader()) && + if (CM.blockNeedsPredicationForAnyReason(OrigLoop->getHeader()) && !useMaskedInterleavedAccesses(*TTI)) { LLVM_DEBUG( dbgs() @@ -8120,28 +8233,30 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { return SelectedVF; } -void LoopVectorizationPlanner::setBestPlan(ElementCount VF, unsigned UF) { - LLVM_DEBUG(dbgs() << "Setting best plan to VF=" << VF << ", UF=" << UF - << '\n'); - BestVF = VF; - BestUF = UF; +VPlan &LoopVectorizationPlanner::getBestPlanFor(ElementCount VF) const { + assert(count_if(VPlans, + [VF](const VPlanPtr &Plan) { return Plan->hasVF(VF); }) == + 1 && + "Best VF has not a single VPlan."); - erase_if(VPlans, [VF](const VPlanPtr &Plan) { - return !Plan->hasVF(VF); - }); - assert(VPlans.size() == 1 && "Best VF has not a single VPlan."); + for (const VPlanPtr &Plan : VPlans) { + if (Plan->hasVF(VF)) + return *Plan.get(); + } + llvm_unreachable("No plan found!"); } -void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV, +void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, + VPlan &BestVPlan, + InnerLoopVectorizer &ILV, DominatorTree *DT) { + LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF << ", UF=" << BestUF + << '\n'); + // Perform the actual loop transformation. // 1. Create a new empty loop. Unlink the old loop and connect the new one. - assert(BestVF.hasValue() && "Vectorization Factor is missing"); - assert(VPlans.size() == 1 && "Not a single VPlan to execute."); - - VPTransformState State{ - *BestVF, BestUF, LI, DT, ILV.Builder, &ILV, VPlans.front().get()}; + VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan}; State.CFG.PrevBB = ILV.createVectorizedLoopSkeleton(); State.TripCount = ILV.getOrCreateTripCount(nullptr); State.CanonicalIV = ILV.Induction; @@ -8157,7 +8272,7 @@ void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV, //===------------------------------------------------===// // 2. Copy and widen instructions from the old loop into the new loop. - VPlans.front()->execute(&State); + BestVPlan.execute(&State); // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. @@ -8237,21 +8352,19 @@ Value *InnerLoopUnroller::reverseVector(Value *Vec) { return Vec; } Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } -Value *InnerLoopUnroller::getStepVector(Value *Val, int StartIdx, Value *Step, +Value *InnerLoopUnroller::getStepVector(Value *Val, Value *StartIdx, + Value *Step, Instruction::BinaryOps BinOp) { // When unrolling and the VF is 1, we only need to add a simple scalar. Type *Ty = Val->getType(); assert(!Ty->isVectorTy() && "Val must be a scalar"); if (Ty->isFloatingPointTy()) { - Constant *C = ConstantFP::get(Ty, (double)StartIdx); - // Floating-point operations inherit FMF via the builder's flags. - Value *MulOp = Builder.CreateFMul(C, Step); + Value *MulOp = Builder.CreateFMul(StartIdx, Step); return Builder.CreateBinOp(BinOp, Val, MulOp); } - Constant *C = ConstantInt::get(Ty, StartIdx); - return Builder.CreateAdd(Val, Builder.CreateMul(C, Step), "induction"); + return Builder.CreateAdd(Val, Builder.CreateMul(StartIdx, Step), "induction"); } static void AddRuntimeUnrollDisableMetaData(Loop *L) { @@ -8326,7 +8439,9 @@ BasicBlock *EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { OldInduction = Legal->getPrimaryInduction(); Type *IdxTy = Legal->getWidestInductionType(); Value *StartIdx = ConstantInt::get(IdxTy, 0); - Constant *Step = ConstantInt::get(IdxTy, VF.getKnownMinValue() * UF); + + IRBuilder<> B(&*Lp->getLoopPreheader()->getFirstInsertionPt()); + Value *Step = getRuntimeVF(B, IdxTy, VF * UF); Value *CountRoundDown = getOrCreateVectorTripCount(Lp); EPI.VectorTripCount = CountRoundDown; Induction = @@ -8344,9 +8459,9 @@ BasicBlock *EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { void EpilogueVectorizerMainLoop::printDebugTracesAtStart() { LLVM_DEBUG({ dbgs() << "Create Skeleton for epilogue vectorized loop (first pass)\n" - << "Main Loop VF:" << EPI.MainLoopVF.getKnownMinValue() + << "Main Loop VF:" << EPI.MainLoopVF << ", Main Loop UF:" << EPI.MainLoopUF - << ", Epilogue Loop VF:" << EPI.EpilogueVF.getKnownMinValue() + << ", Epilogue Loop VF:" << EPI.EpilogueVF << ", Epilogue Loop UF:" << EPI.EpilogueUF << "\n"; }); } @@ -8361,8 +8476,7 @@ BasicBlock *EpilogueVectorizerMainLoop::emitMinimumIterationCountCheck( Loop *L, BasicBlock *Bypass, bool ForEpilogue) { assert(L && "Expected valid Loop."); assert(Bypass && "Expected valid bypass basic block."); - unsigned VFactor = - ForEpilogue ? EPI.EpilogueVF.getKnownMinValue() : VF.getKnownMinValue(); + ElementCount VFactor = ForEpilogue ? EPI.EpilogueVF : VF; unsigned UFactor = ForEpilogue ? EPI.EpilogueUF : UF; Value *Count = getOrCreateTripCount(L); // Reuse existing vector loop preheader for TC checks. @@ -8376,7 +8490,7 @@ BasicBlock *EpilogueVectorizerMainLoop::emitMinimumIterationCountCheck( ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT; Value *CheckMinIters = Builder.CreateICmp( - P, Count, ConstantInt::get(Count->getType(), VFactor * UFactor), + P, Count, createStepForVF(Builder, Count->getType(), VFactor, UFactor), "min.iters.check"); if (!ForEpilogue) @@ -8528,11 +8642,11 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( auto P = Cost->requiresScalarEpilogue(EPI.EpilogueVF) ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT; - Value *CheckMinIters = Builder.CreateICmp( - P, Count, - ConstantInt::get(Count->getType(), - EPI.EpilogueVF.getKnownMinValue() * EPI.EpilogueUF), - "min.epilog.iters.check"); + Value *CheckMinIters = + Builder.CreateICmp(P, Count, + createStepForVF(Builder, Count->getType(), + EPI.EpilogueVF, EPI.EpilogueUF), + "min.epilog.iters.check"); ReplaceInstWithInst( Insert->getTerminator(), @@ -8545,7 +8659,7 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( void EpilogueVectorizerEpilogueLoop::printDebugTracesAtStart() { LLVM_DEBUG({ dbgs() << "Create Skeleton for epilogue vectorized loop (second pass)\n" - << "Epilogue Loop VF:" << EPI.EpilogueVF.getKnownMinValue() + << "Epilogue Loop VF:" << EPI.EpilogueVF << ", Epilogue Loop UF:" << EPI.EpilogueUF << "\n"; }); } @@ -8643,7 +8757,7 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { VPValue *BlockMask = nullptr; if (OrigLoop->getHeader() == BB) { - if (!CM.blockNeedsPredication(BB)) + if (!CM.blockNeedsPredicationForAnyReason(BB)) return BlockMaskCache[BB] = BlockMask; // Loop incoming mask is all-one. // Create the block in mask as the first non-phi instruction in the block. @@ -8658,9 +8772,9 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { if (Legal->getPrimaryInduction()) IV = Plan->getOrAddVPValue(Legal->getPrimaryInduction()); else { - auto IVRecipe = new VPWidenCanonicalIVRecipe(); + auto *IVRecipe = new VPWidenCanonicalIVRecipe(); Builder.getInsertBlock()->insert(IVRecipe, NewInsertionPoint); - IV = IVRecipe->getVPSingleValue(); + IV = IVRecipe; } VPValue *BTC = Plan->getOrCreateBackedgeTakenCount(); bool TailFolded = !CM.isScalarEpilogueAllowed(); @@ -8723,12 +8837,21 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, if (Legal->isMaskRequired(I)) Mask = createBlockInMask(I->getParent(), Plan); + // Determine if the pointer operand of the access is either consecutive or + // reverse consecutive. + LoopVectorizationCostModel::InstWidening Decision = + CM.getWideningDecision(I, Range.Start); + bool Reverse = Decision == LoopVectorizationCostModel::CM_Widen_Reverse; + bool Consecutive = + Reverse || Decision == LoopVectorizationCostModel::CM_Widen; + if (LoadInst *Load = dyn_cast<LoadInst>(I)) - return new VPWidenMemoryInstructionRecipe(*Load, Operands[0], Mask); + return new VPWidenMemoryInstructionRecipe(*Load, Operands[0], Mask, + Consecutive, Reverse); StoreInst *Store = cast<StoreInst>(I); return new VPWidenMemoryInstructionRecipe(*Store, Operands[1], Operands[0], - Mask); + Mask, Consecutive, Reverse); } VPWidenIntOrFpInductionRecipe * @@ -8844,7 +8967,7 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, if (!LoopVectorizationPlanner::getDecisionAndClampRange(willWiden, Range)) return nullptr; - ArrayRef<VPValue *> Ops = Operands.take_front(CI->getNumArgOperands()); + ArrayRef<VPValue *> Ops = Operands.take_front(CI->arg_size()); return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end())); } @@ -9183,6 +9306,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( RecipeBuilder.recordRecipeOf(R); // For min/max reducitons, where we have a pair of icmp/select, we also // need to record the ICmp recipe, so it can be removed later. + assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && + "Only min/max recurrences allowed for inloop reductions"); if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) RecipeBuilder.recordRecipeOf(cast<Instruction>(R->getOperand(0))); } @@ -9211,22 +9336,27 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // visit each basic block after having visited its predecessor basic blocks. // --------------------------------------------------------------------------- - // Create a dummy pre-entry VPBasicBlock to start building the VPlan. auto Plan = std::make_unique<VPlan>(); - VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); - Plan->setEntry(VPBB); // Scan the body of the loop in a topological order to visit each basic block // after having visited its predecessor basic blocks. LoopBlocksDFS DFS(OrigLoop); DFS.perform(LI); + VPBasicBlock *VPBB = nullptr; + VPBasicBlock *HeaderVPBB = nullptr; + SmallVector<VPWidenIntOrFpInductionRecipe *> InductionsToMove; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { // Relevant instructions from basic block BB will be grouped into VPRecipe // ingredients and fill a new VPBasicBlock. unsigned VPBBsForBB = 0; auto *FirstVPBBForBB = new VPBasicBlock(BB->getName()); - VPBlockUtils::insertBlockAfter(FirstVPBBForBB, VPBB); + if (VPBB) + VPBlockUtils::insertBlockAfter(FirstVPBBForBB, VPBB); + else { + Plan->setEntry(FirstVPBBForBB); + HeaderVPBB = FirstVPBBForBB; + } VPBB = FirstVPBBForBB; Builder.setInsertPoint(VPBB); @@ -9268,6 +9398,17 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( Plan->addVPValue(UV, Def); } + if (isa<VPWidenIntOrFpInductionRecipe>(Recipe) && + HeaderVPBB->getFirstNonPhi() != VPBB->end()) { + // Keep track of VPWidenIntOrFpInductionRecipes not in the phi section + // of the header block. That can happen for truncates of induction + // variables. Those recipes are moved to the phi section of the header + // block after applying SinkAfter, which relies on the original + // position of the trunc. + assert(isa<TruncInst>(Instr)); + InductionsToMove.push_back( + cast<VPWidenIntOrFpInductionRecipe>(Recipe)); + } RecipeBuilder.setRecipe(Instr, Recipe); VPBB->appendRecipe(Recipe); continue; @@ -9285,17 +9426,11 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( } } + assert(isa<VPBasicBlock>(Plan->getEntry()) && + !Plan->getEntry()->getEntryBasicBlock()->empty() && + "entry block must be set to a non-empty VPBasicBlock"); RecipeBuilder.fixHeaderPhis(); - // Discard empty dummy pre-entry VPBasicBlock. Note that other VPBasicBlocks - // may also be empty, such as the last one VPBB, reflecting original - // basic-blocks with no recipes. - VPBasicBlock *PreEntry = cast<VPBasicBlock>(Plan->getEntry()); - assert(PreEntry->empty() && "Expecting empty pre-entry block."); - VPBlockBase *Entry = Plan->setEntry(PreEntry->getSingleSuccessor()); - VPBlockUtils::disconnectBlocks(PreEntry, Entry); - delete PreEntry; - // --------------------------------------------------------------------------- // Transform initial VPlan: Apply previously taken decisions, in order, to // bring the VPlan to its final state. @@ -9364,6 +9499,14 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( } } + // Now that sink-after is done, move induction recipes for optimized truncates + // to the phi section of the header block. + for (VPWidenIntOrFpInductionRecipe *Ind : InductionsToMove) + Ind->moveBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi()); + + // Adjust the recipes for any inloop reductions. + adjustRecipesForReductions(VPBB, Plan, RecipeBuilder, Range.Start); + // Introduce a recipe to combine the incoming and previous values of a // first-order recurrence. for (VPRecipeBase &R : Plan->getEntry()->getEntryBasicBlock()->phis()) { @@ -9371,16 +9514,20 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( if (!RecurPhi) continue; + VPRecipeBase *PrevRecipe = RecurPhi->getBackedgeRecipe(); + VPBasicBlock *InsertBlock = PrevRecipe->getParent(); + auto *Region = GetReplicateRegion(PrevRecipe); + if (Region) + InsertBlock = cast<VPBasicBlock>(Region->getSingleSuccessor()); + if (Region || PrevRecipe->isPhi()) + Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi()); + else + Builder.setInsertPoint(InsertBlock, std::next(PrevRecipe->getIterator())); + auto *RecurSplice = cast<VPInstruction>( Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice, {RecurPhi, RecurPhi->getBackedgeValue()})); - VPRecipeBase *PrevRecipe = RecurPhi->getBackedgeRecipe(); - if (auto *Region = GetReplicateRegion(PrevRecipe)) { - VPBasicBlock *Succ = cast<VPBasicBlock>(Region->getSingleSuccessor()); - RecurSplice->moveBefore(*Succ, Succ->getFirstNonPhi()); - } else - RecurSplice->moveAfter(PrevRecipe); RecurPhi->replaceAllUsesWith(RecurSplice); // Set the first operand of RecurSplice to RecurPhi again, after replacing // all users. @@ -9418,22 +9565,9 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( } } - // Adjust the recipes for any inloop reductions. - adjustRecipesForInLoopReductions(Plan, RecipeBuilder, Range.Start); - - // Finally, if tail is folded by masking, introduce selects between the phi - // and the live-out instruction of each reduction, at the end of the latch. - if (CM.foldTailByMasking() && !Legal->getReductionVars().empty()) { - Builder.setInsertPoint(VPBB); - auto *Cond = RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan); - for (auto &Reduction : Legal->getReductionVars()) { - if (CM.isInLoopReduction(Reduction.first)) - continue; - VPValue *Phi = Plan->getOrAddVPValue(Reduction.first); - VPValue *Red = Plan->getOrAddVPValue(Reduction.second.getLoopExitInstr()); - Builder.createNaryOp(Instruction::Select, {Cond, Red, Phi}); - } - } + // From this point onwards, VPlan-to-VPlan transformations may change the plan + // in ways that accessing values using original IR values is incorrect. + Plan->disableValue2VPValue(); VPlanTransforms::sinkScalarOperands(*Plan); VPlanTransforms::mergeReplicateRegions(*Plan); @@ -9451,6 +9585,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( RSO.flush(); Plan->setName(PlanName); + assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; } @@ -9489,12 +9624,14 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { return Plan; } -// Adjust the recipes for any inloop reductions. The chain of instructions -// leading from the loop exit instr to the phi need to be converted to -// reductions, with one operand being vector and the other being the scalar -// reduction chain. -void LoopVectorizationPlanner::adjustRecipesForInLoopReductions( - VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) { +// Adjust the recipes for reductions. For in-loop reductions the chain of +// instructions leading from the loop exit instr to the phi need to be converted +// to reductions, with one operand being vector and the other being the scalar +// reduction chain. For other reductions, a select is introduced between the phi +// and live-out recipes when folding the tail. +void LoopVectorizationPlanner::adjustRecipesForReductions( + VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, + ElementCount MinVF) { for (auto &Reduction : CM.getInLoopReductionChains()) { PHINode *Phi = Reduction.first; RecurrenceDescriptor &RdxDesc = Legal->getReductionVars()[Phi]; @@ -9514,6 +9651,8 @@ void LoopVectorizationPlanner::adjustRecipesForInLoopReductions( VPValue *ChainOp = Plan->getVPValue(Chain); unsigned FirstOpId; + assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && + "Only min/max recurrences allowed for inloop reductions"); if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { assert(isa<VPWidenSelectRecipe>(WidenRecipe) && "Expected to replace a VPWidenSelectSC"); @@ -9551,6 +9690,21 @@ void LoopVectorizationPlanner::adjustRecipesForInLoopReductions( Chain = R; } } + + // If tail is folded by masking, introduce selects between the phi + // and the live-out instruction of each reduction, at the end of the latch. + if (CM.foldTailByMasking()) { + for (VPRecipeBase &R : Plan->getEntry()->getEntryBasicBlock()->phis()) { + VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); + if (!PhiR || PhiR->isInLoop()) + continue; + Builder.setInsertPoint(LatchVPBB); + VPValue *Cond = + RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan); + VPValue *Red = PhiR->getBackedgeValue(); + Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR}); + } + } } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -9565,9 +9719,22 @@ void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent, O << ", "; Mask->printAsOperand(O, SlotTracker); } - for (unsigned i = 0; i < IG->getFactor(); ++i) - if (Instruction *I = IG->getMember(i)) - O << "\n" << Indent << " " << VPlanIngredient(I) << " " << i; + + unsigned OpIdx = 0; + for (unsigned i = 0; i < IG->getFactor(); ++i) { + if (!IG->getMember(i)) + continue; + if (getNumStoreOperands() > 0) { + O << "\n" << Indent << " store "; + getOperand(1 + OpIdx)->printAsOperand(O, SlotTracker); + O << " to index " << i; + } else { + O << "\n" << Indent << " "; + getVPValue(OpIdx)->printAsOperand(O, SlotTracker); + O << " = load from index " << i; + } + ++OpIdx; + } } #endif @@ -9651,17 +9818,20 @@ void VPInterleaveRecipe::execute(VPTransformState &State) { void VPReductionRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Reduction being replicated."); Value *PrevInChain = State.get(getChainOp(), 0); + RecurKind Kind = RdxDesc->getRecurrenceKind(); + bool IsOrdered = State.ILV->useOrderedReductions(*RdxDesc); + // Propagate the fast-math flags carried by the underlying instruction. + IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); + State.Builder.setFastMathFlags(RdxDesc->getFastMathFlags()); for (unsigned Part = 0; Part < State.UF; ++Part) { - RecurKind Kind = RdxDesc->getRecurrenceKind(); - bool IsOrdered = State.ILV->useOrderedReductions(*RdxDesc); Value *NewVecOp = State.get(getVecOp(), Part); if (VPValue *Cond = getCondOp()) { Value *NewCond = State.get(Cond, Part); VectorType *VecTy = cast<VectorType>(NewVecOp->getType()); - Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( + Value *Iden = RdxDesc->getRecurrenceIdentity( Kind, VecTy->getElementType(), RdxDesc->getFastMathFlags()); - Constant *IdenVec = - ConstantVector::getSplat(VecTy->getElementCount(), Iden); + Value *IdenVec = + State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden); Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, IdenVec); NewVecOp = Select; } @@ -9673,8 +9843,8 @@ void VPReductionRecipe::execute(VPTransformState &State) { PrevInChain); else NewRed = State.Builder.CreateBinOp( - (Instruction::BinaryOps)getUnderlyingInstr()->getOpcode(), - PrevInChain, NewVecOp); + (Instruction::BinaryOps)RdxDesc->getOpcode(Kind), PrevInChain, + NewVecOp); PrevInChain = NewRed; } else { PrevInChain = State.get(getChainOp(), Part); @@ -9686,11 +9856,10 @@ void VPReductionRecipe::execute(VPTransformState &State) { NewRed, PrevInChain); } else if (IsOrdered) NextInChain = NewRed; - else { + else NextInChain = State.Builder.CreateBinOp( - (Instruction::BinaryOps)getUnderlyingInstr()->getOpcode(), NewRed, + (Instruction::BinaryOps)RdxDesc->getOpcode(Kind), NewRed, PrevInChain); - } State.set(this, NextInChain, Part); } } @@ -9803,7 +9972,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { VPValue *StoredValue = isStore() ? getStoredValue() : nullptr; State.ILV->vectorizeMemoryInstruction( &Ingredient, State, StoredValue ? nullptr : getVPSingleValue(), getAddr(), - StoredValue, getMask()); + StoredValue, getMask(), Consecutive, Reverse); } // Determine how to lower the scalar epilogue, which depends on 1) optimising @@ -9969,7 +10138,7 @@ static bool processLoopInVPlanNativePath( VectorizationFactor::Disabled() == VF) return false; - LVP.setBestPlan(VF.Width, 1); + VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); { GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, @@ -9978,7 +10147,7 @@ static bool processLoopInVPlanNativePath( &CM, BFI, PSI, Checks); LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \"" << L->getHeader()->getParent()->getName() << "\"\n"); - LVP.executePlan(LB, DT); + LVP.executePlan(VF.Width, 1, BestPlan, LB, DT); } // Mark the loop as already vectorized to avoid vectorizing again. @@ -10149,7 +10318,13 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - if (!LVL.canVectorizeFPMath(EnableStrictReductions)) { + bool AllowOrderedReductions; + // If the flag is set, use that instead and override the TTI behaviour. + if (ForceOrderedReductions.getNumOccurrences() > 0) + AllowOrderedReductions = ForceOrderedReductions; + else + AllowOrderedReductions = TTI->enableOrderedReductions(); + if (!LVL.canVectorizeFPMath(AllowOrderedReductions)) { ORE->emit([&]() { auto *ExactFPMathInst = Requirements.getExactFPInst(); return OptimizationRemarkAnalysisFPCommute(DEBUG_TYPE, "CantReorderFPOps", @@ -10294,7 +10469,6 @@ bool LoopVectorizePass::processLoop(Loop *L) { F->getParent()->getDataLayout()); if (!VF.Width.isScalar() || IC > 1) Checks.Create(L, *LVL.getLAI(), PSE.getUnionPredicate()); - LVP.setBestPlan(VF.Width, IC); using namespace ore; if (!VectorizeLoop) { @@ -10303,7 +10477,9 @@ bool LoopVectorizePass::processLoop(Loop *L) { // interleave it. InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, ORE, IC, &LVL, &CM, BFI, PSI, Checks); - LVP.executePlan(Unroller, DT); + + VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); + LVP.executePlan(VF.Width, IC, BestPlan, Unroller, DT); ORE->emit([&]() { return OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), @@ -10322,14 +10498,13 @@ bool LoopVectorizePass::processLoop(Loop *L) { // The first pass vectorizes the main loop and creates a scalar epilogue // to be vectorized by executing the plan (potentially with a different // factor) again shortly afterwards. - EpilogueLoopVectorizationInfo EPI(VF.Width.getKnownMinValue(), IC, - EpilogueVF.Width.getKnownMinValue(), - 1); + EpilogueLoopVectorizationInfo EPI(VF.Width, IC, EpilogueVF.Width, 1); EpilogueVectorizerMainLoop MainILV(L, PSE, LI, DT, TLI, TTI, AC, ORE, EPI, &LVL, &CM, BFI, PSI, Checks); - LVP.setBestPlan(EPI.MainLoopVF, EPI.MainLoopUF); - LVP.executePlan(MainILV, DT); + VPlan &BestMainPlan = LVP.getBestPlanFor(EPI.MainLoopVF); + LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF, BestMainPlan, MainILV, + DT); ++LoopsVectorized; simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */); @@ -10337,13 +10512,15 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Second pass vectorizes the epilogue and adjusts the control flow // edges from the first pass. - LVP.setBestPlan(EPI.EpilogueVF, EPI.EpilogueUF); EPI.MainLoopVF = EPI.EpilogueVF; EPI.MainLoopUF = EPI.EpilogueUF; EpilogueVectorizerEpilogueLoop EpilogILV(L, PSE, LI, DT, TLI, TTI, AC, ORE, EPI, &LVL, &CM, BFI, PSI, Checks); - LVP.executePlan(EpilogILV, DT); + + VPlan &BestEpiPlan = LVP.getBestPlanFor(EPI.EpilogueVF); + LVP.executePlan(EPI.EpilogueVF, EPI.EpilogueUF, BestEpiPlan, EpilogILV, + DT); ++LoopsEpilogueVectorized; if (!MainILV.areSafetyChecksAdded()) @@ -10351,7 +10528,9 @@ bool LoopVectorizePass::processLoop(Loop *L) { } else { InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, IC, &LVL, &CM, BFI, PSI, Checks); - LVP.executePlan(LB, DT); + + VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); + LVP.executePlan(VF.Width, IC, BestPlan, LB, DT); ++LoopsVectorized; // Add metadata to disable runtime unrolling a scalar loop when there @@ -10469,15 +10648,12 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DB = AM.getResult<DemandedBitsAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - MemorySSA *MSSA = EnableMSSALoopDependency - ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() - : nullptr; auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, MSSA}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, + TLI, TTI, nullptr, nullptr, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }; auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); @@ -10501,3 +10677,14 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, PA.preserveSet<CFGAnalyses>(); return PA; } + +void LoopVectorizePass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LoopVectorizePass> *>(this)->printPipeline( + OS, MapClassName2PassName); + + OS << "<"; + OS << (InterleaveOnlyWhenForced ? "" : "no-") << "interleave-forced-only;"; + OS << (VectorizeOnlyWhenForced ? "" : "no-") << "vectorize-forced-only;"; + OS << ">"; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 1d06bc7d79a7..e3ef0b794f68 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/PriorityQueue.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" @@ -200,12 +201,39 @@ static bool isValidElementType(Type *Ty) { !Ty->isPPC_FP128Ty(); } +/// \returns True if the value is a constant (but not globals/constant +/// expressions). +static bool isConstant(Value *V) { + return isa<Constant>(V) && !isa<ConstantExpr>(V) && !isa<GlobalValue>(V); +} + +/// Checks if \p V is one of vector-like instructions, i.e. undef, +/// insertelement/extractelement with constant indices for fixed vector type or +/// extractvalue instruction. +static bool isVectorLikeInstWithConstOps(Value *V) { + if (!isa<InsertElementInst, ExtractElementInst>(V) && + !isa<ExtractValueInst, UndefValue>(V)) + return false; + auto *I = dyn_cast<Instruction>(V); + if (!I || isa<ExtractValueInst>(I)) + return true; + if (!isa<FixedVectorType>(I->getOperand(0)->getType())) + return false; + if (isa<ExtractElementInst>(I)) + return isConstant(I->getOperand(1)); + assert(isa<InsertElementInst>(V) && "Expected only insertelement."); + return isConstant(I->getOperand(2)); +} + /// \returns true if all of the instructions in \p VL are in the same block or /// false otherwise. static bool allSameBlock(ArrayRef<Value *> VL) { Instruction *I0 = dyn_cast<Instruction>(VL[0]); if (!I0) return false; + if (all_of(VL, isVectorLikeInstWithConstOps)) + return true; + BasicBlock *BB = I0->getParent(); for (int I = 1, E = VL.size(); I < E; I++) { auto *II = dyn_cast<Instruction>(VL[I]); @@ -218,12 +246,6 @@ static bool allSameBlock(ArrayRef<Value *> VL) { return true; } -/// \returns True if the value is a constant (but not globals/constant -/// expressions). -static bool isConstant(Value *V) { - return isa<Constant>(V) && !isa<ConstantExpr>(V) && !isa<GlobalValue>(V); -} - /// \returns True if all of the values in \p VL are constants (but not /// globals/constant expressions). static bool allConstant(ArrayRef<Value *> VL) { @@ -232,12 +254,21 @@ static bool allConstant(ArrayRef<Value *> VL) { return all_of(VL, isConstant); } -/// \returns True if all of the values in \p VL are identical. +/// \returns True if all of the values in \p VL are identical or some of them +/// are UndefValue. static bool isSplat(ArrayRef<Value *> VL) { - for (unsigned i = 1, e = VL.size(); i < e; ++i) - if (VL[i] != VL[0]) + Value *FirstNonUndef = nullptr; + for (Value *V : VL) { + if (isa<UndefValue>(V)) + continue; + if (!FirstNonUndef) { + FirstNonUndef = V; + continue; + } + if (V != FirstNonUndef) return false; - return true; + } + return FirstNonUndef != nullptr; } /// \returns True if \p I is commutative, handles CmpInst and BinaryOperator. @@ -295,8 +326,10 @@ static bool isCommutative(Instruction *I) { /// TODO: Can we split off and reuse the shuffle mask detection from /// TargetTransformInfo::getInstructionThroughput? static Optional<TargetTransformInfo::ShuffleKind> -isShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { +isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { auto *EI0 = cast<ExtractElementInst>(VL[0]); + if (isa<ScalableVectorType>(EI0->getVectorOperandType())) + return None; unsigned Size = cast<FixedVectorType>(EI0->getVectorOperandType())->getNumElements(); Value *Vec1 = nullptr; @@ -504,7 +537,7 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, case Instruction::Call: { CallInst *CI = cast<CallInst>(UserInst); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { + for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) { if (hasVectorInstrinsicScalarOpd(ID, i)) return (CI->getArgOperand(i) == Scalar); } @@ -535,13 +568,67 @@ static bool isSimple(Instruction *I) { return true; } +/// Shuffles \p Mask in accordance with the given \p SubMask. +static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask) { + if (SubMask.empty()) + return; + if (Mask.empty()) { + Mask.append(SubMask.begin(), SubMask.end()); + return; + } + SmallVector<int> NewMask(SubMask.size(), UndefMaskElem); + int TermValue = std::min(Mask.size(), SubMask.size()); + for (int I = 0, E = SubMask.size(); I < E; ++I) { + if (SubMask[I] >= TermValue || SubMask[I] == UndefMaskElem || + Mask[SubMask[I]] >= TermValue) + continue; + NewMask[I] = Mask[SubMask[I]]; + } + Mask.swap(NewMask); +} + +/// Order may have elements assigned special value (size) which is out of +/// bounds. Such indices only appear on places which correspond to undef values +/// (see canReuseExtract for details) and used in order to avoid undef values +/// have effect on operands ordering. +/// The first loop below simply finds all unused indices and then the next loop +/// nest assigns these indices for undef values positions. +/// As an example below Order has two undef positions and they have assigned +/// values 3 and 7 respectively: +/// before: 6 9 5 4 9 2 1 0 +/// after: 6 3 5 4 7 2 1 0 +static void fixupOrderingIndices(SmallVectorImpl<unsigned> &Order) { + const unsigned Sz = Order.size(); + SmallBitVector UsedIndices(Sz); + SmallVector<int> MaskedIndices; + for (unsigned I = 0; I < Sz; ++I) { + if (Order[I] < Sz) + UsedIndices.set(Order[I]); + else + MaskedIndices.push_back(I); + } + if (MaskedIndices.empty()) + return; + SmallVector<int> AvailableIndices(MaskedIndices.size()); + unsigned Cnt = 0; + int Idx = UsedIndices.find_first(); + do { + AvailableIndices[Cnt] = Idx; + Idx = UsedIndices.find_next(Idx); + ++Cnt; + } while (Idx > 0); + assert(Cnt == MaskedIndices.size() && "Non-synced masked/available indices."); + for (int I = 0, E = MaskedIndices.size(); I < E; ++I) + Order[MaskedIndices[I]] = AvailableIndices[I]; +} + namespace llvm { static void inversePermutation(ArrayRef<unsigned> Indices, SmallVectorImpl<int> &Mask) { Mask.clear(); const unsigned E = Indices.size(); - Mask.resize(E, E + 1); + Mask.resize(E, UndefMaskElem); for (unsigned I = 0; I < E; ++I) Mask[Indices[I]] = I; } @@ -581,6 +668,22 @@ static Optional<int> getInsertIndex(Value *InsertInst, unsigned Offset) { return Index; } +/// Reorders the list of scalars in accordance with the given \p Order and then +/// the \p Mask. \p Order - is the original order of the scalars, need to +/// reorder scalars into an unordered state at first according to the given +/// order. Then the ordered scalars are shuffled once again in accordance with +/// the provided mask. +static void reorderScalars(SmallVectorImpl<Value *> &Scalars, + ArrayRef<int> Mask) { + assert(!Mask.empty() && "Expected non-empty mask."); + SmallVector<Value *> Prev(Scalars.size(), + UndefValue::get(Scalars.front()->getType())); + Prev.swap(Scalars); + for (unsigned I = 0, E = Prev.size(); I < E; ++I) + if (Mask[I] != UndefMaskElem) + Scalars[Mask[I]] = Prev[I]; +} + namespace slpvectorizer { /// Bottom Up SLP Vectorizer. @@ -645,13 +748,12 @@ public: void buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst = None); - /// Construct a vectorizable tree that starts at \p Roots, ignoring users for - /// the purpose of scheduling and extraction in the \p UserIgnoreLst taking - /// into account (and updating it, if required) list of externally used - /// values stored in \p ExternallyUsedValues. - void buildTree(ArrayRef<Value *> Roots, - ExtraValueToDebugLocsMap &ExternallyUsedValues, - ArrayRef<Value *> UserIgnoreLst = None); + /// Builds external uses of the vectorized scalars, i.e. the list of + /// vectorized scalars to be extracted, their lanes and their scalar users. \p + /// ExternallyUsedValues contains additional list of external uses to handle + /// vectorization of reductions. + void + buildExternalUses(const ExtraValueToDebugLocsMap &ExternallyUsedValues = {}); /// Clear the internal data structures that are created by 'buildTree'. void deleteTree() { @@ -659,8 +761,6 @@ public: ScalarToTreeEntry.clear(); MustGather.clear(); ExternalUses.clear(); - NumOpsWantToKeepOrder.clear(); - NumOpsWantToKeepOriginalOrder = 0; for (auto &Iter : BlocksSchedules) { BlockScheduling *BS = Iter.second.get(); BS->clear(); @@ -674,103 +774,28 @@ public: /// Perform LICM and CSE on the newly generated gather sequences. void optimizeGatherSequence(); - /// \returns The best order of instructions for vectorization. - Optional<ArrayRef<unsigned>> bestOrder() const { - assert(llvm::all_of( - NumOpsWantToKeepOrder, - [this](const decltype(NumOpsWantToKeepOrder)::value_type &D) { - return D.getFirst().size() == - VectorizableTree[0]->Scalars.size(); - }) && - "All orders must have the same size as number of instructions in " - "tree node."); - auto I = std::max_element( - NumOpsWantToKeepOrder.begin(), NumOpsWantToKeepOrder.end(), - [](const decltype(NumOpsWantToKeepOrder)::value_type &D1, - const decltype(NumOpsWantToKeepOrder)::value_type &D2) { - return D1.second < D2.second; - }); - if (I == NumOpsWantToKeepOrder.end() || - I->getSecond() <= NumOpsWantToKeepOriginalOrder) - return None; - - return makeArrayRef(I->getFirst()); - } - - /// Builds the correct order for root instructions. - /// If some leaves have the same instructions to be vectorized, we may - /// incorrectly evaluate the best order for the root node (it is built for the - /// vector of instructions without repeated instructions and, thus, has less - /// elements than the root node). This function builds the correct order for - /// the root node. - /// For example, if the root node is \<a+b, a+c, a+d, f+e\>, then the leaves - /// are \<a, a, a, f\> and \<b, c, d, e\>. When we try to vectorize the first - /// leaf, it will be shrink to \<a, b\>. If instructions in this leaf should - /// be reordered, the best order will be \<1, 0\>. We need to extend this - /// order for the root node. For the root node this order should look like - /// \<3, 0, 1, 2\>. This function extends the order for the reused - /// instructions. - void findRootOrder(OrdersType &Order) { - // If the leaf has the same number of instructions to vectorize as the root - // - order must be set already. - unsigned RootSize = VectorizableTree[0]->Scalars.size(); - if (Order.size() == RootSize) - return; - SmallVector<unsigned, 4> RealOrder(Order.size()); - std::swap(Order, RealOrder); - SmallVector<int, 4> Mask; - inversePermutation(RealOrder, Mask); - Order.assign(Mask.begin(), Mask.end()); - // The leaf has less number of instructions - need to find the true order of - // the root. - // Scan the nodes starting from the leaf back to the root. - const TreeEntry *PNode = VectorizableTree.back().get(); - SmallVector<const TreeEntry *, 4> Nodes(1, PNode); - SmallPtrSet<const TreeEntry *, 4> Visited; - while (!Nodes.empty() && Order.size() != RootSize) { - const TreeEntry *PNode = Nodes.pop_back_val(); - if (!Visited.insert(PNode).second) - continue; - const TreeEntry &Node = *PNode; - for (const EdgeInfo &EI : Node.UserTreeIndices) - if (EI.UserTE) - Nodes.push_back(EI.UserTE); - if (Node.ReuseShuffleIndices.empty()) - continue; - // Build the order for the parent node. - OrdersType NewOrder(Node.ReuseShuffleIndices.size(), RootSize); - SmallVector<unsigned, 4> OrderCounter(Order.size(), 0); - // The algorithm of the order extension is: - // 1. Calculate the number of the same instructions for the order. - // 2. Calculate the index of the new order: total number of instructions - // with order less than the order of the current instruction + reuse - // number of the current instruction. - // 3. The new order is just the index of the instruction in the original - // vector of the instructions. - for (unsigned I : Node.ReuseShuffleIndices) - ++OrderCounter[Order[I]]; - SmallVector<unsigned, 4> CurrentCounter(Order.size(), 0); - for (unsigned I = 0, E = Node.ReuseShuffleIndices.size(); I < E; ++I) { - unsigned ReusedIdx = Node.ReuseShuffleIndices[I]; - unsigned OrderIdx = Order[ReusedIdx]; - unsigned NewIdx = 0; - for (unsigned J = 0; J < OrderIdx; ++J) - NewIdx += OrderCounter[J]; - NewIdx += CurrentCounter[OrderIdx]; - ++CurrentCounter[OrderIdx]; - assert(NewOrder[NewIdx] == RootSize && - "The order index should not be written already."); - NewOrder[NewIdx] = I; - } - std::swap(Order, NewOrder); - } - assert(Order.size() == RootSize && - "Root node is expected or the size of the order must be the same as " - "the number of elements in the root node."); - assert(llvm::all_of(Order, - [RootSize](unsigned Val) { return Val != RootSize; }) && - "All indices must be initialized"); - } + /// Checks if the specified gather tree entry \p TE can be represented as a + /// shuffled vector entry + (possibly) permutation with other gathers. It + /// implements the checks only for possibly ordered scalars (Loads, + /// ExtractElement, ExtractValue), which can be part of the graph. + Optional<OrdersType> findReusedOrderedScalars(const TreeEntry &TE); + + /// Reorders the current graph to the most profitable order starting from the + /// root node to the leaf nodes. The best order is chosen only from the nodes + /// of the same size (vectorization factor). Smaller nodes are considered + /// parts of subgraph with smaller VF and they are reordered independently. We + /// can make it because we still need to extend smaller nodes to the wider VF + /// and we can merge reordering shuffles with the widening shuffles. + void reorderTopToBottom(); + + /// Reorders the current graph to the most profitable order starting from + /// leaves to the root. It allows to rotate small subgraphs and reduce the + /// number of reshuffles if the leaf nodes use the same order. In this case we + /// can merge the orders and just shuffle user node instead of shuffling its + /// operands. Plus, even the leaf nodes have different orders, it allows to + /// sink reordering in the graph closer to the root node and merge it later + /// during analysis. + void reorderBottomToTop(bool IgnoreReorder = false); /// \return The vector element size in bits to use when vectorizing the /// expression tree ending at \p V. If V is a store, the size is the width of @@ -793,6 +818,10 @@ public: return MinVecRegSize; } + unsigned getMinVF(unsigned Sz) const { + return std::max(2U, getMinVecRegSize() / Sz); + } + unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const { unsigned MaxVF = MaxVFOption.getNumOccurrences() ? MaxVFOption : TTI->getMaximumVF(ElemWidth, Opcode); @@ -809,7 +838,7 @@ public: /// \returns True if the VectorizableTree is both tiny and not fully /// vectorizable. We do not vectorize such trees. - bool isTreeTinyAndNotFullyVectorizable() const; + bool isTreeTinyAndNotFullyVectorizable(bool ForReduction = false) const; /// Assume that a legal-sized 'or'-reduction of shifted/zexted loaded values /// can be load combined in the backend. Load combining may not be allowed in @@ -1578,10 +1607,12 @@ private: Value *vectorizeTree(ArrayRef<Value *> VL); /// \returns the scalarization cost for this type. Scalarization in this - /// context means the creation of vectors from a group of scalars. - InstructionCost - getGatherCost(FixedVectorType *Ty, - const DenseSet<unsigned> &ShuffledIndices) const; + /// context means the creation of vectors from a group of scalars. If \p + /// NeedToShuffle is true, need to add a cost of reshuffling some of the + /// vector elements. + InstructionCost getGatherCost(FixedVectorType *Ty, + const DenseSet<unsigned> &ShuffledIndices, + bool NeedToShuffle) const; /// Checks if the gathered \p VL can be represented as shuffle(s) of previous /// tree entries. @@ -1605,7 +1636,7 @@ private: /// \returns whether the VectorizableTree is fully vectorizable and will /// be beneficial even the tree height is tiny. - bool isFullyVectorizableTinyTree() const; + bool isFullyVectorizableTinyTree(bool ForReduction) const; /// Reorder commutative or alt operands to get better probability of /// generating vectorized code. @@ -1621,14 +1652,43 @@ private: /// \returns true if the scalars in VL are equal to this entry. bool isSame(ArrayRef<Value *> VL) const { - if (VL.size() == Scalars.size()) - return std::equal(VL.begin(), VL.end(), Scalars.begin()); - return VL.size() == ReuseShuffleIndices.size() && - std::equal( - VL.begin(), VL.end(), ReuseShuffleIndices.begin(), - [this](Value *V, int Idx) { return V == Scalars[Idx]; }); + auto &&IsSame = [VL](ArrayRef<Value *> Scalars, ArrayRef<int> Mask) { + if (Mask.size() != VL.size() && VL.size() == Scalars.size()) + return std::equal(VL.begin(), VL.end(), Scalars.begin()); + return VL.size() == Mask.size() && + std::equal(VL.begin(), VL.end(), Mask.begin(), + [Scalars](Value *V, int Idx) { + return (isa<UndefValue>(V) && + Idx == UndefMaskElem) || + (Idx != UndefMaskElem && V == Scalars[Idx]); + }); + }; + if (!ReorderIndices.empty()) { + // TODO: implement matching if the nodes are just reordered, still can + // treat the vector as the same if the list of scalars matches VL + // directly, without reordering. + SmallVector<int> Mask; + inversePermutation(ReorderIndices, Mask); + if (VL.size() == Scalars.size()) + return IsSame(Scalars, Mask); + if (VL.size() == ReuseShuffleIndices.size()) { + ::addMask(Mask, ReuseShuffleIndices); + return IsSame(Scalars, Mask); + } + return false; + } + return IsSame(Scalars, ReuseShuffleIndices); } + /// \return Final vectorization factor for the node. Defined by the total + /// number of vectorized scalars, including those, used several times in the + /// entry and counted in the \a ReuseShuffleIndices, if any. + unsigned getVectorFactor() const { + if (!ReuseShuffleIndices.empty()) + return ReuseShuffleIndices.size(); + return Scalars.size(); + }; + /// A vector of scalars. ValueList Scalars; @@ -1701,6 +1761,12 @@ private: } } + /// Reorders operands of the node to the given mask \p Mask. + void reorderOperands(ArrayRef<int> Mask) { + for (ValueList &Operand : Operands) + reorderScalars(Operand, Mask); + } + /// \returns the \p OpIdx operand of this TreeEntry. ValueList &getOperand(unsigned OpIdx) { assert(OpIdx < Operands.size() && "Off bounds"); @@ -1760,19 +1826,14 @@ private: return AltOp ? AltOp->getOpcode() : 0; } - /// Update operations state of this entry if reorder occurred. - bool updateStateIfReorder() { - if (ReorderIndices.empty()) - return false; - InstructionsState S = getSameOpcode(Scalars, ReorderIndices.front()); - setOperations(S); - return true; - } - /// When ReuseShuffleIndices is empty it just returns position of \p V - /// within vector of Scalars. Otherwise, try to remap on its reuse index. + /// When ReuseReorderShuffleIndices is empty it just returns position of \p + /// V within vector of Scalars. Otherwise, try to remap on its reuse index. int findLaneForValue(Value *V) const { unsigned FoundLane = std::distance(Scalars.begin(), find(Scalars, V)); assert(FoundLane < Scalars.size() && "Couldn't find extract lane"); + if (!ReorderIndices.empty()) + FoundLane = ReorderIndices[FoundLane]; + assert(FoundLane < Scalars.size() && "Couldn't find extract lane"); if (!ReuseShuffleIndices.empty()) { FoundLane = std::distance(ReuseShuffleIndices.begin(), find(ReuseShuffleIndices, FoundLane)); @@ -1856,7 +1917,7 @@ private: TreeEntry *newTreeEntry(ArrayRef<Value *> VL, Optional<ScheduleData *> Bundle, const InstructionsState &S, const EdgeInfo &UserTreeIdx, - ArrayRef<unsigned> ReuseShuffleIndices = None, + ArrayRef<int> ReuseShuffleIndices = None, ArrayRef<unsigned> ReorderIndices = None) { TreeEntry::EntryState EntryState = Bundle ? TreeEntry::Vectorize : TreeEntry::NeedToGather; @@ -1869,7 +1930,7 @@ private: Optional<ScheduleData *> Bundle, const InstructionsState &S, const EdgeInfo &UserTreeIdx, - ArrayRef<unsigned> ReuseShuffleIndices = None, + ArrayRef<int> ReuseShuffleIndices = None, ArrayRef<unsigned> ReorderIndices = None) { assert(((!Bundle && EntryState == TreeEntry::NeedToGather) || (Bundle && EntryState != TreeEntry::NeedToGather)) && @@ -1877,12 +1938,25 @@ private: VectorizableTree.push_back(std::make_unique<TreeEntry>(VectorizableTree)); TreeEntry *Last = VectorizableTree.back().get(); Last->Idx = VectorizableTree.size() - 1; - Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); Last->State = EntryState; Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(), ReuseShuffleIndices.end()); - Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end()); - Last->setOperations(S); + if (ReorderIndices.empty()) { + Last->Scalars.assign(VL.begin(), VL.end()); + Last->setOperations(S); + } else { + // Reorder scalars and build final mask. + Last->Scalars.assign(VL.size(), nullptr); + transform(ReorderIndices, Last->Scalars.begin(), + [VL](unsigned Idx) -> Value * { + if (Idx >= VL.size()) + return UndefValue::get(VL.front()->getType()); + return VL[Idx]; + }); + InstructionsState S = getSameOpcode(Last->Scalars); + Last->setOperations(S); + Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end()); + } if (Last->State != TreeEntry::NeedToGather) { for (Value *V : VL) { assert(!getTreeEntry(V) && "Scalar already in tree!"); @@ -1965,12 +2039,9 @@ private: if (result.hasValue()) { return result.getValue(); } - MemoryLocation Loc2 = getLocation(Inst2, AA); bool aliased = true; - if (Loc1.Ptr && Loc2.Ptr && isSimple(Inst1) && isSimple(Inst2)) { - // Do the alias check. - aliased = !AA->isNoAlias(Loc1, Loc2); - } + if (Loc1.Ptr && isSimple(Inst1)) + aliased = isModOrRefSet(AA->getModRefInfo(Inst2, Loc1)); // Store the result in the cache. result = aliased; return aliased; @@ -2434,14 +2505,6 @@ private: } }; - /// Contains orders of operations along with the number of bundles that have - /// operations in this order. It stores only those orders that require - /// reordering, if reordering is not required it is counted using \a - /// NumOpsWantToKeepOriginalOrder. - DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo> NumOpsWantToKeepOrder; - /// Number of bundles that do not require reordering. - unsigned NumOpsWantToKeepOriginalOrder = 0; - // Analysis and block reference. Function *F; ScalarEvolution *SE; @@ -2540,10 +2603,8 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) { std::string Str; raw_string_ostream OS(Str); - if (isSplat(Entry->Scalars)) { - OS << "<splat> " << *Entry->Scalars[0]; - return Str; - } + if (isSplat(Entry->Scalars)) + OS << "<splat> "; for (auto V : Entry->Scalars) { OS << *V; if (llvm::any_of(R->ExternalUses, [&](const BoUpSLP::ExternalUser &EU) { @@ -2594,21 +2655,539 @@ void BoUpSLP::eraseInstructions(ArrayRef<Value *> AV) { }; } -void BoUpSLP::buildTree(ArrayRef<Value *> Roots, - ArrayRef<Value *> UserIgnoreLst) { - ExtraValueToDebugLocsMap ExternallyUsedValues; - buildTree(Roots, ExternallyUsedValues, UserIgnoreLst); +/// Reorders the given \p Reuses mask according to the given \p Mask. \p Reuses +/// contains original mask for the scalars reused in the node. Procedure +/// transform this mask in accordance with the given \p Mask. +static void reorderReuses(SmallVectorImpl<int> &Reuses, ArrayRef<int> Mask) { + assert(!Mask.empty() && Reuses.size() == Mask.size() && + "Expected non-empty mask."); + SmallVector<int> Prev(Reuses.begin(), Reuses.end()); + Prev.swap(Reuses); + for (unsigned I = 0, E = Prev.size(); I < E; ++I) + if (Mask[I] != UndefMaskElem) + Reuses[Mask[I]] = Prev[I]; } -void BoUpSLP::buildTree(ArrayRef<Value *> Roots, - ExtraValueToDebugLocsMap &ExternallyUsedValues, - ArrayRef<Value *> UserIgnoreLst) { - deleteTree(); - UserIgnoreList = UserIgnoreLst; - if (!allSameType(Roots)) +/// Reorders the given \p Order according to the given \p Mask. \p Order - is +/// the original order of the scalars. Procedure transforms the provided order +/// in accordance with the given \p Mask. If the resulting \p Order is just an +/// identity order, \p Order is cleared. +static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) { + assert(!Mask.empty() && "Expected non-empty mask."); + SmallVector<int> MaskOrder; + if (Order.empty()) { + MaskOrder.resize(Mask.size()); + std::iota(MaskOrder.begin(), MaskOrder.end(), 0); + } else { + inversePermutation(Order, MaskOrder); + } + reorderReuses(MaskOrder, Mask); + if (ShuffleVectorInst::isIdentityMask(MaskOrder)) { + Order.clear(); return; - buildTree_rec(Roots, 0, EdgeInfo()); + } + Order.assign(Mask.size(), Mask.size()); + for (unsigned I = 0, E = Mask.size(); I < E; ++I) + if (MaskOrder[I] != UndefMaskElem) + Order[MaskOrder[I]] = I; + fixupOrderingIndices(Order); +} +Optional<BoUpSLP::OrdersType> +BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { + assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); + unsigned NumScalars = TE.Scalars.size(); + OrdersType CurrentOrder(NumScalars, NumScalars); + SmallVector<int> Positions; + SmallBitVector UsedPositions(NumScalars); + const TreeEntry *STE = nullptr; + // Try to find all gathered scalars that are gets vectorized in other + // vectorize node. Here we can have only one single tree vector node to + // correctly identify order of the gathered scalars. + for (unsigned I = 0; I < NumScalars; ++I) { + Value *V = TE.Scalars[I]; + if (!isa<LoadInst, ExtractElementInst, ExtractValueInst>(V)) + continue; + if (const auto *LocalSTE = getTreeEntry(V)) { + if (!STE) + STE = LocalSTE; + else if (STE != LocalSTE) + // Take the order only from the single vector node. + return None; + unsigned Lane = + std::distance(STE->Scalars.begin(), find(STE->Scalars, V)); + if (Lane >= NumScalars) + return None; + if (CurrentOrder[Lane] != NumScalars) { + if (Lane != I) + continue; + UsedPositions.reset(CurrentOrder[Lane]); + } + // The partial identity (where only some elements of the gather node are + // in the identity order) is good. + CurrentOrder[Lane] = I; + UsedPositions.set(I); + } + } + // Need to keep the order if we have a vector entry and at least 2 scalars or + // the vectorized entry has just 2 scalars. + if (STE && (UsedPositions.count() > 1 || STE->Scalars.size() == 2)) { + auto &&IsIdentityOrder = [NumScalars](ArrayRef<unsigned> CurrentOrder) { + for (unsigned I = 0; I < NumScalars; ++I) + if (CurrentOrder[I] != I && CurrentOrder[I] != NumScalars) + return false; + return true; + }; + if (IsIdentityOrder(CurrentOrder)) { + CurrentOrder.clear(); + return CurrentOrder; + } + auto *It = CurrentOrder.begin(); + for (unsigned I = 0; I < NumScalars;) { + if (UsedPositions.test(I)) { + ++I; + continue; + } + if (*It == NumScalars) { + *It = I; + ++I; + } + ++It; + } + return CurrentOrder; + } + return None; +} + +void BoUpSLP::reorderTopToBottom() { + // Maps VF to the graph nodes. + DenseMap<unsigned, SmallPtrSet<TreeEntry *, 4>> VFToOrderedEntries; + // ExtractElement gather nodes which can be vectorized and need to handle + // their ordering. + DenseMap<const TreeEntry *, OrdersType> GathersToOrders; + // Find all reorderable nodes with the given VF. + // Currently the are vectorized loads,extracts + some gathering of extracts. + for_each(VectorizableTree, [this, &VFToOrderedEntries, &GathersToOrders]( + const std::unique_ptr<TreeEntry> &TE) { + // No need to reorder if need to shuffle reuses, still need to shuffle the + // node. + if (!TE->ReuseShuffleIndices.empty()) + return; + if (TE->State == TreeEntry::Vectorize && + isa<LoadInst, ExtractElementInst, ExtractValueInst, StoreInst, + InsertElementInst>(TE->getMainOp()) && + !TE->isAltShuffle()) { + VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + return; + } + if (TE->State == TreeEntry::NeedToGather) { + if (TE->getOpcode() == Instruction::ExtractElement && + !TE->isAltShuffle() && + isa<FixedVectorType>(cast<ExtractElementInst>(TE->getMainOp()) + ->getVectorOperandType()) && + allSameType(TE->Scalars) && allSameBlock(TE->Scalars)) { + // Check that gather of extractelements can be represented as + // just a shuffle of a single vector. + OrdersType CurrentOrder; + bool Reuse = + canReuseExtract(TE->Scalars, TE->getMainOp(), CurrentOrder); + if (Reuse || !CurrentOrder.empty()) { + VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + GathersToOrders.try_emplace(TE.get(), CurrentOrder); + return; + } + } + if (Optional<OrdersType> CurrentOrder = + findReusedOrderedScalars(*TE.get())) { + VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + GathersToOrders.try_emplace(TE.get(), *CurrentOrder); + } + } + }); + + // Reorder the graph nodes according to their vectorization factor. + for (unsigned VF = VectorizableTree.front()->Scalars.size(); VF > 1; + VF /= 2) { + auto It = VFToOrderedEntries.find(VF); + if (It == VFToOrderedEntries.end()) + continue; + // Try to find the most profitable order. We just are looking for the most + // used order and reorder scalar elements in the nodes according to this + // mostly used order. + const SmallPtrSetImpl<TreeEntry *> &OrderedEntries = It->getSecond(); + // All operands are reordered and used only in this node - propagate the + // most used order to the user node. + MapVector<OrdersType, unsigned, + DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>> + OrdersUses; + SmallPtrSet<const TreeEntry *, 4> VisitedOps; + for (const TreeEntry *OpTE : OrderedEntries) { + // No need to reorder this nodes, still need to extend and to use shuffle, + // just need to merge reordering shuffle and the reuse shuffle. + if (!OpTE->ReuseShuffleIndices.empty()) + continue; + // Count number of orders uses. + const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { + if (OpTE->State == TreeEntry::NeedToGather) + return GathersToOrders.find(OpTE)->second; + return OpTE->ReorderIndices; + }(); + // Stores actually store the mask, not the order, need to invert. + if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && + OpTE->getOpcode() == Instruction::Store && !Order.empty()) { + SmallVector<int> Mask; + inversePermutation(Order, Mask); + unsigned E = Order.size(); + OrdersType CurrentOrder(E, E); + transform(Mask, CurrentOrder.begin(), [E](int Idx) { + return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx); + }); + fixupOrderingIndices(CurrentOrder); + ++OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second; + } else { + ++OrdersUses.insert(std::make_pair(Order, 0)).first->second; + } + } + // Set order of the user node. + if (OrdersUses.empty()) + continue; + // Choose the most used order. + ArrayRef<unsigned> BestOrder = OrdersUses.front().first; + unsigned Cnt = OrdersUses.front().second; + for (const auto &Pair : drop_begin(OrdersUses)) { + if (Cnt < Pair.second || (Cnt == Pair.second && Pair.first.empty())) { + BestOrder = Pair.first; + Cnt = Pair.second; + } + } + // Set order of the user node. + if (BestOrder.empty()) + continue; + SmallVector<int> Mask; + inversePermutation(BestOrder, Mask); + SmallVector<int> MaskOrder(BestOrder.size(), UndefMaskElem); + unsigned E = BestOrder.size(); + transform(BestOrder, MaskOrder.begin(), [E](unsigned I) { + return I < E ? static_cast<int>(I) : UndefMaskElem; + }); + // Do an actual reordering, if profitable. + for (std::unique_ptr<TreeEntry> &TE : VectorizableTree) { + // Just do the reordering for the nodes with the given VF. + if (TE->Scalars.size() != VF) { + if (TE->ReuseShuffleIndices.size() == VF) { + // Need to reorder the reuses masks of the operands with smaller VF to + // be able to find the match between the graph nodes and scalar + // operands of the given node during vectorization/cost estimation. + assert(all_of(TE->UserTreeIndices, + [VF, &TE](const EdgeInfo &EI) { + return EI.UserTE->Scalars.size() == VF || + EI.UserTE->Scalars.size() == + TE->Scalars.size(); + }) && + "All users must be of VF size."); + // Update ordering of the operands with the smaller VF than the given + // one. + reorderReuses(TE->ReuseShuffleIndices, Mask); + } + continue; + } + if (TE->State == TreeEntry::Vectorize && + isa<ExtractElementInst, ExtractValueInst, LoadInst, StoreInst, + InsertElementInst>(TE->getMainOp()) && + !TE->isAltShuffle()) { + // Build correct orders for extract{element,value}, loads and + // stores. + reorderOrder(TE->ReorderIndices, Mask); + if (isa<InsertElementInst, StoreInst>(TE->getMainOp())) + TE->reorderOperands(Mask); + } else { + // Reorder the node and its operands. + TE->reorderOperands(Mask); + assert(TE->ReorderIndices.empty() && + "Expected empty reorder sequence."); + reorderScalars(TE->Scalars, Mask); + } + if (!TE->ReuseShuffleIndices.empty()) { + // Apply reversed order to keep the original ordering of the reused + // elements to avoid extra reorder indices shuffling. + OrdersType CurrentOrder; + reorderOrder(CurrentOrder, MaskOrder); + SmallVector<int> NewReuses; + inversePermutation(CurrentOrder, NewReuses); + addMask(NewReuses, TE->ReuseShuffleIndices); + TE->ReuseShuffleIndices.swap(NewReuses); + } + } + } +} + +void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { + SetVector<TreeEntry *> OrderedEntries; + DenseMap<const TreeEntry *, OrdersType> GathersToOrders; + // Find all reorderable leaf nodes with the given VF. + // Currently the are vectorized loads,extracts without alternate operands + + // some gathering of extracts. + SmallVector<TreeEntry *> NonVectorized; + for_each(VectorizableTree, [this, &OrderedEntries, &GathersToOrders, + &NonVectorized]( + const std::unique_ptr<TreeEntry> &TE) { + if (TE->State != TreeEntry::Vectorize) + NonVectorized.push_back(TE.get()); + // No need to reorder if need to shuffle reuses, still need to shuffle the + // node. + if (!TE->ReuseShuffleIndices.empty()) + return; + if (TE->State == TreeEntry::Vectorize && + isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE->getMainOp()) && + !TE->isAltShuffle()) { + OrderedEntries.insert(TE.get()); + return; + } + if (TE->State == TreeEntry::NeedToGather) { + if (TE->getOpcode() == Instruction::ExtractElement && + !TE->isAltShuffle() && + isa<FixedVectorType>(cast<ExtractElementInst>(TE->getMainOp()) + ->getVectorOperandType()) && + allSameType(TE->Scalars) && allSameBlock(TE->Scalars)) { + // Check that gather of extractelements can be represented as + // just a shuffle of a single vector with a single user only. + OrdersType CurrentOrder; + bool Reuse = + canReuseExtract(TE->Scalars, TE->getMainOp(), CurrentOrder); + if ((Reuse || !CurrentOrder.empty()) && + !any_of(VectorizableTree, + [&TE](const std::unique_ptr<TreeEntry> &Entry) { + return Entry->State == TreeEntry::NeedToGather && + Entry.get() != TE.get() && + Entry->isSame(TE->Scalars); + })) { + OrderedEntries.insert(TE.get()); + GathersToOrders.try_emplace(TE.get(), CurrentOrder); + return; + } + } + if (Optional<OrdersType> CurrentOrder = + findReusedOrderedScalars(*TE.get())) { + OrderedEntries.insert(TE.get()); + GathersToOrders.try_emplace(TE.get(), *CurrentOrder); + } + } + }); + + // Checks if the operands of the users are reordarable and have only single + // use. + auto &&CheckOperands = + [this, &NonVectorized](const auto &Data, + SmallVectorImpl<TreeEntry *> &GatherOps) { + for (unsigned I = 0, E = Data.first->getNumOperands(); I < E; ++I) { + if (any_of(Data.second, + [I](const std::pair<unsigned, TreeEntry *> &OpData) { + return OpData.first == I && + OpData.second->State == TreeEntry::Vectorize; + })) + continue; + ArrayRef<Value *> VL = Data.first->getOperand(I); + const TreeEntry *TE = nullptr; + const auto *It = find_if(VL, [this, &TE](Value *V) { + TE = getTreeEntry(V); + return TE; + }); + if (It != VL.end() && TE->isSame(VL)) + return false; + TreeEntry *Gather = nullptr; + if (count_if(NonVectorized, [VL, &Gather](TreeEntry *TE) { + assert(TE->State != TreeEntry::Vectorize && + "Only non-vectorized nodes are expected."); + if (TE->isSame(VL)) { + Gather = TE; + return true; + } + return false; + }) > 1) + return false; + if (Gather) + GatherOps.push_back(Gather); + } + return true; + }; + // 1. Propagate order to the graph nodes, which use only reordered nodes. + // I.e., if the node has operands, that are reordered, try to make at least + // one operand order in the natural order and reorder others + reorder the + // user node itself. + SmallPtrSet<const TreeEntry *, 4> Visited; + while (!OrderedEntries.empty()) { + // 1. Filter out only reordered nodes. + // 2. If the entry has multiple uses - skip it and jump to the next node. + MapVector<TreeEntry *, SmallVector<std::pair<unsigned, TreeEntry *>>> Users; + SmallVector<TreeEntry *> Filtered; + for (TreeEntry *TE : OrderedEntries) { + if (!(TE->State == TreeEntry::Vectorize || + (TE->State == TreeEntry::NeedToGather && + GathersToOrders.count(TE))) || + TE->UserTreeIndices.empty() || !TE->ReuseShuffleIndices.empty() || + !all_of(drop_begin(TE->UserTreeIndices), + [TE](const EdgeInfo &EI) { + return EI.UserTE == TE->UserTreeIndices.front().UserTE; + }) || + !Visited.insert(TE).second) { + Filtered.push_back(TE); + continue; + } + // Build a map between user nodes and their operands order to speedup + // search. The graph currently does not provide this dependency directly. + for (EdgeInfo &EI : TE->UserTreeIndices) { + TreeEntry *UserTE = EI.UserTE; + auto It = Users.find(UserTE); + if (It == Users.end()) + It = Users.insert({UserTE, {}}).first; + It->second.emplace_back(EI.EdgeIdx, TE); + } + } + // Erase filtered entries. + for_each(Filtered, + [&OrderedEntries](TreeEntry *TE) { OrderedEntries.remove(TE); }); + for (const auto &Data : Users) { + // Check that operands are used only in the User node. + SmallVector<TreeEntry *> GatherOps; + if (!CheckOperands(Data, GatherOps)) { + for_each(Data.second, + [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { + OrderedEntries.remove(Op.second); + }); + continue; + } + // All operands are reordered and used only in this node - propagate the + // most used order to the user node. + MapVector<OrdersType, unsigned, + DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>> + OrdersUses; + SmallPtrSet<const TreeEntry *, 4> VisitedOps; + for (const auto &Op : Data.second) { + TreeEntry *OpTE = Op.second; + if (!OpTE->ReuseShuffleIndices.empty() || + (IgnoreReorder && OpTE == VectorizableTree.front().get())) + continue; + const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { + if (OpTE->State == TreeEntry::NeedToGather) + return GathersToOrders.find(OpTE)->second; + return OpTE->ReorderIndices; + }(); + // Stores actually store the mask, not the order, need to invert. + if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && + OpTE->getOpcode() == Instruction::Store && !Order.empty()) { + SmallVector<int> Mask; + inversePermutation(Order, Mask); + unsigned E = Order.size(); + OrdersType CurrentOrder(E, E); + transform(Mask, CurrentOrder.begin(), [E](int Idx) { + return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx); + }); + fixupOrderingIndices(CurrentOrder); + ++OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second; + } else { + ++OrdersUses.insert(std::make_pair(Order, 0)).first->second; + } + if (VisitedOps.insert(OpTE).second) + OrdersUses.insert(std::make_pair(OrdersType(), 0)).first->second += + OpTE->UserTreeIndices.size(); + assert(OrdersUses[{}] > 0 && "Counter cannot be less than 0."); + --OrdersUses[{}]; + } + // If no orders - skip current nodes and jump to the next one, if any. + if (OrdersUses.empty()) { + for_each(Data.second, + [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { + OrderedEntries.remove(Op.second); + }); + continue; + } + // Choose the best order. + ArrayRef<unsigned> BestOrder = OrdersUses.front().first; + unsigned Cnt = OrdersUses.front().second; + for (const auto &Pair : drop_begin(OrdersUses)) { + if (Cnt < Pair.second || (Cnt == Pair.second && Pair.first.empty())) { + BestOrder = Pair.first; + Cnt = Pair.second; + } + } + // Set order of the user node (reordering of operands and user nodes). + if (BestOrder.empty()) { + for_each(Data.second, + [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { + OrderedEntries.remove(Op.second); + }); + continue; + } + // Erase operands from OrderedEntries list and adjust their orders. + VisitedOps.clear(); + SmallVector<int> Mask; + inversePermutation(BestOrder, Mask); + SmallVector<int> MaskOrder(BestOrder.size(), UndefMaskElem); + unsigned E = BestOrder.size(); + transform(BestOrder, MaskOrder.begin(), [E](unsigned I) { + return I < E ? static_cast<int>(I) : UndefMaskElem; + }); + for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) { + TreeEntry *TE = Op.second; + OrderedEntries.remove(TE); + if (!VisitedOps.insert(TE).second) + continue; + if (!TE->ReuseShuffleIndices.empty() && TE->ReorderIndices.empty()) { + // Just reorder reuses indices. + reorderReuses(TE->ReuseShuffleIndices, Mask); + continue; + } + // Gathers are processed separately. + if (TE->State != TreeEntry::Vectorize) + continue; + assert((BestOrder.size() == TE->ReorderIndices.size() || + TE->ReorderIndices.empty()) && + "Non-matching sizes of user/operand entries."); + reorderOrder(TE->ReorderIndices, Mask); + } + // For gathers just need to reorder its scalars. + for (TreeEntry *Gather : GatherOps) { + assert(Gather->ReorderIndices.empty() && + "Unexpected reordering of gathers."); + if (!Gather->ReuseShuffleIndices.empty()) { + // Just reorder reuses indices. + reorderReuses(Gather->ReuseShuffleIndices, Mask); + continue; + } + reorderScalars(Gather->Scalars, Mask); + OrderedEntries.remove(Gather); + } + // Reorder operands of the user node and set the ordering for the user + // node itself. + if (Data.first->State != TreeEntry::Vectorize || + !isa<ExtractElementInst, ExtractValueInst, LoadInst>( + Data.first->getMainOp()) || + Data.first->isAltShuffle()) + Data.first->reorderOperands(Mask); + if (!isa<InsertElementInst, StoreInst>(Data.first->getMainOp()) || + Data.first->isAltShuffle()) { + reorderScalars(Data.first->Scalars, Mask); + reorderOrder(Data.first->ReorderIndices, MaskOrder); + if (Data.first->ReuseShuffleIndices.empty() && + !Data.first->ReorderIndices.empty() && + !Data.first->isAltShuffle()) { + // Insert user node to the list to try to sink reordering deeper in + // the graph. + OrderedEntries.insert(Data.first); + } + } else { + reorderOrder(Data.first->ReorderIndices, Mask); + } + } + } + // If the reordering is unnecessary, just remove the reorder. + if (IgnoreReorder && !VectorizableTree.front()->ReorderIndices.empty() && + VectorizableTree.front()->ReuseShuffleIndices.empty()) + VectorizableTree.front()->ReorderIndices.clear(); +} + +void BoUpSLP::buildExternalUses( + const ExtraValueToDebugLocsMap &ExternallyUsedValues) { // Collect the values that we need to extract from the tree. for (auto &TEPtr : VectorizableTree) { TreeEntry *Entry = TEPtr.get(); @@ -2636,6 +3215,9 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, if (!UserInst) continue; + if (isDeleted(UserInst)) + continue; + // Skip in-tree scalars that become vectors if (TreeEntry *UseEntry = getTreeEntry(U)) { Value *UseScalar = UseEntry->Scalars[0]; @@ -2664,14 +3246,120 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, } } +void BoUpSLP::buildTree(ArrayRef<Value *> Roots, + ArrayRef<Value *> UserIgnoreLst) { + deleteTree(); + UserIgnoreList = UserIgnoreLst; + if (!allSameType(Roots)) + return; + buildTree_rec(Roots, 0, EdgeInfo()); +} + +namespace { +/// Tracks the state we can represent the loads in the given sequence. +enum class LoadsState { Gather, Vectorize, ScatterVectorize }; +} // anonymous namespace + +/// Checks if the given array of loads can be represented as a vectorized, +/// scatter or just simple gather. +static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, + const TargetTransformInfo &TTI, + const DataLayout &DL, ScalarEvolution &SE, + SmallVectorImpl<unsigned> &Order, + SmallVectorImpl<Value *> &PointerOps) { + // Check that a vectorized load would load the same memory as a scalar + // load. For example, we don't want to vectorize loads that are smaller + // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM + // treats loading/storing it as an i8 struct. If we vectorize loads/stores + // from such a struct, we read/write packed bits disagreeing with the + // unvectorized version. + Type *ScalarTy = VL0->getType(); + + if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy)) + return LoadsState::Gather; + + // Make sure all loads in the bundle are simple - we can't vectorize + // atomic or volatile loads. + PointerOps.clear(); + PointerOps.resize(VL.size()); + auto *POIter = PointerOps.begin(); + for (Value *V : VL) { + auto *L = cast<LoadInst>(V); + if (!L->isSimple()) + return LoadsState::Gather; + *POIter = L->getPointerOperand(); + ++POIter; + } + + Order.clear(); + // Check the order of pointer operands. + if (llvm::sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order)) { + Value *Ptr0; + Value *PtrN; + if (Order.empty()) { + Ptr0 = PointerOps.front(); + PtrN = PointerOps.back(); + } else { + Ptr0 = PointerOps[Order.front()]; + PtrN = PointerOps[Order.back()]; + } + Optional<int> Diff = + getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE); + // Check that the sorted loads are consecutive. + if (static_cast<unsigned>(*Diff) == VL.size() - 1) + return LoadsState::Vectorize; + Align CommonAlignment = cast<LoadInst>(VL0)->getAlign(); + for (Value *V : VL) + CommonAlignment = + commonAlignment(CommonAlignment, cast<LoadInst>(V)->getAlign()); + if (TTI.isLegalMaskedGather(FixedVectorType::get(ScalarTy, VL.size()), + CommonAlignment)) + return LoadsState::ScatterVectorize; + } + + return LoadsState::Gather; +} + void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, const EdgeInfo &UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); + SmallVector<int> ReuseShuffleIndicies; + SmallVector<Value *> UniqueValues; + auto &&TryToFindDuplicates = [&VL, &ReuseShuffleIndicies, &UniqueValues, + &UserTreeIdx, + this](const InstructionsState &S) { + // Check that every instruction appears once in this bundle. + DenseMap<Value *, unsigned> UniquePositions; + for (Value *V : VL) { + auto Res = UniquePositions.try_emplace(V, UniqueValues.size()); + ReuseShuffleIndicies.emplace_back(isa<UndefValue>(V) ? -1 + : Res.first->second); + if (Res.second) + UniqueValues.emplace_back(V); + } + size_t NumUniqueScalarValues = UniqueValues.size(); + if (NumUniqueScalarValues == VL.size()) { + ReuseShuffleIndicies.clear(); + } else { + LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n"); + if (NumUniqueScalarValues <= 1 || + !llvm::isPowerOf2_32(NumUniqueScalarValues)) { + LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + return false; + } + VL = UniqueValues; + } + return true; + }; + InstructionsState S = getSameOpcode(VL); if (Depth == RecursionMaxDepth) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } @@ -2680,7 +3368,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, isa<ScalableVectorType>( cast<ExtractElementInst>(S.OpValue)->getVectorOperandType())) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to scalable vector type.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } @@ -2700,9 +3390,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } // If all of the operands are identical or constant we have a simple solution. - if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode()) { + // If we deal with insert/extract instructions, they all must have constant + // indices, otherwise we should gather them, not try to vectorize. + if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode() || + (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(S.MainOp) && + !all_of(VL, isVectorLikeInstWithConstOps))) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } @@ -2724,7 +3420,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.OpValue << ".\n"); if (!E->isSame(VL)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } // Record the reuse of the tree node. FIXME, currently this is only used to @@ -2743,7 +3441,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (getTreeEntry(I)) { LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V << ") is already in tree.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } } @@ -2754,7 +3454,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *V : VL) { if (MustGather.count(V) || is_contained(UserIgnoreList, V)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } } @@ -2773,28 +3475,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } // Check that every instruction appears once in this bundle. - SmallVector<unsigned, 4> ReuseShuffleIndicies; - SmallVector<Value *, 4> UniqueValues; - DenseMap<Value *, unsigned> UniquePositions; - for (Value *V : VL) { - auto Res = UniquePositions.try_emplace(V, UniqueValues.size()); - ReuseShuffleIndicies.emplace_back(Res.first->second); - if (Res.second) - UniqueValues.emplace_back(V); - } - size_t NumUniqueScalarValues = UniqueValues.size(); - if (NumUniqueScalarValues == VL.size()) { - ReuseShuffleIndicies.clear(); - } else { - LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n"); - if (NumUniqueScalarValues <= 1 || - !llvm::isPowerOf2_32(NumUniqueScalarValues)) { - LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); - return; - } - VL = UniqueValues; - } + if (!TryToFindDuplicates(S)) + return; auto &BSRef = BlocksSchedules[BB]; if (!BSRef) @@ -2867,7 +3549,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, bool Reuse = canReuseExtract(VL, VL0, CurrentOrder); if (Reuse) { LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n"); - ++NumOpsWantToKeepOriginalOrder; newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); // This is a special case, as it does not gather, but at the same time @@ -2885,12 +3566,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, dbgs() << " " << Idx; dbgs() << "\n"; }); + fixupOrderingIndices(CurrentOrder); // Insert new order with initial value 0, if it does not exist, // otherwise return the iterator to the existing one. newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies, CurrentOrder); - findRootOrder(CurrentOrder); - ++NumOpsWantToKeepOrder[CurrentOrder]; // This is a special case, as it does not gather, but at the same time // we are not extending buildTree_rec() towards the operands. ValueList Op0; @@ -2910,8 +3590,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Check that we have a buildvector and not a shuffle of 2 or more // different vectors. ValueSet SourceVectors; - for (Value *V : VL) + int MinIdx = std::numeric_limits<int>::max(); + for (Value *V : VL) { SourceVectors.insert(cast<Instruction>(V)->getOperand(0)); + Optional<int> Idx = *getInsertIndex(V, 0); + if (!Idx || *Idx == UndefMaskElem) + continue; + MinIdx = std::min(MinIdx, *Idx); + } if (count_if(VL, [&SourceVectors](Value *V) { return !SourceVectors.contains(V); @@ -2919,13 +3605,35 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Found 2nd source vector - cancel. LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with " "different source vectors.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); BS.cancelScheduling(VL, VL0); return; } - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx); + auto OrdCompare = [](const std::pair<int, int> &P1, + const std::pair<int, int> &P2) { + return P1.first > P2.first; + }; + PriorityQueue<std::pair<int, int>, SmallVector<std::pair<int, int>>, + decltype(OrdCompare)> + Indices(OrdCompare); + for (int I = 0, E = VL.size(); I < E; ++I) { + Optional<int> Idx = *getInsertIndex(VL[I], 0); + if (!Idx || *Idx == UndefMaskElem) + continue; + Indices.emplace(*Idx, I); + } + OrdersType CurrentOrder(VL.size(), VL.size()); + bool IsIdentity = true; + for (int I = 0, E = VL.size(); I < E; ++I) { + CurrentOrder[Indices.top().second] = I; + IsIdentity &= Indices.top().second == I; + Indices.pop(); + } + if (IsIdentity) + CurrentOrder.clear(); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + None, CurrentOrder); LLVM_DEBUG(dbgs() << "SLP: added inserts bundle.\n"); constexpr int NumOps = 2; @@ -2936,7 +3644,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TE->setOperand(I, VectorOperands[I]); } - buildTree_rec(VectorOperands[NumOps - 1], Depth + 1, {TE, 0}); + buildTree_rec(VectorOperands[NumOps - 1], Depth + 1, {TE, NumOps - 1}); return; } case Instruction::Load: { @@ -2946,90 +3654,52 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // treats loading/storing it as an i8 struct. If we vectorize loads/stores // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. - Type *ScalarTy = VL0->getType(); - - if (DL->getTypeSizeInBits(ScalarTy) != - DL->getTypeAllocSizeInBits(ScalarTy)) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); - return; - } - - // Make sure all loads in the bundle are simple - we can't vectorize - // atomic or volatile loads. - SmallVector<Value *, 4> PointerOps(VL.size()); - auto POIter = PointerOps.begin(); - for (Value *V : VL) { - auto *L = cast<LoadInst>(V); - if (!L->isSimple()) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); - return; - } - *POIter = L->getPointerOperand(); - ++POIter; - } - + SmallVector<Value *> PointerOps; OrdersType CurrentOrder; - // Check the order of pointer operands. - if (llvm::sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, CurrentOrder)) { - Value *Ptr0; - Value *PtrN; + TreeEntry *TE = nullptr; + switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, CurrentOrder, + PointerOps)) { + case LoadsState::Vectorize: if (CurrentOrder.empty()) { - Ptr0 = PointerOps.front(); - PtrN = PointerOps.back(); + // Original loads are consecutive and does not require reordering. + TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); } else { - Ptr0 = PointerOps[CurrentOrder.front()]; - PtrN = PointerOps[CurrentOrder.back()]; - } - Optional<int> Diff = getPointersDiff( - ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE); - // Check that the sorted loads are consecutive. - if (static_cast<unsigned>(*Diff) == VL.size() - 1) { - if (CurrentOrder.empty()) { - // Original loads are consecutive and does not require reordering. - ++NumOpsWantToKeepOriginalOrder; - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, - UserTreeIdx, ReuseShuffleIndicies); - TE->setOperandsInOrder(); - LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); - } else { - // Need to reorder. - TreeEntry *TE = - newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies, CurrentOrder); - TE->setOperandsInOrder(); - LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); - findRootOrder(CurrentOrder); - ++NumOpsWantToKeepOrder[CurrentOrder]; - } - return; - } - Align CommonAlignment = cast<LoadInst>(VL0)->getAlign(); - for (Value *V : VL) - CommonAlignment = - commonAlignment(CommonAlignment, cast<LoadInst>(V)->getAlign()); - if (TTI->isLegalMaskedGather(FixedVectorType::get(ScalarTy, VL.size()), - CommonAlignment)) { - // Vectorizing non-consecutive loads with `llvm.masked.gather`. - TreeEntry *TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, - S, UserTreeIdx, ReuseShuffleIndicies); - TE->setOperandsInOrder(); - buildTree_rec(PointerOps, Depth + 1, {TE, 0}); - LLVM_DEBUG(dbgs() - << "SLP: added a vector of non-consecutive loads.\n"); - return; + fixupOrderingIndices(CurrentOrder); + // Need to reorder. + TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies, CurrentOrder); + LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); } + TE->setOperandsInOrder(); + break; + case LoadsState::ScatterVectorize: + // Vectorizing non-consecutive loads with `llvm.masked.gather`. + TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, S, + UserTreeIdx, ReuseShuffleIndicies); + TE->setOperandsInOrder(); + buildTree_rec(PointerOps, Depth + 1, {TE, 0}); + LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n"); + break; + case LoadsState::Gather: + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); +#ifndef NDEBUG + Type *ScalarTy = VL0->getType(); + if (DL->getTypeSizeInBits(ScalarTy) != + DL->getTypeAllocSizeInBits(ScalarTy)) + LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); + else if (any_of(VL, [](Value *V) { + return !cast<LoadInst>(V)->isSimple(); + })) + LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); + else + LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); +#endif // NDEBUG + break; } - - LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); return; } case Instruction::ZExt: @@ -3213,15 +3883,40 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); - TE->setOperandsInOrder(); - for (unsigned i = 0, e = 2; i < e; ++i) { - ValueList Operands; - // Prepare the operand vector. - for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(i)); - - buildTree_rec(Operands, Depth + 1, {TE, i}); + SmallVector<ValueList, 2> Operands(2); + // Prepare the operand vector for pointer operands. + for (Value *V : VL) + Operands.front().push_back( + cast<GetElementPtrInst>(V)->getPointerOperand()); + TE->setOperand(0, Operands.front()); + // Need to cast all indices to the same type before vectorization to + // avoid crash. + // Required to be able to find correct matches between different gather + // nodes and reuse the vectorized values rather than trying to gather them + // again. + int IndexIdx = 1; + Type *VL0Ty = VL0->getOperand(IndexIdx)->getType(); + Type *Ty = all_of(VL, + [VL0Ty, IndexIdx](Value *V) { + return VL0Ty == cast<GetElementPtrInst>(V) + ->getOperand(IndexIdx) + ->getType(); + }) + ? VL0Ty + : DL->getIndexType(cast<GetElementPtrInst>(VL0) + ->getPointerOperandType() + ->getScalarType()); + // Prepare the operand vector. + for (Value *V : VL) { + auto *Op = cast<Instruction>(V)->getOperand(IndexIdx); + auto *CI = cast<ConstantInt>(Op); + Operands.back().push_back(ConstantExpr::getIntegerCast( + CI, Ty, CI->getValue().isSignBitSet())); } + TE->setOperand(IndexIdx, Operands.back()); + + for (unsigned I = 0, Ops = Operands.size(); I < Ops; ++I) + buildTree_rec(Operands[I], Depth + 1, {TE, I}); return; } case Instruction::Store: { @@ -3276,21 +3971,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (static_cast<unsigned>(*Dist) == VL.size() - 1) { if (CurrentOrder.empty()) { // Original stores are consecutive and does not require reordering. - ++NumOpsWantToKeepOriginalOrder; TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); TE->setOperandsInOrder(); buildTree_rec(Operands, Depth + 1, {TE, 0}); LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); } else { + fixupOrderingIndices(CurrentOrder); TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies, CurrentOrder); TE->setOperandsInOrder(); buildTree_rec(Operands, Depth + 1, {TE, 0}); LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled stores.\n"); - findRootOrder(CurrentOrder); - ++NumOpsWantToKeepOrder[CurrentOrder]; } return; } @@ -3321,7 +4014,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } Function *F = CI->getCalledFunction(); - unsigned NumArgs = CI->getNumArgOperands(); + unsigned NumArgs = CI->arg_size(); SmallVector<Value*, 4> ScalarArgs(NumArgs, nullptr); for (unsigned j = 0; j != NumArgs; ++j) if (hasVectorInstrinsicScalarOpd(ID, j)) @@ -3373,7 +4066,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); TE->setOperandsInOrder(); - for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { + for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) { + // For scalar operands no need to to create an entry since no need to + // vectorize it. + if (hasVectorInstrinsicScalarOpd(ID, i)) + continue; ValueList Operands; // Prepare the operand vector. for (Value *V : VL) { @@ -3548,7 +4245,7 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, FastMathFlags FMF; if (auto *FPCI = dyn_cast<FPMathOperator>(CI)) FMF = FPCI->getFastMathFlags(); - SmallVector<const Value *> Arguments(CI->arg_begin(), CI->arg_end()); + SmallVector<const Value *> Arguments(CI->args()); IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, VecTys, FMF, dyn_cast<IntrinsicInst>(CI)); auto IntrinsicCost = @@ -3621,25 +4318,42 @@ computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy, return Cost; } -/// Shuffles \p Mask in accordance with the given \p SubMask. -static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask) { - if (SubMask.empty()) - return; - if (Mask.empty()) { - Mask.append(SubMask.begin(), SubMask.end()); - return; - } - SmallVector<int, 4> NewMask(SubMask.size(), SubMask.size()); - int TermValue = std::min(Mask.size(), SubMask.size()); - for (int I = 0, E = SubMask.size(); I < E; ++I) { - if (SubMask[I] >= TermValue || SubMask[I] == UndefMaskElem || - Mask[SubMask[I]] >= TermValue) { - NewMask[I] = UndefMaskElem; - continue; +/// Build shuffle mask for shuffle graph entries and lists of main and alternate +/// operations operands. +static void +buildSuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, + ArrayRef<int> ReusesIndices, + const function_ref<bool(Instruction *)> IsAltOp, + SmallVectorImpl<int> &Mask, + SmallVectorImpl<Value *> *OpScalars = nullptr, + SmallVectorImpl<Value *> *AltScalars = nullptr) { + unsigned Sz = VL.size(); + Mask.assign(Sz, UndefMaskElem); + SmallVector<int> OrderMask; + if (!ReorderIndices.empty()) + inversePermutation(ReorderIndices, OrderMask); + for (unsigned I = 0; I < Sz; ++I) { + unsigned Idx = I; + if (!ReorderIndices.empty()) + Idx = OrderMask[I]; + auto *OpInst = cast<Instruction>(VL[Idx]); + if (IsAltOp(OpInst)) { + Mask[I] = Sz + Idx; + if (AltScalars) + AltScalars->push_back(OpInst); + } else { + Mask[I] = Idx; + if (OpScalars) + OpScalars->push_back(OpInst); } - NewMask[I] = Mask[SubMask[I]]; } - Mask.swap(NewMask); + if (!ReusesIndices.empty()) { + SmallVector<int> NewMask(ReusesIndices.size(), UndefMaskElem); + transform(ReusesIndices, NewMask.begin(), [&Mask](int Idx) { + return Idx != UndefMaskElem ? Mask[Idx] : UndefMaskElem; + }); + Mask.swap(NewMask); + } } InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, @@ -3661,13 +4375,10 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, if (MinBWs.count(VL[0])) VecTy = FixedVectorType::get( IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); - auto *FinalVecTy = VecTy; + unsigned EntryVF = E->getVectorFactor(); + auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF); - unsigned ReuseShuffleNumbers = E->ReuseShuffleIndices.size(); bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); - if (NeedToShuffleReuses) - FinalVecTy = - FixedVectorType::get(VecTy->getElementType(), ReuseShuffleNumbers); // FIXME: it tries to fix a problem with MSVC buildbots. TargetTransformInfo &TTIRef = *TTI; auto &&AdjustExtractsCost = [this, &TTIRef, CostKind, VL, VecTy, @@ -3785,7 +4496,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // shuffle of a single/two vectors the scalars are extracted from. SmallVector<int> Mask; Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = - isShuffle(VL, Mask); + isFixedVectorShuffle(VL, Mask); if (ShuffleKind.hasValue()) { // Found the bunch of extractelement instructions that must be gathered // into a vector and can be represented as a permutation elements in a @@ -3803,6 +4514,92 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, if (NeedToShuffleReuses) ReuseShuffleCost = TTI->getShuffleCost( TTI::SK_PermuteSingleSrc, FinalVecTy, E->ReuseShuffleIndices); + // Improve gather cost for gather of loads, if we can group some of the + // loads into vector loads. + if (VL.size() > 2 && E->getOpcode() == Instruction::Load && + !E->isAltShuffle()) { + BoUpSLP::ValueSet VectorizedLoads; + unsigned StartIdx = 0; + unsigned VF = VL.size() / 2; + unsigned VectorizedCnt = 0; + unsigned ScatterVectorizeCnt = 0; + const unsigned Sz = DL->getTypeSizeInBits(E->getMainOp()->getType()); + for (unsigned MinVF = getMinVF(2 * Sz); VF >= MinVF; VF /= 2) { + for (unsigned Cnt = StartIdx, End = VL.size(); Cnt + VF <= End; + Cnt += VF) { + ArrayRef<Value *> Slice = VL.slice(Cnt, VF); + if (!VectorizedLoads.count(Slice.front()) && + !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) { + SmallVector<Value *> PointerOps; + OrdersType CurrentOrder; + LoadsState LS = canVectorizeLoads(Slice, Slice.front(), *TTI, *DL, + *SE, CurrentOrder, PointerOps); + switch (LS) { + case LoadsState::Vectorize: + case LoadsState::ScatterVectorize: + // Mark the vectorized loads so that we don't vectorize them + // again. + if (LS == LoadsState::Vectorize) + ++VectorizedCnt; + else + ++ScatterVectorizeCnt; + VectorizedLoads.insert(Slice.begin(), Slice.end()); + // If we vectorized initial block, no need to try to vectorize it + // again. + if (Cnt == StartIdx) + StartIdx += VF; + break; + case LoadsState::Gather: + break; + } + } + } + // Check if the whole array was vectorized already - exit. + if (StartIdx >= VL.size()) + break; + // Found vectorizable parts - exit. + if (!VectorizedLoads.empty()) + break; + } + if (!VectorizedLoads.empty()) { + InstructionCost GatherCost = 0; + unsigned NumParts = TTI->getNumberOfParts(VecTy); + bool NeedInsertSubvectorAnalysis = + !NumParts || (VL.size() / VF) > NumParts; + // Get the cost for gathered loads. + for (unsigned I = 0, End = VL.size(); I < End; I += VF) { + if (VectorizedLoads.contains(VL[I])) + continue; + GatherCost += getGatherCost(VL.slice(I, VF)); + } + // The cost for vectorized loads. + InstructionCost ScalarsCost = 0; + for (Value *V : VectorizedLoads) { + auto *LI = cast<LoadInst>(V); + ScalarsCost += TTI->getMemoryOpCost( + Instruction::Load, LI->getType(), LI->getAlign(), + LI->getPointerAddressSpace(), CostKind, LI); + } + auto *LI = cast<LoadInst>(E->getMainOp()); + auto *LoadTy = FixedVectorType::get(LI->getType(), VF); + Align Alignment = LI->getAlign(); + GatherCost += + VectorizedCnt * + TTI->getMemoryOpCost(Instruction::Load, LoadTy, Alignment, + LI->getPointerAddressSpace(), CostKind, LI); + GatherCost += ScatterVectorizeCnt * + TTI->getGatherScatterOpCost( + Instruction::Load, LoadTy, LI->getPointerOperand(), + /*VariableMask=*/false, Alignment, CostKind, LI); + if (NeedInsertSubvectorAnalysis) { + // Add the cost for the subvectors insert. + for (int I = VF, E = VL.size(); I < E; I += VF) + GatherCost += TTI->getShuffleCost(TTI::SK_InsertSubvector, VecTy, + None, I, LoadTy); + } + return ReuseShuffleCost + GatherCost - ScalarsCost; + } + } return ReuseShuffleCost + getGatherCost(VL); } InstructionCost CommonCost = 0; @@ -3852,7 +4649,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ++Idx; } } - Idx = ReuseShuffleNumbers; + Idx = EntryVF; for (Value *V : VL) { if (ShuffleOrOp == Instruction::ExtractElement) { auto *EE = cast<ExtractElementInst>(V); @@ -3895,29 +4692,33 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, return CommonCost; } case Instruction::InsertElement: { + assert(E->ReuseShuffleIndices.empty() && + "Unique insertelements only are expected."); auto *SrcVecTy = cast<FixedVectorType>(VL0->getType()); unsigned const NumElts = SrcVecTy->getNumElements(); unsigned const NumScalars = VL.size(); - APInt DemandedElts = APInt::getNullValue(NumElts); + APInt DemandedElts = APInt::getZero(NumElts); // TODO: Add support for Instruction::InsertValue. - unsigned Offset = UINT_MAX; + SmallVector<int> Mask; + if (!E->ReorderIndices.empty()) { + inversePermutation(E->ReorderIndices, Mask); + Mask.append(NumElts - NumScalars, UndefMaskElem); + } else { + Mask.assign(NumElts, UndefMaskElem); + std::iota(Mask.begin(), std::next(Mask.begin(), NumScalars), 0); + } + unsigned Offset = *getInsertIndex(VL0, 0); bool IsIdentity = true; - SmallVector<int> ShuffleMask(NumElts, UndefMaskElem); + SmallVector<int> PrevMask(NumElts, UndefMaskElem); + Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { - Optional<int> InsertIdx = getInsertIndex(VL[I], 0); + Optional<int> InsertIdx = getInsertIndex(VL[PrevMask[I]], 0); if (!InsertIdx || *InsertIdx == UndefMaskElem) continue; - unsigned Idx = *InsertIdx; - DemandedElts.setBit(Idx); - if (Idx < Offset) { - Offset = Idx; - IsIdentity &= I == 0; - } else { - assert(Idx >= Offset && "Failed to find vector index offset"); - IsIdentity &= Idx - Offset == I; - } - ShuffleMask[Idx] = I; + DemandedElts.setBit(*InsertIdx); + IsIdentity &= *InsertIdx - Offset == I; + Mask[*InsertIdx - Offset] = I; } assert(Offset < NumElts && "Failed to find vector index offset"); @@ -3932,8 +4733,23 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TargetTransformInfo::SK_PermuteSingleSrc, FixedVectorType::get(SrcVecTy->getElementType(), Sz)); } else if (!IsIdentity) { - Cost += TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, SrcVecTy, - ShuffleMask); + auto *FirstInsert = + cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { + return !is_contained(E->Scalars, + cast<Instruction>(V)->getOperand(0)); + })); + if (isa<UndefValue>(FirstInsert->getOperand(0))) { + Cost += TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, SrcVecTy, Mask); + } else { + SmallVector<int> InsertMask(NumElts); + std::iota(InsertMask.begin(), InsertMask.end(), 0); + for (unsigned I = 0; I < NumElts; I++) { + if (Mask[I] != UndefMaskElem) + InsertMask[Offset + I] = NumElts + I; + } + Cost += + TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVecTy, InsertMask); + } } return Cost; @@ -3955,7 +4771,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, TTI::getCastContextHint(VL0), CostKind, VL0); if (NeedToShuffleReuses) { - CommonCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; } // Calculate the cost of this instruction. @@ -3980,7 +4796,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, Builder.getInt1Ty(), CmpInst::BAD_ICMP_PREDICATE, CostKind, VL0); if (NeedToShuffleReuses) { - CommonCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; } auto *MaskTy = FixedVectorType::get(Builder.getInt1Ty(), VL.size()); InstructionCost ScalarCost = VecTy->getNumElements() * ScalarEltCost; @@ -4085,7 +4901,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TTI->getArithmeticInstrCost(E->getOpcode(), ScalarTy, CostKind, Op1VK, Op2VK, Op1VP, Op2VP, Operands, VL0); if (NeedToShuffleReuses) { - CommonCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; } InstructionCost ScalarCost = VecTy->getNumElements() * ScalarEltCost; InstructionCost VecCost = @@ -4103,7 +4919,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, InstructionCost ScalarEltCost = TTI->getArithmeticInstrCost( Instruction::Add, ScalarTy, CostKind, Op1VK, Op2VK); if (NeedToShuffleReuses) { - CommonCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; } InstructionCost ScalarCost = VecTy->getNumElements() * ScalarEltCost; InstructionCost VecCost = TTI->getArithmeticInstrCost( @@ -4117,7 +4933,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, InstructionCost ScalarEltCost = TTI->getMemoryOpCost( Instruction::Load, ScalarTy, Alignment, 0, CostKind, VL0); if (NeedToShuffleReuses) { - CommonCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; } InstructionCost ScalarLdCost = VecTy->getNumElements() * ScalarEltCost; InstructionCost VecLdCost; @@ -4160,7 +4976,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, InstructionCost ScalarEltCost = TTI->getIntrinsicInstrCost(CostAttrs, CostKind); if (NeedToShuffleReuses) { - CommonCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; } InstructionCost ScalarCallCost = VecTy->getNumElements() * ScalarEltCost; @@ -4215,14 +5031,16 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TTI::CastContextHint::None, CostKind); } - SmallVector<int> Mask(E->Scalars.size()); - for (unsigned I = 0, End = E->Scalars.size(); I < End; ++I) { - auto *OpInst = cast<Instruction>(E->Scalars[I]); - assert(E->isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode"); - Mask[I] = I + (OpInst->getOpcode() == E->getAltOpcode() ? End : 0); - } - VecCost += - TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, 0); + SmallVector<int> Mask; + buildSuffleEntryMask( + E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, + [E](Instruction *I) { + assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); + return I->getOpcode() == E->getAltOpcode(); + }, + Mask); + CommonCost = + TTI->getShuffleCost(TargetTransformInfo::SK_Select, FinalVecTy, Mask); LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); return CommonCost + VecCost - ScalarCost; } @@ -4231,13 +5049,30 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, } } -bool BoUpSLP::isFullyVectorizableTinyTree() const { +bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const { LLVM_DEBUG(dbgs() << "SLP: Check whether the tree with height " << VectorizableTree.size() << " is fully vectorizable .\n"); + auto &&AreVectorizableGathers = [this](const TreeEntry *TE, unsigned Limit) { + SmallVector<int> Mask; + return TE->State == TreeEntry::NeedToGather && + !any_of(TE->Scalars, + [this](Value *V) { return EphValues.contains(V); }) && + (allConstant(TE->Scalars) || isSplat(TE->Scalars) || + TE->Scalars.size() < Limit || + (TE->getOpcode() == Instruction::ExtractElement && + isFixedVectorShuffle(TE->Scalars, Mask)) || + (TE->State == TreeEntry::NeedToGather && + TE->getOpcode() == Instruction::Load && !TE->isAltShuffle())); + }; + // We only handle trees of heights 1 and 2. if (VectorizableTree.size() == 1 && - VectorizableTree[0]->State == TreeEntry::Vectorize) + (VectorizableTree[0]->State == TreeEntry::Vectorize || + (ForReduction && + AreVectorizableGathers(VectorizableTree[0].get(), + VectorizableTree[0]->Scalars.size()) && + VectorizableTree[0]->getVectorFactor() > 2))) return true; if (VectorizableTree.size() != 2) @@ -4249,19 +5084,14 @@ bool BoUpSLP::isFullyVectorizableTinyTree() const { // or they are extractelements, which form shuffle. SmallVector<int> Mask; if (VectorizableTree[0]->State == TreeEntry::Vectorize && - (allConstant(VectorizableTree[1]->Scalars) || - isSplat(VectorizableTree[1]->Scalars) || - (VectorizableTree[1]->State == TreeEntry::NeedToGather && - VectorizableTree[1]->Scalars.size() < - VectorizableTree[0]->Scalars.size()) || - (VectorizableTree[1]->State == TreeEntry::NeedToGather && - VectorizableTree[1]->getOpcode() == Instruction::ExtractElement && - isShuffle(VectorizableTree[1]->Scalars, Mask)))) + AreVectorizableGathers(VectorizableTree[1].get(), + VectorizableTree[0]->Scalars.size())) return true; // Gathering cost would be too much for tiny trees. if (VectorizableTree[0]->State == TreeEntry::NeedToGather || - VectorizableTree[1]->State == TreeEntry::NeedToGather) + (VectorizableTree[1]->State == TreeEntry::NeedToGather && + VectorizableTree[0]->State != TreeEntry::ScatterVectorize)) return false; return true; @@ -4330,7 +5160,7 @@ bool BoUpSLP::isLoadCombineCandidate() const { return true; } -bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const { +bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const { // No need to vectorize inserts of gathered values. if (VectorizableTree.size() == 2 && isa<InsertElementInst>(VectorizableTree[0]->Scalars[0]) && @@ -4344,7 +5174,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const { // If we have a tiny tree (a tree whose size is less than MinTreeSize), we // can vectorize it if we can prove it fully vectorizable. - if (isFullyVectorizableTinyTree()) + if (isFullyVectorizableTinyTree(ForReduction)) return false; assert(VectorizableTree.empty() @@ -4496,7 +5326,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { // If found user is an insertelement, do not calculate extract cost but try // to detect it as a final shuffled/identity match. - if (EU.User && isa<InsertElementInst>(EU.User)) { + if (isa_and_nonnull<InsertElementInst>(EU.User)) { if (auto *FTy = dyn_cast<FixedVectorType>(EU.User->getType())) { Optional<int> InsertIdx = getInsertIndex(EU.User, 0); if (!InsertIdx || *InsertIdx == UndefMaskElem) @@ -4508,8 +5338,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { return false; auto *IE1 = cast<InsertElementInst>(VU); auto *IE2 = cast<InsertElementInst>(V); - // Go though of insertelement instructions trying to find either VU as - // the original vector for IE2 or V as the original vector for IE1. + // Go through of insertelement instructions trying to find either VU + // as the original vector for IE2 or V as the original vector for IE1. do { if (IE1 == VU || IE2 == V) return true; @@ -4525,7 +5355,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { VF.push_back(FTy->getNumElements()); ShuffleMask.emplace_back(VF.back(), UndefMaskElem); FirstUsers.push_back(EU.User); - DemandedElts.push_back(APInt::getNullValue(VF.back())); + DemandedElts.push_back(APInt::getZero(VF.back())); VecId = FirstUsers.size() - 1; } else { VecId = std::distance(FirstUsers.begin(), It); @@ -4705,18 +5535,11 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, } else { // Try to find nodes with the same vector factor. assert(UsedTEs.size() == 2 && "Expected at max 2 permuted entries."); - // FIXME: Shall be replaced by GetVF function once non-power-2 patch is - // landed. - auto &&GetVF = [](const TreeEntry *TE) { - if (!TE->ReuseShuffleIndices.empty()) - return TE->ReuseShuffleIndices.size(); - return TE->Scalars.size(); - }; DenseMap<int, const TreeEntry *> VFToTE; for (const TreeEntry *TE : UsedTEs.front()) - VFToTE.try_emplace(GetVF(TE), TE); + VFToTE.try_emplace(TE->getVectorFactor(), TE); for (const TreeEntry *TE : UsedTEs.back()) { - auto It = VFToTE.find(GetVF(TE)); + auto It = VFToTE.find(TE->getVectorFactor()); if (It != VFToTE.end()) { VF = It->first; Entries.push_back(It->second); @@ -4757,16 +5580,17 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, InstructionCost BoUpSLP::getGatherCost(FixedVectorType *Ty, - const DenseSet<unsigned> &ShuffledIndices) const { + const DenseSet<unsigned> &ShuffledIndices, + bool NeedToShuffle) const { unsigned NumElts = Ty->getNumElements(); - APInt DemandedElts = APInt::getNullValue(NumElts); + APInt DemandedElts = APInt::getZero(NumElts); for (unsigned I = 0; I < NumElts; ++I) if (!ShuffledIndices.count(I)) DemandedElts.setBit(I); InstructionCost Cost = TTI->getScalarizationOverhead(Ty, DemandedElts, /*Insert*/ true, /*Extract*/ false); - if (!ShuffledIndices.empty()) + if (NeedToShuffle) Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); return Cost; } @@ -4777,6 +5601,7 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) ScalarTy = SI->getValueOperand()->getType(); auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); + bool DuplicateNonConst = false; // Find the cost of inserting/extracting values from the vector. // Check if the same elements are inserted several times and count them as // shuffle candidates. @@ -4785,12 +5610,17 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { // Iterate in reverse order to consider insert elements with the high cost. for (unsigned I = VL.size(); I > 0; --I) { unsigned Idx = I - 1; - if (isConstant(VL[Idx])) + // No need to shuffle duplicates for constants. + if (isConstant(VL[Idx])) { + ShuffledElements.insert(Idx); continue; - if (!UniqueElements.insert(VL[Idx]).second) + } + if (!UniqueElements.insert(VL[Idx]).second) { + DuplicateNonConst = true; ShuffledElements.insert(Idx); + } } - return getGatherCost(VecTy, ShuffledElements); + return getGatherCost(VecTy, ShuffledElements, DuplicateNonConst); } // Perform operand reordering on the instructions in VL and return the reordered @@ -5006,17 +5836,18 @@ Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { // block: // %phi = phi <2 x > { .., %entry} {%shuffle, %block} - // %2 = shuffle <2 x > %phi, %poison, <4 x > <0, 0, 1, 1> + // %2 = shuffle <2 x > %phi, poison, <4 x > <1, 1, 0, 0> // ... (use %2) - // %shuffle = shuffle <2 x> %2, poison, <2 x> {0, 2} + // %shuffle = shuffle <2 x> %2, poison, <2 x> {2, 0} // br %block - SmallVector<int> UniqueIdxs; + SmallVector<int> UniqueIdxs(VF, UndefMaskElem); SmallSet<int, 4> UsedIdxs; int Pos = 0; int Sz = VL.size(); for (int Idx : E->ReuseShuffleIndices) { - if (Idx != Sz && UsedIdxs.insert(Idx).second) - UniqueIdxs.emplace_back(Pos); + if (Idx != Sz && Idx != UndefMaskElem && + UsedIdxs.insert(Idx).second) + UniqueIdxs[Idx] = Pos; ++Pos; } assert(VF >= UsedIdxs.size() && "Expected vectorization factor " @@ -5047,11 +5878,9 @@ Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { }).base()); VF = std::max<unsigned>(VF, PowerOf2Ceil(NumValues)); int UniqueVals = 0; - bool HasUndefs = false; for (Value *V : VL.drop_back(VL.size() - VF)) { if (isa<UndefValue>(V)) { ReuseShuffleIndicies.emplace_back(UndefMaskElem); - HasUndefs = true; continue; } if (isConstant(V)) { @@ -5066,15 +5895,10 @@ Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { ++UniqueVals; } } - if (HasUndefs && UniqueVals == 1 && UniqueValues.size() == 1) { + if (UniqueVals == 1 && UniqueValues.size() == 1) { // Emit pure splat vector. - // FIXME: why it is not identified as an identity. - unsigned NumUndefs = count(ReuseShuffleIndicies, UndefMaskElem); - if (NumUndefs == ReuseShuffleIndicies.size() - 1) - ReuseShuffleIndicies.append(VF - ReuseShuffleIndicies.size(), - UndefMaskElem); - else - ReuseShuffleIndicies.assign(VF, 0); + ReuseShuffleIndicies.append(VF - ReuseShuffleIndicies.size(), + UndefMaskElem); } else if (UniqueValues.size() >= VF - 1 || UniqueValues.size() <= 1) { ReuseShuffleIndicies.clear(); UniqueValues.clear(); @@ -5107,12 +5931,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); - unsigned VF = E->Scalars.size(); - if (NeedToShuffleReuses) - VF = E->ReuseShuffleIndices.size(); + unsigned VF = E->getVectorFactor(); ShuffleInstructionBuilder ShuffleBuilder(Builder, VF); if (E->State == TreeEntry::NeedToGather) { - setInsertPointAfterBundle(E); + if (E->getMainOp()) + setInsertPointAfterBundle(E); Value *Vec; SmallVector<int> Mask; SmallVector<const TreeEntry *> Entries; @@ -5152,13 +5975,17 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size()); switch (ShuffleOrOp) { case Instruction::PHI: { + assert( + (E->ReorderIndices.empty() || E != VectorizableTree.front().get()) && + "PHI reordering is free."); auto *PH = cast<PHINode>(VL0); Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues()); Value *V = NewPhi; - if (NeedToShuffleReuses) - V = Builder.CreateShuffleVector(V, E->ReuseShuffleIndices, "shuffle"); + ShuffleBuilder.addInversedMask(E->ReorderIndices); + ShuffleBuilder.addMask(E->ReuseShuffleIndices); + V = ShuffleBuilder.finalize(V); E->VectorizedValue = V; @@ -5209,53 +6036,48 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return NewV; } case Instruction::InsertElement: { - Builder.SetInsertPoint(VL0); + assert(E->ReuseShuffleIndices.empty() && "All inserts should be unique"); + Builder.SetInsertPoint(cast<Instruction>(E->Scalars.back())); Value *V = vectorizeTree(E->getOperand(1)); + // Create InsertVector shuffle if necessary + auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { + return !is_contained(E->Scalars, cast<Instruction>(V)->getOperand(0)); + })); const unsigned NumElts = - cast<FixedVectorType>(VL0->getType())->getNumElements(); + cast<FixedVectorType>(FirstInsert->getType())->getNumElements(); const unsigned NumScalars = E->Scalars.size(); + unsigned Offset = *getInsertIndex(VL0, 0); + assert(Offset < NumElts && "Failed to find vector index offset"); + + // Create shuffle to resize vector + SmallVector<int> Mask; + if (!E->ReorderIndices.empty()) { + inversePermutation(E->ReorderIndices, Mask); + Mask.append(NumElts - NumScalars, UndefMaskElem); + } else { + Mask.assign(NumElts, UndefMaskElem); + std::iota(Mask.begin(), std::next(Mask.begin(), NumScalars), 0); + } // Create InsertVector shuffle if necessary - Instruction *FirstInsert = nullptr; bool IsIdentity = true; - unsigned Offset = UINT_MAX; + SmallVector<int> PrevMask(NumElts, UndefMaskElem); + Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { - Value *Scalar = E->Scalars[I]; - if (!FirstInsert && - !is_contained(E->Scalars, cast<Instruction>(Scalar)->getOperand(0))) - FirstInsert = cast<Instruction>(Scalar); + Value *Scalar = E->Scalars[PrevMask[I]]; Optional<int> InsertIdx = getInsertIndex(Scalar, 0); if (!InsertIdx || *InsertIdx == UndefMaskElem) continue; - unsigned Idx = *InsertIdx; - if (Idx < Offset) { - Offset = Idx; - IsIdentity &= I == 0; - } else { - assert(Idx >= Offset && "Failed to find vector index offset"); - IsIdentity &= Idx - Offset == I; - } - } - assert(Offset < NumElts && "Failed to find vector index offset"); - - // Create shuffle to resize vector - SmallVector<int> Mask(NumElts, UndefMaskElem); - if (!IsIdentity) { - for (unsigned I = 0; I < NumScalars; ++I) { - Value *Scalar = E->Scalars[I]; - Optional<int> InsertIdx = getInsertIndex(Scalar, 0); - if (!InsertIdx || *InsertIdx == UndefMaskElem) - continue; - Mask[*InsertIdx - Offset] = I; - } - } else { - std::iota(Mask.begin(), std::next(Mask.begin(), NumScalars), 0); + IsIdentity &= *InsertIdx - Offset == I; + Mask[*InsertIdx - Offset] = I; } if (!IsIdentity || NumElts != NumScalars) V = Builder.CreateShuffleVector(V, Mask); - if (NumElts != NumScalars) { + if ((!IsIdentity || Offset != 0 || + !isa<UndefValue>(FirstInsert->getOperand(0))) && + NumElts != NumScalars) { SmallVector<int> InsertMask(NumElts); std::iota(InsertMask.begin(), InsertMask.end(), 0); for (unsigned I = 0; I < NumElts; I++) { @@ -5295,6 +6117,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { auto *CI = cast<CastInst>(VL0); Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5317,6 +6140,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); Value *V = Builder.CreateCmp(P0, L, R); propagateIRFlags(V, E->Scalars, VL0); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5337,6 +6161,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V = Builder.CreateSelect(Cond, True, False); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5360,6 +6185,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5403,6 +6229,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5414,9 +6241,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Load: { // Loads are inserted at the head of the tree because we don't want to // sink them all the way down past store instructions. - bool IsReorder = E->updateStateIfReorder(); - if (IsReorder) - VL0 = E->getMainOp(); setInsertPointAfterBundle(E); LoadInst *LI = cast<LoadInst>(VL0); @@ -5457,9 +6281,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::Store: { - bool IsReorder = !E->ReorderIndices.empty(); - auto *SI = cast<StoreInst>( - IsReorder ? E->Scalars[E->ReorderIndices.front()] : VL0); + auto *SI = cast<StoreInst>(VL0); unsigned AS = SI->getPointerAddressSpace(); setInsertPointAfterBundle(E); @@ -5491,37 +6313,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::GetElementPtr: { + auto *GEP0 = cast<GetElementPtrInst>(VL0); setInsertPointAfterBundle(E); Value *Op0 = vectorizeTree(E->getOperand(0)); - std::vector<Value *> OpVecs; - for (int j = 1, e = cast<GetElementPtrInst>(VL0)->getNumOperands(); j < e; - ++j) { - ValueList &VL = E->getOperand(j); - // Need to cast all elements to the same type before vectorization to - // avoid crash. - Type *VL0Ty = VL0->getOperand(j)->getType(); - Type *Ty = llvm::all_of( - VL, [VL0Ty](Value *V) { return VL0Ty == V->getType(); }) - ? VL0Ty - : DL->getIndexType(cast<GetElementPtrInst>(VL0) - ->getPointerOperandType() - ->getScalarType()); - for (Value *&V : VL) { - auto *CI = cast<ConstantInt>(V); - V = ConstantExpr::getIntegerCast(CI, Ty, - CI->getValue().isSignBitSet()); - } - Value *OpVec = vectorizeTree(VL); + SmallVector<Value *> OpVecs; + for (int J = 1, N = GEP0->getNumOperands(); J < N; ++J) { + Value *OpVec = vectorizeTree(E->getOperand(J)); OpVecs.push_back(OpVec); } - Value *V = Builder.CreateGEP( - cast<GetElementPtrInst>(VL0)->getSourceElementType(), Op0, OpVecs); + Value *V = Builder.CreateGEP(GEP0->getSourceElementType(), Op0, OpVecs); if (Instruction *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5548,7 +6355,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { std::vector<Value *> OpVecs; SmallVector<Type *, 2> TysForDecl = {FixedVectorType::get(CI->getType(), E->Scalars.size())}; - for (int j = 0, e = CI->getNumArgOperands(); j < e; ++j) { + for (int j = 0, e = CI->arg_size(); j < e; ++j) { ValueList OpVL; // Some intrinsics have scalar arguments. This argument should not be // vectorized. @@ -5594,6 +6401,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } propagateIRFlags(V, E->Scalars, VL0); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -5641,19 +6449,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // Also, gather up main and alt scalar ops to propagate IR flags to // each vector operation. ValueList OpScalars, AltScalars; - unsigned Sz = E->Scalars.size(); - SmallVector<int> Mask(Sz); - for (unsigned I = 0; I < Sz; ++I) { - auto *OpInst = cast<Instruction>(E->Scalars[I]); - assert(E->isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode"); - if (OpInst->getOpcode() == E->getAltOpcode()) { - Mask[I] = Sz + I; - AltScalars.push_back(E->Scalars[I]); - } else { - Mask[I] = I; - OpScalars.push_back(E->Scalars[I]); - } - } + SmallVector<int> Mask; + buildSuffleEntryMask( + E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, + [E](Instruction *I) { + assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); + return I->getOpcode() == E->getAltOpcode(); + }, + Mask, &OpScalars, &AltScalars); propagateIRFlags(V0, OpScalars); propagateIRFlags(V1, AltScalars); @@ -5661,7 +6464,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *V = Builder.CreateShuffleVector(V0, V1, Mask); if (Instruction *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); E->VectorizedValue = V; @@ -5836,7 +6638,9 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { LLVM_DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); // It is legal to delete users in the ignorelist. - assert((getTreeEntry(U) || is_contained(UserIgnoreList, U)) && + assert((getTreeEntry(U) || is_contained(UserIgnoreList, U) || + (isa_and_nonnull<Instruction>(U) && + isDeleted(cast<Instruction>(U)))) && "Deleting out-of-tree value"); } } @@ -5911,27 +6715,28 @@ void BoUpSLP::optimizeGatherSequence() { "Worklist not sorted properly!"); BasicBlock *BB = (*I)->getBlock(); // For all instructions in blocks containing gather sequences: - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) { - Instruction *In = &*it++; - if (isDeleted(In)) + for (Instruction &In : llvm::make_early_inc_range(*BB)) { + if (isDeleted(&In)) continue; - if (!isa<InsertElementInst>(In) && !isa<ExtractElementInst>(In)) + if (!isa<InsertElementInst>(&In) && !isa<ExtractElementInst>(&In) && + !isa<ShuffleVectorInst>(&In)) continue; // Check if we can replace this instruction with any of the // visited instructions. + bool Replaced = false; for (Instruction *v : Visited) { - if (In->isIdenticalTo(v) && - DT->dominates(v->getParent(), In->getParent())) { - In->replaceAllUsesWith(v); - eraseInstruction(In); - In = nullptr; + if (In.isIdenticalTo(v) && + DT->dominates(v->getParent(), In.getParent())) { + In.replaceAllUsesWith(v); + eraseInstruction(&In); + Replaced = true; break; } } - if (In) { - assert(!is_contained(Visited, In)); - Visited.push_back(In); + if (!Replaced) { + assert(!is_contained(Visited, &In)); + Visited.push_back(&In); } } } @@ -5944,7 +6749,9 @@ void BoUpSLP::optimizeGatherSequence() { Optional<BoUpSLP::ScheduleData *> BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, const InstructionsState &S) { - if (isa<PHINode>(S.OpValue) || isa<InsertElementInst>(S.OpValue)) + // No need to schedule PHIs, insertelement, extractelement and extractvalue + // instructions. + if (isa<PHINode>(S.OpValue) || isVectorLikeInstWithConstOps(S.OpValue)) return nullptr; // Initialize the instruction bundle. @@ -6040,7 +6847,7 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, Value *OpValue) { - if (isa<PHINode>(OpValue) || isa<InsertElementInst>(OpValue)) + if (isa<PHINode>(OpValue) || isVectorLikeInstWithConstOps(OpValue)) return; ScheduleData *Bundle = getScheduleData(OpValue); @@ -6080,8 +6887,9 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, return true; Instruction *I = dyn_cast<Instruction>(V); assert(I && "bundle member must be an instruction"); - assert(!isa<PHINode>(I) && !isa<InsertElementInst>(I) && - "phi nodes/insertelements don't need to be scheduled"); + assert(!isa<PHINode>(I) && !isVectorLikeInstWithConstOps(I) && + "phi nodes/insertelements/extractelements/extractvalues don't need to " + "be scheduled"); auto &&CheckSheduleForI = [this, &S](Instruction *I) -> bool { ScheduleData *ISD = getScheduleData(I); if (!ISD) @@ -6351,7 +7159,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { for (auto *I = BS->ScheduleStart; I != BS->ScheduleEnd; I = I->getNextNode()) { BS->doForAllOpcodes(I, [this, &Idx, &NumToSchedule, BS](ScheduleData *SD) { - assert((isa<InsertElementInst>(SD->Inst) || + assert((isVectorLikeInstWithConstOps(SD->Inst) || SD->isPartOfBundle() == (getTreeEntry(SD->Inst) != nullptr)) && "scheduler and vectorizer bundle mismatch"); SD->FirstInBundle->SchedulingPriority = Idx++; @@ -6694,9 +7502,7 @@ struct SLPVectorizer : public FunctionPass { initializeSLPVectorizerPass(*PassRegistry::getPassRegistry()); } - bool doInitialization(Module &M) override { - return false; - } + bool doInitialization(Module &M) override { return false; } bool runOnFunction(Function &F) override { if (skipFunction(F)) @@ -6831,44 +7637,6 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, return Changed; } -/// Order may have elements assigned special value (size) which is out of -/// bounds. Such indices only appear on places which correspond to undef values -/// (see canReuseExtract for details) and used in order to avoid undef values -/// have effect on operands ordering. -/// The first loop below simply finds all unused indices and then the next loop -/// nest assigns these indices for undef values positions. -/// As an example below Order has two undef positions and they have assigned -/// values 3 and 7 respectively: -/// before: 6 9 5 4 9 2 1 0 -/// after: 6 3 5 4 7 2 1 0 -/// \returns Fixed ordering. -static BoUpSLP::OrdersType fixupOrderingIndices(ArrayRef<unsigned> Order) { - BoUpSLP::OrdersType NewOrder(Order.begin(), Order.end()); - const unsigned Sz = NewOrder.size(); - SmallBitVector UsedIndices(Sz); - SmallVector<int> MaskedIndices; - for (int I = 0, E = NewOrder.size(); I < E; ++I) { - if (NewOrder[I] < Sz) - UsedIndices.set(NewOrder[I]); - else - MaskedIndices.push_back(I); - } - if (MaskedIndices.empty()) - return NewOrder; - SmallVector<int> AvailableIndices(MaskedIndices.size()); - unsigned Cnt = 0; - int Idx = UsedIndices.find_first(); - do { - AvailableIndices[Cnt] = Idx; - Idx = UsedIndices.find_next(Idx); - ++Cnt; - } while (Idx > 0); - assert(Cnt == MaskedIndices.size() && "Non-synced masked/available indices."); - for (int I = 0, E = MaskedIndices.size(); I < E; ++I) - NewOrder[MaskedIndices[I]] = AvailableIndices[I]; - return NewOrder; -} - bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, unsigned Idx) { LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << Chain.size() @@ -6884,19 +7652,13 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, << "\n"); R.buildTree(Chain); - Optional<ArrayRef<unsigned>> Order = R.bestOrder(); - // TODO: Handle orders of size less than number of elements in the vector. - if (Order && Order->size() == Chain.size()) { - // TODO: reorder tree nodes without tree rebuilding. - SmallVector<Value *, 4> ReorderedOps(Chain.size()); - transform(fixupOrderingIndices(*Order), ReorderedOps.begin(), - [Chain](const unsigned Idx) { return Chain[Idx]; }); - R.buildTree(ReorderedOps); - } if (R.isTreeTinyAndNotFullyVectorizable()) return false; if (R.isLoadCombineCandidate()) return false; + R.reorderTopToBottom(); + R.reorderBottomToTop(); + R.buildExternalUses(); R.computeMinimumValueSizes(); @@ -7019,7 +7781,7 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, unsigned EltSize = R.getVectorElementSize(Operands[0]); unsigned MaxElts = llvm::PowerOf2Floor(MaxVecRegSize / EltSize); - unsigned MinVF = std::max(2U, R.getMinVecRegSize() / EltSize); + unsigned MinVF = R.getMinVF(EltSize); unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts); @@ -7092,11 +7854,11 @@ bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { if (!A || !B) return false; Value *VL[] = {A, B}; - return tryToVectorizeList(VL, R, /*AllowReorder=*/true); + return tryToVectorizeList(VL, R); } bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - bool AllowReorder) { + bool LimitForRegisterSize) { if (VL.size() < 2) return false; @@ -7130,7 +7892,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, } unsigned Sz = R.getVectorElementSize(I0); - unsigned MinVF = std::max(2U, R.getMinVecRegSize() / Sz); + unsigned MinVF = R.getMinVF(Sz); unsigned MaxVF = std::max<unsigned>(PowerOf2Floor(VL.size()), MinVF); MaxVF = std::min(R.getMaximumVF(Sz, S.getOpcode()), MaxVF); if (MaxVF < 2) { @@ -7168,7 +7930,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (!isPowerOf2_32(OpsWidth)) continue; - if ((VF > MinVF && OpsWidth <= VF / 2) || (VF == MinVF && OpsWidth < 2)) + if ((LimitForRegisterSize && OpsWidth < MaxVF) || + (VF > MinVF && OpsWidth <= VF / 2) || (VF == MinVF && OpsWidth < 2)) break; ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); @@ -7183,18 +7946,11 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, << "\n"); R.buildTree(Ops); - if (AllowReorder) { - Optional<ArrayRef<unsigned>> Order = R.bestOrder(); - if (Order) { - // TODO: reorder tree nodes without tree rebuilding. - SmallVector<Value *, 4> ReorderedOps(Ops.size()); - transform(fixupOrderingIndices(*Order), ReorderedOps.begin(), - [Ops](const unsigned Idx) { return Ops[Idx]; }); - R.buildTree(ReorderedOps); - } - } if (R.isTreeTinyAndNotFullyVectorizable()) continue; + R.reorderTopToBottom(); + R.reorderBottomToTop(); + R.buildExternalUses(); R.computeMinimumValueSizes(); InstructionCost Cost = R.getTreeCost(); @@ -7387,10 +8143,20 @@ class HorizontalReduction { Value *RHS, const Twine &Name, bool UseSelect) { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); switch (Kind) { - case RecurKind::Add: - case RecurKind::Mul: case RecurKind::Or: + if (UseSelect && + LHS->getType() == CmpInst::makeCmpResultType(LHS->getType())) + return Builder.CreateSelect(LHS, Builder.getTrue(), RHS, Name); + return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, + Name); case RecurKind::And: + if (UseSelect && + LHS->getType() == CmpInst::makeCmpResultType(LHS->getType())) + return Builder.CreateSelect(LHS, RHS, Builder.getFalse(), Name); + return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, + Name); + case RecurKind::Add: + case RecurKind::Mul: case RecurKind::Xor: case RecurKind::FAdd: case RecurKind::FMul: @@ -7434,8 +8200,12 @@ class HorizontalReduction { static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, Value *RHS, const Twine &Name, const ReductionOpsListType &ReductionOps) { - bool UseSelect = ReductionOps.size() == 2; - assert((!UseSelect || isa<SelectInst>(ReductionOps[1][0])) && + bool UseSelect = ReductionOps.size() == 2 || + // Logical or/and. + (ReductionOps.size() == 1 && + isa<SelectInst>(ReductionOps.front().front())); + assert((!UseSelect || ReductionOps.size() != 2 || + isa<SelectInst>(ReductionOps[1][0])) && "Expected cmp + select pairs for reduction"); Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name, UseSelect); if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) { @@ -7573,10 +8343,10 @@ class HorizontalReduction { /// Checks if the instruction is in basic block \p BB. /// For a cmp+sel min/max reduction check that both ops are in \p BB. static bool hasSameParent(Instruction *I, BasicBlock *BB) { - if (isCmpSelMinMax(I)) { + if (isCmpSelMinMax(I) || (isBoolLogicOp(I) && isa<SelectInst>(I))) { auto *Sel = cast<SelectInst>(I); - auto *Cmp = cast<Instruction>(Sel->getCondition()); - return Sel->getParent() == BB && Cmp->getParent() == BB; + auto *Cmp = dyn_cast<Instruction>(Sel->getCondition()); + return Sel->getParent() == BB && Cmp && Cmp->getParent() == BB; } return I->getParent() == BB; } @@ -7758,13 +8528,13 @@ public: } /// Attempt to vectorize the tree found by matchAssociativeReduction. - bool tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI) { + Value *tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI) { // If there are a sufficient number of reduction values, reduce // to a nearby power-of-2. We can safely generate oversized // vectors and rely on the backend to split them to legal sizes. unsigned NumReducedVals = ReducedVals.size(); if (NumReducedVals < 4) - return false; + return nullptr; // Intersect the fast-math-flags from all reduction operations. FastMathFlags RdxFMF; @@ -7838,22 +8608,14 @@ public: unsigned i = 0; while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { ArrayRef<Value *> VL(&ReducedVals[i], ReduxWidth); - V.buildTree(VL, ExternallyUsedValues, IgnoreList); - Optional<ArrayRef<unsigned>> Order = V.bestOrder(); - if (Order) { - assert(Order->size() == VL.size() && - "Order size must be the same as number of vectorized " - "instructions."); - // TODO: reorder tree nodes without tree rebuilding. - SmallVector<Value *, 4> ReorderedOps(VL.size()); - transform(fixupOrderingIndices(*Order), ReorderedOps.begin(), - [VL](const unsigned Idx) { return VL[Idx]; }); - V.buildTree(ReorderedOps, ExternallyUsedValues, IgnoreList); - } - if (V.isTreeTinyAndNotFullyVectorizable()) + V.buildTree(VL, IgnoreList); + if (V.isTreeTinyAndNotFullyVectorizable(/*ForReduction=*/true)) break; if (V.isLoadCombineReductionCandidate(RdxKind)) break; + V.reorderTopToBottom(); + V.reorderBottomToTop(/*IgnoreReorder=*/true); + V.buildExternalUses(ExternallyUsedValues); // For a poison-safe boolean logic reduction, do not replace select // instructions with logic ops. All reduced values will be frozen (see @@ -7873,7 +8635,7 @@ public: InstructionCost Cost = TreeCost + ReductionCost; if (!Cost.isValid()) { LLVM_DEBUG(dbgs() << "Encountered invalid baseline cost.\n"); - return false; + return nullptr; } if (Cost >= -SLPCostThreshold) { V.getORE()->emit([&]() { @@ -7953,7 +8715,7 @@ public: // vector reductions. V.eraseInstructions(IgnoreList); } - return VectorizedTree != nullptr; + return VectorizedTree; } unsigned numReductionValues() const { return ReducedVals.size(); } @@ -7963,6 +8725,7 @@ private: InstructionCost getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal, unsigned ReduxWidth, FastMathFlags FMF) { + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *ScalarTy = FirstReducedVal->getType(); FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); InstructionCost VectorCost, ScalarCost; @@ -7975,33 +8738,39 @@ private: case RecurKind::FAdd: case RecurKind::FMul: { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind); - VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF); - ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); + VectorCost = + TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); + ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); break; } case RecurKind::FMax: case RecurKind::FMin: { + auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, - /*unsigned=*/false); - ScalarCost = - TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); + /*unsigned=*/false, CostKind); + CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); + ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, + SclCondTy, RdxPred, CostKind) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, + SclCondTy, RdxPred, CostKind); break; } case RecurKind::SMax: case RecurKind::SMin: case RecurKind::UMax: case RecurKind::UMin: { + auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); bool IsUnsigned = RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin; - VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, IsUnsigned); - ScalarCost = - TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) + - TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); + VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, IsUnsigned, + CostKind); + CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); + ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, + SclCondTy, RdxPred, CostKind) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, + SclCondTy, RdxPred, CostKind); break; } default: @@ -8023,6 +8792,7 @@ private: assert(isPowerOf2_32(ReduxWidth) && "We only handle power-of-two reductions for now"); + ++NumVectorInstructions; return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind, ReductionOps.back()); } @@ -8232,32 +9002,45 @@ static bool tryToVectorizeHorReductionOrInstOperands( // Skip the analysis of CmpInsts.Compiler implements postanalysis of the // CmpInsts so we can skip extra attempts in // tryToVectorizeHorReductionOrInstOperands and save compile time. - SmallVector<std::pair<Instruction *, unsigned>, 8> Stack(1, {Root, 0}); + std::queue<std::pair<Instruction *, unsigned>> Stack; + Stack.emplace(Root, 0); SmallPtrSet<Value *, 8> VisitedInstrs; + SmallVector<WeakTrackingVH> PostponedInsts; bool Res = false; + auto &&TryToReduce = [TTI, &P, &R](Instruction *Inst, Value *&B0, + Value *&B1) -> Value * { + bool IsBinop = matchRdxBop(Inst, B0, B1); + bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value())); + if (IsBinop || IsSelect) { + HorizontalReduction HorRdx; + if (HorRdx.matchAssociativeReduction(P, Inst)) + return HorRdx.tryToReduce(R, TTI); + } + return nullptr; + }; while (!Stack.empty()) { Instruction *Inst; unsigned Level; - std::tie(Inst, Level) = Stack.pop_back_val(); + std::tie(Inst, Level) = Stack.front(); + Stack.pop(); // Do not try to analyze instruction that has already been vectorized. // This may happen when we vectorize instruction operands on a previous // iteration while stack was populated before that happened. if (R.isDeleted(Inst)) continue; - Value *B0, *B1; - bool IsBinop = matchRdxBop(Inst, B0, B1); - bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value())); - if (IsBinop || IsSelect) { - HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, Inst)) { - if (HorRdx.tryToReduce(R, TTI)) { - Res = true; - // Set P to nullptr to avoid re-analysis of phi node in - // matchAssociativeReduction function unless this is the root node. - P = nullptr; - continue; - } + Value *B0 = nullptr, *B1 = nullptr; + if (Value *V = TryToReduce(Inst, B0, B1)) { + Res = true; + // Set P to nullptr to avoid re-analysis of phi node in + // matchAssociativeReduction function unless this is the root node. + P = nullptr; + if (auto *I = dyn_cast<Instruction>(V)) { + // Try to find another reduction. + Stack.emplace(I, Level); + continue; } + } else { + bool IsBinop = B0 && B1; if (P && IsBinop) { Inst = dyn_cast<Instruction>(B0); if (Inst == P) @@ -8269,14 +9052,14 @@ static bool tryToVectorizeHorReductionOrInstOperands( continue; } } - } - // Set P to nullptr to avoid re-analysis of phi node in - // matchAssociativeReduction function unless this is the root node. - P = nullptr; - // Do not try to vectorize CmpInst operands, this is done separately. - if (!isa<CmpInst>(Inst) && Vectorize(Inst, R)) { - Res = true; - continue; + // Set P to nullptr to avoid re-analysis of phi node in + // matchAssociativeReduction function unless this is the root node. + P = nullptr; + // Do not try to vectorize CmpInst operands, this is done separately. + // Final attempt for binop args vectorization should happen after the loop + // to try to find reductions. + if (!isa<CmpInst>(Inst)) + PostponedInsts.push_back(Inst); } // Try to vectorize operands. @@ -8290,8 +9073,13 @@ static bool tryToVectorizeHorReductionOrInstOperands( // separately. if (!isa<PHINode>(I) && !isa<CmpInst>(I) && !R.isDeleted(I) && I->getParent() == BB) - Stack.emplace_back(I, Level); + Stack.emplace(I, Level); } + // Try to vectorized binops where reductions were not found. + for (Value *V : PostponedInsts) + if (auto *Inst = dyn_cast<Instruction>(V)) + if (!R.isDeleted(Inst)) + Res |= Vectorize(Inst, R); return Res; } @@ -8326,7 +9114,7 @@ bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, LLVM_DEBUG(dbgs() << "SLP: array mappable to vector: " << *IVI << "\n"); // Aggregate value is unlikely to be processed in vector register, we need to // extract scalars into scalar registers, so NeedExtraction is set true. - return tryToVectorizeList(BuildVectorOpds, R, /*AllowReorder=*/false); + return tryToVectorizeList(BuildVectorOpds, R); } bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI, @@ -8337,11 +9125,11 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI, if (!findBuildAggregate(IEI, TTI, BuildVectorOpds, BuildVectorInsts) || (llvm::all_of(BuildVectorOpds, [](Value *V) { return isa<ExtractElementInst>(V); }) && - isShuffle(BuildVectorOpds, Mask))) + isFixedVectorShuffle(BuildVectorOpds, Mask))) return false; LLVM_DEBUG(dbgs() << "SLP: array mappable to vector: " << *IEI << "\n"); - return tryToVectorizeList(BuildVectorInsts, R, /*AllowReorder=*/true); + return tryToVectorizeList(BuildVectorInsts, R); } bool SLPVectorizerPass::vectorizeSimpleInstructions( @@ -8382,6 +9170,78 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions( return OpsChanged; } +template <typename T> +static bool +tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, + function_ref<unsigned(T *)> Limit, + function_ref<bool(T *, T *)> Comparator, + function_ref<bool(T *, T *)> AreCompatible, + function_ref<bool(ArrayRef<T *>, bool)> TryToVectorize, + bool LimitForRegisterSize) { + bool Changed = false; + // Sort by type, parent, operands. + stable_sort(Incoming, Comparator); + + // Try to vectorize elements base on their type. + SmallVector<T *> Candidates; + for (auto *IncIt = Incoming.begin(), *E = Incoming.end(); IncIt != E;) { + // Look for the next elements with the same type, parent and operand + // kinds. + auto *SameTypeIt = IncIt; + while (SameTypeIt != E && AreCompatible(*SameTypeIt, *IncIt)) + ++SameTypeIt; + + // Try to vectorize them. + unsigned NumElts = (SameTypeIt - IncIt); + LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at nodes (" + << NumElts << ")\n"); + // The vectorization is a 3-state attempt: + // 1. Try to vectorize instructions with the same/alternate opcodes with the + // size of maximal register at first. + // 2. Try to vectorize remaining instructions with the same type, if + // possible. This may result in the better vectorization results rather than + // if we try just to vectorize instructions with the same/alternate opcodes. + // 3. Final attempt to try to vectorize all instructions with the + // same/alternate ops only, this may result in some extra final + // vectorization. + if (NumElts > 1 && + TryToVectorize(makeArrayRef(IncIt, NumElts), LimitForRegisterSize)) { + // Success start over because instructions might have been changed. + Changed = true; + } else if (NumElts < Limit(*IncIt) && + (Candidates.empty() || + Candidates.front()->getType() == (*IncIt)->getType())) { + Candidates.append(IncIt, std::next(IncIt, NumElts)); + } + // Final attempt to vectorize instructions with the same types. + if (Candidates.size() > 1 && + (SameTypeIt == E || (*SameTypeIt)->getType() != (*IncIt)->getType())) { + if (TryToVectorize(Candidates, /*LimitForRegisterSize=*/false)) { + // Success start over because instructions might have been changed. + Changed = true; + } else if (LimitForRegisterSize) { + // Try to vectorize using small vectors. + for (auto *It = Candidates.begin(), *End = Candidates.end(); + It != End;) { + auto *SameTypeIt = It; + while (SameTypeIt != End && AreCompatible(*SameTypeIt, *It)) + ++SameTypeIt; + unsigned NumElts = (SameTypeIt - It); + if (NumElts > 1 && TryToVectorize(makeArrayRef(It, NumElts), + /*LimitForRegisterSize=*/false)) + Changed = true; + It = SameTypeIt; + } + } + Candidates.clear(); + } + + // Start over at the next instruction of a different type (or the end). + IncIt = SameTypeIt; + } + return Changed; +} + bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { bool Changed = false; SmallVector<Value *, 4> Incoming; @@ -8390,11 +9250,89 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // node. Allows better to identify the chains that can be vectorized in the // better way. DenseMap<Value *, SmallVector<Value *, 4>> PHIToOpcodes; + auto PHICompare = [this, &PHIToOpcodes](Value *V1, Value *V2) { + assert(isValidElementType(V1->getType()) && + isValidElementType(V2->getType()) && + "Expected vectorizable types only."); + // It is fine to compare type IDs here, since we expect only vectorizable + // types, like ints, floats and pointers, we don't care about other type. + if (V1->getType()->getTypeID() < V2->getType()->getTypeID()) + return true; + if (V1->getType()->getTypeID() > V2->getType()->getTypeID()) + return false; + ArrayRef<Value *> Opcodes1 = PHIToOpcodes[V1]; + ArrayRef<Value *> Opcodes2 = PHIToOpcodes[V2]; + if (Opcodes1.size() < Opcodes2.size()) + return true; + if (Opcodes1.size() > Opcodes2.size()) + return false; + for (int I = 0, E = Opcodes1.size(); I < E; ++I) { + // Undefs are compatible with any other value. + if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) + continue; + if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) + if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { + DomTreeNodeBase<BasicBlock> *NodeI1 = DT->getNode(I1->getParent()); + DomTreeNodeBase<BasicBlock> *NodeI2 = DT->getNode(I2->getParent()); + if (!NodeI1) + return NodeI2 != nullptr; + if (!NodeI2) + return false; + assert((NodeI1 == NodeI2) == + (NodeI1->getDFSNumIn() == NodeI2->getDFSNumIn()) && + "Different nodes should have different DFS numbers"); + if (NodeI1 != NodeI2) + return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); + InstructionsState S = getSameOpcode({I1, I2}); + if (S.getOpcode()) + continue; + return I1->getOpcode() < I2->getOpcode(); + } + if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) + continue; + if (Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID()) + return true; + if (Opcodes1[I]->getValueID() > Opcodes2[I]->getValueID()) + return false; + } + return false; + }; + auto AreCompatiblePHIs = [&PHIToOpcodes](Value *V1, Value *V2) { + if (V1 == V2) + return true; + if (V1->getType() != V2->getType()) + return false; + ArrayRef<Value *> Opcodes1 = PHIToOpcodes[V1]; + ArrayRef<Value *> Opcodes2 = PHIToOpcodes[V2]; + if (Opcodes1.size() != Opcodes2.size()) + return false; + for (int I = 0, E = Opcodes1.size(); I < E; ++I) { + // Undefs are compatible with any other value. + if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) + continue; + if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) + if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { + if (I1->getParent() != I2->getParent()) + return false; + InstructionsState S = getSameOpcode({I1, I2}); + if (S.getOpcode()) + continue; + return false; + } + if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) + continue; + if (Opcodes1[I]->getValueID() != Opcodes2[I]->getValueID()) + return false; + } + return true; + }; + auto Limit = [&R](Value *V) { + unsigned EltSize = R.getVectorElementSize(V); + return std::max(2U, R.getMaxVecRegSize() / EltSize); + }; - bool HaveVectorizedPhiNodes = true; - while (HaveVectorizedPhiNodes) { - HaveVectorizedPhiNodes = false; - + bool HaveVectorizedPhiNodes = false; + do { // Collect the incoming values from the PHIs. Incoming.clear(); for (Instruction &I : *BB) { @@ -8432,132 +9370,15 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } } - // Sort by type, parent, operands. - stable_sort(Incoming, [this, &PHIToOpcodes](Value *V1, Value *V2) { - assert(isValidElementType(V1->getType()) && - isValidElementType(V2->getType()) && - "Expected vectorizable types only."); - // It is fine to compare type IDs here, since we expect only vectorizable - // types, like ints, floats and pointers, we don't care about other type. - if (V1->getType()->getTypeID() < V2->getType()->getTypeID()) - return true; - if (V1->getType()->getTypeID() > V2->getType()->getTypeID()) - return false; - ArrayRef<Value *> Opcodes1 = PHIToOpcodes[V1]; - ArrayRef<Value *> Opcodes2 = PHIToOpcodes[V2]; - if (Opcodes1.size() < Opcodes2.size()) - return true; - if (Opcodes1.size() > Opcodes2.size()) - return false; - for (int I = 0, E = Opcodes1.size(); I < E; ++I) { - // Undefs are compatible with any other value. - if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) - continue; - if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) - if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { - DomTreeNodeBase<BasicBlock> *NodeI1 = DT->getNode(I1->getParent()); - DomTreeNodeBase<BasicBlock> *NodeI2 = DT->getNode(I2->getParent()); - if (!NodeI1) - return NodeI2 != nullptr; - if (!NodeI2) - return false; - assert((NodeI1 == NodeI2) == - (NodeI1->getDFSNumIn() == NodeI2->getDFSNumIn()) && - "Different nodes should have different DFS numbers"); - if (NodeI1 != NodeI2) - return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); - InstructionsState S = getSameOpcode({I1, I2}); - if (S.getOpcode()) - continue; - return I1->getOpcode() < I2->getOpcode(); - } - if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) - continue; - if (Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID()) - return true; - if (Opcodes1[I]->getValueID() > Opcodes2[I]->getValueID()) - return false; - } - return false; - }); - - auto &&AreCompatiblePHIs = [&PHIToOpcodes](Value *V1, Value *V2) { - if (V1 == V2) - return true; - if (V1->getType() != V2->getType()) - return false; - ArrayRef<Value *> Opcodes1 = PHIToOpcodes[V1]; - ArrayRef<Value *> Opcodes2 = PHIToOpcodes[V2]; - if (Opcodes1.size() != Opcodes2.size()) - return false; - for (int I = 0, E = Opcodes1.size(); I < E; ++I) { - // Undefs are compatible with any other value. - if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) - continue; - if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) - if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { - if (I1->getParent() != I2->getParent()) - return false; - InstructionsState S = getSameOpcode({I1, I2}); - if (S.getOpcode()) - continue; - return false; - } - if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) - continue; - if (Opcodes1[I]->getValueID() != Opcodes2[I]->getValueID()) - return false; - } - return true; - }; - - // Try to vectorize elements base on their type. - SmallVector<Value *, 4> Candidates; - for (SmallVector<Value *, 4>::iterator IncIt = Incoming.begin(), - E = Incoming.end(); - IncIt != E;) { - - // Look for the next elements with the same type, parent and operand - // kinds. - SmallVector<Value *, 4>::iterator SameTypeIt = IncIt; - while (SameTypeIt != E && AreCompatiblePHIs(*SameTypeIt, *IncIt)) { - VisitedInstrs.insert(*SameTypeIt); - ++SameTypeIt; - } - - // Try to vectorize them. - unsigned NumElts = (SameTypeIt - IncIt); - LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at PHIs (" - << NumElts << ")\n"); - // The order in which the phi nodes appear in the program does not matter. - // So allow tryToVectorizeList to reorder them if it is beneficial. This - // is done when there are exactly two elements since tryToVectorizeList - // asserts that there are only two values when AllowReorder is true. - if (NumElts > 1 && tryToVectorizeList(makeArrayRef(IncIt, NumElts), R, - /*AllowReorder=*/true)) { - // Success start over because instructions might have been changed. - HaveVectorizedPhiNodes = true; - Changed = true; - } else if (NumElts < 4 && - (Candidates.empty() || - Candidates.front()->getType() == (*IncIt)->getType())) { - Candidates.append(IncIt, std::next(IncIt, NumElts)); - } - // Final attempt to vectorize phis with the same types. - if (SameTypeIt == E || (*SameTypeIt)->getType() != (*IncIt)->getType()) { - if (Candidates.size() > 1 && - tryToVectorizeList(Candidates, R, /*AllowReorder=*/true)) { - // Success start over because instructions might have been changed. - HaveVectorizedPhiNodes = true; - Changed = true; - } - Candidates.clear(); - } - - // Start over at the next instruction of a different type (or the end). - IncIt = SameTypeIt; - } - } + HaveVectorizedPhiNodes = tryToVectorizeSequence<Value>( + Incoming, Limit, PHICompare, AreCompatiblePHIs, + [this, &R](ArrayRef<Value *> Candidates, bool LimitForRegisterSize) { + return tryToVectorizeList(Candidates, R, LimitForRegisterSize); + }, + /*LimitForRegisterSize=*/true); + Changed |= HaveVectorizedPhiNodes; + VisitedInstrs.insert(Incoming.begin(), Incoming.end()); + } while (HaveVectorizedPhiNodes); VisitedInstrs.clear(); @@ -8810,6 +9631,10 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { return V1->getValueOperand()->getValueID() == V2->getValueOperand()->getValueID(); }; + auto Limit = [&R, this](StoreInst *SI) { + unsigned EltSize = DL->getTypeSizeInBits(SI->getValueOperand()->getType()); + return R.getMinVF(EltSize); + }; // Attempt to sort and vectorize each of the store-groups. for (auto &Pair : Stores) { @@ -8819,33 +9644,15 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << Pair.second.size() << ".\n"); - stable_sort(Pair.second, StoreSorter); - - // Try to vectorize elements based on their compatibility. - for (ArrayRef<StoreInst *>::iterator IncIt = Pair.second.begin(), - E = Pair.second.end(); - IncIt != E;) { - - // Look for the next elements with the same type. - ArrayRef<StoreInst *>::iterator SameTypeIt = IncIt; - Type *EltTy = (*IncIt)->getPointerOperand()->getType(); - - while (SameTypeIt != E && AreCompatibleStores(*SameTypeIt, *IncIt)) - ++SameTypeIt; - - // Try to vectorize them. - unsigned NumElts = (SameTypeIt - IncIt); - LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at stores (" - << NumElts << ")\n"); - if (NumElts > 1 && !EltTy->getPointerElementType()->isVectorTy() && - vectorizeStores(makeArrayRef(IncIt, NumElts), R)) { - // Success start over because instructions might have been changed. - Changed = true; - } + if (!isValidElementType(Pair.second.front()->getValueOperand()->getType())) + continue; - // Start over at the next instruction of a different type (or the end). - IncIt = SameTypeIt; - } + Changed |= tryToVectorizeSequence<StoreInst>( + Pair.second, Limit, StoreSorter, AreCompatibleStores, + [this, &R](ArrayRef<StoreInst *> Candidates, bool) { + return vectorizeStores(Candidates, R); + }, + /*LimitForRegisterSize=*/false); } return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp index 5f39fe1c17a3..638467f94e1c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -815,6 +815,28 @@ void VPlan::execute(VPTransformState *State) { for (VPBlockBase *Block : depth_first(Entry)) Block->execute(State); + // Fix the latch value of reduction and first-order recurrences phis in the + // vector loop. + VPBasicBlock *Header = Entry->getEntryBasicBlock(); + for (VPRecipeBase &R : Header->phis()) { + auto *PhiR = dyn_cast<VPWidenPHIRecipe>(&R); + if (!PhiR || !(isa<VPFirstOrderRecurrencePHIRecipe>(&R) || + isa<VPReductionPHIRecipe>(&R))) + continue; + // For first-order recurrences and in-order reduction phis, only a single + // part is generated, which provides the last part from the previous + // iteration. Otherwise all UF parts are generated. + bool SinglePartNeeded = isa<VPFirstOrderRecurrencePHIRecipe>(&R) || + cast<VPReductionPHIRecipe>(&R)->isOrdered(); + unsigned LastPartForNewPhi = SinglePartNeeded ? 1 : State->UF; + for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { + Value *VecPhi = State->get(PhiR, Part); + Value *Val = State->get(PhiR->getBackedgeValue(), + SinglePartNeeded ? State->UF - 1 : Part); + cast<PHINode>(VecPhi)->addIncoming(Val, VectorLatchBB); + } + } + // Setup branch terminator successors for VPBBs in VPBBsToFix based on // VPBB's successors. for (auto VPBB : State->CFG.VPBBsToFix) { @@ -862,6 +884,13 @@ void VPlan::print(raw_ostream &O) const { VPSlotTracker SlotTracker(this); O << "VPlan '" << Name << "' {"; + + if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) { + O << "\nLive-in "; + BackedgeTakenCount->printAsOperand(O, SlotTracker); + O << " = backedge-taken count\n"; + } + for (const VPBlockBase *Block : depth_first(getEntry())) { O << '\n'; Block->print(O, "", SlotTracker); @@ -920,12 +949,12 @@ void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -const Twine VPlanPrinter::getUID(const VPBlockBase *Block) { +Twine VPlanPrinter::getUID(const VPBlockBase *Block) { return (isa<VPRegionBlock>(Block) ? "cluster_N" : "N") + Twine(getOrCreateBID(Block)); } -const Twine VPlanPrinter::getOrCreateName(const VPBlockBase *Block) { +Twine VPlanPrinter::getOrCreateName(const VPBlockBase *Block) { const std::string &Name = Block->getName(); if (!Name.empty()) return Name; @@ -1235,7 +1264,7 @@ void VPWidenCanonicalIVRecipe::execute(VPTransformState &State) { VF.isScalar() ? Indices.back() : ConstantVector::get(Indices); // Add the consecutive indices to the vector value. Value *CanonicalVectorIV = Builder.CreateAdd(VStart, VStep, "vec.iv"); - State.set(getVPSingleValue(), CanonicalVectorIV, Part); + State.set(this, CanonicalVectorIV, Part); } } @@ -1243,7 +1272,7 @@ void VPWidenCanonicalIVRecipe::execute(VPTransformState &State) { void VPWidenCanonicalIVRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "EMIT "; - getVPSingleValue()->printAsOperand(O, SlotTracker); + printAsOperand(O, SlotTracker); O << " = WIDEN-CANONICAL-INDUCTION"; } #endif @@ -1306,12 +1335,16 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { PHINode::Create(VecTy, 2, "vec.phi", &*HeaderBB->getFirstInsertionPt()); State.set(this, EntryPart, Part); } + + // Reductions do not have to start at zero. They can start with + // any loop invariant values. VPValue *StartVPV = getStartValue(); Value *StartV = StartVPV->getLiveInIRValue(); Value *Iden = nullptr; RecurKind RK = RdxDesc.getRecurrenceKind(); - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK)) { + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) || + RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) { // MinMax reduction have the start value as their identify. if (ScalarPHI) { Iden = StartV; @@ -1322,12 +1355,11 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident"); } } else { - Constant *IdenC = RecurrenceDescriptor::getRecurrenceIdentity( - RK, VecTy->getScalarType(), RdxDesc.getFastMathFlags()); - Iden = IdenC; + Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(), + RdxDesc.getFastMathFlags()); if (!ScalarPHI) { - Iden = ConstantVector::getSplat(State.VF, IdenC); + Iden = Builder.CreateVectorSplat(State.VF, Iden); IRBuilderBase::InsertPointGuard IPBuilder(Builder); Builder.SetInsertPoint(State.CFG.VectorPreHeader->getTerminator()); Constant *Zero = Builder.getInt32(0); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h index bdf09d15c27f..00ee31007cb7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1312,7 +1312,7 @@ public: // The first operand is the address, followed by the stored values, followed // by an optional mask. return ArrayRef<VPValue *>(op_begin(), getNumOperands()) - .slice(1, getNumOperands() - (HasMask ? 2 : 1)); + .slice(1, getNumStoreOperands()); } /// Generate the wide load or store, and shuffles. @@ -1325,6 +1325,12 @@ public: #endif const InterleaveGroup<Instruction> *getInterleaveGroup() { return IG; } + + /// Returns the number of stored operands of this interleave group. Returns 0 + /// for load interleave groups. + unsigned getNumStoreOperands() const { + return getNumOperands() - (HasMask ? 2 : 1); + } }; /// A recipe to represent inloop reduction operations, performing a reduction on @@ -1508,6 +1514,12 @@ public: class VPWidenMemoryInstructionRecipe : public VPRecipeBase { Instruction &Ingredient; + // Whether the loaded-from / stored-to addresses are consecutive. + bool Consecutive; + + // Whether the consecutive loaded/stored addresses are in reverse order. + bool Reverse; + void setMask(VPValue *Mask) { if (!Mask) return; @@ -1519,16 +1531,21 @@ class VPWidenMemoryInstructionRecipe : public VPRecipeBase { } public: - VPWidenMemoryInstructionRecipe(LoadInst &Load, VPValue *Addr, VPValue *Mask) - : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr}), Ingredient(Load) { + VPWidenMemoryInstructionRecipe(LoadInst &Load, VPValue *Addr, VPValue *Mask, + bool Consecutive, bool Reverse) + : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr}), Ingredient(Load), + Consecutive(Consecutive), Reverse(Reverse) { + assert((Consecutive || !Reverse) && "Reverse implies consecutive"); new VPValue(VPValue::VPVMemoryInstructionSC, &Load, this); setMask(Mask); } VPWidenMemoryInstructionRecipe(StoreInst &Store, VPValue *Addr, - VPValue *StoredValue, VPValue *Mask) + VPValue *StoredValue, VPValue *Mask, + bool Consecutive, bool Reverse) : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr, StoredValue}), - Ingredient(Store) { + Ingredient(Store), Consecutive(Consecutive), Reverse(Reverse) { + assert((Consecutive || !Reverse) && "Reverse implies consecutive"); setMask(Mask); } @@ -1558,6 +1575,13 @@ public: return getOperand(1); // Stored value is the 2nd, mandatory operand. } + // Return whether the loaded-from / stored-to addresses are consecutive. + bool isConsecutive() const { return Consecutive; } + + // Return whether the consecutive loaded/stored addresses are in reverse + // order. + bool isReverse() const { return Reverse; } + /// Generate the wide load/store. void execute(VPTransformState &State) override; @@ -1569,11 +1593,11 @@ public: }; /// A Recipe for widening the canonical induction variable of the vector loop. -class VPWidenCanonicalIVRecipe : public VPRecipeBase { +class VPWidenCanonicalIVRecipe : public VPRecipeBase, public VPValue { public: - VPWidenCanonicalIVRecipe() : VPRecipeBase(VPWidenCanonicalIVSC, {}) { - new VPValue(nullptr, this); - } + VPWidenCanonicalIVRecipe() + : VPRecipeBase(VPWidenCanonicalIVSC, {}), + VPValue(VPValue::VPVWidenCanonicalIVSC, nullptr, this) {} ~VPWidenCanonicalIVRecipe() override = default; @@ -2094,6 +2118,10 @@ class VPlan { /// Holds the VPLoopInfo analysis for this VPlan. VPLoopInfo VPLInfo; + /// Indicates whether it is safe use the Value2VPValue mapping or if the + /// mapping cannot be used any longer, because it is stale. + bool Value2VPValueEnabled = true; + public: VPlan(VPBlockBase *Entry = nullptr) : Entry(Entry) { if (Entry) @@ -2135,6 +2163,10 @@ public: return BackedgeTakenCount; } + /// Mark the plan to indicate that using Value2VPValue is not safe any + /// longer, because it may be stale. + void disableValue2VPValue() { Value2VPValueEnabled = false; } + void addVF(ElementCount VF) { VFs.insert(VF); } bool hasVF(ElementCount VF) { return VFs.count(VF); } @@ -2148,6 +2180,8 @@ public: void addExternalDef(VPValue *VPVal) { VPExternalDefs.insert(VPVal); } void addVPValue(Value *V) { + assert(Value2VPValueEnabled && + "IR value to VPValue mapping may be out of date!"); assert(V && "Trying to add a null Value to VPlan"); assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); VPValue *VPV = new VPValue(V); @@ -2156,25 +2190,39 @@ public: } void addVPValue(Value *V, VPValue *VPV) { + assert(Value2VPValueEnabled && "Value2VPValue mapping may be out of date!"); assert(V && "Trying to add a null Value to VPlan"); assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); Value2VPValue[V] = VPV; } - VPValue *getVPValue(Value *V) { + /// Returns the VPValue for \p V. \p OverrideAllowed can be used to disable + /// checking whether it is safe to query VPValues using IR Values. + VPValue *getVPValue(Value *V, bool OverrideAllowed = false) { + assert((OverrideAllowed || isa<Constant>(V) || Value2VPValueEnabled) && + "Value2VPValue mapping may be out of date!"); assert(V && "Trying to get the VPValue of a null Value"); assert(Value2VPValue.count(V) && "Value does not exist in VPlan"); return Value2VPValue[V]; } - VPValue *getOrAddVPValue(Value *V) { + /// Gets the VPValue or adds a new one (if none exists yet) for \p V. \p + /// OverrideAllowed can be used to disable checking whether it is safe to + /// query VPValues using IR Values. + VPValue *getOrAddVPValue(Value *V, bool OverrideAllowed = false) { + assert((OverrideAllowed || isa<Constant>(V) || Value2VPValueEnabled) && + "Value2VPValue mapping may be out of date!"); assert(V && "Trying to get or add the VPValue of a null Value"); if (!Value2VPValue.count(V)) addVPValue(V); return getVPValue(V); } - void removeVPValueFor(Value *V) { Value2VPValue.erase(V); } + void removeVPValueFor(Value *V) { + assert(Value2VPValueEnabled && + "IR value to VPValue mapping may be out of date!"); + Value2VPValue.erase(V); + } /// Return the VPLoopInfo analysis for this VPlan. VPLoopInfo &getVPLoopInfo() { return VPLInfo; } @@ -2244,9 +2292,9 @@ class VPlanPrinter { return BlockID.count(Block) ? BlockID[Block] : BlockID[Block] = BID++; } - const Twine getOrCreateName(const VPBlockBase *Block); + Twine getOrCreateName(const VPBlockBase *Block); - const Twine getUID(const VPBlockBase *Block); + Twine getUID(const VPBlockBase *Block); /// Print the information related to a CFG edge between two VPBlockBases. void drawEdge(const VPBlockBase *From, const VPBlockBase *To, bool Hidden, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index c05a8408e1fd..ded5bc04beb5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -31,19 +31,18 @@ void VPlanTransforms::VPInstructionsToVPRecipes( VPBasicBlock *VPBB = Base->getEntryBasicBlock(); // Introduce each ingredient into VPlan. - for (auto I = VPBB->begin(), E = VPBB->end(); I != E;) { - VPRecipeBase *Ingredient = &*I++; - VPValue *VPV = Ingredient->getVPSingleValue(); + for (VPRecipeBase &Ingredient : llvm::make_early_inc_range(*VPBB)) { + VPValue *VPV = Ingredient.getVPSingleValue(); Instruction *Inst = cast<Instruction>(VPV->getUnderlyingValue()); if (DeadInstructions.count(Inst)) { VPValue DummyValue; VPV->replaceAllUsesWith(&DummyValue); - Ingredient->eraseFromParent(); + Ingredient.eraseFromParent(); continue; } VPRecipeBase *NewRecipe = nullptr; - if (auto *VPPhi = dyn_cast<VPWidenPHIRecipe>(Ingredient)) { + if (auto *VPPhi = dyn_cast<VPWidenPHIRecipe>(&Ingredient)) { auto *Phi = cast<PHINode>(VPPhi->getUnderlyingValue()); InductionDescriptor II = Inductions.lookup(Phi); if (II.getKind() == InductionDescriptor::IK_IntInduction || @@ -55,25 +54,25 @@ void VPlanTransforms::VPInstructionsToVPRecipes( continue; } } else { - assert(isa<VPInstruction>(Ingredient) && + assert(isa<VPInstruction>(&Ingredient) && "only VPInstructions expected here"); assert(!isa<PHINode>(Inst) && "phis should be handled above"); // Create VPWidenMemoryInstructionRecipe for loads and stores. if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) { NewRecipe = new VPWidenMemoryInstructionRecipe( *Load, Plan->getOrAddVPValue(getLoadStorePointerOperand(Inst)), - nullptr /*Mask*/); + nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/); } else if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) { NewRecipe = new VPWidenMemoryInstructionRecipe( *Store, Plan->getOrAddVPValue(getLoadStorePointerOperand(Inst)), - Plan->getOrAddVPValue(Store->getValueOperand()), - nullptr /*Mask*/); + Plan->getOrAddVPValue(Store->getValueOperand()), nullptr /*Mask*/, + false /*Consecutive*/, false /*Reverse*/); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { NewRecipe = new VPWidenGEPRecipe( GEP, Plan->mapToVPValues(GEP->operands()), OrigLoop); } else if (CallInst *CI = dyn_cast<CallInst>(Inst)) { - NewRecipe = new VPWidenCallRecipe( - *CI, Plan->mapToVPValues(CI->arg_operands())); + NewRecipe = + new VPWidenCallRecipe(*CI, Plan->mapToVPValues(CI->args())); } else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { bool InvariantCond = SE.isLoopInvariant(SE.getSCEV(SI->getOperand(0)), OrigLoop); @@ -85,13 +84,13 @@ void VPlanTransforms::VPInstructionsToVPRecipes( } } - NewRecipe->insertBefore(Ingredient); + NewRecipe->insertBefore(&Ingredient); if (NewRecipe->getNumDefinedValues() == 1) VPV->replaceAllUsesWith(NewRecipe->getVPSingleValue()); else assert(NewRecipe->getNumDefinedValues() == 0 && "Only recpies with zero or one defined values expected"); - Ingredient->eraseFromParent(); + Ingredient.eraseFromParent(); Plan->removeVPValueFor(Inst); for (auto *Def : NewRecipe->definedValues()) { Plan->addVPValue(Inst, Def); @@ -106,44 +105,76 @@ bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { bool Changed = false; // First, collect the operands of all predicated replicate recipes as seeds // for sinking. - SetVector<VPValue *> WorkList; + SetVector<std::pair<VPBasicBlock *, VPValue *>> WorkList; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { for (auto &Recipe : *VPBB) { auto *RepR = dyn_cast<VPReplicateRecipe>(&Recipe); if (!RepR || !RepR->isPredicated()) continue; - WorkList.insert(RepR->op_begin(), RepR->op_end()); + for (VPValue *Op : RepR->operands()) + WorkList.insert(std::make_pair(RepR->getParent(), Op)); } } // Try to sink each replicate recipe in the worklist. while (!WorkList.empty()) { - auto *C = WorkList.pop_back_val(); + VPBasicBlock *SinkTo; + VPValue *C; + std::tie(SinkTo, C) = WorkList.pop_back_val(); auto *SinkCandidate = dyn_cast_or_null<VPReplicateRecipe>(C->Def); - if (!SinkCandidate || SinkCandidate->isUniform()) - continue; - - // All users of SinkCandidate must be in the same block in order to perform - // sinking. Therefore the destination block for sinking must match the block - // containing the first user. - auto *FirstUser = dyn_cast<VPRecipeBase>(*SinkCandidate->user_begin()); - if (!FirstUser) - continue; - VPBasicBlock *SinkTo = FirstUser->getParent(); - if (SinkCandidate->getParent() == SinkTo || + if (!SinkCandidate || SinkCandidate->isUniform() || + SinkCandidate->getParent() == SinkTo || SinkCandidate->mayHaveSideEffects() || SinkCandidate->mayReadOrWriteMemory()) continue; - // All recipe users of the sink candidate must be in the same block SinkTo. - if (any_of(SinkCandidate->users(), [SinkTo](VPUser *U) { - auto *UI = dyn_cast<VPRecipeBase>(U); - return !UI || UI->getParent() != SinkTo; - })) + bool NeedsDuplicating = false; + // All recipe users of the sink candidate must be in the same block SinkTo + // or all users outside of SinkTo must be uniform-after-vectorization ( + // i.e., only first lane is used) . In the latter case, we need to duplicate + // SinkCandidate. At the moment, we identify such UAV's by looking for the + // address operands of widened memory recipes. + auto CanSinkWithUser = [SinkTo, &NeedsDuplicating, + SinkCandidate](VPUser *U) { + auto *UI = dyn_cast<VPRecipeBase>(U); + if (!UI) + return false; + if (UI->getParent() == SinkTo) + return true; + auto *WidenI = dyn_cast<VPWidenMemoryInstructionRecipe>(UI); + if (WidenI && WidenI->getAddr() == SinkCandidate) { + NeedsDuplicating = true; + return true; + } + return false; + }; + if (!all_of(SinkCandidate->users(), CanSinkWithUser)) continue; + if (NeedsDuplicating) { + Instruction *I = cast<Instruction>(SinkCandidate->getUnderlyingValue()); + auto *Clone = + new VPReplicateRecipe(I, SinkCandidate->operands(), true, false); + // TODO: add ".cloned" suffix to name of Clone's VPValue. + + Clone->insertBefore(SinkCandidate); + SmallVector<VPUser *, 4> Users(SinkCandidate->user_begin(), + SinkCandidate->user_end()); + for (auto *U : Users) { + auto *UI = cast<VPRecipeBase>(U); + if (UI->getParent() == SinkTo) + continue; + + for (unsigned Idx = 0; Idx != UI->getNumOperands(); Idx++) { + if (UI->getOperand(Idx) != SinkCandidate) + continue; + UI->setOperand(Idx, Clone); + } + } + } SinkCandidate->moveBefore(*SinkTo, SinkTo->getFirstNonPhi()); - WorkList.insert(SinkCandidate->op_begin(), SinkCandidate->op_end()); + for (VPValue *Op : SinkCandidate->operands()) + WorkList.insert(std::make_pair(SinkTo, Op)); Changed = true; } return Changed; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 6eec8d14de4a..6d6ea4eb30f1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -128,3 +128,33 @@ void VPlanVerifier::verifyHierarchicalCFG( assert(!TopRegion->getParent() && "VPlan Top Region should have no parent."); verifyRegionRec(TopRegion); } + +bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { + auto Iter = depth_first( + VPBlockRecursiveTraversalWrapper<const VPBlockBase *>(Plan.getEntry())); + for (const VPBasicBlock *VPBB : + VPBlockUtils::blocksOnly<const VPBasicBlock>(Iter)) { + // Verify that phi-like recipes are at the beginning of the block, with no + // other recipes in between. + auto RecipeI = VPBB->begin(); + auto End = VPBB->end(); + while (RecipeI != End && RecipeI->isPhi()) + RecipeI++; + + while (RecipeI != End) { + if (RecipeI->isPhi() && !isa<VPBlendRecipe>(&*RecipeI)) { + errs() << "Found phi-like recipe after non-phi recipe"; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + errs() << ": "; + RecipeI->dump(); + errs() << "after\n"; + std::prev(RecipeI)->dump(); +#endif + return false; + } + RecipeI++; + } + } + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h index 8e8de441648a..839c24e2c9f4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h @@ -26,6 +26,7 @@ namespace llvm { class VPRegionBlock; +class VPlan; /// Struct with utility functions that can be used to check the consistency and /// invariants of a VPlan, including the components of its H-CFG. @@ -35,6 +36,12 @@ struct VPlanVerifier { /// 1. Region/Block verification: Check the Region/Block verification /// invariants for every region in the H-CFG. void verifyHierarchicalCFG(const VPRegionBlock *TopRegion) const; + + /// Verify invariants for general VPlans. Currently it checks the following: + /// 1. all phi-like recipes must be at the beginning of a block, with no other + /// recipes in between. Note that currently there is still an exception for + /// VPBlendRecipes. + static bool verifyPlanIsValid(const VPlan &Plan); }; } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index d18bcd34620c..57b11e9414ba 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -31,10 +31,12 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" +#define DEBUG_TYPE "vector-combine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace llvm::PatternMatch; -#define DEBUG_TYPE "vector-combine" STATISTIC(NumVecLoad, "Number of vector loads formed"); STATISTIC(NumVecCmp, "Number of vector compares formed"); STATISTIC(NumVecBO, "Number of vector binops formed"); @@ -61,8 +63,10 @@ namespace { class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} + const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, + bool ScalarizationOnly) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), + ScalarizationOnly(ScalarizationOnly) {} bool run(); @@ -74,12 +78,18 @@ private: AAResults &AA; AssumptionCache &AC; + /// If true only perform scalarization combines and do not introduce new + /// vector operations. + bool ScalarizationOnly; + + InstructionWorklist Worklist; + bool vectorizeLoadInsert(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, ExtractElementInst *Ext1, unsigned PreferredExtractIndex) const; bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, - unsigned Opcode, + const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex); void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, @@ -92,14 +102,27 @@ private: bool foldExtractedCmps(Instruction &I); bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); + bool foldShuffleOfBinops(Instruction &I); + + void replaceValue(Value &Old, Value &New) { + Old.replaceAllUsesWith(&New); + New.takeName(&Old); + if (auto *NewI = dyn_cast<Instruction>(&New)) { + Worklist.pushUsersToWorkList(*NewI); + Worklist.pushValue(NewI); + } + Worklist.pushValue(&Old); + } + + void eraseInstruction(Instruction &I) { + for (Value *Op : I.operands()) + Worklist.pushValue(Op); + Worklist.remove(&I); + I.eraseFromParent(); + } }; } // namespace -static void replaceValue(Value &Old, Value &New) { - Old.replaceAllUsesWith(&New); - New.takeName(&Old); -} - bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Match insert into fixed vector of scalar value. // TODO: Handle non-zero insert index. @@ -284,12 +307,13 @@ ExtractElementInst *VectorCombine::getShuffleExtract( /// \p ConvertToShuffle to that extract instruction. bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, - unsigned Opcode, + const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex) { assert(isa<ConstantInt>(Ext0->getOperand(1)) && isa<ConstantInt>(Ext1->getOperand(1)) && "Expected constant extract indexes"); + unsigned Opcode = I.getOpcode(); Type *ScalarTy = Ext0->getType(); auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); InstructionCost ScalarOpCost, VectorOpCost; @@ -302,10 +326,11 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, } else { assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && "Expected a compare"); - ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, - CmpInst::makeCmpResultType(ScalarTy)); - VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, - CmpInst::makeCmpResultType(VecTy)); + CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); + ScalarOpCost = TTI.getCmpSelInstrCost( + Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred); + VectorOpCost = TTI.getCmpSelInstrCost( + Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred); } // Get cost estimates for the extract elements. These costs will factor into @@ -480,8 +505,7 @@ bool VectorCombine::foldExtractExtract(Instruction &I) { m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); ExtractElementInst *ExtractToChange; - if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange, - InsertIndex)) + if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex)) return false; if (ExtractToChange) { @@ -501,6 +525,8 @@ bool VectorCombine::foldExtractExtract(Instruction &I) { else foldExtExtBinop(Ext0, Ext1, I); + Worklist.push(Ext0); + Worklist.push(Ext1); return true; } @@ -623,8 +649,11 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { unsigned Opcode = I.getOpcode(); InstructionCost ScalarOpCost, VectorOpCost; if (IsCmp) { - ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy); - VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy); + CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); + ScalarOpCost = TTI.getCmpSelInstrCost( + Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred); + VectorOpCost = TTI.getCmpSelInstrCost( + Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred); } else { ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); @@ -724,7 +753,10 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { InstructionCost OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); - OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2; + OldCost += + TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(), + CmpInst::makeCmpResultType(I0->getType()), Pred) * + 2; OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); // The proposed vector pattern is: @@ -733,7 +765,8 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); - InstructionCost NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType()); + InstructionCost NewCost = TTI.getCmpSelInstrCost( + CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred); SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); ShufMask[CheapIndex] = ExpensiveIndex; NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, @@ -774,18 +807,98 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin, }); } +/// Helper class to indicate whether a vector index can be safely scalarized and +/// if a freeze needs to be inserted. +class ScalarizationResult { + enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; + + StatusTy Status; + Value *ToFreeze; + + ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) + : Status(Status), ToFreeze(ToFreeze) {} + +public: + ScalarizationResult(const ScalarizationResult &Other) = default; + ~ScalarizationResult() { + assert(!ToFreeze && "freeze() not called with ToFreeze being set"); + } + + static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } + static ScalarizationResult safe() { return {StatusTy::Safe}; } + static ScalarizationResult safeWithFreeze(Value *ToFreeze) { + return {StatusTy::SafeWithFreeze, ToFreeze}; + } + + /// Returns true if the index can be scalarize without requiring a freeze. + bool isSafe() const { return Status == StatusTy::Safe; } + /// Returns true if the index cannot be scalarized. + bool isUnsafe() const { return Status == StatusTy::Unsafe; } + /// Returns true if the index can be scalarize, but requires inserting a + /// freeze. + bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } + + /// Reset the state of Unsafe and clear ToFreze if set. + void discard() { + ToFreeze = nullptr; + Status = StatusTy::Unsafe; + } + + /// Freeze the ToFreeze and update the use in \p User to use it. + void freeze(IRBuilder<> &Builder, Instruction &UserI) { + assert(isSafeWithFreeze() && + "should only be used when freezing is required"); + assert(is_contained(ToFreeze->users(), &UserI) && + "UserI must be a user of ToFreeze"); + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(cast<Instruction>(&UserI)); + Value *Frozen = + Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); + for (Use &U : make_early_inc_range((UserI.operands()))) + if (U.get() == ToFreeze) + U.set(Frozen); + + ToFreeze = nullptr; + } +}; + /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. -static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx, - Instruction *CtxI, AssumptionCache &AC) { - if (auto *C = dyn_cast<ConstantInt>(Idx)) - return C->getValue().ult(VecTy->getNumElements()); +static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy, + Value *Idx, Instruction *CtxI, + AssumptionCache &AC, + const DominatorTree &DT) { + if (auto *C = dyn_cast<ConstantInt>(Idx)) { + if (C->getValue().ult(VecTy->getNumElements())) + return ScalarizationResult::safe(); + return ScalarizationResult::unsafe(); + } - APInt Zero(Idx->getType()->getScalarSizeInBits(), 0); - APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements()); + unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); + APInt Zero(IntWidth, 0); + APInt MaxElts(IntWidth, VecTy->getNumElements()); ConstantRange ValidIndices(Zero, MaxElts); - ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0); - return ValidIndices.contains(IdxRange); + ConstantRange IdxRange(IntWidth, true); + + if (isGuaranteedNotToBePoison(Idx, &AC)) { + if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, &DT))) + return ScalarizationResult::safe(); + return ScalarizationResult::unsafe(); + } + + // If the index may be poison, check if we can insert a freeze before the + // range of the index is restricted. + Value *IdxBase; + ConstantInt *CI; + if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { + IdxRange = IdxRange.binaryAnd(CI->getValue()); + } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { + IdxRange = IdxRange.urem(CI->getValue()); + } + + if (ValidIndices.contains(IdxRange)) + return ScalarizationResult::safeWithFreeze(IdxBase); + return ScalarizationResult::unsafe(); } /// The memory operation on a vector of \p ScalarType had alignment of @@ -833,12 +946,17 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || !DL.typeSizeEqualsStoreSize(Load->getType()) || - !canScalarizeAccess(VecTy, Idx, Load, AC) || - SrcAddr != SI->getPointerOperand()->stripPointerCasts() || + SrcAddr != SI->getPointerOperand()->stripPointerCasts()) + return false; + + auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); + if (ScalarizableIdx.isUnsafe() || isMemModifiedBetween(Load->getIterator(), SI->getIterator(), MemoryLocation::get(SI), AA)) return false; + if (ScalarizableIdx.isSafeWithFreeze()) + ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); Value *GEP = Builder.CreateInBoundsGEP( SI->getValueOperand()->getType(), SI->getPointerOperand(), {ConstantInt::get(Idx->getType(), 0), Idx}); @@ -849,8 +967,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { DL); NSI->setAlignment(ScalarOpAlignment); replaceValue(I, *NSI); - // Need erasing the store manually. - I.eraseFromParent(); + eraseInstruction(I); return true; } @@ -860,11 +977,10 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { /// Try to scalarize vector loads feeding extractelement instructions. bool VectorCombine::scalarizeLoadExtract(Instruction &I) { Value *Ptr; - Value *Idx; - if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx)))) + if (!match(&I, m_Load(m_Value(Ptr)))) return false; - auto *LI = cast<LoadInst>(I.getOperand(0)); + auto *LI = cast<LoadInst>(&I); const DataLayout &DL = I.getModule()->getDataLayout(); if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType())) return false; @@ -909,8 +1025,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { else if (LastCheckedInst->comesBefore(UI)) LastCheckedInst = UI; - if (!canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC)) + auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); + if (!ScalarIdx.isSafe()) { + // TODO: Freeze index if it is safe to do so. + ScalarIdx.discard(); return false; + } auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); OriginalCost += @@ -946,6 +1066,60 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return true; } +/// Try to convert "shuffle (binop), (binop)" with a shared binop operand into +/// "binop (shuffle), (shuffle)". +bool VectorCombine::foldShuffleOfBinops(Instruction &I) { + auto *VecTy = dyn_cast<FixedVectorType>(I.getType()); + if (!VecTy) + return false; + + BinaryOperator *B0, *B1; + ArrayRef<int> Mask; + if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)), + m_Mask(Mask))) || + B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy) + return false; + + // Try to replace a binop with a shuffle if the shuffle is not costly. + // The new shuffle will choose from a single, common operand, so it may be + // cheaper than the existing two-operand shuffle. + SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size()); + Instruction::BinaryOps Opcode = B0->getOpcode(); + InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy); + InstructionCost ShufCost = TTI.getShuffleCost( + TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask); + if (ShufCost > BinopCost) + return false; + + // If we have something like "add X, Y" and "add Z, X", swap ops to match. + Value *X = B0->getOperand(0), *Y = B0->getOperand(1); + Value *Z = B1->getOperand(0), *W = B1->getOperand(1); + if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W) + std::swap(X, Y); + + Value *Shuf0, *Shuf1; + if (X == Z) { + // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W) + Shuf0 = Builder.CreateShuffleVector(X, UnaryMask); + Shuf1 = Builder.CreateShuffleVector(Y, W, Mask); + } else if (Y == W) { + // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y) + Shuf0 = Builder.CreateShuffleVector(X, Z, Mask); + Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask); + } else { + return false; + } + + Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1); + // Intersect flags from the old binops. + if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { + NewInst->copyIRFlags(B0); + NewInst->andIRFlags(B1); + } + replaceValue(I, *NewBO); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -957,29 +1131,43 @@ bool VectorCombine::run() { return false; bool MadeChange = false; + auto FoldInst = [this, &MadeChange](Instruction &I) { + Builder.SetInsertPoint(&I); + if (!ScalarizationOnly) { + MadeChange |= vectorizeLoadInsert(I); + MadeChange |= foldExtractExtract(I); + MadeChange |= foldBitcastShuf(I); + MadeChange |= foldExtractedCmps(I); + MadeChange |= foldShuffleOfBinops(I); + } + MadeChange |= scalarizeBinopOrCmp(I); + MadeChange |= scalarizeLoadExtract(I); + MadeChange |= foldSingleElementStore(I); + }; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; // Use early increment range so that we can erase instructions in loop. for (Instruction &I : make_early_inc_range(BB)) { - if (isa<DbgInfoIntrinsic>(I)) + if (I.isDebugOrPseudoInst()) continue; - Builder.SetInsertPoint(&I); - MadeChange |= vectorizeLoadInsert(I); - MadeChange |= foldExtractExtract(I); - MadeChange |= foldBitcastShuf(I); - MadeChange |= scalarizeBinopOrCmp(I); - MadeChange |= foldExtractedCmps(I); - MadeChange |= scalarizeLoadExtract(I); - MadeChange |= foldSingleElementStore(I); + FoldInst(I); } } - // We're done with transforms, so remove dead instructions. - if (MadeChange) - for (BasicBlock &BB : F) - SimplifyInstructionsInBlock(&BB); + while (!Worklist.isEmpty()) { + Instruction *I = Worklist.removeOne(); + if (!I) + continue; + + if (isInstructionTriviallyDead(I)) { + eraseInstruction(*I); + continue; + } + + FoldInst(*I); + } return MadeChange; } @@ -1014,7 +1202,7 @@ public: auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA, AC); + VectorCombine Combiner(F, TTI, DT, AA, AC, false); return Combiner.run(); } }; @@ -1038,7 +1226,7 @@ PreservedAnalyses VectorCombinePass::run(Function &F, TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); AAResults &AA = FAM.getResult<AAManager>(F); - VectorCombine Combiner(F, TTI, DT, AA, AC); + VectorCombine Combiner(F, TTI, DT, AA, AC, ScalarizationOnly); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; |