diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2019-10-23 17:51:42 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2019-10-23 17:51:42 +0000 |
commit | 1d5ae1026e831016fc29fd927877c86af904481f (patch) | |
tree | 2cdfd12620fcfa5d9e4a0389f85368e8e36f63f9 /lib/Transforms | |
parent | e6d1592492a3a379186bfb02bd0f4eda0669c0d5 (diff) |
Notes
Diffstat (limited to 'lib/Transforms')
160 files changed, 16942 insertions, 4892 deletions
diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 06222d7e7e44..a24de3ca213f 100644 --- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -121,14 +121,13 @@ static bool foldGuardedRotateToFunnelShift(Instruction &I) { BasicBlock *GuardBB = Phi.getIncomingBlock(RotSrc == P1); BasicBlock *RotBB = Phi.getIncomingBlock(RotSrc != P1); Instruction *TermI = GuardBB->getTerminator(); - BasicBlock *TrueBB, *FalseBB; ICmpInst::Predicate Pred; - if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), TrueBB, - FalseBB))) + BasicBlock *PhiBB = Phi.getParent(); + if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), + m_SpecificBB(PhiBB), m_SpecificBB(RotBB)))) return false; - BasicBlock *PhiBB = Phi.getParent(); - if (Pred != CmpInst::ICMP_EQ || TrueBB != PhiBB || FalseBB != RotBB) + if (Pred != CmpInst::ICMP_EQ) return false; // We matched a variation of this IR pattern: @@ -251,6 +250,72 @@ static bool foldAnyOrAllBitsSet(Instruction &I) { return true; } +// Try to recognize below function as popcount intrinsic. +// This is the "best" algorithm from +// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel +// Also used in TargetLowering::expandCTPOP(). +// +// int popcount(unsigned int i) { +// i = i - ((i >> 1) & 0x55555555); +// i = (i & 0x33333333) + ((i >> 2) & 0x33333333); +// i = ((i + (i >> 4)) & 0x0F0F0F0F); +// return (i * 0x01010101) >> 24; +// } +static bool tryToRecognizePopCount(Instruction &I) { + if (I.getOpcode() != Instruction::LShr) + return false; + + Type *Ty = I.getType(); + if (!Ty->isIntOrIntVectorTy()) + return false; + + unsigned Len = Ty->getScalarSizeInBits(); + // FIXME: fix Len == 8 and other irregular type lengths. + if (!(Len <= 128 && Len > 8 && Len % 8 == 0)) + return false; + + APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55)); + APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33)); + APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F)); + APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01)); + APInt MaskShift = APInt(Len, Len - 8); + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *MulOp0; + // Matching "(i * 0x01010101...) >> 24". + if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) && + match(Op1, m_SpecificInt(MaskShift))) { + Value *ShiftOp0; + // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". + if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)), + m_Deferred(ShiftOp0)), + m_SpecificInt(Mask0F)))) { + Value *AndOp0; + // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)". + if (match(ShiftOp0, + m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)), + m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)), + m_SpecificInt(Mask33))))) { + Value *Root, *SubOp1; + // Matching "i - ((i >> 1) & 0x55555555...)". + if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) && + match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)), + m_SpecificInt(Mask55)))) { + LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n"); + IRBuilder<> Builder(&I); + Function *Func = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::ctpop, I.getType()); + I.replaceAllUsesWith(Builder.CreateCall(Func, {Root})); + return true; + } + } + } + } + + return false; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. @@ -269,6 +334,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { MadeChange |= foldAnyOrAllBitsSet(I); MadeChange |= foldGuardedRotateToFunnelShift(I); + MadeChange |= tryToRecognizePopCount(I); } } @@ -303,7 +369,7 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage( } bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); return runImpl(F, TLI, DT); } diff --git a/lib/Transforms/Coroutines/CoroCleanup.cpp b/lib/Transforms/Coroutines/CoroCleanup.cpp index 1fb0a114d0c7..c3e05577f044 100644 --- a/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -73,6 +73,8 @@ bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { II->replaceAllUsesWith(ConstantInt::getTrue(Context)); break; case Intrinsic::coro_id: + case Intrinsic::coro_id_retcon: + case Intrinsic::coro_id_retcon_once: II->replaceAllUsesWith(ConstantTokenNone::get(Context)); break; case Intrinsic::coro_subfn_addr: @@ -111,8 +113,9 @@ struct CoroCleanup : FunctionPass { bool doInitialization(Module &M) override { if (coro::declaresIntrinsics(M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr", "llvm.coro.free", - "llvm.coro.id"})) - L = llvm::make_unique<Lowerer>(M); + "llvm.coro.id", "llvm.coro.id.retcon", + "llvm.coro.id.retcon.once"})) + L = std::make_unique<Lowerer>(M); return false; } diff --git a/lib/Transforms/Coroutines/CoroEarly.cpp b/lib/Transforms/Coroutines/CoroEarly.cpp index 692697d6f32e..55993d33ee4e 100644 --- a/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/lib/Transforms/Coroutines/CoroEarly.cpp @@ -91,13 +91,14 @@ void Lowerer::lowerCoroDone(IntrinsicInst *II) { Value *Operand = II->getArgOperand(0); // ResumeFnAddr is the first pointer sized element of the coroutine frame. + static_assert(coro::Shape::SwitchFieldIndex::Resume == 0, + "resume function not at offset zero"); auto *FrameTy = Int8Ptr; PointerType *FramePtrTy = FrameTy->getPointerTo(); Builder.SetInsertPoint(II); auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy); - auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0); - auto *Load = Builder.CreateLoad(FrameTy, Gep); + auto *Load = Builder.CreateLoad(BCI); auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); II->replaceAllUsesWith(Cond); @@ -189,6 +190,10 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { } } break; + case Intrinsic::coro_id_retcon: + case Intrinsic::coro_id_retcon_once: + F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); + break; case Intrinsic::coro_resume: lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex); break; @@ -231,11 +236,18 @@ struct CoroEarly : public FunctionPass { // This pass has work to do only if we find intrinsics we are going to lower // in the module. bool doInitialization(Module &M) override { - if (coro::declaresIntrinsics( - M, {"llvm.coro.id", "llvm.coro.destroy", "llvm.coro.done", - "llvm.coro.end", "llvm.coro.noop", "llvm.coro.free", - "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"})) - L = llvm::make_unique<Lowerer>(M); + if (coro::declaresIntrinsics(M, {"llvm.coro.id", + "llvm.coro.id.retcon", + "llvm.coro.id.retcon.once", + "llvm.coro.destroy", + "llvm.coro.done", + "llvm.coro.end", + "llvm.coro.noop", + "llvm.coro.free", + "llvm.coro.promise", + "llvm.coro.resume", + "llvm.coro.suspend"})) + L = std::make_unique<Lowerer>(M); return false; } diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp index 6707aa1c827d..aca77119023b 100644 --- a/lib/Transforms/Coroutines/CoroElide.cpp +++ b/lib/Transforms/Coroutines/CoroElide.cpp @@ -286,7 +286,7 @@ struct CoroElide : FunctionPass { bool doInitialization(Module &M) override { if (coro::declaresIntrinsics(M, {"llvm.coro.id"})) - L = llvm::make_unique<Lowerer>(M); + L = std::make_unique<Lowerer>(M); return false; } diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp index 58bf22bee29b..2c42cf8a6d25 100644 --- a/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/lib/Transforms/Coroutines/CoroFrame.cpp @@ -18,6 +18,7 @@ #include "CoroInternal.h" #include "llvm/ADT/BitVector.h" +#include "llvm/Analysis/PtrUseVisitor.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/CFG.h" @@ -28,6 +29,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/circular_raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" using namespace llvm; @@ -120,6 +122,15 @@ struct SuspendCrossingInfo { return false; BasicBlock *UseBB = I->getParent(); + + // As a special case, treat uses by an llvm.coro.suspend.retcon + // as if they were uses in the suspend's single predecessor: the + // uses conceptually occur before the suspend. + if (isa<CoroSuspendRetconInst>(I)) { + UseBB = UseBB->getSinglePredecessor(); + assert(UseBB && "should have split coro.suspend into its own block"); + } + return hasPathCrossingSuspendPoint(DefBB, UseBB); } @@ -128,7 +139,17 @@ struct SuspendCrossingInfo { } bool isDefinitionAcrossSuspend(Instruction &I, User *U) const { - return isDefinitionAcrossSuspend(I.getParent(), U); + auto *DefBB = I.getParent(); + + // As a special case, treat values produced by an llvm.coro.suspend.* + // as if they were defined in the single successor: the uses + // conceptually occur after the suspend. + if (isa<AnyCoroSuspendInst>(I)) { + DefBB = DefBB->getSingleSuccessor(); + assert(DefBB && "should have split coro.suspend into its own block"); + } + + return isDefinitionAcrossSuspend(DefBB, U); } }; } // end anonymous namespace @@ -183,9 +204,10 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) B.Suspend = true; B.Kills |= B.Consumes; }; - for (CoroSuspendInst *CSI : Shape.CoroSuspends) { + for (auto *CSI : Shape.CoroSuspends) { markSuspendBlock(CSI); - markSuspendBlock(CSI->getCoroSave()); + if (auto *Save = CSI->getCoroSave()) + markSuspendBlock(Save); } // Iterate propagating consumes and kills until they stop changing. @@ -261,11 +283,13 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) // We build up the list of spills for every case where a use is separated // from the definition by a suspend point. +static const unsigned InvalidFieldIndex = ~0U; + namespace { class Spill { Value *Def = nullptr; Instruction *User = nullptr; - unsigned FieldNo = 0; + unsigned FieldNo = InvalidFieldIndex; public: Spill(Value *Def, llvm::User *U) : Def(Def), User(cast<Instruction>(U)) {} @@ -280,11 +304,11 @@ public: // the definition the first time they encounter it. Consider refactoring // SpillInfo into two arrays to normalize the spill representation. unsigned fieldIndex() const { - assert(FieldNo && "Accessing unassigned field"); + assert(FieldNo != InvalidFieldIndex && "Accessing unassigned field"); return FieldNo; } void setFieldIndex(unsigned FieldNumber) { - assert(!FieldNo && "Reassigning field number"); + assert(FieldNo == InvalidFieldIndex && "Reassigning field number"); FieldNo = FieldNumber; } }; @@ -376,18 +400,30 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, SmallString<32> Name(F.getName()); Name.append(".Frame"); StructType *FrameTy = StructType::create(C, Name); - auto *FramePtrTy = FrameTy->getPointerTo(); - auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, - /*isVarArg=*/false); - auto *FnPtrTy = FnTy->getPointerTo(); - - // Figure out how wide should be an integer type storing the suspend index. - unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size())); - Type *PromiseType = Shape.PromiseAlloca - ? Shape.PromiseAlloca->getType()->getElementType() - : Type::getInt1Ty(C); - SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, PromiseType, - Type::getIntNTy(C, IndexBits)}; + SmallVector<Type *, 8> Types; + + AllocaInst *PromiseAlloca = Shape.getPromiseAlloca(); + + if (Shape.ABI == coro::ABI::Switch) { + auto *FramePtrTy = FrameTy->getPointerTo(); + auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, + /*IsVarArg=*/false); + auto *FnPtrTy = FnTy->getPointerTo(); + + // Figure out how wide should be an integer type storing the suspend index. + unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size())); + Type *PromiseType = PromiseAlloca + ? PromiseAlloca->getType()->getElementType() + : Type::getInt1Ty(C); + Type *IndexType = Type::getIntNTy(C, IndexBits); + Types.push_back(FnPtrTy); + Types.push_back(FnPtrTy); + Types.push_back(PromiseType); + Types.push_back(IndexType); + } else { + assert(PromiseAlloca == nullptr && "lowering doesn't support promises"); + } + Value *CurrentDef = nullptr; Padder.addTypes(Types); @@ -399,7 +435,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, CurrentDef = S.def(); // PromiseAlloca was already added to Types array earlier. - if (CurrentDef == Shape.PromiseAlloca) + if (CurrentDef == PromiseAlloca) continue; uint64_t Count = 1; @@ -430,9 +466,80 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, } FrameTy->setBody(Types); + switch (Shape.ABI) { + case coro::ABI::Switch: + break; + + // Remember whether the frame is inline in the storage. + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: { + auto &Layout = F.getParent()->getDataLayout(); + auto Id = Shape.getRetconCoroId(); + Shape.RetconLowering.IsFrameInlineInStorage + = (Layout.getTypeAllocSize(FrameTy) <= Id->getStorageSize() && + Layout.getABITypeAlignment(FrameTy) <= Id->getStorageAlignment()); + break; + } + } + return FrameTy; } +// We use a pointer use visitor to discover if there are any writes into an +// alloca that dominates CoroBegin. If that is the case, insertSpills will copy +// the value from the alloca into the coroutine frame spill slot corresponding +// to that alloca. +namespace { +struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { + using Base = PtrUseVisitor<AllocaUseVisitor>; + AllocaUseVisitor(const DataLayout &DL, const DominatorTree &DT, + const CoroBeginInst &CB) + : PtrUseVisitor(DL), DT(DT), CoroBegin(CB) {} + + // We are only interested in uses that dominate coro.begin. + void visit(Instruction &I) { + if (DT.dominates(&I, &CoroBegin)) + Base::visit(I); + } + // We need to provide this overload as PtrUseVisitor uses a pointer based + // visiting function. + void visit(Instruction *I) { return visit(*I); } + + void visitLoadInst(LoadInst &) {} // Good. Nothing to do. + + // If the use is an operand, the pointer escaped and anything can write into + // that memory. If the use is the pointer, we are definitely writing into the + // alloca and therefore we need to copy. + void visitStoreInst(StoreInst &SI) { PI.setAborted(&SI); } + + // Any other instruction that is not filtered out by PtrUseVisitor, will + // result in the copy. + void visitInstruction(Instruction &I) { PI.setAborted(&I); } + +private: + const DominatorTree &DT; + const CoroBeginInst &CoroBegin; +}; +} // namespace +static bool mightWriteIntoAllocaPtr(AllocaInst &A, const DominatorTree &DT, + const CoroBeginInst &CB) { + const DataLayout &DL = A.getModule()->getDataLayout(); + AllocaUseVisitor Visitor(DL, DT, CB); + auto PtrI = Visitor.visitPtr(A); + if (PtrI.isEscaped() || PtrI.isAborted()) { + auto *PointerEscapingInstr = PtrI.getEscapingInst() + ? PtrI.getEscapingInst() + : PtrI.getAbortingInst(); + if (PointerEscapingInstr) { + LLVM_DEBUG( + dbgs() << "AllocaInst copy was triggered by instruction: " + << *PointerEscapingInstr << "\n"); + } + return true; + } + return false; +} + // We need to make room to insert a spill after initial PHIs, but before // catchswitch instruction. Placing it before violates the requirement that // catchswitch, like all other EHPads must be the first nonPHI in a block. @@ -476,7 +583,7 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) { // whatever // // -static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { +static Instruction *insertSpills(const SpillInfo &Spills, coro::Shape &Shape) { auto *CB = Shape.CoroBegin; LLVMContext &C = CB->getContext(); IRBuilder<> Builder(CB->getNextNode()); @@ -484,11 +591,14 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { PointerType *FramePtrTy = FrameTy->getPointerTo(); auto *FramePtr = cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr")); + DominatorTree DT(*CB->getFunction()); Value *CurrentValue = nullptr; BasicBlock *CurrentBlock = nullptr; Value *CurrentReload = nullptr; - unsigned Index = 0; // Proper field number will be read from field definition. + + // Proper field number will be read from field definition. + unsigned Index = InvalidFieldIndex; // We need to keep track of any allocas that need "spilling" // since they will live in the coroutine frame now, all access to them @@ -496,9 +606,11 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { // we remember allocas and their indices to be handled once we processed // all the spills. SmallVector<std::pair<AllocaInst *, unsigned>, 4> Allocas; - // Promise alloca (if present) has a fixed field number (Shape::PromiseField) - if (Shape.PromiseAlloca) - Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField); + // Promise alloca (if present) has a fixed field number. + if (auto *PromiseAlloca = Shape.getPromiseAlloca()) { + assert(Shape.ABI == coro::ABI::Switch); + Allocas.emplace_back(PromiseAlloca, coro::Shape::SwitchFieldIndex::Promise); + } // Create a GEP with the given index into the coroutine frame for the original // value Orig. Appends an extra 0 index for array-allocas, preserving the @@ -526,7 +638,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { // Create a load instruction to reload the spilled value from the coroutine // frame. auto CreateReload = [&](Instruction *InsertBefore) { - assert(Index && "accessing unassigned field number"); + assert(Index != InvalidFieldIndex && "accessing unassigned field number"); Builder.SetInsertPoint(InsertBefore); auto *G = GetFramePointer(Index, CurrentValue); @@ -558,29 +670,45 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { // coroutine frame. Instruction *InsertPt = nullptr; - if (isa<Argument>(CurrentValue)) { + if (auto Arg = dyn_cast<Argument>(CurrentValue)) { // For arguments, we will place the store instruction right after // the coroutine frame pointer instruction, i.e. bitcast of // coro.begin from i8* to %f.frame*. InsertPt = FramePtr->getNextNode(); + + // If we're spilling an Argument, make sure we clear 'nocapture' + // from the coroutine function. + Arg->getParent()->removeParamAttr(Arg->getArgNo(), + Attribute::NoCapture); + } else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) { // If we are spilling the result of the invoke instruction, split the // normal edge and insert the spill in the new block. auto NewBB = SplitEdge(II->getParent(), II->getNormalDest()); InsertPt = NewBB->getTerminator(); - } else if (dyn_cast<PHINode>(CurrentValue)) { + } else if (isa<PHINode>(CurrentValue)) { // Skip the PHINodes and EH pads instructions. BasicBlock *DefBlock = cast<Instruction>(E.def())->getParent(); if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator())) InsertPt = splitBeforeCatchSwitch(CSI); else InsertPt = &*DefBlock->getFirstInsertionPt(); + } else if (auto CSI = dyn_cast<AnyCoroSuspendInst>(CurrentValue)) { + // Don't spill immediately after a suspend; splitting assumes + // that the suspend will be followed by a branch. + InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHI(); } else { + auto *I = cast<Instruction>(E.def()); + assert(!I->isTerminator() && "unexpected terminator"); // For all other values, the spill is placed immediately after // the definition. - assert(!cast<Instruction>(E.def())->isTerminator() && - "unexpected terminator"); - InsertPt = cast<Instruction>(E.def())->getNextNode(); + if (DT.dominates(CB, I)) { + InsertPt = I->getNextNode(); + } else { + // Unless, it is not dominated by CoroBegin, then it will be + // inserted immediately after CoroFrame is computed. + InsertPt = FramePtr->getNextNode(); + } } Builder.SetInsertPoint(InsertPt); @@ -613,21 +741,53 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { } BasicBlock *FramePtrBB = FramePtr->getParent(); - Shape.AllocaSpillBlock = - FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB"); - Shape.AllocaSpillBlock->splitBasicBlock(&Shape.AllocaSpillBlock->front(), - "PostSpill"); - Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front()); + auto SpillBlock = + FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB"); + SpillBlock->splitBasicBlock(&SpillBlock->front(), "PostSpill"); + Shape.AllocaSpillBlock = SpillBlock; // If we found any allocas, replace all of their remaining uses with Geps. + // Note: we cannot do it indiscriminately as some of the uses may not be + // dominated by CoroBegin. + bool MightNeedToCopy = false; + Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front()); + SmallVector<Instruction *, 4> UsersToUpdate; for (auto &P : Allocas) { - auto *G = GetFramePointer(P.second, P.first); + AllocaInst *const A = P.first; + UsersToUpdate.clear(); + for (User *U : A->users()) { + auto *I = cast<Instruction>(U); + if (DT.dominates(CB, I)) + UsersToUpdate.push_back(I); + else + MightNeedToCopy = true; + } + if (!UsersToUpdate.empty()) { + auto *G = GetFramePointer(P.second, A); + G->takeName(A); + for (Instruction *I : UsersToUpdate) + I->replaceUsesOfWith(A, G); + } + } + // If we discovered such uses not dominated by CoroBegin, see if any of them + // preceed coro begin and have instructions that can modify the + // value of the alloca and therefore would require a copying the value into + // the spill slot in the coroutine frame. + if (MightNeedToCopy) { + Builder.SetInsertPoint(FramePtr->getNextNode()); + + for (auto &P : Allocas) { + AllocaInst *const A = P.first; + if (mightWriteIntoAllocaPtr(*A, DT, *CB)) { + if (A->isArrayAllocation()) + report_fatal_error( + "Coroutines cannot handle copying of array allocas yet"); - // We are not using ReplaceInstWithInst(P.first, cast<Instruction>(G)) here, - // as we are changing location of the instruction. - G->takeName(P.first); - P.first->replaceAllUsesWith(G); - P.first->eraseFromParent(); + auto *G = GetFramePointer(P.second, A); + auto *Value = Builder.CreateLoad(A); + Builder.CreateStore(Value, G); + } + } } return FramePtr; } @@ -829,52 +989,6 @@ static void rewriteMaterializableInstructions(IRBuilder<> &IRB, } } -// Move early uses of spilled variable after CoroBegin. -// For example, if a parameter had address taken, we may end up with the code -// like: -// define @f(i32 %n) { -// %n.addr = alloca i32 -// store %n, %n.addr -// ... -// call @coro.begin -// we need to move the store after coro.begin -static void moveSpillUsesAfterCoroBegin(Function &F, SpillInfo const &Spills, - CoroBeginInst *CoroBegin) { - DominatorTree DT(F); - SmallVector<Instruction *, 8> NeedsMoving; - - Value *CurrentValue = nullptr; - - for (auto const &E : Spills) { - if (CurrentValue == E.def()) - continue; - - CurrentValue = E.def(); - - for (User *U : CurrentValue->users()) { - Instruction *I = cast<Instruction>(U); - if (!DT.dominates(CoroBegin, I)) { - LLVM_DEBUG(dbgs() << "will move: " << *I << "\n"); - - // TODO: Make this more robust. Currently if we run into a situation - // where simple instruction move won't work we panic and - // report_fatal_error. - for (User *UI : I->users()) { - if (!DT.dominates(CoroBegin, cast<Instruction>(UI))) - report_fatal_error("cannot move instruction since its users are not" - " dominated by CoroBegin"); - } - - NeedsMoving.push_back(I); - } - } - } - - Instruction *InsertPt = CoroBegin->getNextNode(); - for (Instruction *I : NeedsMoving) - I->moveBefore(InsertPt); -} - // Splits the block at a particular instruction unless it is the first // instruction in the block with a single predecessor. static BasicBlock *splitBlockIfNotFirst(Instruction *I, const Twine &Name) { @@ -895,21 +1009,337 @@ static void splitAround(Instruction *I, const Twine &Name) { splitBlockIfNotFirst(I->getNextNode(), "After" + Name); } +static bool isSuspendBlock(BasicBlock *BB) { + return isa<AnyCoroSuspendInst>(BB->front()); +} + +typedef SmallPtrSet<BasicBlock*, 8> VisitedBlocksSet; + +/// Does control flow starting at the given block ever reach a suspend +/// instruction before reaching a block in VisitedOrFreeBBs? +static bool isSuspendReachableFrom(BasicBlock *From, + VisitedBlocksSet &VisitedOrFreeBBs) { + // Eagerly try to add this block to the visited set. If it's already + // there, stop recursing; this path doesn't reach a suspend before + // either looping or reaching a freeing block. + if (!VisitedOrFreeBBs.insert(From).second) + return false; + + // We assume that we'll already have split suspends into their own blocks. + if (isSuspendBlock(From)) + return true; + + // Recurse on the successors. + for (auto Succ : successors(From)) { + if (isSuspendReachableFrom(Succ, VisitedOrFreeBBs)) + return true; + } + + return false; +} + +/// Is the given alloca "local", i.e. bounded in lifetime to not cross a +/// suspend point? +static bool isLocalAlloca(CoroAllocaAllocInst *AI) { + // Seed the visited set with all the basic blocks containing a free + // so that we won't pass them up. + VisitedBlocksSet VisitedOrFreeBBs; + for (auto User : AI->users()) { + if (auto FI = dyn_cast<CoroAllocaFreeInst>(User)) + VisitedOrFreeBBs.insert(FI->getParent()); + } + + return !isSuspendReachableFrom(AI->getParent(), VisitedOrFreeBBs); +} + +/// After we split the coroutine, will the given basic block be along +/// an obvious exit path for the resumption function? +static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB, + unsigned depth = 3) { + // If we've bottomed out our depth count, stop searching and assume + // that the path might loop back. + if (depth == 0) return false; + + // If this is a suspend block, we're about to exit the resumption function. + if (isSuspendBlock(BB)) return true; + + // Recurse into the successors. + for (auto Succ : successors(BB)) { + if (!willLeaveFunctionImmediatelyAfter(Succ, depth - 1)) + return false; + } + + // If none of the successors leads back in a loop, we're on an exit/abort. + return true; +} + +static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) { + // Look for a free that isn't sufficiently obviously followed by + // either a suspend or a termination, i.e. something that will leave + // the coro resumption frame. + for (auto U : AI->users()) { + auto FI = dyn_cast<CoroAllocaFreeInst>(U); + if (!FI) continue; + + if (!willLeaveFunctionImmediatelyAfter(FI->getParent())) + return true; + } + + // If we never found one, we don't need a stack save. + return false; +} + +/// Turn each of the given local allocas into a normal (dynamic) alloca +/// instruction. +static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas, + SmallVectorImpl<Instruction*> &DeadInsts) { + for (auto AI : LocalAllocas) { + auto M = AI->getModule(); + IRBuilder<> Builder(AI); + + // Save the stack depth. Try to avoid doing this if the stackrestore + // is going to immediately precede a return or something. + Value *StackSave = nullptr; + if (localAllocaNeedsStackSave(AI)) + StackSave = Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::stacksave)); + + // Allocate memory. + auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize()); + Alloca->setAlignment(MaybeAlign(AI->getAlignment())); + + for (auto U : AI->users()) { + // Replace gets with the allocation. + if (isa<CoroAllocaGetInst>(U)) { + U->replaceAllUsesWith(Alloca); + + // Replace frees with stackrestores. This is safe because + // alloca.alloc is required to obey a stack discipline, although we + // don't enforce that structurally. + } else { + auto FI = cast<CoroAllocaFreeInst>(U); + if (StackSave) { + Builder.SetInsertPoint(FI); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::stackrestore), + StackSave); + } + } + DeadInsts.push_back(cast<Instruction>(U)); + } + + DeadInsts.push_back(AI); + } +} + +/// Turn the given coro.alloca.alloc call into a dynamic allocation. +/// This happens during the all-instructions iteration, so it must not +/// delete the call. +static Instruction *lowerNonLocalAlloca(CoroAllocaAllocInst *AI, + coro::Shape &Shape, + SmallVectorImpl<Instruction*> &DeadInsts) { + IRBuilder<> Builder(AI); + auto Alloc = Shape.emitAlloc(Builder, AI->getSize(), nullptr); + + for (User *U : AI->users()) { + if (isa<CoroAllocaGetInst>(U)) { + U->replaceAllUsesWith(Alloc); + } else { + auto FI = cast<CoroAllocaFreeInst>(U); + Builder.SetInsertPoint(FI); + Shape.emitDealloc(Builder, Alloc, nullptr); + } + DeadInsts.push_back(cast<Instruction>(U)); + } + + // Push this on last so that it gets deleted after all the others. + DeadInsts.push_back(AI); + + // Return the new allocation value so that we can check for needed spills. + return cast<Instruction>(Alloc); +} + +/// Get the current swifterror value. +static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy, + coro::Shape &Shape) { + // Make a fake function pointer as a sort of intrinsic. + auto FnTy = FunctionType::get(ValueTy, {}, false); + auto Fn = ConstantPointerNull::get(FnTy->getPointerTo()); + + auto Call = Builder.CreateCall(Fn, {}); + Shape.SwiftErrorOps.push_back(Call); + + return Call; +} + +/// Set the given value as the current swifterror value. +/// +/// Returns a slot that can be used as a swifterror slot. +static Value *emitSetSwiftErrorValue(IRBuilder<> &Builder, Value *V, + coro::Shape &Shape) { + // Make a fake function pointer as a sort of intrinsic. + auto FnTy = FunctionType::get(V->getType()->getPointerTo(), + {V->getType()}, false); + auto Fn = ConstantPointerNull::get(FnTy->getPointerTo()); + + auto Call = Builder.CreateCall(Fn, { V }); + Shape.SwiftErrorOps.push_back(Call); + + return Call; +} + +/// Set the swifterror value from the given alloca before a call, +/// then put in back in the alloca afterwards. +/// +/// Returns an address that will stand in for the swifterror slot +/// until splitting. +static Value *emitSetAndGetSwiftErrorValueAround(Instruction *Call, + AllocaInst *Alloca, + coro::Shape &Shape) { + auto ValueTy = Alloca->getAllocatedType(); + IRBuilder<> Builder(Call); + + // Load the current value from the alloca and set it as the + // swifterror value. + auto ValueBeforeCall = Builder.CreateLoad(ValueTy, Alloca); + auto Addr = emitSetSwiftErrorValue(Builder, ValueBeforeCall, Shape); + + // Move to after the call. Since swifterror only has a guaranteed + // value on normal exits, we can ignore implicit and explicit unwind + // edges. + if (isa<CallInst>(Call)) { + Builder.SetInsertPoint(Call->getNextNode()); + } else { + auto Invoke = cast<InvokeInst>(Call); + Builder.SetInsertPoint(Invoke->getNormalDest()->getFirstNonPHIOrDbg()); + } + + // Get the current swifterror value and store it to the alloca. + auto ValueAfterCall = emitGetSwiftErrorValue(Builder, ValueTy, Shape); + Builder.CreateStore(ValueAfterCall, Alloca); + + return Addr; +} + +/// Eliminate a formerly-swifterror alloca by inserting the get/set +/// intrinsics and attempting to MemToReg the alloca away. +static void eliminateSwiftErrorAlloca(Function &F, AllocaInst *Alloca, + coro::Shape &Shape) { + for (auto UI = Alloca->use_begin(), UE = Alloca->use_end(); UI != UE; ) { + // We're likely changing the use list, so use a mutation-safe + // iteration pattern. + auto &Use = *UI; + ++UI; + + // swifterror values can only be used in very specific ways. + // We take advantage of that here. + auto User = Use.getUser(); + if (isa<LoadInst>(User) || isa<StoreInst>(User)) + continue; + + assert(isa<CallInst>(User) || isa<InvokeInst>(User)); + auto Call = cast<Instruction>(User); + + auto Addr = emitSetAndGetSwiftErrorValueAround(Call, Alloca, Shape); + + // Use the returned slot address as the call argument. + Use.set(Addr); + } + + // All the uses should be loads and stores now. + assert(isAllocaPromotable(Alloca)); +} + +/// "Eliminate" a swifterror argument by reducing it to the alloca case +/// and then loading and storing in the prologue and epilog. +/// +/// The argument keeps the swifterror flag. +static void eliminateSwiftErrorArgument(Function &F, Argument &Arg, + coro::Shape &Shape, + SmallVectorImpl<AllocaInst*> &AllocasToPromote) { + IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg()); + + auto ArgTy = cast<PointerType>(Arg.getType()); + auto ValueTy = ArgTy->getElementType(); + + // Reduce to the alloca case: + + // Create an alloca and replace all uses of the arg with it. + auto Alloca = Builder.CreateAlloca(ValueTy, ArgTy->getAddressSpace()); + Arg.replaceAllUsesWith(Alloca); + + // Set an initial value in the alloca. swifterror is always null on entry. + auto InitialValue = Constant::getNullValue(ValueTy); + Builder.CreateStore(InitialValue, Alloca); + + // Find all the suspends in the function and save and restore around them. + for (auto Suspend : Shape.CoroSuspends) { + (void) emitSetAndGetSwiftErrorValueAround(Suspend, Alloca, Shape); + } + + // Find all the coro.ends in the function and restore the error value. + for (auto End : Shape.CoroEnds) { + Builder.SetInsertPoint(End); + auto FinalValue = Builder.CreateLoad(ValueTy, Alloca); + (void) emitSetSwiftErrorValue(Builder, FinalValue, Shape); + } + + // Now we can use the alloca logic. + AllocasToPromote.push_back(Alloca); + eliminateSwiftErrorAlloca(F, Alloca, Shape); +} + +/// Eliminate all problematic uses of swifterror arguments and allocas +/// from the function. We'll fix them up later when splitting the function. +static void eliminateSwiftError(Function &F, coro::Shape &Shape) { + SmallVector<AllocaInst*, 4> AllocasToPromote; + + // Look for a swifterror argument. + for (auto &Arg : F.args()) { + if (!Arg.hasSwiftErrorAttr()) continue; + + eliminateSwiftErrorArgument(F, Arg, Shape, AllocasToPromote); + break; + } + + // Look for swifterror allocas. + for (auto &Inst : F.getEntryBlock()) { + auto Alloca = dyn_cast<AllocaInst>(&Inst); + if (!Alloca || !Alloca->isSwiftError()) continue; + + // Clear the swifterror flag. + Alloca->setSwiftError(false); + + AllocasToPromote.push_back(Alloca); + eliminateSwiftErrorAlloca(F, Alloca, Shape); + } + + // If we have any allocas to promote, compute a dominator tree and + // promote them en masse. + if (!AllocasToPromote.empty()) { + DominatorTree DT(F); + PromoteMemToReg(AllocasToPromote, DT); + } +} + void coro::buildCoroutineFrame(Function &F, Shape &Shape) { // Lower coro.dbg.declare to coro.dbg.value, since we are going to rewrite // access to local variables. LowerDbgDeclare(F); - Shape.PromiseAlloca = Shape.CoroBegin->getId()->getPromise(); - if (Shape.PromiseAlloca) { - Shape.CoroBegin->getId()->clearPromise(); + eliminateSwiftError(F, Shape); + + if (Shape.ABI == coro::ABI::Switch && + Shape.SwitchLowering.PromiseAlloca) { + Shape.getSwitchCoroId()->clearPromise(); } // Make sure that all coro.save, coro.suspend and the fallthrough coro.end // intrinsics are in their own blocks to simplify the logic of building up // SuspendCrossing data. - for (CoroSuspendInst *CSI : Shape.CoroSuspends) { - splitAround(CSI->getCoroSave(), "CoroSave"); + for (auto *CSI : Shape.CoroSuspends) { + if (auto *Save = CSI->getCoroSave()) + splitAround(Save, "CoroSave"); splitAround(CSI, "CoroSuspend"); } @@ -926,6 +1356,8 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { IRBuilder<> Builder(F.getContext()); SpillInfo Spills; + SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas; + SmallVector<Instruction*, 4> DeadInstructions; for (int Repeat = 0; Repeat < 4; ++Repeat) { // See if there are materializable instructions across suspend points. @@ -955,11 +1387,40 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { // of the Coroutine Frame. if (isCoroutineStructureIntrinsic(I) || &I == Shape.CoroBegin) continue; + // The Coroutine Promise always included into coroutine frame, no need to // check for suspend crossing. - if (Shape.PromiseAlloca == &I) + if (Shape.ABI == coro::ABI::Switch && + Shape.SwitchLowering.PromiseAlloca == &I) continue; + // Handle alloca.alloc specially here. + if (auto AI = dyn_cast<CoroAllocaAllocInst>(&I)) { + // Check whether the alloca's lifetime is bounded by suspend points. + if (isLocalAlloca(AI)) { + LocalAllocas.push_back(AI); + continue; + } + + // If not, do a quick rewrite of the alloca and then add spills of + // the rewritten value. The rewrite doesn't invalidate anything in + // Spills because the other alloca intrinsics have no other operands + // besides AI, and it doesn't invalidate the iteration because we delay + // erasing AI. + auto Alloc = lowerNonLocalAlloca(AI, Shape, DeadInstructions); + + for (User *U : Alloc->users()) { + if (Checker.isDefinitionAcrossSuspend(*Alloc, U)) + Spills.emplace_back(Alloc, U); + } + continue; + } + + // Ignore alloca.get; we process this as part of coro.alloca.alloc. + if (isa<CoroAllocaGetInst>(I)) { + continue; + } + for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) { // We cannot spill a token. @@ -970,7 +1431,10 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { } } LLVM_DEBUG(dump("Spills", Spills)); - moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin); Shape.FrameTy = buildFrameType(F, Shape, Spills); Shape.FramePtr = insertSpills(Spills, Shape); + lowerLocalAllocas(LocalAllocas, DeadInstructions); + + for (auto I : DeadInstructions) + I->eraseFromParent(); } diff --git a/lib/Transforms/Coroutines/CoroInstr.h b/lib/Transforms/Coroutines/CoroInstr.h index 5e19d7642e38..de2d2920cb15 100644 --- a/lib/Transforms/Coroutines/CoroInstr.h +++ b/lib/Transforms/Coroutines/CoroInstr.h @@ -27,6 +27,7 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/raw_ostream.h" namespace llvm { @@ -77,10 +78,8 @@ public: } }; -/// This represents the llvm.coro.alloc instruction. -class LLVM_LIBRARY_VISIBILITY CoroIdInst : public IntrinsicInst { - enum { AlignArg, PromiseArg, CoroutineArg, InfoArg }; - +/// This represents a common base class for llvm.coro.id instructions. +class LLVM_LIBRARY_VISIBILITY AnyCoroIdInst : public IntrinsicInst { public: CoroAllocInst *getCoroAlloc() { for (User *U : users()) @@ -97,6 +96,24 @@ public: llvm_unreachable("no coro.begin associated with coro.id"); } + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + auto ID = I->getIntrinsicID(); + return ID == Intrinsic::coro_id || + ID == Intrinsic::coro_id_retcon || + ID == Intrinsic::coro_id_retcon_once; + } + + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.id instruction. +class LLVM_LIBRARY_VISIBILITY CoroIdInst : public AnyCoroIdInst { + enum { AlignArg, PromiseArg, CoroutineArg, InfoArg }; + +public: AllocaInst *getPromise() const { Value *Arg = getArgOperand(PromiseArg); return isa<ConstantPointerNull>(Arg) @@ -182,6 +199,80 @@ public: } }; +/// This represents either the llvm.coro.id.retcon or +/// llvm.coro.id.retcon.once instruction. +class LLVM_LIBRARY_VISIBILITY AnyCoroIdRetconInst : public AnyCoroIdInst { + enum { SizeArg, AlignArg, StorageArg, PrototypeArg, AllocArg, DeallocArg }; + +public: + void checkWellFormed() const; + + uint64_t getStorageSize() const { + return cast<ConstantInt>(getArgOperand(SizeArg))->getZExtValue(); + } + + uint64_t getStorageAlignment() const { + return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue(); + } + + Value *getStorage() const { + return getArgOperand(StorageArg); + } + + /// Return the prototype for the continuation function. The type, + /// attributes, and calling convention of the continuation function(s) + /// are taken from this declaration. + Function *getPrototype() const { + return cast<Function>(getArgOperand(PrototypeArg)->stripPointerCasts()); + } + + /// Return the function to use for allocating memory. + Function *getAllocFunction() const { + return cast<Function>(getArgOperand(AllocArg)->stripPointerCasts()); + } + + /// Return the function to use for deallocating memory. + Function *getDeallocFunction() const { + return cast<Function>(getArgOperand(DeallocArg)->stripPointerCasts()); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + auto ID = I->getIntrinsicID(); + return ID == Intrinsic::coro_id_retcon + || ID == Intrinsic::coro_id_retcon_once; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.id.retcon instruction. +class LLVM_LIBRARY_VISIBILITY CoroIdRetconInst + : public AnyCoroIdRetconInst { +public: + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_id_retcon; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.id.retcon.once instruction. +class LLVM_LIBRARY_VISIBILITY CoroIdRetconOnceInst + : public AnyCoroIdRetconInst { +public: + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_id_retcon_once; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + /// This represents the llvm.coro.frame instruction. class LLVM_LIBRARY_VISIBILITY CoroFrameInst : public IntrinsicInst { public: @@ -215,7 +306,9 @@ class LLVM_LIBRARY_VISIBILITY CoroBeginInst : public IntrinsicInst { enum { IdArg, MemArg }; public: - CoroIdInst *getId() const { return cast<CoroIdInst>(getArgOperand(IdArg)); } + AnyCoroIdInst *getId() const { + return cast<AnyCoroIdInst>(getArgOperand(IdArg)); + } Value *getMem() const { return getArgOperand(MemArg); } @@ -261,8 +354,22 @@ public: } }; +class LLVM_LIBRARY_VISIBILITY AnyCoroSuspendInst : public IntrinsicInst { +public: + CoroSaveInst *getCoroSave() const; + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_suspend || + I->getIntrinsicID() == Intrinsic::coro_suspend_retcon; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + /// This represents the llvm.coro.suspend instruction. -class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public IntrinsicInst { +class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public AnyCoroSuspendInst { enum { SaveArg, FinalArg }; public: @@ -273,6 +380,7 @@ public: assert(isa<ConstantTokenNone>(Arg)); return nullptr; } + bool isFinal() const { return cast<Constant>(getArgOperand(FinalArg))->isOneValue(); } @@ -286,6 +394,37 @@ public: } }; +inline CoroSaveInst *AnyCoroSuspendInst::getCoroSave() const { + if (auto Suspend = dyn_cast<CoroSuspendInst>(this)) + return Suspend->getCoroSave(); + return nullptr; +} + +/// This represents the llvm.coro.suspend.retcon instruction. +class LLVM_LIBRARY_VISIBILITY CoroSuspendRetconInst : public AnyCoroSuspendInst { +public: + op_iterator value_begin() { return arg_begin(); } + const_op_iterator value_begin() const { return arg_begin(); } + + op_iterator value_end() { return arg_end(); } + const_op_iterator value_end() const { return arg_end(); } + + iterator_range<op_iterator> value_operands() { + return make_range(value_begin(), value_end()); + } + iterator_range<const_op_iterator> value_operands() const { + return make_range(value_begin(), value_end()); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_suspend_retcon; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + /// This represents the llvm.coro.size instruction. class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst { public: @@ -317,6 +456,60 @@ public: } }; +/// This represents the llvm.coro.alloca.alloc instruction. +class LLVM_LIBRARY_VISIBILITY CoroAllocaAllocInst : public IntrinsicInst { + enum { SizeArg, AlignArg }; +public: + Value *getSize() const { + return getArgOperand(SizeArg); + } + unsigned getAlignment() const { + return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue(); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_alloca_alloc; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.alloca.get instruction. +class LLVM_LIBRARY_VISIBILITY CoroAllocaGetInst : public IntrinsicInst { + enum { AllocArg }; +public: + CoroAllocaAllocInst *getAlloc() const { + return cast<CoroAllocaAllocInst>(getArgOperand(AllocArg)); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_alloca_get; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.alloca.free instruction. +class LLVM_LIBRARY_VISIBILITY CoroAllocaFreeInst : public IntrinsicInst { + enum { AllocArg }; +public: + CoroAllocaAllocInst *getAlloc() const { + return cast<CoroAllocaAllocInst>(getArgOperand(AllocArg)); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_alloca_free; + } + static bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + } // End namespace llvm. #endif diff --git a/lib/Transforms/Coroutines/CoroInternal.h b/lib/Transforms/Coroutines/CoroInternal.h index 441c8a20f1f3..c151474316f9 100644 --- a/lib/Transforms/Coroutines/CoroInternal.h +++ b/lib/Transforms/Coroutines/CoroInternal.h @@ -12,6 +12,7 @@ #define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H #include "CoroInstr.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Transforms/Coroutines.h" namespace llvm { @@ -61,37 +62,174 @@ struct LowererBase { Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt); }; +enum class ABI { + /// The "resume-switch" lowering, where there are separate resume and + /// destroy functions that are shared between all suspend points. The + /// coroutine frame implicitly stores the resume and destroy functions, + /// the current index, and any promise value. + Switch, + + /// The "returned-continuation" lowering, where each suspend point creates a + /// single continuation function that is used for both resuming and + /// destroying. Does not support promises. + Retcon, + + /// The "unique returned-continuation" lowering, where each suspend point + /// creates a single continuation function that is used for both resuming + /// and destroying. Does not support promises. The function is known to + /// suspend at most once during its execution, and the return value of + /// the continuation is void. + RetconOnce, +}; + // Holds structural Coroutine Intrinsics for a particular function and other // values used during CoroSplit pass. struct LLVM_LIBRARY_VISIBILITY Shape { CoroBeginInst *CoroBegin; SmallVector<CoroEndInst *, 4> CoroEnds; SmallVector<CoroSizeInst *, 2> CoroSizes; - SmallVector<CoroSuspendInst *, 4> CoroSuspends; - - // Field Indexes for known coroutine frame fields. - enum { - ResumeField, - DestroyField, - PromiseField, - IndexField, + SmallVector<AnyCoroSuspendInst *, 4> CoroSuspends; + SmallVector<CallInst*, 2> SwiftErrorOps; + + // Field indexes for special fields in the switch lowering. + struct SwitchFieldIndex { + enum { + Resume, + Destroy, + Promise, + Index, + /// The index of the first spill field. + FirstSpill + }; }; + coro::ABI ABI; + StructType *FrameTy; Instruction *FramePtr; BasicBlock *AllocaSpillBlock; - SwitchInst *ResumeSwitch; - AllocaInst *PromiseAlloca; - bool HasFinalSuspend; + + struct SwitchLoweringStorage { + SwitchInst *ResumeSwitch; + AllocaInst *PromiseAlloca; + BasicBlock *ResumeEntryBlock; + bool HasFinalSuspend; + }; + + struct RetconLoweringStorage { + Function *ResumePrototype; + Function *Alloc; + Function *Dealloc; + BasicBlock *ReturnBlock; + bool IsFrameInlineInStorage; + }; + + union { + SwitchLoweringStorage SwitchLowering; + RetconLoweringStorage RetconLowering; + }; + + CoroIdInst *getSwitchCoroId() const { + assert(ABI == coro::ABI::Switch); + return cast<CoroIdInst>(CoroBegin->getId()); + } + + AnyCoroIdRetconInst *getRetconCoroId() const { + assert(ABI == coro::ABI::Retcon || + ABI == coro::ABI::RetconOnce); + return cast<AnyCoroIdRetconInst>(CoroBegin->getId()); + } IntegerType *getIndexType() const { + assert(ABI == coro::ABI::Switch); assert(FrameTy && "frame type not assigned"); - return cast<IntegerType>(FrameTy->getElementType(IndexField)); + return cast<IntegerType>(FrameTy->getElementType(SwitchFieldIndex::Index)); } ConstantInt *getIndex(uint64_t Value) const { return ConstantInt::get(getIndexType(), Value); } + PointerType *getSwitchResumePointerType() const { + assert(ABI == coro::ABI::Switch); + assert(FrameTy && "frame type not assigned"); + return cast<PointerType>(FrameTy->getElementType(SwitchFieldIndex::Resume)); + } + + FunctionType *getResumeFunctionType() const { + switch (ABI) { + case coro::ABI::Switch: { + auto *FnPtrTy = getSwitchResumePointerType(); + return cast<FunctionType>(FnPtrTy->getPointerElementType()); + } + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + return RetconLowering.ResumePrototype->getFunctionType(); + } + llvm_unreachable("Unknown coro::ABI enum"); + } + + ArrayRef<Type*> getRetconResultTypes() const { + assert(ABI == coro::ABI::Retcon || + ABI == coro::ABI::RetconOnce); + auto FTy = CoroBegin->getFunction()->getFunctionType(); + + // The safety of all this is checked by checkWFRetconPrototype. + if (auto STy = dyn_cast<StructType>(FTy->getReturnType())) { + return STy->elements().slice(1); + } else { + return ArrayRef<Type*>(); + } + } + + ArrayRef<Type*> getRetconResumeTypes() const { + assert(ABI == coro::ABI::Retcon || + ABI == coro::ABI::RetconOnce); + + // The safety of all this is checked by checkWFRetconPrototype. + auto FTy = RetconLowering.ResumePrototype->getFunctionType(); + return FTy->params().slice(1); + } + + CallingConv::ID getResumeFunctionCC() const { + switch (ABI) { + case coro::ABI::Switch: + return CallingConv::Fast; + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + return RetconLowering.ResumePrototype->getCallingConv(); + } + llvm_unreachable("Unknown coro::ABI enum"); + } + + unsigned getFirstSpillFieldIndex() const { + switch (ABI) { + case coro::ABI::Switch: + return SwitchFieldIndex::FirstSpill; + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + return 0; + } + llvm_unreachable("Unknown coro::ABI enum"); + } + + AllocaInst *getPromiseAlloca() const { + if (ABI == coro::ABI::Switch) + return SwitchLowering.PromiseAlloca; + return nullptr; + } + + /// Allocate memory according to the rules of the active lowering. + /// + /// \param CG - if non-null, will be updated for the new call + Value *emitAlloc(IRBuilder<> &Builder, Value *Size, CallGraph *CG) const; + + /// Deallocate memory according to the rules of the active lowering. + /// + /// \param CG - if non-null, will be updated for the new call + void emitDealloc(IRBuilder<> &Builder, Value *Ptr, CallGraph *CG) const; + Shape() = default; explicit Shape(Function &F) { buildFrom(F); } void buildFrom(Function &F); diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp index 5458e70ff16a..04723cbde417 100644 --- a/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/lib/Transforms/Coroutines/CoroSplit.cpp @@ -55,6 +55,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -70,9 +71,197 @@ using namespace llvm; #define DEBUG_TYPE "coro-split" +namespace { + +/// A little helper class for building +class CoroCloner { +public: + enum class Kind { + /// The shared resume function for a switch lowering. + SwitchResume, + + /// The shared unwind function for a switch lowering. + SwitchUnwind, + + /// The shared cleanup function for a switch lowering. + SwitchCleanup, + + /// An individual continuation function. + Continuation, + }; +private: + Function &OrigF; + Function *NewF; + const Twine &Suffix; + coro::Shape &Shape; + Kind FKind; + ValueToValueMapTy VMap; + IRBuilder<> Builder; + Value *NewFramePtr = nullptr; + Value *SwiftErrorSlot = nullptr; + + /// The active suspend instruction; meaningful only for continuation ABIs. + AnyCoroSuspendInst *ActiveSuspend = nullptr; + +public: + /// Create a cloner for a switch lowering. + CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape, + Kind FKind) + : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape), + FKind(FKind), Builder(OrigF.getContext()) { + assert(Shape.ABI == coro::ABI::Switch); + } + + /// Create a cloner for a continuation lowering. + CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape, + Function *NewF, AnyCoroSuspendInst *ActiveSuspend) + : OrigF(OrigF), NewF(NewF), Suffix(Suffix), Shape(Shape), + FKind(Kind::Continuation), Builder(OrigF.getContext()), + ActiveSuspend(ActiveSuspend) { + assert(Shape.ABI == coro::ABI::Retcon || + Shape.ABI == coro::ABI::RetconOnce); + assert(NewF && "need existing function for continuation"); + assert(ActiveSuspend && "need active suspend point for continuation"); + } + + Function *getFunction() const { + assert(NewF != nullptr && "declaration not yet set"); + return NewF; + } + + void create(); + +private: + bool isSwitchDestroyFunction() { + switch (FKind) { + case Kind::Continuation: + case Kind::SwitchResume: + return false; + case Kind::SwitchUnwind: + case Kind::SwitchCleanup: + return true; + } + llvm_unreachable("Unknown CoroCloner::Kind enum"); + } + + void createDeclaration(); + void replaceEntryBlock(); + Value *deriveNewFramePointer(); + void replaceRetconSuspendUses(); + void replaceCoroSuspends(); + void replaceCoroEnds(); + void replaceSwiftErrorOps(); + void handleFinalSuspend(); + void maybeFreeContinuationStorage(); +}; + +} // end anonymous namespace + +static void maybeFreeRetconStorage(IRBuilder<> &Builder, coro::Shape &Shape, + Value *FramePtr, CallGraph *CG) { + assert(Shape.ABI == coro::ABI::Retcon || + Shape.ABI == coro::ABI::RetconOnce); + if (Shape.RetconLowering.IsFrameInlineInStorage) + return; + + Shape.emitDealloc(Builder, FramePtr, CG); +} + +/// Replace a non-unwind call to llvm.coro.end. +static void replaceFallthroughCoroEnd(CoroEndInst *End, coro::Shape &Shape, + Value *FramePtr, bool InResume, + CallGraph *CG) { + // Start inserting right before the coro.end. + IRBuilder<> Builder(End); + + // Create the return instruction. + switch (Shape.ABI) { + // The cloned functions in switch-lowering always return void. + case coro::ABI::Switch: + // coro.end doesn't immediately end the coroutine in the main function + // in this lowering, because we need to deallocate the coroutine. + if (!InResume) + return; + Builder.CreateRetVoid(); + break; + + // In unique continuation lowering, the continuations always return void. + // But we may have implicitly allocated storage. + case coro::ABI::RetconOnce: + maybeFreeRetconStorage(Builder, Shape, FramePtr, CG); + Builder.CreateRetVoid(); + break; + + // In non-unique continuation lowering, we signal completion by returning + // a null continuation. + case coro::ABI::Retcon: { + maybeFreeRetconStorage(Builder, Shape, FramePtr, CG); + auto RetTy = Shape.getResumeFunctionType()->getReturnType(); + auto RetStructTy = dyn_cast<StructType>(RetTy); + PointerType *ContinuationTy = + cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy); + + Value *ReturnValue = ConstantPointerNull::get(ContinuationTy); + if (RetStructTy) { + ReturnValue = Builder.CreateInsertValue(UndefValue::get(RetStructTy), + ReturnValue, 0); + } + Builder.CreateRet(ReturnValue); + break; + } + } + + // Remove the rest of the block, by splitting it into an unreachable block. + auto *BB = End->getParent(); + BB->splitBasicBlock(End); + BB->getTerminator()->eraseFromParent(); +} + +/// Replace an unwind call to llvm.coro.end. +static void replaceUnwindCoroEnd(CoroEndInst *End, coro::Shape &Shape, + Value *FramePtr, bool InResume, CallGraph *CG){ + IRBuilder<> Builder(End); + + switch (Shape.ABI) { + // In switch-lowering, this does nothing in the main function. + case coro::ABI::Switch: + if (!InResume) + return; + break; + + // In continuation-lowering, this frees the continuation storage. + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + maybeFreeRetconStorage(Builder, Shape, FramePtr, CG); + break; + } + + // If coro.end has an associated bundle, add cleanupret instruction. + if (auto Bundle = End->getOperandBundle(LLVMContext::OB_funclet)) { + auto *FromPad = cast<CleanupPadInst>(Bundle->Inputs[0]); + auto *CleanupRet = Builder.CreateCleanupRet(FromPad, nullptr); + End->getParent()->splitBasicBlock(End); + CleanupRet->getParent()->getTerminator()->eraseFromParent(); + } +} + +static void replaceCoroEnd(CoroEndInst *End, coro::Shape &Shape, + Value *FramePtr, bool InResume, CallGraph *CG) { + if (End->isUnwind()) + replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG); + else + replaceFallthroughCoroEnd(End, Shape, FramePtr, InResume, CG); + + auto &Context = End->getContext(); + End->replaceAllUsesWith(InResume ? ConstantInt::getTrue(Context) + : ConstantInt::getFalse(Context)); + End->eraseFromParent(); +} + // Create an entry block for a resume function with a switch that will jump to // suspend points. -static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { +static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { + assert(Shape.ABI == coro::ABI::Switch); LLVMContext &C = F.getContext(); // resume.entry: @@ -91,15 +280,16 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { IRBuilder<> Builder(NewEntry); auto *FramePtr = Shape.FramePtr; auto *FrameTy = Shape.FrameTy; - auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( - FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); + auto *GepIndex = Builder.CreateStructGEP( + FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr"); auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index"); auto *Switch = Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); - Shape.ResumeSwitch = Switch; + Shape.SwitchLowering.ResumeSwitch = Switch; size_t SuspendIndex = 0; - for (CoroSuspendInst *S : Shape.CoroSuspends) { + for (auto *AnyS : Shape.CoroSuspends) { + auto *S = cast<CoroSuspendInst>(AnyS); ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); // Replace CoroSave with a store to Index: @@ -109,14 +299,15 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { Builder.SetInsertPoint(Save); if (S->isFinal()) { // Final suspend point is represented by storing zero in ResumeFnAddr. - auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, - 0, "ResumeFn.addr"); + auto *GepIndex = Builder.CreateStructGEP(FrameTy, FramePtr, + coro::Shape::SwitchFieldIndex::Resume, + "ResumeFn.addr"); auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( cast<PointerType>(GepIndex->getType())->getElementType())); Builder.CreateStore(NullPtr, GepIndex); } else { - auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( - FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); + auto *GepIndex = Builder.CreateStructGEP( + FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr"); Builder.CreateStore(IndexVal, GepIndex); } Save->replaceAllUsesWith(ConstantTokenNone::get(C)); @@ -164,48 +355,9 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { Builder.SetInsertPoint(UnreachBB); Builder.CreateUnreachable(); - return NewEntry; + Shape.SwitchLowering.ResumeEntryBlock = NewEntry; } -// In Resumers, we replace fallthrough coro.end with ret void and delete the -// rest of the block. -static void replaceFallthroughCoroEnd(IntrinsicInst *End, - ValueToValueMapTy &VMap) { - auto *NewE = cast<IntrinsicInst>(VMap[End]); - ReturnInst::Create(NewE->getContext(), nullptr, NewE); - - // Remove the rest of the block, by splitting it into an unreachable block. - auto *BB = NewE->getParent(); - BB->splitBasicBlock(NewE); - BB->getTerminator()->eraseFromParent(); -} - -// In Resumers, we replace unwind coro.end with True to force the immediate -// unwind to caller. -static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { - if (Shape.CoroEnds.empty()) - return; - - LLVMContext &Context = Shape.CoroEnds.front()->getContext(); - auto *True = ConstantInt::getTrue(Context); - for (CoroEndInst *CE : Shape.CoroEnds) { - if (!CE->isUnwind()) - continue; - - auto *NewCE = cast<IntrinsicInst>(VMap[CE]); - - // If coro.end has an associated bundle, add cleanupret instruction. - if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { - Value *FromPad = Bundle->Inputs[0]; - auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); - NewCE->getParent()->splitBasicBlock(NewCE); - CleanupRet->getParent()->getTerminator()->eraseFromParent(); - } - - NewCE->replaceAllUsesWith(True); - NewCE->eraseFromParent(); - } -} // Rewrite final suspend point handling. We do not use suspend index to // represent the final suspend point. Instead we zero-out ResumeFnAddr in the @@ -216,83 +368,364 @@ static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { // In the destroy function, we add a code sequence to check if ResumeFnAddress // is Null, and if so, jump to the appropriate label to handle cleanup from the // final suspend point. -static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, - coro::Shape &Shape, SwitchInst *Switch, - bool IsDestroy) { - assert(Shape.HasFinalSuspend); +void CoroCloner::handleFinalSuspend() { + assert(Shape.ABI == coro::ABI::Switch && + Shape.SwitchLowering.HasFinalSuspend); + auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]); auto FinalCaseIt = std::prev(Switch->case_end()); BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); Switch->removeCase(FinalCaseIt); - if (IsDestroy) { + if (isSwitchDestroyFunction()) { BasicBlock *OldSwitchBB = Switch->getParent(); auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); Builder.SetInsertPoint(OldSwitchBB->getTerminator()); - auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, - 0, 0, "ResumeFn.addr"); - auto *Load = Builder.CreateLoad( - Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex); - auto *NullPtr = - ConstantPointerNull::get(cast<PointerType>(Load->getType())); - auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); + auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr, + coro::Shape::SwitchFieldIndex::Resume, + "ResumeFn.addr"); + auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(), + GepIndex); + auto *Cond = Builder.CreateIsNull(Load); Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); OldSwitchBB->getTerminator()->eraseFromParent(); } } -// Create a resume clone by cloning the body of the original function, setting -// new entry block and replacing coro.suspend an appropriate value to force -// resume or cleanup pass for every suspend point. -static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, - BasicBlock *ResumeEntry, int8_t FnIndex) { - Module *M = F.getParent(); - auto *FrameTy = Shape.FrameTy; - auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0)); - auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); +static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape, + const Twine &Suffix, + Module::iterator InsertBefore) { + Module *M = OrigF.getParent(); + auto *FnTy = Shape.getResumeFunctionType(); Function *NewF = - Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage, - F.getName() + Suffix, M); + Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, + OrigF.getName() + Suffix); NewF->addParamAttr(0, Attribute::NonNull); NewF->addParamAttr(0, Attribute::NoAlias); - ValueToValueMapTy VMap; + M->getFunctionList().insert(InsertBefore, NewF); + + return NewF; +} + +/// Replace uses of the active llvm.coro.suspend.retcon call with the +/// arguments to the continuation function. +/// +/// This assumes that the builder has a meaningful insertion point. +void CoroCloner::replaceRetconSuspendUses() { + assert(Shape.ABI == coro::ABI::Retcon || + Shape.ABI == coro::ABI::RetconOnce); + + auto NewS = VMap[ActiveSuspend]; + if (NewS->use_empty()) return; + + // Copy out all the continuation arguments after the buffer pointer into + // an easily-indexed data structure for convenience. + SmallVector<Value*, 8> Args; + for (auto I = std::next(NewF->arg_begin()), E = NewF->arg_end(); I != E; ++I) + Args.push_back(&*I); + + // If the suspend returns a single scalar value, we can just do a simple + // replacement. + if (!isa<StructType>(NewS->getType())) { + assert(Args.size() == 1); + NewS->replaceAllUsesWith(Args.front()); + return; + } + + // Try to peephole extracts of an aggregate return. + for (auto UI = NewS->use_begin(), UE = NewS->use_end(); UI != UE; ) { + auto EVI = dyn_cast<ExtractValueInst>((UI++)->getUser()); + if (!EVI || EVI->getNumIndices() != 1) + continue; + + EVI->replaceAllUsesWith(Args[EVI->getIndices().front()]); + EVI->eraseFromParent(); + } + + // If we have no remaining uses, we're done. + if (NewS->use_empty()) return; + + // Otherwise, we need to create an aggregate. + Value *Agg = UndefValue::get(NewS->getType()); + for (size_t I = 0, E = Args.size(); I != E; ++I) + Agg = Builder.CreateInsertValue(Agg, Args[I], I); + + NewS->replaceAllUsesWith(Agg); +} + +void CoroCloner::replaceCoroSuspends() { + Value *SuspendResult; + + switch (Shape.ABI) { + // In switch lowering, replace coro.suspend with the appropriate value + // for the type of function we're extracting. + // Replacing coro.suspend with (0) will result in control flow proceeding to + // a resume label associated with a suspend point, replacing it with (1) will + // result in control flow proceeding to a cleanup label associated with this + // suspend point. + case coro::ABI::Switch: + SuspendResult = Builder.getInt8(isSwitchDestroyFunction() ? 1 : 0); + break; + + // In returned-continuation lowering, the arguments from earlier + // continuations are theoretically arbitrary, and they should have been + // spilled. + case coro::ABI::RetconOnce: + case coro::ABI::Retcon: + return; + } + + for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) { + // The active suspend was handled earlier. + if (CS == ActiveSuspend) continue; + + auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]); + MappedCS->replaceAllUsesWith(SuspendResult); + MappedCS->eraseFromParent(); + } +} + +void CoroCloner::replaceCoroEnds() { + for (CoroEndInst *CE : Shape.CoroEnds) { + // We use a null call graph because there's no call graph node for + // the cloned function yet. We'll just be rebuilding that later. + auto NewCE = cast<CoroEndInst>(VMap[CE]); + replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr); + } +} + +static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape, + ValueToValueMapTy *VMap) { + Value *CachedSlot = nullptr; + auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * { + if (CachedSlot) { + assert(CachedSlot->getType()->getPointerElementType() == ValueTy && + "multiple swifterror slots in function with different types"); + return CachedSlot; + } + + // Check if the function has a swifterror argument. + for (auto &Arg : F.args()) { + if (Arg.isSwiftError()) { + CachedSlot = &Arg; + assert(Arg.getType()->getPointerElementType() == ValueTy && + "swifterror argument does not have expected type"); + return &Arg; + } + } + + // Create a swifterror alloca. + IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg()); + auto Alloca = Builder.CreateAlloca(ValueTy); + Alloca->setSwiftError(true); + + CachedSlot = Alloca; + return Alloca; + }; + + for (CallInst *Op : Shape.SwiftErrorOps) { + auto MappedOp = VMap ? cast<CallInst>((*VMap)[Op]) : Op; + IRBuilder<> Builder(MappedOp); + + // If there are no arguments, this is a 'get' operation. + Value *MappedResult; + if (Op->getNumArgOperands() == 0) { + auto ValueTy = Op->getType(); + auto Slot = getSwiftErrorSlot(ValueTy); + MappedResult = Builder.CreateLoad(ValueTy, Slot); + } else { + assert(Op->getNumArgOperands() == 1); + auto Value = MappedOp->getArgOperand(0); + auto ValueTy = Value->getType(); + auto Slot = getSwiftErrorSlot(ValueTy); + Builder.CreateStore(Value, Slot); + MappedResult = Slot; + } + + MappedOp->replaceAllUsesWith(MappedResult); + MappedOp->eraseFromParent(); + } + + // If we're updating the original function, we've invalidated SwiftErrorOps. + if (VMap == nullptr) { + Shape.SwiftErrorOps.clear(); + } +} + +void CoroCloner::replaceSwiftErrorOps() { + ::replaceSwiftErrorOps(*NewF, Shape, &VMap); +} + +void CoroCloner::replaceEntryBlock() { + // In the original function, the AllocaSpillBlock is a block immediately + // following the allocation of the frame object which defines GEPs for + // all the allocas that have been moved into the frame, and it ends by + // branching to the original beginning of the coroutine. Make this + // the entry block of the cloned function. + auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); + Entry->setName("entry" + Suffix); + Entry->moveBefore(&NewF->getEntryBlock()); + Entry->getTerminator()->eraseFromParent(); + + // Clear all predecessors of the new entry block. There should be + // exactly one predecessor, which we created when splitting out + // AllocaSpillBlock to begin with. + assert(Entry->hasOneUse()); + auto BranchToEntry = cast<BranchInst>(Entry->user_back()); + assert(BranchToEntry->isUnconditional()); + Builder.SetInsertPoint(BranchToEntry); + Builder.CreateUnreachable(); + BranchToEntry->eraseFromParent(); + + // TODO: move any allocas into Entry that weren't moved into the frame. + // (Currently we move all allocas into the frame.) + + // Branch from the entry to the appropriate place. + Builder.SetInsertPoint(Entry); + switch (Shape.ABI) { + case coro::ABI::Switch: { + // In switch-lowering, we built a resume-entry block in the original + // function. Make the entry block branch to this. + auto *SwitchBB = + cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]); + Builder.CreateBr(SwitchBB); + break; + } + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: { + // In continuation ABIs, we want to branch to immediately after the + // active suspend point. Earlier phases will have put the suspend in its + // own basic block, so just thread our jump directly to its successor. + auto MappedCS = cast<CoroSuspendRetconInst>(VMap[ActiveSuspend]); + auto Branch = cast<BranchInst>(MappedCS->getNextNode()); + assert(Branch->isUnconditional()); + Builder.CreateBr(Branch->getSuccessor(0)); + break; + } + } +} + +/// Derive the value of the new frame pointer. +Value *CoroCloner::deriveNewFramePointer() { + // Builder should be inserting to the front of the new entry block. + + switch (Shape.ABI) { + // In switch-lowering, the argument is the frame pointer. + case coro::ABI::Switch: + return &*NewF->arg_begin(); + + // In continuation-lowering, the argument is the opaque storage. + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: { + Argument *NewStorage = &*NewF->arg_begin(); + auto FramePtrTy = Shape.FrameTy->getPointerTo(); + + // If the storage is inline, just bitcast to the storage to the frame type. + if (Shape.RetconLowering.IsFrameInlineInStorage) + return Builder.CreateBitCast(NewStorage, FramePtrTy); + + // Otherwise, load the real frame from the opaque storage. + auto FramePtrPtr = + Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo()); + return Builder.CreateLoad(FramePtrPtr); + } + } + llvm_unreachable("bad ABI"); +} + +/// Clone the body of the original function into a resume function of +/// some sort. +void CoroCloner::create() { + // Create the new function if we don't already have one. + if (!NewF) { + NewF = createCloneDeclaration(OrigF, Shape, Suffix, + OrigF.getParent()->end()); + } + // Replace all args with undefs. The buildCoroutineFrame algorithm already // rewritten access to the args that occurs after suspend points with loads // and stores to/from the coroutine frame. - for (Argument &A : F.args()) + for (Argument &A : OrigF.args()) VMap[&A] = UndefValue::get(A.getType()); SmallVector<ReturnInst *, 4> Returns; - CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); - NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); + // Ignore attempts to change certain attributes of the function. + // TODO: maybe there should be a way to suppress this during cloning? + auto savedVisibility = NewF->getVisibility(); + auto savedUnnamedAddr = NewF->getUnnamedAddr(); + auto savedDLLStorageClass = NewF->getDLLStorageClass(); + + // NewF's linkage (which CloneFunctionInto does *not* change) might not + // be compatible with the visibility of OrigF (which it *does* change), + // so protect against that. + auto savedLinkage = NewF->getLinkage(); + NewF->setLinkage(llvm::GlobalValue::ExternalLinkage); + + CloneFunctionInto(NewF, &OrigF, VMap, /*ModuleLevelChanges=*/true, Returns); + + NewF->setLinkage(savedLinkage); + NewF->setVisibility(savedVisibility); + NewF->setUnnamedAddr(savedUnnamedAddr); + NewF->setDLLStorageClass(savedDLLStorageClass); + + auto &Context = NewF->getContext(); + + // Replace the attributes of the new function: + auto OrigAttrs = NewF->getAttributes(); + auto NewAttrs = AttributeList(); + + switch (Shape.ABI) { + case coro::ABI::Switch: + // Bootstrap attributes by copying function attributes from the + // original function. This should include optimization settings and so on. + NewAttrs = NewAttrs.addAttributes(Context, AttributeList::FunctionIndex, + OrigAttrs.getFnAttributes()); + break; + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + // If we have a continuation prototype, just use its attributes, + // full-stop. + NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes(); + break; + } - // Remove old returns. - for (ReturnInst *Return : Returns) - changeToUnreachable(Return, /*UseLLVMTrap=*/false); + // Make the frame parameter nonnull and noalias. + NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NonNull); + NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NoAlias); + + switch (Shape.ABI) { + // In these ABIs, the cloned functions always return 'void', and the + // existing return sites are meaningless. Note that for unique + // continuations, this includes the returns associated with suspends; + // this is fine because we can't suspend twice. + case coro::ABI::Switch: + case coro::ABI::RetconOnce: + // Remove old returns. + for (ReturnInst *Return : Returns) + changeToUnreachable(Return, /*UseLLVMTrap=*/false); + break; + + // With multi-suspend continuations, we'll already have eliminated the + // original returns and inserted returns before all the suspend points, + // so we want to leave any returns in place. + case coro::ABI::Retcon: + break; + } - // Remove old return attributes. - NewF->removeAttributes( - AttributeList::ReturnIndex, - AttributeFuncs::typeIncompatible(NewF->getReturnType())); + NewF->setAttributes(NewAttrs); + NewF->setCallingConv(Shape.getResumeFunctionCC()); - // Make AllocaSpillBlock the new entry block. - auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); - auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); - Entry->moveBefore(&NewF->getEntryBlock()); - Entry->getTerminator()->eraseFromParent(); - BranchInst::Create(SwitchBB, Entry); - Entry->setName("entry" + Suffix); + // Set up the new entry block. + replaceEntryBlock(); - // Clear all predecessors of the new entry block. - auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); - Entry->replaceAllUsesWith(Switch->getDefaultDest()); - - IRBuilder<> Builder(&NewF->getEntryBlock().front()); + Builder.SetInsertPoint(&NewF->getEntryBlock().front()); + NewFramePtr = deriveNewFramePointer(); // Remap frame pointer. - Argument *NewFramePtr = &*NewF->arg_begin(); - Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); + Value *OldFramePtr = VMap[Shape.FramePtr]; NewFramePtr->takeName(OldFramePtr); OldFramePtr->replaceAllUsesWith(NewFramePtr); @@ -302,50 +735,55 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); OldVFrame->replaceAllUsesWith(NewVFrame); - // Rewrite final suspend handling as it is not done via switch (allows to - // remove final case from the switch, since it is undefined behavior to resume - // the coroutine suspended at the final suspend point. - if (Shape.HasFinalSuspend) { - auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); - bool IsDestroy = FnIndex != 0; - handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); + switch (Shape.ABI) { + case coro::ABI::Switch: + // Rewrite final suspend handling as it is not done via switch (allows to + // remove final case from the switch, since it is undefined behavior to + // resume the coroutine suspended at the final suspend point. + if (Shape.SwitchLowering.HasFinalSuspend) + handleFinalSuspend(); + break; + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + // Replace uses of the active suspend with the corresponding + // continuation-function arguments. + assert(ActiveSuspend != nullptr && + "no active suspend when lowering a continuation-style coroutine"); + replaceRetconSuspendUses(); + break; } - // Replace coro suspend with the appropriate resume index. - // Replacing coro.suspend with (0) will result in control flow proceeding to - // a resume label associated with a suspend point, replacing it with (1) will - // result in control flow proceeding to a cleanup label associated with this - // suspend point. - auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0); - for (CoroSuspendInst *CS : Shape.CoroSuspends) { - auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]); - MappedCS->replaceAllUsesWith(NewValue); - MappedCS->eraseFromParent(); - } + // Handle suspends. + replaceCoroSuspends(); + + // Handle swifterror. + replaceSwiftErrorOps(); // Remove coro.end intrinsics. - replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); - replaceUnwindCoroEnds(Shape, VMap); + replaceCoroEnds(); + // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, // to suppress deallocation code. - coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), - /*Elide=*/FnIndex == 2); - - NewF->setCallingConv(CallingConv::Fast); - - return NewF; + if (Shape.ABI == coro::ABI::Switch) + coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), + /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup); } -static void removeCoroEnds(coro::Shape &Shape) { - if (Shape.CoroEnds.empty()) - return; - - LLVMContext &Context = Shape.CoroEnds.front()->getContext(); - auto *False = ConstantInt::getFalse(Context); +// Create a resume clone by cloning the body of the original function, setting +// new entry block and replacing coro.suspend an appropriate value to force +// resume or cleanup pass for every suspend point. +static Function *createClone(Function &F, const Twine &Suffix, + coro::Shape &Shape, CoroCloner::Kind FKind) { + CoroCloner Cloner(F, Suffix, Shape, FKind); + Cloner.create(); + return Cloner.getFunction(); +} - for (CoroEndInst *CE : Shape.CoroEnds) { - CE->replaceAllUsesWith(False); - CE->eraseFromParent(); +/// Remove calls to llvm.coro.end in the original function. +static void removeCoroEnds(coro::Shape &Shape, CallGraph *CG) { + for (auto End : Shape.CoroEnds) { + replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, CG); } } @@ -377,8 +815,12 @@ static void replaceFrameSize(coro::Shape &Shape) { // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) // // Assumes that all the functions have the same signature. -static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, - std::initializer_list<Function *> Fns) { +static void setCoroInfo(Function &F, coro::Shape &Shape, + ArrayRef<Function *> Fns) { + // This only works under the switch-lowering ABI because coro elision + // only works on the switch-lowering ABI. + assert(Shape.ABI == coro::ABI::Switch); + SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); assert(!Args.empty()); Function *Part = *Fns.begin(); @@ -393,38 +835,45 @@ static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, // Update coro.begin instruction to refer to this constant. LLVMContext &C = F.getContext(); auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); - CoroBegin->getId()->setInfo(BC); + Shape.getSwitchCoroId()->setInfo(BC); } // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, Function *DestroyFn, Function *CleanupFn) { + assert(Shape.ABI == coro::ABI::Switch); + IRBuilder<> Builder(Shape.FramePtr->getNextNode()); - auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( - Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, + auto *ResumeAddr = Builder.CreateStructGEP( + Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume, "resume.addr"); Builder.CreateStore(ResumeFn, ResumeAddr); Value *DestroyOrCleanupFn = DestroyFn; - CoroIdInst *CoroId = Shape.CoroBegin->getId(); + CoroIdInst *CoroId = Shape.getSwitchCoroId(); if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { // If there is a CoroAlloc and it returns false (meaning we elide the // allocation, use CleanupFn instead of DestroyFn). DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); } - auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( - Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, + auto *DestroyAddr = Builder.CreateStructGEP( + Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy, "destroy.addr"); Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); } static void postSplitCleanup(Function &F) { removeUnreachableBlocks(F); + + // For now, we do a mandatory verification step because we don't + // entirely trust this pass. Note that we don't want to add a verifier + // pass to FPM below because it will also verify all the global data. + verifyFunction(F); + legacy::FunctionPassManager FPM(F.getParent()); - FPM.add(createVerifierPass()); FPM.add(createSCCPPass()); FPM.add(createCFGSimplificationPass()); FPM.add(createEarlyCSEPass()); @@ -520,21 +969,34 @@ static void addMustTailToCoroResumes(Function &F) { // Coroutine has no suspend points. Remove heap allocation for the coroutine // frame if possible. -static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { +static void handleNoSuspendCoroutine(coro::Shape &Shape) { + auto *CoroBegin = Shape.CoroBegin; auto *CoroId = CoroBegin->getId(); auto *AllocInst = CoroId->getCoroAlloc(); - coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); - if (AllocInst) { - IRBuilder<> Builder(AllocInst); - // FIXME: Need to handle overaligned members. - auto *Frame = Builder.CreateAlloca(FrameTy); - auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); - AllocInst->replaceAllUsesWith(Builder.getFalse()); - AllocInst->eraseFromParent(); - CoroBegin->replaceAllUsesWith(VFrame); - } else { - CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); + switch (Shape.ABI) { + case coro::ABI::Switch: { + auto SwitchId = cast<CoroIdInst>(CoroId); + coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr); + if (AllocInst) { + IRBuilder<> Builder(AllocInst); + // FIXME: Need to handle overaligned members. + auto *Frame = Builder.CreateAlloca(Shape.FrameTy); + auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); + AllocInst->replaceAllUsesWith(Builder.getFalse()); + AllocInst->eraseFromParent(); + CoroBegin->replaceAllUsesWith(VFrame); + } else { + CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); + } + break; + } + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + CoroBegin->replaceAllUsesWith(UndefValue::get(CoroBegin->getType())); + break; } + CoroBegin->eraseFromParent(); } @@ -670,12 +1132,16 @@ static bool simplifySuspendPoint(CoroSuspendInst *Suspend, // Remove suspend points that are simplified. static void simplifySuspendPoints(coro::Shape &Shape) { + // Currently, the only simplification we do is switch-lowering-specific. + if (Shape.ABI != coro::ABI::Switch) + return; + auto &S = Shape.CoroSuspends; size_t I = 0, N = S.size(); if (N == 0) return; while (true) { - if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { + if (simplifySuspendPoint(cast<CoroSuspendInst>(S[I]), Shape.CoroBegin)) { if (--N == I) break; std::swap(S[I], S[N]); @@ -687,142 +1153,227 @@ static void simplifySuspendPoints(coro::Shape &Shape) { S.resize(N); } -static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) { - // Collect all blocks that we need to look for instructions to relocate. - SmallPtrSet<BasicBlock *, 4> RelocBlocks; - SmallVector<BasicBlock *, 4> Work; - Work.push_back(CB->getParent()); +static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, + SmallVectorImpl<Function *> &Clones) { + assert(Shape.ABI == coro::ABI::Switch); - do { - BasicBlock *Current = Work.pop_back_val(); - for (BasicBlock *BB : predecessors(Current)) - if (RelocBlocks.count(BB) == 0) { - RelocBlocks.insert(BB); - Work.push_back(BB); - } - } while (!Work.empty()); - return RelocBlocks; -} - -static SmallPtrSet<Instruction *, 8> -getNotRelocatableInstructions(CoroBeginInst *CoroBegin, - SmallPtrSetImpl<BasicBlock *> &RelocBlocks) { - SmallPtrSet<Instruction *, 8> DoNotRelocate; - // Collect all instructions that we should not relocate - SmallVector<Instruction *, 8> Work; - - // Start with CoroBegin and terminators of all preceding blocks. - Work.push_back(CoroBegin); - BasicBlock *CoroBeginBB = CoroBegin->getParent(); - for (BasicBlock *BB : RelocBlocks) - if (BB != CoroBeginBB) - Work.push_back(BB->getTerminator()); - - // For every instruction in the Work list, place its operands in DoNotRelocate - // set. - do { - Instruction *Current = Work.pop_back_val(); - LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n"); - DoNotRelocate.insert(Current); - for (Value *U : Current->operands()) { - auto *I = dyn_cast<Instruction>(U); - if (!I) - continue; + createResumeEntryBlock(F, Shape); + auto ResumeClone = createClone(F, ".resume", Shape, + CoroCloner::Kind::SwitchResume); + auto DestroyClone = createClone(F, ".destroy", Shape, + CoroCloner::Kind::SwitchUnwind); + auto CleanupClone = createClone(F, ".cleanup", Shape, + CoroCloner::Kind::SwitchCleanup); - if (auto *A = dyn_cast<AllocaInst>(I)) { - // Stores to alloca instructions that occur before the coroutine frame - // is allocated should not be moved; the stored values may be used by - // the coroutine frame allocator. The operands to those stores must also - // remain in place. - for (const auto &User : A->users()) - if (auto *SI = dyn_cast<llvm::StoreInst>(User)) - if (RelocBlocks.count(SI->getParent()) != 0 && - DoNotRelocate.count(SI) == 0) { - Work.push_back(SI); - DoNotRelocate.insert(SI); - } - continue; - } + postSplitCleanup(*ResumeClone); + postSplitCleanup(*DestroyClone); + postSplitCleanup(*CleanupClone); + + addMustTailToCoroResumes(*ResumeClone); + + // Store addresses resume/destroy/cleanup functions in the coroutine frame. + updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); + + assert(Clones.empty()); + Clones.push_back(ResumeClone); + Clones.push_back(DestroyClone); + Clones.push_back(CleanupClone); + + // Create a constant array referring to resume/destroy/clone functions pointed + // by the last argument of @llvm.coro.info, so that CoroElide pass can + // determined correct function to call. + setCoroInfo(F, Shape, Clones); +} - if (DoNotRelocate.count(I) == 0) { - Work.push_back(I); - DoNotRelocate.insert(I); +static void splitRetconCoroutine(Function &F, coro::Shape &Shape, + SmallVectorImpl<Function *> &Clones) { + assert(Shape.ABI == coro::ABI::Retcon || + Shape.ABI == coro::ABI::RetconOnce); + assert(Clones.empty()); + + // Reset various things that the optimizer might have decided it + // "knows" about the coroutine function due to not seeing a return. + F.removeFnAttr(Attribute::NoReturn); + F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + + // Allocate the frame. + auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId()); + Value *RawFramePtr; + if (Shape.RetconLowering.IsFrameInlineInStorage) { + RawFramePtr = Id->getStorage(); + } else { + IRBuilder<> Builder(Id); + + // Determine the size of the frame. + const DataLayout &DL = F.getParent()->getDataLayout(); + auto Size = DL.getTypeAllocSize(Shape.FrameTy); + + // Allocate. We don't need to update the call graph node because we're + // going to recompute it from scratch after splitting. + RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr); + RawFramePtr = + Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType()); + + // Stash the allocated frame pointer in the continuation storage. + auto Dest = Builder.CreateBitCast(Id->getStorage(), + RawFramePtr->getType()->getPointerTo()); + Builder.CreateStore(RawFramePtr, Dest); + } + + // Map all uses of llvm.coro.begin to the allocated frame pointer. + { + // Make sure we don't invalidate Shape.FramePtr. + TrackingVH<Instruction> Handle(Shape.FramePtr); + Shape.CoroBegin->replaceAllUsesWith(RawFramePtr); + Shape.FramePtr = Handle.getValPtr(); + } + + // Create a unique return block. + BasicBlock *ReturnBB = nullptr; + SmallVector<PHINode *, 4> ReturnPHIs; + + // Create all the functions in order after the main function. + auto NextF = std::next(F.getIterator()); + + // Create a continuation function for each of the suspend points. + Clones.reserve(Shape.CoroSuspends.size()); + for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) { + auto Suspend = cast<CoroSuspendRetconInst>(Shape.CoroSuspends[i]); + + // Create the clone declaration. + auto Continuation = + createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF); + Clones.push_back(Continuation); + + // Insert a branch to the unified return block immediately before + // the suspend point. + auto SuspendBB = Suspend->getParent(); + auto NewSuspendBB = SuspendBB->splitBasicBlock(Suspend); + auto Branch = cast<BranchInst>(SuspendBB->getTerminator()); + + // Create the unified return block. + if (!ReturnBB) { + // Place it before the first suspend. + ReturnBB = BasicBlock::Create(F.getContext(), "coro.return", &F, + NewSuspendBB); + Shape.RetconLowering.ReturnBlock = ReturnBB; + + IRBuilder<> Builder(ReturnBB); + + // Create PHIs for all the return values. + assert(ReturnPHIs.empty()); + + // First, the continuation. + ReturnPHIs.push_back(Builder.CreatePHI(Continuation->getType(), + Shape.CoroSuspends.size())); + + // Next, all the directly-yielded values. + for (auto ResultTy : Shape.getRetconResultTypes()) + ReturnPHIs.push_back(Builder.CreatePHI(ResultTy, + Shape.CoroSuspends.size())); + + // Build the return value. + auto RetTy = F.getReturnType(); + + // Cast the continuation value if necessary. + // We can't rely on the types matching up because that type would + // have to be infinite. + auto CastedContinuationTy = + (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0)); + auto *CastedContinuation = + Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy); + + Value *RetV; + if (ReturnPHIs.size() == 1) { + RetV = CastedContinuation; + } else { + RetV = UndefValue::get(RetTy); + RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0); + for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I) + RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I); } + + Builder.CreateRet(RetV); } - } while (!Work.empty()); - return DoNotRelocate; -} -static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { - // Analyze which non-alloca instructions are needed for allocation and - // relocate the rest to after coro.begin. We need to do it, since some of the - // targets of those instructions may be placed into coroutine frame memory - // for which becomes available after coro.begin intrinsic. + // Branch to the return block. + Branch->setSuccessor(0, ReturnBB); + ReturnPHIs[0]->addIncoming(Continuation, SuspendBB); + size_t NextPHIIndex = 1; + for (auto &VUse : Suspend->value_operands()) + ReturnPHIs[NextPHIIndex++]->addIncoming(&*VUse, SuspendBB); + assert(NextPHIIndex == ReturnPHIs.size()); + } - auto BlockSet = getCoroBeginPredBlocks(CoroBegin); - auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); + assert(Clones.size() == Shape.CoroSuspends.size()); + for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) { + auto Suspend = Shape.CoroSuspends[i]; + auto Clone = Clones[i]; - Instruction *InsertPt = CoroBegin->getNextNode(); - BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. - for (auto B = BB.begin(), E = BB.end(); B != E;) { - Instruction &I = *B++; - if (isa<AllocaInst>(&I)) - continue; - if (&I == CoroBegin) - break; - if (DoNotRelocateSet.count(&I)) - continue; - I.moveBefore(InsertPt); + CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend).create(); + } +} + +namespace { + class PrettyStackTraceFunction : public PrettyStackTraceEntry { + Function &F; + public: + PrettyStackTraceFunction(Function &F) : F(F) {} + void print(raw_ostream &OS) const override { + OS << "While splitting coroutine "; + F.printAsOperand(OS, /*print type*/ false, F.getParent()); + OS << "\n"; + } + }; +} + +static void splitCoroutine(Function &F, coro::Shape &Shape, + SmallVectorImpl<Function *> &Clones) { + switch (Shape.ABI) { + case coro::ABI::Switch: + return splitSwitchCoroutine(F, Shape, Clones); + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: + return splitRetconCoroutine(F, Shape, Clones); } + llvm_unreachable("bad ABI kind"); } static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { - EliminateUnreachableBlocks(F); + PrettyStackTraceFunction prettyStackTrace(F); + + // The suspend-crossing algorithm in buildCoroutineFrame get tripped + // up by uses in unreachable blocks, so remove them as a first pass. + removeUnreachableBlocks(F); coro::Shape Shape(F); if (!Shape.CoroBegin) return; simplifySuspendPoints(Shape); - relocateInstructionBefore(Shape.CoroBegin, F); buildCoroutineFrame(F, Shape); replaceFrameSize(Shape); + SmallVector<Function*, 4> Clones; + // If there are no suspend points, no split required, just remove // the allocation and deallocation blocks, they are not needed. if (Shape.CoroSuspends.empty()) { - handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); - removeCoroEnds(Shape); - postSplitCleanup(F); - coro::updateCallGraph(F, {}, CG, SCC); - return; + handleNoSuspendCoroutine(Shape); + } else { + splitCoroutine(F, Shape, Clones); } - auto *ResumeEntry = createResumeEntryBlock(F, Shape); - auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); - auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); - auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2); - - // We no longer need coro.end in F. - removeCoroEnds(Shape); + // Replace all the swifterror operations in the original function. + // This invalidates SwiftErrorOps in the Shape. + replaceSwiftErrorOps(F, Shape, nullptr); + removeCoroEnds(Shape, &CG); postSplitCleanup(F); - postSplitCleanup(*ResumeClone); - postSplitCleanup(*DestroyClone); - postSplitCleanup(*CleanupClone); - - addMustTailToCoroResumes(*ResumeClone); - - // Store addresses resume/destroy/cleanup functions in the coroutine frame. - updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); - - // Create a constant array referring to resume/destroy/clone functions pointed - // by the last argument of @llvm.coro.info, so that CoroElide pass can - // determined correct function to call. - setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); // Update call graph and add the functions we created to the SCC. - coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); + coro::updateCallGraph(F, Clones, CG, SCC); } // When we see the coroutine the first time, we insert an indirect call to a @@ -881,6 +1432,80 @@ static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { SCC.initialize(Nodes); } +/// Replace a call to llvm.coro.prepare.retcon. +static void replacePrepare(CallInst *Prepare, CallGraph &CG) { + auto CastFn = Prepare->getArgOperand(0); // as an i8* + auto Fn = CastFn->stripPointerCasts(); // as its original type + + // Find call graph nodes for the preparation. + CallGraphNode *PrepareUserNode = nullptr, *FnNode = nullptr; + if (auto ConcreteFn = dyn_cast<Function>(Fn)) { + PrepareUserNode = CG[Prepare->getFunction()]; + FnNode = CG[ConcreteFn]; + } + + // Attempt to peephole this pattern: + // %0 = bitcast [[TYPE]] @some_function to i8* + // %1 = call @llvm.coro.prepare.retcon(i8* %0) + // %2 = bitcast %1 to [[TYPE]] + // ==> + // %2 = @some_function + for (auto UI = Prepare->use_begin(), UE = Prepare->use_end(); + UI != UE; ) { + // Look for bitcasts back to the original function type. + auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser()); + if (!Cast || Cast->getType() != Fn->getType()) continue; + + // Check whether the replacement will introduce new direct calls. + // If so, we'll need to update the call graph. + if (PrepareUserNode) { + for (auto &Use : Cast->uses()) { + if (auto *CB = dyn_cast<CallBase>(Use.getUser())) { + if (!CB->isCallee(&Use)) + continue; + PrepareUserNode->removeCallEdgeFor(*CB); + PrepareUserNode->addCalledFunction(CB, FnNode); + } + } + } + + // Replace and remove the cast. + Cast->replaceAllUsesWith(Fn); + Cast->eraseFromParent(); + } + + // Replace any remaining uses with the function as an i8*. + // This can never directly be a callee, so we don't need to update CG. + Prepare->replaceAllUsesWith(CastFn); + Prepare->eraseFromParent(); + + // Kill dead bitcasts. + while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) { + if (!Cast->use_empty()) break; + CastFn = Cast->getOperand(0); + Cast->eraseFromParent(); + } +} + +/// Remove calls to llvm.coro.prepare.retcon, a barrier meant to prevent +/// IPO from operating on calls to a retcon coroutine before it's been +/// split. This is only safe to do after we've split all retcon +/// coroutines in the module. We can do that this in this pass because +/// this pass does promise to split all retcon coroutines (as opposed to +/// switch coroutines, which are lowered in multiple stages). +static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) { + bool Changed = false; + for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end(); + PI != PE; ) { + // Intrinsics can only be used in calls. + auto *Prepare = cast<CallInst>((PI++)->getUser()); + replacePrepare(Prepare, CG); + Changed = true; + } + + return Changed; +} + //===----------------------------------------------------------------------===// // Top Level Driver //===----------------------------------------------------------------------===// @@ -899,7 +1524,9 @@ struct CoroSplit : public CallGraphSCCPass { // A coroutine is identified by the presence of coro.begin intrinsic, if // we don't have any, this pass has nothing to do. bool doInitialization(CallGraph &CG) override { - Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); + Run = coro::declaresIntrinsics(CG.getModule(), + {"llvm.coro.begin", + "llvm.coro.prepare.retcon"}); return CallGraphSCCPass::doInitialization(CG); } @@ -907,6 +1534,12 @@ struct CoroSplit : public CallGraphSCCPass { if (!Run) return false; + // Check for uses of llvm.coro.prepare.retcon. + auto PrepareFn = + SCC.getCallGraph().getModule().getFunction("llvm.coro.prepare.retcon"); + if (PrepareFn && PrepareFn->use_empty()) + PrepareFn = nullptr; + // Find coroutines for processing. SmallVector<Function *, 4> Coroutines; for (CallGraphNode *CGN : SCC) @@ -914,12 +1547,17 @@ struct CoroSplit : public CallGraphSCCPass { if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) Coroutines.push_back(F); - if (Coroutines.empty()) + if (Coroutines.empty() && !PrepareFn) return false; CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + + if (Coroutines.empty()) + return replaceAllPrepares(PrepareFn, CG); + createDevirtTriggerFunc(CG, SCC); + // Split all the coroutines. for (Function *F : Coroutines) { Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); StringRef Value = Attr.getValueAsString(); @@ -932,6 +1570,10 @@ struct CoroSplit : public CallGraphSCCPass { F->removeFnAttr(CORO_PRESPLIT_ATTR); splitCoroutine(*F, CG, SCC); } + + if (PrepareFn) + replaceAllPrepares(PrepareFn, CG); + return true; } diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp index a581d1d21169..f39483b27518 100644 --- a/lib/Transforms/Coroutines/Coroutines.cpp +++ b/lib/Transforms/Coroutines/Coroutines.cpp @@ -123,12 +123,26 @@ Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, static bool isCoroutineIntrinsicName(StringRef Name) { // NOTE: Must be sorted! static const char *const CoroIntrinsics[] = { - "llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.destroy", - "llvm.coro.done", "llvm.coro.end", "llvm.coro.frame", - "llvm.coro.free", "llvm.coro.id", "llvm.coro.noop", - "llvm.coro.param", "llvm.coro.promise", "llvm.coro.resume", - "llvm.coro.save", "llvm.coro.size", "llvm.coro.subfn.addr", + "llvm.coro.alloc", + "llvm.coro.begin", + "llvm.coro.destroy", + "llvm.coro.done", + "llvm.coro.end", + "llvm.coro.frame", + "llvm.coro.free", + "llvm.coro.id", + "llvm.coro.id.retcon", + "llvm.coro.id.retcon.once", + "llvm.coro.noop", + "llvm.coro.param", + "llvm.coro.prepare.retcon", + "llvm.coro.promise", + "llvm.coro.resume", + "llvm.coro.save", + "llvm.coro.size", + "llvm.coro.subfn.addr", "llvm.coro.suspend", + "llvm.coro.suspend.retcon", }; return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1; } @@ -217,9 +231,6 @@ static void clear(coro::Shape &Shape) { Shape.FrameTy = nullptr; Shape.FramePtr = nullptr; Shape.AllocaSpillBlock = nullptr; - Shape.ResumeSwitch = nullptr; - Shape.PromiseAlloca = nullptr; - Shape.HasFinalSuspend = false; } static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin, @@ -235,6 +246,7 @@ static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin, // Collect "interesting" coroutine intrinsics. void coro::Shape::buildFrom(Function &F) { + bool HasFinalSuspend = false; size_t FinalSuspendIndex = 0; clear(*this); SmallVector<CoroFrameInst *, 8> CoroFrames; @@ -257,9 +269,15 @@ void coro::Shape::buildFrom(Function &F) { if (II->use_empty()) UnusedCoroSaves.push_back(cast<CoroSaveInst>(II)); break; - case Intrinsic::coro_suspend: - CoroSuspends.push_back(cast<CoroSuspendInst>(II)); - if (CoroSuspends.back()->isFinal()) { + case Intrinsic::coro_suspend_retcon: { + auto Suspend = cast<CoroSuspendRetconInst>(II); + CoroSuspends.push_back(Suspend); + break; + } + case Intrinsic::coro_suspend: { + auto Suspend = cast<CoroSuspendInst>(II); + CoroSuspends.push_back(Suspend); + if (Suspend->isFinal()) { if (HasFinalSuspend) report_fatal_error( "Only one suspend point can be marked as final"); @@ -267,18 +285,23 @@ void coro::Shape::buildFrom(Function &F) { FinalSuspendIndex = CoroSuspends.size() - 1; } break; + } case Intrinsic::coro_begin: { auto CB = cast<CoroBeginInst>(II); - if (CB->getId()->getInfo().isPreSplit()) { - if (CoroBegin) - report_fatal_error( + + // Ignore coro id's that aren't pre-split. + auto Id = dyn_cast<CoroIdInst>(CB->getId()); + if (Id && !Id->getInfo().isPreSplit()) + break; + + if (CoroBegin) + report_fatal_error( "coroutine should have exactly one defining @llvm.coro.begin"); - CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); - CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); - CB->removeAttribute(AttributeList::FunctionIndex, - Attribute::NoDuplicate); - CoroBegin = CB; - } + CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + CB->removeAttribute(AttributeList::FunctionIndex, + Attribute::NoDuplicate); + CoroBegin = CB; break; } case Intrinsic::coro_end: @@ -310,7 +333,7 @@ void coro::Shape::buildFrom(Function &F) { // Replace all coro.suspend with undef and remove related coro.saves if // present. - for (CoroSuspendInst *CS : CoroSuspends) { + for (AnyCoroSuspendInst *CS : CoroSuspends) { CS->replaceAllUsesWith(UndefValue::get(CS->getType())); CS->eraseFromParent(); if (auto *CoroSave = CS->getCoroSave()) @@ -324,19 +347,136 @@ void coro::Shape::buildFrom(Function &F) { return; } + auto Id = CoroBegin->getId(); + switch (auto IdIntrinsic = Id->getIntrinsicID()) { + case Intrinsic::coro_id: { + auto SwitchId = cast<CoroIdInst>(Id); + this->ABI = coro::ABI::Switch; + this->SwitchLowering.HasFinalSuspend = HasFinalSuspend; + this->SwitchLowering.ResumeSwitch = nullptr; + this->SwitchLowering.PromiseAlloca = SwitchId->getPromise(); + this->SwitchLowering.ResumeEntryBlock = nullptr; + + for (auto AnySuspend : CoroSuspends) { + auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend); + if (!Suspend) { +#ifndef NDEBUG + AnySuspend->dump(); +#endif + report_fatal_error("coro.id must be paired with coro.suspend"); + } + + if (!Suspend->getCoroSave()) + createCoroSave(CoroBegin, Suspend); + } + break; + } + + case Intrinsic::coro_id_retcon: + case Intrinsic::coro_id_retcon_once: { + auto ContinuationId = cast<AnyCoroIdRetconInst>(Id); + ContinuationId->checkWellFormed(); + this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon + ? coro::ABI::Retcon + : coro::ABI::RetconOnce); + auto Prototype = ContinuationId->getPrototype(); + this->RetconLowering.ResumePrototype = Prototype; + this->RetconLowering.Alloc = ContinuationId->getAllocFunction(); + this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction(); + this->RetconLowering.ReturnBlock = nullptr; + this->RetconLowering.IsFrameInlineInStorage = false; + + // Determine the result value types, and make sure they match up with + // the values passed to the suspends. + auto ResultTys = getRetconResultTypes(); + auto ResumeTys = getRetconResumeTypes(); + + for (auto AnySuspend : CoroSuspends) { + auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend); + if (!Suspend) { +#ifndef NDEBUG + AnySuspend->dump(); +#endif + report_fatal_error("coro.id.retcon.* must be paired with " + "coro.suspend.retcon"); + } + + // Check that the argument types of the suspend match the results. + auto SI = Suspend->value_begin(), SE = Suspend->value_end(); + auto RI = ResultTys.begin(), RE = ResultTys.end(); + for (; SI != SE && RI != RE; ++SI, ++RI) { + auto SrcTy = (*SI)->getType(); + if (SrcTy != *RI) { + // The optimizer likes to eliminate bitcasts leading into variadic + // calls, but that messes with our invariants. Re-insert the + // bitcast and ignore this type mismatch. + if (CastInst::isBitCastable(SrcTy, *RI)) { + auto BCI = new BitCastInst(*SI, *RI, "", Suspend); + SI->set(BCI); + continue; + } + +#ifndef NDEBUG + Suspend->dump(); + Prototype->getFunctionType()->dump(); +#endif + report_fatal_error("argument to coro.suspend.retcon does not " + "match corresponding prototype function result"); + } + } + if (SI != SE || RI != RE) { +#ifndef NDEBUG + Suspend->dump(); + Prototype->getFunctionType()->dump(); +#endif + report_fatal_error("wrong number of arguments to coro.suspend.retcon"); + } + + // Check that the result type of the suspend matches the resume types. + Type *SResultTy = Suspend->getType(); + ArrayRef<Type*> SuspendResultTys; + if (SResultTy->isVoidTy()) { + // leave as empty array + } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) { + SuspendResultTys = SResultStructTy->elements(); + } else { + // forms an ArrayRef using SResultTy, be careful + SuspendResultTys = SResultTy; + } + if (SuspendResultTys.size() != ResumeTys.size()) { +#ifndef NDEBUG + Suspend->dump(); + Prototype->getFunctionType()->dump(); +#endif + report_fatal_error("wrong number of results from coro.suspend.retcon"); + } + for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) { + if (SuspendResultTys[I] != ResumeTys[I]) { +#ifndef NDEBUG + Suspend->dump(); + Prototype->getFunctionType()->dump(); +#endif + report_fatal_error("result from coro.suspend.retcon does not " + "match corresponding prototype function param"); + } + } + } + break; + } + + default: + llvm_unreachable("coro.begin is not dependent on a coro.id call"); + } + // The coro.free intrinsic is always lowered to the result of coro.begin. for (CoroFrameInst *CF : CoroFrames) { CF->replaceAllUsesWith(CoroBegin); CF->eraseFromParent(); } - // Canonicalize coro.suspend by inserting a coro.save if needed. - for (CoroSuspendInst *CS : CoroSuspends) - if (!CS->getCoroSave()) - createCoroSave(CoroBegin, CS); - // Move final suspend to be the last element in the CoroSuspends vector. - if (HasFinalSuspend && + if (ABI == coro::ABI::Switch && + SwitchLowering.HasFinalSuspend && FinalSuspendIndex != CoroSuspends.size() - 1) std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back()); @@ -345,6 +485,154 @@ void coro::Shape::buildFrom(Function &F) { CoroSave->eraseFromParent(); } +static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) { + Call->setCallingConv(Callee->getCallingConv()); + // TODO: attributes? +} + +static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){ + if (CG) + (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]); +} + +Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size, + CallGraph *CG) const { + switch (ABI) { + case coro::ABI::Switch: + llvm_unreachable("can't allocate memory in coro switch-lowering"); + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: { + auto Alloc = RetconLowering.Alloc; + Size = Builder.CreateIntCast(Size, + Alloc->getFunctionType()->getParamType(0), + /*is signed*/ false); + auto *Call = Builder.CreateCall(Alloc, Size); + propagateCallAttrsFromCallee(Call, Alloc); + addCallToCallGraph(CG, Call, Alloc); + return Call; + } + } + llvm_unreachable("Unknown coro::ABI enum"); +} + +void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr, + CallGraph *CG) const { + switch (ABI) { + case coro::ABI::Switch: + llvm_unreachable("can't allocate memory in coro switch-lowering"); + + case coro::ABI::Retcon: + case coro::ABI::RetconOnce: { + auto Dealloc = RetconLowering.Dealloc; + Ptr = Builder.CreateBitCast(Ptr, + Dealloc->getFunctionType()->getParamType(0)); + auto *Call = Builder.CreateCall(Dealloc, Ptr); + propagateCallAttrsFromCallee(Call, Dealloc); + addCallToCallGraph(CG, Call, Dealloc); + return; + } + } + llvm_unreachable("Unknown coro::ABI enum"); +} + +LLVM_ATTRIBUTE_NORETURN +static void fail(const Instruction *I, const char *Reason, Value *V) { +#ifndef NDEBUG + I->dump(); + if (V) { + errs() << " Value: "; + V->printAsOperand(llvm::errs()); + errs() << '\n'; + } +#endif + report_fatal_error(Reason); +} + +/// Check that the given value is a well-formed prototype for the +/// llvm.coro.id.retcon.* intrinsics. +static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) { + auto F = dyn_cast<Function>(V->stripPointerCasts()); + if (!F) + fail(I, "llvm.coro.id.retcon.* prototype not a Function", V); + + auto FT = F->getFunctionType(); + + if (isa<CoroIdRetconInst>(I)) { + bool ResultOkay; + if (FT->getReturnType()->isPointerTy()) { + ResultOkay = true; + } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) { + ResultOkay = (!SRetTy->isOpaque() && + SRetTy->getNumElements() > 0 && + SRetTy->getElementType(0)->isPointerTy()); + } else { + ResultOkay = false; + } + if (!ResultOkay) + fail(I, "llvm.coro.id.retcon prototype must return pointer as first " + "result", F); + + if (FT->getReturnType() != + I->getFunction()->getFunctionType()->getReturnType()) + fail(I, "llvm.coro.id.retcon prototype return type must be same as" + "current function return type", F); + } else { + // No meaningful validation to do here for llvm.coro.id.unique.once. + } + + if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy()) + fail(I, "llvm.coro.id.retcon.* prototype must take pointer as " + "its first parameter", F); +} + +/// Check that the given value is a well-formed allocator. +static void checkWFAlloc(const Instruction *I, Value *V) { + auto F = dyn_cast<Function>(V->stripPointerCasts()); + if (!F) + fail(I, "llvm.coro.* allocator not a Function", V); + + auto FT = F->getFunctionType(); + if (!FT->getReturnType()->isPointerTy()) + fail(I, "llvm.coro.* allocator must return a pointer", F); + + if (FT->getNumParams() != 1 || + !FT->getParamType(0)->isIntegerTy()) + fail(I, "llvm.coro.* allocator must take integer as only param", F); +} + +/// Check that the given value is a well-formed deallocator. +static void checkWFDealloc(const Instruction *I, Value *V) { + auto F = dyn_cast<Function>(V->stripPointerCasts()); + if (!F) + fail(I, "llvm.coro.* deallocator not a Function", V); + + auto FT = F->getFunctionType(); + if (!FT->getReturnType()->isVoidTy()) + fail(I, "llvm.coro.* deallocator must return void", F); + + if (FT->getNumParams() != 1 || + !FT->getParamType(0)->isPointerTy()) + fail(I, "llvm.coro.* deallocator must take pointer as only param", F); +} + +static void checkConstantInt(const Instruction *I, Value *V, + const char *Reason) { + if (!isa<ConstantInt>(V)) { + fail(I, Reason, V); + } +} + +void AnyCoroIdRetconInst::checkWellFormed() const { + checkConstantInt(this, getArgOperand(SizeArg), + "size argument to coro.id.retcon.* must be constant"); + checkConstantInt(this, getArgOperand(AlignArg), + "alignment argument to coro.id.retcon.* must be constant"); + checkWFRetconPrototype(this, getArgOperand(PrototypeArg)); + checkWFAlloc(this, getArgOperand(AllocArg)); + checkWFDealloc(this, getArgOperand(DeallocArg)); +} + void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createCoroEarlyPass()); } diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index 95a9f31cced3..dd9f74a881ee 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -304,7 +304,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // of the previous load. LoadInst *newLoad = IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val"); - newLoad->setAlignment(OrigLoad->getAlignment()); + newLoad->setAlignment(MaybeAlign(OrigLoad->getAlignment())); // Transfer the AA info too. AAMDNodes AAInfo; OrigLoad->getAAMetadata(AAInfo); diff --git a/lib/Transforms/IPO/Attributor.cpp b/lib/Transforms/IPO/Attributor.cpp index 2a52c6b9b4ad..95f47345d8fd 100644 --- a/lib/Transforms/IPO/Attributor.cpp +++ b/lib/Transforms/IPO/Attributor.cpp @@ -16,11 +16,15 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -30,6 +34,9 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" + #include <cassert> using namespace llvm; @@ -46,19 +53,50 @@ STATISTIC(NumAttributesValidFixpoint, "Number of abstract attributes in a valid fixpoint state"); STATISTIC(NumAttributesManifested, "Number of abstract attributes manifested in IR"); -STATISTIC(NumFnNoUnwind, "Number of functions marked nounwind"); - -STATISTIC(NumFnUniqueReturned, "Number of function with unique return"); -STATISTIC(NumFnKnownReturns, "Number of function with known return values"); -STATISTIC(NumFnArgumentReturned, - "Number of function arguments marked returned"); -STATISTIC(NumFnNoSync, "Number of functions marked nosync"); -STATISTIC(NumFnNoFree, "Number of functions marked nofree"); -STATISTIC(NumFnReturnedNonNull, - "Number of function return values marked nonnull"); -STATISTIC(NumFnArgumentNonNull, "Number of function arguments marked nonnull"); -STATISTIC(NumCSArgumentNonNull, "Number of call site arguments marked nonnull"); -STATISTIC(NumFnWillReturn, "Number of functions marked willreturn"); + +// Some helper macros to deal with statistics tracking. +// +// Usage: +// For simple IR attribute tracking overload trackStatistics in the abstract +// attribute and choose the right STATS_DECLTRACK_********* macro, +// e.g.,: +// void trackStatistics() const override { +// STATS_DECLTRACK_ARG_ATTR(returned) +// } +// If there is a single "increment" side one can use the macro +// STATS_DECLTRACK with a custom message. If there are multiple increment +// sides, STATS_DECL and STATS_TRACK can also be used separatly. +// +#define BUILD_STAT_MSG_IR_ATTR(TYPE, NAME) \ + ("Number of " #TYPE " marked '" #NAME "'") +#define BUILD_STAT_NAME(NAME, TYPE) NumIR##TYPE##_##NAME +#define STATS_DECL_(NAME, MSG) STATISTIC(NAME, MSG); +#define STATS_DECL(NAME, TYPE, MSG) \ + STATS_DECL_(BUILD_STAT_NAME(NAME, TYPE), MSG); +#define STATS_TRACK(NAME, TYPE) ++(BUILD_STAT_NAME(NAME, TYPE)); +#define STATS_DECLTRACK(NAME, TYPE, MSG) \ + { \ + STATS_DECL(NAME, TYPE, MSG) \ + STATS_TRACK(NAME, TYPE) \ + } +#define STATS_DECLTRACK_ARG_ATTR(NAME) \ + STATS_DECLTRACK(NAME, Arguments, BUILD_STAT_MSG_IR_ATTR(arguments, NAME)) +#define STATS_DECLTRACK_CSARG_ATTR(NAME) \ + STATS_DECLTRACK(NAME, CSArguments, \ + BUILD_STAT_MSG_IR_ATTR(call site arguments, NAME)) +#define STATS_DECLTRACK_FN_ATTR(NAME) \ + STATS_DECLTRACK(NAME, Function, BUILD_STAT_MSG_IR_ATTR(functions, NAME)) +#define STATS_DECLTRACK_CS_ATTR(NAME) \ + STATS_DECLTRACK(NAME, CS, BUILD_STAT_MSG_IR_ATTR(call site, NAME)) +#define STATS_DECLTRACK_FNRET_ATTR(NAME) \ + STATS_DECLTRACK(NAME, FunctionReturn, \ + BUILD_STAT_MSG_IR_ATTR(function returns, NAME)) +#define STATS_DECLTRACK_CSRET_ATTR(NAME) \ + STATS_DECLTRACK(NAME, CSReturn, \ + BUILD_STAT_MSG_IR_ATTR(call site returns, NAME)) +#define STATS_DECLTRACK_FLOATING_ATTR(NAME) \ + STATS_DECLTRACK(NAME, Floating, \ + ("Number of floating values known to be '" #NAME "'")) // TODO: Determine a good default value. // @@ -72,18 +110,32 @@ static cl::opt<unsigned> MaxFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32)); +static cl::opt<bool> VerifyMaxFixpointIterations( + "attributor-max-iterations-verify", cl::Hidden, + cl::desc("Verify that max-iterations is a tight bound for a fixpoint"), + cl::init(false)); static cl::opt<bool> DisableAttributor( "attributor-disable", cl::Hidden, cl::desc("Disable the attributor inter-procedural deduction pass."), cl::init(true)); -static cl::opt<bool> VerifyAttributor( - "attributor-verify", cl::Hidden, - cl::desc("Verify the Attributor deduction and " - "manifestation of attributes -- may issue false-positive errors"), +static cl::opt<bool> ManifestInternal( + "attributor-manifest-internal", cl::Hidden, + cl::desc("Manifest Attributor internal string attributes."), cl::init(false)); +static cl::opt<unsigned> DepRecInterval( + "attributor-dependence-recompute-interval", cl::Hidden, + cl::desc("Number of iterations until dependences are recomputed."), + cl::init(4)); + +static cl::opt<bool> EnableHeapToStack("enable-heap-to-stack-conversion", + cl::init(true), cl::Hidden); + +static cl::opt<int> MaxHeapToStackSize("max-heap-to-stack-size", cl::init(128), + cl::Hidden); + /// Logic operators for the change status enum class. /// ///{ @@ -95,78 +147,30 @@ ChangeStatus llvm::operator&(ChangeStatus l, ChangeStatus r) { } ///} -/// Helper to adjust the statistics. -static void bookkeeping(AbstractAttribute::ManifestPosition MP, - const Attribute &Attr) { - if (!AreStatisticsEnabled()) - return; - - if (!Attr.isEnumAttribute()) - return; - switch (Attr.getKindAsEnum()) { - case Attribute::NoUnwind: - NumFnNoUnwind++; - return; - case Attribute::Returned: - NumFnArgumentReturned++; - return; - case Attribute::NoSync: - NumFnNoSync++; - break; - case Attribute::NoFree: - NumFnNoFree++; - break; - case Attribute::NonNull: - switch (MP) { - case AbstractAttribute::MP_RETURNED: - NumFnReturnedNonNull++; - break; - case AbstractAttribute::MP_ARGUMENT: - NumFnArgumentNonNull++; - break; - case AbstractAttribute::MP_CALL_SITE_ARGUMENT: - NumCSArgumentNonNull++; - break; - default: - break; - } - break; - case Attribute::WillReturn: - NumFnWillReturn++; - break; - default: - return; - } -} - -template <typename StateTy> -using followValueCB_t = std::function<bool(Value *, StateTy &State)>; -template <typename StateTy> -using visitValueCB_t = std::function<void(Value *, StateTy &State)>; - -/// Recursively visit all values that might become \p InitV at some point. This +/// Recursively visit all values that might become \p IRP at some point. This /// will be done by looking through cast instructions, selects, phis, and calls -/// with the "returned" attribute. The callback \p FollowValueCB is asked before -/// a potential origin value is looked at. If no \p FollowValueCB is passed, a -/// default one is used that will make sure we visit every value only once. Once -/// we cannot look through the value any further, the callback \p VisitValueCB -/// is invoked and passed the current value and the \p State. To limit how much -/// effort is invested, we will never visit more than \p MaxValues values. -template <typename StateTy> +/// with the "returned" attribute. Once we cannot look through the value any +/// further, the callback \p VisitValueCB is invoked and passed the current +/// value, the \p State, and a flag to indicate if we stripped anything. To +/// limit how much effort is invested, we will never visit more values than +/// specified by \p MaxValues. +template <typename AAType, typename StateTy> static bool genericValueTraversal( - Value *InitV, StateTy &State, visitValueCB_t<StateTy> &VisitValueCB, - followValueCB_t<StateTy> *FollowValueCB = nullptr, int MaxValues = 8) { - + Attributor &A, IRPosition IRP, const AAType &QueryingAA, StateTy &State, + const function_ref<bool(Value &, StateTy &, bool)> &VisitValueCB, + int MaxValues = 8) { + + const AAIsDead *LivenessAA = nullptr; + if (IRP.getAnchorScope()) + LivenessAA = &A.getAAFor<AAIsDead>( + QueryingAA, IRPosition::function(*IRP.getAnchorScope()), + /* TrackDependence */ false); + bool AnyDead = false; + + // TODO: Use Positions here to allow context sensitivity in VisitValueCB SmallPtrSet<Value *, 16> Visited; - followValueCB_t<bool> DefaultFollowValueCB = [&](Value *Val, bool &) { - return Visited.insert(Val).second; - }; - - if (!FollowValueCB) - FollowValueCB = &DefaultFollowValueCB; - SmallVector<Value *, 16> Worklist; - Worklist.push_back(InitV); + Worklist.push_back(&IRP.getAssociatedValue()); int Iteration = 0; do { @@ -174,7 +178,7 @@ static bool genericValueTraversal( // Check if we should process the current value. To prevent endless // recursion keep a record of the values we followed! - if (!(*FollowValueCB)(V, State)) + if (!Visited.insert(V).second) continue; // Make sure we limit the compile time for complex expressions. @@ -183,23 +187,23 @@ static bool genericValueTraversal( // Explicitly look through calls with a "returned" attribute if we do // not have a pointer as stripPointerCasts only works on them. + Value *NewV = nullptr; if (V->getType()->isPointerTy()) { - V = V->stripPointerCasts(); + NewV = V->stripPointerCasts(); } else { CallSite CS(V); if (CS && CS.getCalledFunction()) { - Value *NewV = nullptr; for (Argument &Arg : CS.getCalledFunction()->args()) if (Arg.hasReturnedAttr()) { NewV = CS.getArgOperand(Arg.getArgNo()); break; } - if (NewV) { - Worklist.push_back(NewV); - continue; - } } } + if (NewV && NewV != V) { + Worklist.push_back(NewV); + continue; + } // Look through select instructions, visit both potential values. if (auto *SI = dyn_cast<SelectInst>(V)) { @@ -208,35 +212,34 @@ static bool genericValueTraversal( continue; } - // Look through phi nodes, visit all operands. + // Look through phi nodes, visit all live operands. if (auto *PHI = dyn_cast<PHINode>(V)) { - Worklist.append(PHI->op_begin(), PHI->op_end()); + assert(LivenessAA && + "Expected liveness in the presence of instructions!"); + for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) { + const BasicBlock *IncomingBB = PHI->getIncomingBlock(u); + if (LivenessAA->isAssumedDead(IncomingBB->getTerminator())) { + AnyDead = true; + continue; + } + Worklist.push_back(PHI->getIncomingValue(u)); + } continue; } // Once a leaf is reached we inform the user through the callback. - VisitValueCB(V, State); + if (!VisitValueCB(*V, State, Iteration > 1)) + return false; } while (!Worklist.empty()); + // If we actually used liveness information so we have to record a dependence. + if (AnyDead) + A.recordDependence(*LivenessAA, QueryingAA); + // All values have been visited. return true; } -/// Helper to identify the correct offset into an attribute list. -static unsigned getAttrIndex(AbstractAttribute::ManifestPosition MP, - unsigned ArgNo = 0) { - switch (MP) { - case AbstractAttribute::MP_ARGUMENT: - case AbstractAttribute::MP_CALL_SITE_ARGUMENT: - return ArgNo + AttributeList::FirstArgIndex; - case AbstractAttribute::MP_FUNCTION: - return AttributeList::FunctionIndex; - case AbstractAttribute::MP_RETURNED: - return AttributeList::ReturnIndex; - } - llvm_unreachable("Unknown manifest position!"); -} - /// Return true if \p New is equal or worse than \p Old. static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) { if (!Old.isIntAttribute()) @@ -247,12 +250,9 @@ static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) { /// Return true if the information provided by \p Attr was added to the /// attribute list \p Attrs. This is only the case if it was not already present -/// in \p Attrs at the position describe by \p MP and \p ArgNo. +/// in \p Attrs at the position describe by \p PK and \p AttrIdx. static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr, - AttributeList &Attrs, - AbstractAttribute::ManifestPosition MP, - unsigned ArgNo = 0) { - unsigned AttrIdx = getAttrIndex(MP, ArgNo); + AttributeList &Attrs, int AttrIdx) { if (Attr.isEnumAttribute()) { Attribute::AttrKind Kind = Attr.getKindAsEnum(); @@ -270,9 +270,47 @@ static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr, Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); return true; } + if (Attr.isIntAttribute()) { + Attribute::AttrKind Kind = Attr.getKindAsEnum(); + if (Attrs.hasAttribute(AttrIdx, Kind)) + if (isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + return false; + Attrs = Attrs.removeAttribute(Ctx, AttrIdx, Kind); + Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + return true; + } llvm_unreachable("Expected enum or string attribute!"); } +static const Value *getPointerOperand(const Instruction *I) { + if (auto *LI = dyn_cast<LoadInst>(I)) + if (!LI->isVolatile()) + return LI->getPointerOperand(); + + if (auto *SI = dyn_cast<StoreInst>(I)) + if (!SI->isVolatile()) + return SI->getPointerOperand(); + + if (auto *CXI = dyn_cast<AtomicCmpXchgInst>(I)) + if (!CXI->isVolatile()) + return CXI->getPointerOperand(); + + if (auto *RMWI = dyn_cast<AtomicRMWInst>(I)) + if (!RMWI->isVolatile()) + return RMWI->getPointerOperand(); + + return nullptr; +} +static const Value *getBasePointerOfAccessPointerOperand(const Instruction *I, + int64_t &BytesOffset, + const DataLayout &DL) { + const Value *Ptr = getPointerOperand(I); + if (!Ptr) + return nullptr; + + return GetPointerBaseWithConstantOffset(Ptr, BytesOffset, DL, + /*AllowNonInbounds*/ false); +} ChangeStatus AbstractAttribute::update(Attributor &A) { ChangeStatus HasChanged = ChangeStatus::UNCHANGED; @@ -289,143 +327,527 @@ ChangeStatus AbstractAttribute::update(Attributor &A) { return HasChanged; } -ChangeStatus AbstractAttribute::manifest(Attributor &A) { - assert(getState().isValidState() && - "Attempted to manifest an invalid state!"); - assert(getAssociatedValue() && - "Attempted to manifest an attribute without associated value!"); - - ChangeStatus HasChanged = ChangeStatus::UNCHANGED; - SmallVector<Attribute, 4> DeducedAttrs; - getDeducedAttributes(DeducedAttrs); - - Function &ScopeFn = getAnchorScope(); - LLVMContext &Ctx = ScopeFn.getContext(); - ManifestPosition MP = getManifestPosition(); - - AttributeList Attrs; - SmallVector<unsigned, 4> ArgNos; +ChangeStatus +IRAttributeManifest::manifestAttrs(Attributor &A, IRPosition &IRP, + const ArrayRef<Attribute> &DeducedAttrs) { + Function *ScopeFn = IRP.getAssociatedFunction(); + IRPosition::Kind PK = IRP.getPositionKind(); // In the following some generic code that will manifest attributes in // DeducedAttrs if they improve the current IR. Due to the different // annotation positions we use the underlying AttributeList interface. - // Note that MP_CALL_SITE_ARGUMENT can annotate multiple locations. - switch (MP) { - case MP_ARGUMENT: - ArgNos.push_back(cast<Argument>(getAssociatedValue())->getArgNo()); - Attrs = ScopeFn.getAttributes(); + AttributeList Attrs; + switch (PK) { + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: + return ChangeStatus::UNCHANGED; + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_FUNCTION: + case IRPosition::IRP_RETURNED: + Attrs = ScopeFn->getAttributes(); break; - case MP_FUNCTION: - case MP_RETURNED: - ArgNos.push_back(0); - Attrs = ScopeFn.getAttributes(); + case IRPosition::IRP_CALL_SITE: + case IRPosition::IRP_CALL_SITE_RETURNED: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + Attrs = ImmutableCallSite(&IRP.getAnchorValue()).getAttributes(); break; - case MP_CALL_SITE_ARGUMENT: { - CallSite CS(&getAnchoredValue()); - for (unsigned u = 0, e = CS.getNumArgOperands(); u != e; u++) - if (CS.getArgOperand(u) == getAssociatedValue()) - ArgNos.push_back(u); - Attrs = CS.getAttributes(); - } } + ChangeStatus HasChanged = ChangeStatus::UNCHANGED; + LLVMContext &Ctx = IRP.getAnchorValue().getContext(); for (const Attribute &Attr : DeducedAttrs) { - for (unsigned ArgNo : ArgNos) { - if (!addIfNotExistent(Ctx, Attr, Attrs, MP, ArgNo)) - continue; + if (!addIfNotExistent(Ctx, Attr, Attrs, IRP.getAttrIdx())) + continue; - HasChanged = ChangeStatus::CHANGED; - bookkeeping(MP, Attr); - } + HasChanged = ChangeStatus::CHANGED; } if (HasChanged == ChangeStatus::UNCHANGED) return HasChanged; - switch (MP) { - case MP_ARGUMENT: - case MP_FUNCTION: - case MP_RETURNED: - ScopeFn.setAttributes(Attrs); + switch (PK) { + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_FUNCTION: + case IRPosition::IRP_RETURNED: + ScopeFn->setAttributes(Attrs); + break; + case IRPosition::IRP_CALL_SITE: + case IRPosition::IRP_CALL_SITE_RETURNED: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + CallSite(&IRP.getAnchorValue()).setAttributes(Attrs); + break; + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: break; - case MP_CALL_SITE_ARGUMENT: - CallSite(&getAnchoredValue()).setAttributes(Attrs); } return HasChanged; } -Function &AbstractAttribute::getAnchorScope() { - Value &V = getAnchoredValue(); - if (isa<Function>(V)) - return cast<Function>(V); - if (isa<Argument>(V)) - return *cast<Argument>(V).getParent(); - if (isa<Instruction>(V)) - return *cast<Instruction>(V).getFunction(); - llvm_unreachable("No scope for anchored value found!"); +const IRPosition IRPosition::EmptyKey(255); +const IRPosition IRPosition::TombstoneKey(256); + +SubsumingPositionIterator::SubsumingPositionIterator(const IRPosition &IRP) { + IRPositions.emplace_back(IRP); + + ImmutableCallSite ICS(&IRP.getAnchorValue()); + switch (IRP.getPositionKind()) { + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: + case IRPosition::IRP_FUNCTION: + return; + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_RETURNED: + IRPositions.emplace_back( + IRPosition::function(*IRP.getAssociatedFunction())); + return; + case IRPosition::IRP_CALL_SITE: + assert(ICS && "Expected call site!"); + // TODO: We need to look at the operand bundles similar to the redirection + // in CallBase. + if (!ICS.hasOperandBundles()) + if (const Function *Callee = ICS.getCalledFunction()) + IRPositions.emplace_back(IRPosition::function(*Callee)); + return; + case IRPosition::IRP_CALL_SITE_RETURNED: + assert(ICS && "Expected call site!"); + // TODO: We need to look at the operand bundles similar to the redirection + // in CallBase. + if (!ICS.hasOperandBundles()) { + if (const Function *Callee = ICS.getCalledFunction()) { + IRPositions.emplace_back(IRPosition::returned(*Callee)); + IRPositions.emplace_back(IRPosition::function(*Callee)); + } + } + IRPositions.emplace_back( + IRPosition::callsite_function(cast<CallBase>(*ICS.getInstruction()))); + return; + case IRPosition::IRP_CALL_SITE_ARGUMENT: { + int ArgNo = IRP.getArgNo(); + assert(ICS && ArgNo >= 0 && "Expected call site!"); + // TODO: We need to look at the operand bundles similar to the redirection + // in CallBase. + if (!ICS.hasOperandBundles()) { + const Function *Callee = ICS.getCalledFunction(); + if (Callee && Callee->arg_size() > unsigned(ArgNo)) + IRPositions.emplace_back(IRPosition::argument(*Callee->getArg(ArgNo))); + if (Callee) + IRPositions.emplace_back(IRPosition::function(*Callee)); + } + IRPositions.emplace_back(IRPosition::value(IRP.getAssociatedValue())); + return; + } + } +} + +bool IRPosition::hasAttr(ArrayRef<Attribute::AttrKind> AKs, + bool IgnoreSubsumingPositions) const { + for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) { + for (Attribute::AttrKind AK : AKs) + if (EquivIRP.getAttr(AK).getKindAsEnum() == AK) + return true; + // The first position returned by the SubsumingPositionIterator is + // always the position itself. If we ignore subsuming positions we + // are done after the first iteration. + if (IgnoreSubsumingPositions) + break; + } + return false; } -const Function &AbstractAttribute::getAnchorScope() const { - return const_cast<AbstractAttribute *>(this)->getAnchorScope(); +void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs, + SmallVectorImpl<Attribute> &Attrs) const { + for (const IRPosition &EquivIRP : SubsumingPositionIterator(*this)) + for (Attribute::AttrKind AK : AKs) { + const Attribute &Attr = EquivIRP.getAttr(AK); + if (Attr.getKindAsEnum() == AK) + Attrs.push_back(Attr); + } } -/// -----------------------NoUnwind Function Attribute-------------------------- +void IRPosition::verify() { + switch (KindOrArgNo) { + default: + assert(KindOrArgNo >= 0 && "Expected argument or call site argument!"); + assert((isa<CallBase>(AnchorVal) || isa<Argument>(AnchorVal)) && + "Expected call base or argument for positive attribute index!"); + if (isa<Argument>(AnchorVal)) { + assert(cast<Argument>(AnchorVal)->getArgNo() == unsigned(getArgNo()) && + "Argument number mismatch!"); + assert(cast<Argument>(AnchorVal) == &getAssociatedValue() && + "Associated value mismatch!"); + } else { + assert(cast<CallBase>(*AnchorVal).arg_size() > unsigned(getArgNo()) && + "Call site argument number mismatch!"); + assert(cast<CallBase>(*AnchorVal).getArgOperand(getArgNo()) == + &getAssociatedValue() && + "Associated value mismatch!"); + } + break; + case IRP_INVALID: + assert(!AnchorVal && "Expected no value for an invalid position!"); + break; + case IRP_FLOAT: + assert((!isa<CallBase>(&getAssociatedValue()) && + !isa<Argument>(&getAssociatedValue())) && + "Expected specialized kind for call base and argument values!"); + break; + case IRP_RETURNED: + assert(isa<Function>(AnchorVal) && + "Expected function for a 'returned' position!"); + assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!"); + break; + case IRP_CALL_SITE_RETURNED: + assert((isa<CallBase>(AnchorVal)) && + "Expected call base for 'call site returned' position!"); + assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!"); + break; + case IRP_CALL_SITE: + assert((isa<CallBase>(AnchorVal)) && + "Expected call base for 'call site function' position!"); + assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!"); + break; + case IRP_FUNCTION: + assert(isa<Function>(AnchorVal) && + "Expected function for a 'function' position!"); + assert(AnchorVal == &getAssociatedValue() && "Associated value mismatch!"); + break; + } +} + +namespace { +/// Helper functions to clamp a state \p S of type \p StateType with the +/// information in \p R and indicate/return if \p S did change (as-in update is +/// required to be run again). +/// +///{ +template <typename StateType> +ChangeStatus clampStateAndIndicateChange(StateType &S, const StateType &R); + +template <> +ChangeStatus clampStateAndIndicateChange<IntegerState>(IntegerState &S, + const IntegerState &R) { + auto Assumed = S.getAssumed(); + S ^= R; + return Assumed == S.getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; +} -struct AANoUnwindFunction : AANoUnwind, BooleanState { +template <> +ChangeStatus clampStateAndIndicateChange<BooleanState>(BooleanState &S, + const BooleanState &R) { + return clampStateAndIndicateChange<IntegerState>(S, R); +} +///} - AANoUnwindFunction(Function &F, InformationCache &InfoCache) - : AANoUnwind(F, InfoCache) {} +/// Clamp the information known for all returned values of a function +/// (identified by \p QueryingAA) into \p S. +template <typename AAType, typename StateType = typename AAType::StateType> +static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA, + StateType &S) { + LLVM_DEBUG(dbgs() << "[Attributor] Clamp return value states for " + << static_cast<const AbstractAttribute &>(QueryingAA) + << " into " << S << "\n"); + + assert((QueryingAA.getIRPosition().getPositionKind() == + IRPosition::IRP_RETURNED || + QueryingAA.getIRPosition().getPositionKind() == + IRPosition::IRP_CALL_SITE_RETURNED) && + "Can only clamp returned value states for a function returned or call " + "site returned position!"); + + // Use an optional state as there might not be any return values and we want + // to join (IntegerState::operator&) the state of all there are. + Optional<StateType> T; + + // Callback for each possibly returned value. + auto CheckReturnValue = [&](Value &RV) -> bool { + const IRPosition &RVPos = IRPosition::value(RV); + const AAType &AA = A.getAAFor<AAType>(QueryingAA, RVPos); + LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() + << " @ " << RVPos << "\n"); + const StateType &AAS = static_cast<const StateType &>(AA.getState()); + if (T.hasValue()) + *T &= AAS; + else + T = AAS; + LLVM_DEBUG(dbgs() << "[Attributor] AA State: " << AAS << " RV State: " << T + << "\n"); + return T->isValidState(); + }; - /// See AbstractAttribute::getState() - /// { - AbstractState &getState() override { return *this; } - const AbstractState &getState() const override { return *this; } - /// } + if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA)) + S.indicatePessimisticFixpoint(); + else if (T.hasValue()) + S ^= *T; +} - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { return MP_FUNCTION; } +/// Helper class to compose two generic deduction +template <typename AAType, typename Base, typename StateType, + template <typename...> class F, template <typename...> class G> +struct AAComposeTwoGenericDeduction + : public F<AAType, G<AAType, Base, StateType>, StateType> { + AAComposeTwoGenericDeduction(const IRPosition &IRP) + : F<AAType, G<AAType, Base, StateType>, StateType>(IRP) {} - const std::string getAsStr() const override { - return getAssumed() ? "nounwind" : "may-unwind"; + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus ChangedF = F<AAType, G<AAType, Base, StateType>, StateType>::updateImpl(A); + ChangeStatus ChangedG = G<AAType, Base, StateType>::updateImpl(A); + return ChangedF | ChangedG; } +}; + +/// Helper class for generic deduction: return value -> returned position. +template <typename AAType, typename Base, + typename StateType = typename AAType::StateType> +struct AAReturnedFromReturnedValues : public Base { + AAReturnedFromReturnedValues(const IRPosition &IRP) : Base(IRP) {} /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override; + ChangeStatus updateImpl(Attributor &A) override { + StateType S; + clampReturnedValueStates<AAType, StateType>(A, *this, S); + // TODO: If we know we visited all returned values, thus no are assumed + // dead, we can take the known information from the state T. + return clampStateAndIndicateChange<StateType>(this->getState(), S); + } +}; - /// See AANoUnwind::isAssumedNoUnwind(). - bool isAssumedNoUnwind() const override { return getAssumed(); } +/// Clamp the information known at all call sites for a given argument +/// (identified by \p QueryingAA) into \p S. +template <typename AAType, typename StateType = typename AAType::StateType> +static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, + StateType &S) { + LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for " + << static_cast<const AbstractAttribute &>(QueryingAA) + << " into " << S << "\n"); + + assert(QueryingAA.getIRPosition().getPositionKind() == + IRPosition::IRP_ARGUMENT && + "Can only clamp call site argument states for an argument position!"); + + // Use an optional state as there might not be any return values and we want + // to join (IntegerState::operator&) the state of all there are. + Optional<StateType> T; + + // The argument number which is also the call site argument number. + unsigned ArgNo = QueryingAA.getIRPosition().getArgNo(); + + auto CallSiteCheck = [&](AbstractCallSite ACS) { + const IRPosition &ACSArgPos = IRPosition::callsite_argument(ACS, ArgNo); + // Check if a coresponding argument was found or if it is on not associated + // (which can happen for callback calls). + if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID) + return false; - /// See AANoUnwind::isKnownNoUnwind(). - bool isKnownNoUnwind() const override { return getKnown(); } + const AAType &AA = A.getAAFor<AAType>(QueryingAA, ACSArgPos); + LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction() + << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n"); + const StateType &AAS = static_cast<const StateType &>(AA.getState()); + if (T.hasValue()) + *T &= AAS; + else + T = AAS; + LLVM_DEBUG(dbgs() << "[Attributor] AA State: " << AAS << " CSA State: " << T + << "\n"); + return T->isValidState(); + }; + + if (!A.checkForAllCallSites(CallSiteCheck, QueryingAA, true)) + S.indicatePessimisticFixpoint(); + else if (T.hasValue()) + S ^= *T; +} + +/// Helper class for generic deduction: call site argument -> argument position. +template <typename AAType, typename Base, + typename StateType = typename AAType::StateType> +struct AAArgumentFromCallSiteArguments : public Base { + AAArgumentFromCallSiteArguments(const IRPosition &IRP) : Base(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + StateType S; + clampCallSiteArgumentStates<AAType, StateType>(A, *this, S); + // TODO: If we know we visited all incoming values, thus no are assumed + // dead, we can take the known information from the state T. + return clampStateAndIndicateChange<StateType>(this->getState(), S); + } }; -ChangeStatus AANoUnwindFunction::updateImpl(Attributor &A) { - Function &F = getAnchorScope(); +/// Helper class for generic replication: function returned -> cs returned. +template <typename AAType, typename Base, + typename StateType = typename AAType::StateType> +struct AACallSiteReturnedFromReturned : public Base { + AACallSiteReturnedFromReturned(const IRPosition &IRP) : Base(IRP) {} - // The map from instruction opcodes to those instructions in the function. - auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); - auto Opcodes = { - (unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, - (unsigned)Instruction::Call, (unsigned)Instruction::CleanupRet, - (unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume}; + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + assert(this->getIRPosition().getPositionKind() == + IRPosition::IRP_CALL_SITE_RETURNED && + "Can only wrap function returned positions for call site returned " + "positions!"); + auto &S = this->getState(); + + const Function *AssociatedFunction = + this->getIRPosition().getAssociatedFunction(); + if (!AssociatedFunction) + return S.indicatePessimisticFixpoint(); + + IRPosition FnPos = IRPosition::returned(*AssociatedFunction); + const AAType &AA = A.getAAFor<AAType>(*this, FnPos); + return clampStateAndIndicateChange( + S, static_cast<const typename AAType::StateType &>(AA.getState())); + } +}; - for (unsigned Opcode : Opcodes) { - for (Instruction *I : OpcodeInstMap[Opcode]) { - if (!I->mayThrow()) - continue; +/// Helper class for generic deduction using must-be-executed-context +/// Base class is required to have `followUse` method. - auto *NoUnwindAA = A.getAAFor<AANoUnwind>(*this, *I); +/// bool followUse(Attributor &A, const Use *U, const Instruction *I) +/// U - Underlying use. +/// I - The user of the \p U. +/// `followUse` returns true if the value should be tracked transitively. - if (!NoUnwindAA || !NoUnwindAA->isAssumedNoUnwind()) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; +template <typename AAType, typename Base, + typename StateType = typename AAType::StateType> +struct AAFromMustBeExecutedContext : public Base { + AAFromMustBeExecutedContext(const IRPosition &IRP) : Base(IRP) {} + + void initialize(Attributor &A) override { + Base::initialize(A); + IRPosition &IRP = this->getIRPosition(); + Instruction *CtxI = IRP.getCtxI(); + + if (!CtxI) + return; + + for (const Use &U : IRP.getAssociatedValue().uses()) + Uses.insert(&U); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto BeforeState = this->getState(); + auto &S = this->getState(); + Instruction *CtxI = this->getIRPosition().getCtxI(); + if (!CtxI) + return ChangeStatus::UNCHANGED; + + MustBeExecutedContextExplorer &Explorer = + A.getInfoCache().getMustBeExecutedContextExplorer(); + + SetVector<const Use *> NextUses; + + for (const Use *U : Uses) { + if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) { + auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI); + bool Found = EIt.count(UserI); + while (!Found && ++EIt != EEnd) + Found = EIt.getCurrentInst() == UserI; + if (Found && Base::followUse(A, U, UserI)) + for (const Use &Us : UserI->uses()) + NextUses.insert(&Us); } } + for (const Use *U : NextUses) + Uses.insert(U); + + return BeforeState == S ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } - return ChangeStatus::UNCHANGED; -} + +private: + /// Container for (transitive) uses of the associated value. + SetVector<const Use *> Uses; +}; + +template <typename AAType, typename Base, + typename StateType = typename AAType::StateType> +using AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext = + AAComposeTwoGenericDeduction<AAType, Base, StateType, + AAFromMustBeExecutedContext, + AAArgumentFromCallSiteArguments>; + +template <typename AAType, typename Base, + typename StateType = typename AAType::StateType> +using AACallSiteReturnedFromReturnedAndMustBeExecutedContext = + AAComposeTwoGenericDeduction<AAType, Base, StateType, + AAFromMustBeExecutedContext, + AACallSiteReturnedFromReturned>; + +/// -----------------------NoUnwind Function Attribute-------------------------- + +struct AANoUnwindImpl : AANoUnwind { + AANoUnwindImpl(const IRPosition &IRP) : AANoUnwind(IRP) {} + + const std::string getAsStr() const override { + return getAssumed() ? "nounwind" : "may-unwind"; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto Opcodes = { + (unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, + (unsigned)Instruction::Call, (unsigned)Instruction::CleanupRet, + (unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume}; + + auto CheckForNoUnwind = [&](Instruction &I) { + if (!I.mayThrow()) + return true; + + if (ImmutableCallSite ICS = ImmutableCallSite(&I)) { + const auto &NoUnwindAA = + A.getAAFor<AANoUnwind>(*this, IRPosition::callsite_function(ICS)); + return NoUnwindAA.isAssumedNoUnwind(); + } + return false; + }; + + if (!A.checkForAllInstructions(CheckForNoUnwind, *this, Opcodes)) + return indicatePessimisticFixpoint(); + + return ChangeStatus::UNCHANGED; + } +}; + +struct AANoUnwindFunction final : public AANoUnwindImpl { + AANoUnwindFunction(const IRPosition &IRP) : AANoUnwindImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(nounwind) } +}; + +/// NoUnwind attribute deduction for a call sites. +struct AANoUnwindCallSite final : AANoUnwindImpl { + AANoUnwindCallSite(const IRPosition &IRP) : AANoUnwindImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoUnwindImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) + indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AANoUnwind>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), + static_cast<const AANoUnwind::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); } +}; /// --------------------- Function Return Values ------------------------------- @@ -434,68 +856,48 @@ ChangeStatus AANoUnwindFunction::updateImpl(Attributor &A) { /// /// If there is a unique returned value R, the manifest method will: /// - mark R with the "returned" attribute, if R is an argument. -class AAReturnedValuesImpl final : public AAReturnedValues, AbstractState { +class AAReturnedValuesImpl : public AAReturnedValues, public AbstractState { /// Mapping of values potentially returned by the associated function to the /// return instructions that might return them. - DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> ReturnedValues; + MapVector<Value *, SmallSetVector<ReturnInst *, 4>> ReturnedValues; + + /// Mapping to remember the number of returned values for a call site such + /// that we can avoid updates if nothing changed. + DenseMap<const CallBase *, unsigned> NumReturnedValuesPerKnownAA; + + /// Set of unresolved calls returned by the associated function. + SmallSetVector<CallBase *, 4> UnresolvedCalls; /// State flags /// ///{ - bool IsFixed; - bool IsValidState; - bool HasOverdefinedReturnedCalls; + bool IsFixed = false; + bool IsValidState = true; ///} - /// Collect values that could become \p V in the set \p Values, each mapped to - /// \p ReturnInsts. - void collectValuesRecursively( - Attributor &A, Value *V, SmallPtrSetImpl<ReturnInst *> &ReturnInsts, - DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> &Values) { - - visitValueCB_t<bool> VisitValueCB = [&](Value *Val, bool &) { - assert(!isa<Instruction>(Val) || - &getAnchorScope() == cast<Instruction>(Val)->getFunction()); - Values[Val].insert(ReturnInsts.begin(), ReturnInsts.end()); - }; - - bool UnusedBool; - bool Success = genericValueTraversal(V, UnusedBool, VisitValueCB); - - // If we did abort the above traversal we haven't see all the values. - // Consequently, we cannot know if the information we would derive is - // accurate so we give up early. - if (!Success) - indicatePessimisticFixpoint(); - } - public: - /// See AbstractAttribute::AbstractAttribute(...). - AAReturnedValuesImpl(Function &F, InformationCache &InfoCache) - : AAReturnedValues(F, InfoCache) { - // We do not have an associated argument yet. - AssociatedVal = nullptr; - } + AAReturnedValuesImpl(const IRPosition &IRP) : AAReturnedValues(IRP) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // Reset the state. - AssociatedVal = nullptr; IsFixed = false; IsValidState = true; - HasOverdefinedReturnedCalls = false; ReturnedValues.clear(); - Function &F = cast<Function>(getAnchoredValue()); + Function *F = getAssociatedFunction(); + if (!F) { + indicatePessimisticFixpoint(); + return; + } // The map from instruction opcodes to those instructions in the function. - auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + auto &OpcodeInstMap = A.getInfoCache().getOpcodeInstMapForFunction(*F); // Look through all arguments, if one is marked as returned we are done. - for (Argument &Arg : F.args()) { + for (Argument &Arg : F->args()) { if (Arg.hasReturnedAttr()) { - auto &ReturnInstSet = ReturnedValues[&Arg]; for (Instruction *RI : OpcodeInstMap[Instruction::Ret]) ReturnInstSet.insert(cast<ReturnInst>(RI)); @@ -505,13 +907,8 @@ public: } } - // If no argument was marked as returned we look at all return instructions - // and collect potentially returned values. - for (Instruction *RI : OpcodeInstMap[Instruction::Ret]) { - SmallPtrSet<ReturnInst *, 1> RISet({cast<ReturnInst>(RI)}); - collectValuesRecursively(A, cast<ReturnInst>(RI)->getReturnValue(), RISet, - ReturnedValues); - } + if (!F->hasExactDefinition()) + indicatePessimisticFixpoint(); } /// See AbstractAttribute::manifest(...). @@ -523,25 +920,35 @@ public: /// See AbstractAttribute::getState(...). const AbstractState &getState() const override { return *this; } - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; } - /// See AbstractAttribute::updateImpl(Attributor &A). ChangeStatus updateImpl(Attributor &A) override; + llvm::iterator_range<iterator> returned_values() override { + return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end()); + } + + llvm::iterator_range<const_iterator> returned_values() const override { + return llvm::make_range(ReturnedValues.begin(), ReturnedValues.end()); + } + + const SmallSetVector<CallBase *, 4> &getUnresolvedCalls() const override { + return UnresolvedCalls; + } + /// Return the number of potential return values, -1 if unknown. - size_t getNumReturnValues() const { + size_t getNumReturnValues() const override { return isValidState() ? ReturnedValues.size() : -1; } /// Return an assumed unique return value if a single candidate is found. If /// there cannot be one, return a nullptr. If it is not clear yet, return the /// Optional::NoneType. - Optional<Value *> getAssumedUniqueReturnValue() const; + Optional<Value *> getAssumedUniqueReturnValue(Attributor &A) const; - /// See AbstractState::checkForallReturnedValues(...). - bool - checkForallReturnedValues(std::function<bool(Value &)> &Pred) const override; + /// See AbstractState::checkForAllReturnedValues(...). + bool checkForAllReturnedValuesAndReturnInsts( + const function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> + &Pred) const override; /// Pretty print the attribute similar to the IR representation. const std::string getAsStr() const override; @@ -553,13 +960,15 @@ public: bool isValidState() const override { return IsValidState; } /// See AbstractState::indicateOptimisticFixpoint(...). - void indicateOptimisticFixpoint() override { + ChangeStatus indicateOptimisticFixpoint() override { IsFixed = true; - IsValidState &= true; + return ChangeStatus::UNCHANGED; } - void indicatePessimisticFixpoint() override { + + ChangeStatus indicatePessimisticFixpoint() override { IsFixed = true; IsValidState = false; + return ChangeStatus::CHANGED; } }; @@ -568,21 +977,52 @@ ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) { // Bookkeeping. assert(isValidState()); - NumFnKnownReturns++; + STATS_DECLTRACK(KnownReturnValues, FunctionReturn, + "Number of function with known return values"); // Check if we have an assumed unique return value that we could manifest. - Optional<Value *> UniqueRV = getAssumedUniqueReturnValue(); + Optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A); if (!UniqueRV.hasValue() || !UniqueRV.getValue()) return Changed; // Bookkeeping. - NumFnUniqueReturned++; + STATS_DECLTRACK(UniqueReturnValue, FunctionReturn, + "Number of function with unique return"); + + // Callback to replace the uses of CB with the constant C. + auto ReplaceCallSiteUsersWith = [](CallBase &CB, Constant &C) { + if (CB.getNumUses() == 0 || CB.isMustTailCall()) + return ChangeStatus::UNCHANGED; + CB.replaceAllUsesWith(&C); + return ChangeStatus::CHANGED; + }; // If the assumed unique return value is an argument, annotate it. if (auto *UniqueRVArg = dyn_cast<Argument>(UniqueRV.getValue())) { - AssociatedVal = UniqueRVArg; - Changed = AbstractAttribute::manifest(A) | Changed; + getIRPosition() = IRPosition::argument(*UniqueRVArg); + Changed = IRAttribute::manifest(A); + } else if (auto *RVC = dyn_cast<Constant>(UniqueRV.getValue())) { + // We can replace the returned value with the unique returned constant. + Value &AnchorValue = getAnchorValue(); + if (Function *F = dyn_cast<Function>(&AnchorValue)) { + for (const Use &U : F->uses()) + if (CallBase *CB = dyn_cast<CallBase>(U.getUser())) + if (CB->isCallee(&U)) { + Constant *RVCCast = + ConstantExpr::getTruncOrBitCast(RVC, CB->getType()); + Changed = ReplaceCallSiteUsersWith(*CB, *RVCCast) | Changed; + } + } else { + assert(isa<CallBase>(AnchorValue) && + "Expcected a function or call base anchor!"); + Constant *RVCCast = + ConstantExpr::getTruncOrBitCast(RVC, AnchorValue.getType()); + Changed = ReplaceCallSiteUsersWith(cast<CallBase>(AnchorValue), *RVCCast); + } + if (Changed == ChangeStatus::CHANGED) + STATS_DECLTRACK(UniqueConstantReturnValue, FunctionReturn, + "Number of function returns replaced by constant return"); } return Changed; @@ -590,18 +1030,20 @@ ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) { const std::string AAReturnedValuesImpl::getAsStr() const { return (isAtFixpoint() ? "returns(#" : "may-return(#") + - (isValidState() ? std::to_string(getNumReturnValues()) : "?") + ")"; + (isValidState() ? std::to_string(getNumReturnValues()) : "?") + + ")[#UC: " + std::to_string(UnresolvedCalls.size()) + "]"; } -Optional<Value *> AAReturnedValuesImpl::getAssumedUniqueReturnValue() const { - // If checkForallReturnedValues provides a unique value, ignoring potential +Optional<Value *> +AAReturnedValuesImpl::getAssumedUniqueReturnValue(Attributor &A) const { + // If checkForAllReturnedValues provides a unique value, ignoring potential // undef values that can also be present, it is assumed to be the actual // return value and forwarded to the caller of this method. If there are // multiple, a nullptr is returned indicating there cannot be a unique // returned value. Optional<Value *> UniqueRV; - std::function<bool(Value &)> Pred = [&](Value &RV) -> bool { + auto Pred = [&](Value &RV) -> bool { // If we found a second returned value and neither the current nor the saved // one is an undef, there is no unique returned value. Undefs are special // since we can pretend they have any value. @@ -618,14 +1060,15 @@ Optional<Value *> AAReturnedValuesImpl::getAssumedUniqueReturnValue() const { return true; }; - if (!checkForallReturnedValues(Pred)) + if (!A.checkForAllReturnedValues(Pred, *this)) UniqueRV = nullptr; return UniqueRV; } -bool AAReturnedValuesImpl::checkForallReturnedValues( - std::function<bool(Value &)> &Pred) const { +bool AAReturnedValuesImpl::checkForAllReturnedValuesAndReturnInsts( + const function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> + &Pred) const { if (!isValidState()) return false; @@ -634,11 +1077,11 @@ bool AAReturnedValuesImpl::checkForallReturnedValues( for (auto &It : ReturnedValues) { Value *RV = It.first; - ImmutableCallSite ICS(RV); - if (ICS && !HasOverdefinedReturnedCalls) + CallBase *CB = dyn_cast<CallBase>(RV); + if (CB && !UnresolvedCalls.count(CB)) continue; - if (!Pred(*RV)) + if (!Pred(*RV, It.second)) return false; } @@ -646,125 +1089,196 @@ bool AAReturnedValuesImpl::checkForallReturnedValues( } ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) { + size_t NumUnresolvedCalls = UnresolvedCalls.size(); + bool Changed = false; + + // State used in the value traversals starting in returned values. + struct RVState { + // The map in which we collect return values -> return instrs. + decltype(ReturnedValues) &RetValsMap; + // The flag to indicate a change. + bool &Changed; + // The return instrs we come from. + SmallSetVector<ReturnInst *, 4> RetInsts; + }; - // Check if we know of any values returned by the associated function, - // if not, we are done. - if (getNumReturnValues() == 0) { - indicateOptimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + // Callback for a leaf value returned by the associated function. + auto VisitValueCB = [](Value &Val, RVState &RVS, bool) -> bool { + auto Size = RVS.RetValsMap[&Val].size(); + RVS.RetValsMap[&Val].insert(RVS.RetInsts.begin(), RVS.RetInsts.end()); + bool Inserted = RVS.RetValsMap[&Val].size() != Size; + RVS.Changed |= Inserted; + LLVM_DEBUG({ + if (Inserted) + dbgs() << "[AAReturnedValues] 1 Add new returned value " << Val + << " => " << RVS.RetInsts.size() << "\n"; + }); + return true; + }; - // Check if any of the returned values is a call site we can refine. - decltype(ReturnedValues) AddRVs; - bool HasCallSite = false; + // Helper method to invoke the generic value traversal. + auto VisitReturnedValue = [&](Value &RV, RVState &RVS) { + IRPosition RetValPos = IRPosition::value(RV); + return genericValueTraversal<AAReturnedValues, RVState>(A, RetValPos, *this, + RVS, VisitValueCB); + }; - // Look at all returned call sites. - for (auto &It : ReturnedValues) { - SmallPtrSet<ReturnInst *, 2> &ReturnInsts = It.second; - Value *RV = It.first; - LLVM_DEBUG(dbgs() << "[AAReturnedValues] Potentially returned value " << *RV - << "\n"); + // Callback for all "return intructions" live in the associated function. + auto CheckReturnInst = [this, &VisitReturnedValue, &Changed](Instruction &I) { + ReturnInst &Ret = cast<ReturnInst>(I); + RVState RVS({ReturnedValues, Changed, {}}); + RVS.RetInsts.insert(&Ret); + return VisitReturnedValue(*Ret.getReturnValue(), RVS); + }; - // Only call sites can change during an update, ignore the rest. - CallSite RetCS(RV); - if (!RetCS) + // Start by discovering returned values from all live returned instructions in + // the associated function. + if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret})) + return indicatePessimisticFixpoint(); + + // Once returned values "directly" present in the code are handled we try to + // resolve returned calls. + decltype(ReturnedValues) NewRVsMap; + for (auto &It : ReturnedValues) { + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned value: " << *It.first + << " by #" << It.second.size() << " RIs\n"); + CallBase *CB = dyn_cast<CallBase>(It.first); + if (!CB || UnresolvedCalls.count(CB)) continue; - // For now, any call site we see will prevent us from directly fixing the - // state. However, if the information on the callees is fixed, the call - // sites will be removed and we will fix the information for this state. - HasCallSite = true; - - // Try to find a assumed unique return value for the called function. - auto *RetCSAA = A.getAAFor<AAReturnedValuesImpl>(*this, *RV); - if (!RetCSAA) { - HasOverdefinedReturnedCalls = true; - LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site (" << *RV - << ") with " << (RetCSAA ? "invalid" : "no") - << " associated state\n"); + if (!CB->getCalledFunction()) { + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Unresolved call: " << *CB + << "\n"); + UnresolvedCalls.insert(CB); continue; } - // Try to find a assumed unique return value for the called function. - Optional<Value *> AssumedUniqueRV = RetCSAA->getAssumedUniqueReturnValue(); + // TODO: use the function scope once we have call site AAReturnedValues. + const auto &RetValAA = A.getAAFor<AAReturnedValues>( + *this, IRPosition::function(*CB->getCalledFunction())); + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Found another AAReturnedValues: " + << static_cast<const AbstractAttribute &>(RetValAA) + << "\n"); - // If no assumed unique return value was found due to the lack of - // candidates, we may need to resolve more calls (through more update - // iterations) or the called function will not return. Either way, we simply - // stick with the call sites as return values. Because there were not - // multiple possibilities, we do not treat it as overdefined. - if (!AssumedUniqueRV.hasValue()) + // Skip dead ends, thus if we do not know anything about the returned + // call we mark it as unresolved and it will stay that way. + if (!RetValAA.getState().isValidState()) { + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Unresolved call: " << *CB + << "\n"); + UnresolvedCalls.insert(CB); continue; + } - // If multiple, non-refinable values were found, there cannot be a unique - // return value for the called function. The returned call is overdefined! - if (!AssumedUniqueRV.getValue()) { - HasOverdefinedReturnedCalls = true; - LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site has multiple " - "potentially returned values\n"); + // Do not try to learn partial information. If the callee has unresolved + // return values we will treat the call as unresolved/opaque. + auto &RetValAAUnresolvedCalls = RetValAA.getUnresolvedCalls(); + if (!RetValAAUnresolvedCalls.empty()) { + UnresolvedCalls.insert(CB); continue; } - LLVM_DEBUG({ - bool UniqueRVIsKnown = RetCSAA->isAtFixpoint(); - dbgs() << "[AAReturnedValues] Returned call site " - << (UniqueRVIsKnown ? "known" : "assumed") - << " unique return value: " << *AssumedUniqueRV << "\n"; - }); + // Now check if we can track transitively returned values. If possible, thus + // if all return value can be represented in the current scope, do so. + bool Unresolved = false; + for (auto &RetValAAIt : RetValAA.returned_values()) { + Value *RetVal = RetValAAIt.first; + if (isa<Argument>(RetVal) || isa<CallBase>(RetVal) || + isa<Constant>(RetVal)) + continue; + // Anything that did not fit in the above categories cannot be resolved, + // mark the call as unresolved. + LLVM_DEBUG(dbgs() << "[AAReturnedValues] transitively returned value " + "cannot be translated: " + << *RetVal << "\n"); + UnresolvedCalls.insert(CB); + Unresolved = true; + break; + } - // The assumed unique return value. - Value *AssumedRetVal = AssumedUniqueRV.getValue(); - - // If the assumed unique return value is an argument, lookup the matching - // call site operand and recursively collect new returned values. - // If it is not an argument, it is just put into the set of returned values - // as we would have already looked through casts, phis, and similar values. - if (Argument *AssumedRetArg = dyn_cast<Argument>(AssumedRetVal)) - collectValuesRecursively(A, - RetCS.getArgOperand(AssumedRetArg->getArgNo()), - ReturnInsts, AddRVs); - else - AddRVs[AssumedRetVal].insert(ReturnInsts.begin(), ReturnInsts.end()); - } + if (Unresolved) + continue; - // Keep track of any change to trigger updates on dependent attributes. - ChangeStatus Changed = ChangeStatus::UNCHANGED; + // Now track transitively returned values. + unsigned &NumRetAA = NumReturnedValuesPerKnownAA[CB]; + if (NumRetAA == RetValAA.getNumReturnValues()) { + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Skip call as it has not " + "changed since it was seen last\n"); + continue; + } + NumRetAA = RetValAA.getNumReturnValues(); + + for (auto &RetValAAIt : RetValAA.returned_values()) { + Value *RetVal = RetValAAIt.first; + if (Argument *Arg = dyn_cast<Argument>(RetVal)) { + // Arguments are mapped to call site operands and we begin the traversal + // again. + bool Unused = false; + RVState RVS({NewRVsMap, Unused, RetValAAIt.second}); + VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS); + continue; + } else if (isa<CallBase>(RetVal)) { + // Call sites are resolved by the callee attribute over time, no need to + // do anything for us. + continue; + } else if (isa<Constant>(RetVal)) { + // Constants are valid everywhere, we can simply take them. + NewRVsMap[RetVal].insert(It.second.begin(), It.second.end()); + continue; + } + } + } - for (auto &It : AddRVs) { + // To avoid modifications to the ReturnedValues map while we iterate over it + // we kept record of potential new entries in a copy map, NewRVsMap. + for (auto &It : NewRVsMap) { assert(!It.second.empty() && "Entry does not add anything."); auto &ReturnInsts = ReturnedValues[It.first]; for (ReturnInst *RI : It.second) - if (ReturnInsts.insert(RI).second) { + if (ReturnInsts.insert(RI)) { LLVM_DEBUG(dbgs() << "[AAReturnedValues] Add new returned value " << *It.first << " => " << *RI << "\n"); - Changed = ChangeStatus::CHANGED; + Changed = true; } } - // If there is no call site in the returned values we are done. - if (!HasCallSite) { - indicateOptimisticFixpoint(); - return ChangeStatus::CHANGED; - } - - return Changed; + Changed |= (NumUnresolvedCalls != UnresolvedCalls.size()); + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; } -/// ------------------------ NoSync Function Attribute ------------------------- +struct AAReturnedValuesFunction final : public AAReturnedValuesImpl { + AAReturnedValuesFunction(const IRPosition &IRP) : AAReturnedValuesImpl(IRP) {} -struct AANoSyncFunction : AANoSync, BooleanState { + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(returned) } +}; - AANoSyncFunction(Function &F, InformationCache &InfoCache) - : AANoSync(F, InfoCache) {} +/// Returned values information for a call sites. +struct AAReturnedValuesCallSite final : AAReturnedValuesImpl { + AAReturnedValuesCallSite(const IRPosition &IRP) : AAReturnedValuesImpl(IRP) {} - /// See AbstractAttribute::getState() - /// { - AbstractState &getState() override { return *this; } - const AbstractState &getState() const override { return *this; } - /// } + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites instead of + // redirecting requests to the callee. + llvm_unreachable("Abstract attributes for returned values are not " + "supported for call sites yet!"); + } - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { return MP_FUNCTION; } + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + return indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + +/// ------------------------ NoSync Function Attribute ------------------------- + +struct AANoSyncImpl : AANoSync { + AANoSyncImpl(const IRPosition &IRP) : AANoSync(IRP) {} const std::string getAsStr() const override { return getAssumed() ? "nosync" : "may-sync"; @@ -773,12 +1287,6 @@ struct AANoSyncFunction : AANoSync, BooleanState { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override; - /// See AANoSync::isAssumedNoSync() - bool isAssumedNoSync() const override { return getAssumed(); } - - /// See AANoSync::isKnownNoSync() - bool isKnownNoSync() const override { return getKnown(); } - /// Helper function used to determine whether an instruction is non-relaxed /// atomic. In other words, if an atomic instruction does not have unordered /// or monotonic ordering @@ -792,7 +1300,7 @@ struct AANoSyncFunction : AANoSync, BooleanState { static bool isNoSyncIntrinsic(Instruction *I); }; -bool AANoSyncFunction::isNonRelaxedAtomic(Instruction *I) { +bool AANoSyncImpl::isNonRelaxedAtomic(Instruction *I) { if (!I->isAtomic()) return false; @@ -841,7 +1349,7 @@ bool AANoSyncFunction::isNonRelaxedAtomic(Instruction *I) { /// Checks if an intrinsic is nosync. Currently only checks mem* intrinsics. /// FIXME: We should ipmrove the handling of intrinsics. -bool AANoSyncFunction::isNoSyncIntrinsic(Instruction *I) { +bool AANoSyncImpl::isNoSyncIntrinsic(Instruction *I) { if (auto *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { /// Element wise atomic memory intrinsics are can only be unordered, @@ -863,7 +1371,7 @@ bool AANoSyncFunction::isNoSyncIntrinsic(Instruction *I) { return false; } -bool AANoSyncFunction::isVolatile(Instruction *I) { +bool AANoSyncImpl::isVolatile(Instruction *I) { assert(!ImmutableCallSite(I) && !isa<CallBase>(I) && "Calls should not be checked here"); @@ -881,482 +1389,3074 @@ bool AANoSyncFunction::isVolatile(Instruction *I) { } } -ChangeStatus AANoSyncFunction::updateImpl(Attributor &A) { - Function &F = getAnchorScope(); +ChangeStatus AANoSyncImpl::updateImpl(Attributor &A) { - /// We are looking for volatile instructions or Non-Relaxed atomics. - /// FIXME: We should ipmrove the handling of intrinsics. - for (Instruction *I : InfoCache.getReadOrWriteInstsForFunction(F)) { - ImmutableCallSite ICS(I); - auto *NoSyncAA = A.getAAFor<AANoSyncFunction>(*this, *I); + auto CheckRWInstForNoSync = [&](Instruction &I) { + /// We are looking for volatile instructions or Non-Relaxed atomics. + /// FIXME: We should ipmrove the handling of intrinsics. - if (isa<IntrinsicInst>(I) && isNoSyncIntrinsic(I)) - continue; + if (isa<IntrinsicInst>(&I) && isNoSyncIntrinsic(&I)) + return true; + + if (ImmutableCallSite ICS = ImmutableCallSite(&I)) { + if (ICS.hasFnAttr(Attribute::NoSync)) + return true; + + const auto &NoSyncAA = + A.getAAFor<AANoSync>(*this, IRPosition::callsite_function(ICS)); + if (NoSyncAA.isAssumedNoSync()) + return true; + return false; + } + + if (!isVolatile(&I) && !isNonRelaxedAtomic(&I)) + return true; + + return false; + }; - if (ICS && (!NoSyncAA || !NoSyncAA->isAssumedNoSync()) && - !ICS.hasFnAttr(Attribute::NoSync)) { + auto CheckForNoSync = [&](Instruction &I) { + // At this point we handled all read/write effects and they are all + // nosync, so they can be skipped. + if (I.mayReadOrWriteMemory()) + return true; + + // non-convergent and readnone imply nosync. + return !ImmutableCallSite(&I).isConvergent(); + }; + + if (!A.checkForAllReadWriteInstructions(CheckRWInstForNoSync, *this) || + !A.checkForAllCallLikeInstructions(CheckForNoSync, *this)) + return indicatePessimisticFixpoint(); + + return ChangeStatus::UNCHANGED; +} + +struct AANoSyncFunction final : public AANoSyncImpl { + AANoSyncFunction(const IRPosition &IRP) : AANoSyncImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(nosync) } +}; + +/// NoSync attribute deduction for a call sites. +struct AANoSyncCallSite final : AANoSyncImpl { + AANoSyncCallSite(const IRPosition &IRP) : AANoSyncImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoSyncImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AANoSync>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), static_cast<const AANoSync::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nosync); } +}; + +/// ------------------------ No-Free Attributes ---------------------------- + +struct AANoFreeImpl : public AANoFree { + AANoFreeImpl(const IRPosition &IRP) : AANoFree(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto CheckForNoFree = [&](Instruction &I) { + ImmutableCallSite ICS(&I); + if (ICS.hasFnAttr(Attribute::NoFree)) + return true; + + const auto &NoFreeAA = + A.getAAFor<AANoFree>(*this, IRPosition::callsite_function(ICS)); + return NoFreeAA.isAssumedNoFree(); + }; + + if (!A.checkForAllCallLikeInstructions(CheckForNoFree, *this)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "nofree" : "may-free"; + } +}; + +struct AANoFreeFunction final : public AANoFreeImpl { + AANoFreeFunction(const IRPosition &IRP) : AANoFreeImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(nofree) } +}; + +/// NoFree attribute deduction for a call sites. +struct AANoFreeCallSite final : AANoFreeImpl { + AANoFreeCallSite(const IRPosition &IRP) : AANoFreeImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoFreeImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) + indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AANoFree>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), static_cast<const AANoFree::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nofree); } +}; + +/// ------------------------ NonNull Argument Attribute ------------------------ +static int64_t getKnownNonNullAndDerefBytesForUse( + Attributor &A, AbstractAttribute &QueryingAA, Value &AssociatedValue, + const Use *U, const Instruction *I, bool &IsNonNull, bool &TrackUse) { + TrackUse = false; + + const Value *UseV = U->get(); + if (!UseV->getType()->isPointerTy()) + return 0; + + Type *PtrTy = UseV->getType(); + const Function *F = I->getFunction(); + bool NullPointerIsDefined = + F ? llvm::NullPointerIsDefined(F, PtrTy->getPointerAddressSpace()) : true; + const DataLayout &DL = A.getInfoCache().getDL(); + if (ImmutableCallSite ICS = ImmutableCallSite(I)) { + if (ICS.isBundleOperand(U)) + return 0; + + if (ICS.isCallee(U)) { + IsNonNull |= !NullPointerIsDefined; + return 0; } - if (ICS) - continue; + unsigned ArgNo = ICS.getArgumentNo(U); + IRPosition IRP = IRPosition::callsite_argument(ICS, ArgNo); + auto &DerefAA = A.getAAFor<AADereferenceable>(QueryingAA, IRP); + IsNonNull |= DerefAA.isKnownNonNull(); + return DerefAA.getKnownDereferenceableBytes(); + } - if (!isVolatile(I) && !isNonRelaxedAtomic(I)) - continue; + int64_t Offset; + if (const Value *Base = getBasePointerOfAccessPointerOperand(I, Offset, DL)) { + if (Base == &AssociatedValue && getPointerOperand(I) == UseV) { + int64_t DerefBytes = + Offset + (int64_t)DL.getTypeStoreSize(PtrTy->getPointerElementType()); + + IsNonNull |= !NullPointerIsDefined; + return DerefBytes; + } + } + if (const Value *Base = + GetPointerBaseWithConstantOffset(UseV, Offset, DL, + /*AllowNonInbounds*/ false)) { + auto &DerefAA = + A.getAAFor<AADereferenceable>(QueryingAA, IRPosition::value(*Base)); + IsNonNull |= (!NullPointerIsDefined && DerefAA.isKnownNonNull()); + IsNonNull |= (!NullPointerIsDefined && (Offset != 0)); + int64_t DerefBytes = DerefAA.getKnownDereferenceableBytes(); + return std::max(int64_t(0), DerefBytes - Offset); + } + + return 0; +} + +struct AANonNullImpl : AANonNull { + AANonNullImpl(const IRPosition &IRP) + : AANonNull(IRP), + NullIsDefined(NullPointerIsDefined( + getAnchorScope(), + getAssociatedValue().getType()->getPointerAddressSpace())) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + if (!NullIsDefined && + hasAttr({Attribute::NonNull, Attribute::Dereferenceable})) + indicateOptimisticFixpoint(); + else + AANonNull::initialize(A); + } + + /// See AAFromMustBeExecutedContext + bool followUse(Attributor &A, const Use *U, const Instruction *I) { + bool IsNonNull = false; + bool TrackUse = false; + getKnownNonNullAndDerefBytesForUse(A, *this, getAssociatedValue(), U, I, + IsNonNull, TrackUse); + takeKnownMaximum(IsNonNull); + return TrackUse; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "nonnull" : "may-null"; + } + + /// Flag to determine if the underlying value can be null and still allow + /// valid accesses. + const bool NullIsDefined; +}; + +/// NonNull attribute for a floating value. +struct AANonNullFloating + : AAFromMustBeExecutedContext<AANonNull, AANonNullImpl> { + using Base = AAFromMustBeExecutedContext<AANonNull, AANonNullImpl>; + AANonNullFloating(const IRPosition &IRP) : Base(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + Base::initialize(A); + + if (isAtFixpoint()) + return; + + const IRPosition &IRP = getIRPosition(); + const Value &V = IRP.getAssociatedValue(); + const DataLayout &DL = A.getDataLayout(); + + // TODO: This context sensitive query should be removed once we can do + // context sensitive queries in the genericValueTraversal below. + if (isKnownNonZero(&V, DL, 0, /* TODO: AC */ nullptr, IRP.getCtxI(), + /* TODO: DT */ nullptr)) + indicateOptimisticFixpoint(); + } + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Change = Base::updateImpl(A); + if (isKnownNonNull()) + return Change; + + if (!NullIsDefined) { + const auto &DerefAA = A.getAAFor<AADereferenceable>(*this, getIRPosition()); + if (DerefAA.getAssumedDereferenceableBytes()) + return Change; + } + + const DataLayout &DL = A.getDataLayout(); + + auto VisitValueCB = [&](Value &V, AAAlign::StateType &T, + bool Stripped) -> bool { + const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V)); + if (!Stripped && this == &AA) { + if (!isKnownNonZero(&V, DL, 0, /* TODO: AC */ nullptr, + /* CtxI */ getCtxI(), + /* TODO: DT */ nullptr)) + T.indicatePessimisticFixpoint(); + } else { + // Use abstract attribute information. + const AANonNull::StateType &NS = + static_cast<const AANonNull::StateType &>(AA.getState()); + T ^= NS; + } + return T.isValidState(); + }; + + StateType T; + if (!genericValueTraversal<AANonNull, StateType>(A, getIRPosition(), *this, + T, VisitValueCB)) + return indicatePessimisticFixpoint(); + + return clampStateAndIndicateChange(getState(), T); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(nonnull) } +}; + +/// NonNull attribute for function return value. +struct AANonNullReturned final + : AAReturnedFromReturnedValues<AANonNull, AANonNullImpl> { + AANonNullReturned(const IRPosition &IRP) + : AAReturnedFromReturnedValues<AANonNull, AANonNullImpl>(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(nonnull) } +}; + +/// NonNull attribute for function argument. +struct AANonNullArgument final + : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AANonNull, + AANonNullImpl> { + AANonNullArgument(const IRPosition &IRP) + : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AANonNull, + AANonNullImpl>( + IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) } +}; + +struct AANonNullCallSiteArgument final : AANonNullFloating { + AANonNullCallSiteArgument(const IRPosition &IRP) : AANonNullFloating(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(nonnull) } +}; + +/// NonNull attribute for a call site return position. +struct AANonNullCallSiteReturned final + : AACallSiteReturnedFromReturnedAndMustBeExecutedContext<AANonNull, + AANonNullImpl> { + AANonNullCallSiteReturned(const IRPosition &IRP) + : AACallSiteReturnedFromReturnedAndMustBeExecutedContext<AANonNull, + AANonNullImpl>( + IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) } +}; + +/// ------------------------ No-Recurse Attributes ---------------------------- + +struct AANoRecurseImpl : public AANoRecurse { + AANoRecurseImpl(const IRPosition &IRP) : AANoRecurse(IRP) {} + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + return getAssumed() ? "norecurse" : "may-recurse"; + } +}; + +struct AANoRecurseFunction final : AANoRecurseImpl { + AANoRecurseFunction(const IRPosition &IRP) : AANoRecurseImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoRecurseImpl::initialize(A); + if (const Function *F = getAnchorScope()) + if (A.getInfoCache().getSccSize(*F) == 1) + return; indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; } - auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); - auto Opcodes = {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, - (unsigned)Instruction::Call}; + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { - for (unsigned Opcode : Opcodes) { - for (Instruction *I : OpcodeInstMap[Opcode]) { - // At this point we handled all read/write effects and they are all - // nosync, so they can be skipped. - if (I->mayReadOrWriteMemory()) - continue; + auto CheckForNoRecurse = [&](Instruction &I) { + ImmutableCallSite ICS(&I); + if (ICS.hasFnAttr(Attribute::NoRecurse)) + return true; - ImmutableCallSite ICS(I); + const auto &NoRecurseAA = + A.getAAFor<AANoRecurse>(*this, IRPosition::callsite_function(ICS)); + if (!NoRecurseAA.isAssumedNoRecurse()) + return false; - // non-convergent and readnone imply nosync. - if (!ICS.isConvergent()) - continue; + // Recursion to the same function + if (ICS.getCalledFunction() == getAnchorScope()) + return false; + + return true; + }; + + if (!A.checkForAllCallLikeInstructions(CheckForNoRecurse, *this)) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } + + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(norecurse) } +}; + +/// NoRecurse attribute deduction for a call sites. +struct AANoRecurseCallSite final : AANoRecurseImpl { + AANoRecurseCallSite(const IRPosition &IRP) : AANoRecurseImpl(IRP) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoRecurseImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AANoRecurse>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), + static_cast<const AANoRecurse::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(norecurse); } +}; + +/// ------------------------ Will-Return Attributes ---------------------------- + +// Helper function that checks whether a function has any cycle. +// TODO: Replace with more efficent code +static bool containsCycle(Function &F) { + SmallPtrSet<BasicBlock *, 32> Visited; + + // Traverse BB by dfs and check whether successor is already visited. + for (BasicBlock *BB : depth_first(&F)) { + Visited.insert(BB); + for (auto *SuccBB : successors(BB)) { + if (Visited.count(SuccBB)) + return true; } } + return false; +} - return ChangeStatus::UNCHANGED; +// Helper function that checks the function have a loop which might become an +// endless loop +// FIXME: Any cycle is regarded as endless loop for now. +// We have to allow some patterns. +static bool containsPossiblyEndlessLoop(Function *F) { + return !F || !F->hasExactDefinition() || containsCycle(*F); } -/// ------------------------ No-Free Attributes ---------------------------- +struct AAWillReturnImpl : public AAWillReturn { + AAWillReturnImpl(const IRPosition &IRP) : AAWillReturn(IRP) {} -struct AANoFreeFunction : AbstractAttribute, BooleanState { + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAWillReturn::initialize(A); - /// See AbstractAttribute::AbstractAttribute(...). - AANoFreeFunction(Function &F, InformationCache &InfoCache) - : AbstractAttribute(F, InfoCache) {} + Function *F = getAssociatedFunction(); + if (containsPossiblyEndlessLoop(F)) + indicatePessimisticFixpoint(); + } - /// See AbstractAttribute::getState() - ///{ - AbstractState &getState() override { return *this; } - const AbstractState &getState() const override { return *this; } - ///} + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto CheckForWillReturn = [&](Instruction &I) { + IRPosition IPos = IRPosition::callsite_function(ImmutableCallSite(&I)); + const auto &WillReturnAA = A.getAAFor<AAWillReturn>(*this, IPos); + if (WillReturnAA.isKnownWillReturn()) + return true; + if (!WillReturnAA.isAssumedWillReturn()) + return false; + const auto &NoRecurseAA = A.getAAFor<AANoRecurse>(*this, IPos); + return NoRecurseAA.isAssumedNoRecurse(); + }; + + if (!A.checkForAllCallLikeInstructions(CheckForWillReturn, *this)) + return indicatePessimisticFixpoint(); + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + return getAssumed() ? "willreturn" : "may-noreturn"; + } +}; + +struct AAWillReturnFunction final : AAWillReturnImpl { + AAWillReturnFunction(const IRPosition &IRP) : AAWillReturnImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(willreturn) } +}; + +/// WillReturn attribute deduction for a call sites. +struct AAWillReturnCallSite final : AAWillReturnImpl { + AAWillReturnCallSite(const IRPosition &IRP) : AAWillReturnImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAWillReturnImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) + indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AAWillReturn>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), + static_cast<const AAWillReturn::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(willreturn); } +}; + +/// ------------------------ NoAlias Argument Attribute ------------------------ + +struct AANoAliasImpl : AANoAlias { + AANoAliasImpl(const IRPosition &IRP) : AANoAlias(IRP) {} + + const std::string getAsStr() const override { + return getAssumed() ? "noalias" : "may-alias"; + } +}; + +/// NoAlias attribute for a floating value. +struct AANoAliasFloating final : AANoAliasImpl { + AANoAliasFloating(const IRPosition &IRP) : AANoAliasImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoAliasImpl::initialize(A); + Value &Val = getAssociatedValue(); + if (isa<AllocaInst>(Val)) + indicateOptimisticFixpoint(); + if (isa<ConstantPointerNull>(Val) && + Val.getType()->getPointerAddressSpace() == 0) + indicateOptimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Implement this. + return indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(noalias) + } +}; + +/// NoAlias attribute for an argument. +struct AANoAliasArgument final + : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> { + AANoAliasArgument(const IRPosition &IRP) + : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noalias) } +}; + +struct AANoAliasCallSiteArgument final : AANoAliasImpl { + AANoAliasCallSiteArgument(const IRPosition &IRP) : AANoAliasImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // See callsite argument attribute and callee argument attribute. + ImmutableCallSite ICS(&getAnchorValue()); + if (ICS.paramHasAttr(getArgNo(), Attribute::NoAlias)) + indicateOptimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // We can deduce "noalias" if the following conditions hold. + // (i) Associated value is assumed to be noalias in the definition. + // (ii) Associated value is assumed to be no-capture in all the uses + // possibly executed before this callsite. + // (iii) There is no other pointer argument which could alias with the + // value. + + const Value &V = getAssociatedValue(); + const IRPosition IRP = IRPosition::value(V); + + // (i) Check whether noalias holds in the definition. + + auto &NoAliasAA = A.getAAFor<AANoAlias>(*this, IRP); + + if (!NoAliasAA.isAssumedNoAlias()) + return indicatePessimisticFixpoint(); + + LLVM_DEBUG(dbgs() << "[Attributor][AANoAliasCSArg] " << V + << " is assumed NoAlias in the definition\n"); + + // (ii) Check whether the value is captured in the scope using AANoCapture. + // FIXME: This is conservative though, it is better to look at CFG and + // check only uses possibly executed before this callsite. - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { return MP_FUNCTION; } + auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, IRP); + if (!NoCaptureAA.isAssumedNoCaptureMaybeReturned()) { + LLVM_DEBUG( + dbgs() << "[Attributor][AANoAliasCSArg] " << V + << " cannot be noalias as it is potentially captured\n"); + return indicatePessimisticFixpoint(); + } + + // (iii) Check there is no other pointer argument which could alias with the + // value. + ImmutableCallSite ICS(&getAnchorValue()); + for (unsigned i = 0; i < ICS.getNumArgOperands(); i++) { + if (getArgNo() == (int)i) + continue; + const Value *ArgOp = ICS.getArgOperand(i); + if (!ArgOp->getType()->isPointerTy()) + continue; + + if (const Function *F = getAnchorScope()) { + if (AAResults *AAR = A.getInfoCache().getAAResultsForFunction(*F)) { + bool IsAliasing = AAR->isNoAlias(&getAssociatedValue(), ArgOp); + LLVM_DEBUG(dbgs() + << "[Attributor][NoAliasCSArg] Check alias between " + "callsite arguments " + << AAR->isNoAlias(&getAssociatedValue(), ArgOp) << " " + << getAssociatedValue() << " " << *ArgOp << " => " + << (IsAliasing ? "" : "no-") << "alias \n"); + + if (IsAliasing) + continue; + } + } + return indicatePessimisticFixpoint(); + } + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(noalias) } +}; + +/// NoAlias attribute for function return value. +struct AANoAliasReturned final : AANoAliasImpl { + AANoAliasReturned(const IRPosition &IRP) : AANoAliasImpl(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + virtual ChangeStatus updateImpl(Attributor &A) override { + + auto CheckReturnValue = [&](Value &RV) -> bool { + if (Constant *C = dyn_cast<Constant>(&RV)) + if (C->isNullValue() || isa<UndefValue>(C)) + return true; + + /// For now, we can only deduce noalias if we have call sites. + /// FIXME: add more support. + ImmutableCallSite ICS(&RV); + if (!ICS) + return false; + + const IRPosition &RVPos = IRPosition::value(RV); + const auto &NoAliasAA = A.getAAFor<AANoAlias>(*this, RVPos); + if (!NoAliasAA.isAssumedNoAlias()) + return false; + + const auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, RVPos); + return NoCaptureAA.isAssumedNoCaptureMaybeReturned(); + }; + + if (!A.checkForAllReturnedValues(CheckReturnValue, *this)) + return indicatePessimisticFixpoint(); + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noalias) } +}; + +/// NoAlias attribute deduction for a call site return value. +struct AANoAliasCallSiteReturned final : AANoAliasImpl { + AANoAliasCallSiteReturned(const IRPosition &IRP) : AANoAliasImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AANoAliasImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) + indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::returned(*F); + auto &FnAA = A.getAAFor<AANoAlias>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), static_cast<const AANoAlias::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noalias); } +}; + +/// -------------------AAIsDead Function Attribute----------------------- + +struct AAIsDeadImpl : public AAIsDead { + AAIsDeadImpl(const IRPosition &IRP) : AAIsDead(IRP) {} + + void initialize(Attributor &A) override { + const Function *F = getAssociatedFunction(); + if (F && !F->isDeclaration()) + exploreFromEntry(A, F); + } + + void exploreFromEntry(Attributor &A, const Function *F) { + ToBeExploredPaths.insert(&(F->getEntryBlock().front())); + + for (size_t i = 0; i < ToBeExploredPaths.size(); ++i) + if (const Instruction *NextNoReturnI = + findNextNoReturn(A, ToBeExploredPaths[i])) + NoReturnCalls.insert(NextNoReturnI); + + // Mark the block live after we looked for no-return instructions. + assumeLive(A, F->getEntryBlock()); + } + + /// Find the next assumed noreturn instruction in the block of \p I starting + /// from, thus including, \p I. + /// + /// The caller is responsible to monitor the ToBeExploredPaths set as new + /// instructions discovered in other basic block will be placed in there. + /// + /// \returns The next assumed noreturn instructions in the block of \p I + /// starting from, thus including, \p I. + const Instruction *findNextNoReturn(Attributor &A, const Instruction *I); /// See AbstractAttribute::getAsStr(). const std::string getAsStr() const override { - return getAssumed() ? "nofree" : "may-free"; + return "Live[#BB " + std::to_string(AssumedLiveBlocks.size()) + "/" + + std::to_string(getAssociatedFunction()->size()) + "][#NRI " + + std::to_string(NoReturnCalls.size()) + "]"; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + assert(getState().isValidState() && + "Attempted to manifest an invalid state!"); + + ChangeStatus HasChanged = ChangeStatus::UNCHANGED; + Function &F = *getAssociatedFunction(); + + if (AssumedLiveBlocks.empty()) { + A.deleteAfterManifest(F); + return ChangeStatus::CHANGED; + } + + // Flag to determine if we can change an invoke to a call assuming the + // callee is nounwind. This is not possible if the personality of the + // function allows to catch asynchronous exceptions. + bool Invoke2CallAllowed = !mayCatchAsynchronousExceptions(F); + + for (const Instruction *NRC : NoReturnCalls) { + Instruction *I = const_cast<Instruction *>(NRC); + BasicBlock *BB = I->getParent(); + Instruction *SplitPos = I->getNextNode(); + // TODO: mark stuff before unreachable instructions as dead. + + if (auto *II = dyn_cast<InvokeInst>(I)) { + // If we keep the invoke the split position is at the beginning of the + // normal desitination block (it invokes a noreturn function after all). + BasicBlock *NormalDestBB = II->getNormalDest(); + SplitPos = &NormalDestBB->front(); + + /// Invoke is replaced with a call and unreachable is placed after it if + /// the callee is nounwind and noreturn. Otherwise, we keep the invoke + /// and only place an unreachable in the normal successor. + if (Invoke2CallAllowed) { + if (II->getCalledFunction()) { + const IRPosition &IPos = IRPosition::callsite_function(*II); + const auto &AANoUnw = A.getAAFor<AANoUnwind>(*this, IPos); + if (AANoUnw.isAssumedNoUnwind()) { + LLVM_DEBUG(dbgs() + << "[AAIsDead] Replace invoke with call inst\n"); + // We do not need an invoke (II) but instead want a call followed + // by an unreachable. However, we do not remove II as other + // abstract attributes might have it cached as part of their + // results. Given that we modify the CFG anyway, we simply keep II + // around but in a new dead block. To avoid II being live through + // a different edge we have to ensure the block we place it in is + // only reached from the current block of II and then not reached + // at all when we insert the unreachable. + SplitBlockPredecessors(NormalDestBB, {BB}, ".i2c"); + CallInst *CI = createCallMatchingInvoke(II); + CI->insertBefore(II); + CI->takeName(II); + II->replaceAllUsesWith(CI); + SplitPos = CI->getNextNode(); + } + } + } + + if (SplitPos == &NormalDestBB->front()) { + // If this is an invoke of a noreturn function the edge to the normal + // destination block is dead but not necessarily the block itself. + // TODO: We need to move to an edge based system during deduction and + // also manifest. + assert(!NormalDestBB->isLandingPad() && + "Expected the normal destination not to be a landingpad!"); + if (NormalDestBB->getUniquePredecessor() == BB) { + assumeLive(A, *NormalDestBB); + } else { + BasicBlock *SplitBB = + SplitBlockPredecessors(NormalDestBB, {BB}, ".dead"); + // The split block is live even if it contains only an unreachable + // instruction at the end. + assumeLive(A, *SplitBB); + SplitPos = SplitBB->getTerminator(); + HasChanged = ChangeStatus::CHANGED; + } + } + } + + if (isa_and_nonnull<UnreachableInst>(SplitPos)) + continue; + + BB = SplitPos->getParent(); + SplitBlock(BB, SplitPos); + changeToUnreachable(BB->getTerminator(), /* UseLLVMTrap */ false); + HasChanged = ChangeStatus::CHANGED; + } + + for (BasicBlock &BB : F) + if (!AssumedLiveBlocks.count(&BB)) + A.deleteAfterManifest(BB); + + return HasChanged; } /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override; - /// See AbstractAttribute::getAttrKind(). - Attribute::AttrKind getAttrKind() const override { return ID; } + /// See AAIsDead::isAssumedDead(BasicBlock *). + bool isAssumedDead(const BasicBlock *BB) const override { + assert(BB->getParent() == getAssociatedFunction() && + "BB must be in the same anchor scope function."); + + if (!getAssumed()) + return false; + return !AssumedLiveBlocks.count(BB); + } + + /// See AAIsDead::isKnownDead(BasicBlock *). + bool isKnownDead(const BasicBlock *BB) const override { + return getKnown() && isAssumedDead(BB); + } + + /// See AAIsDead::isAssumed(Instruction *I). + bool isAssumedDead(const Instruction *I) const override { + assert(I->getParent()->getParent() == getAssociatedFunction() && + "Instruction must be in the same anchor scope function."); + + if (!getAssumed()) + return false; + + // If it is not in AssumedLiveBlocks then it for sure dead. + // Otherwise, it can still be after noreturn call in a live block. + if (!AssumedLiveBlocks.count(I->getParent())) + return true; + + // If it is not after a noreturn call, than it is live. + return isAfterNoReturn(I); + } + + /// See AAIsDead::isKnownDead(Instruction *I). + bool isKnownDead(const Instruction *I) const override { + return getKnown() && isAssumedDead(I); + } + + /// Check if instruction is after noreturn call, in other words, assumed dead. + bool isAfterNoReturn(const Instruction *I) const; - /// Return true if "nofree" is assumed. - bool isAssumedNoFree() const { return getAssumed(); } + /// Determine if \p F might catch asynchronous exceptions. + static bool mayCatchAsynchronousExceptions(const Function &F) { + return F.hasPersonalityFn() && !canSimplifyInvokeNoUnwind(&F); + } + + /// Assume \p BB is (partially) live now and indicate to the Attributor \p A + /// that internal function called from \p BB should now be looked at. + void assumeLive(Attributor &A, const BasicBlock &BB) { + if (!AssumedLiveBlocks.insert(&BB).second) + return; + + // We assume that all of BB is (probably) live now and if there are calls to + // internal functions we will assume that those are now live as well. This + // is a performance optimization for blocks with calls to a lot of internal + // functions. It can however cause dead functions to be treated as live. + for (const Instruction &I : BB) + if (ImmutableCallSite ICS = ImmutableCallSite(&I)) + if (const Function *F = ICS.getCalledFunction()) + if (F->hasLocalLinkage()) + A.markLiveInternalFunction(*F); + } - /// Return true if "nofree" is known. - bool isKnownNoFree() const { return getKnown(); } + /// Collection of to be explored paths. + SmallSetVector<const Instruction *, 8> ToBeExploredPaths; - /// The identifier used by the Attributor for this class of attributes. - static constexpr Attribute::AttrKind ID = Attribute::NoFree; + /// Collection of all assumed live BasicBlocks. + DenseSet<const BasicBlock *> AssumedLiveBlocks; + + /// Collection of calls with noreturn attribute, assumed or knwon. + SmallSetVector<const Instruction *, 4> NoReturnCalls; }; -ChangeStatus AANoFreeFunction::updateImpl(Attributor &A) { - Function &F = getAnchorScope(); +struct AAIsDeadFunction final : public AAIsDeadImpl { + AAIsDeadFunction(const IRPosition &IRP) : AAIsDeadImpl(IRP) {} - // The map from instruction opcodes to those instructions in the function. - auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECL(PartiallyDeadBlocks, Function, + "Number of basic blocks classified as partially dead"); + BUILD_STAT_NAME(PartiallyDeadBlocks, Function) += NoReturnCalls.size(); + } +}; - for (unsigned Opcode : - {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, - (unsigned)Instruction::Call}) { - for (Instruction *I : OpcodeInstMap[Opcode]) { +bool AAIsDeadImpl::isAfterNoReturn(const Instruction *I) const { + const Instruction *PrevI = I->getPrevNode(); + while (PrevI) { + if (NoReturnCalls.count(PrevI)) + return true; + PrevI = PrevI->getPrevNode(); + } + return false; +} - auto ICS = ImmutableCallSite(I); - auto *NoFreeAA = A.getAAFor<AANoFreeFunction>(*this, *I); +const Instruction *AAIsDeadImpl::findNextNoReturn(Attributor &A, + const Instruction *I) { + const BasicBlock *BB = I->getParent(); + const Function &F = *BB->getParent(); - if ((!NoFreeAA || !NoFreeAA->isAssumedNoFree()) && - !ICS.hasFnAttr(Attribute::NoFree)) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; + // Flag to determine if we can change an invoke to a call assuming the callee + // is nounwind. This is not possible if the personality of the function allows + // to catch asynchronous exceptions. + bool Invoke2CallAllowed = !mayCatchAsynchronousExceptions(F); + + // TODO: We should have a function that determines if an "edge" is dead. + // Edges could be from an instruction to the next or from a terminator + // to the successor. For now, we need to special case the unwind block + // of InvokeInst below. + + while (I) { + ImmutableCallSite ICS(I); + + if (ICS) { + const IRPosition &IPos = IRPosition::callsite_function(ICS); + // Regarless of the no-return property of an invoke instruction we only + // learn that the regular successor is not reachable through this + // instruction but the unwind block might still be. + if (auto *Invoke = dyn_cast<InvokeInst>(I)) { + // Use nounwind to justify the unwind block is dead as well. + const auto &AANoUnw = A.getAAFor<AANoUnwind>(*this, IPos); + if (!Invoke2CallAllowed || !AANoUnw.isAssumedNoUnwind()) { + assumeLive(A, *Invoke->getUnwindDest()); + ToBeExploredPaths.insert(&Invoke->getUnwindDest()->front()); + } } + + const auto &NoReturnAA = A.getAAFor<AANoReturn>(*this, IPos); + if (NoReturnAA.isAssumedNoReturn()) + return I; } + + I = I->getNextNode(); } - return ChangeStatus::UNCHANGED; + + // get new paths (reachable blocks). + for (const BasicBlock *SuccBB : successors(BB)) { + assumeLive(A, *SuccBB); + ToBeExploredPaths.insert(&SuccBB->front()); + } + + // No noreturn instruction found. + return nullptr; } -/// ------------------------ NonNull Argument Attribute ------------------------ -struct AANonNullImpl : AANonNull, BooleanState { +ChangeStatus AAIsDeadImpl::updateImpl(Attributor &A) { + ChangeStatus Status = ChangeStatus::UNCHANGED; + + // Temporary collection to iterate over existing noreturn instructions. This + // will alow easier modification of NoReturnCalls collection + SmallVector<const Instruction *, 8> NoReturnChanged; + + for (const Instruction *I : NoReturnCalls) + NoReturnChanged.push_back(I); + + for (const Instruction *I : NoReturnChanged) { + size_t Size = ToBeExploredPaths.size(); + + const Instruction *NextNoReturnI = findNextNoReturn(A, I); + if (NextNoReturnI != I) { + Status = ChangeStatus::CHANGED; + NoReturnCalls.remove(I); + if (NextNoReturnI) + NoReturnCalls.insert(NextNoReturnI); + } - AANonNullImpl(Value &V, InformationCache &InfoCache) - : AANonNull(V, InfoCache) {} + // Explore new paths. + while (Size != ToBeExploredPaths.size()) { + Status = ChangeStatus::CHANGED; + if (const Instruction *NextNoReturnI = + findNextNoReturn(A, ToBeExploredPaths[Size++])) + NoReturnCalls.insert(NextNoReturnI); + } + } + + LLVM_DEBUG(dbgs() << "[AAIsDead] AssumedLiveBlocks: " + << AssumedLiveBlocks.size() << " Total number of blocks: " + << getAssociatedFunction()->size() << "\n"); - AANonNullImpl(Value *AssociatedVal, Value &AnchoredValue, - InformationCache &InfoCache) - : AANonNull(AssociatedVal, AnchoredValue, InfoCache) {} + // If we know everything is live there is no need to query for liveness. + if (NoReturnCalls.empty() && + getAssociatedFunction()->size() == AssumedLiveBlocks.size()) { + // Indicating a pessimistic fixpoint will cause the state to be "invalid" + // which will cause the Attributor to not return the AAIsDead on request, + // which will prevent us from querying isAssumedDead(). + indicatePessimisticFixpoint(); + assert(!isValidState() && "Expected an invalid state!"); + Status = ChangeStatus::CHANGED; + } + + return Status; +} + +/// Liveness information for a call sites. +struct AAIsDeadCallSite final : AAIsDeadImpl { + AAIsDeadCallSite(const IRPosition &IRP) : AAIsDeadImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites instead of + // redirecting requests to the callee. + llvm_unreachable("Abstract attributes for liveness are not " + "supported for call sites yet!"); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + return indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + +/// -------------------- Dereferenceable Argument Attribute -------------------- + +template <> +ChangeStatus clampStateAndIndicateChange<DerefState>(DerefState &S, + const DerefState &R) { + ChangeStatus CS0 = clampStateAndIndicateChange<IntegerState>( + S.DerefBytesState, R.DerefBytesState); + ChangeStatus CS1 = + clampStateAndIndicateChange<IntegerState>(S.GlobalState, R.GlobalState); + return CS0 | CS1; +} + +struct AADereferenceableImpl : AADereferenceable { + AADereferenceableImpl(const IRPosition &IRP) : AADereferenceable(IRP) {} + using StateType = DerefState; + + void initialize(Attributor &A) override { + SmallVector<Attribute, 4> Attrs; + getAttrs({Attribute::Dereferenceable, Attribute::DereferenceableOrNull}, + Attrs); + for (const Attribute &Attr : Attrs) + takeKnownDerefBytesMaximum(Attr.getValueAsInt()); + + NonNullAA = &A.getAAFor<AANonNull>(*this, getIRPosition()); + + const IRPosition &IRP = this->getIRPosition(); + bool IsFnInterface = IRP.isFnInterfaceKind(); + const Function *FnScope = IRP.getAnchorScope(); + if (IsFnInterface && (!FnScope || !FnScope->hasExactDefinition())) + indicatePessimisticFixpoint(); + } /// See AbstractAttribute::getState() /// { - AbstractState &getState() override { return *this; } - const AbstractState &getState() const override { return *this; } + StateType &getState() override { return *this; } + const StateType &getState() const override { return *this; } /// } + /// See AAFromMustBeExecutedContext + bool followUse(Attributor &A, const Use *U, const Instruction *I) { + bool IsNonNull = false; + bool TrackUse = false; + int64_t DerefBytes = getKnownNonNullAndDerefBytesForUse( + A, *this, getAssociatedValue(), U, I, IsNonNull, TrackUse); + takeKnownDerefBytesMaximum(DerefBytes); + return TrackUse; + } + + void getDeducedAttributes(LLVMContext &Ctx, + SmallVectorImpl<Attribute> &Attrs) const override { + // TODO: Add *_globally support + if (isAssumedNonNull()) + Attrs.emplace_back(Attribute::getWithDereferenceableBytes( + Ctx, getAssumedDereferenceableBytes())); + else + Attrs.emplace_back(Attribute::getWithDereferenceableOrNullBytes( + Ctx, getAssumedDereferenceableBytes())); + } + /// See AbstractAttribute::getAsStr(). const std::string getAsStr() const override { - return getAssumed() ? "nonnull" : "may-null"; + if (!getAssumedDereferenceableBytes()) + return "unknown-dereferenceable"; + return std::string("dereferenceable") + + (isAssumedNonNull() ? "" : "_or_null") + + (isAssumedGlobal() ? "_globally" : "") + "<" + + std::to_string(getKnownDereferenceableBytes()) + "-" + + std::to_string(getAssumedDereferenceableBytes()) + ">"; } +}; + +/// Dereferenceable attribute for a floating value. +struct AADereferenceableFloating + : AAFromMustBeExecutedContext<AADereferenceable, AADereferenceableImpl> { + using Base = + AAFromMustBeExecutedContext<AADereferenceable, AADereferenceableImpl>; + AADereferenceableFloating(const IRPosition &IRP) : Base(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Change = Base::updateImpl(A); + + const DataLayout &DL = A.getDataLayout(); + + auto VisitValueCB = [&](Value &V, DerefState &T, bool Stripped) -> bool { + unsigned IdxWidth = + DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace()); + APInt Offset(IdxWidth, 0); + const Value *Base = + V.stripAndAccumulateInBoundsConstantOffsets(DL, Offset); + + const auto &AA = + A.getAAFor<AADereferenceable>(*this, IRPosition::value(*Base)); + int64_t DerefBytes = 0; + if (!Stripped && this == &AA) { + // Use IR information if we did not strip anything. + // TODO: track globally. + bool CanBeNull; + DerefBytes = Base->getPointerDereferenceableBytes(DL, CanBeNull); + T.GlobalState.indicatePessimisticFixpoint(); + } else { + const DerefState &DS = static_cast<const DerefState &>(AA.getState()); + DerefBytes = DS.DerefBytesState.getAssumed(); + T.GlobalState &= DS.GlobalState; + } + + // For now we do not try to "increase" dereferenceability due to negative + // indices as we first have to come up with code to deal with loops and + // for overflows of the dereferenceable bytes. + int64_t OffsetSExt = Offset.getSExtValue(); + if (OffsetSExt < 0) + OffsetSExt = 0; + + T.takeAssumedDerefBytesMinimum( + std::max(int64_t(0), DerefBytes - OffsetSExt)); + + if (this == &AA) { + if (!Stripped) { + // If nothing was stripped IR information is all we got. + T.takeKnownDerefBytesMaximum( + std::max(int64_t(0), DerefBytes - OffsetSExt)); + T.indicatePessimisticFixpoint(); + } else if (OffsetSExt > 0) { + // If something was stripped but there is circular reasoning we look + // for the offset. If it is positive we basically decrease the + // dereferenceable bytes in a circluar loop now, which will simply + // drive them down to the known value in a very slow way which we + // can accelerate. + T.indicatePessimisticFixpoint(); + } + } + + return T.isValidState(); + }; - /// See AANonNull::isAssumedNonNull(). - bool isAssumedNonNull() const override { return getAssumed(); } + DerefState T; + if (!genericValueTraversal<AADereferenceable, DerefState>( + A, getIRPosition(), *this, T, VisitValueCB)) + return indicatePessimisticFixpoint(); - /// See AANonNull::isKnownNonNull(). - bool isKnownNonNull() const override { return getKnown(); } + return Change | clampStateAndIndicateChange(getState(), T); + } - /// Generate a predicate that checks if a given value is assumed nonnull. - /// The generated function returns true if a value satisfies any of - /// following conditions. - /// (i) A value is known nonZero(=nonnull). - /// (ii) A value is associated with AANonNull and its isAssumedNonNull() is - /// true. - std::function<bool(Value &)> generatePredicate(Attributor &); + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(dereferenceable) + } }; -std::function<bool(Value &)> AANonNullImpl::generatePredicate(Attributor &A) { - // FIXME: The `AAReturnedValues` should provide the predicate with the - // `ReturnInst` vector as well such that we can use the control flow sensitive - // version of `isKnownNonZero`. This should fix `test11` in - // `test/Transforms/FunctionAttrs/nonnull.ll` +/// Dereferenceable attribute for a return value. +struct AADereferenceableReturned final + : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl, + DerefState> { + AADereferenceableReturned(const IRPosition &IRP) + : AAReturnedFromReturnedValues<AADereferenceable, AADereferenceableImpl, + DerefState>(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(dereferenceable) + } +}; - std::function<bool(Value &)> Pred = [&](Value &RV) -> bool { - if (isKnownNonZero(&RV, getAnchorScope().getParent()->getDataLayout())) - return true; +/// Dereferenceable attribute for an argument +struct AADereferenceableArgument final + : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext< + AADereferenceable, AADereferenceableImpl, DerefState> { + using Base = AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext< + AADereferenceable, AADereferenceableImpl, DerefState>; + AADereferenceableArgument(const IRPosition &IRP) : Base(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_ARG_ATTR(dereferenceable) + } +}; - auto *NonNullAA = A.getAAFor<AANonNull>(*this, RV); +/// Dereferenceable attribute for a call site argument. +struct AADereferenceableCallSiteArgument final : AADereferenceableFloating { + AADereferenceableCallSiteArgument(const IRPosition &IRP) + : AADereferenceableFloating(IRP) {} - ImmutableCallSite ICS(&RV); + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(dereferenceable) + } +}; - if ((!NonNullAA || !NonNullAA->isAssumedNonNull()) && - (!ICS || !ICS.hasRetAttr(Attribute::NonNull))) - return false; +/// Dereferenceable attribute deduction for a call site return value. +struct AADereferenceableCallSiteReturned final + : AACallSiteReturnedFromReturnedAndMustBeExecutedContext< + AADereferenceable, AADereferenceableImpl> { + using Base = AACallSiteReturnedFromReturnedAndMustBeExecutedContext< + AADereferenceable, AADereferenceableImpl>; + AADereferenceableCallSiteReturned(const IRPosition &IRP) : Base(IRP) {} - return true; - }; + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + Base::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) + indicatePessimisticFixpoint(); + } - return Pred; -} + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + + ChangeStatus Change = Base::updateImpl(A); + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::returned(*F); + auto &FnAA = A.getAAFor<AADereferenceable>(*this, FnPos); + return Change | + clampStateAndIndicateChange( + getState(), static_cast<const DerefState &>(FnAA.getState())); + } -/// NonNull attribute for function return value. -struct AANonNullReturned : AANonNullImpl { + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CS_ATTR(dereferenceable); + } +}; - AANonNullReturned(Function &F, InformationCache &InfoCache) - : AANonNullImpl(F, InfoCache) {} +// ------------------------ Align Argument Attribute ------------------------ - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { return MP_RETURNED; } +struct AAAlignImpl : AAAlign { + AAAlignImpl(const IRPosition &IRP) : AAAlign(IRP) {} - /// See AbstractAttriubute::initialize(...). + // Max alignemnt value allowed in IR + static const unsigned MAX_ALIGN = 1U << 29; + + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - Function &F = getAnchorScope(); + takeAssumedMinimum(MAX_ALIGN); - // Already nonnull. - if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex, - Attribute::NonNull)) - indicateOptimisticFixpoint(); + SmallVector<Attribute, 4> Attrs; + getAttrs({Attribute::Alignment}, Attrs); + for (const Attribute &Attr : Attrs) + takeKnownMaximum(Attr.getValueAsInt()); + + if (getIRPosition().isFnInterfaceKind() && + (!getAssociatedFunction() || + !getAssociatedFunction()->hasExactDefinition())) + indicatePessimisticFixpoint(); } + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + // Check for users that allow alignment annotations. + Value &AnchorVal = getIRPosition().getAnchorValue(); + for (const Use &U : AnchorVal.uses()) { + if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { + if (SI->getPointerOperand() == &AnchorVal) + if (SI->getAlignment() < getAssumedAlign()) { + STATS_DECLTRACK(AAAlign, Store, + "Number of times alignemnt added to a store"); + SI->setAlignment(Align(getAssumedAlign())); + Changed = ChangeStatus::CHANGED; + } + } else if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { + if (LI->getPointerOperand() == &AnchorVal) + if (LI->getAlignment() < getAssumedAlign()) { + LI->setAlignment(Align(getAssumedAlign())); + STATS_DECLTRACK(AAAlign, Load, + "Number of times alignemnt added to a load"); + Changed = ChangeStatus::CHANGED; + } + } + } + + return AAAlign::manifest(A) | Changed; + } + + // TODO: Provide a helper to determine the implied ABI alignment and check in + // the existing manifest method and a new one for AAAlignImpl that value + // to avoid making the alignment explicit if it did not improve. + + /// See AbstractAttribute::getDeducedAttributes + virtual void + getDeducedAttributes(LLVMContext &Ctx, + SmallVectorImpl<Attribute> &Attrs) const override { + if (getAssumedAlign() > 1) + Attrs.emplace_back( + Attribute::getWithAlignment(Ctx, Align(getAssumedAlign()))); + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumedAlign() ? ("align<" + std::to_string(getKnownAlign()) + + "-" + std::to_string(getAssumedAlign()) + ">") + : "unknown-align"; + } +}; + +/// Align attribute for a floating value. +struct AAAlignFloating : AAAlignImpl { + AAAlignFloating(const IRPosition &IRP) : AAAlignImpl(IRP) {} + /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override; + ChangeStatus updateImpl(Attributor &A) override { + const DataLayout &DL = A.getDataLayout(); + + auto VisitValueCB = [&](Value &V, AAAlign::StateType &T, + bool Stripped) -> bool { + const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V)); + if (!Stripped && this == &AA) { + // Use only IR information if we did not strip anything. + const MaybeAlign PA = V.getPointerAlignment(DL); + T.takeKnownMaximum(PA ? PA->value() : 0); + T.indicatePessimisticFixpoint(); + } else { + // Use abstract attribute information. + const AAAlign::StateType &DS = + static_cast<const AAAlign::StateType &>(AA.getState()); + T ^= DS; + } + return T.isValidState(); + }; + + StateType T; + if (!genericValueTraversal<AAAlign, StateType>(A, getIRPosition(), *this, T, + VisitValueCB)) + return indicatePessimisticFixpoint(); + + // TODO: If we know we visited all incoming values, thus no are assumed + // dead, we can take the known information from the state T. + return clampStateAndIndicateChange(getState(), T); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FLOATING_ATTR(align) } }; -ChangeStatus AANonNullReturned::updateImpl(Attributor &A) { - Function &F = getAnchorScope(); +/// Align attribute for function return value. +struct AAAlignReturned final + : AAReturnedFromReturnedValues<AAAlign, AAAlignImpl> { + AAAlignReturned(const IRPosition &IRP) + : AAReturnedFromReturnedValues<AAAlign, AAAlignImpl>(IRP) {} - auto *AARetVal = A.getAAFor<AAReturnedValues>(*this, F); - if (!AARetVal) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(aligned) } +}; + +/// Align attribute for function argument. +struct AAAlignArgument final + : AAArgumentFromCallSiteArguments<AAAlign, AAAlignImpl> { + AAAlignArgument(const IRPosition &IRP) + : AAArgumentFromCallSiteArguments<AAAlign, AAAlignImpl>(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(aligned) } +}; + +struct AAAlignCallSiteArgument final : AAAlignFloating { + AAAlignCallSiteArgument(const IRPosition &IRP) : AAAlignFloating(IRP) {} + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + return AAAlignImpl::manifest(A); } - std::function<bool(Value &)> Pred = this->generatePredicate(A); - if (!AARetVal->checkForallReturnedValues(Pred)) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CSARG_ATTR(aligned) } +}; + +/// Align attribute deduction for a call site return value. +struct AAAlignCallSiteReturned final : AAAlignImpl { + AAAlignCallSiteReturned(const IRPosition &IRP) : AAAlignImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAAlignImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F) + indicatePessimisticFixpoint(); } - return ChangeStatus::UNCHANGED; -} -/// NonNull attribute for function argument. -struct AANonNullArgument : AANonNullImpl { + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::returned(*F); + auto &FnAA = A.getAAFor<AAAlign>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), static_cast<const AAAlign::StateType &>(FnAA.getState())); + } - AANonNullArgument(Argument &A, InformationCache &InfoCache) - : AANonNullImpl(A, InfoCache) {} + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(align); } +}; - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; } +/// ------------------ Function No-Return Attribute ---------------------------- +struct AANoReturnImpl : public AANoReturn { + AANoReturnImpl(const IRPosition &IRP) : AANoReturn(IRP) {} - /// See AbstractAttriubute::initialize(...). + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - Argument *Arg = cast<Argument>(getAssociatedValue()); - if (Arg->hasNonNullAttr()) - indicateOptimisticFixpoint(); + AANoReturn::initialize(A); + Function *F = getAssociatedFunction(); + if (!F || F->hasFnAttribute(Attribute::WillReturn)) + indicatePessimisticFixpoint(); } + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "noreturn" : "may-return"; + } + + /// See AbstractAttribute::updateImpl(Attributor &A). + virtual ChangeStatus updateImpl(Attributor &A) override { + const auto &WillReturnAA = A.getAAFor<AAWillReturn>(*this, getIRPosition()); + if (WillReturnAA.isKnownWillReturn()) + return indicatePessimisticFixpoint(); + auto CheckForNoReturn = [](Instruction &) { return false; }; + if (!A.checkForAllInstructions(CheckForNoReturn, *this, + {(unsigned)Instruction::Ret})) + return indicatePessimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } +}; + +struct AANoReturnFunction final : AANoReturnImpl { + AANoReturnFunction(const IRPosition &IRP) : AANoReturnImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(noreturn) } +}; + +/// NoReturn attribute deduction for a call sites. +struct AANoReturnCallSite final : AANoReturnImpl { + AANoReturnCallSite(const IRPosition &IRP) : AANoReturnImpl(IRP) {} + /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override; + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AANoReturn>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), + static_cast<const AANoReturn::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(noreturn); } }; -/// NonNull attribute for a call site argument. -struct AANonNullCallSiteArgument : AANonNullImpl { +/// ----------------------- Variable Capturing --------------------------------- - /// See AANonNullImpl::AANonNullImpl(...). - AANonNullCallSiteArgument(CallSite CS, unsigned ArgNo, - InformationCache &InfoCache) - : AANonNullImpl(CS.getArgOperand(ArgNo), *CS.getInstruction(), InfoCache), - ArgNo(ArgNo) {} +/// A class to hold the state of for no-capture attributes. +struct AANoCaptureImpl : public AANoCapture { + AANoCaptureImpl(const IRPosition &IRP) : AANoCapture(IRP) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - CallSite CS(&getAnchoredValue()); - if (isKnownNonZero(getAssociatedValue(), - getAnchorScope().getParent()->getDataLayout()) || - CS.paramHasAttr(ArgNo, getAttrKind())) + AANoCapture::initialize(A); + + // You cannot "capture" null in the default address space. + if (isa<ConstantPointerNull>(getAssociatedValue()) && + getAssociatedValue().getType()->getPointerAddressSpace() == 0) { indicateOptimisticFixpoint(); + return; + } + + const IRPosition &IRP = getIRPosition(); + const Function *F = + getArgNo() >= 0 ? IRP.getAssociatedFunction() : IRP.getAnchorScope(); + + // Check what state the associated function can actually capture. + if (F) + determineFunctionCaptureCapabilities(IRP, *F, *this); + else + indicatePessimisticFixpoint(); } - /// See AbstractAttribute::updateImpl(Attributor &A). + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override; - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { - return MP_CALL_SITE_ARGUMENT; - }; + /// see AbstractAttribute::isAssumedNoCaptureMaybeReturned(...). + virtual void + getDeducedAttributes(LLVMContext &Ctx, + SmallVectorImpl<Attribute> &Attrs) const override { + if (!isAssumedNoCaptureMaybeReturned()) + return; + + if (getArgNo() >= 0) { + if (isAssumedNoCapture()) + Attrs.emplace_back(Attribute::get(Ctx, Attribute::NoCapture)); + else if (ManifestInternal) + Attrs.emplace_back(Attribute::get(Ctx, "no-capture-maybe-returned")); + } + } + + /// Set the NOT_CAPTURED_IN_MEM and NOT_CAPTURED_IN_RET bits in \p Known + /// depending on the ability of the function associated with \p IRP to capture + /// state in memory and through "returning/throwing", respectively. + static void determineFunctionCaptureCapabilities(const IRPosition &IRP, + const Function &F, + IntegerState &State) { + // TODO: Once we have memory behavior attributes we should use them here. + + // If we know we cannot communicate or write to memory, we do not care about + // ptr2int anymore. + if (F.onlyReadsMemory() && F.doesNotThrow() && + F.getReturnType()->isVoidTy()) { + State.addKnownBits(NO_CAPTURE); + return; + } + + // A function cannot capture state in memory if it only reads memory, it can + // however return/throw state and the state might be influenced by the + // pointer value, e.g., loading from a returned pointer might reveal a bit. + if (F.onlyReadsMemory()) + State.addKnownBits(NOT_CAPTURED_IN_MEM); + + // A function cannot communicate state back if it does not through + // exceptions and doesn not return values. + if (F.doesNotThrow() && F.getReturnType()->isVoidTy()) + State.addKnownBits(NOT_CAPTURED_IN_RET); + + // Check existing "returned" attributes. + int ArgNo = IRP.getArgNo(); + if (F.doesNotThrow() && ArgNo >= 0) { + for (unsigned u = 0, e = F.arg_size(); u< e; ++u) + if (F.hasParamAttribute(u, Attribute::Returned)) { + if (u == unsigned(ArgNo)) + State.removeAssumedBits(NOT_CAPTURED_IN_RET); + else if (F.onlyReadsMemory()) + State.addKnownBits(NO_CAPTURE); + else + State.addKnownBits(NOT_CAPTURED_IN_RET); + break; + } + } + } - // Return argument index of associated value. - int getArgNo() const { return ArgNo; } + /// See AbstractState::getAsStr(). + const std::string getAsStr() const override { + if (isKnownNoCapture()) + return "known not-captured"; + if (isAssumedNoCapture()) + return "assumed not-captured"; + if (isKnownNoCaptureMaybeReturned()) + return "known not-captured-maybe-returned"; + if (isAssumedNoCaptureMaybeReturned()) + return "assumed not-captured-maybe-returned"; + return "assumed-captured"; + } +}; + +/// Attributor-aware capture tracker. +struct AACaptureUseTracker final : public CaptureTracker { + + /// Create a capture tracker that can lookup in-flight abstract attributes + /// through the Attributor \p A. + /// + /// If a use leads to a potential capture, \p CapturedInMemory is set and the + /// search is stopped. If a use leads to a return instruction, + /// \p CommunicatedBack is set to true and \p CapturedInMemory is not changed. + /// If a use leads to a ptr2int which may capture the value, + /// \p CapturedInInteger is set. If a use is found that is currently assumed + /// "no-capture-maybe-returned", the user is added to the \p PotentialCopies + /// set. All values in \p PotentialCopies are later tracked as well. For every + /// explored use we decrement \p RemainingUsesToExplore. Once it reaches 0, + /// the search is stopped with \p CapturedInMemory and \p CapturedInInteger + /// conservatively set to true. + AACaptureUseTracker(Attributor &A, AANoCapture &NoCaptureAA, + const AAIsDead &IsDeadAA, IntegerState &State, + SmallVectorImpl<const Value *> &PotentialCopies, + unsigned &RemainingUsesToExplore) + : A(A), NoCaptureAA(NoCaptureAA), IsDeadAA(IsDeadAA), State(State), + PotentialCopies(PotentialCopies), + RemainingUsesToExplore(RemainingUsesToExplore) {} + + /// Determine if \p V maybe captured. *Also updates the state!* + bool valueMayBeCaptured(const Value *V) { + if (V->getType()->isPointerTy()) { + PointerMayBeCaptured(V, this); + } else { + State.indicatePessimisticFixpoint(); + } + return State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED); + } + + /// See CaptureTracker::tooManyUses(). + void tooManyUses() override { + State.removeAssumedBits(AANoCapture::NO_CAPTURE); + } + + bool isDereferenceableOrNull(Value *O, const DataLayout &DL) override { + if (CaptureTracker::isDereferenceableOrNull(O, DL)) + return true; + const auto &DerefAA = + A.getAAFor<AADereferenceable>(NoCaptureAA, IRPosition::value(*O)); + return DerefAA.getAssumedDereferenceableBytes(); + } + + /// See CaptureTracker::captured(...). + bool captured(const Use *U) override { + Instruction *UInst = cast<Instruction>(U->getUser()); + LLVM_DEBUG(dbgs() << "Check use: " << *U->get() << " in " << *UInst + << "\n"); + + // Because we may reuse the tracker multiple times we keep track of the + // number of explored uses ourselves as well. + if (RemainingUsesToExplore-- == 0) { + LLVM_DEBUG(dbgs() << " - too many uses to explore!\n"); + return isCapturedIn(/* Memory */ true, /* Integer */ true, + /* Return */ true); + } + + // Deal with ptr2int by following uses. + if (isa<PtrToIntInst>(UInst)) { + LLVM_DEBUG(dbgs() << " - ptr2int assume the worst!\n"); + return valueMayBeCaptured(UInst); + } + + // Explicitly catch return instructions. + if (isa<ReturnInst>(UInst)) + return isCapturedIn(/* Memory */ false, /* Integer */ false, + /* Return */ true); + + // For now we only use special logic for call sites. However, the tracker + // itself knows about a lot of other non-capturing cases already. + CallSite CS(UInst); + if (!CS || !CS.isArgOperand(U)) + return isCapturedIn(/* Memory */ true, /* Integer */ true, + /* Return */ true); + + unsigned ArgNo = CS.getArgumentNo(U); + const IRPosition &CSArgPos = IRPosition::callsite_argument(CS, ArgNo); + // If we have a abstract no-capture attribute for the argument we can use + // it to justify a non-capture attribute here. This allows recursion! + auto &ArgNoCaptureAA = A.getAAFor<AANoCapture>(NoCaptureAA, CSArgPos); + if (ArgNoCaptureAA.isAssumedNoCapture()) + return isCapturedIn(/* Memory */ false, /* Integer */ false, + /* Return */ false); + if (ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) { + addPotentialCopy(CS); + return isCapturedIn(/* Memory */ false, /* Integer */ false, + /* Return */ false); + } + + // Lastly, we could not find a reason no-capture can be assumed so we don't. + return isCapturedIn(/* Memory */ true, /* Integer */ true, + /* Return */ true); + } + + /// Register \p CS as potential copy of the value we are checking. + void addPotentialCopy(CallSite CS) { + PotentialCopies.push_back(CS.getInstruction()); + } + + /// See CaptureTracker::shouldExplore(...). + bool shouldExplore(const Use *U) override { + // Check liveness. + return !IsDeadAA.isAssumedDead(cast<Instruction>(U->getUser())); + } + + /// Update the state according to \p CapturedInMem, \p CapturedInInt, and + /// \p CapturedInRet, then return the appropriate value for use in the + /// CaptureTracker::captured() interface. + bool isCapturedIn(bool CapturedInMem, bool CapturedInInt, + bool CapturedInRet) { + LLVM_DEBUG(dbgs() << " - captures [Mem " << CapturedInMem << "|Int " + << CapturedInInt << "|Ret " << CapturedInRet << "]\n"); + if (CapturedInMem) + State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_MEM); + if (CapturedInInt) + State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_INT); + if (CapturedInRet) + State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_RET); + return !State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED); + } private: - unsigned ArgNo; + /// The attributor providing in-flight abstract attributes. + Attributor &A; + + /// The abstract attribute currently updated. + AANoCapture &NoCaptureAA; + + /// The abstract liveness state. + const AAIsDead &IsDeadAA; + + /// The state currently updated. + IntegerState &State; + + /// Set of potential copies of the tracked value. + SmallVectorImpl<const Value *> &PotentialCopies; + + /// Global counter to limit the number of explored uses. + unsigned &RemainingUsesToExplore; +}; + +ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { + const IRPosition &IRP = getIRPosition(); + const Value *V = + getArgNo() >= 0 ? IRP.getAssociatedArgument() : &IRP.getAssociatedValue(); + if (!V) + return indicatePessimisticFixpoint(); + + const Function *F = + getArgNo() >= 0 ? IRP.getAssociatedFunction() : IRP.getAnchorScope(); + assert(F && "Expected a function!"); + const IRPosition &FnPos = IRPosition::function(*F); + const auto &IsDeadAA = A.getAAFor<AAIsDead>(*this, FnPos); + + AANoCapture::StateType T; + + // Readonly means we cannot capture through memory. + const auto &FnMemAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos); + if (FnMemAA.isAssumedReadOnly()) { + T.addKnownBits(NOT_CAPTURED_IN_MEM); + if (FnMemAA.isKnownReadOnly()) + addKnownBits(NOT_CAPTURED_IN_MEM); + } + + // Make sure all returned values are different than the underlying value. + // TODO: we could do this in a more sophisticated way inside + // AAReturnedValues, e.g., track all values that escape through returns + // directly somehow. + auto CheckReturnedArgs = [&](const AAReturnedValues &RVAA) { + bool SeenConstant = false; + for (auto &It : RVAA.returned_values()) { + if (isa<Constant>(It.first)) { + if (SeenConstant) + return false; + SeenConstant = true; + } else if (!isa<Argument>(It.first) || + It.first == getAssociatedArgument()) + return false; + } + return true; + }; + + const auto &NoUnwindAA = A.getAAFor<AANoUnwind>(*this, FnPos); + if (NoUnwindAA.isAssumedNoUnwind()) { + bool IsVoidTy = F->getReturnType()->isVoidTy(); + const AAReturnedValues *RVAA = + IsVoidTy ? nullptr : &A.getAAFor<AAReturnedValues>(*this, FnPos); + if (IsVoidTy || CheckReturnedArgs(*RVAA)) { + T.addKnownBits(NOT_CAPTURED_IN_RET); + if (T.isKnown(NOT_CAPTURED_IN_MEM)) + return ChangeStatus::UNCHANGED; + if (NoUnwindAA.isKnownNoUnwind() && + (IsVoidTy || RVAA->getState().isAtFixpoint())) { + addKnownBits(NOT_CAPTURED_IN_RET); + if (isKnown(NOT_CAPTURED_IN_MEM)) + return indicateOptimisticFixpoint(); + } + } + } + + // Use the CaptureTracker interface and logic with the specialized tracker, + // defined in AACaptureUseTracker, that can look at in-flight abstract + // attributes and directly updates the assumed state. + SmallVector<const Value *, 4> PotentialCopies; + unsigned RemainingUsesToExplore = DefaultMaxUsesToExplore; + AACaptureUseTracker Tracker(A, *this, IsDeadAA, T, PotentialCopies, + RemainingUsesToExplore); + + // Check all potential copies of the associated value until we can assume + // none will be captured or we have to assume at least one might be. + unsigned Idx = 0; + PotentialCopies.push_back(V); + while (T.isAssumed(NO_CAPTURE_MAYBE_RETURNED) && Idx < PotentialCopies.size()) + Tracker.valueMayBeCaptured(PotentialCopies[Idx++]); + + AAAlign::StateType &S = getState(); + auto Assumed = S.getAssumed(); + S.intersectAssumedBits(T.getAssumed()); + return Assumed == S.getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; +} + +/// NoCapture attribute for function arguments. +struct AANoCaptureArgument final : AANoCaptureImpl { + AANoCaptureArgument(const IRPosition &IRP) : AANoCaptureImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nocapture) } +}; + +/// NoCapture attribute for call site arguments. +struct AANoCaptureCallSiteArgument final : AANoCaptureImpl { + AANoCaptureCallSiteArgument(const IRPosition &IRP) : AANoCaptureImpl(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Argument *Arg = getAssociatedArgument(); + if (!Arg) + return indicatePessimisticFixpoint(); + const IRPosition &ArgPos = IRPosition::argument(*Arg); + auto &ArgAA = A.getAAFor<AANoCapture>(*this, ArgPos); + return clampStateAndIndicateChange( + getState(), + static_cast<const AANoCapture::StateType &>(ArgAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override{STATS_DECLTRACK_CSARG_ATTR(nocapture)}; +}; + +/// NoCapture attribute for floating values. +struct AANoCaptureFloating final : AANoCaptureImpl { + AANoCaptureFloating(const IRPosition &IRP) : AANoCaptureImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(nocapture) + } +}; + +/// NoCapture attribute for function return value. +struct AANoCaptureReturned final : AANoCaptureImpl { + AANoCaptureReturned(const IRPosition &IRP) : AANoCaptureImpl(IRP) { + llvm_unreachable("NoCapture is not applicable to function returns!"); + } + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + llvm_unreachable("NoCapture is not applicable to function returns!"); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + llvm_unreachable("NoCapture is not applicable to function returns!"); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + +/// NoCapture attribute deduction for a call site return value. +struct AANoCaptureCallSiteReturned final : AANoCaptureImpl { + AANoCaptureCallSiteReturned(const IRPosition &IRP) : AANoCaptureImpl(IRP) {} + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(nocapture) + } }; -ChangeStatus AANonNullArgument::updateImpl(Attributor &A) { - Function &F = getAnchorScope(); - Argument &Arg = cast<Argument>(getAnchoredValue()); - unsigned ArgNo = Arg.getArgNo(); +/// ------------------ Value Simplify Attribute ---------------------------- +struct AAValueSimplifyImpl : AAValueSimplify { + AAValueSimplifyImpl(const IRPosition &IRP) : AAValueSimplify(IRP) {} + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? (getKnown() ? "simplified" : "maybe-simple") + : "not-simple"; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} + + /// See AAValueSimplify::getAssumedSimplifiedValue() + Optional<Value *> getAssumedSimplifiedValue(Attributor &A) const override { + if (!getAssumed()) + return const_cast<Value *>(&getAssociatedValue()); + return SimplifiedAssociatedValue; + } + void initialize(Attributor &A) override {} + + /// Helper function for querying AAValueSimplify and updating candicate. + /// \param QueryingValue Value trying to unify with SimplifiedValue + /// \param AccumulatedSimplifiedValue Current simplification result. + static bool checkAndUpdate(Attributor &A, const AbstractAttribute &QueryingAA, + Value &QueryingValue, + Optional<Value *> &AccumulatedSimplifiedValue) { + // FIXME: Add a typecast support. + + auto &ValueSimpifyAA = A.getAAFor<AAValueSimplify>( + QueryingAA, IRPosition::value(QueryingValue)); - // Callback function - std::function<bool(CallSite)> CallSiteCheck = [&](CallSite CS) { - assert(CS && "Sanity check: Call site was not initialized properly!"); + Optional<Value *> QueryingValueSimplified = + ValueSimpifyAA.getAssumedSimplifiedValue(A); - auto *NonNullAA = A.getAAFor<AANonNull>(*this, *CS.getInstruction(), ArgNo); + if (!QueryingValueSimplified.hasValue()) + return true; - // Check that NonNullAA is AANonNullCallSiteArgument. - if (NonNullAA) { - ImmutableCallSite ICS(&NonNullAA->getAnchoredValue()); - if (ICS && CS.getInstruction() == ICS.getInstruction()) - return NonNullAA->isAssumedNonNull(); + if (!QueryingValueSimplified.getValue()) return false; + + Value &QueryingValueSimplifiedUnwrapped = + *QueryingValueSimplified.getValue(); + + if (isa<UndefValue>(QueryingValueSimplifiedUnwrapped)) + return true; + + if (AccumulatedSimplifiedValue.hasValue()) + return AccumulatedSimplifiedValue == QueryingValueSimplified; + + LLVM_DEBUG(dbgs() << "[Attributor][ValueSimplify] " << QueryingValue + << " is assumed to be " + << QueryingValueSimplifiedUnwrapped << "\n"); + + AccumulatedSimplifiedValue = QueryingValueSimplified; + return true; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + if (!SimplifiedAssociatedValue.hasValue() || + !SimplifiedAssociatedValue.getValue()) + return Changed; + + if (auto *C = dyn_cast<Constant>(SimplifiedAssociatedValue.getValue())) { + // We can replace the AssociatedValue with the constant. + Value &V = getAssociatedValue(); + if (!V.user_empty() && &V != C && V.getType() == C->getType()) { + LLVM_DEBUG(dbgs() << "[Attributor][ValueSimplify] " << V << " -> " << *C + << "\n"); + V.replaceAllUsesWith(C); + Changed = ChangeStatus::CHANGED; + } + } + + return Changed | AAValueSimplify::manifest(A); + } + +protected: + // An assumed simplified value. Initially, it is set to Optional::None, which + // means that the value is not clear under current assumption. If in the + // pessimistic state, getAssumedSimplifiedValue doesn't return this value but + // returns orignal associated value. + Optional<Value *> SimplifiedAssociatedValue; +}; + +struct AAValueSimplifyArgument final : AAValueSimplifyImpl { + AAValueSimplifyArgument(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + bool HasValueBefore = SimplifiedAssociatedValue.hasValue(); + + auto PredForCallSite = [&](AbstractCallSite ACS) { + // Check if we have an associated argument or not (which can happen for + // callback calls). + if (Value *ArgOp = ACS.getCallArgOperand(getArgNo())) + return checkAndUpdate(A, *this, *ArgOp, SimplifiedAssociatedValue); + return false; + }; + + if (!A.checkForAllCallSites(PredForCallSite, *this, true)) + return indicatePessimisticFixpoint(); + + // If a candicate was found in this update, return CHANGED. + return HasValueBefore == SimplifiedAssociatedValue.hasValue() + ? ChangeStatus::UNCHANGED + : ChangeStatus ::CHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_ARG_ATTR(value_simplify) + } +}; + +struct AAValueSimplifyReturned : AAValueSimplifyImpl { + AAValueSimplifyReturned(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + bool HasValueBefore = SimplifiedAssociatedValue.hasValue(); + + auto PredForReturned = [&](Value &V) { + return checkAndUpdate(A, *this, V, SimplifiedAssociatedValue); + }; + + if (!A.checkForAllReturnedValues(PredForReturned, *this)) + return indicatePessimisticFixpoint(); + + // If a candicate was found in this update, return CHANGED. + return HasValueBefore == SimplifiedAssociatedValue.hasValue() + ? ChangeStatus::UNCHANGED + : ChangeStatus ::CHANGED; + } + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(value_simplify) + } +}; + +struct AAValueSimplifyFloating : AAValueSimplifyImpl { + AAValueSimplifyFloating(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + Value &V = getAnchorValue(); + + // TODO: add other stuffs + if (isa<Constant>(V) || isa<UndefValue>(V)) + indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + bool HasValueBefore = SimplifiedAssociatedValue.hasValue(); + + auto VisitValueCB = [&](Value &V, BooleanState, bool Stripped) -> bool { + auto &AA = A.getAAFor<AAValueSimplify>(*this, IRPosition::value(V)); + if (!Stripped && this == &AA) { + // TODO: Look the instruction and check recursively. + LLVM_DEBUG( + dbgs() << "[Attributor][ValueSimplify] Can't be stripped more : " + << V << "\n"); + indicatePessimisticFixpoint(); + return false; + } + return checkAndUpdate(A, *this, V, SimplifiedAssociatedValue); + }; + + if (!genericValueTraversal<AAValueSimplify, BooleanState>( + A, getIRPosition(), *this, static_cast<BooleanState &>(*this), + VisitValueCB)) + return indicatePessimisticFixpoint(); + + // If a candicate was found in this update, return CHANGED. + + return HasValueBefore == SimplifiedAssociatedValue.hasValue() + ? ChangeStatus::UNCHANGED + : ChangeStatus ::CHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(value_simplify) + } +}; + +struct AAValueSimplifyFunction : AAValueSimplifyImpl { + AAValueSimplifyFunction(const IRPosition &IRP) : AAValueSimplifyImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + SimplifiedAssociatedValue = &getAnchorValue(); + indicateOptimisticFixpoint(); + } + /// See AbstractAttribute::initialize(...). + ChangeStatus updateImpl(Attributor &A) override { + llvm_unreachable( + "AAValueSimplify(Function|CallSite)::updateImpl will not be called"); + } + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FN_ATTR(value_simplify) + } +}; + +struct AAValueSimplifyCallSite : AAValueSimplifyFunction { + AAValueSimplifyCallSite(const IRPosition &IRP) + : AAValueSimplifyFunction(IRP) {} + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_CS_ATTR(value_simplify) + } +}; + +struct AAValueSimplifyCallSiteReturned : AAValueSimplifyReturned { + AAValueSimplifyCallSiteReturned(const IRPosition &IRP) + : AAValueSimplifyReturned(IRP) {} + + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(value_simplify) + } +}; +struct AAValueSimplifyCallSiteArgument : AAValueSimplifyFloating { + AAValueSimplifyCallSiteArgument(const IRPosition &IRP) + : AAValueSimplifyFloating(IRP) {} + + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(value_simplify) + } +}; + +/// ----------------------- Heap-To-Stack Conversion --------------------------- +struct AAHeapToStackImpl : public AAHeapToStack { + AAHeapToStackImpl(const IRPosition &IRP) : AAHeapToStack(IRP) {} + + const std::string getAsStr() const override { + return "[H2S] Mallocs: " + std::to_string(MallocCalls.size()); + } + + ChangeStatus manifest(Attributor &A) override { + assert(getState().isValidState() && + "Attempted to manifest an invalid state!"); + + ChangeStatus HasChanged = ChangeStatus::UNCHANGED; + Function *F = getAssociatedFunction(); + const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F); + + for (Instruction *MallocCall : MallocCalls) { + // This malloc cannot be replaced. + if (BadMallocCalls.count(MallocCall)) + continue; + + for (Instruction *FreeCall : FreesForMalloc[MallocCall]) { + LLVM_DEBUG(dbgs() << "H2S: Removing free call: " << *FreeCall << "\n"); + A.deleteAfterManifest(*FreeCall); + HasChanged = ChangeStatus::CHANGED; + } + + LLVM_DEBUG(dbgs() << "H2S: Removing malloc call: " << *MallocCall + << "\n"); + + Constant *Size; + if (isCallocLikeFn(MallocCall, TLI)) { + auto *Num = cast<ConstantInt>(MallocCall->getOperand(0)); + auto *SizeT = dyn_cast<ConstantInt>(MallocCall->getOperand(1)); + APInt TotalSize = SizeT->getValue() * Num->getValue(); + Size = + ConstantInt::get(MallocCall->getOperand(0)->getType(), TotalSize); + } else { + Size = cast<ConstantInt>(MallocCall->getOperand(0)); + } + + unsigned AS = cast<PointerType>(MallocCall->getType())->getAddressSpace(); + Instruction *AI = new AllocaInst(Type::getInt8Ty(F->getContext()), AS, + Size, "", MallocCall->getNextNode()); + + if (AI->getType() != MallocCall->getType()) + AI = new BitCastInst(AI, MallocCall->getType(), "malloc_bc", + AI->getNextNode()); + + MallocCall->replaceAllUsesWith(AI); + + if (auto *II = dyn_cast<InvokeInst>(MallocCall)) { + auto *NBB = II->getNormalDest(); + BranchInst::Create(NBB, MallocCall->getParent()); + A.deleteAfterManifest(*MallocCall); + } else { + A.deleteAfterManifest(*MallocCall); + } + + if (isCallocLikeFn(MallocCall, TLI)) { + auto *BI = new BitCastInst(AI, MallocCall->getType(), "calloc_bc", + AI->getNextNode()); + Value *Ops[] = { + BI, ConstantInt::get(F->getContext(), APInt(8, 0, false)), Size, + ConstantInt::get(Type::getInt1Ty(F->getContext()), false)}; + + Type *Tys[] = {BI->getType(), MallocCall->getOperand(0)->getType()}; + Module *M = F->getParent(); + Function *Fn = Intrinsic::getDeclaration(M, Intrinsic::memset, Tys); + CallInst::Create(Fn, Ops, "", BI->getNextNode()); + } + HasChanged = ChangeStatus::CHANGED; } - if (CS.paramHasAttr(ArgNo, Attribute::NonNull)) + return HasChanged; + } + + /// Collection of all malloc calls in a function. + SmallSetVector<Instruction *, 4> MallocCalls; + + /// Collection of malloc calls that cannot be converted. + DenseSet<const Instruction *> BadMallocCalls; + + /// A map for each malloc call to the set of associated free calls. + DenseMap<Instruction *, SmallPtrSet<Instruction *, 4>> FreesForMalloc; + + ChangeStatus updateImpl(Attributor &A) override; +}; + +ChangeStatus AAHeapToStackImpl::updateImpl(Attributor &A) { + const Function *F = getAssociatedFunction(); + const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F); + + auto UsesCheck = [&](Instruction &I) { + SmallPtrSet<const Use *, 8> Visited; + SmallVector<const Use *, 8> Worklist; + + for (Use &U : I.uses()) + Worklist.push_back(&U); + + while (!Worklist.empty()) { + const Use *U = Worklist.pop_back_val(); + if (!Visited.insert(U).second) + continue; + + auto *UserI = U->getUser(); + + if (isa<LoadInst>(UserI)) + continue; + if (auto *SI = dyn_cast<StoreInst>(UserI)) { + if (SI->getValueOperand() == U->get()) { + LLVM_DEBUG(dbgs() << "[H2S] escaping store to memory: " << *UserI << "\n"); + return false; + } + // A store into the malloc'ed memory is fine. + continue; + } + + // NOTE: Right now, if a function that has malloc pointer as an argument + // frees memory, we assume that the malloc pointer is freed. + + // TODO: Add nofree callsite argument attribute to indicate that pointer + // argument is not freed. + if (auto *CB = dyn_cast<CallBase>(UserI)) { + if (!CB->isArgOperand(U)) + continue; + + if (CB->isLifetimeStartOrEnd()) + continue; + + // Record malloc. + if (isFreeCall(UserI, TLI)) { + FreesForMalloc[&I].insert( + cast<Instruction>(const_cast<User *>(UserI))); + continue; + } + + // If a function does not free memory we are fine + const auto &NoFreeAA = + A.getAAFor<AANoFree>(*this, IRPosition::callsite_function(*CB)); + + unsigned ArgNo = U - CB->arg_begin(); + const auto &NoCaptureAA = A.getAAFor<AANoCapture>( + *this, IRPosition::callsite_argument(*CB, ArgNo)); + + if (!NoCaptureAA.isAssumedNoCapture() || !NoFreeAA.isAssumedNoFree()) { + LLVM_DEBUG(dbgs() << "[H2S] Bad user: " << *UserI << "\n"); + return false; + } + continue; + } + + if (isa<GetElementPtrInst>(UserI) || isa<BitCastInst>(UserI)) { + for (Use &U : UserI->uses()) + Worklist.push_back(&U); + continue; + } + + // Unknown user. + LLVM_DEBUG(dbgs() << "[H2S] Unknown user: " << *UserI << "\n"); + return false; + } + return true; + }; + + auto MallocCallocCheck = [&](Instruction &I) { + if (BadMallocCalls.count(&I)) return true; - Value *V = CS.getArgOperand(ArgNo); - if (isKnownNonZero(V, getAnchorScope().getParent()->getDataLayout())) + bool IsMalloc = isMallocLikeFn(&I, TLI); + bool IsCalloc = !IsMalloc && isCallocLikeFn(&I, TLI); + if (!IsMalloc && !IsCalloc) { + BadMallocCalls.insert(&I); return true; + } - return false; - }; - if (!A.checkForAllCallSites(F, CallSiteCheck, true)) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; - } - return ChangeStatus::UNCHANGED; -} + if (IsMalloc) { + if (auto *Size = dyn_cast<ConstantInt>(I.getOperand(0))) + if (Size->getValue().sle(MaxHeapToStackSize)) + if (UsesCheck(I)) { + MallocCalls.insert(&I); + return true; + } + } else if (IsCalloc) { + bool Overflow = false; + if (auto *Num = dyn_cast<ConstantInt>(I.getOperand(0))) + if (auto *Size = dyn_cast<ConstantInt>(I.getOperand(1))) + if ((Size->getValue().umul_ov(Num->getValue(), Overflow)) + .sle(MaxHeapToStackSize)) + if (!Overflow && UsesCheck(I)) { + MallocCalls.insert(&I); + return true; + } + } -ChangeStatus AANonNullCallSiteArgument::updateImpl(Attributor &A) { - // NOTE: Never look at the argument of the callee in this method. - // If we do this, "nonnull" is always deduced because of the assumption. + BadMallocCalls.insert(&I); + return true; + }; - Value &V = *getAssociatedValue(); + size_t NumBadMallocs = BadMallocCalls.size(); - auto *NonNullAA = A.getAAFor<AANonNull>(*this, V); + A.checkForAllCallLikeInstructions(MallocCallocCheck, *this); - if (!NonNullAA || !NonNullAA->isAssumedNonNull()) { - indicatePessimisticFixpoint(); + if (NumBadMallocs != BadMallocCalls.size()) return ChangeStatus::CHANGED; - } return ChangeStatus::UNCHANGED; } -/// ------------------------ Will-Return Attributes ---------------------------- +struct AAHeapToStackFunction final : public AAHeapToStackImpl { + AAHeapToStackFunction(const IRPosition &IRP) : AAHeapToStackImpl(IRP) {} -struct AAWillReturnImpl : public AAWillReturn, BooleanState { + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECL(MallocCalls, Function, + "Number of MallocCalls converted to allocas"); + BUILD_STAT_NAME(MallocCalls, Function) += MallocCalls.size(); + } +}; + +/// -------------------- Memory Behavior Attributes ---------------------------- +/// Includes read-none, read-only, and write-only. +/// ---------------------------------------------------------------------------- +struct AAMemoryBehaviorImpl : public AAMemoryBehavior { + AAMemoryBehaviorImpl(const IRPosition &IRP) : AAMemoryBehavior(IRP) {} - /// See AbstractAttribute::AbstractAttribute(...). - AAWillReturnImpl(Function &F, InformationCache &InfoCache) - : AAWillReturn(F, InfoCache) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + intersectAssumedBits(BEST_STATE); + getKnownStateFromValue(getIRPosition(), getState()); + IRAttribute::initialize(A); + } - /// See AAWillReturn::isKnownWillReturn(). - bool isKnownWillReturn() const override { return getKnown(); } + /// Return the memory behavior information encoded in the IR for \p IRP. + static void getKnownStateFromValue(const IRPosition &IRP, + IntegerState &State) { + SmallVector<Attribute, 2> Attrs; + IRP.getAttrs(AttrKinds, Attrs); + for (const Attribute &Attr : Attrs) { + switch (Attr.getKindAsEnum()) { + case Attribute::ReadNone: + State.addKnownBits(NO_ACCESSES); + break; + case Attribute::ReadOnly: + State.addKnownBits(NO_WRITES); + break; + case Attribute::WriteOnly: + State.addKnownBits(NO_READS); + break; + default: + llvm_unreachable("Unexpcted attribute!"); + } + } - /// See AAWillReturn::isAssumedWillReturn(). - bool isAssumedWillReturn() const override { return getAssumed(); } + if (auto *I = dyn_cast<Instruction>(&IRP.getAnchorValue())) { + if (!I->mayReadFromMemory()) + State.addKnownBits(NO_READS); + if (!I->mayWriteToMemory()) + State.addKnownBits(NO_WRITES); + } + } - /// See AbstractAttribute::getState(...). - AbstractState &getState() override { return *this; } + /// See AbstractAttribute::getDeducedAttributes(...). + void getDeducedAttributes(LLVMContext &Ctx, + SmallVectorImpl<Attribute> &Attrs) const override { + assert(Attrs.size() == 0); + if (isAssumedReadNone()) + Attrs.push_back(Attribute::get(Ctx, Attribute::ReadNone)); + else if (isAssumedReadOnly()) + Attrs.push_back(Attribute::get(Ctx, Attribute::ReadOnly)); + else if (isAssumedWriteOnly()) + Attrs.push_back(Attribute::get(Ctx, Attribute::WriteOnly)); + assert(Attrs.size() <= 1); + } - /// See AbstractAttribute::getState(...). - const AbstractState &getState() const override { return *this; } + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + IRPosition &IRP = getIRPosition(); + + // Check if we would improve the existing attributes first. + SmallVector<Attribute, 4> DeducedAttrs; + getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs); + if (llvm::all_of(DeducedAttrs, [&](const Attribute &Attr) { + return IRP.hasAttr(Attr.getKindAsEnum(), + /* IgnoreSubsumingPositions */ true); + })) + return ChangeStatus::UNCHANGED; + + // Clear existing attributes. + IRP.removeAttrs(AttrKinds); + + // Use the generic manifest method. + return IRAttribute::manifest(A); + } - /// See AbstractAttribute::getAsStr() + /// See AbstractState::getAsStr(). const std::string getAsStr() const override { - return getAssumed() ? "willreturn" : "may-noreturn"; + if (isAssumedReadNone()) + return "readnone"; + if (isAssumedReadOnly()) + return "readonly"; + if (isAssumedWriteOnly()) + return "writeonly"; + return "may-read/write"; } + + /// The set of IR attributes AAMemoryBehavior deals with. + static const Attribute::AttrKind AttrKinds[3]; }; -struct AAWillReturnFunction final : AAWillReturnImpl { +const Attribute::AttrKind AAMemoryBehaviorImpl::AttrKinds[] = { + Attribute::ReadNone, Attribute::ReadOnly, Attribute::WriteOnly}; - /// See AbstractAttribute::AbstractAttribute(...). - AAWillReturnFunction(Function &F, InformationCache &InfoCache) - : AAWillReturnImpl(F, InfoCache) {} +/// Memory behavior attribute for a floating value. +struct AAMemoryBehaviorFloating : AAMemoryBehaviorImpl { + AAMemoryBehaviorFloating(const IRPosition &IRP) : AAMemoryBehaviorImpl(IRP) {} - /// See AbstractAttribute::getManifestPosition(). - ManifestPosition getManifestPosition() const override { - return MP_FUNCTION; + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAMemoryBehaviorImpl::initialize(A); + // Initialize the use vector with all direct uses of the associated value. + for (const Use &U : getAssociatedValue().uses()) + Uses.insert(&U); } + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + if (isAssumedReadNone()) + STATS_DECLTRACK_FLOATING_ATTR(readnone) + else if (isAssumedReadOnly()) + STATS_DECLTRACK_FLOATING_ATTR(readonly) + else if (isAssumedWriteOnly()) + STATS_DECLTRACK_FLOATING_ATTR(writeonly) + } + +private: + /// Return true if users of \p UserI might access the underlying + /// variable/location described by \p U and should therefore be analyzed. + bool followUsersOfUseIn(Attributor &A, const Use *U, + const Instruction *UserI); + + /// Update the state according to the effect of use \p U in \p UserI. + void analyzeUseIn(Attributor &A, const Use *U, const Instruction *UserI); + +protected: + /// Container for (transitive) uses of the associated argument. + SetVector<const Use *> Uses; +}; + +/// Memory behavior attribute for function argument. +struct AAMemoryBehaviorArgument : AAMemoryBehaviorFloating { + AAMemoryBehaviorArgument(const IRPosition &IRP) + : AAMemoryBehaviorFloating(IRP) {} + /// See AbstractAttribute::initialize(...). - void initialize(Attributor &A) override; + void initialize(Attributor &A) override { + AAMemoryBehaviorFloating::initialize(A); + + // Initialize the use vector with all direct uses of the associated value. + Argument *Arg = getAssociatedArgument(); + if (!Arg || !Arg->getParent()->hasExactDefinition()) + indicatePessimisticFixpoint(); + } + + ChangeStatus manifest(Attributor &A) override { + // TODO: From readattrs.ll: "inalloca parameters are always + // considered written" + if (hasAttr({Attribute::InAlloca})) { + removeKnownBits(NO_WRITES); + removeAssumedBits(NO_WRITES); + } + return AAMemoryBehaviorFloating::manifest(A); + } + + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + if (isAssumedReadNone()) + STATS_DECLTRACK_ARG_ATTR(readnone) + else if (isAssumedReadOnly()) + STATS_DECLTRACK_ARG_ATTR(readonly) + else if (isAssumedWriteOnly()) + STATS_DECLTRACK_ARG_ATTR(writeonly) + } +}; + +struct AAMemoryBehaviorCallSiteArgument final : AAMemoryBehaviorArgument { + AAMemoryBehaviorCallSiteArgument(const IRPosition &IRP) + : AAMemoryBehaviorArgument(IRP) {} /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override; + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Argument *Arg = getAssociatedArgument(); + const IRPosition &ArgPos = IRPosition::argument(*Arg); + auto &ArgAA = A.getAAFor<AAMemoryBehavior>(*this, ArgPos); + return clampStateAndIndicateChange( + getState(), + static_cast<const AANoCapture::StateType &>(ArgAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + if (isAssumedReadNone()) + STATS_DECLTRACK_CSARG_ATTR(readnone) + else if (isAssumedReadOnly()) + STATS_DECLTRACK_CSARG_ATTR(readonly) + else if (isAssumedWriteOnly()) + STATS_DECLTRACK_CSARG_ATTR(writeonly) + } }; -// Helper function that checks whether a function has any cycle. -// TODO: Replace with more efficent code -bool containsCycle(Function &F) { - SmallPtrSet<BasicBlock *, 32> Visited; +/// Memory behavior attribute for a call site return position. +struct AAMemoryBehaviorCallSiteReturned final : AAMemoryBehaviorFloating { + AAMemoryBehaviorCallSiteReturned(const IRPosition &IRP) + : AAMemoryBehaviorFloating(IRP) {} - // Traverse BB by dfs and check whether successor is already visited. - for (BasicBlock *BB : depth_first(&F)) { - Visited.insert(BB); - for (auto *SuccBB : successors(BB)) { - if (Visited.count(SuccBB)) - return true; + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + // We do not annotate returned values. + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + +/// An AA to represent the memory behavior function attributes. +struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { + AAMemoryBehaviorFunction(const IRPosition &IRP) : AAMemoryBehaviorImpl(IRP) {} + + /// See AbstractAttribute::updateImpl(Attributor &A). + virtual ChangeStatus updateImpl(Attributor &A) override; + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + Function &F = cast<Function>(getAnchorValue()); + if (isAssumedReadNone()) { + F.removeFnAttr(Attribute::ArgMemOnly); + F.removeFnAttr(Attribute::InaccessibleMemOnly); + F.removeFnAttr(Attribute::InaccessibleMemOrArgMemOnly); } + return AAMemoryBehaviorImpl::manifest(A); } - return false; + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + if (isAssumedReadNone()) + STATS_DECLTRACK_FN_ATTR(readnone) + else if (isAssumedReadOnly()) + STATS_DECLTRACK_FN_ATTR(readonly) + else if (isAssumedWriteOnly()) + STATS_DECLTRACK_FN_ATTR(writeonly) + } +}; + +/// AAMemoryBehavior attribute for call sites. +struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { + AAMemoryBehaviorCallSite(const IRPosition &IRP) : AAMemoryBehaviorImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAMemoryBehaviorImpl::initialize(A); + Function *F = getAssociatedFunction(); + if (!F || !F->hasExactDefinition()) + indicatePessimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos); + return clampStateAndIndicateChange( + getState(), static_cast<const AAAlign::StateType &>(FnAA.getState())); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + if (isAssumedReadNone()) + STATS_DECLTRACK_CS_ATTR(readnone) + else if (isAssumedReadOnly()) + STATS_DECLTRACK_CS_ATTR(readonly) + else if (isAssumedWriteOnly()) + STATS_DECLTRACK_CS_ATTR(writeonly) + } +}; +} // namespace + +ChangeStatus AAMemoryBehaviorFunction::updateImpl(Attributor &A) { + + // The current assumed state used to determine a change. + auto AssumedState = getAssumed(); + + auto CheckRWInst = [&](Instruction &I) { + // If the instruction has an own memory behavior state, use it to restrict + // the local state. No further analysis is required as the other memory + // state is as optimistic as it gets. + if (ImmutableCallSite ICS = ImmutableCallSite(&I)) { + const auto &MemBehaviorAA = A.getAAFor<AAMemoryBehavior>( + *this, IRPosition::callsite_function(ICS)); + intersectAssumedBits(MemBehaviorAA.getAssumed()); + return !isAtFixpoint(); + } + + // Remove access kind modifiers if necessary. + if (I.mayReadFromMemory()) + removeAssumedBits(NO_READS); + if (I.mayWriteToMemory()) + removeAssumedBits(NO_WRITES); + return !isAtFixpoint(); + }; + + if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this)) + return indicatePessimisticFixpoint(); + + return (AssumedState != getAssumed()) ? ChangeStatus::CHANGED + : ChangeStatus::UNCHANGED; } -// Helper function that checks the function have a loop which might become an -// endless loop -// FIXME: Any cycle is regarded as endless loop for now. -// We have to allow some patterns. -bool containsPossiblyEndlessLoop(Function &F) { return containsCycle(F); } +ChangeStatus AAMemoryBehaviorFloating::updateImpl(Attributor &A) { -void AAWillReturnFunction::initialize(Attributor &A) { - Function &F = getAnchorScope(); + const IRPosition &IRP = getIRPosition(); + const IRPosition &FnPos = IRPosition::function_scope(IRP); + AAMemoryBehavior::StateType &S = getState(); - if (containsPossiblyEndlessLoop(F)) - indicatePessimisticFixpoint(); + // First, check the function scope. We take the known information and we avoid + // work if the assumed information implies the current assumed information for + // this attribute. + const auto &FnMemAA = A.getAAFor<AAMemoryBehavior>(*this, FnPos); + S.addKnownBits(FnMemAA.getKnown()); + if ((S.getAssumed() & FnMemAA.getAssumed()) == S.getAssumed()) + return ChangeStatus::UNCHANGED; + + // Make sure the value is not captured (except through "return"), if + // it is, any information derived would be irrelevant anyway as we cannot + // check the potential aliases introduced by the capture. However, no need + // to fall back to anythign less optimistic than the function state. + const auto &ArgNoCaptureAA = A.getAAFor<AANoCapture>(*this, IRP); + if (!ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) { + S.intersectAssumedBits(FnMemAA.getAssumed()); + return ChangeStatus::CHANGED; + } + + // The current assumed state used to determine a change. + auto AssumedState = S.getAssumed(); + + // Liveness information to exclude dead users. + // TODO: Take the FnPos once we have call site specific liveness information. + const auto &LivenessAA = A.getAAFor<AAIsDead>( + *this, IRPosition::function(*IRP.getAssociatedFunction())); + + // Visit and expand uses until all are analyzed or a fixpoint is reached. + for (unsigned i = 0; i < Uses.size() && !isAtFixpoint(); i++) { + const Use *U = Uses[i]; + Instruction *UserI = cast<Instruction>(U->getUser()); + LLVM_DEBUG(dbgs() << "[AAMemoryBehavior] Use: " << **U << " in " << *UserI + << " [Dead: " << (LivenessAA.isAssumedDead(UserI)) + << "]\n"); + if (LivenessAA.isAssumedDead(UserI)) + continue; + + // Check if the users of UserI should also be visited. + if (followUsersOfUseIn(A, U, UserI)) + for (const Use &UserIUse : UserI->uses()) + Uses.insert(&UserIUse); + + // If UserI might touch memory we analyze the use in detail. + if (UserI->mayReadOrWriteMemory()) + analyzeUseIn(A, U, UserI); + } + + return (AssumedState != getAssumed()) ? ChangeStatus::CHANGED + : ChangeStatus::UNCHANGED; } -ChangeStatus AAWillReturnFunction::updateImpl(Attributor &A) { - Function &F = getAnchorScope(); +bool AAMemoryBehaviorFloating::followUsersOfUseIn(Attributor &A, const Use *U, + const Instruction *UserI) { + // The loaded value is unrelated to the pointer argument, no need to + // follow the users of the load. + if (isa<LoadInst>(UserI)) + return false; - // The map from instruction opcodes to those instructions in the function. - auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + // By default we follow all uses assuming UserI might leak information on U, + // we have special handling for call sites operands though. + ImmutableCallSite ICS(UserI); + if (!ICS || !ICS.isArgOperand(U)) + return true; - for (unsigned Opcode : - {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, - (unsigned)Instruction::Call}) { - for (Instruction *I : OpcodeInstMap[Opcode]) { - auto ICS = ImmutableCallSite(I); + // If the use is a call argument known not to be captured, the users of + // the call do not need to be visited because they have to be unrelated to + // the input. Note that this check is not trivial even though we disallow + // general capturing of the underlying argument. The reason is that the + // call might the argument "through return", which we allow and for which we + // need to check call users. + unsigned ArgNo = ICS.getArgumentNo(U); + const auto &ArgNoCaptureAA = + A.getAAFor<AANoCapture>(*this, IRPosition::callsite_argument(ICS, ArgNo)); + return !ArgNoCaptureAA.isAssumedNoCapture(); +} - if (ICS.hasFnAttr(Attribute::WillReturn)) - continue; +void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use *U, + const Instruction *UserI) { + assert(UserI->mayReadOrWriteMemory()); - auto *WillReturnAA = A.getAAFor<AAWillReturn>(*this, *I); - if (!WillReturnAA || !WillReturnAA->isAssumedWillReturn()) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; - } + switch (UserI->getOpcode()) { + default: + // TODO: Handle all atomics and other side-effect operations we know of. + break; + case Instruction::Load: + // Loads cause the NO_READS property to disappear. + removeAssumedBits(NO_READS); + return; - auto *NoRecurseAA = A.getAAFor<AANoRecurse>(*this, *I); + case Instruction::Store: + // Stores cause the NO_WRITES property to disappear if the use is the + // pointer operand. Note that we do assume that capturing was taken care of + // somewhere else. + if (cast<StoreInst>(UserI)->getPointerOperand() == U->get()) + removeAssumedBits(NO_WRITES); + return; - // FIXME: (i) Prohibit any recursion for now. - // (ii) AANoRecurse isn't implemented yet so currently any call is - // regarded as having recursion. - // Code below should be - // if ((!NoRecurseAA || !NoRecurseAA->isAssumedNoRecurse()) && - if (!NoRecurseAA && !ICS.hasFnAttr(Attribute::NoRecurse)) { - indicatePessimisticFixpoint(); - return ChangeStatus::CHANGED; - } + case Instruction::Call: + case Instruction::CallBr: + case Instruction::Invoke: { + // For call sites we look at the argument memory behavior attribute (this + // could be recursive!) in order to restrict our own state. + ImmutableCallSite ICS(UserI); + + // Give up on operand bundles. + if (ICS.isBundleOperand(U)) { + indicatePessimisticFixpoint(); + return; + } + + // Calling a function does read the function pointer, maybe write it if the + // function is self-modifying. + if (ICS.isCallee(U)) { + removeAssumedBits(NO_READS); + break; } + + // Adjust the possible access behavior based on the information on the + // argument. + unsigned ArgNo = ICS.getArgumentNo(U); + const IRPosition &ArgPos = IRPosition::callsite_argument(ICS, ArgNo); + const auto &MemBehaviorAA = A.getAAFor<AAMemoryBehavior>(*this, ArgPos); + // "assumed" has at most the same bits as the MemBehaviorAA assumed + // and at least "known". + intersectAssumedBits(MemBehaviorAA.getAssumed()); + return; } + }; - return ChangeStatus::UNCHANGED; + // Generally, look at the "may-properties" and adjust the assumed state if we + // did not trigger special handling before. + if (UserI->mayReadFromMemory()) + removeAssumedBits(NO_READS); + if (UserI->mayWriteToMemory()) + removeAssumedBits(NO_WRITES); } /// ---------------------------------------------------------------------------- /// Attributor /// ---------------------------------------------------------------------------- -bool Attributor::checkForAllCallSites(Function &F, - std::function<bool(CallSite)> &Pred, - bool RequireAllCallSites) { +bool Attributor::isAssumedDead(const AbstractAttribute &AA, + const AAIsDead *LivenessAA) { + const Instruction *CtxI = AA.getIRPosition().getCtxI(); + if (!CtxI) + return false; + + if (!LivenessAA) + LivenessAA = + &getAAFor<AAIsDead>(AA, IRPosition::function(*CtxI->getFunction()), + /* TrackDependence */ false); + + // Don't check liveness for AAIsDead. + if (&AA == LivenessAA) + return false; + + if (!LivenessAA->isAssumedDead(CtxI)) + return false; + + // We actually used liveness information so we have to record a dependence. + recordDependence(*LivenessAA, AA); + + return true; +} + +bool Attributor::checkForAllCallSites( + const function_ref<bool(AbstractCallSite)> &Pred, + const AbstractAttribute &QueryingAA, bool RequireAllCallSites) { // We can try to determine information from // the call sites. However, this is only possible all call sites are known, // hence the function has internal linkage. - if (RequireAllCallSites && !F.hasInternalLinkage()) { + const IRPosition &IRP = QueryingAA.getIRPosition(); + const Function *AssociatedFunction = IRP.getAssociatedFunction(); + if (!AssociatedFunction) { + LLVM_DEBUG(dbgs() << "[Attributor] No function associated with " << IRP + << "\n"); + return false; + } + + return checkForAllCallSites(Pred, *AssociatedFunction, RequireAllCallSites, + &QueryingAA); +} + +bool Attributor::checkForAllCallSites( + const function_ref<bool(AbstractCallSite)> &Pred, const Function &Fn, + bool RequireAllCallSites, const AbstractAttribute *QueryingAA) { + if (RequireAllCallSites && !Fn.hasLocalLinkage()) { LLVM_DEBUG( dbgs() - << "Attributor: Function " << F.getName() + << "[Attributor] Function " << Fn.getName() << " has no internal linkage, hence not all call sites are known\n"); return false; } - for (const Use &U : F.uses()) { + for (const Use &U : Fn.uses()) { + AbstractCallSite ACS(&U); + if (!ACS) { + LLVM_DEBUG(dbgs() << "[Attributor] Function " + << Fn.getName() + << " has non call site use " << *U.get() << " in " + << *U.getUser() << "\n"); + return false; + } + + Instruction *I = ACS.getInstruction(); + Function *Caller = I->getFunction(); + + const auto *LivenessAA = + lookupAAFor<AAIsDead>(IRPosition::function(*Caller), QueryingAA, + /* TrackDependence */ false); + + // Skip dead calls. + if (LivenessAA && LivenessAA->isAssumedDead(I)) { + // We actually used liveness information so we have to record a + // dependence. + if (QueryingAA) + recordDependence(*LivenessAA, *QueryingAA); + continue; + } - CallSite CS(U.getUser()); - if (!CS || !CS.isCallee(&U) || !CS.getCaller()->hasExactDefinition()) { + const Use *EffectiveUse = + ACS.isCallbackCall() ? &ACS.getCalleeUseForCallback() : &U; + if (!ACS.isCallee(EffectiveUse)) { if (!RequireAllCallSites) continue; - - LLVM_DEBUG(dbgs() << "Attributor: User " << *U.getUser() - << " is an invalid use of " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "[Attributor] User " << EffectiveUse->getUser() + << " is an invalid use of " + << Fn.getName() << "\n"); return false; } - if (Pred(CS)) + if (Pred(ACS)) continue; - LLVM_DEBUG(dbgs() << "Attributor: Call site callback failed for " - << *CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << "[Attributor] Call site callback failed for " + << *ACS.getInstruction() << "\n"); return false; } return true; } -ChangeStatus Attributor::run() { - // Initialize all abstract attributes. - for (AbstractAttribute *AA : AllAbstractAttributes) - AA->initialize(*this); +bool Attributor::checkForAllReturnedValuesAndReturnInsts( + const function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> + &Pred, + const AbstractAttribute &QueryingAA) { + + const IRPosition &IRP = QueryingAA.getIRPosition(); + // Since we need to provide return instructions we have to have an exact + // definition. + const Function *AssociatedFunction = IRP.getAssociatedFunction(); + if (!AssociatedFunction) + return false; + // If this is a call site query we use the call site specific return values + // and liveness information. + // TODO: use the function scope once we have call site AAReturnedValues. + const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); + const auto &AARetVal = getAAFor<AAReturnedValues>(QueryingAA, QueryIRP); + if (!AARetVal.getState().isValidState()) + return false; + + return AARetVal.checkForAllReturnedValuesAndReturnInsts(Pred); +} + +bool Attributor::checkForAllReturnedValues( + const function_ref<bool(Value &)> &Pred, + const AbstractAttribute &QueryingAA) { + + const IRPosition &IRP = QueryingAA.getIRPosition(); + const Function *AssociatedFunction = IRP.getAssociatedFunction(); + if (!AssociatedFunction) + return false; + + // TODO: use the function scope once we have call site AAReturnedValues. + const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); + const auto &AARetVal = getAAFor<AAReturnedValues>(QueryingAA, QueryIRP); + if (!AARetVal.getState().isValidState()) + return false; + + return AARetVal.checkForAllReturnedValuesAndReturnInsts( + [&](Value &RV, const SmallSetVector<ReturnInst *, 4> &) { + return Pred(RV); + }); +} + +static bool +checkForAllInstructionsImpl(InformationCache::OpcodeInstMapTy &OpcodeInstMap, + const function_ref<bool(Instruction &)> &Pred, + const AAIsDead *LivenessAA, bool &AnyDead, + const ArrayRef<unsigned> &Opcodes) { + for (unsigned Opcode : Opcodes) { + for (Instruction *I : OpcodeInstMap[Opcode]) { + // Skip dead instructions. + if (LivenessAA && LivenessAA->isAssumedDead(I)) { + AnyDead = true; + continue; + } + + if (!Pred(*I)) + return false; + } + } + return true; +} + +bool Attributor::checkForAllInstructions( + const llvm::function_ref<bool(Instruction &)> &Pred, + const AbstractAttribute &QueryingAA, const ArrayRef<unsigned> &Opcodes) { + + const IRPosition &IRP = QueryingAA.getIRPosition(); + // Since we need to provide instructions we have to have an exact definition. + const Function *AssociatedFunction = IRP.getAssociatedFunction(); + if (!AssociatedFunction) + return false; + + // TODO: use the function scope once we have call site AAReturnedValues. + const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); + const auto &LivenessAA = + getAAFor<AAIsDead>(QueryingAA, QueryIRP, /* TrackDependence */ false); + bool AnyDead = false; + + auto &OpcodeInstMap = + InfoCache.getOpcodeInstMapForFunction(*AssociatedFunction); + if (!checkForAllInstructionsImpl(OpcodeInstMap, Pred, &LivenessAA, AnyDead, + Opcodes)) + return false; + + // If we actually used liveness information so we have to record a dependence. + if (AnyDead) + recordDependence(LivenessAA, QueryingAA); + + return true; +} + +bool Attributor::checkForAllReadWriteInstructions( + const llvm::function_ref<bool(Instruction &)> &Pred, + AbstractAttribute &QueryingAA) { + + const Function *AssociatedFunction = + QueryingAA.getIRPosition().getAssociatedFunction(); + if (!AssociatedFunction) + return false; + + // TODO: use the function scope once we have call site AAReturnedValues. + const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); + const auto &LivenessAA = + getAAFor<AAIsDead>(QueryingAA, QueryIRP, /* TrackDependence */ false); + bool AnyDead = false; + + for (Instruction *I : + InfoCache.getReadOrWriteInstsForFunction(*AssociatedFunction)) { + // Skip dead instructions. + if (LivenessAA.isAssumedDead(I)) { + AnyDead = true; + continue; + } + + if (!Pred(*I)) + return false; + } + + // If we actually used liveness information so we have to record a dependence. + if (AnyDead) + recordDependence(LivenessAA, QueryingAA); + + return true; +} + +ChangeStatus Attributor::run(Module &M) { LLVM_DEBUG(dbgs() << "[Attributor] Identified and initialized " << AllAbstractAttributes.size() << " abstract attributes.\n"); @@ -1370,10 +4470,25 @@ ChangeStatus Attributor::run() { SetVector<AbstractAttribute *> Worklist; Worklist.insert(AllAbstractAttributes.begin(), AllAbstractAttributes.end()); + bool RecomputeDependences = false; + do { + // Remember the size to determine new attributes. + size_t NumAAs = AllAbstractAttributes.size(); LLVM_DEBUG(dbgs() << "\n\n[Attributor] #Iteration: " << IterationCounter << ", Worklist size: " << Worklist.size() << "\n"); + // If dependences (=QueryMap) are recomputed we have to look at all abstract + // attributes again, regardless of what changed in the last iteration. + if (RecomputeDependences) { + LLVM_DEBUG( + dbgs() << "[Attributor] Run all AAs to recompute dependences\n"); + QueryMap.clear(); + ChangedAAs.clear(); + Worklist.insert(AllAbstractAttributes.begin(), + AllAbstractAttributes.end()); + } + // Add all abstract attributes that are potentially dependent on one that // changed to the work list. for (AbstractAttribute *ChangedAA : ChangedAAs) { @@ -1381,27 +4496,42 @@ ChangeStatus Attributor::run() { Worklist.insert(QuerriedAAs.begin(), QuerriedAAs.end()); } + LLVM_DEBUG(dbgs() << "[Attributor] #Iteration: " << IterationCounter + << ", Worklist+Dependent size: " << Worklist.size() + << "\n"); + // Reset the changed set. ChangedAAs.clear(); // Update all abstract attribute in the work list and record the ones that // changed. for (AbstractAttribute *AA : Worklist) - if (AA->update(*this) == ChangeStatus::CHANGED) - ChangedAAs.push_back(AA); + if (!isAssumedDead(*AA, nullptr)) + if (AA->update(*this) == ChangeStatus::CHANGED) + ChangedAAs.push_back(AA); + + // Check if we recompute the dependences in the next iteration. + RecomputeDependences = (DepRecomputeInterval > 0 && + IterationCounter % DepRecomputeInterval == 0); + + // Add attributes to the changed set if they have been created in the last + // iteration. + ChangedAAs.append(AllAbstractAttributes.begin() + NumAAs, + AllAbstractAttributes.end()); // Reset the work list and repopulate with the changed abstract attributes. // Note that dependent ones are added above. Worklist.clear(); Worklist.insert(ChangedAAs.begin(), ChangedAAs.end()); - } while (!Worklist.empty() && ++IterationCounter < MaxFixpointIterations); + } while (!Worklist.empty() && (IterationCounter++ < MaxFixpointIterations || + VerifyMaxFixpointIterations)); LLVM_DEBUG(dbgs() << "\n[Attributor] Fixpoint iteration done after: " << IterationCounter << "/" << MaxFixpointIterations << " iterations\n"); - bool FinishedAtFixpoint = Worklist.empty(); + size_t NumFinalAAs = AllAbstractAttributes.size(); // Reset abstract arguments not settled in a sound fixpoint by now. This // happens when we stopped the fixpoint iteration early. Note that only the @@ -1448,8 +4578,14 @@ ChangeStatus Attributor::run() { if (!State.isValidState()) continue; + // Skip dead code. + if (isAssumedDead(*AA, nullptr)) + continue; // Manifest the state and record if we changed the IR. ChangeStatus LocalChange = AA->manifest(*this); + if (LocalChange == ChangeStatus::CHANGED && AreStatisticsEnabled()) + AA->trackStatistics(); + ManifestChange = ManifestChange | LocalChange; NumAtFixpoint++; @@ -1462,69 +4598,92 @@ ChangeStatus Attributor::run() { << " arguments while " << NumAtFixpoint << " were in a valid fixpoint state\n"); - // If verification is requested, we finished this run at a fixpoint, and the - // IR was changed, we re-run the whole fixpoint analysis, starting at - // re-initialization of the arguments. This re-run should not result in an IR - // change. Though, the (virtual) state of attributes at the end of the re-run - // might be more optimistic than the known state or the IR state if the better - // state cannot be manifested. - if (VerifyAttributor && FinishedAtFixpoint && - ManifestChange == ChangeStatus::CHANGED) { - VerifyAttributor = false; - ChangeStatus VerifyStatus = run(); - if (VerifyStatus != ChangeStatus::UNCHANGED) - llvm_unreachable( - "Attributor verification failed, re-run did result in an IR change " - "even after a fixpoint was reached in the original run. (False " - "positives possible!)"); - VerifyAttributor = true; - } - NumAttributesManifested += NumManifested; NumAttributesValidFixpoint += NumAtFixpoint; - return ManifestChange; -} - -void Attributor::identifyDefaultAbstractAttributes( - Function &F, InformationCache &InfoCache, - DenseSet</* Attribute::AttrKind */ unsigned> *Whitelist) { + (void)NumFinalAAs; + assert( + NumFinalAAs == AllAbstractAttributes.size() && + "Expected the final number of abstract attributes to remain unchanged!"); + + // Delete stuff at the end to avoid invalid references and a nice order. + { + LLVM_DEBUG(dbgs() << "\n[Attributor] Delete at least " + << ToBeDeletedFunctions.size() << " functions and " + << ToBeDeletedBlocks.size() << " blocks and " + << ToBeDeletedInsts.size() << " instructions\n"); + for (Instruction *I : ToBeDeletedInsts) { + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + } - // Every function can be nounwind. - registerAA(*new AANoUnwindFunction(F, InfoCache)); + if (unsigned NumDeadBlocks = ToBeDeletedBlocks.size()) { + SmallVector<BasicBlock *, 8> ToBeDeletedBBs; + ToBeDeletedBBs.reserve(NumDeadBlocks); + ToBeDeletedBBs.append(ToBeDeletedBlocks.begin(), ToBeDeletedBlocks.end()); + DeleteDeadBlocks(ToBeDeletedBBs); + STATS_DECLTRACK(AAIsDead, BasicBlock, + "Number of dead basic blocks deleted."); + } - // Every function might be marked "nosync" - registerAA(*new AANoSyncFunction(F, InfoCache)); + STATS_DECL(AAIsDead, Function, "Number of dead functions deleted."); + for (Function *Fn : ToBeDeletedFunctions) { + Fn->replaceAllUsesWith(UndefValue::get(Fn->getType())); + Fn->eraseFromParent(); + STATS_TRACK(AAIsDead, Function); + } - // Every function might be "no-free". - registerAA(*new AANoFreeFunction(F, InfoCache)); + // Identify dead internal functions and delete them. This happens outside + // the other fixpoint analysis as we might treat potentially dead functions + // as live to lower the number of iterations. If they happen to be dead, the + // below fixpoint loop will identify and eliminate them. + SmallVector<Function *, 8> InternalFns; + for (Function &F : M) + if (F.hasLocalLinkage()) + InternalFns.push_back(&F); + + bool FoundDeadFn = true; + while (FoundDeadFn) { + FoundDeadFn = false; + for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) { + Function *F = InternalFns[u]; + if (!F) + continue; - // Return attributes are only appropriate if the return type is non void. - Type *ReturnType = F.getReturnType(); - if (!ReturnType->isVoidTy()) { - // Argument attribute "returned" --- Create only one per function even - // though it is an argument attribute. - if (!Whitelist || Whitelist->count(AAReturnedValues::ID)) - registerAA(*new AAReturnedValuesImpl(F, InfoCache)); + const auto *LivenessAA = + lookupAAFor<AAIsDead>(IRPosition::function(*F)); + if (LivenessAA && + !checkForAllCallSites([](AbstractCallSite ACS) { return false; }, + *LivenessAA, true)) + continue; - // Every function with pointer return type might be marked nonnull. - if (ReturnType->isPointerTy() && - (!Whitelist || Whitelist->count(AANonNullReturned::ID))) - registerAA(*new AANonNullReturned(F, InfoCache)); + STATS_TRACK(AAIsDead, Function); + F->replaceAllUsesWith(UndefValue::get(F->getType())); + F->eraseFromParent(); + InternalFns[u] = nullptr; + FoundDeadFn = true; + } + } } - // Every argument with pointer type might be marked nonnull. - for (Argument &Arg : F.args()) { - if (Arg.getType()->isPointerTy()) - registerAA(*new AANonNullArgument(Arg, InfoCache)); + if (VerifyMaxFixpointIterations && + IterationCounter != MaxFixpointIterations) { + errs() << "\n[Attributor] Fixpoint iteration done after: " + << IterationCounter << "/" << MaxFixpointIterations + << " iterations\n"; + llvm_unreachable("The fixpoint was not reached with exactly the number of " + "specified iterations!"); } - // Every function might be "will-return". - registerAA(*new AAWillReturnFunction(F, InfoCache)); + return ManifestChange; +} + +void Attributor::initializeInformationCache(Function &F) { - // Walk all instructions to find more attribute opportunities and also - // interesting instructions that might be queried by abstract attributes - // during their initialization or update. + // Walk all instructions to find interesting instructions that might be + // queried by abstract attributes during their initialization or update. + // This has to happen before we create attributes. auto &ReadOrWriteInsts = InfoCache.FuncRWInstsMap[&F]; auto &InstOpcodeMap = InfoCache.FuncInstOpcodeMap[&F]; @@ -1540,8 +4699,12 @@ void Attributor::identifyDefaultAbstractAttributes( default: assert((!ImmutableCallSite(&I)) && (!isa<CallBase>(&I)) && "New call site/base instruction type needs to be known int the " - "attributor."); + "Attributor."); break; + case Instruction::Load: + // The alignment of a pointer is interesting for loads. + case Instruction::Store: + // The alignment of a pointer is interesting for stores. case Instruction::Call: case Instruction::CallBr: case Instruction::Invoke: @@ -1555,18 +4718,154 @@ void Attributor::identifyDefaultAbstractAttributes( InstOpcodeMap[I.getOpcode()].push_back(&I); if (I.mayReadOrWriteMemory()) ReadOrWriteInsts.push_back(&I); + } +} + +void Attributor::identifyDefaultAbstractAttributes(Function &F) { + if (!VisitedFunctions.insert(&F).second) + return; + + IRPosition FPos = IRPosition::function(F); + + // Check for dead BasicBlocks in every function. + // We need dead instruction detection because we do not want to deal with + // broken IR in which SSA rules do not apply. + getOrCreateAAFor<AAIsDead>(FPos); + + // Every function might be "will-return". + getOrCreateAAFor<AAWillReturn>(FPos); + // Every function can be nounwind. + getOrCreateAAFor<AANoUnwind>(FPos); + + // Every function might be marked "nosync" + getOrCreateAAFor<AANoSync>(FPos); + + // Every function might be "no-free". + getOrCreateAAFor<AANoFree>(FPos); + + // Every function might be "no-return". + getOrCreateAAFor<AANoReturn>(FPos); + + // Every function might be "no-recurse". + getOrCreateAAFor<AANoRecurse>(FPos); + + // Every function might be "readnone/readonly/writeonly/...". + getOrCreateAAFor<AAMemoryBehavior>(FPos); + + // Every function might be applicable for Heap-To-Stack conversion. + if (EnableHeapToStack) + getOrCreateAAFor<AAHeapToStack>(FPos); + + // Return attributes are only appropriate if the return type is non void. + Type *ReturnType = F.getReturnType(); + if (!ReturnType->isVoidTy()) { + // Argument attribute "returned" --- Create only one per function even + // though it is an argument attribute. + getOrCreateAAFor<AAReturnedValues>(FPos); + + IRPosition RetPos = IRPosition::returned(F); + + // Every function might be simplified. + getOrCreateAAFor<AAValueSimplify>(RetPos); + + if (ReturnType->isPointerTy()) { + + // Every function with pointer return type might be marked align. + getOrCreateAAFor<AAAlign>(RetPos); + + // Every function with pointer return type might be marked nonnull. + getOrCreateAAFor<AANonNull>(RetPos); + + // Every function with pointer return type might be marked noalias. + getOrCreateAAFor<AANoAlias>(RetPos); + + // Every function with pointer return type might be marked + // dereferenceable. + getOrCreateAAFor<AADereferenceable>(RetPos); + } + } + + for (Argument &Arg : F.args()) { + IRPosition ArgPos = IRPosition::argument(Arg); + + // Every argument might be simplified. + getOrCreateAAFor<AAValueSimplify>(ArgPos); + + if (Arg.getType()->isPointerTy()) { + // Every argument with pointer type might be marked nonnull. + getOrCreateAAFor<AANonNull>(ArgPos); + + // Every argument with pointer type might be marked noalias. + getOrCreateAAFor<AANoAlias>(ArgPos); + + // Every argument with pointer type might be marked dereferenceable. + getOrCreateAAFor<AADereferenceable>(ArgPos); + + // Every argument with pointer type might be marked align. + getOrCreateAAFor<AAAlign>(ArgPos); + + // Every argument with pointer type might be marked nocapture. + getOrCreateAAFor<AANoCapture>(ArgPos); + + // Every argument with pointer type might be marked + // "readnone/readonly/writeonly/..." + getOrCreateAAFor<AAMemoryBehavior>(ArgPos); + } + } + + auto CallSitePred = [&](Instruction &I) -> bool { CallSite CS(&I); - if (CS && CS.getCalledFunction()) { + if (CS.getCalledFunction()) { for (int i = 0, e = CS.getCalledFunction()->arg_size(); i < e; i++) { + + IRPosition CSArgPos = IRPosition::callsite_argument(CS, i); + + // Call site argument might be simplified. + getOrCreateAAFor<AAValueSimplify>(CSArgPos); + if (!CS.getArgument(i)->getType()->isPointerTy()) continue; // Call site argument attribute "non-null". - registerAA(*new AANonNullCallSiteArgument(CS, i, InfoCache), i); + getOrCreateAAFor<AANonNull>(CSArgPos); + + // Call site argument attribute "no-alias". + getOrCreateAAFor<AANoAlias>(CSArgPos); + + // Call site argument attribute "dereferenceable". + getOrCreateAAFor<AADereferenceable>(CSArgPos); + + // Call site argument attribute "align". + getOrCreateAAFor<AAAlign>(CSArgPos); } } - } + return true; + }; + + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + bool Success, AnyDead = false; + Success = checkForAllInstructionsImpl( + OpcodeInstMap, CallSitePred, nullptr, AnyDead, + {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, + (unsigned)Instruction::Call}); + (void)Success; + assert(Success && !AnyDead && "Expected the check call to be successful!"); + + auto LoadStorePred = [&](Instruction &I) -> bool { + if (isa<LoadInst>(I)) + getOrCreateAAFor<AAAlign>( + IRPosition::value(*cast<LoadInst>(I).getPointerOperand())); + else + getOrCreateAAFor<AAAlign>( + IRPosition::value(*cast<StoreInst>(I).getPointerOperand())); + return true; + }; + Success = checkForAllInstructionsImpl( + OpcodeInstMap, LoadStorePred, nullptr, AnyDead, + {(unsigned)Instruction::Load, (unsigned)Instruction::Store}); + (void)Success; + assert(Success && !AnyDead && "Expected the check call to be successful!"); } /// Helpers to ease debugging through output streams and print calls. @@ -1576,21 +4875,39 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, ChangeStatus S) { return OS << (S == ChangeStatus::CHANGED ? "changed" : "unchanged"); } -raw_ostream &llvm::operator<<(raw_ostream &OS, - AbstractAttribute::ManifestPosition AP) { +raw_ostream &llvm::operator<<(raw_ostream &OS, IRPosition::Kind AP) { switch (AP) { - case AbstractAttribute::MP_ARGUMENT: + case IRPosition::IRP_INVALID: + return OS << "inv"; + case IRPosition::IRP_FLOAT: + return OS << "flt"; + case IRPosition::IRP_RETURNED: + return OS << "fn_ret"; + case IRPosition::IRP_CALL_SITE_RETURNED: + return OS << "cs_ret"; + case IRPosition::IRP_FUNCTION: + return OS << "fn"; + case IRPosition::IRP_CALL_SITE: + return OS << "cs"; + case IRPosition::IRP_ARGUMENT: return OS << "arg"; - case AbstractAttribute::MP_CALL_SITE_ARGUMENT: + case IRPosition::IRP_CALL_SITE_ARGUMENT: return OS << "cs_arg"; - case AbstractAttribute::MP_FUNCTION: - return OS << "fn"; - case AbstractAttribute::MP_RETURNED: - return OS << "fn_ret"; } llvm_unreachable("Unknown attribute position!"); } +raw_ostream &llvm::operator<<(raw_ostream &OS, const IRPosition &Pos) { + const Value &AV = Pos.getAssociatedValue(); + return OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " [" + << Pos.getAnchorValue().getName() << "@" << Pos.getArgNo() << "]}"; +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const IntegerState &S) { + return OS << "(" << S.getKnown() << "-" << S.getAssumed() << ")" + << static_cast<const AbstractState &>(S); +} + raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractState &S) { return OS << (!S.isValidState() ? "top" : (S.isAtFixpoint() ? "fix" : "")); } @@ -1601,8 +4918,8 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractAttribute &AA) { } void AbstractAttribute::print(raw_ostream &OS) const { - OS << "[" << getManifestPosition() << "][" << getAsStr() << "][" - << AnchoredVal.getName() << "]"; + OS << "[P: " << getIRPosition() << "][" << getAsStr() << "][S: " << getState() + << "]"; } ///} @@ -1610,7 +4927,7 @@ void AbstractAttribute::print(raw_ostream &OS) const { /// Pass (Manager) Boilerplate /// ---------------------------------------------------------------------------- -static bool runAttributorOnModule(Module &M) { +static bool runAttributorOnModule(Module &M, AnalysisGetter &AG) { if (DisableAttributor) return false; @@ -1619,39 +4936,39 @@ static bool runAttributorOnModule(Module &M) { // Create an Attributor and initially empty information cache that is filled // while we identify default attribute opportunities. - Attributor A; - InformationCache InfoCache; + InformationCache InfoCache(M, AG); + Attributor A(InfoCache, DepRecInterval); + + for (Function &F : M) + A.initializeInformationCache(F); for (Function &F : M) { - // TODO: Not all attributes require an exact definition. Find a way to - // enable deduction for some but not all attributes in case the - // definition might be changed at runtime, see also - // http://lists.llvm.org/pipermail/llvm-dev/2018-February/121275.html. - // TODO: We could always determine abstract attributes and if sufficient - // information was found we could duplicate the functions that do not - // have an exact definition. - if (!F.hasExactDefinition()) { + if (F.hasExactDefinition()) + NumFnWithExactDefinition++; + else NumFnWithoutExactDefinition++; - continue; - } - // For now we ignore naked and optnone functions. - if (F.hasFnAttribute(Attribute::Naked) || - F.hasFnAttribute(Attribute::OptimizeNone)) - continue; - - NumFnWithExactDefinition++; + // We look at internal functions only on-demand but if any use is not a + // direct call, we have to do it eagerly. + if (F.hasLocalLinkage()) { + if (llvm::all_of(F.uses(), [](const Use &U) { + return ImmutableCallSite(U.getUser()) && + ImmutableCallSite(U.getUser()).isCallee(&U); + })) + continue; + } // Populate the Attributor with abstract attribute opportunities in the // function and the information cache with IR information. - A.identifyDefaultAbstractAttributes(F, InfoCache); + A.identifyDefaultAbstractAttributes(F); } - return A.run() == ChangeStatus::CHANGED; + return A.run(M) == ChangeStatus::CHANGED; } PreservedAnalyses AttributorPass::run(Module &M, ModuleAnalysisManager &AM) { - if (runAttributorOnModule(M)) { + AnalysisGetter AG(AM); + if (runAttributorOnModule(M, AG)) { // FIXME: Think about passes we will preserve and add them here. return PreservedAnalyses::none(); } @@ -1670,12 +4987,14 @@ struct AttributorLegacyPass : public ModulePass { bool runOnModule(Module &M) override { if (skipModule(M)) return false; - return runAttributorOnModule(M); + + AnalysisGetter AG; + return runAttributorOnModule(M, AG); } void getAnalysisUsage(AnalysisUsage &AU) const override { // FIXME: Think about passes we will preserve and add them here. - AU.setPreservesCFG(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; @@ -1684,7 +5003,147 @@ struct AttributorLegacyPass : public ModulePass { Pass *llvm::createAttributorLegacyPass() { return new AttributorLegacyPass(); } char AttributorLegacyPass::ID = 0; + +const char AAReturnedValues::ID = 0; +const char AANoUnwind::ID = 0; +const char AANoSync::ID = 0; +const char AANoFree::ID = 0; +const char AANonNull::ID = 0; +const char AANoRecurse::ID = 0; +const char AAWillReturn::ID = 0; +const char AANoAlias::ID = 0; +const char AANoReturn::ID = 0; +const char AAIsDead::ID = 0; +const char AADereferenceable::ID = 0; +const char AAAlign::ID = 0; +const char AANoCapture::ID = 0; +const char AAValueSimplify::ID = 0; +const char AAHeapToStack::ID = 0; +const char AAMemoryBehavior::ID = 0; + +// Macro magic to create the static generator function for attributes that +// follow the naming scheme. + +#define SWITCH_PK_INV(CLASS, PK, POS_NAME) \ + case IRPosition::PK: \ + llvm_unreachable("Cannot create " #CLASS " for a " POS_NAME " position!"); + +#define SWITCH_PK_CREATE(CLASS, IRP, PK, SUFFIX) \ + case IRPosition::PK: \ + AA = new CLASS##SUFFIX(IRP); \ + break; + +#define CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \ + SWITCH_PK_INV(CLASS, IRP_FLOAT, "floating") \ + SWITCH_PK_INV(CLASS, IRP_ARGUMENT, "argument") \ + SWITCH_PK_INV(CLASS, IRP_RETURNED, "returned") \ + SWITCH_PK_INV(CLASS, IRP_CALL_SITE_RETURNED, "call site returned") \ + SWITCH_PK_INV(CLASS, IRP_CALL_SITE_ARGUMENT, "call site argument") \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE, CallSite) \ + } \ + return *AA; \ + } + +#define CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \ + SWITCH_PK_INV(CLASS, IRP_FUNCTION, "function") \ + SWITCH_PK_INV(CLASS, IRP_CALL_SITE, "call site") \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FLOAT, Floating) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_ARGUMENT, Argument) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_RETURNED, Returned) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_RETURNED, CallSiteReturned) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_ARGUMENT, CallSiteArgument) \ + } \ + return *AA; \ + } + +#define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE, CallSite) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FLOAT, Floating) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_ARGUMENT, Argument) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_RETURNED, Returned) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_RETURNED, CallSiteReturned) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_ARGUMENT, CallSiteArgument) \ + } \ + return *AA; \ + } + +#define CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \ + SWITCH_PK_INV(CLASS, IRP_ARGUMENT, "argument") \ + SWITCH_PK_INV(CLASS, IRP_FLOAT, "floating") \ + SWITCH_PK_INV(CLASS, IRP_RETURNED, "returned") \ + SWITCH_PK_INV(CLASS, IRP_CALL_SITE_RETURNED, "call site returned") \ + SWITCH_PK_INV(CLASS, IRP_CALL_SITE_ARGUMENT, "call site argument") \ + SWITCH_PK_INV(CLASS, IRP_CALL_SITE, "call site") \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \ + } \ + return *AA; \ + } + +#define CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_INV(CLASS, IRP_INVALID, "invalid") \ + SWITCH_PK_INV(CLASS, IRP_RETURNED, "returned") \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FUNCTION, Function) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE, CallSite) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_FLOAT, Floating) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_ARGUMENT, Argument) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_RETURNED, CallSiteReturned) \ + SWITCH_PK_CREATE(CLASS, IRP, IRP_CALL_SITE_ARGUMENT, CallSiteArgument) \ + } \ + return *AA; \ + } + +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUnwind) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoSync) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoRecurse) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAWillReturn) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoReturn) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues) + +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADereferenceable) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAlign) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoCapture) + +CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) + +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack) + +CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) + +#undef CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef SWITCH_PK_CREATE +#undef SWITCH_PK_INV + INITIALIZE_PASS_BEGIN(AttributorLegacyPass, "attributor", "Deduce and propagate attributes", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(AttributorLegacyPass, "attributor", "Deduce and propagate attributes", false, false) diff --git a/lib/Transforms/IPO/BlockExtractor.cpp b/lib/Transforms/IPO/BlockExtractor.cpp index 6c365f3f3cbe..de80c88c1591 100644 --- a/lib/Transforms/IPO/BlockExtractor.cpp +++ b/lib/Transforms/IPO/BlockExtractor.cpp @@ -119,6 +119,8 @@ void BlockExtractor::loadFile() { /*KeepEmpty=*/false); if (LineSplit.empty()) continue; + if (LineSplit.size()!=2) + report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'"); SmallVector<StringRef, 4> BBNames; LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); @@ -204,7 +206,8 @@ bool BlockExtractor::runOnModule(Module &M) { ++NumExtracted; Changed = true; } - Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(); + CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent()); + Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC); if (F) LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() << "' in: " << F->getName() << '\n'); diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index ad877ae1786c..3cf839e397f8 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -48,7 +48,7 @@ static void FindUsedValues(GlobalVariable *LLVMUsed, ConstantArray *Inits = cast<ConstantArray>(LLVMUsed->getInitializer()); for (unsigned i = 0, e = Inits->getNumOperands(); i != e; ++i) { - Value *Operand = Inits->getOperand(i)->stripPointerCastsNoFollowAliases(); + Value *Operand = Inits->getOperand(i)->stripPointerCasts(); GlobalValue *GV = cast<GlobalValue>(Operand); UsedValues.insert(GV); } @@ -120,7 +120,7 @@ static void replace(Module &M, GlobalVariable *Old, GlobalVariable *New) { // Bump the alignment if necessary. if (Old->getAlignment() || New->getAlignment()) - New->setAlignment(std::max(getAlignment(Old), getAlignment(New))); + New->setAlignment(Align(std::max(getAlignment(Old), getAlignment(New)))); copyDebugLocMetadata(Old, New); Old->replaceAllUsesWith(NewConstant); diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index e30b33aa4872..e20159ba0db5 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -84,13 +84,9 @@ void CrossDSOCFI::buildCFICheck(Module &M) { for (GlobalObject &GO : M.global_objects()) { Types.clear(); GO.getMetadata(LLVMContext::MD_type, Types); - for (MDNode *Type : Types) { - // Sanity check. GO must not be a function declaration. - assert(!isa<Function>(&GO) || !cast<Function>(&GO)->isDeclaration()); - + for (MDNode *Type : Types) if (ConstantInt *TypeId = extractNumericTypeId(Type)) TypeIds.insert(TypeId->getZExtValue()); - } } NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); @@ -108,11 +104,11 @@ void CrossDSOCFI::buildCFICheck(Module &M) { FunctionCallee C = M.getOrInsertFunction( "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); - Function *F = dyn_cast<Function>(C.getCallee()); + Function *F = cast<Function>(C.getCallee()); // Take over the existing function. The frontend emits a weak stub so that the // linker knows about the symbol; this pass replaces the function body. F->deleteBody(); - F->setAlignment(4096); + F->setAlignment(Align(4096)); Triple T(M.getTargetTriple()); if (T.isARM() || T.isThumb()) diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 5ccd8bc4b0fb..b174c63a577b 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -78,11 +78,8 @@ STATISTIC(NumNoRecurse, "Number of functions marked as norecurse"); STATISTIC(NumNoUnwind, "Number of functions marked as nounwind"); STATISTIC(NumNoFree, "Number of functions marked as nofree"); -// FIXME: This is disabled by default to avoid exposing security vulnerabilities -// in C/C++ code compiled by clang: -// http://lists.llvm.org/pipermail/cfe-dev/2017-January/052066.html static cl::opt<bool> EnableNonnullArgPropagation( - "enable-nonnull-arg-prop", cl::Hidden, + "enable-nonnull-arg-prop", cl::init(true), cl::Hidden, cl::desc("Try to propagate nonnull argument attributes from callsites to " "caller functions.")); @@ -664,6 +661,25 @@ static bool addArgumentAttrsFromCallsites(Function &F) { return Changed; } +static bool addReadAttr(Argument *A, Attribute::AttrKind R) { + assert((R == Attribute::ReadOnly || R == Attribute::ReadNone) + && "Must be a Read attribute."); + assert(A && "Argument must not be null."); + + // If the argument already has the attribute, nothing needs to be done. + if (A->hasAttribute(R)) + return false; + + // Otherwise, remove potentially conflicting attribute, add the new one, + // and update statistics. + A->removeAttr(Attribute::WriteOnly); + A->removeAttr(Attribute::ReadOnly); + A->removeAttr(Attribute::ReadNone); + A->addAttr(R); + R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg; + return true; +} + /// Deduce nocapture attributes for the SCC. static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { bool Changed = false; @@ -732,11 +748,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { SmallPtrSet<Argument *, 8> Self; Self.insert(&*A); Attribute::AttrKind R = determinePointerReadAttrs(&*A, Self); - if (R != Attribute::None) { - A->addAttr(R); - Changed = true; - R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg; - } + if (R != Attribute::None) + Changed = addReadAttr(A, R); } } } @@ -833,12 +846,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (ReadAttr != Attribute::None) { for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) { Argument *A = ArgumentSCC[i]->Definition; - // Clear out existing readonly/readnone attributes - A->removeAttr(Attribute::ReadOnly); - A->removeAttr(Attribute::ReadNone); - A->addAttr(ReadAttr); - ReadAttr == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg; - Changed = true; + Changed = addReadAttr(A, ReadAttr); } } } diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index 62c7fbd07223..3f5cc078d75f 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -450,7 +450,7 @@ static void computeImportForFunction( } else if (PrintImportFailures) { assert(!FailureInfo && "Expected no FailureInfo for newly rejected candidate"); - FailureInfo = llvm::make_unique<FunctionImporter::ImportFailureInfo>( + FailureInfo = std::make_unique<FunctionImporter::ImportFailureInfo>( VI, Edge.second.getHotness(), Reason, 1); } LLVM_DEBUG( @@ -764,7 +764,7 @@ void llvm::computeDeadSymbols( } // Make value live and add it to the worklist if it was not live before. - auto visit = [&](ValueInfo VI) { + auto visit = [&](ValueInfo VI, bool IsAliasee) { // FIXME: If we knew which edges were created for indirect call profiles, // we could skip them here. Any that are live should be reached via // other edges, e.g. reference edges. Otherwise, using a profile collected @@ -800,12 +800,15 @@ void llvm::computeDeadSymbols( Interposable = true; } - if (!KeepAliveLinkage) - return; + if (!IsAliasee) { + if (!KeepAliveLinkage) + return; - if (Interposable) - report_fatal_error( - "Interposable and available_externally/linkonce_odr/weak_odr symbol"); + if (Interposable) + report_fatal_error( + "Interposable and available_externally/linkonce_odr/weak_odr " + "symbol"); + } } for (auto &S : VI.getSummaryList()) @@ -821,16 +824,16 @@ void llvm::computeDeadSymbols( // If this is an alias, visit the aliasee VI to ensure that all copies // are marked live and it is added to the worklist for further // processing of its references. - visit(AS->getAliaseeVI()); + visit(AS->getAliaseeVI(), true); continue; } Summary->setLive(true); for (auto Ref : Summary->refs()) - visit(Ref); + visit(Ref, false); if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) for (auto Call : FS->calls()) - visit(Call.first); + visit(Call.first, false); } } Index.setWithGlobalValueDeadStripping(); @@ -892,7 +895,7 @@ std::error_code llvm::EmitImportsFiles( StringRef ModulePath, StringRef OutputFilename, const std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { std::error_code EC; - raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::F_None); + raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::OF_None); if (EC) return EC; for (auto &ILI : ModuleToSummariesForIndex) @@ -948,23 +951,15 @@ void llvm::thinLTOResolvePrevailingInModule( auto NewLinkage = GS->second->linkage(); if (NewLinkage == GV.getLinkage()) return; - - // Switch the linkage to weakany if asked for, e.g. we do this for - // linker redefined symbols (via --wrap or --defsym). - // We record that the visibility should be changed here in `addThinLTO` - // as we need access to the resolution vectors for each input file in - // order to find which symbols have been redefined. - // We may consider reorganizing this code and moving the linkage recording - // somewhere else, e.g. in thinLTOResolvePrevailingInIndex. - if (NewLinkage == GlobalValue::WeakAnyLinkage) { - GV.setLinkage(NewLinkage); - return; - } - if (GlobalValue::isLocalLinkage(GV.getLinkage()) || + // Don't internalize anything here, because the code below + // lacks necessary correctness checks. Leave this job to + // LLVM 'internalize' pass. + GlobalValue::isLocalLinkage(NewLinkage) || // In case it was dead and already converted to declaration. GV.isDeclaration()) return; + // Check for a non-prevailing def that has interposable linkage // (e.g. non-odr weak or linkonce). In that case we can't simply // convert to available_externally, since it would lose the diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index 86b7f3e49ee6..f010f7b703a6 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -17,9 +17,11 @@ #include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CtorUtils.h" @@ -29,10 +31,15 @@ using namespace llvm; #define DEBUG_TYPE "globaldce" +static cl::opt<bool> + ClEnableVFE("enable-vfe", cl::Hidden, cl::init(true), cl::ZeroOrMore, + cl::desc("Enable virtual function elimination")); + STATISTIC(NumAliases , "Number of global aliases removed"); STATISTIC(NumFunctions, "Number of functions removed"); STATISTIC(NumIFuncs, "Number of indirect functions removed"); STATISTIC(NumVariables, "Number of global variables removed"); +STATISTIC(NumVFuncs, "Number of virtual functions removed"); namespace { class GlobalDCELegacyPass : public ModulePass { @@ -118,6 +125,15 @@ void GlobalDCEPass::UpdateGVDependencies(GlobalValue &GV) { ComputeDependencies(User, Deps); Deps.erase(&GV); // Remove self-reference. for (GlobalValue *GVU : Deps) { + // If this is a dep from a vtable to a virtual function, and we have + // complete information about all virtual call sites which could call + // though this vtable, then skip it, because the call site information will + // be more precise. + if (VFESafeVTables.count(GVU) && isa<Function>(&GV)) { + LLVM_DEBUG(dbgs() << "Ignoring dep " << GVU->getName() << " -> " + << GV.getName() << "\n"); + continue; + } GVDependencies[GVU].insert(&GV); } } @@ -132,12 +148,133 @@ void GlobalDCEPass::MarkLive(GlobalValue &GV, if (Updates) Updates->push_back(&GV); if (Comdat *C = GV.getComdat()) { - for (auto &&CM : make_range(ComdatMembers.equal_range(C))) + for (auto &&CM : make_range(ComdatMembers.equal_range(C))) { MarkLive(*CM.second, Updates); // Recursion depth is only two because only // globals in the same comdat are visited. + } + } +} + +void GlobalDCEPass::ScanVTables(Module &M) { + SmallVector<MDNode *, 2> Types; + LLVM_DEBUG(dbgs() << "Building type info -> vtable map\n"); + + auto *LTOPostLinkMD = + cast_or_null<ConstantAsMetadata>(M.getModuleFlag("LTOPostLink")); + bool LTOPostLink = + LTOPostLinkMD && + (cast<ConstantInt>(LTOPostLinkMD->getValue())->getZExtValue() != 0); + + for (GlobalVariable &GV : M.globals()) { + Types.clear(); + GV.getMetadata(LLVMContext::MD_type, Types); + if (GV.isDeclaration() || Types.empty()) + continue; + + // Use the typeid metadata on the vtable to build a mapping from typeids to + // the list of (GV, offset) pairs which are the possible vtables for that + // typeid. + for (MDNode *Type : Types) { + Metadata *TypeID = Type->getOperand(1).get(); + + uint64_t Offset = + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + + TypeIdMap[TypeID].insert(std::make_pair(&GV, Offset)); + } + + // If the type corresponding to the vtable is private to this translation + // unit, we know that we can see all virtual functions which might use it, + // so VFE is safe. + if (auto GO = dyn_cast<GlobalObject>(&GV)) { + GlobalObject::VCallVisibility TypeVis = GO->getVCallVisibility(); + if (TypeVis == GlobalObject::VCallVisibilityTranslationUnit || + (LTOPostLink && + TypeVis == GlobalObject::VCallVisibilityLinkageUnit)) { + LLVM_DEBUG(dbgs() << GV.getName() << " is safe for VFE\n"); + VFESafeVTables.insert(&GV); + } + } + } +} + +void GlobalDCEPass::ScanVTableLoad(Function *Caller, Metadata *TypeId, + uint64_t CallOffset) { + for (auto &VTableInfo : TypeIdMap[TypeId]) { + GlobalVariable *VTable = VTableInfo.first; + uint64_t VTableOffset = VTableInfo.second; + + Constant *Ptr = + getPointerAtOffset(VTable->getInitializer(), VTableOffset + CallOffset, + *Caller->getParent()); + if (!Ptr) { + LLVM_DEBUG(dbgs() << "can't find pointer in vtable!\n"); + VFESafeVTables.erase(VTable); + return; + } + + auto Callee = dyn_cast<Function>(Ptr->stripPointerCasts()); + if (!Callee) { + LLVM_DEBUG(dbgs() << "vtable entry is not function pointer!\n"); + VFESafeVTables.erase(VTable); + return; + } + + LLVM_DEBUG(dbgs() << "vfunc dep " << Caller->getName() << " -> " + << Callee->getName() << "\n"); + GVDependencies[Caller].insert(Callee); } } +void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) { + LLVM_DEBUG(dbgs() << "Scanning type.checked.load intrinsics\n"); + Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + + if (!TypeCheckedLoadFunc) + return; + + for (auto U : TypeCheckedLoadFunc->users()) { + auto CI = dyn_cast<CallInst>(U); + if (!CI) + continue; + + auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + Value *TypeIdValue = CI->getArgOperand(2); + auto *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); + + if (Offset) { + ScanVTableLoad(CI->getFunction(), TypeId, Offset->getZExtValue()); + } else { + // type.checked.load with a non-constant offset, so assume every entry in + // every matching vtable is used. + for (auto &VTableInfo : TypeIdMap[TypeId]) { + VFESafeVTables.erase(VTableInfo.first); + } + } + } +} + +void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) { + if (!ClEnableVFE) + return; + + ScanVTables(M); + + if (VFESafeVTables.empty()) + return; + + ScanTypeCheckedLoadIntrinsics(M); + + LLVM_DEBUG( + dbgs() << "VFE safe vtables:\n"; + for (auto *VTable : VFESafeVTables) + dbgs() << " " << VTable->getName() << "\n"; + ); +} + PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { bool Changed = false; @@ -163,6 +300,10 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { if (Comdat *C = GA.getComdat()) ComdatMembers.insert(std::make_pair(C, &GA)); + // Add dependencies between virtual call sites and the virtual functions they + // might call, if we have that information. + AddVirtualFunctionDependencies(M); + // Loop over the module, adding globals which are obviously necessary. for (GlobalObject &GO : M.global_objects()) { Changed |= RemoveUnusedGlobalValue(GO); @@ -257,8 +398,17 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { }; NumFunctions += DeadFunctions.size(); - for (Function *F : DeadFunctions) + for (Function *F : DeadFunctions) { + if (!F->use_empty()) { + // Virtual functions might still be referenced by one or more vtables, + // but if we've proven them to be unused then it's safe to replace the + // virtual function pointers with null, allowing us to remove the + // function itself. + ++NumVFuncs; + F->replaceNonMetadataUsesWith(ConstantPointerNull::get(F->getType())); + } EraseUnusedGlobalValue(F); + } NumVariables += DeadGlobalVars.size(); for (GlobalVariable *GV : DeadGlobalVars) @@ -277,6 +427,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { ConstantDependenciesCache.clear(); GVDependencies.clear(); ComdatMembers.clear(); + TypeIdMap.clear(); + VFESafeVTables.clear(); if (Changed) return PreservedAnalyses::none(); diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index c4fb3ce77f6e..819715b9f8da 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -155,7 +155,8 @@ static bool isLeakCheckerRoot(GlobalVariable *GV) { /// Given a value that is stored to a global but never read, determine whether /// it's safe to remove the store and the chain of computation that feeds the /// store. -static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) { +static bool IsSafeComputationToRemove( + Value *V, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { do { if (isa<Constant>(V)) return true; @@ -164,7 +165,7 @@ static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) { if (isa<LoadInst>(V) || isa<InvokeInst>(V) || isa<Argument>(V) || isa<GlobalValue>(V)) return false; - if (isAllocationFn(V, TLI)) + if (isAllocationFn(V, GetTLI)) return true; Instruction *I = cast<Instruction>(V); @@ -184,8 +185,9 @@ static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) { /// This GV is a pointer root. Loop over all users of the global and clean up /// any that obviously don't assign the global a value that isn't dynamically /// allocated. -static bool CleanupPointerRootUsers(GlobalVariable *GV, - const TargetLibraryInfo *TLI) { +static bool +CleanupPointerRootUsers(GlobalVariable *GV, + function_ref<TargetLibraryInfo &(Function &)> GetTLI) { // A brief explanation of leak checkers. The goal is to find bugs where // pointers are forgotten, causing an accumulating growth in memory // usage over time. The common strategy for leak checkers is to whitelist the @@ -241,18 +243,18 @@ static bool CleanupPointerRootUsers(GlobalVariable *GV, C->destroyConstant(); // This could have invalidated UI, start over from scratch. Dead.clear(); - CleanupPointerRootUsers(GV, TLI); + CleanupPointerRootUsers(GV, GetTLI); return true; } } } for (int i = 0, e = Dead.size(); i != e; ++i) { - if (IsSafeComputationToRemove(Dead[i].first, TLI)) { + if (IsSafeComputationToRemove(Dead[i].first, GetTLI)) { Dead[i].second->eraseFromParent(); Instruction *I = Dead[i].first; do { - if (isAllocationFn(I, TLI)) + if (isAllocationFn(I, GetTLI)) break; Instruction *J = dyn_cast<Instruction>(I->getOperand(0)); if (!J) @@ -270,9 +272,9 @@ static bool CleanupPointerRootUsers(GlobalVariable *GV, /// We just marked GV constant. Loop over all users of the global, cleaning up /// the obvious ones. This is largely just a quick scan over the use list to /// clean up the easy and obvious cruft. This returns true if it made a change. -static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, - const DataLayout &DL, - TargetLibraryInfo *TLI) { +static bool CleanupConstantGlobalUsers( + Value *V, Constant *Init, const DataLayout &DL, + function_ref<TargetLibraryInfo &(Function &)> GetTLI) { bool Changed = false; // Note that we need to use a weak value handle for the worklist items. When // we delete a constant array, we may also be holding pointer to one of its @@ -302,12 +304,12 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, Constant *SubInit = nullptr; if (Init) SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE); - Changed |= CleanupConstantGlobalUsers(CE, SubInit, DL, TLI); + Changed |= CleanupConstantGlobalUsers(CE, SubInit, DL, GetTLI); } else if ((CE->getOpcode() == Instruction::BitCast && CE->getType()->isPointerTy()) || CE->getOpcode() == Instruction::AddrSpaceCast) { // Pointer cast, delete any stores and memsets to the global. - Changed |= CleanupConstantGlobalUsers(CE, nullptr, DL, TLI); + Changed |= CleanupConstantGlobalUsers(CE, nullptr, DL, GetTLI); } if (CE->use_empty()) { @@ -321,7 +323,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, Constant *SubInit = nullptr; if (!isa<ConstantExpr>(GEP->getOperand(0))) { ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>( - ConstantFoldInstruction(GEP, DL, TLI)); + ConstantFoldInstruction(GEP, DL, &GetTLI(*GEP->getFunction()))); if (Init && CE && CE->getOpcode() == Instruction::GetElementPtr) SubInit = ConstantFoldLoadThroughGEPConstantExpr(Init, CE); @@ -331,7 +333,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, if (Init && isa<ConstantAggregateZero>(Init) && GEP->isInBounds()) SubInit = Constant::getNullValue(GEP->getResultElementType()); } - Changed |= CleanupConstantGlobalUsers(GEP, SubInit, DL, TLI); + Changed |= CleanupConstantGlobalUsers(GEP, SubInit, DL, GetTLI); if (GEP->use_empty()) { GEP->eraseFromParent(); @@ -348,7 +350,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, // us, and if they are all dead, nuke them without remorse. if (isSafeToDestroyConstant(C)) { C->destroyConstant(); - CleanupConstantGlobalUsers(V, Init, DL, TLI); + CleanupConstantGlobalUsers(V, Init, DL, GetTLI); return true; } } @@ -495,8 +497,8 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { // had 256 byte alignment for example, something might depend on that: // propagate info to each field. uint64_t FieldOffset = Layout.getElementOffset(i); - unsigned NewAlign = (unsigned)MinAlign(StartAlignment, FieldOffset); - if (NewAlign > DL.getABITypeAlignment(STy->getElementType(i))) + Align NewAlign(MinAlign(StartAlignment, FieldOffset)); + if (NewAlign > Align(DL.getABITypeAlignment(STy->getElementType(i)))) NGV->setAlignment(NewAlign); // Copy over the debug info for the variable. @@ -511,7 +513,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { NewGlobals.reserve(NumElements); auto ElTy = STy->getElementType(); uint64_t EltSize = DL.getTypeAllocSize(ElTy); - unsigned EltAlign = DL.getABITypeAlignment(ElTy); + Align EltAlign(DL.getABITypeAlignment(ElTy)); uint64_t FragmentSizeInBits = DL.getTypeAllocSizeInBits(ElTy); for (unsigned i = 0, e = NumElements; i != e; ++i) { Constant *In = Init->getAggregateElement(i); @@ -530,7 +532,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { // Calculate the known alignment of the field. If the original aggregate // had 256 byte alignment for example, something might depend on that: // propagate info to each field. - unsigned NewAlign = (unsigned)MinAlign(StartAlignment, EltSize*i); + Align NewAlign(MinAlign(StartAlignment, EltSize * i)); if (NewAlign > EltAlign) NGV->setAlignment(NewAlign); transferSRADebugInfo(GV, NGV, FragmentSizeInBits * i, FragmentSizeInBits, @@ -745,9 +747,9 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { /// are uses of the loaded value that would trap if the loaded value is /// dynamically null, then we know that they cannot be reachable with a null /// optimize away the load. -static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV, - const DataLayout &DL, - TargetLibraryInfo *TLI) { +static bool OptimizeAwayTrappingUsesOfLoads( + GlobalVariable *GV, Constant *LV, const DataLayout &DL, + function_ref<TargetLibraryInfo &(Function &)> GetTLI) { bool Changed = false; // Keep track of whether we are able to remove all the uses of the global @@ -793,10 +795,10 @@ static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV, // nor is the global. if (AllNonStoreUsesGone) { if (isLeakCheckerRoot(GV)) { - Changed |= CleanupPointerRootUsers(GV, TLI); + Changed |= CleanupPointerRootUsers(GV, GetTLI); } else { Changed = true; - CleanupConstantGlobalUsers(GV, nullptr, DL, TLI); + CleanupConstantGlobalUsers(GV, nullptr, DL, GetTLI); } if (GV->use_empty()) { LLVM_DEBUG(dbgs() << " *** GLOBAL NOW DEAD!\n"); @@ -889,8 +891,8 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, while (!GV->use_empty()) { if (StoreInst *SI = dyn_cast<StoreInst>(GV->user_back())) { // The global is initialized when the store to it occurs. - new StoreInst(ConstantInt::getTrue(GV->getContext()), InitBool, false, 0, - SI->getOrdering(), SI->getSyncScopeID(), SI); + new StoreInst(ConstantInt::getTrue(GV->getContext()), InitBool, false, + None, SI->getOrdering(), SI->getSyncScopeID(), SI); SI->eraseFromParent(); continue; } @@ -907,7 +909,7 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, // Replace the cmp X, 0 with a use of the bool value. // Sink the load to where the compare was, if atomic rules allow us to. Value *LV = new LoadInst(InitBool->getValueType(), InitBool, - InitBool->getName() + ".val", false, 0, + InitBool->getName() + ".val", false, None, LI->getOrdering(), LI->getSyncScopeID(), LI->isUnordered() ? (Instruction *)ICI : LI); InitBoolUsed = true; @@ -1562,10 +1564,10 @@ static bool tryToOptimizeStoreOfMallocToGlobal(GlobalVariable *GV, CallInst *CI, // Try to optimize globals based on the knowledge that only one value (besides // its initializer) is ever stored to the global. -static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, - AtomicOrdering Ordering, - const DataLayout &DL, - TargetLibraryInfo *TLI) { +static bool +optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, + AtomicOrdering Ordering, const DataLayout &DL, + function_ref<TargetLibraryInfo &(Function &)> GetTLI) { // Ignore no-op GEPs and bitcasts. StoredOnceVal = StoredOnceVal->stripPointerCasts(); @@ -1583,9 +1585,10 @@ static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, SOVC = ConstantExpr::getBitCast(SOVC, GV->getInitializer()->getType()); // Optimize away any trapping uses of the loaded value. - if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC, DL, TLI)) + if (OptimizeAwayTrappingUsesOfLoads(GV, SOVC, DL, GetTLI)) return true; - } else if (CallInst *CI = extractMallocCall(StoredOnceVal, TLI)) { + } else if (CallInst *CI = extractMallocCall(StoredOnceVal, GetTLI)) { + auto *TLI = &GetTLI(*CI->getFunction()); Type *MallocType = getMallocAllocatedType(CI, TLI); if (MallocType && tryToOptimizeStoreOfMallocToGlobal(GV, CI, MallocType, Ordering, DL, TLI)) @@ -1643,10 +1646,12 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { // instead of a select to synthesize the desired value. bool IsOneZero = false; bool EmitOneOrZero = true; - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)){ + auto *CI = dyn_cast<ConstantInt>(OtherVal); + if (CI && CI->getValue().getActiveBits() <= 64) { IsOneZero = InitVal->isNullValue() && CI->isOne(); - if (ConstantInt *CIInit = dyn_cast<ConstantInt>(GV->getInitializer())){ + auto *CIInit = dyn_cast<ConstantInt>(GV->getInitializer()); + if (CIInit && CIInit->getValue().getActiveBits() <= 64) { uint64_t ValInit = CIInit->getZExtValue(); uint64_t ValOther = CI->getZExtValue(); uint64_t ValMinus = ValOther - ValInit; @@ -1711,7 +1716,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { assert(LI->getOperand(0) == GV && "Not a copy!"); // Insert a new load, to preserve the saved value. StoreVal = new LoadInst(NewGV->getValueType(), NewGV, - LI->getName() + ".b", false, 0, + LI->getName() + ".b", false, None, LI->getOrdering(), LI->getSyncScopeID(), LI); } else { assert((isa<CastInst>(StoredVal) || isa<SelectInst>(StoredVal)) && @@ -1721,15 +1726,15 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { } } StoreInst *NSI = - new StoreInst(StoreVal, NewGV, false, 0, SI->getOrdering(), + new StoreInst(StoreVal, NewGV, false, None, SI->getOrdering(), SI->getSyncScopeID(), SI); NSI->setDebugLoc(SI->getDebugLoc()); } else { // Change the load into a load of bool then a select. LoadInst *LI = cast<LoadInst>(UI); - LoadInst *NLI = - new LoadInst(NewGV->getValueType(), NewGV, LI->getName() + ".b", - false, 0, LI->getOrdering(), LI->getSyncScopeID(), LI); + LoadInst *NLI = new LoadInst(NewGV->getValueType(), NewGV, + LI->getName() + ".b", false, None, + LI->getOrdering(), LI->getSyncScopeID(), LI); Instruction *NSI; if (IsOneZero) NSI = new ZExtInst(NLI, LI->getType(), "", LI); @@ -1914,9 +1919,10 @@ static void makeAllConstantUsesInstructions(Constant *C) { /// Analyze the specified global variable and optimize /// it if possible. If we make a change, return true. -static bool processInternalGlobal( - GlobalVariable *GV, const GlobalStatus &GS, TargetLibraryInfo *TLI, - function_ref<DominatorTree &(Function &)> LookupDomTree) { +static bool +processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, + function_ref<DominatorTree &(Function &)> LookupDomTree) { auto &DL = GV->getParent()->getDataLayout(); // If this is a first class global and has only one accessing function and // this function is non-recursive, we replace the global with a local alloca @@ -1963,11 +1969,12 @@ static bool processInternalGlobal( bool Changed; if (isLeakCheckerRoot(GV)) { // Delete any constant stores to the global. - Changed = CleanupPointerRootUsers(GV, TLI); + Changed = CleanupPointerRootUsers(GV, GetTLI); } else { // Delete any stores we can find to the global. We may not be able to // make it completely dead though. - Changed = CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI); + Changed = + CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); } // If the global is dead now, delete it. @@ -1989,7 +1996,7 @@ static bool processInternalGlobal( GV->setConstant(true); // Clean up any obviously simplifiable users now. - CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI); + CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); // If the global is dead now, just nuke it. if (GV->use_empty()) { @@ -2019,7 +2026,7 @@ static bool processInternalGlobal( GV->setInitializer(SOVConstant); // Clean up any obviously simplifiable users now. - CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI); + CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, GetTLI); if (GV->use_empty()) { LLVM_DEBUG(dbgs() << " *** Substituting initializer allowed us to " @@ -2033,7 +2040,8 @@ static bool processInternalGlobal( // Try to optimize globals based on the knowledge that only one value // (besides its initializer) is ever stored to the global. - if (optimizeOnceStoredGlobal(GV, GS.StoredOnceValue, GS.Ordering, DL, TLI)) + if (optimizeOnceStoredGlobal(GV, GS.StoredOnceValue, GS.Ordering, DL, + GetTLI)) return true; // Otherwise, if the global was not a boolean, we can shrink it to be a @@ -2054,7 +2062,8 @@ static bool processInternalGlobal( /// Analyze the specified global variable and optimize it if possible. If we /// make a change, return true. static bool -processGlobal(GlobalValue &GV, TargetLibraryInfo *TLI, +processGlobal(GlobalValue &GV, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree) { if (GV.getName().startswith("llvm.")) return false; @@ -2086,7 +2095,7 @@ processGlobal(GlobalValue &GV, TargetLibraryInfo *TLI, if (GVar->isConstant() || !GVar->hasInitializer()) return Changed; - return processInternalGlobal(GVar, GS, TLI, LookupDomTree) || Changed; + return processInternalGlobal(GVar, GS, GetTLI, LookupDomTree) || Changed; } /// Walk all of the direct calls of the specified function, changing them to @@ -2234,7 +2243,8 @@ hasOnlyColdCalls(Function &F, } static bool -OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, +OptimizeFunctions(Module &M, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<BlockFrequencyInfo &(Function &)> GetBFI, function_ref<DominatorTree &(Function &)> LookupDomTree, @@ -2275,17 +2285,13 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, // So, remove unreachable blocks from the function, because a) there's // no point in analyzing them and b) GlobalOpt should otherwise grow // some more complicated logic to break these cycles. - // Removing unreachable blocks might invalidate the dominator so we - // recalculate it. if (!F->isDeclaration()) { - if (removeUnreachableBlocks(*F)) { - auto &DT = LookupDomTree(*F); - DT.recalculate(*F); - Changed = true; - } + auto &DT = LookupDomTree(*F); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Changed |= removeUnreachableBlocks(*F, &DTU); } - Changed |= processGlobal(*F, TLI, LookupDomTree); + Changed |= processGlobal(*F, GetTLI, LookupDomTree); if (!F->hasLocalLinkage()) continue; @@ -2342,7 +2348,8 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, } static bool -OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI, +OptimizeGlobalVars(Module &M, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree, SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { bool Changed = false; @@ -2357,7 +2364,10 @@ OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI, if (GV->hasInitializer()) if (auto *C = dyn_cast<Constant>(GV->getInitializer())) { auto &DL = M.getDataLayout(); - Constant *New = ConstantFoldConstant(C, DL, TLI); + // TLI is not used in the case of a Constant, so use default nullptr + // for that optional parameter, since we don't have a Function to + // provide GetTLI anyway. + Constant *New = ConstantFoldConstant(C, DL, /*TLI*/ nullptr); if (New && New != C) GV->setInitializer(New); } @@ -2367,7 +2377,7 @@ OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI, continue; } - Changed |= processGlobal(*GV, TLI, LookupDomTree); + Changed |= processGlobal(*GV, GetTLI, LookupDomTree); } return Changed; } @@ -2581,8 +2591,8 @@ static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, } static int compareNames(Constant *const *A, Constant *const *B) { - Value *AStripped = (*A)->stripPointerCastsNoFollowAliases(); - Value *BStripped = (*B)->stripPointerCastsNoFollowAliases(); + Value *AStripped = (*A)->stripPointerCasts(); + Value *BStripped = (*B)->stripPointerCasts(); return AStripped->getName().compare(BStripped->getName()); } @@ -2809,7 +2819,14 @@ OptimizeGlobalAliases(Module &M, return Changed; } -static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { +static Function * +FindCXAAtExit(Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { + // Hack to get a default TLI before we have actual Function. + auto FuncIter = M.begin(); + if (FuncIter == M.end()) + return nullptr; + auto *TLI = &GetTLI(*FuncIter); + LibFunc F = LibFunc_cxa_atexit; if (!TLI->has(F)) return nullptr; @@ -2818,6 +2835,9 @@ static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { if (!Fn) return nullptr; + // Now get the actual TLI for Fn. + TLI = &GetTLI(*Fn); + // Make sure that the function has the correct prototype. if (!TLI->getLibFunc(*Fn, F) || F != LibFunc_cxa_atexit) return nullptr; @@ -2889,7 +2909,8 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { } static bool optimizeGlobalsInModule( - Module &M, const DataLayout &DL, TargetLibraryInfo *TLI, + Module &M, const DataLayout &DL, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<BlockFrequencyInfo &(Function &)> GetBFI, function_ref<DominatorTree &(Function &)> LookupDomTree) { @@ -2914,24 +2935,24 @@ static bool optimizeGlobalsInModule( NotDiscardableComdats.insert(C); // Delete functions that are trivially dead, ccc -> fastcc - LocalChange |= OptimizeFunctions(M, TLI, GetTTI, GetBFI, LookupDomTree, + LocalChange |= OptimizeFunctions(M, GetTLI, GetTTI, GetBFI, LookupDomTree, NotDiscardableComdats); // Optimize global_ctors list. LocalChange |= optimizeGlobalCtorsList(M, [&](Function *F) { - return EvaluateStaticConstructor(F, DL, TLI); + return EvaluateStaticConstructor(F, DL, &GetTLI(*F)); }); // Optimize non-address-taken globals. - LocalChange |= OptimizeGlobalVars(M, TLI, LookupDomTree, - NotDiscardableComdats); + LocalChange |= + OptimizeGlobalVars(M, GetTLI, LookupDomTree, NotDiscardableComdats); // Resolve aliases, when possible. LocalChange |= OptimizeGlobalAliases(M, NotDiscardableComdats); // Try to remove trivial global destructors if they are not removed // already. - Function *CXAAtExitFn = FindCXAAtExit(M, TLI); + Function *CXAAtExitFn = FindCXAAtExit(M, GetTLI); if (CXAAtExitFn) LocalChange |= OptimizeEmptyGlobalCXXDtors(CXAAtExitFn); @@ -2946,12 +2967,14 @@ static bool optimizeGlobalsInModule( PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { auto &DL = M.getDataLayout(); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupDomTree = [&FAM](Function &F) -> DominatorTree &{ return FAM.getResult<DominatorTreeAnalysis>(F); }; + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + }; auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & { return FAM.getResult<TargetIRAnalysis>(F); }; @@ -2960,7 +2983,7 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { return FAM.getResult<BlockFrequencyAnalysis>(F); }; - if (!optimizeGlobalsInModule(M, DL, &TLI, GetTTI, GetBFI, LookupDomTree)) + if (!optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -2979,10 +3002,12 @@ struct GlobalOptLegacyPass : public ModulePass { return false; auto &DL = M.getDataLayout(); - auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); auto LookupDomTree = [this](Function &F) -> DominatorTree & { return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); }; + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; auto GetTTI = [this](Function &F) -> TargetTransformInfo & { return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); }; @@ -2991,7 +3016,8 @@ struct GlobalOptLegacyPass : public ModulePass { return this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - return optimizeGlobalsInModule(M, DL, TLI, GetTTI, GetBFI, LookupDomTree); + return optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, + LookupDomTree); } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/lib/Transforms/IPO/HotColdSplitting.cpp b/lib/Transforms/IPO/HotColdSplitting.cpp index ab1a9a79cad6..cfdcc8db7f50 100644 --- a/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/lib/Transforms/IPO/HotColdSplitting.cpp @@ -85,12 +85,6 @@ static cl::opt<int> "multiple of TCC_Basic)")); namespace { - -/// A sequence of basic blocks. -/// -/// A 0-sized SmallVector is slightly cheaper to move than a std::vector. -using BlockSequence = SmallVector<BasicBlock *, 0>; - // Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify // this function unless you modify the MBB version as well. // @@ -169,31 +163,6 @@ static bool markFunctionCold(Function &F, bool UpdateEntryCount = false) { return Changed; } -class HotColdSplitting { -public: - HotColdSplitting(ProfileSummaryInfo *ProfSI, - function_ref<BlockFrequencyInfo *(Function &)> GBFI, - function_ref<TargetTransformInfo &(Function &)> GTTI, - std::function<OptimizationRemarkEmitter &(Function &)> *GORE, - function_ref<AssumptionCache *(Function &)> LAC) - : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {} - bool run(Module &M); - -private: - bool isFunctionCold(const Function &F) const; - bool shouldOutlineFrom(const Function &F) const; - bool outlineColdRegions(Function &F, bool HasProfileSummary); - Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT, - BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, - OptimizationRemarkEmitter &ORE, - AssumptionCache *AC, unsigned Count); - ProfileSummaryInfo *PSI; - function_ref<BlockFrequencyInfo *(Function &)> GetBFI; - function_ref<TargetTransformInfo &(Function &)> GetTTI; - std::function<OptimizationRemarkEmitter &(Function &)> *GetORE; - function_ref<AssumptionCache *(Function &)> LookupAC; -}; - class HotColdSplittingLegacyPass : public ModulePass { public: static char ID; @@ -321,13 +290,10 @@ static int getOutliningPenalty(ArrayRef<BasicBlock *> Region, return Penalty; } -Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region, - DominatorTree &DT, - BlockFrequencyInfo *BFI, - TargetTransformInfo &TTI, - OptimizationRemarkEmitter &ORE, - AssumptionCache *AC, - unsigned Count) { +Function *HotColdSplitting::extractColdRegion( + const BlockSequence &Region, const CodeExtractorAnalysisCache &CEAC, + DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, + OptimizationRemarkEmitter &ORE, AssumptionCache *AC, unsigned Count) { assert(!Region.empty()); // TODO: Pass BFI and BPI to update profile information. @@ -349,7 +315,7 @@ Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region, return nullptr; Function *OrigF = Region[0]->getParent(); - if (Function *OutF = CE.extractCodeRegion()) { + if (Function *OutF = CE.extractCodeRegion(CEAC)) { User *U = *OutF->user_begin(); CallInst *CI = cast<CallInst>(U); CallSite CS(CI); @@ -607,9 +573,9 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { }); if (!DT) - DT = make_unique<DominatorTree>(F); + DT = std::make_unique<DominatorTree>(F); if (!PDT) - PDT = make_unique<PostDominatorTree>(F); + PDT = std::make_unique<PostDominatorTree>(F); auto Regions = OutliningRegion::create(*BB, *DT, *PDT); for (OutliningRegion &Region : Regions) { @@ -637,9 +603,14 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { } } + if (OutliningWorklist.empty()) + return Changed; + // Outline single-entry cold regions, splitting up larger regions as needed. unsigned OutlinedFunctionID = 1; - while (!OutliningWorklist.empty()) { + // Cache and recycle the CodeExtractor analysis to avoid O(n^2) compile-time. + CodeExtractorAnalysisCache CEAC(F); + do { OutliningRegion Region = OutliningWorklist.pop_back_val(); assert(!Region.empty() && "Empty outlining region in worklist"); do { @@ -650,14 +621,14 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { BB->dump(); }); - Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC, - OutlinedFunctionID); + Function *Outlined = extractColdRegion(SubRegion, CEAC, *DT, BFI, TTI, + ORE, AC, OutlinedFunctionID); if (Outlined) { ++OutlinedFunctionID; Changed = true; } } while (!Region.empty()); - } + } while (!OutliningWorklist.empty()); return Changed; } diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp index 34db75dd8b03..bddf75211599 100644 --- a/lib/Transforms/IPO/IPO.cpp +++ b/lib/Transforms/IPO/IPO.cpp @@ -114,6 +114,10 @@ void LLVMAddIPSCCPPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createIPSCCPPass()); } +void LLVMAddMergeFunctionsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createMergeFunctionsPass()); +} + void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) { auto PreserveMain = [=](const GlobalValue &GV) { return AllButMain && GV.getName() == "main"; @@ -121,6 +125,15 @@ void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) { unwrap(PM)->add(createInternalizePass(PreserveMain)); } +void LLVMAddInternalizePassWithMustPreservePredicate( + LLVMPassManagerRef PM, + void *Context, + LLVMBool (*Pred)(LLVMValueRef, void *)) { + unwrap(PM)->add(createInternalizePass([=](const GlobalValue &GV) { + return Pred(wrap(&GV), Context) == 0 ? false : true; + })); +} + void LLVMAddStripDeadPrototypesPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createStripDeadPrototypesPass()); } diff --git a/lib/Transforms/IPO/InferFunctionAttrs.cpp b/lib/Transforms/IPO/InferFunctionAttrs.cpp index 7f5511e008e1..d1a68b28bd33 100644 --- a/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -18,24 +18,28 @@ using namespace llvm; #define DEBUG_TYPE "inferattrs" -static bool inferAllPrototypeAttributes(Module &M, - const TargetLibraryInfo &TLI) { +static bool inferAllPrototypeAttributes( + Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { bool Changed = false; for (Function &F : M.functions()) // We only infer things using the prototype and the name; we don't need // definitions. if (F.isDeclaration() && !F.hasOptNone()) - Changed |= inferLibFuncAttributes(F, TLI); + Changed |= inferLibFuncAttributes(F, GetTLI(F)); return Changed; } PreservedAnalyses InferFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) { - auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + }; - if (!inferAllPrototypeAttributes(M, TLI)) + if (!inferAllPrototypeAttributes(M, GetTLI)) // If we didn't infer anything, preserve all analyses. return PreservedAnalyses::all(); @@ -60,8 +64,10 @@ struct InferFunctionAttrsLegacyPass : public ModulePass { if (skipModule(M)) return false; - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return inferAllPrototypeAttributes(M, TLI); + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; + return inferAllPrototypeAttributes(M, GetTLI); } }; } diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 945f8affae6e..4b72261131c1 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -239,7 +239,7 @@ static void mergeInlinedArrayAllocas( } if (Align1 > Align2) - AvailableAlloca->setAlignment(AI->getAlignment()); + AvailableAlloca->setAlignment(MaybeAlign(AI->getAlignment())); } AI->eraseFromParent(); @@ -527,7 +527,8 @@ static void setInlineRemark(CallSite &CS, StringRef message) { static bool inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, std::function<AssumptionCache &(Function &)> GetAssumptionCache, - ProfileSummaryInfo *PSI, TargetLibraryInfo &TLI, + ProfileSummaryInfo *PSI, + std::function<TargetLibraryInfo &(Function &)> GetTLI, bool InsertLifetime, function_ref<InlineCost(CallSite CS)> GetInlineCost, function_ref<AAResults &(Function &)> AARGetter, @@ -626,7 +627,8 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, Instruction *Instr = CS.getInstruction(); - bool IsTriviallyDead = isInstructionTriviallyDead(Instr, &TLI); + bool IsTriviallyDead = + isInstructionTriviallyDead(Instr, &GetTLI(*Caller)); int InlineHistoryID; if (!IsTriviallyDead) { @@ -757,13 +759,16 @@ bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) { CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); ACT = &getAnalysis<AssumptionCacheTracker>(); PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto GetTLI = [&](Function &F) -> TargetLibraryInfo & { + return getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; - return inlineCallsImpl(SCC, CG, GetAssumptionCache, PSI, TLI, InsertLifetime, - [this](CallSite CS) { return getInlineCost(CS); }, - LegacyAARGetter(*this), ImportedFunctionsStats); + return inlineCallsImpl( + SCC, CG, GetAssumptionCache, PSI, GetTLI, InsertLifetime, + [this](CallSite CS) { return getInlineCost(CS); }, LegacyAARGetter(*this), + ImportedFunctionsStats); } /// Remove now-dead linkonce functions at the end of @@ -879,7 +884,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (!ImportedFunctionsStats && InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) { ImportedFunctionsStats = - llvm::make_unique<ImportedFunctionsInliningStatistics>(); + std::make_unique<ImportedFunctionsInliningStatistics>(); ImportedFunctionsStats->setModuleInfo(M); } diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index 91c7b5f5f135..add2ae053735 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -141,10 +141,12 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { if (NumLoops == 0) return Changed; --NumLoops; AssumptionCache *AC = nullptr; + Function &Func = *L->getHeader()->getParent(); if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) - AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent()); + AC = ACT->lookupAssumptionCache(Func); + CodeExtractorAnalysisCache CEAC(Func); CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); - if (Extractor.extractCodeRegion() != nullptr) { + if (Extractor.extractCodeRegion(CEAC) != nullptr) { Changed = true; // After extraction, the loop is replaced by a function call, so // we shouldn't try to run any more loop passes on it. diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index f7371284f47e..2dec366d70e2 100644 --- a/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -230,6 +230,16 @@ void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits, Bytes[AllocByteOffset + B] |= AllocMask; } +bool lowertypetests::isJumpTableCanonical(Function *F) { + if (F->isDeclarationForLinker()) + return false; + auto *CI = mdconst::extract_or_null<ConstantInt>( + F->getParent()->getModuleFlag("CFI Canonical Jump Tables")); + if (!CI || CI->getZExtValue() != 0) + return true; + return F->hasFnAttribute("cfi-canonical-jump-table"); +} + namespace { struct ByteArrayInfo { @@ -251,9 +261,12 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> { GlobalObject *GO; size_t NTypes; - // For functions: true if this is a definition (either in the merged module or - // in one of the thinlto modules). - bool IsDefinition; + // For functions: true if the jump table is canonical. This essentially means + // whether the canonical address (i.e. the symbol table entry) of the function + // is provided by the local jump table. This is normally the same as whether + // the function is defined locally, but if canonical jump tables are disabled + // by the user then the jump table never provides a canonical definition. + bool IsJumpTableCanonical; // For functions: true if this function is either defined or used in a thinlto // module and its jumptable entry needs to be exported to thinlto backends. @@ -263,13 +276,13 @@ class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> { public: static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO, - bool IsDefinition, bool IsExported, + bool IsJumpTableCanonical, bool IsExported, ArrayRef<MDNode *> Types) { auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate( totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember))); GTM->GO = GO; GTM->NTypes = Types.size(); - GTM->IsDefinition = IsDefinition; + GTM->IsJumpTableCanonical = IsJumpTableCanonical; GTM->IsExported = IsExported; std::uninitialized_copy(Types.begin(), Types.end(), GTM->getTrailingObjects<MDNode *>()); @@ -280,8 +293,8 @@ public: return GO; } - bool isDefinition() const { - return IsDefinition; + bool isJumpTableCanonical() const { + return IsJumpTableCanonical; } bool isExported() const { @@ -320,6 +333,49 @@ private: size_t NTargets; }; +struct ScopedSaveAliaseesAndUsed { + Module &M; + SmallPtrSet<GlobalValue *, 16> Used, CompilerUsed; + std::vector<std::pair<GlobalIndirectSymbol *, Function *>> FunctionAliases; + + ScopedSaveAliaseesAndUsed(Module &M) : M(M) { + // The users of this class want to replace all function references except + // for aliases and llvm.used/llvm.compiler.used with references to a jump + // table. We avoid replacing aliases in order to avoid introducing a double + // indirection (or an alias pointing to a declaration in ThinLTO mode), and + // we avoid replacing llvm.used/llvm.compiler.used because these global + // variables describe properties of the global, not the jump table (besides, + // offseted references to the jump table in llvm.used are invalid). + // Unfortunately, LLVM doesn't have a "RAUW except for these (possibly + // indirect) users", so what we do is save the list of globals referenced by + // llvm.used/llvm.compiler.used and aliases, erase the used lists, let RAUW + // replace the aliasees and then set them back to their original values at + // the end. + if (GlobalVariable *GV = collectUsedGlobalVariables(M, Used, false)) + GV->eraseFromParent(); + if (GlobalVariable *GV = collectUsedGlobalVariables(M, CompilerUsed, true)) + GV->eraseFromParent(); + + for (auto &GIS : concat<GlobalIndirectSymbol>(M.aliases(), M.ifuncs())) { + // FIXME: This should look past all aliases not just interposable ones, + // see discussion on D65118. + if (auto *F = + dyn_cast<Function>(GIS.getIndirectSymbol()->stripPointerCasts())) + FunctionAliases.push_back({&GIS, F}); + } + } + + ~ScopedSaveAliaseesAndUsed() { + appendToUsed(M, std::vector<GlobalValue *>(Used.begin(), Used.end())); + appendToCompilerUsed(M, std::vector<GlobalValue *>(CompilerUsed.begin(), + CompilerUsed.end())); + + for (auto P : FunctionAliases) + P.first->setIndirectSymbol( + ConstantExpr::getBitCast(P.second, P.first->getType())); + } +}; + class LowerTypeTestsModule { Module &M; @@ -387,7 +443,8 @@ class LowerTypeTestsModule { uint8_t *exportTypeId(StringRef TypeId, const TypeIdLowering &TIL); TypeIdLowering importTypeId(StringRef TypeId); void importTypeTest(CallInst *CI); - void importFunction(Function *F, bool isDefinition); + void importFunction(Function *F, bool isJumpTableCanonical, + std::vector<GlobalAlias *> &AliasesToErase); BitSetInfo buildBitSet(Metadata *TypeId, @@ -421,7 +478,8 @@ class LowerTypeTestsModule { ArrayRef<GlobalTypeMember *> Globals, ArrayRef<ICallBranchFunnel *> ICallBranchFunnels); - void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT, bool IsDefinition); + void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT, + bool IsJumpTableCanonical); void moveInitializerToModuleConstructor(GlobalVariable *GV); void findGlobalVariableUsersOf(Constant *C, SmallSetVector<GlobalVariable *, 8> &Out); @@ -433,7 +491,7 @@ class LowerTypeTestsModule { /// the block. 'This's use list is expected to have at least one element. /// Unlike replaceAllUsesWith this function skips blockaddr and direct call /// uses. - void replaceCfiUses(Function *Old, Value *New, bool IsDefinition); + void replaceCfiUses(Function *Old, Value *New, bool IsJumpTableCanonical); /// replaceDirectCalls - Go through the uses list for this definition and /// replace each use, which is a direct function call. @@ -759,43 +817,50 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( // Build a new global with the combined contents of the referenced globals. // This global is a struct whose even-indexed elements contain the original // contents of the referenced globals and whose odd-indexed elements contain - // any padding required to align the next element to the next power of 2. + // any padding required to align the next element to the next power of 2 plus + // any additional padding required to meet its alignment requirements. std::vector<Constant *> GlobalInits; const DataLayout &DL = M.getDataLayout(); + DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout; + Align MaxAlign; + uint64_t CurOffset = 0; + uint64_t DesiredPadding = 0; for (GlobalTypeMember *G : Globals) { - GlobalVariable *GV = cast<GlobalVariable>(G->getGlobal()); + auto *GV = cast<GlobalVariable>(G->getGlobal()); + MaybeAlign Alignment(GV->getAlignment()); + if (!Alignment) + Alignment = Align(DL.getABITypeAlignment(GV->getValueType())); + MaxAlign = std::max(MaxAlign, *Alignment); + uint64_t GVOffset = alignTo(CurOffset + DesiredPadding, *Alignment); + GlobalLayout[G] = GVOffset; + if (GVOffset != 0) { + uint64_t Padding = GVOffset - CurOffset; + GlobalInits.push_back( + ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding))); + } + GlobalInits.push_back(GV->getInitializer()); uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType()); + CurOffset = GVOffset + InitSize; - // Compute the amount of padding required. - uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize; + // Compute the amount of padding that we'd like for the next element. + DesiredPadding = NextPowerOf2(InitSize - 1) - InitSize; // Experiments of different caps with Chromium on both x64 and ARM64 // have shown that the 32-byte cap generates the smallest binary on // both platforms while different caps yield similar performance. // (see https://lists.llvm.org/pipermail/llvm-dev/2018-July/124694.html) - if (Padding > 32) - Padding = alignTo(InitSize, 32) - InitSize; - - GlobalInits.push_back( - ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding))); + if (DesiredPadding > 32) + DesiredPadding = alignTo(InitSize, 32) - InitSize; } - if (!GlobalInits.empty()) - GlobalInits.pop_back(); + Constant *NewInit = ConstantStruct::getAnon(M.getContext(), GlobalInits); auto *CombinedGlobal = new GlobalVariable(M, NewInit->getType(), /*isConstant=*/true, GlobalValue::PrivateLinkage, NewInit); + CombinedGlobal->setAlignment(MaxAlign); StructType *NewTy = cast<StructType>(NewInit->getType()); - const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy); - - // Compute the offsets of the original globals within the new global. - DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout; - for (unsigned I = 0; I != Globals.size(); ++I) - // Multiply by 2 to account for padding elements. - GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2); - lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout); // Build aliases pointing to offsets into the combined global for each @@ -975,14 +1040,16 @@ void LowerTypeTestsModule::importTypeTest(CallInst *CI) { } // ThinLTO backend: the function F has a jump table entry; update this module -// accordingly. isDefinition describes the type of the jump table entry. -void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { +// accordingly. isJumpTableCanonical describes the type of the jump table entry. +void LowerTypeTestsModule::importFunction( + Function *F, bool isJumpTableCanonical, + std::vector<GlobalAlias *> &AliasesToErase) { assert(F->getType()->getAddressSpace() == 0); GlobalValue::VisibilityTypes Visibility = F->getVisibility(); std::string Name = F->getName(); - if (F->isDeclarationForLinker() && isDefinition) { + if (F->isDeclarationForLinker() && isJumpTableCanonical) { // Non-dso_local functions may be overriden at run time, // don't short curcuit them if (F->isDSOLocal()) { @@ -997,12 +1064,13 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { } Function *FDecl; - if (F->isDeclarationForLinker() && !isDefinition) { - // Declaration of an external function. + if (!isJumpTableCanonical) { + // Either a declaration of an external function or a reference to a locally + // defined jump table. FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, F->getAddressSpace(), Name + ".cfi_jt", &M); FDecl->setVisibility(GlobalValue::HiddenVisibility); - } else if (isDefinition) { + } else { F->setName(Name + ".cfi"); F->setLinkage(GlobalValue::ExternalLinkage); FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, @@ -1011,8 +1079,8 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { Visibility = GlobalValue::HiddenVisibility; // Delete aliases pointing to this function, they'll be re-created in the - // merged output - SmallVector<GlobalAlias*, 4> ToErase; + // merged output. Don't do it yet though because ScopedSaveAliaseesAndUsed + // will want to reset the aliasees first. for (auto &U : F->uses()) { if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) { Function *AliasDecl = Function::Create( @@ -1020,24 +1088,15 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { F->getAddressSpace(), "", &M); AliasDecl->takeName(A); A->replaceAllUsesWith(AliasDecl); - ToErase.push_back(A); + AliasesToErase.push_back(A); } } - for (auto *A : ToErase) - A->eraseFromParent(); - } else { - // Function definition without type metadata, where some other translation - // unit contained a declaration with type metadata. This normally happens - // during mixed CFI + non-CFI compilation. We do nothing with the function - // so that it is treated the same way as a function defined outside of the - // LTO unit. - return; } - if (F->isWeakForLinker()) - replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isDefinition); + if (F->hasExternalWeakLinkage()) + replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isJumpTableCanonical); else - replaceCfiUses(F, FDecl, isDefinition); + replaceCfiUses(F, FDecl, isJumpTableCanonical); // Set visibility late because it's used in replaceCfiUses() to determine // whether uses need to to be replaced. @@ -1225,7 +1284,7 @@ void LowerTypeTestsModule::findGlobalVariableUsersOf( // Replace all uses of F with (F ? JT : 0). void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( - Function *F, Constant *JT, bool IsDefinition) { + Function *F, Constant *JT, bool IsJumpTableCanonical) { // The target expression can not appear in a constant initializer on most // (all?) targets. Switch to a runtime initializer. SmallSetVector<GlobalVariable *, 8> GlobalVarUsers; @@ -1239,7 +1298,7 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( Function::Create(cast<FunctionType>(F->getValueType()), GlobalValue::ExternalWeakLinkage, F->getAddressSpace(), "", &M); - replaceCfiUses(F, PlaceholderFn, IsDefinition); + replaceCfiUses(F, PlaceholderFn, IsJumpTableCanonical); Constant *Target = ConstantExpr::getSelect( ConstantExpr::getICmp(CmpInst::ICMP_NE, F, @@ -1276,8 +1335,9 @@ selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions, unsigned ArmCount = 0, ThumbCount = 0; for (const auto GTM : Functions) { - if (!GTM->isDefinition()) { + if (!GTM->isJumpTableCanonical()) { // PLT stubs are always ARM. + // FIXME: This is the wrong heuristic for non-canonical jump tables. ++ArmCount; continue; } @@ -1303,7 +1363,7 @@ void LowerTypeTestsModule::createJumpTable( cast<Function>(Functions[I]->getGlobal())); // Align the whole table by entry size. - F->setAlignment(getJumpTableEntrySize()); + F->setAlignment(Align(getJumpTableEntrySize())); // Skip prologue. // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3. // Luckily, this function does not get any prologue even without the @@ -1438,47 +1498,53 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout); - // Build aliases pointing to offsets into the jump table, and replace - // references to the original functions with references to the aliases. - for (unsigned I = 0; I != Functions.size(); ++I) { - Function *F = cast<Function>(Functions[I]->getGlobal()); - bool IsDefinition = Functions[I]->isDefinition(); - - Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast( - ConstantExpr::getInBoundsGetElementPtr( - JumpTableType, JumpTable, - ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), - ConstantInt::get(IntPtrTy, I)}), - F->getType()); - if (Functions[I]->isExported()) { - if (IsDefinition) { - ExportSummary->cfiFunctionDefs().insert(F->getName()); + { + ScopedSaveAliaseesAndUsed S(M); + + // Build aliases pointing to offsets into the jump table, and replace + // references to the original functions with references to the aliases. + for (unsigned I = 0; I != Functions.size(); ++I) { + Function *F = cast<Function>(Functions[I]->getGlobal()); + bool IsJumpTableCanonical = Functions[I]->isJumpTableCanonical(); + + Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast( + ConstantExpr::getInBoundsGetElementPtr( + JumpTableType, JumpTable, + ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), + ConstantInt::get(IntPtrTy, I)}), + F->getType()); + if (Functions[I]->isExported()) { + if (IsJumpTableCanonical) { + ExportSummary->cfiFunctionDefs().insert(F->getName()); + } else { + GlobalAlias *JtAlias = GlobalAlias::create( + F->getValueType(), 0, GlobalValue::ExternalLinkage, + F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M); + JtAlias->setVisibility(GlobalValue::HiddenVisibility); + ExportSummary->cfiFunctionDecls().insert(F->getName()); + } + } + if (!IsJumpTableCanonical) { + if (F->hasExternalWeakLinkage()) + replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr, + IsJumpTableCanonical); + else + replaceCfiUses(F, CombinedGlobalElemPtr, IsJumpTableCanonical); } else { - GlobalAlias *JtAlias = GlobalAlias::create( - F->getValueType(), 0, GlobalValue::ExternalLinkage, - F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M); - JtAlias->setVisibility(GlobalValue::HiddenVisibility); - ExportSummary->cfiFunctionDecls().insert(F->getName()); + assert(F->getType()->getAddressSpace() == 0); + + GlobalAlias *FAlias = + GlobalAlias::create(F->getValueType(), 0, F->getLinkage(), "", + CombinedGlobalElemPtr, &M); + FAlias->setVisibility(F->getVisibility()); + FAlias->takeName(F); + if (FAlias->hasName()) + F->setName(FAlias->getName() + ".cfi"); + replaceCfiUses(F, FAlias, IsJumpTableCanonical); + if (!F->hasLocalLinkage()) + F->setVisibility(GlobalVariable::HiddenVisibility); } } - if (!IsDefinition) { - if (F->isWeakForLinker()) - replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr, IsDefinition); - else - replaceCfiUses(F, CombinedGlobalElemPtr, IsDefinition); - } else { - assert(F->getType()->getAddressSpace() == 0); - - GlobalAlias *FAlias = GlobalAlias::create( - F->getValueType(), 0, F->getLinkage(), "", CombinedGlobalElemPtr, &M); - FAlias->setVisibility(F->getVisibility()); - FAlias->takeName(F); - if (FAlias->hasName()) - F->setName(FAlias->getName() + ".cfi"); - replaceCfiUses(F, FAlias, IsDefinition); - if (!F->hasLocalLinkage()) - F->setVisibility(GlobalVariable::HiddenVisibility); - } } createJumpTable(JumpTableFn, Functions); @@ -1623,7 +1689,7 @@ bool LowerTypeTestsModule::runForTesting(Module &M) { ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary + ": "); std::error_code EC; - raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text); ExitOnErr(errorCodeToError(EC)); yaml::Output Out(OS); @@ -1643,7 +1709,8 @@ static bool isDirectCall(Use& U) { return false; } -void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefinition) { +void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, + bool IsJumpTableCanonical) { SmallSetVector<Constant *, 4> Constants; auto UI = Old->use_begin(), E = Old->use_end(); for (; UI != E;) { @@ -1655,7 +1722,7 @@ void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefi continue; // Skip direct calls to externally defined or non-dso_local functions - if (isDirectCall(U) && (Old->isDSOLocal() || !IsDefinition)) + if (isDirectCall(U) && (Old->isDSOLocal() || !IsJumpTableCanonical)) continue; // Must handle Constants specially, we cannot call replaceUsesOfWith on a @@ -1678,16 +1745,7 @@ void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefi } void LowerTypeTestsModule::replaceDirectCalls(Value *Old, Value *New) { - auto UI = Old->use_begin(), E = Old->use_end(); - for (; UI != E;) { - Use &U = *UI; - ++UI; - - if (!isDirectCall(U)) - continue; - - U.set(New); - } + Old->replaceUsesWithIf(New, [](Use &U) { return isDirectCall(U); }); } bool LowerTypeTestsModule::lower() { @@ -1734,10 +1792,16 @@ bool LowerTypeTestsModule::lower() { Decls.push_back(&F); } - for (auto F : Defs) - importFunction(F, /*isDefinition*/ true); - for (auto F : Decls) - importFunction(F, /*isDefinition*/ false); + std::vector<GlobalAlias *> AliasesToErase; + { + ScopedSaveAliaseesAndUsed S(M); + for (auto F : Defs) + importFunction(F, /*isJumpTableCanonical*/ true, AliasesToErase); + for (auto F : Decls) + importFunction(F, /*isJumpTableCanonical*/ false, AliasesToErase); + } + for (GlobalAlias *GA : AliasesToErase) + GA->eraseFromParent(); return true; } @@ -1823,6 +1887,17 @@ bool LowerTypeTestsModule::lower() { CfiFunctionLinkage Linkage = P.second.Linkage; MDNode *FuncMD = P.second.FuncMD; Function *F = M.getFunction(FunctionName); + if (F && F->hasLocalLinkage()) { + // Locally defined function that happens to have the same name as a + // function defined in a ThinLTO module. Rename it to move it out of + // the way of the external reference that we're about to create. + // Note that setName will find a unique name for the function, so even + // if there is an existing function with the suffix there won't be a + // name collision. + F->setName(F->getName() + ".1"); + F = nullptr; + } + if (!F) F = Function::Create( FunctionType::get(Type::getVoidTy(M.getContext()), false), @@ -1871,24 +1946,26 @@ bool LowerTypeTestsModule::lower() { Types.clear(); GO.getMetadata(LLVMContext::MD_type, Types); - bool IsDefinition = !GO.isDeclarationForLinker(); + bool IsJumpTableCanonical = false; bool IsExported = false; if (Function *F = dyn_cast<Function>(&GO)) { + IsJumpTableCanonical = isJumpTableCanonical(F); if (ExportedFunctions.count(F->getName())) { - IsDefinition |= ExportedFunctions[F->getName()].Linkage == CFL_Definition; + IsJumpTableCanonical |= + ExportedFunctions[F->getName()].Linkage == CFL_Definition; IsExported = true; // TODO: The logic here checks only that the function is address taken, // not that the address takers are live. This can be updated to check // their liveness and emit fewer jumptable entries once monolithic LTO // builds also emit summaries. } else if (!F->hasAddressTaken()) { - if (!CrossDsoCfi || !IsDefinition || F->hasLocalLinkage()) + if (!CrossDsoCfi || !IsJumpTableCanonical || F->hasLocalLinkage()) continue; } } - auto *GTM = - GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types); + auto *GTM = GlobalTypeMember::create(Alloc, &GO, IsJumpTableCanonical, + IsExported, Types); GlobalTypeMembers[&GO] = GTM; for (MDNode *Type : Types) { verifyTypeMDNode(&GO, Type); diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 3a08069dcd4a..8b9abaddc84c 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -769,7 +769,7 @@ void MergeFunctions::writeAlias(Function *F, Function *G) { PtrType->getElementType(), PtrType->getAddressSpace(), G->getLinkage(), "", BitcastF, G->getParent()); - F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); + F->setAlignment(MaybeAlign(std::max(F->getAlignment(), G->getAlignment()))); GA->takeName(G); GA->setVisibility(G->getVisibility()); GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); @@ -816,7 +816,7 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { removeUsers(F); F->replaceAllUsesWith(NewF); - unsigned MaxAlignment = std::max(G->getAlignment(), NewF->getAlignment()); + MaybeAlign MaxAlignment(std::max(G->getAlignment(), NewF->getAlignment())); writeThunkOrAlias(F, G); writeThunkOrAlias(F, NewF); diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index 733782e8764d..e193074884af 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -409,7 +409,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, return std::unique_ptr<FunctionOutliningMultiRegionInfo>(); std::unique_ptr<FunctionOutliningMultiRegionInfo> OutliningInfo = - llvm::make_unique<FunctionOutliningMultiRegionInfo>(); + std::make_unique<FunctionOutliningMultiRegionInfo>(); auto IsSingleEntry = [](SmallVectorImpl<BasicBlock *> &BlockList) { BasicBlock *Dom = BlockList.front(); @@ -589,7 +589,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { }; std::unique_ptr<FunctionOutliningInfo> OutliningInfo = - llvm::make_unique<FunctionOutliningInfo>(); + std::make_unique<FunctionOutliningInfo>(); BasicBlock *CurrEntry = EntryBlock; bool CandidateFound = false; @@ -966,7 +966,7 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE, function_ref<AssumptionCache *(Function &)> LookupAC) : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { - ClonedOI = llvm::make_unique<FunctionOutliningInfo>(); + ClonedOI = std::make_unique<FunctionOutliningInfo>(); // Clone the function, so that we can hack away on it. ValueToValueMapTy VMap; @@ -991,7 +991,7 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( OptimizationRemarkEmitter &ORE, function_ref<AssumptionCache *(Function &)> LookupAC) : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { - ClonedOMRI = llvm::make_unique<FunctionOutliningMultiRegionInfo>(); + ClonedOMRI = std::make_unique<FunctionOutliningMultiRegionInfo>(); // Clone the function, so that we can hack away on it. ValueToValueMapTy VMap; @@ -1122,6 +1122,9 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { BranchProbabilityInfo BPI(*ClonedFunc, LI); ClonedFuncBFI.reset(new BlockFrequencyInfo(*ClonedFunc, BPI, LI)); + // Cache and recycle the CodeExtractor analysis to avoid O(n^2) compile-time. + CodeExtractorAnalysisCache CEAC(*ClonedFunc); + SetVector<Value *> Inputs, Outputs, Sinks; for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo : ClonedOMRI->ORI) { @@ -1148,7 +1151,7 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { if (Outputs.size() > 0 && !ForceLiveExit) continue; - Function *OutlinedFunc = CE.extractCodeRegion(); + Function *OutlinedFunc = CE.extractCodeRegion(CEAC); if (OutlinedFunc) { CallSite OCS = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc); @@ -1210,11 +1213,12 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { } // Extract the body of the if. + CodeExtractorAnalysisCache CEAC(*ClonedFunc); Function *OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc), /* AllowVarargs */ true) - .extractCodeRegion(); + .extractCodeRegion(CEAC); if (OutlinedFunc) { BasicBlock *OutliningCallBB = @@ -1264,7 +1268,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { if (PSI->isFunctionEntryCold(F)) return {false, nullptr}; - if (empty(F->users())) + if (F->users().empty()) return {false, nullptr}; OptimizationRemarkEmitter ORE(F); @@ -1370,7 +1374,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { return false; } - assert(empty(Cloner.OrigFunc->users()) && + assert(Cloner.OrigFunc->users().empty() && "F's users should all be replaced!"); std::vector<User *> Users(Cloner.ClonedFunc->user_begin(), diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 3ea77f08fd3c..5314a8219b1e 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -654,6 +654,7 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createGlobalsAAWrapperPass()); MPM.add(createFloat2IntPass()); + MPM.add(createLowerConstantIntrinsicsPass()); addExtensionsToPM(EP_VectorizerStart, MPM); diff --git a/lib/Transforms/IPO/SCCP.cpp b/lib/Transforms/IPO/SCCP.cpp index 7be3608bd2ec..307690729b14 100644 --- a/lib/Transforms/IPO/SCCP.cpp +++ b/lib/Transforms/IPO/SCCP.cpp @@ -9,16 +9,18 @@ using namespace llvm; PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { const DataLayout &DL = M.getDataLayout(); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto GetTLI = [&FAM](Function &F) -> const TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + }; auto getAnalysis = [&FAM](Function &F) -> AnalysisResultsForFn { DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); return { - make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)), + std::make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)), &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F)}; }; - if (!runIPSCCP(M, DL, &TLI, getAnalysis)) + if (!runIPSCCP(M, DL, GetTLI, getAnalysis)) return PreservedAnalyses::all(); PreservedAnalyses PA; @@ -47,14 +49,14 @@ public: if (skipModule(M)) return false; const DataLayout &DL = M.getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - + auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn { DominatorTree &DT = this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); return { - make_unique<PredicateInfo>( + std::make_unique<PredicateInfo>( F, DT, this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache( F)), @@ -62,7 +64,7 @@ public: nullptr}; // manager, so set them to nullptr. }; - return runIPSCCP(M, DL, TLI, getAnalysis); + return runIPSCCP(M, DL, GetTLI, getAnalysis); } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index 877d20e72ffc..6184681db8a2 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -72,6 +72,7 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/MisExpect.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -79,6 +80,7 @@ #include <limits> #include <map> #include <memory> +#include <queue> #include <string> #include <system_error> #include <utility> @@ -128,6 +130,12 @@ static cl::opt<bool> ProfileSampleAccurate( "callsite and function as having 0 samples. Otherwise, treat " "un-sampled callsites and functions conservatively as unknown. ")); +static cl::opt<bool> ProfileAccurateForSymsInList( + "profile-accurate-for-symsinlist", cl::Hidden, cl::ZeroOrMore, + cl::init(true), + cl::desc("For symbols in profile symbol list, regard their profiles to " + "be accurate. It may be overriden by profile-sample-accurate. ")); + namespace { using BlockWeightMap = DenseMap<const BasicBlock *, uint64_t>; @@ -137,9 +145,11 @@ using EdgeWeightMap = DenseMap<Edge, uint64_t>; using BlockEdgeMap = DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>>; +class SampleProfileLoader; + class SampleCoverageTracker { public: - SampleCoverageTracker() = default; + SampleCoverageTracker(SampleProfileLoader &SPL) : SPLoader(SPL){}; bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset, uint32_t Discriminator, uint64_t Samples); @@ -185,6 +195,76 @@ private: /// keyed by FunctionSamples pointers, but these stats are cleared after /// every function, so we just need to keep a single counter. uint64_t TotalUsedSamples = 0; + + SampleProfileLoader &SPLoader; +}; + +class GUIDToFuncNameMapper { +public: + GUIDToFuncNameMapper(Module &M, SampleProfileReader &Reader, + DenseMap<uint64_t, StringRef> &GUIDToFuncNameMap) + : CurrentReader(Reader), CurrentModule(M), + CurrentGUIDToFuncNameMap(GUIDToFuncNameMap) { + if (CurrentReader.getFormat() != SPF_Compact_Binary) + return; + + for (const auto &F : CurrentModule) { + StringRef OrigName = F.getName(); + CurrentGUIDToFuncNameMap.insert( + {Function::getGUID(OrigName), OrigName}); + + // Local to global var promotion used by optimization like thinlto + // will rename the var and add suffix like ".llvm.xxx" to the + // original local name. In sample profile, the suffixes of function + // names are all stripped. Since it is possible that the mapper is + // built in post-thin-link phase and var promotion has been done, + // we need to add the substring of function name without the suffix + // into the GUIDToFuncNameMap. + StringRef CanonName = FunctionSamples::getCanonicalFnName(F); + if (CanonName != OrigName) + CurrentGUIDToFuncNameMap.insert( + {Function::getGUID(CanonName), CanonName}); + } + + // Update GUIDToFuncNameMap for each function including inlinees. + SetGUIDToFuncNameMapForAll(&CurrentGUIDToFuncNameMap); + } + + ~GUIDToFuncNameMapper() { + if (CurrentReader.getFormat() != SPF_Compact_Binary) + return; + + CurrentGUIDToFuncNameMap.clear(); + + // Reset GUIDToFuncNameMap for of each function as they're no + // longer valid at this point. + SetGUIDToFuncNameMapForAll(nullptr); + } + +private: + void SetGUIDToFuncNameMapForAll(DenseMap<uint64_t, StringRef> *Map) { + std::queue<FunctionSamples *> FSToUpdate; + for (auto &IFS : CurrentReader.getProfiles()) { + FSToUpdate.push(&IFS.second); + } + + while (!FSToUpdate.empty()) { + FunctionSamples *FS = FSToUpdate.front(); + FSToUpdate.pop(); + FS->GUIDToFuncNameMap = Map; + for (const auto &ICS : FS->getCallsiteSamples()) { + const FunctionSamplesMap &FSMap = ICS.second; + for (auto &IFS : FSMap) { + FunctionSamples &FS = const_cast<FunctionSamples &>(IFS.second); + FSToUpdate.push(&FS); + } + } + } + } + + SampleProfileReader &CurrentReader; + Module &CurrentModule; + DenseMap<uint64_t, StringRef> &CurrentGUIDToFuncNameMap; }; /// Sample profile pass. @@ -199,8 +279,9 @@ public: std::function<AssumptionCache &(Function &)> GetAssumptionCache, std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo) : GetAC(std::move(GetAssumptionCache)), - GetTTI(std::move(GetTargetTransformInfo)), Filename(Name), - RemappingFilename(RemapName), IsThinLTOPreLink(IsThinLTOPreLink) {} + GetTTI(std::move(GetTargetTransformInfo)), CoverageTracker(*this), + Filename(Name), RemappingFilename(RemapName), + IsThinLTOPreLink(IsThinLTOPreLink) {} bool doInitialization(Module &M); bool runOnModule(Module &M, ModuleAnalysisManager *AM, @@ -209,6 +290,8 @@ public: void dump() { Reader->dump(); } protected: + friend class SampleCoverageTracker; + bool runOnFunction(Function &F, ModuleAnalysisManager *AM); unsigned getFunctionLoc(Function &F); bool emitAnnotations(Function &F); @@ -237,6 +320,8 @@ protected: bool propagateThroughEdges(Function &F, bool UpdateBlockCount); void computeDominanceAndLoopInfo(Function &F); void clearFunctionData(); + bool callsiteIsHot(const FunctionSamples *CallsiteFS, + ProfileSummaryInfo *PSI); /// Map basic blocks to their computed weights. /// @@ -310,6 +395,10 @@ protected: /// Profile Summary Info computed from sample profile. ProfileSummaryInfo *PSI = nullptr; + /// Profle Symbol list tells whether a function name appears in the binary + /// used to generate the current profile. + std::unique_ptr<ProfileSymbolList> PSL; + /// Total number of samples collected in this profile. /// /// This is the sum of all the samples collected in all the functions executed @@ -326,6 +415,21 @@ protected: uint64_t entryCount; }; DenseMap<Function *, NotInlinedProfileInfo> notInlinedCallInfo; + + // GUIDToFuncNameMap saves the mapping from GUID to the symbol name, for + // all the function symbols defined or declared in current module. + DenseMap<uint64_t, StringRef> GUIDToFuncNameMap; + + // All the Names used in FunctionSamples including outline function + // names, inline instance names and call target names. + StringSet<> NamesInProfile; + + // For symbol in profile symbol list, whether to regard their profiles + // to be accurate. It is mainly decided by existance of profile symbol + // list and -profile-accurate-for-symsinlist flag, but it can be + // overriden by -profile-sample-accurate or profile-sample-accurate + // attribute. + bool ProfAccForSymsInList; }; class SampleProfileLoaderLegacyPass : public ModulePass { @@ -381,14 +485,23 @@ private: /// To decide whether an inlined callsite is hot, we compare the callsite /// sample count with the hot cutoff computed by ProfileSummaryInfo, it is /// regarded as hot if the count is above the cutoff value. -static bool callsiteIsHot(const FunctionSamples *CallsiteFS, - ProfileSummaryInfo *PSI) { +/// +/// When ProfileAccurateForSymsInList is enabled and profile symbol list +/// is present, functions in the profile symbol list but without profile will +/// be regarded as cold and much less inlining will happen in CGSCC inlining +/// pass, so we tend to lower the hot criteria here to allow more early +/// inlining to happen for warm callsites and it is helpful for performance. +bool SampleProfileLoader::callsiteIsHot(const FunctionSamples *CallsiteFS, + ProfileSummaryInfo *PSI) { if (!CallsiteFS) return false; // The callsite was not inlined in the original binary. assert(PSI && "PSI is expected to be non null"); uint64_t CallsiteTotalSamples = CallsiteFS->getTotalSamples(); - return PSI->isHotCount(CallsiteTotalSamples); + if (ProfAccForSymsInList) + return !PSI->isColdCount(CallsiteTotalSamples); + else + return PSI->isHotCount(CallsiteTotalSamples); } /// Mark as used the sample record for the given function samples at @@ -425,7 +538,7 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS, for (const auto &I : FS->getCallsiteSamples()) for (const auto &J : I.second) { const FunctionSamples *CalleeSamples = &J.second; - if (callsiteIsHot(CalleeSamples, PSI)) + if (SPLoader.callsiteIsHot(CalleeSamples, PSI)) Count += countUsedRecords(CalleeSamples, PSI); } @@ -444,7 +557,7 @@ SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS, for (const auto &I : FS->getCallsiteSamples()) for (const auto &J : I.second) { const FunctionSamples *CalleeSamples = &J.second; - if (callsiteIsHot(CalleeSamples, PSI)) + if (SPLoader.callsiteIsHot(CalleeSamples, PSI)) Count += countBodyRecords(CalleeSamples, PSI); } @@ -465,7 +578,7 @@ SampleCoverageTracker::countBodySamples(const FunctionSamples *FS, for (const auto &I : FS->getCallsiteSamples()) for (const auto &J : I.second) { const FunctionSamples *CalleeSamples = &J.second; - if (callsiteIsHot(CalleeSamples, PSI)) + if (SPLoader.callsiteIsHot(CalleeSamples, PSI)) Total += countBodySamples(CalleeSamples, PSI); } @@ -788,6 +901,14 @@ bool SampleProfileLoader::inlineHotFunctions( Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) { DenseSet<Instruction *> PromotedInsns; + // ProfAccForSymsInList is used in callsiteIsHot. The assertion makes sure + // Profile symbol list is ignored when profile-sample-accurate is on. + assert((!ProfAccForSymsInList || + (!ProfileSampleAccurate && + !F.hasFnAttribute("profile-sample-accurate"))) && + "ProfAccForSymsInList should be false when profile-sample-accurate " + "is enabled"); + DenseMap<Instruction *, const FunctionSamples *> localNotInlinedCallSites; bool Changed = false; while (true) { @@ -1219,17 +1340,12 @@ void SampleProfileLoader::buildEdges(Function &F) { } /// Returns the sorted CallTargetMap \p M by count in descending order. -static SmallVector<InstrProfValueData, 2> SortCallTargets( - const SampleRecord::CallTargetMap &M) { +static SmallVector<InstrProfValueData, 2> GetSortedValueDataFromCallTargets( + const SampleRecord::CallTargetMap & M) { SmallVector<InstrProfValueData, 2> R; - for (auto I = M.begin(); I != M.end(); ++I) - R.push_back({FunctionSamples::getGUID(I->getKey()), I->getValue()}); - llvm::sort(R, [](const InstrProfValueData &L, const InstrProfValueData &R) { - if (L.Count == R.Count) - return L.Value > R.Value; - else - return L.Count > R.Count; - }); + for (const auto &I : SampleRecord::SortCallTargets(M)) { + R.emplace_back(InstrProfValueData{FunctionSamples::getGUID(I.first), I.second}); + } return R; } @@ -1324,7 +1440,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (!T || T.get().empty()) continue; SmallVector<InstrProfValueData, 2> SortedCallTargets = - SortCallTargets(T.get()); + GetSortedValueDataFromCallTargets(T.get()); uint64_t Sum; findIndirectCallFunctionSamples(I, Sum); annotateValueSite(*I.getParent()->getParent()->getParent(), I, @@ -1374,6 +1490,8 @@ void SampleProfileLoader::propagateWeights(Function &F) { } } + misexpect::verifyMisExpect(TI, Weights, TI->getContext()); + uint64_t TempWeight; // Only set weights if there is at least one non-zero weight. // In any other case, let the analyzer set weights. @@ -1557,30 +1675,29 @@ INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile", bool SampleProfileLoader::doInitialization(Module &M) { auto &Ctx = M.getContext(); - auto ReaderOrErr = SampleProfileReader::create(Filename, Ctx); + + std::unique_ptr<SampleProfileReaderItaniumRemapper> RemapReader; + auto ReaderOrErr = + SampleProfileReader::create(Filename, Ctx, RemappingFilename); if (std::error_code EC = ReaderOrErr.getError()) { std::string Msg = "Could not open profile: " + EC.message(); Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); return false; } Reader = std::move(ReaderOrErr.get()); - Reader->collectFuncsToUse(M); + Reader->collectFuncsFrom(M); ProfileIsValid = (Reader->read() == sampleprof_error::success); - - if (!RemappingFilename.empty()) { - // Apply profile remappings to the loaded profile data if requested. - // For now, we only support remapping symbols encoded using the Itanium - // C++ ABI's name mangling scheme. - ReaderOrErr = SampleProfileReaderItaniumRemapper::create( - RemappingFilename, Ctx, std::move(Reader)); - if (std::error_code EC = ReaderOrErr.getError()) { - std::string Msg = "Could not open profile remapping file: " + EC.message(); - Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); - return false; - } - Reader = std::move(ReaderOrErr.get()); - ProfileIsValid = (Reader->read() == sampleprof_error::success); + PSL = Reader->getProfileSymbolList(); + + // While profile-sample-accurate is on, ignore symbol list. + ProfAccForSymsInList = + ProfileAccurateForSymsInList && PSL && !ProfileSampleAccurate; + if (ProfAccForSymsInList) { + NamesInProfile.clear(); + if (auto NameTable = Reader->getNameTable()) + NamesInProfile.insert(NameTable->begin(), NameTable->end()); } + return true; } @@ -1594,7 +1711,7 @@ ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, ProfileSummaryInfo *_PSI) { - FunctionSamples::GUIDToFuncNameMapper Mapper(M); + GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap); if (!ProfileIsValid) return false; @@ -1651,19 +1768,48 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { } bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) { - + DILocation2SampleMap.clear(); // By default the entry count is initialized to -1, which will be treated // conservatively by getEntryCount as the same as unknown (None). This is // to avoid newly added code to be treated as cold. If we have samples // this will be overwritten in emitAnnotations. - // If ProfileSampleAccurate is true or F has profile-sample-accurate - // attribute, initialize the entry count to 0 so callsites or functions - // unsampled will be treated as cold. - uint64_t initialEntryCount = - (ProfileSampleAccurate || F.hasFnAttribute("profile-sample-accurate")) - ? 0 - : -1; + uint64_t initialEntryCount = -1; + + ProfAccForSymsInList = ProfileAccurateForSymsInList && PSL; + if (ProfileSampleAccurate || F.hasFnAttribute("profile-sample-accurate")) { + // initialize all the function entry counts to 0. It means all the + // functions without profile will be regarded as cold. + initialEntryCount = 0; + // profile-sample-accurate is a user assertion which has a higher precedence + // than symbol list. When profile-sample-accurate is on, ignore symbol list. + ProfAccForSymsInList = false; + } + + // PSL -- profile symbol list include all the symbols in sampled binary. + // If ProfileAccurateForSymsInList is enabled, PSL is used to treat + // old functions without samples being cold, without having to worry + // about new and hot functions being mistakenly treated as cold. + if (ProfAccForSymsInList) { + // Initialize the entry count to 0 for functions in the list. + if (PSL->contains(F.getName())) + initialEntryCount = 0; + + // Function in the symbol list but without sample will be regarded as + // cold. To minimize the potential negative performance impact it could + // have, we want to be a little conservative here saying if a function + // shows up in the profile, no matter as outline function, inline instance + // or call targets, treat the function as not being cold. This will handle + // the cases such as most callsites of a function are inlined in sampled + // binary but not inlined in current build (because of source code drift, + // imprecise debug information, or the callsites are all cold individually + // but not cold accumulatively...), so the outline function showing up as + // cold in sampled binary will actually not be cold after current build. + StringRef CanonName = FunctionSamples::getCanonicalFnName(F); + if (NamesInProfile.count(CanonName)) + initialEntryCount = -1; + } + F.setEntryCount(ProfileCount(initialEntryCount, Function::PCT_Real)); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; if (AM) { @@ -1672,7 +1818,7 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) .getManager(); ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); } else { - OwnedORE = make_unique<OptimizationRemarkEmitter>(&F); + OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F); ORE = OwnedORE.get(); } Samples = Reader->getSamplesFor(F); diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 24c476376c14..690b5e8bf49e 100644 --- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -24,6 +24,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionImport.h" +#include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -218,10 +219,18 @@ void splitAndWriteThinLTOBitcode( promoteTypeIds(M, ModuleId); - // Returns whether a global has attached type metadata. Such globals may - // participate in CFI or whole-program devirtualization, so they need to - // appear in the merged module instead of the thin LTO module. + // Returns whether a global or its associated global has attached type + // metadata. The former may participate in CFI or whole-program + // devirtualization, so they need to appear in the merged module instead of + // the thin LTO module. Similarly, globals that are associated with globals + // with type metadata need to appear in the merged module because they will + // reference the global's section directly. auto HasTypeMetadata = [](const GlobalObject *GO) { + if (MDNode *MD = GO->getMetadata(LLVMContext::MD_associated)) + if (auto *AssocVM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(0))) + if (auto *AssocGO = dyn_cast<GlobalObject>(AssocVM->getValue())) + if (AssocGO->hasMetadata(LLVMContext::MD_type)) + return true; return GO->hasMetadata(LLVMContext::MD_type); }; @@ -315,9 +324,9 @@ void splitAndWriteThinLTOBitcode( SmallVector<Metadata *, 4> Elts; Elts.push_back(MDString::get(Ctx, F.getName())); CfiFunctionLinkage Linkage; - if (!F.isDeclarationForLinker()) + if (lowertypetests::isJumpTableCanonical(&F)) Linkage = CFL_Definition; - else if (F.isWeakForLinker()) + else if (F.hasExternalWeakLinkage()) Linkage = CFL_WeakDeclaration; else Linkage = CFL_Declaration; @@ -457,7 +466,7 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, // splitAndWriteThinLTOBitcode). Just always build it once via the // buildModuleSummaryIndex when Module(s) are ready. ProfileSummaryInfo PSI(M); - NewIndex = llvm::make_unique<ModuleSummaryIndex>( + NewIndex = std::make_unique<ModuleSummaryIndex>( buildModuleSummaryIndex(M, nullptr, &PSI)); Index = NewIndex.get(); } diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp index 6b6dd6194e17..f0cf5581ba8a 100644 --- a/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -24,12 +24,14 @@ // returns 0, or a single vtable's function returns 1, replace each virtual // call with a comparison of the vptr against that vtable's address. // -// This pass is intended to be used during the regular and thin LTO pipelines. +// This pass is intended to be used during the regular and thin LTO pipelines: +// // During regular LTO, the pass determines the best optimization for each // virtual call and applies the resolutions directly to virtual calls that are // eligible for virtual call optimization (i.e. calls that use either of the -// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During -// ThinLTO, the pass operates in two phases: +// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). +// +// During hybrid Regular/ThinLTO, the pass operates in two phases: // - Export phase: this is run during the thin link over a single merged module // that contains all vtables with !type metadata that participate in the link. // The pass computes a resolution for each virtual call and stores it in the @@ -38,6 +40,14 @@ // modules. The pass applies the resolutions previously computed during the // import phase to each eligible virtual call. // +// During ThinLTO, the pass operates in two phases: +// - Export phase: this is run during the thin link over the index which +// contains a summary of all vtables with !type metadata that participate in +// the link. It computes a resolution for each virtual call and stores it in +// the type identifier summary. Only single implementation devirtualization +// is supported. +// - Import phase: (same as with hybrid case above). +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/WholeProgramDevirt.h" @@ -117,6 +127,11 @@ static cl::opt<unsigned> cl::desc("Maximum number of call targets per " "call site to enable branch funnels")); +static cl::opt<bool> + PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden, + cl::init(false), cl::ZeroOrMore, + cl::desc("Print index-based devirtualization messages")); + // Find the minimum offset that we may store a value of size Size bits at. If // IsAfter is set, look for an offset before the object, otherwise look for an // offset after the object. @@ -265,6 +280,25 @@ template <> struct DenseMapInfo<VTableSlot> { } }; +template <> struct DenseMapInfo<VTableSlotSummary> { + static VTableSlotSummary getEmptyKey() { + return {DenseMapInfo<StringRef>::getEmptyKey(), + DenseMapInfo<uint64_t>::getEmptyKey()}; + } + static VTableSlotSummary getTombstoneKey() { + return {DenseMapInfo<StringRef>::getTombstoneKey(), + DenseMapInfo<uint64_t>::getTombstoneKey()}; + } + static unsigned getHashValue(const VTableSlotSummary &I) { + return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^ + DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); + } + static bool isEqual(const VTableSlotSummary &LHS, + const VTableSlotSummary &RHS) { + return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; + } +}; + } // end namespace llvm namespace { @@ -342,19 +376,21 @@ struct CallSiteInfo { /// pass the vector is non-empty, we will need to add a use of llvm.type.test /// to each of the function summaries in the vector. std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; + std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers; bool isExported() const { return SummaryHasTypeTestAssumeUsers || !SummaryTypeCheckedLoadUsers.empty(); } - void markSummaryHasTypeTestAssumeUsers() { - SummaryHasTypeTestAssumeUsers = true; + void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { + SummaryTypeCheckedLoadUsers.push_back(FS); AllCallSitesDevirted = false; } - void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { - SummaryTypeCheckedLoadUsers.push_back(FS); + void addSummaryTypeTestAssumeUser(FunctionSummary *FS) { + SummaryTypeTestAssumeUsers.push_back(FS); + SummaryHasTypeTestAssumeUsers = true; AllCallSitesDevirted = false; } @@ -456,7 +492,6 @@ struct DevirtModule { void buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); - Constant *getPointerAtOffset(Constant *I, uint64_t Offset); bool tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, const std::set<TypeMemberInfo> &TypeMemberInfos, @@ -464,7 +499,8 @@ struct DevirtModule { void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, bool &IsExported); - bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, + bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary, + MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res); @@ -542,6 +578,38 @@ struct DevirtModule { function_ref<DominatorTree &(Function &)> LookupDomTree); }; +struct DevirtIndex { + ModuleSummaryIndex &ExportSummary; + // The set in which to record GUIDs exported from their module by + // devirtualization, used by client to ensure they are not internalized. + std::set<GlobalValue::GUID> &ExportedGUIDs; + // A map in which to record the information necessary to locate the WPD + // resolution for local targets in case they are exported by cross module + // importing. + std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap; + + MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots; + + DevirtIndex( + ModuleSummaryIndex &ExportSummary, + std::set<GlobalValue::GUID> &ExportedGUIDs, + std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) + : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs), + LocalWPDTargetsMap(LocalWPDTargetsMap) {} + + bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot, + const TypeIdCompatibleVtableInfo TIdInfo, + uint64_t ByteOffset); + + bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, + VTableSlotSummary &SlotSummary, + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, + std::set<ValueInfo> &DevirtTargets); + + void run(); +}; + struct WholeProgramDevirt : public ModulePass { static char ID; @@ -572,7 +640,7 @@ struct WholeProgramDevirt : public ModulePass { // an optimization remark emitter on the fly, when we need it. std::unique_ptr<OptimizationRemarkEmitter> ORE; auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { - ORE = make_unique<OptimizationRemarkEmitter>(F); + ORE = std::make_unique<OptimizationRemarkEmitter>(F); return *ORE; }; @@ -632,6 +700,41 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } +namespace llvm { +void runWholeProgramDevirtOnIndex( + ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, + std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { + DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); +} + +void updateIndexWPDForExports( + ModuleSummaryIndex &Summary, + function_ref<bool(StringRef, GlobalValue::GUID)> isExported, + std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { + for (auto &T : LocalWPDTargetsMap) { + auto &VI = T.first; + // This was enforced earlier during trySingleImplDevirt. + assert(VI.getSummaryList().size() == 1 && + "Devirt of local target has more than one copy"); + auto &S = VI.getSummaryList()[0]; + if (!isExported(S->modulePath(), VI.getGUID())) + continue; + + // It's been exported by a cross module import. + for (auto &SlotSummary : T.second) { + auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID); + assert(TIdSum); + auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset); + assert(WPDRes != TIdSum->WPDRes.end()); + WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( + WPDRes->second.SingleImplName, + Summary.getModuleHash(S->modulePath())); + } + } +} + +} // end namespace llvm + bool DevirtModule::runForTesting( Module &M, function_ref<AAResults &(Function &)> AARGetter, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, @@ -662,7 +765,7 @@ bool DevirtModule::runForTesting( ExitOnError ExitOnErr( "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); std::error_code EC; - raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_Text); ExitOnErr(errorCodeToError(EC)); yaml::Output Out(OS); @@ -706,38 +809,6 @@ void DevirtModule::buildTypeIdentifierMap( } } -Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) { - if (I->getType()->isPointerTy()) { - if (Offset == 0) - return I; - return nullptr; - } - - const DataLayout &DL = M.getDataLayout(); - - if (auto *C = dyn_cast<ConstantStruct>(I)) { - const StructLayout *SL = DL.getStructLayout(C->getType()); - if (Offset >= SL->getSizeInBytes()) - return nullptr; - - unsigned Op = SL->getElementContainingOffset(Offset); - return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), - Offset - SL->getElementOffset(Op)); - } - if (auto *C = dyn_cast<ConstantArray>(I)) { - ArrayType *VTableTy = C->getType(); - uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); - - unsigned Op = Offset / ElemSize; - if (Op >= C->getNumOperands()) - return nullptr; - - return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), - Offset % ElemSize); - } - return nullptr; -} - bool DevirtModule::tryFindVirtualCallTargets( std::vector<VirtualCallTarget> &TargetsForSlot, const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { @@ -746,7 +817,7 @@ bool DevirtModule::tryFindVirtualCallTargets( return false; Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), - TM.Offset + ByteOffset); + TM.Offset + ByteOffset, M); if (!Ptr) return false; @@ -766,6 +837,34 @@ bool DevirtModule::tryFindVirtualCallTargets( return !TargetsForSlot.empty(); } +bool DevirtIndex::tryFindVirtualCallTargets( + std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, + uint64_t ByteOffset) { + for (const TypeIdOffsetVtableInfo P : TIdInfo) { + // VTable initializer should have only one summary, or all copies must be + // linkonce/weak ODR. + assert(P.VTableVI.getSummaryList().size() == 1 || + llvm::all_of( + P.VTableVI.getSummaryList(), + [&](const std::unique_ptr<GlobalValueSummary> &Summary) { + return GlobalValue::isLinkOnceODRLinkage(Summary->linkage()) || + GlobalValue::isWeakODRLinkage(Summary->linkage()); + })); + const auto *VS = cast<GlobalVarSummary>(P.VTableVI.getSummaryList()[0].get()); + if (!P.VTableVI.getSummaryList()[0]->isLive()) + continue; + for (auto VTP : VS->vTableFuncs()) { + if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset) + continue; + + TargetsForSlot.push_back(VTP.FuncVI); + } + } + + // Give up if we couldn't find any targets. + return !TargetsForSlot.empty(); +} + void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, bool &IsExported) { auto Apply = [&](CallSiteInfo &CSInfo) { @@ -788,9 +887,38 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, Apply(P.second); } +static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { + // We can't add calls if we haven't seen a definition + if (Callee.getSummaryList().empty()) + return false; + + // Insert calls into the summary index so that the devirtualized targets + // are eligible for import. + // FIXME: Annotate type tests with hotness. For now, mark these as hot + // to better ensure we have the opportunity to inline them. + bool IsExported = false; + auto &S = Callee.getSummaryList()[0]; + CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); + auto AddCalls = [&](CallSiteInfo &CSInfo) { + for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { + FS->addCall({Callee, CI}); + IsExported |= S->modulePath() != FS->modulePath(); + } + for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) { + FS->addCall({Callee, CI}); + IsExported |= S->modulePath() != FS->modulePath(); + } + }; + AddCalls(SlotInfo.CSInfo); + for (auto &P : SlotInfo.ConstCSInfo) + AddCalls(P.second); + return IsExported; +} + bool DevirtModule::trySingleImplDevirt( - MutableArrayRef<VirtualCallTarget> TargetsForSlot, - VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) { + ModuleSummaryIndex *ExportSummary, + MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res) { // See if the program contains a single implementation of this virtual // function. Function *TheFn = TargetsForSlot[0].Fn; @@ -830,6 +958,10 @@ bool DevirtModule::trySingleImplDevirt( TheFn->setVisibility(GlobalValue::HiddenVisibility); TheFn->setName(NewName); } + if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID())) + // Any needed promotion of 'TheFn' has already been done during + // LTO unit split, so we can ignore return value of AddCalls. + AddCalls(SlotInfo, TheFnVI); Res->TheKind = WholeProgramDevirtResolution::SingleImpl; Res->SingleImplName = TheFn->getName(); @@ -837,6 +969,63 @@ bool DevirtModule::trySingleImplDevirt( return true; } +bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, + VTableSlotSummary &SlotSummary, + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, + std::set<ValueInfo> &DevirtTargets) { + // See if the program contains a single implementation of this virtual + // function. + auto TheFn = TargetsForSlot[0]; + for (auto &&Target : TargetsForSlot) + if (TheFn != Target) + return false; + + // Don't devirtualize if we don't have target definition. + auto Size = TheFn.getSummaryList().size(); + if (!Size) + return false; + + // If the summary list contains multiple summaries where at least one is + // a local, give up, as we won't know which (possibly promoted) name to use. + for (auto &S : TheFn.getSummaryList()) + if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) + return false; + + // Collect functions devirtualized at least for one call site for stats. + if (PrintSummaryDevirt) + DevirtTargets.insert(TheFn); + + auto &S = TheFn.getSummaryList()[0]; + bool IsExported = AddCalls(SlotInfo, TheFn); + if (IsExported) + ExportedGUIDs.insert(TheFn.getGUID()); + + // Record in summary for use in devirtualization during the ThinLTO import + // step. + Res->TheKind = WholeProgramDevirtResolution::SingleImpl; + if (GlobalValue::isLocalLinkage(S->linkage())) { + if (IsExported) + // If target is a local function and we are exporting it by + // devirtualizing a call in another module, we need to record the + // promoted name. + Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( + TheFn.name(), ExportSummary.getModuleHash(S->modulePath())); + else { + LocalWPDTargetsMap[TheFn].push_back(SlotSummary); + Res->SingleImplName = TheFn.name(); + } + } else + Res->SingleImplName = TheFn.name(); + + // Name will be empty if this thin link driven off of serialized combined + // index (e.g. llvm-lto). However, WPD is not supported/invoked for the + // legacy LTO API anyway. + assert(!Res->SingleImplName.empty()); + + return true; +} + void DevirtModule::tryICallBranchFunnel( MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res, VTableSlot Slot) { @@ -1302,10 +1491,13 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { if (B.Before.Bytes.empty() && B.After.Bytes.empty()) return; - // Align each byte array to pointer width. - unsigned PointerSize = M.getDataLayout().getPointerSize(); - B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); - B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); + // Align the before byte array to the global's minimum alignment so that we + // don't break any alignment requirements on the global. + MaybeAlign Alignment(B.GV->getAlignment()); + if (!Alignment) + Alignment = + Align(M.getDataLayout().getABITypeAlignment(B.GV->getValueType())); + B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment)); // Before was stored in reverse order; flip it now. for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) @@ -1322,6 +1514,7 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { GlobalVariable::PrivateLinkage, NewInit, "", B.GV); NewGV->setSection(B.GV->getSection()); NewGV->setComdat(B.GV->getComdat()); + NewGV->setAlignment(MaybeAlign(B.GV->getAlignment())); // Copy the original vtable's metadata to the anonymous global, adjusting // offsets as required. @@ -1483,8 +1676,11 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { } void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { + auto *TypeId = dyn_cast<MDString>(Slot.TypeID); + if (!TypeId) + return; const TypeIdSummary *TidSummary = - ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString()); + ImportSummary->getTypeIdSummary(TypeId->getString()); if (!TidSummary) return; auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); @@ -1493,6 +1689,7 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { const WholeProgramDevirtResolution &Res = ResI->second; if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { + assert(!Res.SingleImplName.empty()); // The type of the function in the declaration is irrelevant because every // call site will cast it to the correct type. Constant *SingleImpl = @@ -1627,8 +1824,7 @@ bool DevirtModule::run() { // FIXME: Only add live functions. for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { for (Metadata *MD : MetadataByGUID[VF.GUID]) { - CallSlots[{MD, VF.Offset}] - .CSInfo.markSummaryHasTypeTestAssumeUsers(); + CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); } } for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { @@ -1641,7 +1837,7 @@ bool DevirtModule::run() { for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { CallSlots[{MD, VC.VFunc.Offset}] .ConstCSInfo[VC.Args] - .markSummaryHasTypeTestAssumeUsers(); + .addSummaryTypeTestAssumeUser(FS); } } for (const FunctionSummary::ConstVCall &VC : @@ -1673,7 +1869,7 @@ bool DevirtModule::run() { cast<MDString>(S.first.TypeID)->getString()) .WPDRes[S.first.ByteOffset]; - if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) { + if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); @@ -1710,7 +1906,7 @@ bool DevirtModule::run() { using namespace ore; OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) << "devirtualized " - << NV("FunctionName", F->getName())); + << NV("FunctionName", DT.first)); } } @@ -1722,5 +1918,86 @@ bool DevirtModule::run() { for (VTableBits &B : Bits) rebuildGlobal(B); + // We have lowered or deleted the type checked load intrinsics, so we no + // longer have enough information to reason about the liveness of virtual + // function pointers in GlobalDCE. + for (GlobalVariable &GV : M.globals()) + GV.eraseMetadata(LLVMContext::MD_vcall_visibility); + return true; } + +void DevirtIndex::run() { + if (ExportSummary.typeIdCompatibleVtableMap().empty()) + return; + + DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; + for (auto &P : ExportSummary.typeIdCompatibleVtableMap()) { + NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first); + } + + // Collect information from summary about which calls to try to devirtualize. + for (auto &P : ExportSummary) { + for (auto &S : P.second.SummaryList) { + auto *FS = dyn_cast<FunctionSummary>(S.get()); + if (!FS) + continue; + // FIXME: Only add live functions. + for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { + for (StringRef Name : NameByGUID[VF.GUID]) { + CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); + } + } + for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { + for (StringRef Name : NameByGUID[VF.GUID]) { + CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); + } + } + for (const FunctionSummary::ConstVCall &VC : + FS->type_test_assume_const_vcalls()) { + for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { + CallSlots[{Name, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .addSummaryTypeTestAssumeUser(FS); + } + } + for (const FunctionSummary::ConstVCall &VC : + FS->type_checked_load_const_vcalls()) { + for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { + CallSlots[{Name, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .addSummaryTypeCheckedLoadUser(FS); + } + } + } + } + + std::set<ValueInfo> DevirtTargets; + // For each (type, offset) pair: + for (auto &S : CallSlots) { + // Search each of the members of the type identifier for the virtual + // function implementation at offset S.first.ByteOffset, and add to + // TargetsForSlot. + std::vector<ValueInfo> TargetsForSlot; + auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); + assert(TidSummary); + if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, + S.first.ByteOffset)) { + WholeProgramDevirtResolution *Res = + &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID) + .WPDRes[S.first.ByteOffset]; + + if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res, + DevirtTargets)) + continue; + } + } + + // Optionally have the thin link print message for each devirtualized + // function. + if (PrintSummaryDevirt) + for (const auto &DT : DevirtTargets) + errs() << "Devirtualized call to " << DT << "\n"; + + return; +} diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index ba15b023f2a3..8bc34825f8a7 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1097,6 +1097,107 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +Instruction * +InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I) { + assert((I.getOpcode() == Instruction::Add || + I.getOpcode() == Instruction::Or || + I.getOpcode() == Instruction::Sub) && + "Expecting add/or/sub instruction"); + + // We have a subtraction/addition between a (potentially truncated) *logical* + // right-shift of X and a "select". + Value *X, *Select; + Instruction *LowBitsToSkip, *Extract; + if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd( + m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)), + m_Instruction(Extract))), + m_Value(Select)))) + return nullptr; + + // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS. + if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select) + return nullptr; + + Type *XTy = X->getType(); + bool HadTrunc = I.getType() != XTy; + + // If there was a truncation of extracted value, then we'll need to produce + // one extra instruction, so we need to ensure one instruction will go away. + if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Extraction should extract high NBits bits, with shift amount calculated as: + // low bits to skip = shift bitwidth - high bits to extract + // The shift amount itself may be extended, and we need to look past zero-ext + // when matching NBits, that will matter for matching later. + Constant *C; + Value *NBits; + if (!match( + LowBitsToSkip, + m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) || + !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + X->getType()->getScalarSizeInBits())))) + return nullptr; + + // Sign-extending value can be zero-extended if we `sub`tract it, + // or sign-extended otherwise. + auto SkipExtInMagic = [&I](Value *&V) { + if (I.getOpcode() == Instruction::Sub) + match(V, m_ZExtOrSelf(m_Value(V))); + else + match(V, m_SExtOrSelf(m_Value(V))); + }; + + // Now, finally validate the sign-extending magic. + // `select` itself may be appropriately extended, look past that. + SkipExtInMagic(Select); + + ICmpInst::Predicate Pred; + const APInt *Thr; + Value *SignExtendingValue, *Zero; + bool ShouldSignext; + // It must be a select between two values we will later establish to be a + // sign-extending value and a zero constant. The condition guarding the + // sign-extension must be based on a sign bit of the same X we had in `lshr`. + if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)), + m_Value(SignExtendingValue), m_Value(Zero))) || + !isSignBitCheck(Pred, *Thr, ShouldSignext)) + return nullptr; + + // icmp-select pair is commutative. + if (!ShouldSignext) + std::swap(SignExtendingValue, Zero); + + // If we should not perform sign-extension then we must add/or/subtract zero. + if (!match(Zero, m_Zero())) + return nullptr; + // Otherwise, it should be some constant, left-shifted by the same NBits we + // had in `lshr`. Said left-shift can also be appropriately extended. + // Again, we must look past zero-ext when looking for NBits. + SkipExtInMagic(SignExtendingValue); + Constant *SignExtendingValueBaseConstant; + if (!match(SignExtendingValue, + m_Shl(m_Constant(SignExtendingValueBaseConstant), + m_ZExtOrSelf(m_Specific(NBits))))) + return nullptr; + // If we `sub`, then the constant should be one, else it should be all-ones. + if (I.getOpcode() == Instruction::Sub + ? !match(SignExtendingValueBaseConstant, m_One()) + : !match(SignExtendingValueBaseConstant, m_AllOnes())) + return nullptr; + + auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip, + Extract->getName() + ".sext"); + NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType()); +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1302,12 +1403,32 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Instruction *V = canonicalizeLowbitMask(I, Builder)) return V; + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I)) return SatAdd; return Changed ? &I : nullptr; } +/// Eliminate an op from a linear interpolation (lerp) pattern. +static Instruction *factorizeLerp(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y, *Z; + if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y), + m_OneUse(m_FSub(m_FPOne(), + m_Value(Z))))), + m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z)))))) + return nullptr; + + // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants] + Value *XY = Builder.CreateFSubFMF(X, Y, &I); + Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I); + return BinaryOperator::CreateFAddFMF(Y, MulZ, &I); +} + /// Factor a common operand out of fadd/fsub of fmul/fdiv. static Instruction *factorizeFAddFSub(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1315,6 +1436,10 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I, I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub"); assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "FP factorization requires FMF"); + + if (Instruction *Lerp = factorizeLerp(I, Builder)) + return Lerp; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *X, *Y, *Z; bool IsFMul; @@ -1362,17 +1487,32 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) return FoldedFAdd; - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - Value *X; // (-X) + Y --> Y - X - if (match(LHS, m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFSubFMF(RHS, X, &I); - // Y + (-X) --> Y - X - if (match(RHS, m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFSubFMF(LHS, X, &I); + Value *X, *Y; + if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateFSubFMF(Y, X, &I); + + // Similar to above, but look through fmul/fdiv for the negated term. + // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants] + Value *Z; + if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z)))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } + // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants] + // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants] + if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z))) || + match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))), + m_Value(Z)))) { + Value *XY = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { Value *LHSIntVal = LHSConv->getOperand(0); Type *FPType = LHSConv->getType(); @@ -1631,37 +1771,50 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { const APInt *Op0C; if (match(Op0, m_APInt(Op0C))) { - unsigned BitWidth = I.getType()->getScalarSizeInBits(); - // -(X >>u 31) -> (X >>s 31) - // -(X >>s 31) -> (X >>u 31) if (Op0C->isNullValue()) { + Value *Op1Wide; + match(Op1, m_TruncOrSelf(m_Value(Op1Wide))); + bool HadTrunc = Op1Wide != Op1; + bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse(); + unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits(); + Value *X; const APInt *ShAmt; - if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) && + // -(X >>u 31) -> (X >>s 31) + if (NoTruncOrTruncIsOneUse && + match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) && *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); - return BinaryOperator::CreateAShr(X, ShAmtOp); + Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); + Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp); + NewShift->copyIRFlags(Op1Wide); + if (!HadTrunc) + return NewShift; + Builder.Insert(NewShift); + return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); } - if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) && + // -(X >>s 31) -> (X >>u 31) + if (NoTruncOrTruncIsOneUse && + match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) && *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); - return BinaryOperator::CreateLShr(X, ShAmtOp); + Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); + Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp); + NewShift->copyIRFlags(Op1Wide); + if (!HadTrunc) + return NewShift; + Builder.Insert(NewShift); + return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); } - if (Op1->hasOneUse()) { + if (!HadTrunc && Op1->hasOneUse()) { Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; if (SPF == SPF_ABS || SPF == SPF_NABS) { // This is a negate of an ABS/NABS pattern. Just swap the operands // of the select. - SelectInst *SI = cast<SelectInst>(Op1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - SI->setTrueValue(FalseVal); - SI->setFalseValue(TrueVal); + cast<SelectInst>(Op1)->swapValues(); // Don't swap prof metadata, we didn't change the branch behavior. - return replaceInstUsesWith(I, SI); + return replaceInstUsesWith(I, Op1); } } } @@ -1686,6 +1839,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } + // (sub (or A, B) (and A, B)) --> (xor A, B) + { + Value *A, *B; + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + } + + // (sub (and A, B) (or A, B)) --> neg (xor A, B) + { + Value *A, *B; + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateXor(A, B)); + } + // (sub (or A, B), (xor A, B)) --> (and A, B) { Value *A, *B; @@ -1694,6 +1864,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, B); } + // (sub (xor A, B) (or A, B)) --> neg (and A, B) + { + Value *A, *B; + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B)); + } + { Value *Y; // ((X | Y) - X) --> (~X & Y) @@ -1778,7 +1957,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { std::swap(LHS, RHS); // LHS is now O above and expected to have at least 2 uses (the min/max) // NotA is epected to have 2 uses from the min/max and 1 from the sub. - if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && !NotA->hasNUsesOrMore(4)) { // Note: We don't generate the inverse max/min, just create the not of // it and let other folds do the rest. @@ -1826,6 +2005,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return SelectInst::Create(Cmp, Neg, A); } + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1865,6 +2048,22 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { return nullptr; } +static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, + InstCombiner::BuilderTy &Builder) { + Value *FNeg; + if (!match(&I, m_FNeg(m_Value(FNeg)))) + return nullptr; + + Value *X, *Y; + if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + return nullptr; +} + Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { Value *Op = I.getOperand(0); @@ -1882,6 +2081,9 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) return BinaryOperator::CreateFSubFMF(Y, X, &I); + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + return nullptr; } @@ -1903,6 +2105,9 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Instruction *X = foldFNegIntoConstant(I)) return X; + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + Value *X, *Y; Constant *C; @@ -1944,6 +2149,21 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + // Similar to above, but look through fmul/fdiv of the negated value: + // Op0 - (-X * Y) --> Op0 + (X * Y) + // Op0 - (Y * -X) --> Op0 + (X * Y) + if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) { + Value *FMul = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FMul, &I); + } + // Op0 - (-X / Y) --> Op0 + (X / Y) + // Op0 - (X / -Y) --> Op0 + (X / Y) + if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) || + match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) { + Value *FDiv = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I); + } + // Handle special cases for FSub with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2b9859b602f4..4a30b60ca931 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -160,16 +160,14 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, } /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise -/// (V < Lo || V >= Hi). This method expects that Lo <= Hi. IsSigned indicates +/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates /// whether to treat V, Lo, and Hi as signed or not. Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside) { - assert((isSigned ? Lo.sle(Hi) : Lo.ule(Hi)) && - "Lo is not <= Hi in range emission code!"); + assert((isSigned ? Lo.slt(Hi) : Lo.ult(Hi)) && + "Lo is not < Hi in range emission code!"); Type *Ty = V->getType(); - if (Lo == Hi) - return Inside ? ConstantInt::getFalse(Ty) : ConstantInt::getTrue(Ty); // V >= Min && V < Hi --> V < Hi // V < Min || V >= Hi --> V >= Hi @@ -1051,9 +1049,103 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, return nullptr; } +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, + ICmpInst *UnsignedICmp, bool IsAnd, + const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { + Value *ZeroCmpOp; + ICmpInst::Predicate EqPred; + if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) || + !ICmpInst::isEquality(EqPred)) + return nullptr; + + auto IsKnownNonZero = [&](Value *V) { + return isKnownNonZero(V, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + }; + + ICmpInst::Predicate UnsignedPred; + + Value *A, *B; + if (match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) && + match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && + (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { + if (UnsignedICmp->getOperand(0) != ZeroCmpOp) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + + auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { + if (!IsKnownNonZero(NonZero)) + std::swap(NonZero, Other); + return IsKnownNonZero(NonZero); + }; + + // Given ZeroCmpOp = (A + B) + // ZeroCmpOp <= A && ZeroCmpOp != 0 --> (0-B) < A + // ZeroCmpOp > A || ZeroCmpOp == 0 --> (0-B) >= A + // + // ZeroCmpOp < A && ZeroCmpOp != 0 --> (0-X) < Y iff + // ZeroCmpOp >= A || ZeroCmpOp == 0 --> (0-X) >= Y iff + // with X being the value (A/B) that is known to be non-zero, + // and Y being remaining value. + if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && + IsAnd) + return Builder.CreateICmpULT(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE && + IsAnd && GetKnownNonZeroAndOther(B, A)) + return Builder.CreateICmpULT(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && + !IsAnd) + return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ && + !IsAnd && GetKnownNonZeroAndOther(B, A)) + return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); + } + + Value *Base, *Offset; + if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset)))) + return nullptr; + + if (!match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || + !ICmpInst::isUnsigned(UnsignedPred)) + return nullptr; + if (UnsignedICmp->getOperand(0) != Base) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + + // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset + // (no overflow and not null) + if ((UnsignedPred == ICmpInst::ICMP_UGE || + UnsignedPred == ICmpInst::ICMP_UGT) && + EqPred == ICmpInst::ICMP_NE && IsAnd) + return Builder.CreateICmpUGT(Base, Offset); + + // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset + // (overflow or null) + if ((UnsignedPred == ICmpInst::ICMP_ULE || + UnsignedPred == ICmpInst::ICMP_ULT) && + EqPred == ICmpInst::ICMP_EQ && !IsAnd) + return Builder.CreateICmpULE(Base, Offset); + + // Base <= Offset && (Base - Offset) != 0 --> Base < Offset + if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && + IsAnd) + return Builder.CreateICmpULT(Base, Offset); + + // Base > Offset || (Base - Offset) == 0 --> Base >= Offset + if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && + !IsAnd) + return Builder.CreateICmpUGE(Base, Offset); + + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { + const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) // if K1 and K2 are a one-bit mask. if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI)) @@ -1096,6 +1188,13 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder)) return V; + if (Value *X = + foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/true, Q, Builder)) + return X; + if (Value *X = + foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder)) + return X; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); @@ -1196,16 +1295,22 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_ULT: - if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13 + // (X != 13 & X u< 14) -> X < 13 + if (LHSC->getValue() == (RHSC->getValue() - 1)) return Builder.CreateICmpULT(LHS0, LHSC); - if (LHSC->isZero()) // (X != 0 & X u< 14) -> X-1 u< 13 + if (LHSC->isZero()) // (X != 0 & X u< C) -> X-1 u< C-1 return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: - if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13 + // (X != 13 & X s< 14) -> X < 13 + if (LHSC->getValue() == (RHSC->getValue() - 1)) return Builder.CreateICmpSLT(LHS0, LHSC); - break; // (X != 13 & X s< 15) -> no change + // (X != INT_MIN & X s< C) -> X-(INT_MIN+1) u< (C-(INT_MIN+1)) + if (LHSC->isMinValue(true)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), + true, true); + break; // (X != 13 & X s< 15) -> no change case ICmpInst::ICMP_NE: // Potential folds for this case should already be handled. break; @@ -1216,10 +1321,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14 + // (X u> 13 & X != 14) -> X u> 14 + if (RHSC->getValue() == (LHSC->getValue() + 1)) return Builder.CreateICmp(PredL, LHS0, RHSC); + // X u> C & X != UINT_MAX -> (X-(C+1)) u< UINT_MAX-(C+1) + if (RHSC->isMaxValue(false)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), + false, true); break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) u< 1 return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); } @@ -1229,10 +1339,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14 + // (X s> 13 & X != 14) -> X s> 14 + if (RHSC->getValue() == (LHSC->getValue() + 1)) return Builder.CreateICmp(PredL, LHS0, RHSC); + // X s> C & X != INT_MAX -> (X-(C+1)) u< INT_MAX-(C+1) + if (RHSC->isMaxValue(true)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), + true, true); break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) u< 1 return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, true); } @@ -1352,8 +1467,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *A, *B; if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) && match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) && - !IsFreeToInvert(A, A->hasOneUse()) && - !IsFreeToInvert(B, B->hasOneUse())) { + !isFreeToInvert(A, A->hasOneUse()) && + !isFreeToInvert(B, B->hasOneUse())) { Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } @@ -1770,13 +1885,13 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) - if (Op1->hasOneUse() || IsFreeToInvert(C, C->hasOneUse())) + if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C)); // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) - if (Op0->hasOneUse() || IsFreeToInvert(C, C->hasOneUse())) + if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); // (A | B) & ((~A) ^ B) -> (A & B) @@ -1844,6 +1959,20 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); + // and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0. + { + Value *X, *Y; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(&I, m_c_And(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)), + m_APInt(ShAmt))), + m_Deferred(X))) && + *ShAmt == Ty->getScalarSizeInBits() - 1) { + Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); + return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty)); + } + } + return nullptr; } @@ -2057,6 +2186,8 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { + const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI)) @@ -2182,6 +2313,13 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldIsPowerOf2(LHS, RHS, false /* JoinedByAnd */, Builder)) return V; + if (Value *X = + foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/false, Q, Builder)) + return X; + if (Value *X = + foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder)) + return X; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; @@ -2251,8 +2389,19 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, case ICmpInst::ICMP_EQ: // Potential folds for this case should already be handled. break; - case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change - case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + case ICmpInst::ICMP_UGT: + // (X == 0 || X u> C) -> (X-1) u>= C + if (LHSC->isMinValue(false)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, + false, false); + // (X == 13 | X u> 14) -> no change + break; + case ICmpInst::ICMP_SGT: + // (X == INT_MIN || X s> C) -> (X-(INT_MIN+1)) u>= C-INT_MIN + if (LHSC->isMinValue(true)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, + true, false); + // (X == 13 | X s> 14) -> no change break; } break; @@ -2261,6 +2410,10 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + // (X u< C || X == UINT_MAX) => (X-C) u>= UINT_MAX-C + if (RHSC->isMaxValue(false)) + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), + false, false); break; case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 assert(!RHSC->isMaxValue(false) && "Missed icmp simplification"); @@ -2272,9 +2425,14 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, switch (PredR) { default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change + case ICmpInst::ICMP_EQ: + // (X s< C || X == INT_MAX) => (X-C) u>= INT_MAX-C + if (RHSC->isMaxValue(true)) + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), + true, false); + // (X s< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) u> 2 assert(!RHSC->isMaxValue(true) && "Missed icmp simplification"); return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, false); @@ -2552,6 +2710,25 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } + // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? -1 : X. + { + Value *X, *Y; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(&I, m_c_Or(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)), + m_APInt(ShAmt))), + m_Deferred(X))) && + *ShAmt == Ty->getScalarSizeInBits() - 1) { + Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); + return SelectInst::Create(NewICmpInst, ConstantInt::getAllOnesValue(Ty), + X); + } + } + + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + return nullptr; } @@ -2617,7 +2794,11 @@ static Instruction *foldXorToXor(BinaryOperator &I, return nullptr; } -Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, + BinaryOperator &I) { + assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS && + I.getOperand(1) == RHS && "Should be 'xor' with these operands"); + if (predicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) @@ -2672,14 +2853,35 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { // TODO: If OrICmp is false, the whole thing is false (InstSimplify?). if (Value *AndICmp = SimplifyBinOp(Instruction::And, LHS, RHS, SQ)) { // TODO: Independently handle cases where the 'and' side is a constant. - if (OrICmp == LHS && AndICmp == RHS && RHS->hasOneUse()) { - // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS - RHS->setPredicate(RHS->getInversePredicate()); - return Builder.CreateAnd(LHS, RHS); + ICmpInst *X = nullptr, *Y = nullptr; + if (OrICmp == LHS && AndICmp == RHS) { + // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS --> X & !Y + X = LHS; + Y = RHS; } - if (OrICmp == RHS && AndICmp == LHS && LHS->hasOneUse()) { - // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS - LHS->setPredicate(LHS->getInversePredicate()); + if (OrICmp == RHS && AndICmp == LHS) { + // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS --> !Y & X + X = RHS; + Y = LHS; + } + if (X && Y && (Y->hasOneUse() || canFreelyInvertAllUsersOf(Y, &I))) { + // Invert the predicate of 'Y', thus inverting its output. + Y->setPredicate(Y->getInversePredicate()); + // So, are there other uses of Y? + if (!Y->hasOneUse()) { + // We need to adapt other uses of Y though. Get a value that matches + // the original value of Y before inversion. While this increases + // immediate instruction count, we have just ensured that all the + // users are freely-invertible, so that 'not' *will* get folded away. + BuilderTy::InsertPointGuard Guard(Builder); + // Set insertion point to right after the Y. + Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator())); + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + // Replace all uses of Y (excluding the one in NotY!) with NotY. + Y->replaceUsesWithIf(NotY, + [NotY](Use &U) { return U.getUser() != NotY; }); + } + // All done. return Builder.CreateAnd(LHS, RHS); } } @@ -2747,9 +2949,9 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I, return nullptr; // We only want to do the transform if it is free to do. - if (IsFreeToInvert(X, X->hasOneUse())) { + if (isFreeToInvert(X, X->hasOneUse())) { // Ok, good. - } else if (IsFreeToInvert(Y, Y->hasOneUse())) { + } else if (isFreeToInvert(Y, Y->hasOneUse())) { std::swap(X, Y); } else return nullptr; @@ -2827,9 +3029,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // Apply DeMorgan's Law when inverts are free: // ~(X & Y) --> (~X | ~Y) // ~(X | Y) --> (~X & ~Y) - if (IsFreeToInvert(NotVal->getOperand(0), + if (isFreeToInvert(NotVal->getOperand(0), NotVal->getOperand(0)->hasOneUse()) && - IsFreeToInvert(NotVal->getOperand(1), + isFreeToInvert(NotVal->getOperand(1), NotVal->getOperand(1)->hasOneUse())) { Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs"); Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs"); @@ -3004,7 +3206,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) - if (Value *V = foldXorOfICmps(LHS, RHS)) + if (Value *V = foldXorOfICmps(LHS, RHS, I)) return replaceInstUsesWith(I, V); if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) @@ -3052,7 +3254,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (SelectPatternResult::isMinOrMax(SPF)) { // It's possible we get here before the not has been simplified, so make // sure the input to the not isn't freely invertible. - if (match(LHS, m_Not(m_Value(X))) && !IsFreeToInvert(X, X->hasOneUse())) { + if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { Value *NotY = Builder.CreateNot(RHS); return SelectInst::Create( Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); @@ -3060,7 +3262,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // It's possible we get here before the not has been simplified, so make // sure the input to the not isn't freely invertible. - if (match(RHS, m_Not(m_Value(Y))) && !IsFreeToInvert(Y, Y->hasOneUse())) { + if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { Value *NotX = Builder.CreateNot(LHS); return SelectInst::Create( Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); @@ -3068,8 +3270,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // If both sides are freely invertible, then we can get rid of the xor // completely. - if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && - IsFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { Value *NotLHS = Builder.CreateNot(LHS); Value *NotRHS = Builder.CreateNot(RHS); return SelectInst::Create( diff --git a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index 5f37a00f56cf..825f4b468b0a 100644 --- a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -124,7 +124,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { auto *SI = new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), &RMWI); SI->setAtomic(Ordering, RMWI.getSyncScopeID()); - SI->setAlignment(DL.getABITypeAlignment(RMWI.getType())); + SI->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType()))); return eraseInstFromFunction(RMWI); } @@ -154,6 +154,6 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand()); Load->setAtomic(Ordering, RMWI.getSyncScopeID()); - Load->setAlignment(DL.getABITypeAlignment(RMWI.getType())); + Load->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType()))); return Load; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 4b3333affa72..c650d242cd50 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -185,7 +185,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); LoadInst *L = Builder.CreateLoad(IntType, Src); // Alignment from the mem intrinsic will be better, so use it. - L->setAlignment(CopySrcAlign); + L->setAlignment( + MaybeAlign(CopySrcAlign)); // FIXME: Check if we can use Align instead. if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); MDNode *LoopMemParallelMD = @@ -198,7 +199,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { StoreInst *S = Builder.CreateStore(L, Dest); // Alignment from the mem intrinsic will be better, so use it. - S->setAlignment(CopyDstAlign); + S->setAlignment( + MaybeAlign(CopyDstAlign)); // FIXME: Check if we can use Align instead. if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); if (LoopMemParallelMD) @@ -223,9 +225,10 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { } Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); - if (MI->getDestAlignment() < Alignment) { - MI->setDestAlignment(Alignment); + const unsigned KnownAlignment = + getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); + if (MI->getDestAlignment() < KnownAlignment) { + MI->setDestAlignment(KnownAlignment); return MI; } @@ -243,13 +246,9 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue()); if (!LenC || !FillC || !FillC->getType()->isIntegerTy(8)) return nullptr; - uint64_t Len = LenC->getLimitedValue(); - Alignment = MI->getDestAlignment(); + const uint64_t Len = LenC->getLimitedValue(); assert(Len && "0-sized memory setting should be removed already."); - - // Alignment 0 is identity for alignment 1 for memset, but not store. - if (Alignment == 0) - Alignment = 1; + const Align Alignment = assumeAligned(MI->getDestAlignment()); // If it is an atomic and alignment is less than the size then we will // introduce the unaligned memory access which will be later transformed @@ -1060,9 +1059,9 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query - if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment, - II.getModule()->getDataLayout(), - &II, nullptr)) { + if (isDereferenceableAndAlignedPointer( + LoadPtr, II.getType(), MaybeAlign(Alignment), + II.getModule()->getDataLayout(), &II, nullptr)) { Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); @@ -1086,7 +1085,8 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { // If the mask is all ones, this is a plain vector store of the 1st argument. if (ConstMask->isAllOnesValue()) { Value *StorePtr = II.getArgOperand(1); - unsigned Alignment = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + MaybeAlign Alignment( + cast<ConstantInt>(II.getArgOperand(2))->getZExtValue()); return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } @@ -2234,6 +2234,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, Add); } + // Try to simplify the underlying FMul. + if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; + } + LLVM_FALLTHROUGH; } case Intrinsic::fma: { @@ -2258,9 +2267,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; } - // fma x, 1, z -> fadd x, z - if (match(Src1, m_FPOne())) { - auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); + // Try to simplify the underlying FMul. We can only apply simplifications + // that do not require rounding. + if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); FAdd->copyFastMathFlags(II); return FAdd; } @@ -2331,7 +2343,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX loads into normal loads. Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr, Twine(""), false, 1); + return new LoadInst(II->getType(), Ptr, Twine(""), false, Align::None()); } case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: @@ -2349,7 +2361,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX stores into normal stores. Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr, false, 1); + return new StoreInst(II->getArgOperand(0), Ptr, false, Align::None()); } case Intrinsic::ppc_qpx_qvlfs: // Turn PPC QPX qvlfs -> load if the pointer is known aligned. @@ -3885,6 +3897,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Asan needs to poison memory to detect invalid access which is possible // even for empty lifetime range. if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || + II->getFunction()->hasFnAttribute(Attribute::SanitizeMemory) || II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) break; @@ -3950,10 +3963,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::experimental_gc_relocate: { + auto &GCR = *cast<GCRelocateInst>(II); + + // If we have two copies of the same pointer in the statepoint argument + // list, canonicalize to one. This may let us common gc.relocates. + if (GCR.getBasePtr() == GCR.getDerivedPtr() && + GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) { + auto *OpIntTy = GCR.getOperand(2)->getType(); + II->setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); + return II; + } + // Translate facts known about a pointer before relocating into // facts about the relocate value, while being careful to // preserve relocation semantics. - Value *DerivedPtr = cast<GCRelocateInst>(II)->getDerivedPtr(); + Value *DerivedPtr = GCR.getDerivedPtr(); // Remove the relocation if unused, note that this check is required // to prevent the cases below from looping forever. @@ -4177,10 +4201,58 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { return nullptr; } +static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { + unsigned NumArgs = Call.getNumArgOperands(); + ConstantInt *Op0C = dyn_cast<ConstantInt>(Call.getOperand(0)); + ConstantInt *Op1C = + (NumArgs == 1) ? nullptr : dyn_cast<ConstantInt>(Call.getOperand(1)); + // Bail out if the allocation size is zero. + if ((Op0C && Op0C->isNullValue()) || (Op1C && Op1C->isNullValue())) + return; + + if (isMallocLikeFn(&Call, TLI) && Op0C) { + if (isOpNewLikeFn(&Call, TLI)) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableBytes( + Call.getContext(), Op0C->getZExtValue())); + else + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op0C->getZExtValue())); + } else if (isReallocLikeFn(&Call, TLI) && Op1C) { + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); + } else if (isCallocLikeFn(&Call, TLI) && Op0C && Op1C) { + bool Overflow; + const APInt &N = Op0C->getValue(); + APInt Size = N.umul_ov(Op1C->getValue(), Overflow); + if (!Overflow) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size.getZExtValue())); + } else if (isStrdupLikeFn(&Call, TLI)) { + uint64_t Len = GetStringLength(Call.getOperand(0)); + if (Len) { + // strdup + if (NumArgs == 1) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Len)); + // strndup + else if (NumArgs == 2 && Op1C) + Call.addAttribute( + AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); + } + } +} + /// Improvements for call, callbr and invoke instructions. Instruction *InstCombiner::visitCallBase(CallBase &Call) { - if (isAllocLikeFn(&Call, &TLI)) - return visitAllocSite(Call); + if (isAllocationFn(&Call, &TLI)) + annotateAnyAllocSite(Call, &TLI); bool Changed = false; @@ -4312,6 +4384,9 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { if (I) return eraseInstFromFunction(*I); } + if (isAllocLikeFn(&Call, &TLI)) + return visitAllocSite(Call); + return Changed ? &Call : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 2c9ba203fbf3..65aaef28d87a 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -140,7 +140,7 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, } AllocaInst *New = AllocaBuilder.CreateAlloca(CastElTy, Amt); - New->setAlignment(AI.getAlignment()); + New->setAlignment(MaybeAlign(AI.getAlignment())); New->takeName(&AI); New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); @@ -1531,16 +1531,16 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // what we can and cannot do safely varies from operation to operation, and // is explained below in the various case statements. Type *Ty = FPT.getType(); - BinaryOperator *OpI = dyn_cast<BinaryOperator>(FPT.getOperand(0)); - if (OpI && OpI->hasOneUse()) { - Type *LHSMinType = getMinimumFPType(OpI->getOperand(0)); - Type *RHSMinType = getMinimumFPType(OpI->getOperand(1)); - unsigned OpWidth = OpI->getType()->getFPMantissaWidth(); + auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); + if (BO && BO->hasOneUse()) { + Type *LHSMinType = getMinimumFPType(BO->getOperand(0)); + Type *RHSMinType = getMinimumFPType(BO->getOperand(1)); + unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); unsigned SrcWidth = std::max(LHSWidth, RHSWidth); unsigned DstWidth = Ty->getFPMantissaWidth(); - switch (OpI->getOpcode()) { + switch (BO->getOpcode()) { default: break; case Instruction::FAdd: case Instruction::FSub: @@ -1563,10 +1563,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // could be tightened for those cases, but they are rare (the main // case of interest here is (float)((double)float + float)). if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { - Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); - Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); - Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHS, RHS); - RI->copyFastMathFlags(OpI); + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + Instruction *RI = BinaryOperator::Create(BO->getOpcode(), LHS, RHS); + RI->copyFastMathFlags(BO); return RI; } break; @@ -1577,9 +1577,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // rounding can possibly occur; we can safely perform the operation // in the destination format if it can represent both sources. if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { - Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); - Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); - return BinaryOperator::CreateFMulFMF(LHS, RHS, OpI); + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + return BinaryOperator::CreateFMulFMF(LHS, RHS, BO); } break; case Instruction::FDiv: @@ -1590,9 +1590,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // condition used here is a good conservative first pass. // TODO: Tighten bound via rigorous analysis of the unbalanced case. if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { - Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); - Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); - return BinaryOperator::CreateFDivFMF(LHS, RHS, OpI); + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + return BinaryOperator::CreateFDivFMF(LHS, RHS, BO); } break; case Instruction::FRem: { @@ -1604,14 +1604,14 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { break; Value *LHS, *RHS; if (LHSWidth == SrcWidth) { - LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType); - RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType); + LHS = Builder.CreateFPTrunc(BO->getOperand(0), LHSMinType); + RHS = Builder.CreateFPTrunc(BO->getOperand(1), LHSMinType); } else { - LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType); - RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType); + LHS = Builder.CreateFPTrunc(BO->getOperand(0), RHSMinType); + RHS = Builder.CreateFPTrunc(BO->getOperand(1), RHSMinType); } - Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, OpI); + Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, BO); return CastInst::CreateFPCast(ExactResult, Ty); } } @@ -2338,8 +2338,23 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If we found a path from the src to dest, create the getelementptr now. if (SrcElTy == DstElTy) { SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); - return GetElementPtrInst::CreateInBounds(SrcPTy->getElementType(), Src, - Idxs); + GetElementPtrInst *GEP = + GetElementPtrInst::Create(SrcPTy->getElementType(), Src, Idxs); + + // If the source pointer is dereferenceable, then assume it points to an + // allocated object and apply "inbounds" to the GEP. + bool CanBeNull; + if (Src->getPointerDereferenceableBytes(DL, CanBeNull)) { + // In a non-default address space (not 0), a null pointer can not be + // assumed inbounds, so ignore that case (dereferenceable_or_null). + // The reason is that 'null' is not treated differently in these address + // spaces, and we consequently ignore the 'gep inbounds' special case + // for 'null' which allows 'inbounds' on 'null' if the indices are + // zeros. + if (SrcPTy->getAddressSpace() == 0 || !CanBeNull) + GEP->setIsInBounds(); + } + return GEP; } } @@ -2391,28 +2406,47 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } - if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(Src)) { + if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) { // Okay, we have (bitcast (shuffle ..)). Check to see if this is // a bitcast to a vector with the same # elts. - if (SVI->hasOneUse() && DestTy->isVectorTy() && - DestTy->getVectorNumElements() == SVI->getType()->getNumElements() && - SVI->getType()->getNumElements() == - SVI->getOperand(0)->getType()->getVectorNumElements()) { + Value *ShufOp0 = Shuf->getOperand(0); + Value *ShufOp1 = Shuf->getOperand(1); + unsigned NumShufElts = Shuf->getType()->getVectorNumElements(); + unsigned NumSrcVecElts = ShufOp0->getType()->getVectorNumElements(); + if (Shuf->hasOneUse() && DestTy->isVectorTy() && + DestTy->getVectorNumElements() == NumShufElts && + NumShufElts == NumSrcVecElts) { BitCastInst *Tmp; // If either of the operands is a cast from CI.getType(), then // evaluating the shuffle in the casted destination's type will allow // us to eliminate at least one cast. - if (((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(0))) && + if (((Tmp = dyn_cast<BitCastInst>(ShufOp0)) && Tmp->getOperand(0)->getType() == DestTy) || - ((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(1))) && + ((Tmp = dyn_cast<BitCastInst>(ShufOp1)) && Tmp->getOperand(0)->getType() == DestTy)) { - Value *LHS = Builder.CreateBitCast(SVI->getOperand(0), DestTy); - Value *RHS = Builder.CreateBitCast(SVI->getOperand(1), DestTy); + Value *LHS = Builder.CreateBitCast(ShufOp0, DestTy); + Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy); // Return a new shuffle vector. Use the same element ID's, as we // know the vector types match #elts. - return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2)); + return new ShuffleVectorInst(LHS, RHS, Shuf->getOperand(2)); } } + + // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as + // a byte-swap: + // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X) + // TODO: We should match the related pattern for bitreverse. + if (DestTy->isIntegerTy() && + DL.isLegalInteger(DestTy->getScalarSizeInBits()) && + SrcTy->getScalarSizeInBits() == 8 && NumShufElts % 2 == 0 && + Shuf->hasOneUse() && Shuf->isReverse()) { + assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); + assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op"); + Function *Bswap = + Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy); + Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); + return IntrinsicInst::Create(Bswap, { ScalarX }); + } } // Handle the A->B->A cast, and there is an intervening PHI node. diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3a4283ae5406..a9f64feb600c 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -69,34 +69,6 @@ static bool hasBranchUse(ICmpInst &I) { return false; } -/// Given an exploded icmp instruction, return true if the comparison only -/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the -/// result of the comparison is true when the input value is signed. -static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, - bool &TrueIfSigned) { - switch (Pred) { - case ICmpInst::ICMP_SLT: // True if LHS s< 0 - TrueIfSigned = true; - return RHS.isNullValue(); - case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 - TrueIfSigned = true; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_SGT: // True if LHS s> -1 - TrueIfSigned = false; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_UGT: - // True if LHS u> RHS and RHS == high-bit-mask - 1 - TrueIfSigned = true; - return RHS.isMaxSignedValue(); - case ICmpInst::ICMP_UGE: - // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) - TrueIfSigned = true; - return RHS.isSignMask(); - default: - return false; - } -} - /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. @@ -832,6 +804,10 @@ getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, const DataLayout &DL) { + // FIXME: Support vector of pointers. + if (GEPLHS->getType()->isVectorTy()) + return nullptr; + if (!GEPLHS->hasAllConstantIndices()) return nullptr; @@ -882,7 +858,9 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (PtrBase == RHS && GEPLHS->isInBounds()) { + // FIXME: Support vector pointer GEPs. + if (PtrBase == RHS && GEPLHS->isInBounds() && + !GEPLHS->getType()->isVectorTy()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can @@ -894,6 +872,37 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); + } + + if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) && + isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() && + !NullPointerIsDefined(I.getFunction(), + RHS->getType()->getPointerAddressSpace())) { + // For most address spaces, an allocation can't be placed at null, but null + // itself is treated as a 0 size allocation in the in bounds rules. Thus, + // the only valid inbounds address derived from null, is null itself. + // Thus, we have four cases to consider: + // 1) Base == nullptr, Offset == 0 -> inbounds, null + // 2) Base == nullptr, Offset != 0 -> poison as the result is out of bounds + // 3) Base != nullptr, Offset == (-base) -> poison (crossing allocations) + // 4) Base != nullptr, Offset != (-base) -> nonnull (and possibly poison) + // + // (Note if we're indexing a type of size 0, that simply collapses into one + // of the buckets above.) + // + // In general, we're allowed to make values less poison (i.e. remove + // sources of full UB), so in this case, we just select between the two + // non-poison cases (1 and 4 above). + // + // For vectors, we apply the same reasoning on a per-lane basis. + auto *Base = GEPLHS->getPointerOperand(); + if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { + int NumElts = GEPLHS->getType()->getVectorNumElements(); + Base = Builder.CreateVectorSplat(NumElts, Base); + } + return new ICmpInst(Cond, Base, + ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(RHS), Base->getType())); } else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) { // If the base pointers are different, but the indices are the same, just // compare the base pointer. @@ -916,11 +925,13 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. + // FIXME: Support vector of pointers. if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == - GEPRHS->getOperand(0)->stripPointerCasts()) { + GEPRHS->getOperand(0)->stripPointerCasts() && + !GEPLHS->getType()->isVectorTy()) { Value *LOffset = EmitGEPOffset(GEPLHS); Value *ROffset = EmitGEPOffset(GEPRHS); @@ -949,12 +960,14 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } // If one of the GEPs has all zero indices, recurse. - if (GEPLHS->hasAllZeroIndices()) + // FIXME: Handle vector of pointers. + if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices()) return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. - if (GEPRHS->hasAllZeroIndices()) + // FIXME: Handle vector of pointers. + if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices()) return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); @@ -964,15 +977,20 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, unsigned DiffOperand = 0; // The operand that differs. for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { - if (GEPLHS->getOperand(i)->getType()->getPrimitiveSizeInBits() != - GEPRHS->getOperand(i)->getType()->getPrimitiveSizeInBits()) { + Type *LHSType = GEPLHS->getOperand(i)->getType(); + Type *RHSType = GEPRHS->getOperand(i)->getType(); + // FIXME: Better support for vector of pointers. + if (LHSType->getPrimitiveSizeInBits() != + RHSType->getPrimitiveSizeInBits() || + (GEPLHS->getType()->isVectorTy() && + (!LHSType->isVectorTy() || !RHSType->isVectorTy()))) { // Irreconcilable differences. NumDifferences = 2; break; - } else { - if (NumDifferences++) break; - DiffOperand = i; } + + if (NumDifferences++) break; + DiffOperand = i; } if (NumDifferences == 0) // SAME GEP? @@ -1317,6 +1335,59 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +/// If we have: +/// icmp eq/ne (urem/srem %x, %y), 0 +/// iff %y is a power-of-two, we can replace this with a bit test: +/// icmp eq/ne (and %x, (add %y, -1)), 0 +Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { + // This fold is only valid for equality predicates. + if (!I.isEquality()) + return nullptr; + ICmpInst::Predicate Pred; + Value *X, *Y, *Zero; + if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))), + m_CombineAnd(m_Zero(), m_Value(Zero))))) + return nullptr; + if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I)) + return nullptr; + // This may increase instruction count, we don't enforce that Y is a constant. + Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType())); + Value *Masked = Builder.CreateAnd(X, Mask); + return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero); +} + +/// Fold equality-comparison between zero and any (maybe truncated) right-shift +/// by one-less-than-bitwidth into a sign test on the original value. +Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { + Instruction *Val; + ICmpInst::Predicate Pred; + if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) + return nullptr; + + Value *X; + Type *XTy; + + Constant *C; + if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { + XTy = X->getType(); + unsigned XBitWidth = XTy->getScalarSizeInBits(); + if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(XBitWidth, XBitWidth - 1)))) + return nullptr; + } else if (isa<BinaryOperator>(Val) && + (X = reassociateShiftAmtsOfTwoSameDirectionShifts( + cast<BinaryOperator>(Val), SQ.getWithInstruction(Val), + /*AnalyzeForSignBitExtraction=*/true))) { + XTy = X->getType(); + } else + return nullptr; + + return ICmpInst::Create(Instruction::ICmp, + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_SLT, + X, ConstantInt::getNullValue(XTy)); +} + // Handle icmp pred X, 0 Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1335,6 +1406,9 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { } } + if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp)) + return New; + // Given: // icmp eq/ne (urem %x, %y), 0 // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': @@ -2179,6 +2253,44 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, return nullptr; } +Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp, + BinaryOperator *SRem, + const APInt &C) { + // Match an 'is positive' or 'is negative' comparison of remainder by a + // constant power-of-2 value: + // (X % pow2C) sgt/slt 0 + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT) + return nullptr; + + // TODO: The one-use check is standard because we do not typically want to + // create longer instruction sequences, but this might be a special-case + // because srem is not good for analysis or codegen. + if (!SRem->hasOneUse()) + return nullptr; + + const APInt *DivisorC; + if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + return nullptr; + + // Mask off the sign bit and the modulo bits (low-bits). + Type *Ty = SRem->getType(); + APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits()); + Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); + Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); + + // For 'is positive?' check that the sign-bit is clear and at least 1 masked + // bit is set. Example: + // (i8 X % 32) s> 0 --> (X & 159) s> 0 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty)); + + // For 'is negative?' check that the sign-bit is set and at least 1 masked + // bit is set. Example: + // (i16 X % 4) s< 0 --> (X & 32771) u> 32768 + return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask)); +} + /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, @@ -2387,6 +2499,11 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, const APInt *C2; APInt SubResult; + // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0 + if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality()) + return new ICmpInst(Cmp.getPredicate(), Y, + ConstantInt::get(Y->getType(), 0)); + // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) if (match(X, m_APInt(C2)) && ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || @@ -2509,20 +2626,49 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, // TODO: Generalize this to work with other comparison idioms or ensure // they get canonicalized into this form. - // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 - // Greater), where Equal, Less and Greater are placeholders for any three - // constants. - ICmpInst::Predicate PredA, PredB; - if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && - match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && - PredA == ICmpInst::ICMP_EQ && - match(SI->getFalseValue(), - m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), - m_ConstantInt(Less), m_ConstantInt(Greater))) && - PredB == ICmpInst::ICMP_SLT) { - return true; + // select i1 (a == b), + // i32 Equal, + // i32 (select i1 (a < b), i32 Less, i32 Greater) + // where Equal, Less and Greater are placeholders for any three constants. + ICmpInst::Predicate PredA; + if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) || + !ICmpInst::isEquality(PredA)) + return false; + Value *EqualVal = SI->getTrueValue(); + Value *UnequalVal = SI->getFalseValue(); + // We still can get non-canonical predicate here, so canonicalize. + if (PredA == ICmpInst::ICMP_NE) + std::swap(EqualVal, UnequalVal); + if (!match(EqualVal, m_ConstantInt(Equal))) + return false; + ICmpInst::Predicate PredB; + Value *LHS2, *RHS2; + if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)), + m_ConstantInt(Less), m_ConstantInt(Greater)))) + return false; + // We can get predicate mismatch here, so canonicalize if possible: + // First, ensure that 'LHS' match. + if (LHS2 != LHS) { + // x sgt y <--> y slt x + std::swap(LHS2, RHS2); + PredB = ICmpInst::getSwappedPredicate(PredB); + } + if (LHS2 != LHS) + return false; + // We also need to canonicalize 'RHS'. + if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { + // x sgt C-1 <--> x sge C <--> not(x slt C) + auto FlippedStrictness = + getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2)); + if (!FlippedStrictness) + return false; + assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check"); + RHS2 = FlippedStrictness->second; + // And kind-of perform the result swap. + std::swap(Less, Greater); + PredB = ICmpInst::ICMP_SLT; } - return false; + return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, @@ -2702,6 +2848,10 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; + case Instruction::SRem: + if (Instruction *I = foldICmpSRemConstant(Cmp, BO, *C)) + return I; + break; case Instruction::UDiv: if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; @@ -2926,6 +3076,28 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, } break; } + + case Intrinsic::uadd_sat: { + // uadd.sat(a, b) == 0 -> (a | b) == 0 + if (C.isNullValue()) { + Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); + return replaceInstUsesWith(Cmp, Builder.CreateICmp( + Cmp.getPredicate(), Or, Constant::getNullValue(Ty))); + + } + break; + } + + case Intrinsic::usub_sat: { + // usub.sat(a, b) == 0 -> a <= b + if (C.isNullValue()) { + ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return ICmpInst::Create(Instruction::ICmp, NewPred, + II->getArgOperand(0), II->getArgOperand(1)); + } + break; + } default: break; } @@ -3275,6 +3447,7 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I, // we should move shifts to the same hand of 'and', i.e. rewrite as // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) // We are only interested in opposite logical shifts here. +// One of the shifts can be truncated. // If we can, we want to end up creating 'lshr' shift. static Value * foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, @@ -3284,55 +3457,215 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, return nullptr; auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); - auto m_AnyLShr = m_LShr(m_Value(), m_Value()); - - // Look for an 'and' of two (opposite) logical shifts. - // Pick the single-use shift as XShift. - Value *XShift, *YShift; - if (!match(I.getOperand(0), - m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))), - m_CombineAnd(m_AnyLogicalShift, m_Value(YShift))))) + + // Look for an 'and' of two logical shifts, one of which may be truncated. + // We use m_TruncOrSelf() on the RHS to correctly handle commutative case. + Instruction *XShift, *MaybeTruncation, *YShift; + if (!match( + I.getOperand(0), + m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), + m_CombineAnd(m_TruncOrSelf(m_CombineAnd( + m_AnyLogicalShift, m_Instruction(YShift))), + m_Instruction(MaybeTruncation))))) return nullptr; - // If YShift is a single-use 'lshr', swap the shifts around. - if (match(YShift, m_OneUse(m_AnyLShr))) + // We potentially looked past 'trunc', but only when matching YShift, + // therefore YShift must have the widest type. + Instruction *WidestShift = YShift; + // Therefore XShift must have the shallowest type. + // Or they both have identical types if there was no truncation. + Instruction *NarrowestShift = XShift; + + Type *WidestTy = WidestShift->getType(); + assert(NarrowestShift->getType() == I.getOperand(0)->getType() && + "We did not look past any shifts while matching XShift though."); + bool HadTrunc = WidestTy != I.getOperand(0)->getType(); + + // If YShift is a 'lshr', swap the shifts around. + if (match(YShift, m_LShr(m_Value(), m_Value()))) std::swap(XShift, YShift); // The shifts must be in opposite directions. - Instruction::BinaryOps XShiftOpcode = - cast<BinaryOperator>(XShift)->getOpcode(); - if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode()) + auto XShiftOpcode = XShift->getOpcode(); + if (XShiftOpcode == YShift->getOpcode()) return nullptr; // Do not care about same-direction shifts here. Value *X, *XShAmt, *Y, *YShAmt; - match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); - match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt)))); + match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt)))); + + // If one of the values being shifted is a constant, then we will end with + // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not, + // however, we will need to ensure that we won't increase instruction count. + if (!isa<Constant>(X) && !isa<Constant>(Y)) { + // At least one of the hands of the 'and' should be one-use shift. + if (!match(I.getOperand(0), + m_c_And(m_OneUse(m_AnyLogicalShift), m_Value()))) + return nullptr; + if (HadTrunc) { + // Due to the 'trunc', we will need to widen X. For that either the old + // 'trunc' or the shift amt in the non-truncated shift should be one-use. + if (!MaybeTruncation->hasOneUse() && + !NarrowestShift->getOperand(1)->hasOneUse()) + return nullptr; + } + } + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (XShAmt->getType() != YShAmt->getType()) + return nullptr; // Can we fold (XShAmt+YShAmt) ? - Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt, - SQ.getWithInstruction(&I)); + auto *NewShAmt = dyn_cast_or_null<Constant>( + SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, + /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); + // Is the new shift amount smaller than the bit width? // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); - if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, - APInt(BitWidth, BitWidth)))) + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(WidestBitWidth, WidestBitWidth)))) return nullptr; - // All good, we can do this fold. The shift is the same that was for X. + + // An extra legality check is needed if we had trunc-of-lshr. + if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) { + auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ, + WidestShift]() { + // It isn't obvious whether it's worth it to analyze non-constants here. + // Also, let's basically give up on non-splat cases, pessimizing vectors. + // If *any* of these preconditions matches we can perform the fold. + Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy() + ? NewShAmt->getSplatValue() + : NewShAmt; + // If it's edge-case shift (by 0 or by WidestBitWidth-1) we can fold. + if (NewShAmtSplat && + (NewShAmtSplat->isNullValue() || + NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1)) + return true; + // We consider *min* leading zeros so a single outlier + // blocks the transform as opposed to allowing it. + if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: NewShAmt u<= countLeadingZeros(C) + if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero)) + return true; + } + if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: ((WidestBitWidth-1)-NewShAmt) u<= countLeadingZeros(C) + if (NewShAmtSplat) { + APInt AdjNewShAmt = + (WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger(); + if (AdjNewShAmt.ule(MinLeadZero)) + return true; + } + } + return false; // Can't tell if it's ok. + }; + if (!CanFold()) + return nullptr; + } + + // All good, we can do this fold. + X = Builder.CreateZExt(X, WidestTy); + Y = Builder.CreateZExt(Y, WidestTy); + // The shift is the same that was for X. Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr ? Builder.CreateLShr(X, NewShAmt) : Builder.CreateShl(X, NewShAmt); Value *T1 = Builder.CreateAnd(T0, Y); return Builder.CreateICmp(I.getPredicate(), T1, - Constant::getNullValue(X->getType())); + Constant::getNullValue(WidestTy)); +} + +/// Fold +/// (-1 u/ x) u< y +/// ((x * y) u/ x) != y +/// to +/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// Note that the comparison is commutative, while inverted (u>=, ==) predicate +/// will mean that we are looking for the opposite answer. +Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { + ICmpInst::Predicate Pred; + Value *X, *Y; + Instruction *Mul; + bool NeedNegation; + // Look for: (-1 u/ x) u</u>= y + if (!I.isEquality() && + match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Value(Y)))) { + Mul = nullptr; + // Canonicalize as-if y was on RHS. + if (I.getOperand(1) != Y) + Pred = I.getSwappedPredicate(); + + // Are we checking that overflow does not happen, or does happen? + switch (Pred) { + case ICmpInst::Predicate::ICMP_ULT: + NeedNegation = false; + break; // OK + case ICmpInst::Predicate::ICMP_UGE: + NeedNegation = true; + break; // OK + default: + return nullptr; // Wrong predicate. + } + } else // Look for: ((x * y) u/ x) !=/== y + if (I.isEquality() && + match(&I, m_c_ICmp(Pred, m_Value(Y), + m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + m_Value(X)), + m_Instruction(Mul)), + m_Deferred(X)))))) { + NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; + } else + return nullptr; + + BuilderTy::InsertPointGuard Guard(Builder); + // If the pattern included (x * y), we'll want to insert new instructions + // right before that original multiplication so that we can replace it. + bool MulHadOtherUses = Mul && !Mul->hasOneUse(); + if (MulHadOtherUses) + Builder.SetInsertPoint(Mul); + + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::umul_with_overflow, X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + + // If the multiplication was used elsewhere, to ensure that we don't leave + // "duplicate" instructions, replace uses of that original multiplication + // with the multiplication result from the with.overflow intrinsic. + if (MulHadOtherUses) + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + + Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + if (NeedNegation) // This technically increases instruction count. + Res = Builder.CreateNot(Res, "umul.not.ov"); + + return Res; } /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code /// duplication. -Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { +Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // Special logic for binary operators. @@ -3345,13 +3678,13 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { Value *X; // Convert add-with-unsigned-overflow comparisons into a 'not' with compare. - // (Op1 + X) <u Op1 --> ~Op1 <u X - // Op0 >u (Op0 + X) --> X >u ~Op0 + // (Op1 + X) u</u>= Op1 --> ~Op1 u</u>= X if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) && - Pred == ICmpInst::ICMP_ULT) + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) return new ICmpInst(Pred, Builder.CreateNot(Op1), X); + // Op0 u>/u<= (Op0 + X) --> X u>/u<= ~Op0 if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) && - Pred == ICmpInst::ICMP_UGT) + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; @@ -3378,21 +3711,21 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + // icmp (A+B), A -> icmp B, 0 for equalities or if there is no overflow. + // icmp (A+B), B -> icmp A, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, Constant::getNullValue(Op1->getType())); - // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + // icmp C, (C+D) -> icmp 0, D for equalities or if there is no overflow. + // icmp D, (C+D) -> icmp 0, C for equalities or if there is no overflow. if ((C == Op0 || D == Op0) && NoOp1WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), C == Op0 ? D : C); - // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + // icmp (A+B), (A+D) -> icmp B, D for equalities or if there is no overflow. if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && - NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) { + NoOp1WrapProblem) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; if (A == C) { @@ -3416,39 +3749,39 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { return new ICmpInst(Pred, Y, Z); } - // icmp slt (X + -1), Y -> icmp sle X, Y + // icmp slt (A + -1), Op1 -> icmp sle A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && match(B, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - // icmp sge (X + -1), Y -> icmp sgt X, Y + // icmp sge (A + -1), Op1 -> icmp sgt A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && match(B, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - // icmp sle (X + 1), Y -> icmp slt X, Y + // icmp sle (A + 1), Op1 -> icmp slt A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - // icmp sgt (X + 1), Y -> icmp sge X, Y + // icmp sgt (A + 1), Op1 -> icmp sge A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - // icmp sgt X, (Y + -1) -> icmp sge X, Y + // icmp sgt Op0, (C + -1) -> icmp sge Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && match(D, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - // icmp sle X, (Y + -1) -> icmp slt X, Y + // icmp sle Op0, (C + -1) -> icmp slt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && match(D, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - // icmp sge X, (Y + 1) -> icmp sgt X, Y + // icmp sge Op0, (C + 1) -> icmp sgt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); - // icmp slt X, (Y + 1) -> icmp sle X, Y + // icmp slt Op0, (C + 1) -> icmp sle Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); @@ -3456,33 +3789,33 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { // canonicalization from (X -nuw 1) to (X + -1) means that the combinations // wouldn't happen even if they were implemented. // - // icmp ult (X - 1), Y -> icmp ule X, Y - // icmp uge (X - 1), Y -> icmp ugt X, Y - // icmp ugt X, (Y - 1) -> icmp uge X, Y - // icmp ule X, (Y - 1) -> icmp ult X, Y + // icmp ult (A - 1), Op1 -> icmp ule A, Op1 + // icmp uge (A - 1), Op1 -> icmp ugt A, Op1 + // icmp ugt Op0, (C - 1) -> icmp uge Op0, C + // icmp ule Op0, (C - 1) -> icmp ult Op0, C - // icmp ule (X + 1), Y -> icmp ult X, Y + // icmp ule (A + 1), Op0 -> icmp ult A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); - // icmp ugt (X + 1), Y -> icmp uge X, Y + // icmp ugt (A + 1), Op0 -> icmp uge A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); - // icmp uge X, (Y + 1) -> icmp ugt X, Y + // icmp uge Op0, (C + 1) -> icmp ugt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); - // icmp ult X, (Y + 1) -> icmp ule X, Y + // icmp ult Op0, (C + 1) -> icmp ule Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); // if C1 has greater magnitude than C2: - // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y + // icmp (A + C1), (C + C2) -> icmp (A + C3), C // s.t. C3 = C1 - C2 // // if C2 has greater magnitude than C1: - // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) + // icmp (A + C1), (C + C2) -> icmp A, (C + C3) // s.t. C3 = C2 - C1 if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) @@ -3520,29 +3853,35 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + // icmp (A-B), A -> icmp 0, B for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + // icmp C, (C-D) -> icmp D, 0 for equalities or if there is no overflow. if (C == Op0 && NoOp1WrapProblem) return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - // (A - B) >u A --> A <u B - if (A == Op1 && Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_ULT, A, B); - // C <u (C - D) --> C <u D - if (C == Op0 && Pred == ICmpInst::ICMP_ULT) - return new ICmpInst(ICmpInst::ICMP_ULT, C, D); - - // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. - if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) + // Convert sub-with-unsigned-overflow comparisons into a comparison of args. + // (A - B) u>/u<= A --> B u>/u<= A + if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) + return new ICmpInst(Pred, B, A); + // C u</u>= (C - D) --> C u</u>= D + if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) + return new ICmpInst(Pred, C, D); + // (A - B) u>=/u< A --> B u>/u<= A iff B != 0 + if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && + isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A); + // C u<=/u> (C - D) --> C u</u>= D iff B != 0 + if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && + isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D); + + // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem) return new ICmpInst(Pred, A, C); - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. - if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) + + // icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem) return new ICmpInst(Pred, D, B); // icmp (0-X) < cst --> x > -cst @@ -3677,6 +4016,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { } } + if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + return replaceInstUsesWith(I, V); + if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) return replaceInstUsesWith(I, V); @@ -3953,125 +4295,140 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; } -/// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so -/// far. -Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { - const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); - Value *LHSCIOp = LHSCI->getOperand(0); - Type *SrcTy = LHSCIOp->getType(); - Type *DestTy = LHSCI->getType(); - - // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the - // integer type is the same size as the pointer type. - const auto& CompatibleSizes = [&](Type* SrcTy, Type* DestTy) -> bool { - if (isa<VectorType>(SrcTy)) { - SrcTy = cast<VectorType>(SrcTy)->getElementType(); - DestTy = cast<VectorType>(DestTy)->getElementType(); - } - return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); - }; - if (LHSCI->getOpcode() == Instruction::PtrToInt && - CompatibleSizes(SrcTy, DestTy)) { - Value *RHSOp = nullptr; - if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { - Value *RHSCIOp = RHSC->getOperand(0); - if (RHSCIOp->getType()->getPointerAddressSpace() == - LHSCIOp->getType()->getPointerAddressSpace()) { - RHSOp = RHSC->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder.CreateBitCast(RHSOp, LHSCIOp->getType()); - } - } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { - RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } - - if (RHSOp) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSOp); - } - - // The code below only handles extension cast instructions, so far. - // Enforce this. - if (LHSCI->getOpcode() != Instruction::ZExt && - LHSCI->getOpcode() != Instruction::SExt) +static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); + auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); + Value *X; + if (!match(CastOp0, m_ZExtOrSExt(m_Value(X)))) return nullptr; - bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; - bool isSignedCmp = ICmp.isSigned(); - - if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { - // Not an extension from the same type? - Value *RHSCIOp = CI->getOperand(0); - if (RHSCIOp->getType() != LHSCIOp->getType()) - return nullptr; - + bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; + bool IsSignedCmp = ICmp.isSigned(); + if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) { // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. - if (CI->getOpcode() != LHSCI->getOpcode()) + // TODO: This is too strict. We can handle some predicates (equality?). + if (CastOp0->getOpcode() != CastOp1->getOpcode()) return nullptr; - // Deal with equality cases early. + // Not an extension from the same type? + Value *Y = CastOp1->getOperand(0); + Type *XTy = X->getType(), *YTy = Y->getType(); + if (XTy != YTy) { + // One of the casts must have one use because we are creating a new cast. + if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse()) + return nullptr; + // Extend the narrower operand to the type of the wider operand. + if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) + X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy); + else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) + Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy); + else + return nullptr; + } + + // (zext X) == (zext Y) --> X == Y + // (sext X) == (sext Y) --> X == Y if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getPredicate(), X, Y); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedCmp && isSignedExt) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); + if (IsSignedCmp && IsSignedExt) + return new ICmpInst(ICmp.getPredicate(), X, Y); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y); } - // If we aren't dealing with a constant on the RHS, exit early. + // Below here, we are only folding a compare with constant. auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); if (!C) return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // re-extended to DestTy. + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); - Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); + Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); // If the re-extended constant didn't change... if (Res2 == C) { - // Deal with equality cases early. if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res1); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedExt && isSignedCmp) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); + if (IsSignedExt && IsSignedCmp) + return new ICmpInst(ICmp.getPredicate(), X, Res1); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); } // The re-extended constant changed, partly changed (in the case of a vector), // or could not be determined to be equal (in the case of a constant // expression), so the constant cannot be represented in the shorter type. - // Consequently, we cannot emit a simple comparison. // All the cases that fold to true or false will have already been handled // by SimplifyICmpInst, so only deal with the tricky case. + if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C)) + return nullptr; + + // Is source op positive? + // icmp ult (sext X), C --> icmp sgt X, -1 + if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) + return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy)); + + // Is source op negative? + // icmp ugt (sext X), C --> icmp slt X, 0 + assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); + return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); +} - if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C)) +/// Handle icmp (cast x), (cast or constant). +Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { + auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); + if (!CastOp0) + return nullptr; + if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1))) return nullptr; - // Evaluate the comparison for LT (we invert for GT below). LE and GE cases - // should have been folded away previously and not enter in here. + Value *Op0Src = CastOp0->getOperand(0); + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); - // We're performing an unsigned comp with a sign extended value. - // This is true if the input is >= 0. [aka >s -1] - Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Value *Result = Builder.CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { + if (isa<VectorType>(SrcTy)) { + SrcTy = cast<VectorType>(SrcTy)->getElementType(); + DestTy = cast<VectorType>(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; + if (CastOp0->getOpcode() == Instruction::PtrToInt && + CompatibleSizes(SrcTy, DestTy)) { + Value *NewOp1 = nullptr; + if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { + Value *PtrSrc = PtrToIntOp1->getOperand(0); + if (PtrSrc->getType()->getPointerAddressSpace() == + Op0Src->getType()->getPointerAddressSpace()) { + NewOp1 = PtrToIntOp1->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (Op0Src->getType() != NewOp1->getType()) + NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); + } + } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { + NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } - // Finally, return the value computed. - if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) - return replaceInstUsesWith(ICmp, Result); + if (NewOp1) + return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); + } - assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); - return BinaryOperator::CreateNot(Result); + return foldICmpWithZextOrSext(ICmp, Builder); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { @@ -4791,41 +5148,35 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return nullptr; } -/// If we have an icmp le or icmp ge instruction with a constant operand, turn -/// it into the appropriate icmp lt or icmp gt instruction. This transform -/// allows them to be folded in visitICmpInst. -static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { - ICmpInst::Predicate Pred = I.getPredicate(); - if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE && - Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE) - return nullptr; +llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, + Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); - Value *Op0 = I.getOperand(0); - Value *Op1 = I.getOperand(1); - auto *Op1C = dyn_cast<Constant>(Op1); - if (!Op1C) - return nullptr; + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; - // Check if the constant operand can be safely incremented/decremented without - // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled - // the edge cases for us, so we just assert on them. For vectors, we must - // handle the edge cases. - Type *Op1Type = Op1->getType(); - bool IsSigned = I.isSigned(); - bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE); - auto *CI = dyn_cast<ConstantInt>(Op1C); - if (CI) { - // A <= MAX -> TRUE ; A >= MIN -> TRUE - assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned)); - } else if (Op1Type->isVectorTy()) { - // TODO? If the edge cases for vectors were guaranteed to be handled as they - // are for scalar, we could remove the min/max checks. However, to do that, - // we would have to use insertelement/shufflevector to replace edge values. - unsigned NumElts = Op1Type->getVectorNumElements(); + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + if (auto *CI = dyn_cast<ConstantInt>(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return llvm::None; + } else if (Type->isVectorTy()) { + unsigned NumElts = Type->getVectorNumElements(); for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = Op1C->getAggregateElement(i); + Constant *Elt = C->getAggregateElement(i); if (!Elt) - return nullptr; + return llvm::None; if (isa<UndefValue>(Elt)) continue; @@ -4833,20 +5184,43 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); - if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned))) - return nullptr; + if (!CI || !ConstantIsOk(CI)) + return llvm::None; } } else { // ConstantExpr? - return nullptr; + return llvm::None; } - // Increment or decrement the constant and set the new comparison predicate: - // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT - Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1, true); - CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT; - NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred; - return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + +/// If we have an icmp le or icmp ge instruction with a constant operand, turn +/// it into the appropriate icmp lt or icmp gt instruction. This transform +/// allows them to be folded in visitICmpInst. +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) || + isCanonicalPredicate(Pred)) + return nullptr; + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast<Constant>(Op1); + if (!Op1C) + return nullptr; + + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + if (!FlippedStrictness) + return nullptr; + + return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); } /// Integer compare with boolean values can always be turned into bitwise ops. @@ -5002,6 +5376,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; + const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); unsigned Op0Cplxity = getComplexity(Op0); unsigned Op1Cplxity = getComplexity(Op1); @@ -5016,8 +5391,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, - SQ.getWithInstruction(&I))) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) return replaceInstUsesWith(I, V); // Comparing -val or val with non-zero is the same as just comparing val @@ -5050,6 +5424,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -5098,6 +5475,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstant(I)) return Res; + // Try to match comparison as a sign bit test. Intentionally do this after + // foldICmpInstWithConstant() to potentially let other folds to happen first. + if (Instruction *New = foldSignBitTest(I)) + return New; + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; @@ -5124,20 +5506,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpBitCast(I, Builder)) return Res; - if (isa<CastInst>(Op0)) { - // Handle the special case of: icmp (cast bool to X), <cst> - // This comes up when you have code like - // int X = A < B; - // if (X) ... - // For generality, we handle any zero-extension of any operand comparison - // with a constant or another cast from the same type. - if (isa<Constant>(Op1) || isa<CastInst>(Op1)) - if (Instruction *R = foldICmpWithCastAndCast(I)) - return R; - } - - if (Instruction *Res = foldICmpBinOp(I)) - return Res; + if (Instruction *R = foldICmpWithCastOp(I)) + return R; if (Instruction *Res = foldICmpWithMinMax(I)) return Res; diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 434b0d591215..1dbc06d92e7a 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -113,6 +113,48 @@ static inline bool isCanonicalPredicate(CmpInst::Predicate Pred) { } } +/// Given an exploded icmp instruction, return true if the comparison only +/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the +/// result of the comparison is true when the input value is signed. +inline bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, + bool &TrueIfSigned) { + switch (Pred) { + case ICmpInst::ICMP_SLT: // True if LHS s< 0 + TrueIfSigned = true; + return RHS.isNullValue(); + case ICmpInst::ICMP_SLE: // True if LHS s<= -1 + TrueIfSigned = true; + return RHS.isAllOnesValue(); + case ICmpInst::ICMP_SGT: // True if LHS s> -1 + TrueIfSigned = false; + return RHS.isAllOnesValue(); + case ICmpInst::ICMP_SGE: // True if LHS s>= 0 + TrueIfSigned = false; + return RHS.isNullValue(); + case ICmpInst::ICMP_UGT: + // True if LHS u> RHS and RHS == sign-bit-mask - 1 + TrueIfSigned = true; + return RHS.isMaxSignedValue(); + case ICmpInst::ICMP_UGE: + // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc) + TrueIfSigned = true; + return RHS.isMinSignedValue(); + case ICmpInst::ICMP_ULT: + // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc) + TrueIfSigned = false; + return RHS.isMinSignedValue(); + case ICmpInst::ICMP_ULE: + // True if LHS u<= RHS and RHS == sign-bit-mask - 1 + TrueIfSigned = false; + return RHS.isMaxSignedValue(); + default: + return false; + } +} + +llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C); + /// Return the source operand of a potentially bitcasted value while optionally /// checking if it has one use. If there is no bitcast or the one use check is /// not met, return the input value itself. @@ -139,32 +181,17 @@ static inline Constant *SubOne(Constant *C) { /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all /// uses of V and only keep uses of ~V. -static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { +/// +/// See also: canFreelyInvertAllUsersOf() +static inline bool isFreeToInvert(Value *V, bool WillInvertAllUses) { // ~(~(X)) -> X. if (match(V, m_Not(m_Value()))) return true; // Constants can be considered to be not'ed values. - if (isa<ConstantInt>(V)) + if (match(V, m_AnyIntegralConstant())) return true; - // A vector of constant integers can be inverted easily. - if (V->getType()->isVectorTy() && isa<Constant>(V)) { - unsigned NumElts = V->getType()->getVectorNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast<Constant>(V)->getAggregateElement(i); - if (!Elt) - return false; - - if (isa<UndefValue>(Elt)) - continue; - - if (!isa<ConstantInt>(Elt)) - return false; - } - return true; - } - // Compares can be inverted if all of their uses are being modified to use the // ~V. if (isa<CmpInst>(V)) @@ -185,6 +212,32 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } +/// Given i1 V, can every user of V be freely adapted if V is changed to !V ? +/// +/// See also: isFreeToInvert() +static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) { + // Look at every user of V. + for (User *U : V->users()) { + if (U == IgnoredUser) + continue; // Don't consider this user. + + auto *I = cast<Instruction>(U); + switch (I->getOpcode()) { + case Instruction::Select: + case Instruction::Br: + break; // Free to invert by swapping true/false values/destinations. + case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring it. + if (!match(I, m_Not(m_Value()))) + return false; // Not a 'not'. + break; + default: + return false; // Don't know, likely not freely invertible. + } + // So far all users were free to invert... + } + return true; // Can freely invert all users! +} + /// Some binary operators require special handling to avoid poison and undefined /// behavior. If a constant vector has undef elements, replace those undefs with /// identity constants if possible because those are always safe to execute. @@ -337,6 +390,13 @@ public: Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); + Value *reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction = false); + Instruction *canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I); + Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr); Instruction *visitAShr(BinaryOperator &I); Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); @@ -541,6 +601,7 @@ private: Instruction *narrowMathIfNoOverflow(BinaryOperator &I); Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); + Instruction *matchSAddSubSat(SelectInst &MinMax1); /// Determine if a pair of casts can be replaced by a single cast. /// @@ -557,7 +618,7 @@ private: Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &I); /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should @@ -725,7 +786,7 @@ public: Value *LHS, Value *RHS, Instruction *CxtI) const; /// Maximum size of array considered when transforming. - uint64_t MaxArraySizeForCombine; + uint64_t MaxArraySizeForCombine = 0; private: /// Performs a few simplifications for operators which are associative @@ -798,7 +859,8 @@ private: int DmaskIdx = -1); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, - APInt &UndefElts, unsigned Depth = 0); + APInt &UndefElts, unsigned Depth = 0, + bool AllowMultipleUsers = false); /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); @@ -847,17 +909,21 @@ private: Constant *RHSC); Instruction *foldICmpAddOpConst(Value *X, const APInt &C, ICmpInst::Predicate Pred); - Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); + Instruction *foldICmpWithCastOp(ICmpInst &ICI); Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); - Instruction *foldICmpBinOp(ICmpInst &Cmp); + Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); + Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); + Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, @@ -874,6 +940,8 @@ private: const APInt &C); Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, const APInt &C); + Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt &C); Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, const APInt &C); Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 054fb7da09a2..3a0e05832fcb 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -175,7 +175,7 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, uint64_t AllocaSize = DL.getTypeStoreSize(AI->getAllocatedType()); if (!AllocaSize) return false; - return isDereferenceableAndAlignedPointer(V, AI->getAlignment(), + return isDereferenceableAndAlignedPointer(V, Align(AI->getAlignment()), APInt(64, AllocaSize), DL); } @@ -197,7 +197,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { if (C->getValue().getActiveBits() <= 64) { Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName()); - New->setAlignment(AI.getAlignment()); + New->setAlignment(MaybeAlign(AI.getAlignment())); // Scan to the end of the allocation instructions, to skip over a block of // allocas if possible...also skip interleaved debug info @@ -345,7 +345,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { if (AI.getAllocatedType()->isSized()) { // If the alignment is 0 (unspecified), assign it the preferred alignment. if (AI.getAlignment() == 0) - AI.setAlignment(DL.getPrefTypeAlignment(AI.getAllocatedType())); + AI.setAlignment( + MaybeAlign(DL.getPrefTypeAlignment(AI.getAllocatedType()))); // Move all alloca's of zero byte objects to the entry block and merge them // together. Note that we only do this for alloca's, because malloc should @@ -377,12 +378,12 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // assign it the preferred alignment. if (EntryAI->getAlignment() == 0) EntryAI->setAlignment( - DL.getPrefTypeAlignment(EntryAI->getAllocatedType())); + MaybeAlign(DL.getPrefTypeAlignment(EntryAI->getAllocatedType()))); // Replace this zero-sized alloca with the one at the start of the entry // block after ensuring that the address will be aligned enough for both // types. - unsigned MaxAlign = std::max(EntryAI->getAlignment(), - AI.getAlignment()); + const MaybeAlign MaxAlign( + std::max(EntryAI->getAlignment(), AI.getAlignment())); EntryAI->setAlignment(MaxAlign); if (AI.getType() != EntryAI->getType()) return new BitCastInst(EntryAI, AI.getType()); @@ -455,9 +456,6 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT Value *Ptr = LI.getPointerOperand(); unsigned AS = LI.getPointerAddressSpace(); - SmallVector<std::pair<unsigned, MDNode *>, 8> MD; - LI.getAllMetadata(MD); - Value *NewPtr = nullptr; if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && NewPtr->getType()->getPointerElementType() == NewTy && @@ -467,48 +465,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); - MDBuilder MDB(NewLoad->getContext()); - for (const auto &MDPair : MD) { - unsigned ID = MDPair.first; - MDNode *N = MDPair.second; - // Note, essentially every kind of metadata should be preserved here! This - // routine is supposed to clone a load instruction changing *only its type*. - // The only metadata it makes sense to drop is metadata which is invalidated - // when the pointer type changes. This should essentially never be the case - // in LLVM, but we explicitly switch over only known metadata to be - // conservatively correct. If you are adding metadata to LLVM which pertains - // to loads, you almost certainly want to add it here. - switch (ID) { - case LLVMContext::MD_dbg: - case LLVMContext::MD_tbaa: - case LLVMContext::MD_prof: - case LLVMContext::MD_fpmath: - case LLVMContext::MD_tbaa_struct: - case LLVMContext::MD_invariant_load: - case LLVMContext::MD_alias_scope: - case LLVMContext::MD_noalias: - case LLVMContext::MD_nontemporal: - case LLVMContext::MD_mem_parallel_loop_access: - case LLVMContext::MD_access_group: - // All of these directly apply. - NewLoad->setMetadata(ID, N); - break; - - case LLVMContext::MD_nonnull: - copyNonnullMetadata(LI, N, *NewLoad); - break; - case LLVMContext::MD_align: - case LLVMContext::MD_dereferenceable: - case LLVMContext::MD_dereferenceable_or_null: - // These only directly apply if the new type is also a pointer. - if (NewTy->isPointerTy()) - NewLoad->setMetadata(ID, N); - break; - case LLVMContext::MD_range: - copyRangeMetadata(IC.getDataLayout(), LI, N, *NewLoad); - break; - } - } + copyMetadataForLoad(*NewLoad, LI); return NewLoad; } @@ -1004,9 +961,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { LoadAlign != 0 ? LoadAlign : DL.getABITypeAlignment(LI.getType()); if (KnownAlign > EffectiveLoadAlign) - LI.setAlignment(KnownAlign); + LI.setAlignment(MaybeAlign(KnownAlign)); else if (LoadAlign == 0) - LI.setAlignment(EffectiveLoadAlign); + LI.setAlignment(MaybeAlign(EffectiveLoadAlign)); // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) { @@ -1063,11 +1020,11 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). - unsigned Align = LI.getAlignment(); - if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), Align, - DL, SI) && - isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), Align, - DL, SI)) { + const MaybeAlign Alignment(LI.getAlignment()); + if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), + Alignment, DL, SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), + Alignment, DL, SI)) { LoadInst *V1 = Builder.CreateLoad(LI.getType(), SI->getOperand(1), SI->getOperand(1)->getName() + ".val"); @@ -1075,9 +1032,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { Builder.CreateLoad(LI.getType(), SI->getOperand(2), SI->getOperand(2)->getName() + ".val"); assert(LI.isUnordered() && "implied by above"); - V1->setAlignment(Align); + V1->setAlignment(Alignment); V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); - V2->setAlignment(Align); + V2->setAlignment(Alignment); V2->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); return SelectInst::Create(SI->getCondition(), V1, V2); } @@ -1399,15 +1356,15 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { return eraseInstFromFunction(SI); // Attempt to improve the alignment. - unsigned KnownAlign = getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT); - unsigned StoreAlign = SI.getAlignment(); - unsigned EffectiveStoreAlign = - StoreAlign != 0 ? StoreAlign : DL.getABITypeAlignment(Val->getType()); + const Align KnownAlign = Align(getOrEnforceKnownAlignment( + Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT)); + const MaybeAlign StoreAlign = MaybeAlign(SI.getAlignment()); + const Align EffectiveStoreAlign = + StoreAlign ? *StoreAlign : Align(DL.getABITypeAlignment(Val->getType())); if (KnownAlign > EffectiveStoreAlign) SI.setAlignment(KnownAlign); - else if (StoreAlign == 0) + else if (!StoreAlign) SI.setAlignment(EffectiveStoreAlign); // Try to canonicalize the stored type. @@ -1622,8 +1579,8 @@ bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) { // Advance to a place where it is safe to insert the new store and insert it. BBI = DestBB->getFirstInsertionPt(); - StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), - SI.isVolatile(), SI.getAlignment(), + StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), + MaybeAlign(SI.getAlignment()), SI.getOrdering(), SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(MergedLoc); diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index cc753ce05313..0b9128a9f5a1 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -124,6 +124,50 @@ static Constant *getLogBase2(Type *Ty, Constant *C) { return ConstantVector::get(Elts); } +// TODO: This is a specific form of a much more general pattern. +// We could detect a select with any binop identity constant, or we +// could use SimplifyBinOp to see if either arm of the select reduces. +// But that needs to be done carefully and/or while removing potential +// reverse canonicalizations as in InstCombiner::foldSelectIntoOp(). +static Value *foldMulSelectToNegate(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *OtherOp; + + // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp + // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), + m_Value(OtherOp)))) + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateNeg(OtherOp)); + + // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp + // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), + m_Value(OtherOp)))) + return Builder.CreateSelect(Cond, Builder.CreateNeg(OtherOp), OtherOp); + + // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp + // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0), + m_SpecificFP(-1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp)); + } + + // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp + // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0), + m_SpecificFP(1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp); + } + + return nullptr; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) @@ -213,6 +257,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) return FoldedMul; + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { // Canonicalize (X+C1)*CI -> X*CI+C1*CI. @@ -358,6 +405,9 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) return FoldedMul; + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_SpecificFP(-1.0))) @@ -373,16 +423,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - // Sink negation: -X * Y --> -(X * Y) - // But don't transform constant expressions because there's an inverse fold. - if (match(Op0, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op0)) - return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I); - - // Sink negation: Y * -X --> -(X * Y) - // But don't transform constant expressions because there's an inverse fold. - if (match(Op1, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op1)) - return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I); - // fabs(X) * fabs(X) -> X * X if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) return BinaryOperator::CreateFMulFMF(X, X, &I); @@ -1211,8 +1251,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) && match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X))); - if ((IsTan || IsCot) && hasUnaryFloatFn(&TLI, I.getType(), LibFunc_tan, - LibFunc_tanf, LibFunc_tanl)) { + if ((IsTan || IsCot) && + hasFloatFn(&TLI, I.getType(), LibFunc_tan, LibFunc_tanf, LibFunc_tanl)) { IRBuilder<> B(&I); IRBuilder<>::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(I.getFastMathFlags()); @@ -1244,6 +1284,17 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return &I; } + // X / fabs(X) -> copysign(1.0, X) + // fabs(X) / X -> copysign(1.0, X) + if (I.hasNoNaNs() && I.hasNoInfs() && + (match(&I, + m_FDiv(m_Value(X), m_Intrinsic<Intrinsic::fabs>(m_Deferred(X)))) || + match(&I, m_FDiv(m_Intrinsic<Intrinsic::fabs>(m_Value(X)), + m_Deferred(X))))) { + Value *V = Builder.CreateBinaryIntrinsic( + Intrinsic::copysign, ConstantFP::get(I.getType(), 1.0), X, &I); + return replaceInstUsesWith(I, V); + } return nullptr; } @@ -1309,6 +1360,8 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { + // This may increase instruction count, we don't enforce that Y is a + // constant. Constant *N1 = Constant::getAllOnesValue(Ty); Value *Add = Builder.CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 5820ab726637..e0376b7582f3 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -542,7 +542,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // visitLoadInst will propagate an alignment onto the load when TD is around, // and if TD isn't around, we can't handle the mixed case. bool isVolatile = FirstLI->isVolatile(); - unsigned LoadAlignment = FirstLI->getAlignment(); + MaybeAlign LoadAlignment(FirstLI->getAlignment()); unsigned LoadAddrSpace = FirstLI->getPointerAddressSpace(); // We can't sink the load if the loaded value could be modified between the @@ -574,10 +574,10 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // If some of the loads have an alignment specified but not all of them, // we can't do the transformation. - if ((LoadAlignment != 0) != (LI->getAlignment() != 0)) + if ((LoadAlignment.hasValue()) != (LI->getAlignment() != 0)) return nullptr; - LoadAlignment = std::min(LoadAlignment, LI->getAlignment()); + LoadAlignment = std::min(LoadAlignment, MaybeAlign(LI->getAlignment())); // If the PHI is of volatile loads and the load block has multiple // successors, sinking it would remove a load of the volatile value from diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index aefaf5af1750..9fc871e49b30 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -785,6 +785,41 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Fold the following code sequence: +/// \code +/// int a = ctlz(x & -x); +// x ? 31 - a : a; +/// \code +/// +/// into: +/// cttz(x) +static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); + if (!ICI->isEquality() || !match(ICI->getOperand(1), m_Zero())) + return nullptr; + + if (ICI->getPredicate() == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + if (!match(FalseVal, + m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) + return nullptr; + + if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) + return nullptr; + + Value *X = ICI->getOperand(0); + auto *II = cast<IntrinsicInst>(TrueVal); + if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) + return nullptr; + + Function *F = Intrinsic::getDeclaration(II->getModule(), Intrinsic::cttz, + II->getType()); + return CallInst::Create(F, {X, II->getArgOperand(1)}); +} + /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single /// call to cttz/ctlz with flag 'is_zero_undef' cleared. /// @@ -973,8 +1008,7 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, // If we are swapping the select operands, swap the metadata too. assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS && "Unexpected results from matchSelectPattern"); - Sel.setTrueValue(LHS); - Sel.setFalseValue(RHS); + Sel.swapValues(); Sel.swapProfMetadata(); return &Sel; } @@ -1056,17 +1090,293 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, } // We are swapping the select operands, so swap the metadata too. - Sel.setTrueValue(FVal); - Sel.setFalseValue(TVal); + Sel.swapValues(); Sel.swapProfMetadata(); return &Sel; } +static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, + const SimplifyQuery &Q) { + // If this is a binary operator, try to simplify it with the replaced op + // because we know Op and ReplaceOp are equivalant. + // For example: V = X + 1, Op = X, ReplaceOp = 42 + // Simplifies as: add(42, 1) --> 43 + if (auto *BO = dyn_cast<BinaryOperator>(V)) { + if (BO->getOperand(0) == Op) + return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q); + if (BO->getOperand(1) == Op) + return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q); + } + + return nullptr; +} + +/// If we have a select with an equality comparison, then we know the value in +/// one of the arms of the select. See if substituting this value into an arm +/// and simplifying the result yields the same value as the other arm. +/// +/// To make this transform safe, we must drop poison-generating flags +/// (nsw, etc) if we simplified to a binop because the select may be guarding +/// that poison from propagating. If the existing binop already had no +/// poison-generating flags, then this transform can be done by instsimplify. +/// +/// Consider: +/// %cmp = icmp eq i32 %x, 2147483647 +/// %add = add nsw i32 %x, 1 +/// %sel = select i1 %cmp, i32 -2147483648, i32 %add +/// +/// We can't replace %sel with %add unless we strip away the flags. +/// TODO: Wrapping flags could be preserved in some cases with better analysis. +static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, + const SimplifyQuery &Q) { + if (!Cmp.isEquality()) + return nullptr; + + // Canonicalize the pattern to ICMP_EQ by swapping the select operands. + Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Try each equivalence substitution possibility. + // We have an 'EQ' comparison, so the select's false value will propagate. + // Example: + // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 + // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43 + Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal || + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal || + simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal || + simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) { + if (auto *FalseInst = dyn_cast<Instruction>(FalseVal)) + FalseInst->dropPoisonGeneratingFlags(); + return FalseVal; + } + return nullptr; +} + +// See if this is a pattern like: +// %old_cmp1 = icmp slt i32 %x, C2 +// %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high +// %old_x_offseted = add i32 %x, C1 +// %old_cmp0 = icmp ult i32 %old_x_offseted, C0 +// %r = select i1 %old_cmp0, i32 %x, i32 %old_replacement +// This can be rewritten as more canonical pattern: +// %new_cmp1 = icmp slt i32 %x, -C1 +// %new_cmp2 = icmp sge i32 %x, C0-C1 +// %new_clamped_low = select i1 %new_cmp1, i32 %target_low, i32 %x +// %r = select i1 %new_cmp2, i32 %target_high, i32 %new_clamped_low +// Iff -C1 s<= C2 s<= C0-C1 +// Also ULT predicate can also be UGT iff C0 != -1 (+invert result) +// SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) +static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, + InstCombiner::BuilderTy &Builder) { + Value *X = Sel0.getTrueValue(); + Value *Sel1 = Sel0.getFalseValue(); + + // First match the condition of the outermost select. + // Said condition must be one-use. + if (!Cmp0.hasOneUse()) + return nullptr; + Value *Cmp00 = Cmp0.getOperand(0); + Constant *C0; + if (!match(Cmp0.getOperand(1), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))) + return nullptr; + // Canonicalize Cmp0 into the form we expect. + // FIXME: we shouldn't care about lanes that are 'undef' in the end? + switch (Cmp0.getPredicate()) { + case ICmpInst::Predicate::ICMP_ULT: + break; // Great! + case ICmpInst::Predicate::ICMP_ULE: + // We'd have to increment C0 by one, and for that it must not have all-ones + // element, but then it would have been canonicalized to 'ult' before + // we get here. So we can't do anything useful with 'ule'. + return nullptr; + case ICmpInst::Predicate::ICMP_UGT: + // We want to canonicalize it to 'ult', so we'll need to increment C0, + // which again means it must not have any all-ones elements. + if (!match(C0, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, + APInt::getAllOnesValue( + C0->getType()->getScalarSizeInBits())))) + return nullptr; // Can't do, have all-ones element[s]. + C0 = AddOne(C0); + std::swap(X, Sel1); + break; + case ICmpInst::Predicate::ICMP_UGE: + // The only way we'd get this predicate if this `icmp` has extra uses, + // but then we won't be able to do this fold. + return nullptr; + default: + return nullptr; // Unknown predicate. + } + + // Now that we've canonicalized the ICmp, we know the X we expect; + // the select in other hand should be one-use. + if (!Sel1->hasOneUse()) + return nullptr; + + // We now can finish matching the condition of the outermost select: + // it should either be the X itself, or an addition of some constant to X. + Constant *C1; + if (Cmp00 == X) + C1 = ConstantInt::getNullValue(Sel0.getType()); + else if (!match(Cmp00, + m_Add(m_Specific(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1))))) + return nullptr; + + Value *Cmp1; + ICmpInst::Predicate Pred1; + Constant *C2; + Value *ReplacementLow, *ReplacementHigh; + if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow), + m_Value(ReplacementHigh))) || + !match(Cmp1, + m_ICmp(Pred1, m_Specific(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C2))))) + return nullptr; + + if (!Cmp1->hasOneUse() && (Cmp00 == X || !Cmp00->hasOneUse())) + return nullptr; // Not enough one-use instructions for the fold. + // FIXME: this restriction could be relaxed if Cmp1 can be reused as one of + // two comparisons we'll need to build. + + // Canonicalize Cmp1 into the form we expect. + // FIXME: we shouldn't care about lanes that are 'undef' in the end? + switch (Pred1) { + case ICmpInst::Predicate::ICMP_SLT: + break; + case ICmpInst::Predicate::ICMP_SLE: + // We'd have to increment C2 by one, and for that it must not have signed + // max element, but then it would have been canonicalized to 'slt' before + // we get here. So we can't do anything useful with 'sle'. + return nullptr; + case ICmpInst::Predicate::ICMP_SGT: + // We want to canonicalize it to 'slt', so we'll need to increment C2, + // which again means it must not have any signed max elements. + if (!match(C2, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, + APInt::getSignedMaxValue( + C2->getType()->getScalarSizeInBits())))) + return nullptr; // Can't do, have signed max element[s]. + C2 = AddOne(C2); + LLVM_FALLTHROUGH; + case ICmpInst::Predicate::ICMP_SGE: + // Also non-canonical, but here we don't need to change C2, + // so we don't have any restrictions on C2, so we can just handle it. + std::swap(ReplacementLow, ReplacementHigh); + break; + default: + return nullptr; // Unknown predicate. + } + + // The thresholds of this clamp-like pattern. + auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); + auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); + + // The fold has a precondition 1: C2 s>= ThresholdLow + auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, + ThresholdLowIncl); + if (!match(Precond1, m_One())) + return nullptr; + // The fold has a precondition 2: C2 s<= ThresholdHigh + auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2, + ThresholdHighExcl); + if (!match(Precond2, m_One())) + return nullptr; + + // All good, finally emit the new pattern. + Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl); + Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl); + Value *MaybeReplacedLow = + Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X); + Instruction *MaybeReplacedHigh = + SelectInst::Create(ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); + + return MaybeReplacedHigh; +} + +// If we have +// %cmp = icmp [canonical predicate] i32 %x, C0 +// %r = select i1 %cmp, i32 %y, i32 C1 +// Where C0 != C1 and %x may be different from %y, see if the constant that we +// will have if we flip the strictness of the predicate (i.e. without changing +// the result) is identical to the C1 in select. If it matches we can change +// original comparison to one with swapped predicate, reuse the constant, +// and swap the hands of select. +static Instruction * +tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred; + Value *X; + Constant *C0; + if (!match(&Cmp, m_OneUse(m_ICmp( + Pred, m_Value(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))))) + return nullptr; + + // If comparison predicate is non-relational, we won't be able to do anything. + if (ICmpInst::isEquality(Pred)) + return nullptr; + + // If comparison predicate is non-canonical, then we certainly won't be able + // to make it canonical; canonicalizeCmpWithConstant() already tried. + if (!isCanonicalPredicate(Pred)) + return nullptr; + + // If the [input] type of comparison and select type are different, lets abort + // for now. We could try to compare constants with trunc/[zs]ext though. + if (C0->getType() != Sel.getType()) + return nullptr; + + // FIXME: are there any magic icmp predicate+constant pairs we must not touch? + + Value *SelVal0, *SelVal1; // We do not care which one is from where. + match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1))); + // At least one of these values we are selecting between must be a constant + // else we'll never succeed. + if (!match(SelVal0, m_AnyIntegralConstant()) && + !match(SelVal1, m_AnyIntegralConstant())) + return nullptr; + + // Does this constant C match any of the `select` values? + auto MatchesSelectValue = [SelVal0, SelVal1](Constant *C) { + return C->isElementWiseEqual(SelVal0) || C->isElementWiseEqual(SelVal1); + }; + + // If C0 *already* matches true/false value of select, we are done. + if (MatchesSelectValue(C0)) + return nullptr; + + // Check the constant we'd have with flipped-strictness predicate. + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0); + if (!FlippedStrictness) + return nullptr; + + // If said constant doesn't match either, then there is no hope, + if (!MatchesSelectValue(FlippedStrictness->second)) + return nullptr; + + // It matched! Lets insert the new comparison just before select. + InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&Sel); + + Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped. + Value *NewCmp = Builder.CreateICmp(Pred, X, FlippedStrictness->second, + Cmp.getName() + ".inv"); + Sel.setCondition(NewCmp); + Sel.swapValues(); + Sel.swapProfMetadata(); + + return &Sel; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); + if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ)) + return replaceInstUsesWith(SI, V); if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; @@ -1074,12 +1384,21 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder)) return NewAbs; + if (Instruction *NewAbs = canonicalizeClampLike(SI, *ICI, Builder)) + return NewAbs; + + if (Instruction *NewSel = + tryToReuseConstantFromSelectInComparison(SI, *ICI, Builder)) + return NewSel; + bool Changed = adjustMinMax(SI, *ICI); if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); // NOTE: if we wanted to, this is where to detect integer MIN/MAX + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); @@ -1149,6 +1468,9 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) return V; + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) + return V; + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); @@ -1253,6 +1575,16 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, } } + // max(max(A, B), min(A, B)) --> max(A, B) + // min(min(A, B), max(A, B)) --> min(A, B) + // TODO: This could be done in instsimplify. + if (SPF1 == SPF2 && + ((SPF1 == SPF_UMIN && match(C, m_c_UMax(m_Specific(A), m_Specific(B)))) || + (SPF1 == SPF_SMIN && match(C, m_c_SMax(m_Specific(A), m_Specific(B)))) || + (SPF1 == SPF_UMAX && match(C, m_c_UMin(m_Specific(A), m_Specific(B)))) || + (SPF1 == SPF_SMAX && match(C, m_c_SMin(m_Specific(A), m_Specific(B)))))) + return replaceInstUsesWith(Outer, Inner); + // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) // TODO: This could be done in instsimplify. @@ -1280,7 +1612,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, return true; } - if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) { + if (isFreeToInvert(V, !V->hasNUsesOrMore(3))) { NotV = nullptr; return true; } @@ -1492,6 +1824,30 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { ConstantVector::get(Mask)); } +/// If we have a select of vectors with a scalar condition, try to convert that +/// to a vector select by splatting the condition. A splat may get folded with +/// other operations in IR and having all operands of a select be vector types +/// is likely better for vector codegen. +static Instruction *canonicalizeScalarSelectOfVecs( + SelectInst &Sel, InstCombiner::BuilderTy &Builder) { + Type *Ty = Sel.getType(); + if (!Ty->isVectorTy()) + return nullptr; + + // We can replace a single-use extract with constant index. + Value *Cond = Sel.getCondition(); + if (!match(Cond, m_OneUse(m_ExtractElement(m_Value(), m_ConstantInt())))) + return nullptr; + + // select (extelt V, Index), T, F --> select (splat V, Index), T, F + // Splatting the extracted condition reduces code (we could directly create a + // splat shuffle of the source vector to eliminate the intermediate step). + unsigned NumElts = Ty->getVectorNumElements(); + Value *SplatCond = Builder.CreateVectorSplat(NumElts, Cond); + Sel.setCondition(SplatCond); + return &Sel; +} + /// Reuse bitcasted operands between a compare and select: /// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> /// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) @@ -1648,6 +2004,71 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, return nullptr; } +/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. +Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) { + Type *Ty = MinMax1.getType(); + + // We are looking for a tree of: + // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B)))) + // Where the min and max could be reversed + Instruction *MinMax2; + BinaryOperator *AddSub; + const APInt *MinValue, *MaxValue; + if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) { + if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) + return nullptr; + } else if (match(&MinMax1, + m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) { + if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue)))) + return nullptr; + } else + return nullptr; + + // Check that the constants clamp a saturate, and that the new type would be + // sensible to convert to. + if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1) + return nullptr; + // In what bitwidth can this be treated as saturating arithmetics? + unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1; + // FIXME: This isn't quite right for vectors, but using the scalar type is a + // good first approximation for what should be done there. + if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) + return nullptr; + + // Also make sure that the number of uses is as expected. The "3"s are for the + // the two items of min/max (the compare and the select). + if (MinMax2->hasNUsesOrMore(3) || AddSub->hasNUsesOrMore(3)) + return nullptr; + + // Create the new type (which can be a vector type) + Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); + // Match the two extends from the add/sub + Value *A, *B; + if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B))))) + return nullptr; + // And check the incoming values are of a type smaller than or equal to the + // size of the saturation. Otherwise the higher bits can cause different + // results. + if (A->getType()->getScalarSizeInBits() > NewBitWidth || + B->getType()->getScalarSizeInBits() > NewBitWidth) + return nullptr; + + Intrinsic::ID IntrinsicID; + if (AddSub->getOpcode() == Instruction::Add) + IntrinsicID = Intrinsic::sadd_sat; + else if (AddSub->getOpcode() == Instruction::Sub) + IntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + // Finally create and return the sat intrinsic, truncated to the new type + Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); + Value *AT = Builder.CreateSExt(A, NewTy); + Value *BT = Builder.CreateSExt(B, NewTy); + Value *Sat = Builder.CreateCall(F, {AT, BT}); + return CastInst::Create(Instruction::SExt, Sat, Ty); +} + /// Reduce a sequence of min/max with a common operand. static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, Value *RHS, @@ -1788,6 +2209,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *I = canonicalizeSelectToShuffle(SI)) return I; + if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, Builder)) + return I; + // Canonicalize a one-use integer compare with a non-canonical predicate by // inverting the predicate and swapping the select operands. This matches a // compare canonicalization for conditional branches. @@ -2013,16 +2437,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (LHS->getType()->isFPOrFPVectorTy() && ((CmpLHS != LHS && CmpLHS != RHS) || (CmpRHS != LHS && CmpRHS != RHS)))) { - CmpInst::Predicate Pred = getMinMaxPred(SPF, SPR.Ordered); + CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; - if (CmpInst::isIntPredicate(Pred)) { - Cmp = Builder.CreateICmp(Pred, LHS, RHS); + if (CmpInst::isIntPredicate(MinMaxPred)) { + Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS); } else { IRBuilder<>::FastMathFlagGuard FMFG(Builder); - auto FMF = cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); + auto FMF = + cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); Builder.setFastMathFlags(FMF); - Cmp = Builder.CreateFCmp(Pred, LHS, RHS); + Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS); } Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); @@ -2040,9 +2465,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { Value *A; if (match(X, m_Not(m_Value(A))) && !X->hasNUsesOrMore(3) && - !IsFreeToInvert(A, A->hasOneUse()) && + !isFreeToInvert(A, A->hasOneUse()) && // Passing false to only consider m_Not and constants. - IsFreeToInvert(Y, false)) { + isFreeToInvert(Y, false)) { Value *B = Builder.CreateNot(Y); Value *NewMinMax = createMinMax(Builder, getInverseMinMaxFlavor(SPF), A, B); @@ -2070,6 +2495,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) return I; + if (Instruction *I = matchSAddSubSat(SI)) + return I; } } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index c821292400cd..64294838644f 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -25,50 +25,275 @@ using namespace PatternMatch; // we should rewrite it as // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) // This is valid for any shift, but they must be identical. -static Instruction * -reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, - const SimplifyQuery &SQ) { - // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1 - Value *X, *ShAmt1, *ShAmt0; +// +// AnalyzeForSignBitExtraction indicates that we will only analyze whether this +// pattern has any 2 right-shifts that sum to 1 less than original bit width. +Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction) { + // Look for a shift of some instruction, ignore zext of shift amount if any. + Instruction *Sh0Op0; + Value *ShAmt0; + if (!match(Sh0, + m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0))))) + return nullptr; + + // If there is a truncation between the two shifts, we must make note of it + // and look through it. The truncation imposes additional constraints on the + // transform. Instruction *Sh1; - if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)), - m_Instruction(Sh1)), - m_Value(ShAmt0)))) + Value *Trunc = nullptr; + match(Sh0Op0, + m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)), + m_Instruction(Sh1))); + + // Inner shift: (x shiftopcode ShAmt1) + // Like with other shift, ignore zext of shift amount if any. + Value *X, *ShAmt1; + if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1))))) + return nullptr; + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now.. + if (ShAmt0->getType() != ShAmt1->getType()) + return nullptr; + + // We are only looking for signbit extraction if we have two right shifts. + bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && + match(Sh1, m_Shr(m_Value(), m_Value())); + // ... and if it's not two right-shifts, we know the answer already. + if (AnalyzeForSignBitExtraction && !HadTwoRightShifts) return nullptr; - // The shift opcodes must be identical. + // The shift opcodes must be identical, unless we are just checking whether + // this pattern can be interpreted as a sign-bit-extraction. Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); - if (ShiftOpcode != Sh1->getOpcode()) + bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode(); + if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction) return nullptr; + + // If we saw truncation, we'll need to produce extra instruction, + // and for that one of the operands of the shift must be one-use, + // unless of course we don't actually plan to produce any instructions here. + if (Trunc && !AnalyzeForSignBitExtraction && + !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + // Can we fold (ShAmt0+ShAmt1) ? - Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1, - SQ.getWithInstruction(Sh0)); + auto *NewShAmt = dyn_cast_or_null<Constant>( + SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false, + SQ.getWithInstruction(Sh0))); if (!NewShAmt) return nullptr; // Did not simplify. - // Is the new shift amount smaller than the bit width? - // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); + unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits(); + unsigned XBitWidth = X->getType()->getScalarSizeInBits(); + // Is the new shift amount smaller than the bit width of inner/new shift? if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, - APInt(BitWidth, BitWidth)))) - return nullptr; + APInt(NewShAmtBitWidth, XBitWidth)))) + return nullptr; // FIXME: could perform constant-folding. + + // If there was a truncation, and we have a right-shift, we can only fold if + // we are left with the original sign bit. Likewise, if we were just checking + // that this is a sighbit extraction, this is the place to check it. + // FIXME: zero shift amount is also legal here, but we can't *easily* check + // more than one predicate so it's not really worth it. + if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) { + // If it's not a sign bit extraction, then we're done. + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(NewShAmtBitWidth, XBitWidth - 1)))) + return nullptr; + // If it is, and that was the question, return the base value. + if (AnalyzeForSignBitExtraction) + return X; + } + + assert(IdenticalShOpcodes && "Should not get here with different shifts."); + // All good, we can do this fold. + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); + BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); - // If both of the original shifts had the same flag set, preserve the flag. - if (ShiftOpcode == Instruction::BinaryOps::Shl) { - NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && - Sh1->hasNoUnsignedWrap()); - NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && - Sh1->hasNoSignedWrap()); - } else { - NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + + // The flags can only be propagated if there wasn't a trunc. + if (!Trunc) { + // If the pattern did not involve trunc, and both of the original shifts + // had the same flag set, preserve the flag. + if (ShiftOpcode == Instruction::BinaryOps::Shl) { + NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && + Sh1->hasNoUnsignedWrap()); + NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && + Sh1->hasNoSignedWrap()); + } else { + NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + } + } + + Instruction *Ret = NewShift; + if (Trunc) { + Builder.Insert(NewShift); + Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType()); + } + + return Ret; +} + +// Try to replace `undef` constants in C with Replacement. +static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) { + if (C && match(C, m_Undef())) + return Replacement; + + if (auto *CV = dyn_cast<ConstantVector>(C)) { + llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands()); + for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) { + Constant *EltC = CV->getOperand(i); + NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC; + } + return ConstantVector::get(NewOps); + } + + // Don't know how to deal with this constant. + return C; +} + +// If we have some pattern that leaves only some low bits set, and then performs +// left-shift of those bits, if none of the bits that are left after the final +// shift are modified by the mask, we can omit the mask. +// +// There are many variants to this pattern: +// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt +// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt +// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt +// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt +// e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt +// f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt +// All these patterns can be simplified to just: +// x << ShiftShAmt +// iff: +// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x) +// c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt) +static Instruction * +dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, + const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { + assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl && + "The input must be 'shl'!"); + + Value *Masked, *ShiftShAmt; + match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt))); + + Type *NarrowestTy = OuterShift->getType(); + Type *WidestTy = Masked->getType(); + // The mask must be computed in a type twice as wide to ensure + // that no bits are lost if the sum-of-shifts is wider than the base type. + Type *ExtendedTy = WidestTy->getExtendedType(); + + Value *MaskShAmt; + + // ((1 << MaskShAmt) - 1) + auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); + // (~(-1 << maskNbits)) + auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); + // (-1 >> MaskShAmt) + auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); + // ((-1 << MaskShAmt) >> MaskShAmt) + auto MaskD = + m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); + + Value *X; + Constant *NewMask; + + if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { + // Can we simplify (MaskShAmt+ShiftShAmt) ? + auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst( + MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); + if (!SumOfShAmts) + return nullptr; // Did not simplify. + // In this pattern SumOfShAmts correlates with the number of low bits + // that shall remain in the root value (OuterShift). + + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the the `undef` shift amounts with final + // shift bitwidth to ensure that the value remains undef when creating the + // subsequent shift op. + SumOfShAmts = replaceUndefsWith( + SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), + ExtendedTy->getScalarSizeInBits())); + auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + auto *ExtendedInvertedMask = + ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); + NewMask = ConstantExpr::getNot(ExtendedInvertedMask); + } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || + match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), + m_Deferred(MaskShAmt)))) { + // Can we simplify (ShiftShAmt-MaskShAmt) ? + auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst( + ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); + if (!ShAmtsDiff) + return nullptr; // Did not simplify. + // In this pattern ShAmtsDiff correlates with the number of high bits that + // shall be unset in the root value (OuterShift). + + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the the `undef` shift amounts with negated + // bitwidth of innermost shift to ensure that the value remains undef when + // creating the subsequent shift op. + unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); + ShAmtsDiff = replaceUndefsWith( + ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), + -WidestTyBitWidth)); + auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), + WidestTyBitWidth, + /*isSigned=*/false), + ShAmtsDiff), + ExtendedTy); + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + NewMask = + ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + } else + return nullptr; // Don't know anything about this pattern. + + NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy); + + // Does this mask has any unset bits? If not then we can just not apply it. + bool NeedMask = !match(NewMask, m_AllOnes()); + + // If we need to apply a mask, there are several more restrictions we have. + if (NeedMask) { + // The old masking instruction must go away. + if (!Masked->hasOneUse()) + return nullptr; + // The original "masking" instruction must not have been`ashr`. + if (match(Masked, m_AShr(m_Value(), m_Value()))) + return nullptr; } - return NewShift; + + // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. + auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, + OuterShift->getOperand(1)); + + if (!NeedMask) + return NewShift; + + Builder.Insert(NewShift); + return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); + // If the shift amount is a one-use `sext`, we can demote it to `zext`. + Value *Y; + if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) { + Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName()); + return BinaryOperator::Create(I.getOpcode(), Op0, NewExt); + } + // See if we can fold away this shift. if (SimplifyDemandedInstructionBits(I)) return &I; @@ -83,8 +308,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; - if (Instruction *NewShift = - reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)) + if (auto *NewShift = cast_or_null<Instruction>( + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) return NewShift; // (C1 shift (A add C2)) -> (C1 shift C2) shift A) @@ -618,9 +843,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } Instruction *InstCombiner::visitShl(BinaryOperator &I) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), - I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q)) return replaceInstUsesWith(I, V); if (Instruction *X = foldVectorBinop(I)) @@ -629,6 +855,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Instruction *V = commonShiftTransforms(I)) return V; + if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder)) + return V; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); @@ -636,12 +865,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); - unsigned BitWidth = Ty->getScalarSizeInBits(); // shl (zext X), ShAmt --> zext (shl X, ShAmt) // This is only valid if X would have zeros shifted out. Value *X; - if (match(Op0, m_ZExt(m_Value(X)))) { + if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); if (ShAmt < SrcWidth && MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) @@ -719,6 +947,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // (X * C2) << C1 --> X * (C2 << C1) if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + + // shl (zext i1 X), C1 --> select (X, 1 << C1, 0) + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } } // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 @@ -859,6 +1093,75 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return nullptr; } +Instruction * +InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr) { + assert(OldAShr.getOpcode() == Instruction::AShr && + "Must be called with arithmetic right-shift instruction only."); + + // Check that constant C is a splat of the element-wise bitwidth of V. + auto BitWidthSplat = [](Constant *C, Value *V) { + return match( + C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + V->getType()->getScalarSizeInBits()))); + }; + + // It should look like variable-length sign-extension on the outside: + // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits) + Value *NBits; + Instruction *MaybeTrunc; + Constant *C1, *C2; + if (!match(&OldAShr, + m_AShr(m_Shl(m_Instruction(MaybeTrunc), + m_ZExtOrSelf(m_Sub(m_Constant(C1), + m_ZExtOrSelf(m_Value(NBits))))), + m_ZExtOrSelf(m_Sub(m_Constant(C2), + m_ZExtOrSelf(m_Deferred(NBits)))))) || + !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr)) + return nullptr; + + // There may or may not be a truncation after outer two shifts. + Instruction *HighBitExtract; + match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract))); + bool HadTrunc = MaybeTrunc != HighBitExtract; + + // And finally, the innermost part of the pattern must be a right-shift. + Value *X, *NumLowBitsToSkip; + if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip)))) + return nullptr; + + // Said right-shift must extract high NBits bits - C0 must be it's bitwidth. + Constant *C0; + if (!match(NumLowBitsToSkip, + m_ZExtOrSelf( + m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) || + !BitWidthSplat(C0, HighBitExtract)) + return nullptr; + + // Since the NBits is identical for all shifts, if the outermost and + // innermost shifts are identical, then outermost shifts are redundant. + // If we had truncation, do keep it though. + if (HighBitExtract->getOpcode() == OldAShr.getOpcode()) + return replaceInstUsesWith(OldAShr, MaybeTrunc); + + // Else, if there was a truncation, then we need to ensure that one + // instruction will go away. + if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Finally, bypass two innermost shifts, and perform the outermost shift on + // the operands of the innermost shift. + Instruction *NewAShr = + BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip); + NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType()); +} + Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) @@ -933,6 +1236,9 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { } } + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) + return R; + // See if we can turn a signed shr into an unsigned shr. if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index e0d85c4b49ae..d30ab8001897 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -971,6 +971,13 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, APInt DemandedElts, int DMaskIdx) { + + // FIXME: Allow v3i16/v3f16 in buffer intrinsics when the types are fully supported. + if (DMaskIdx < 0 && + II->getType()->getScalarSizeInBits() != 32 && + DemandedElts.getActiveBits() == 3) + return nullptr; + unsigned VWidth = II->getType()->getVectorNumElements(); if (VWidth == 1) return nullptr; @@ -1067,16 +1074,22 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, } /// The specified value produces a vector with any number of elements. +/// This method analyzes which elements of the operand are undef and returns +/// that information in UndefElts. +/// /// DemandedElts contains the set of elements that are actually used by the -/// caller. This method analyzes which elements of the operand are undef and -/// returns that information in UndefElts. +/// caller, and by default (AllowMultipleUsers equals false) the value is +/// simplified only if it has a single caller. If AllowMultipleUsers is set +/// to true, DemandedElts refers to the union of sets of elements that are +/// used by all callers. /// /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is /// returned. This returns null if no change was made. Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, - unsigned Depth) { + unsigned Depth, + bool AllowMultipleUsers) { unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1130,19 +1143,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (Depth == 10) return nullptr; - // If multiple users are using the root value, proceed with - // simplification conservatively assuming that all elements - // are needed. - if (!V->hasOneUse()) { - // Quit if we find multiple users of a non-root value though. - // They'll be handled when it's their turn to be visited by - // the main instcombine process. - if (Depth != 0) - // TODO: Just compute the UndefElts information recursively. - return nullptr; + if (!AllowMultipleUsers) { + // If multiple users are using the root value, proceed with + // simplification conservatively assuming that all elements + // are needed. + if (!V->hasOneUse()) { + // Quit if we find multiple users of a non-root value though. + // They'll be handled when it's their turn to be visited by + // the main instcombine process. + if (Depth != 0) + // TODO: Just compute the UndefElts information recursively. + return nullptr; - // Conservatively assume that all elements are needed. - DemandedElts = EltMask; + // Conservatively assume that all elements are needed. + DemandedElts = EltMask; + } } Instruction *I = dyn_cast<Instruction>(V); @@ -1674,8 +1689,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::amdgcn_buffer_load_format: case Intrinsic::amdgcn_raw_buffer_load: case Intrinsic::amdgcn_raw_buffer_load_format: + case Intrinsic::amdgcn_raw_tbuffer_load: case Intrinsic::amdgcn_struct_buffer_load: case Intrinsic::amdgcn_struct_buffer_load_format: + case Intrinsic::amdgcn_struct_tbuffer_load: + case Intrinsic::amdgcn_tbuffer_load: return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); default: { if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index dc9abdd7f47a..9c890748e5ab 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -253,6 +253,69 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, return nullptr; } +/// Find elements of V demanded by UserInstr. +static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + // Conservatively assume that all elements are needed. + APInt UsedElts(APInt::getAllOnesValue(VWidth)); + + switch (UserInstr->getOpcode()) { + case Instruction::ExtractElement: { + ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr); + assert(EEI->getVectorOperand() == V); + ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand()); + if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { + UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + } + break; + } + case Instruction::ShuffleVector: { + ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr); + unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements(); + + UsedElts = APInt(VWidth, 0); + for (unsigned i = 0; i < MaskNumElts; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u || MaskVal >= 2 * VWidth) + continue; + if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) + UsedElts.setBit(MaskVal); + if (Shuffle->getOperand(1) == V && + ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) + UsedElts.setBit(MaskVal - VWidth); + } + break; + } + default: + break; + } + return UsedElts; +} + +/// Find union of elements of V demanded by all its users. +/// If it is known by querying findDemandedEltsBySingleUser that +/// no user demands an element of V, then the corresponding bit +/// remains unset in the returned value. +static APInt findDemandedEltsByAllUsers(Value *V) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + APInt UnionUsedElts(VWidth, 0); + for (const Use &U : V->uses()) { + if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { + UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + } else { + UnionUsedElts = APInt::getAllOnesValue(VWidth); + break; + } + + if (UnionUsedElts.isAllOnesValue()) + break; + } + + return UnionUsedElts; +} + Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); @@ -271,19 +334,35 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { return nullptr; // This instruction only demands the single element from the input vector. - // If the input vector has a single use, simplify it based on this use - // property. - if (SrcVec->hasOneUse() && NumElts != 1) { - APInt UndefElts(NumElts, 0); - APInt DemandedElts(NumElts, 0); - DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, - UndefElts)) { - EI.setOperand(0, V); - return &EI; + if (NumElts != 1) { + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } else { + // If the input vector has multiple uses, simplify it based on a union + // of all elements used. + APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + if (!DemandedElts.isAllOnesValue()) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts( + SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + true /* AllowMultipleUsers */)) { + if (V != SrcVec) { + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } + } } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) return I; @@ -766,6 +845,55 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); } +/// Try to fold an extract+insert element into an existing identity shuffle by +/// changing the shuffle's mask to include the index of this insert element. +static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { + // Check if the vector operand of this insert is an identity shuffle. + auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); + if (!Shuf || !isa<UndefValue>(Shuf->getOperand(1)) || + !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding())) + return nullptr; + + // Check for a constant insertion index. + uint64_t IdxC; + if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) + return nullptr; + + // Check if this insert's scalar op is extracted from the identity shuffle's + // input vector. + Value *Scalar = InsElt.getOperand(1); + Value *X = Shuf->getOperand(0); + if (!match(Scalar, m_ExtractElement(m_Specific(X), m_SpecificInt(IdxC)))) + return nullptr; + + // Replace the shuffle mask element at the index of this extract+insert with + // that same index value. + // For example: + // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask' + unsigned NumMaskElts = Shuf->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewMaskVec(NumMaskElts); + Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext()); + Constant *NewMaskEltC = ConstantInt::get(I32Ty, IdxC); + Constant *OldMask = Shuf->getMask(); + for (unsigned i = 0; i != NumMaskElts; ++i) { + if (i != IdxC) { + // All mask elements besides the inserted element remain the same. + NewMaskVec[i] = OldMask->getAggregateElement(i); + } else if (OldMask->getAggregateElement(i) == NewMaskEltC) { + // If the mask element was already set, there's nothing to do + // (demanded elements analysis may unset it later). + return nullptr; + } else { + assert(isa<UndefValue>(OldMask->getAggregateElement(i)) && + "Unexpected shuffle mask element for identity shuffle"); + NewMaskVec[i] = NewMaskEltC; + } + } + + Constant *NewMask = ConstantVector::get(NewMaskVec); + return new ShuffleVectorInst(X, Shuf->getOperand(1), NewMask); +} + /// If we have an insertelement instruction feeding into another insertelement /// and the 2nd is inserting a constant into the vector, canonicalize that /// constant insertion before the insertion of a variable: @@ -987,6 +1115,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Splat = foldInsEltIntoSplat(IE)) return Splat; + if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) + return IdentityShuf; + return nullptr; } @@ -1009,17 +1140,23 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, if (Depth == 0) return false; switch (I->getOpcode()) { + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // Propagating an undefined shuffle mask element to integer div/rem is not + // allowed because those opcodes can create immediate undefined behavior + // from an undefined element in an operand. + if (llvm::any_of(Mask, [](int M){ return M == -1; })) + return false; + LLVM_FALLTHROUGH; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: case Instruction::FSub: case Instruction::Mul: case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: case Instruction::FRem: case Instruction::Shl: case Instruction::LShr: @@ -1040,9 +1177,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, case Instruction::FPExt: case Instruction::GetElementPtr: { // Bail out if we would create longer vector ops. We could allow creating - // longer vector ops, but that may result in more expensive codegen. We - // would also need to limit the transform to avoid undefined behavior for - // integer div/rem. + // longer vector ops, but that may result in more expensive codegen. Type *ITy = I->getType(); if (ITy->isVectorTy() && Mask.size() > ITy->getVectorNumElements()) return false; diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 385f4926b845..ecb486c544e0 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -200,8 +200,8 @@ bool InstCombiner::shouldChangeType(Type *From, Type *To) const { // where both B and C should be ConstantInts, results in a constant that does // not overflow. This function only handles the Add and Sub opcodes. For // all other opcodes, the function conservatively returns false. -static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { - OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); +static bool maintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); if (!OBO || !OBO->hasNoSignedWrap()) return false; @@ -224,10 +224,15 @@ static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { } static bool hasNoUnsignedWrap(BinaryOperator &I) { - OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); return OBO && OBO->hasNoUnsignedWrap(); } +static bool hasNoSignedWrap(BinaryOperator &I) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + return OBO && OBO->hasNoSignedWrap(); +} + /// Conservatively clears subclassOptionalData after a reassociation or /// commutation. We preserve fast-math flags when applicable as they can be /// preserved. @@ -332,22 +337,21 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // It simplifies to V. Form "A op V". I.setOperand(0, A); I.setOperand(1, V); - // Conservatively clear the optional flags, since they may not be - // preserved by the reassociation. bool IsNUW = hasNoUnsignedWrap(I) && hasNoUnsignedWrap(*Op0); - bool IsNSW = MaintainNoSignedWrap(I, B, C); + bool IsNSW = maintainNoSignedWrap(I, B, C) && hasNoSignedWrap(*Op0); + // Conservatively clear all optional flags since they may not be + // preserved by the reassociation. Reset nsw/nuw based on the above + // analysis. ClearSubclassDataAfterReassociation(I); + // Note: this is only valid because SimplifyBinOp doesn't look at + // the operands to Op0. if (IsNUW) I.setHasNoUnsignedWrap(true); - if (IsNSW && - (!Op0 || (isa<BinaryOperator>(Op0) && Op0->hasNoSignedWrap()))) { - // Note: this is only valid because SimplifyBinOp doesn't look at - // the operands to Op0. + if (IsNSW) I.setHasNoSignedWrap(true); - } Changed = true; ++NumReassoc; @@ -610,7 +614,6 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, HasNUW &= ROBO->hasNoUnsignedWrap(); } - const APInt *CInt; if (TopLevelOpcode == Instruction::Add && InnerOpcode == Instruction::Mul) { // We can propagate 'nsw' if we know that @@ -620,6 +623,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, // %Z = mul nsw i16 %X, C+1 // // iff C+1 isn't INT_MIN + const APInt *CInt; if (match(V, m_APInt(CInt))) { if (!CInt->isMinSignedValue()) BO->setHasNoSignedWrap(HasNSW); @@ -763,12 +767,16 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse(); + + FastMathFlags FMF; BuilderTy::FastMathFlagGuard Guard(Builder); - if (isa<FPMathOperator>(&I)) - Builder.setFastMathFlags(I.getFastMathFlags()); + if (isa<FPMathOperator>(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } - Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); - Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I)); + Value *V1 = SimplifyBinOp(Opcode, C, E, FMF, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyBinOp(Opcode, B, D, FMF, SQ.getWithInstruction(&I)); if (V1 && V2) SI = Builder.CreateSelect(A, V2, V1); else if (V2 && SelectsHaveOneUse) @@ -1659,7 +1667,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // to an index of zero, so replace it with zero if it is not zero already. Type *EltTy = GTI.getIndexedType(); if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0) - if (!isa<Constant>(*I) || !cast<Constant>(*I)->isNullValue()) { + if (!isa<Constant>(*I) || !match(I->get(), m_Zero())) { *I = Constant::getNullValue(NewIndexType); MadeChange = true; } @@ -2549,9 +2557,7 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Change br (not X), label True, label False to: br X, label False, True Value *X = nullptr; - BasicBlock *TrueDest; - BasicBlock *FalseDest; - if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) && + if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && !isa<Constant>(X)) { // Swap Destinations and condition... BI.setCondition(X); @@ -2569,8 +2575,8 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest, - FalseDest)) && + if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), + m_BasicBlock(), m_BasicBlock())) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. CmpInst *Cond = cast<CmpInst>(BI.getCondition()); @@ -3156,6 +3162,21 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { findDbgUsers(DbgUsers, I); for (auto *DII : reverse(DbgUsers)) { if (DII->getParent() == SrcBlock) { + if (isa<DbgDeclareInst>(DII)) { + // A dbg.declare instruction should not be cloned, since there can only be + // one per variable fragment. It should be left in the original place since + // sunk instruction is not an alloca(otherwise we could not be here). + // But we need to update arguments of dbg.declare instruction, so that it + // would not point into sunk instruction. + if (!isa<CastInst>(I)) + continue; // dbg.declare points at something it shouldn't + + DII->setOperand( + 0, MetadataAsValue::get(I->getContext(), + ValueAsMetadata::get(I->getOperand(0)))); + continue; + } + // dbg.value is in the same basic block as the sunk inst, see if we can // salvage it. Clone a new copy of the instruction: on success we need // both salvaged and unsalvaged copies. @@ -3580,7 +3601,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { // Required analyses. auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 6821e214e921..d92ee11c2e1a 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -129,6 +129,8 @@ static const uintptr_t kRetiredStackFrameMagic = 0x45E0360E; static const char *const kAsanModuleCtorName = "asan.module_ctor"; static const char *const kAsanModuleDtorName = "asan.module_dtor"; static const uint64_t kAsanCtorAndDtorPriority = 1; +// On Emscripten, the system needs more than one priorities for constructors. +static const uint64_t kAsanEmscriptenCtorAndDtorPriority = 50; static const char *const kAsanReportErrorTemplate = "__asan_report_"; static const char *const kAsanRegisterGlobalsName = "__asan_register_globals"; static const char *const kAsanUnregisterGlobalsName = @@ -191,6 +193,11 @@ static cl::opt<bool> ClRecover( cl::desc("Enable recovery mode (continue-after-error)."), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClInsertVersionCheck( + "asan-guard-against-version-mismatch", + cl::desc("Guard against compiler/runtime version mismatch."), + cl::Hidden, cl::init(true)); + // This flag may need to be replaced with -f[no-]asan-reads. static cl::opt<bool> ClInstrumentReads("asan-instrument-reads", cl::desc("instrument read instructions"), @@ -530,6 +537,14 @@ static size_t RedzoneSizeForScale(int MappingScale) { return std::max(32U, 1U << MappingScale); } +static uint64_t GetCtorAndDtorPriority(Triple &TargetTriple) { + if (TargetTriple.isOSEmscripten()) { + return kAsanEmscriptenCtorAndDtorPriority; + } else { + return kAsanCtorAndDtorPriority; + } +} + namespace { /// Module analysis for getting various metadata about the module. @@ -565,10 +580,10 @@ char ASanGlobalsMetadataWrapperPass::ID = 0; /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer { - AddressSanitizer(Module &M, GlobalsMetadata &GlobalsMD, + AddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD, bool CompileKernel = false, bool Recover = false, bool UseAfterScope = false) - : UseAfterScope(UseAfterScope || ClUseAfterScope), GlobalsMD(GlobalsMD) { + : UseAfterScope(UseAfterScope || ClUseAfterScope), GlobalsMD(*GlobalsMD) { this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; this->CompileKernel = ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel; @@ -677,7 +692,7 @@ private: FunctionCallee AsanMemmove, AsanMemcpy, AsanMemset; InlineAsm *EmptyAsm; Value *LocalDynamicShadow = nullptr; - GlobalsMetadata GlobalsMD; + const GlobalsMetadata &GlobalsMD; DenseMap<const AllocaInst *, bool> ProcessedAllocas; }; @@ -706,8 +721,8 @@ public: GlobalsMetadata &GlobalsMD = getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - AddressSanitizer ASan(*F.getParent(), GlobalsMD, CompileKernel, Recover, + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + AddressSanitizer ASan(*F.getParent(), &GlobalsMD, CompileKernel, Recover, UseAfterScope); return ASan.instrumentFunction(F, TLI); } @@ -720,10 +735,10 @@ private: class ModuleAddressSanitizer { public: - ModuleAddressSanitizer(Module &M, GlobalsMetadata &GlobalsMD, + ModuleAddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD, bool CompileKernel = false, bool Recover = false, bool UseGlobalsGC = true, bool UseOdrIndicator = false) - : GlobalsMD(GlobalsMD), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), + : GlobalsMD(*GlobalsMD), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), // Enable aliases as they should have no downside with ODR indicators. UsePrivateAlias(UseOdrIndicator || ClUsePrivateAlias), UseOdrIndicator(UseOdrIndicator || ClUseOdrIndicator), @@ -783,7 +798,7 @@ private: } int GetAsanVersion(const Module &M) const; - GlobalsMetadata GlobalsMD; + const GlobalsMetadata &GlobalsMD; bool CompileKernel; bool Recover; bool UseGlobalsGC; @@ -830,7 +845,7 @@ public: bool runOnModule(Module &M) override { GlobalsMetadata &GlobalsMD = getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); - ModuleAddressSanitizer ASanModule(M, GlobalsMD, CompileKernel, Recover, + ModuleAddressSanitizer ASanModule(M, &GlobalsMD, CompileKernel, Recover, UseGlobalGC, UseOdrIndicator); return ASanModule.instrumentModule(M); } @@ -1033,7 +1048,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { if (!II.isLifetimeStartOrEnd()) return; // Found lifetime intrinsic, add ASan instrumentation if necessary. - ConstantInt *Size = dyn_cast<ConstantInt>(II.getArgOperand(0)); + auto *Size = cast<ConstantInt>(II.getArgOperand(0)); // If size argument is undefined, don't do anything. if (Size->isMinusOne()) return; // Check that size doesn't saturate uint64_t and can @@ -1156,7 +1171,7 @@ PreservedAnalyses AddressSanitizerPass::run(Function &F, Module &M = *F.getParent(); if (auto *R = MAM.getCachedResult<ASanGlobalsMetadataAnalysis>(M)) { const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); - AddressSanitizer Sanitizer(M, *R, CompileKernel, Recover, UseAfterScope); + AddressSanitizer Sanitizer(M, R, CompileKernel, Recover, UseAfterScope); if (Sanitizer.instrumentFunction(F, TLI)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); @@ -1178,7 +1193,7 @@ ModuleAddressSanitizerPass::ModuleAddressSanitizerPass(bool CompileKernel, PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, AnalysisManager<Module> &AM) { GlobalsMetadata &GlobalsMD = AM.getResult<ASanGlobalsMetadataAnalysis>(M); - ModuleAddressSanitizer Sanitizer(M, GlobalsMD, CompileKernel, Recover, + ModuleAddressSanitizer Sanitizer(M, &GlobalsMD, CompileKernel, Recover, UseGlobalGC, UseOdrIndicator); if (Sanitizer.instrumentModule(M)) return PreservedAnalyses::none(); @@ -1331,7 +1346,7 @@ Value *AddressSanitizer::isInterestingMemoryAccess(Instruction *I, unsigned *Alignment, Value **MaybeMask) { // Skip memory accesses inserted by another instrumentation. - if (I->getMetadata("nosanitize")) return nullptr; + if (I->hasMetadata("nosanitize")) return nullptr; // Do not instrument the load fetching the dynamic shadow address. if (LocalDynamicShadow == I) @@ -1775,9 +1790,10 @@ void ModuleAddressSanitizer::createInitializerPoisonCalls( // Must have a function or null ptr. if (Function *F = dyn_cast<Function>(CS->getOperand(1))) { if (F->getName() == kAsanModuleCtorName) continue; - ConstantInt *Priority = dyn_cast<ConstantInt>(CS->getOperand(0)); + auto *Priority = cast<ConstantInt>(CS->getOperand(0)); // Don't instrument CTORs that will run before asan.module_ctor. - if (Priority->getLimitedValue() <= kAsanCtorAndDtorPriority) continue; + if (Priority->getLimitedValue() <= GetCtorAndDtorPriority(TargetTriple)) + continue; poisonOneInitializer(*F, ModuleName); } } @@ -1919,7 +1935,12 @@ StringRef ModuleAddressSanitizer::getGlobalMetadataSection() const { case Triple::COFF: return ".ASAN$GL"; case Triple::ELF: return "asan_globals"; case Triple::MachO: return "__DATA,__asan_globals,regular"; - default: break; + case Triple::Wasm: + case Triple::XCOFF: + report_fatal_error( + "ModuleAddressSanitizer not implemented for object file format."); + case Triple::UnknownObjectFormat: + break; } llvm_unreachable("unsupported object format"); } @@ -2033,7 +2054,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsCOFF( unsigned SizeOfGlobalStruct = DL.getTypeAllocSize(Initializer->getType()); assert(isPowerOf2_32(SizeOfGlobalStruct) && "global metadata will not be padded appropriately"); - Metadata->setAlignment(SizeOfGlobalStruct); + Metadata->setAlignment(assumeAligned(SizeOfGlobalStruct)); SetComdatForGlobalMetadata(G, Metadata, ""); } @@ -2170,7 +2191,7 @@ void ModuleAddressSanitizer::InstrumentGlobalsWithMetadataArray( M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage, ConstantArray::get(ArrayOfGlobalStructTy, MetadataInitializers), ""); if (Mapping.Scale > 3) - AllGlobals->setAlignment(1ULL << Mapping.Scale); + AllGlobals->setAlignment(Align(1ULL << Mapping.Scale)); IRB.CreateCall(AsanRegisterGlobals, {IRB.CreatePointerCast(AllGlobals, IntptrTy), @@ -2270,7 +2291,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, "", G, G->getThreadLocalMode()); NewGlobal->copyAttributesFrom(G); NewGlobal->setComdat(G->getComdat()); - NewGlobal->setAlignment(MinRZ); + NewGlobal->setAlignment(MaybeAlign(MinRZ)); // Don't fold globals with redzones. ODR violation detector and redzone // poisoning implicitly creates a dependence on the global's address, so it // is no longer valid for it to be marked unnamed_addr. @@ -2338,7 +2359,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, // Set meaningful attributes for indicator symbol. ODRIndicatorSym->setVisibility(NewGlobal->getVisibility()); ODRIndicatorSym->setDLLStorageClass(NewGlobal->getDLLStorageClass()); - ODRIndicatorSym->setAlignment(1); + ODRIndicatorSym->setAlignment(Align::None()); ODRIndicator = ODRIndicatorSym; } @@ -2410,39 +2431,39 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) { // Create a module constructor. A destructor is created lazily because not all // platforms, and not all modules need it. + std::string AsanVersion = std::to_string(GetAsanVersion(M)); std::string VersionCheckName = - kAsanVersionCheckNamePrefix + std::to_string(GetAsanVersion(M)); + ClInsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : ""; std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{}, /*InitArgs=*/{}, VersionCheckName); bool CtorComdat = true; - bool Changed = false; // TODO(glider): temporarily disabled globals instrumentation for KASan. if (ClGlobals) { IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator()); - Changed |= InstrumentGlobals(IRB, M, &CtorComdat); + InstrumentGlobals(IRB, M, &CtorComdat); } + const uint64_t Priority = GetCtorAndDtorPriority(TargetTriple); + // Put the constructor and destructor in comdat if both // (1) global instrumentation is not TU-specific // (2) target is ELF. if (UseCtorComdat && TargetTriple.isOSBinFormatELF() && CtorComdat) { AsanCtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleCtorName)); - appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority, - AsanCtorFunction); + appendToGlobalCtors(M, AsanCtorFunction, Priority, AsanCtorFunction); if (AsanDtorFunction) { AsanDtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleDtorName)); - appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority, - AsanDtorFunction); + appendToGlobalDtors(M, AsanDtorFunction, Priority, AsanDtorFunction); } } else { - appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority); + appendToGlobalCtors(M, AsanCtorFunction, Priority); if (AsanDtorFunction) - appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority); + appendToGlobalDtors(M, AsanDtorFunction, Priority); } - return Changed; + return true; } void AddressSanitizer::initializeCallbacks(Module &M) { @@ -2664,7 +2685,7 @@ bool AddressSanitizer::instrumentFunction(Function &F, if (CS) { // A call inside BB. TempsToInstrument.clear(); - if (CS.doesNotReturn() && !CS->getMetadata("nosanitize")) + if (CS.doesNotReturn() && !CS->hasMetadata("nosanitize")) NoReturnCalls.push_back(CS.getInstruction()); } if (CallInst *CI = dyn_cast<CallInst>(&Inst)) @@ -2877,18 +2898,19 @@ void FunctionStackPoisoner::copyArgsPassedByValToAllocas() { for (Argument &Arg : F.args()) { if (Arg.hasByValAttr()) { Type *Ty = Arg.getType()->getPointerElementType(); - unsigned Align = Arg.getParamAlignment(); - if (Align == 0) Align = DL.getABITypeAlignment(Ty); + unsigned Alignment = Arg.getParamAlignment(); + if (Alignment == 0) + Alignment = DL.getABITypeAlignment(Ty); AllocaInst *AI = IRB.CreateAlloca( Ty, nullptr, (Arg.hasName() ? Arg.getName() : "Arg" + Twine(Arg.getArgNo())) + ".byval"); - AI->setAlignment(Align); + AI->setAlignment(Align(Alignment)); Arg.replaceAllUsesWith(AI); uint64_t AllocSize = DL.getTypeAllocSize(Ty); - IRB.CreateMemCpy(AI, Align, &Arg, Align, AllocSize); + IRB.CreateMemCpy(AI, Alignment, &Arg, Alignment, AllocSize); } } } @@ -2919,7 +2941,7 @@ Value *FunctionStackPoisoner::createAllocaForLayout( } assert((ClRealignStack & (ClRealignStack - 1)) == 0); size_t FrameAlignment = std::max(L.FrameAlignment, (size_t)ClRealignStack); - Alloca->setAlignment(FrameAlignment); + Alloca->setAlignment(MaybeAlign(FrameAlignment)); return IRB.CreatePointerCast(Alloca, IntptrTy); } @@ -2928,7 +2950,7 @@ void FunctionStackPoisoner::createDynamicAllocasInitStorage() { IRBuilder<> IRB(dyn_cast<Instruction>(FirstBB.begin())); DynamicAllocaLayout = IRB.CreateAlloca(IntptrTy, nullptr); IRB.CreateStore(Constant::getNullValue(IntptrTy), DynamicAllocaLayout); - DynamicAllocaLayout->setAlignment(32); + DynamicAllocaLayout->setAlignment(Align(32)); } void FunctionStackPoisoner::processDynamicAllocas() { @@ -3275,7 +3297,7 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { // Insert new alloca with new NewSize and Align params. AllocaInst *NewAlloca = IRB.CreateAlloca(IRB.getInt8Ty(), NewSize); - NewAlloca->setAlignment(Align); + NewAlloca->setAlignment(MaybeAlign(Align)); // NewAddress = Address + Align Value *NewAddress = IRB.CreateAdd(IRB.CreatePtrToInt(NewAlloca, IntptrTy), diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp index 4dc9b611c156..ae34be986537 100644 --- a/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -224,7 +224,7 @@ struct BoundsCheckingLegacyPass : public FunctionPass { } bool runOnFunction(Function &F) override { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); return addBoundsChecking(F, TLI, SE); } diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h index 971e00041762..8bb6f47c4846 100644 --- a/lib/Transforms/Instrumentation/CFGMST.h +++ b/lib/Transforms/Instrumentation/CFGMST.h @@ -257,13 +257,13 @@ public: std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr)); if (Inserted) { // Newly inserted, update the real info. - Iter->second = std::move(llvm::make_unique<BBInfo>(Index)); + Iter->second = std::move(std::make_unique<BBInfo>(Index)); Index++; } std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr)); if (Inserted) // Newly inserted, update the real info. - Iter->second = std::move(llvm::make_unique<BBInfo>(Index)); + Iter->second = std::move(std::make_unique<BBInfo>(Index)); AllEdges.emplace_back(new Edge(Src, Dest, W)); return *AllEdges.back(); } diff --git a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 3f4f9bc7145d..55c64fa4b727 100644 --- a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -512,30 +512,38 @@ static bool isHoistable(Instruction *I, DominatorTree &DT) { // first-region entry block) or the (hoistable or unhoistable) base values that // are defined outside (including the first-region entry block) of the // scope. The returned set doesn't include constants. -static std::set<Value *> getBaseValues(Value *V, - DominatorTree &DT) { +static std::set<Value *> getBaseValues( + Value *V, DominatorTree &DT, + DenseMap<Value *, std::set<Value *>> &Visited) { + if (Visited.count(V)) { + return Visited[V]; + } std::set<Value *> Result; if (auto *I = dyn_cast<Instruction>(V)) { // We don't stop at a block that's not in the Scope because we would miss some // instructions that are based on the same base values if we stop there. if (!isHoistable(I, DT)) { Result.insert(I); + Visited.insert(std::make_pair(V, Result)); return Result; } // I is hoistable above the Scope. for (Value *Op : I->operands()) { - std::set<Value *> OpResult = getBaseValues(Op, DT); + std::set<Value *> OpResult = getBaseValues(Op, DT, Visited); Result.insert(OpResult.begin(), OpResult.end()); } + Visited.insert(std::make_pair(V, Result)); return Result; } if (isa<Argument>(V)) { Result.insert(V); + Visited.insert(std::make_pair(V, Result)); return Result; } // We don't include others like constants because those won't lead to any // chance of folding of conditions (eg two bit checks merged into one check) // after CHR. + Visited.insert(std::make_pair(V, Result)); return Result; // empty } @@ -1078,12 +1086,13 @@ static bool shouldSplit(Instruction *InsertPoint, if (!PrevConditionValues.empty() && !ConditionValues.empty()) { // Use std::set as DenseSet doesn't work with set_intersection. std::set<Value *> PrevBases, Bases; + DenseMap<Value *, std::set<Value *>> Visited; for (Value *V : PrevConditionValues) { - std::set<Value *> BaseValues = getBaseValues(V, DT); + std::set<Value *> BaseValues = getBaseValues(V, DT, Visited); PrevBases.insert(BaseValues.begin(), BaseValues.end()); } for (Value *V : ConditionValues) { - std::set<Value *> BaseValues = getBaseValues(V, DT); + std::set<Value *> BaseValues = getBaseValues(V, DT, Visited); Bases.insert(BaseValues.begin(), BaseValues.end()); } CHR_DEBUG( @@ -1538,10 +1547,7 @@ static bool negateICmpIfUsedByBranchOrSelectOnly(ICmpInst *ICmp, } if (auto *SI = dyn_cast<SelectInst>(U)) { // Swap operands - Value *TrueValue = SI->getTrueValue(); - Value *FalseValue = SI->getFalseValue(); - SI->setTrueValue(FalseValue); - SI->setFalseValue(TrueValue); + SI->swapValues(); SI->swapProfMetadata(); if (Scope->TrueBiasedSelects.count(SI)) { assert(Scope->FalseBiasedSelects.count(SI) == 0 && @@ -2073,7 +2079,7 @@ bool ControlHeightReductionLegacyPass::runOnFunction(Function &F) { getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); RegionInfo &RI = getAnalysis<RegionInfoPass>().getRegionInfo(); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE = - llvm::make_unique<OptimizationRemarkEmitter>(&F); + std::make_unique<OptimizationRemarkEmitter>(&F); return CHR(F, BFI, DT, PSI, RI, *OwnedORE.get()).run(); } diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 2279c1bcb6a8..c0353cba0b2f 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -1212,7 +1212,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, return DFS.ZeroShadow; case 1: { LoadInst *LI = new LoadInst(DFS.ShadowTy, ShadowAddr, "", Pos); - LI->setAlignment(ShadowAlign); + LI->setAlignment(MaybeAlign(ShadowAlign)); return LI; } case 2: { diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 59950ffc4e9a..ac6082441eae 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -86,7 +86,9 @@ public: ReversedVersion[3] = Options.Version[0]; ReversedVersion[4] = '\0'; } - bool runOnModule(Module &M, const TargetLibraryInfo &TLI); + bool + runOnModule(Module &M, + std::function<const TargetLibraryInfo &(Function &F)> GetTLI); private: // Create the .gcno files for the Module based on DebugInfo. @@ -102,9 +104,9 @@ private: std::vector<Regex> &Regexes); // Get pointers to the functions in the runtime library. - FunctionCallee getStartFileFunc(); - FunctionCallee getEmitFunctionFunc(); - FunctionCallee getEmitArcsFunc(); + FunctionCallee getStartFileFunc(const TargetLibraryInfo *TLI); + FunctionCallee getEmitFunctionFunc(const TargetLibraryInfo *TLI); + FunctionCallee getEmitArcsFunc(const TargetLibraryInfo *TLI); FunctionCallee getSummaryInfoFunc(); FunctionCallee getEndFileFunc(); @@ -127,7 +129,7 @@ private: SmallVector<uint32_t, 4> FileChecksums; Module *M; - const TargetLibraryInfo *TLI; + std::function<const TargetLibraryInfo &(Function &F)> GetTLI; LLVMContext *Ctx; SmallVector<std::unique_ptr<GCOVFunction>, 16> Funcs; std::vector<Regex> FilterRe; @@ -147,8 +149,9 @@ public: StringRef getPassName() const override { return "GCOV Profiler"; } bool runOnModule(Module &M) override { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return Profiler.runOnModule(M, TLI); + return Profiler.runOnModule(M, [this](Function &F) -> TargetLibraryInfo & { + return getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -555,9 +558,10 @@ std::string GCOVProfiler::mangleName(const DICompileUnit *CU, return CurPath.str(); } -bool GCOVProfiler::runOnModule(Module &M, const TargetLibraryInfo &TLI) { +bool GCOVProfiler::runOnModule( + Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI) { this->M = &M; - this->TLI = &TLI; + this->GetTLI = std::move(GetTLI); Ctx = &M.getContext(); AddFlushBeforeForkAndExec(); @@ -574,9 +578,12 @@ PreservedAnalyses GCOVProfilerPass::run(Module &M, ModuleAnalysisManager &AM) { GCOVProfiler Profiler(GCOVOpts); + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); - if (!Profiler.runOnModule(M, TLI)) + if (!Profiler.runOnModule(M, [&](Function &F) -> TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + })) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -624,6 +631,7 @@ static bool shouldKeepInEntry(BasicBlock::iterator It) { void GCOVProfiler::AddFlushBeforeForkAndExec() { SmallVector<Instruction *, 2> ForkAndExecs; for (auto &F : M->functions()) { + auto *TLI = &GetTLI(F); for (auto &I : instructions(F)) { if (CallInst *CI = dyn_cast<CallInst>(&I)) { if (Function *Callee = CI->getCalledFunction()) { @@ -669,7 +677,8 @@ void GCOVProfiler::emitProfileNotes() { continue; std::error_code EC; - raw_fd_ostream out(mangleName(CU, GCovFileType::GCNO), EC, sys::fs::F_None); + raw_fd_ostream out(mangleName(CU, GCovFileType::GCNO), EC, + sys::fs::OF_None); if (EC) { Ctx->emitError(Twine("failed to open coverage notes file for writing: ") + EC.message()); @@ -695,7 +704,7 @@ void GCOVProfiler::emitProfileNotes() { ++It; EntryBlock.splitBasicBlock(It); - Funcs.push_back(make_unique<GCOVFunction>(SP, &F, &out, FunctionIdent++, + Funcs.push_back(std::make_unique<GCOVFunction>(SP, &F, &out, FunctionIdent++, Options.UseCfgChecksum, Options.ExitBlockBeforeBody)); GCOVFunction &Func = *Funcs.back(); @@ -873,7 +882,7 @@ bool GCOVProfiler::emitProfileArcs() { return Result; } -FunctionCallee GCOVProfiler::getStartFileFunc() { +FunctionCallee GCOVProfiler::getStartFileFunc(const TargetLibraryInfo *TLI) { Type *Args[] = { Type::getInt8PtrTy(*Ctx), // const char *orig_filename Type::getInt8PtrTy(*Ctx), // const char version[4] @@ -887,7 +896,7 @@ FunctionCallee GCOVProfiler::getStartFileFunc() { return Res; } -FunctionCallee GCOVProfiler::getEmitFunctionFunc() { +FunctionCallee GCOVProfiler::getEmitFunctionFunc(const TargetLibraryInfo *TLI) { Type *Args[] = { Type::getInt32Ty(*Ctx), // uint32_t ident Type::getInt8PtrTy(*Ctx), // const char *function_name @@ -906,7 +915,7 @@ FunctionCallee GCOVProfiler::getEmitFunctionFunc() { return M->getOrInsertFunction("llvm_gcda_emit_function", FTy); } -FunctionCallee GCOVProfiler::getEmitArcsFunc() { +FunctionCallee GCOVProfiler::getEmitArcsFunc(const TargetLibraryInfo *TLI) { Type *Args[] = { Type::getInt32Ty(*Ctx), // uint32_t num_counters Type::getInt64PtrTy(*Ctx), // uint64_t *counters @@ -943,9 +952,11 @@ Function *GCOVProfiler::insertCounterWriteout( BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", WriteoutF); IRBuilder<> Builder(BB); - FunctionCallee StartFile = getStartFileFunc(); - FunctionCallee EmitFunction = getEmitFunctionFunc(); - FunctionCallee EmitArcs = getEmitArcsFunc(); + auto *TLI = &GetTLI(*WriteoutF); + + FunctionCallee StartFile = getStartFileFunc(TLI); + FunctionCallee EmitFunction = getEmitFunctionFunc(TLI); + FunctionCallee EmitArcs = getEmitArcsFunc(TLI); FunctionCallee SummaryInfo = getSummaryInfoFunc(); FunctionCallee EndFile = getEndFileFunc(); diff --git a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 90a9f4955a4b..f87132ee4758 100644 --- a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -12,10 +12,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/BinaryFormat/ELF.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -52,7 +54,10 @@ using namespace llvm; #define DEBUG_TYPE "hwasan" static const char *const kHwasanModuleCtorName = "hwasan.module_ctor"; +static const char *const kHwasanNoteName = "hwasan.note"; static const char *const kHwasanInitName = "__hwasan_init"; +static const char *const kHwasanPersonalityThunkName = + "__hwasan_personality_thunk"; static const char *const kHwasanShadowMemoryDynamicAddress = "__hwasan_shadow_memory_dynamic_address"; @@ -112,6 +117,9 @@ static cl::opt<bool> ClGenerateTagsWithCalls( cl::desc("generate new tags with runtime library calls"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClGlobals("hwasan-globals", cl::desc("Instrument globals"), + cl::Hidden, cl::init(false)); + static cl::opt<int> ClMatchAllTag( "hwasan-match-all-tag", cl::desc("don't report bad accesses via pointers with this tag"), @@ -155,8 +163,18 @@ static cl::opt<bool> static cl::opt<bool> ClInstrumentLandingPads("hwasan-instrument-landing-pads", - cl::desc("instrument landing pads"), cl::Hidden, - cl::init(true)); + cl::desc("instrument landing pads"), cl::Hidden, + cl::init(false), cl::ZeroOrMore); + +static cl::opt<bool> ClUseShortGranules( + "hwasan-use-short-granules", + cl::desc("use short granules in allocas and outlined checks"), cl::Hidden, + cl::init(false), cl::ZeroOrMore); + +static cl::opt<bool> ClInstrumentPersonalityFunctions( + "hwasan-instrument-personality-functions", + cl::desc("instrument personality functions"), cl::Hidden, cl::init(false), + cl::ZeroOrMore); static cl::opt<bool> ClInlineAllChecks("hwasan-inline-all-checks", cl::desc("inline all checks"), @@ -169,16 +187,16 @@ namespace { class HWAddressSanitizer { public: explicit HWAddressSanitizer(Module &M, bool CompileKernel = false, - bool Recover = false) { + bool Recover = false) : M(M) { this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; this->CompileKernel = ClEnableKhwasan.getNumOccurrences() > 0 ? ClEnableKhwasan : CompileKernel; - initializeModule(M); + initializeModule(); } bool sanitizeFunction(Function &F); - void initializeModule(Module &M); + void initializeModule(); void initializeCallbacks(Module &M); @@ -216,9 +234,14 @@ public: Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty); void emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord); + void instrumentGlobal(GlobalVariable *GV, uint8_t Tag); + void instrumentGlobals(); + + void instrumentPersonalityFunctions(); + private: LLVMContext *C; - std::string CurModuleUniqueId; + Module &M; Triple TargetTriple; FunctionCallee HWAsanMemmove, HWAsanMemcpy, HWAsanMemset; FunctionCallee HWAsanHandleVfork; @@ -238,17 +261,21 @@ private: bool InTls; void init(Triple &TargetTriple); - unsigned getAllocaAlignment() const { return 1U << Scale; } + unsigned getObjectAlignment() const { return 1U << Scale; } }; ShadowMapping Mapping; + Type *VoidTy = Type::getVoidTy(M.getContext()); Type *IntptrTy; Type *Int8PtrTy; Type *Int8Ty; Type *Int32Ty; + Type *Int64Ty = Type::getInt64Ty(M.getContext()); bool CompileKernel; bool Recover; + bool UseShortGranules; + bool InstrumentLandingPads; Function *HwasanCtorFunction; @@ -278,7 +305,7 @@ public: StringRef getPassName() const override { return "HWAddressSanitizer"; } bool doInitialization(Module &M) override { - HWASan = llvm::make_unique<HWAddressSanitizer>(M, CompileKernel, Recover); + HWASan = std::make_unique<HWAddressSanitizer>(M, CompileKernel, Recover); return true; } @@ -333,7 +360,7 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M, /// Module-level initialization. /// /// inserts a call to __hwasan_init to the module's constructor list. -void HWAddressSanitizer::initializeModule(Module &M) { +void HWAddressSanitizer::initializeModule() { LLVM_DEBUG(dbgs() << "Init " << M.getName() << "\n"); auto &DL = M.getDataLayout(); @@ -342,7 +369,6 @@ void HWAddressSanitizer::initializeModule(Module &M) { Mapping.init(TargetTriple); C = &(M.getContext()); - CurModuleUniqueId = getUniqueModuleId(&M); IRBuilder<> IRB(*C); IntptrTy = IRB.getIntPtrTy(DL); Int8PtrTy = IRB.getInt8PtrTy(); @@ -350,6 +376,21 @@ void HWAddressSanitizer::initializeModule(Module &M) { Int32Ty = IRB.getInt32Ty(); HwasanCtorFunction = nullptr; + + // Older versions of Android do not have the required runtime support for + // short granules, global or personality function instrumentation. On other + // platforms we currently require using the latest version of the runtime. + bool NewRuntime = + !TargetTriple.isAndroid() || !TargetTriple.isAndroidVersionLT(30); + + UseShortGranules = + ClUseShortGranules.getNumOccurrences() ? ClUseShortGranules : NewRuntime; + + // If we don't have personality function support, fall back to landing pads. + InstrumentLandingPads = ClInstrumentLandingPads.getNumOccurrences() + ? ClInstrumentLandingPads + : !NewRuntime; + if (!CompileKernel) { std::tie(HwasanCtorFunction, std::ignore) = getOrCreateSanitizerCtorAndInitFunctions( @@ -363,6 +404,18 @@ void HWAddressSanitizer::initializeModule(Module &M) { Ctor->setComdat(CtorComdat); appendToGlobalCtors(M, Ctor, 0, Ctor); }); + + bool InstrumentGlobals = + ClGlobals.getNumOccurrences() ? ClGlobals : NewRuntime; + if (InstrumentGlobals) + instrumentGlobals(); + + bool InstrumentPersonalityFunctions = + ClInstrumentPersonalityFunctions.getNumOccurrences() + ? ClInstrumentPersonalityFunctions + : NewRuntime; + if (InstrumentPersonalityFunctions) + instrumentPersonalityFunctions(); } if (!TargetTriple.isAndroid()) { @@ -456,7 +509,7 @@ Value *HWAddressSanitizer::isInterestingMemoryAccess(Instruction *I, unsigned *Alignment, Value **MaybeMask) { // Skip memory accesses inserted by another instrumentation. - if (I->getMetadata("nosanitize")) return nullptr; + if (I->hasMetadata("nosanitize")) return nullptr; // Do not instrument the load fetching the dynamic shadow address. if (LocalDynamicShadow == I) @@ -564,9 +617,11 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, TargetTriple.isOSBinFormatELF() && !Recover) { Module *M = IRB.GetInsertBlock()->getParent()->getParent(); Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); - IRB.CreateCall( - Intrinsic::getDeclaration(M, Intrinsic::hwasan_check_memaccess), - {shadowBase(), Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); + IRB.CreateCall(Intrinsic::getDeclaration( + M, UseShortGranules + ? Intrinsic::hwasan_check_memaccess_shortgranules + : Intrinsic::hwasan_check_memaccess), + {shadowBase(), Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); return; } @@ -718,7 +773,9 @@ static uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size) { - size_t AlignedSize = alignTo(Size, Mapping.getAllocaAlignment()); + size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); + if (!UseShortGranules) + Size = AlignedSize; Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty()); if (ClInstrumentWithCalls) { @@ -738,7 +795,7 @@ bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, /*Align=*/1); if (Size != AlignedSize) { IRB.CreateStore( - ConstantInt::get(Int8Ty, Size % Mapping.getAllocaAlignment()), + ConstantInt::get(Int8Ty, Size % Mapping.getObjectAlignment()), IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize)); IRB.CreateStore(JustTag, IRB.CreateConstGEP1_32( Int8Ty, IRB.CreateBitCast(AI, Int8PtrTy), @@ -778,8 +835,9 @@ Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { // FIXME: use addressofreturnaddress (but implement it in aarch64 backend // first). Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - auto GetStackPointerFn = - Intrinsic::getDeclaration(M, Intrinsic::frameaddress); + auto GetStackPointerFn = Intrinsic::getDeclaration( + M, Intrinsic::frameaddress, + IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); Value *StackPointer = IRB.CreateCall( GetStackPointerFn, {Constant::getNullValue(IRB.getInt32Ty())}); @@ -912,8 +970,10 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { PC = readRegister(IRB, "pc"); else PC = IRB.CreatePtrToInt(F, IntptrTy); - auto GetStackPointerFn = - Intrinsic::getDeclaration(F->getParent(), Intrinsic::frameaddress); + Module *M = F->getParent(); + auto GetStackPointerFn = Intrinsic::getDeclaration( + M, Intrinsic::frameaddress, + IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); Value *SP = IRB.CreatePtrToInt( IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(IRB.getInt32Ty())}), @@ -999,11 +1059,8 @@ bool HWAddressSanitizer::instrumentStack( AI->hasName() ? AI->getName().str() : "alloca." + itostr(N); Replacement->setName(Name + ".hwasan"); - for (auto UI = AI->use_begin(), UE = AI->use_end(); UI != UE;) { - Use &U = *UI++; - if (U.getUser() != AILong) - U.set(Replacement); - } + AI->replaceUsesWithIf(Replacement, + [AILong](Use &U) { return U.getUser() != AILong; }); for (auto *DDI : AllocaDeclareMap.lookup(AI)) { DIExpression *OldExpr = DDI->getExpression(); @@ -1020,7 +1077,7 @@ bool HWAddressSanitizer::instrumentStack( // Re-tag alloca memory with the special UAR tag. Value *Tag = getUARTag(IRB, StackTag); - tagAlloca(IRB, AI, Tag, alignTo(Size, Mapping.getAllocaAlignment())); + tagAlloca(IRB, AI, Tag, alignTo(Size, Mapping.getObjectAlignment())); } } @@ -1074,7 +1131,7 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { if (auto *Alloca = dyn_cast_or_null<AllocaInst>(DDI->getAddress())) AllocaDeclareMap[Alloca].push_back(DDI); - if (ClInstrumentLandingPads && isa<LandingPadInst>(Inst)) + if (InstrumentLandingPads && isa<LandingPadInst>(Inst)) LandingPadVec.push_back(&Inst); Value *MaybeMask = nullptr; @@ -1093,6 +1150,13 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { if (!LandingPadVec.empty()) instrumentLandingPads(LandingPadVec); + if (AllocasToInstrument.empty() && F.hasPersonalityFn() && + F.getPersonalityFn()->getName() == kHwasanPersonalityThunkName) { + // __hwasan_personality_thunk is a no-op for functions without an + // instrumented stack, so we can drop it. + F.setPersonalityFn(nullptr); + } + if (AllocasToInstrument.empty() && ToInstrument.empty()) return false; @@ -1118,8 +1182,9 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap; for (AllocaInst *AI : AllocasToInstrument) { uint64_t Size = getAllocaSizeInBytes(*AI); - uint64_t AlignedSize = alignTo(Size, Mapping.getAllocaAlignment()); - AI->setAlignment(std::max(AI->getAlignment(), 16u)); + uint64_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); + AI->setAlignment( + MaybeAlign(std::max(AI->getAlignment(), Mapping.getObjectAlignment()))); if (Size != AlignedSize) { Type *AllocatedType = AI->getAllocatedType(); if (AI->isArrayAllocation()) { @@ -1132,7 +1197,7 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { auto *NewAI = new AllocaInst( TypeWithPadding, AI->getType()->getAddressSpace(), nullptr, "", AI); NewAI->takeName(AI); - NewAI->setAlignment(AI->getAlignment()); + NewAI->setAlignment(MaybeAlign(AI->getAlignment())); NewAI->setUsedWithInAlloca(AI->isUsedWithInAlloca()); NewAI->setSwiftError(AI->isSwiftError()); NewAI->copyMetadata(*AI); @@ -1179,6 +1244,257 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F) { return Changed; } +void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) { + Constant *Initializer = GV->getInitializer(); + uint64_t SizeInBytes = + M.getDataLayout().getTypeAllocSize(Initializer->getType()); + uint64_t NewSize = alignTo(SizeInBytes, Mapping.getObjectAlignment()); + if (SizeInBytes != NewSize) { + // Pad the initializer out to the next multiple of 16 bytes and add the + // required short granule tag. + std::vector<uint8_t> Init(NewSize - SizeInBytes, 0); + Init.back() = Tag; + Constant *Padding = ConstantDataArray::get(*C, Init); + Initializer = ConstantStruct::getAnon({Initializer, Padding}); + } + + auto *NewGV = new GlobalVariable(M, Initializer->getType(), GV->isConstant(), + GlobalValue::ExternalLinkage, Initializer, + GV->getName() + ".hwasan"); + NewGV->copyAttributesFrom(GV); + NewGV->setLinkage(GlobalValue::PrivateLinkage); + NewGV->copyMetadata(GV, 0); + NewGV->setAlignment( + MaybeAlign(std::max(GV->getAlignment(), Mapping.getObjectAlignment()))); + + // It is invalid to ICF two globals that have different tags. In the case + // where the size of the global is a multiple of the tag granularity the + // contents of the globals may be the same but the tags (i.e. symbol values) + // may be different, and the symbols are not considered during ICF. In the + // case where the size is not a multiple of the granularity, the short granule + // tags would discriminate two globals with different tags, but there would + // otherwise be nothing stopping such a global from being incorrectly ICF'd + // with an uninstrumented (i.e. tag 0) global that happened to have the short + // granule tag in the last byte. + NewGV->setUnnamedAddr(GlobalValue::UnnamedAddr::None); + + // Descriptor format (assuming little-endian): + // bytes 0-3: relative address of global + // bytes 4-6: size of global (16MB ought to be enough for anyone, but in case + // it isn't, we create multiple descriptors) + // byte 7: tag + auto *DescriptorTy = StructType::get(Int32Ty, Int32Ty); + const uint64_t MaxDescriptorSize = 0xfffff0; + for (uint64_t DescriptorPos = 0; DescriptorPos < SizeInBytes; + DescriptorPos += MaxDescriptorSize) { + auto *Descriptor = + new GlobalVariable(M, DescriptorTy, true, GlobalValue::PrivateLinkage, + nullptr, GV->getName() + ".hwasan.descriptor"); + auto *GVRelPtr = ConstantExpr::getTrunc( + ConstantExpr::getAdd( + ConstantExpr::getSub( + ConstantExpr::getPtrToInt(NewGV, Int64Ty), + ConstantExpr::getPtrToInt(Descriptor, Int64Ty)), + ConstantInt::get(Int64Ty, DescriptorPos)), + Int32Ty); + uint32_t Size = std::min(SizeInBytes - DescriptorPos, MaxDescriptorSize); + auto *SizeAndTag = ConstantInt::get(Int32Ty, Size | (uint32_t(Tag) << 24)); + Descriptor->setComdat(NewGV->getComdat()); + Descriptor->setInitializer(ConstantStruct::getAnon({GVRelPtr, SizeAndTag})); + Descriptor->setSection("hwasan_globals"); + Descriptor->setMetadata(LLVMContext::MD_associated, + MDNode::get(*C, ValueAsMetadata::get(NewGV))); + appendToCompilerUsed(M, Descriptor); + } + + Constant *Aliasee = ConstantExpr::getIntToPtr( + ConstantExpr::getAdd( + ConstantExpr::getPtrToInt(NewGV, Int64Ty), + ConstantInt::get(Int64Ty, uint64_t(Tag) << kPointerTagShift)), + GV->getType()); + auto *Alias = GlobalAlias::create(GV->getValueType(), GV->getAddressSpace(), + GV->getLinkage(), "", Aliasee, &M); + Alias->setVisibility(GV->getVisibility()); + Alias->takeName(GV); + GV->replaceAllUsesWith(Alias); + GV->eraseFromParent(); +} + +void HWAddressSanitizer::instrumentGlobals() { + // Start by creating a note that contains pointers to the list of global + // descriptors. Adding a note to the output file will cause the linker to + // create a PT_NOTE program header pointing to the note that we can use to + // find the descriptor list starting from the program headers. A function + // provided by the runtime initializes the shadow memory for the globals by + // accessing the descriptor list via the note. The dynamic loader needs to + // call this function whenever a library is loaded. + // + // The reason why we use a note for this instead of a more conventional + // approach of having a global constructor pass a descriptor list pointer to + // the runtime is because of an order of initialization problem. With + // constructors we can encounter the following problematic scenario: + // + // 1) library A depends on library B and also interposes one of B's symbols + // 2) B's constructors are called before A's (as required for correctness) + // 3) during construction, B accesses one of its "own" globals (actually + // interposed by A) and triggers a HWASAN failure due to the initialization + // for A not having happened yet + // + // Even without interposition it is possible to run into similar situations in + // cases where two libraries mutually depend on each other. + // + // We only need one note per binary, so put everything for the note in a + // comdat. + Comdat *NoteComdat = M.getOrInsertComdat(kHwasanNoteName); + + Type *Int8Arr0Ty = ArrayType::get(Int8Ty, 0); + auto Start = + new GlobalVariable(M, Int8Arr0Ty, true, GlobalVariable::ExternalLinkage, + nullptr, "__start_hwasan_globals"); + Start->setVisibility(GlobalValue::HiddenVisibility); + Start->setDSOLocal(true); + auto Stop = + new GlobalVariable(M, Int8Arr0Ty, true, GlobalVariable::ExternalLinkage, + nullptr, "__stop_hwasan_globals"); + Stop->setVisibility(GlobalValue::HiddenVisibility); + Stop->setDSOLocal(true); + + // Null-terminated so actually 8 bytes, which are required in order to align + // the note properly. + auto *Name = ConstantDataArray::get(*C, "LLVM\0\0\0"); + + auto *NoteTy = StructType::get(Int32Ty, Int32Ty, Int32Ty, Name->getType(), + Int32Ty, Int32Ty); + auto *Note = + new GlobalVariable(M, NoteTy, /*isConstantGlobal=*/true, + GlobalValue::PrivateLinkage, nullptr, kHwasanNoteName); + Note->setSection(".note.hwasan.globals"); + Note->setComdat(NoteComdat); + Note->setAlignment(Align(4)); + Note->setDSOLocal(true); + + // The pointers in the note need to be relative so that the note ends up being + // placed in rodata, which is the standard location for notes. + auto CreateRelPtr = [&](Constant *Ptr) { + return ConstantExpr::getTrunc( + ConstantExpr::getSub(ConstantExpr::getPtrToInt(Ptr, Int64Ty), + ConstantExpr::getPtrToInt(Note, Int64Ty)), + Int32Ty); + }; + Note->setInitializer(ConstantStruct::getAnon( + {ConstantInt::get(Int32Ty, 8), // n_namesz + ConstantInt::get(Int32Ty, 8), // n_descsz + ConstantInt::get(Int32Ty, ELF::NT_LLVM_HWASAN_GLOBALS), // n_type + Name, CreateRelPtr(Start), CreateRelPtr(Stop)})); + appendToCompilerUsed(M, Note); + + // Create a zero-length global in hwasan_globals so that the linker will + // always create start and stop symbols. + auto Dummy = new GlobalVariable( + M, Int8Arr0Ty, /*isConstantGlobal*/ true, GlobalVariable::PrivateLinkage, + Constant::getNullValue(Int8Arr0Ty), "hwasan.dummy.global"); + Dummy->setSection("hwasan_globals"); + Dummy->setComdat(NoteComdat); + Dummy->setMetadata(LLVMContext::MD_associated, + MDNode::get(*C, ValueAsMetadata::get(Note))); + appendToCompilerUsed(M, Dummy); + + std::vector<GlobalVariable *> Globals; + for (GlobalVariable &GV : M.globals()) { + if (GV.isDeclarationForLinker() || GV.getName().startswith("llvm.") || + GV.isThreadLocal()) + continue; + + // Common symbols can't have aliases point to them, so they can't be tagged. + if (GV.hasCommonLinkage()) + continue; + + // Globals with custom sections may be used in __start_/__stop_ enumeration, + // which would be broken both by adding tags and potentially by the extra + // padding/alignment that we insert. + if (GV.hasSection()) + continue; + + Globals.push_back(&GV); + } + + MD5 Hasher; + Hasher.update(M.getSourceFileName()); + MD5::MD5Result Hash; + Hasher.final(Hash); + uint8_t Tag = Hash[0]; + + for (GlobalVariable *GV : Globals) { + // Skip tag 0 in order to avoid collisions with untagged memory. + if (Tag == 0) + Tag = 1; + instrumentGlobal(GV, Tag++); + } +} + +void HWAddressSanitizer::instrumentPersonalityFunctions() { + // We need to untag stack frames as we unwind past them. That is the job of + // the personality function wrapper, which either wraps an existing + // personality function or acts as a personality function on its own. Each + // function that has a personality function or that can be unwound past has + // its personality function changed to a thunk that calls the personality + // function wrapper in the runtime. + MapVector<Constant *, std::vector<Function *>> PersonalityFns; + for (Function &F : M) { + if (F.isDeclaration() || !F.hasFnAttribute(Attribute::SanitizeHWAddress)) + continue; + + if (F.hasPersonalityFn()) { + PersonalityFns[F.getPersonalityFn()->stripPointerCasts()].push_back(&F); + } else if (!F.hasFnAttribute(Attribute::NoUnwind)) { + PersonalityFns[nullptr].push_back(&F); + } + } + + if (PersonalityFns.empty()) + return; + + FunctionCallee HwasanPersonalityWrapper = M.getOrInsertFunction( + "__hwasan_personality_wrapper", Int32Ty, Int32Ty, Int32Ty, Int64Ty, + Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy); + FunctionCallee UnwindGetGR = M.getOrInsertFunction("_Unwind_GetGR", VoidTy); + FunctionCallee UnwindGetCFA = M.getOrInsertFunction("_Unwind_GetCFA", VoidTy); + + for (auto &P : PersonalityFns) { + std::string ThunkName = kHwasanPersonalityThunkName; + if (P.first) + ThunkName += ("." + P.first->getName()).str(); + FunctionType *ThunkFnTy = FunctionType::get( + Int32Ty, {Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int8PtrTy}, false); + bool IsLocal = P.first && (!isa<GlobalValue>(P.first) || + cast<GlobalValue>(P.first)->hasLocalLinkage()); + auto *ThunkFn = Function::Create(ThunkFnTy, + IsLocal ? GlobalValue::InternalLinkage + : GlobalValue::LinkOnceODRLinkage, + ThunkName, &M); + if (!IsLocal) { + ThunkFn->setVisibility(GlobalValue::HiddenVisibility); + ThunkFn->setComdat(M.getOrInsertComdat(ThunkName)); + } + + auto *BB = BasicBlock::Create(*C, "entry", ThunkFn); + IRBuilder<> IRB(BB); + CallInst *WrapperCall = IRB.CreateCall( + HwasanPersonalityWrapper, + {ThunkFn->getArg(0), ThunkFn->getArg(1), ThunkFn->getArg(2), + ThunkFn->getArg(3), ThunkFn->getArg(4), + P.first ? IRB.CreateBitCast(P.first, Int8PtrTy) + : Constant::getNullValue(Int8PtrTy), + IRB.CreateBitCast(UnwindGetGR.getCallee(), Int8PtrTy), + IRB.CreateBitCast(UnwindGetCFA.getCallee(), Int8PtrTy)}); + WrapperCall->setTailCall(); + IRB.CreateRet(WrapperCall); + + for (Function *F : P.second) + F->setPersonalityFn(ThunkFn); + } +} + void HWAddressSanitizer::ShadowMapping::init(Triple &TargetTriple) { Scale = kDefaultShadowScale; if (ClMappingOffset.getNumOccurrences() > 0) { diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index c7371f567ff3..74d6e76eceb6 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -403,7 +403,7 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, AM->getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); } else { - OwnedORE = llvm::make_unique<OptimizationRemarkEmitter>(&F); + OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F); ORE = OwnedORE.get(); } diff --git a/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/lib/Transforms/Instrumentation/InstrOrderFile.cpp index a2c1ddfd279e..93d3a8a14d5c 100644 --- a/lib/Transforms/Instrumentation/InstrOrderFile.cpp +++ b/lib/Transforms/Instrumentation/InstrOrderFile.cpp @@ -100,7 +100,8 @@ public: if (!ClOrderFileWriteMapping.empty()) { std::lock_guard<std::mutex> LogLock(MappingMutex); std::error_code EC; - llvm::raw_fd_ostream OS(ClOrderFileWriteMapping, EC, llvm::sys::fs::F_Append); + llvm::raw_fd_ostream OS(ClOrderFileWriteMapping, EC, + llvm::sys::fs::OF_Append); if (EC) { report_fatal_error(Twine("Failed to open ") + ClOrderFileWriteMapping + " to save mapping file for order file instrumentation\n"); diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index 63c2b8078967..1f092a5f3103 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -157,7 +157,10 @@ public: } bool runOnModule(Module &M) override { - return InstrProf.run(M, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI()); + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + }; + return InstrProf.run(M, GetTLI); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -370,8 +373,12 @@ private: } // end anonymous namespace PreservedAnalyses InstrProfiling::run(Module &M, ModuleAnalysisManager &AM) { - auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); - if (!run(M, TLI)) + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult<TargetLibraryAnalysis>(F); + }; + if (!run(M, GetTLI)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -441,7 +448,7 @@ void InstrProfiling::promoteCounterLoadStores(Function *F) { std::unique_ptr<BlockFrequencyInfo> BFI; if (Options.UseBFIInPromotion) { std::unique_ptr<BranchProbabilityInfo> BPI; - BPI.reset(new BranchProbabilityInfo(*F, LI, TLI)); + BPI.reset(new BranchProbabilityInfo(*F, LI, &GetTLI(*F))); BFI.reset(new BlockFrequencyInfo(*F, *BPI, LI)); } @@ -482,9 +489,10 @@ static bool containsProfilingIntrinsics(Module &M) { return false; } -bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { +bool InstrProfiling::run( + Module &M, std::function<const TargetLibraryInfo &(Function &F)> GetTLI) { this->M = &M; - this->TLI = &TLI; + this->GetTLI = std::move(GetTLI); NamesVar = nullptr; NamesSize = 0; ProfileDataMap.clear(); @@ -601,6 +609,7 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { bool IsRange = (Ind->getValueKind()->getZExtValue() == llvm::InstrProfValueKind::IPVK_MemOPSize); CallInst *Call = nullptr; + auto *TLI = &GetTLI(*Ind->getFunction()); if (!IsRange) { Value *Args[3] = {Ind->getTargetValue(), Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), @@ -731,9 +740,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { PD = It->second; } - // Match the linkage and visibility of the name global, except on COFF, where - // the linkage must be local and consequentially the visibility must be - // default. + // Match the linkage and visibility of the name global. COFF supports using + // comdats with internal symbols, so do that if we can. Function *Fn = Inc->getParent()->getParent(); GlobalValue::LinkageTypes Linkage = NamePtr->getLinkage(); GlobalValue::VisibilityTypes Visibility = NamePtr->getVisibility(); @@ -749,19 +757,21 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // new comdat group for the counters and profiling data. If we use the comdat // of the parent function, that will result in relocations against discarded // sections. - Comdat *Cmdt = nullptr; - GlobalValue::LinkageTypes CounterLinkage = Linkage; - if (needsComdatForCounter(*Fn, *M)) { - StringRef CmdtPrefix = getInstrProfComdatPrefix(); + bool NeedComdat = needsComdatForCounter(*Fn, *M); + if (NeedComdat) { if (TT.isOSBinFormatCOFF()) { - // For COFF, the comdat group name must be the name of a symbol in the - // group. Use the counter variable name, and upgrade its linkage to - // something externally visible, like linkonce_odr. - CmdtPrefix = getInstrProfCountersVarPrefix(); - CounterLinkage = GlobalValue::LinkOnceODRLinkage; + // For COFF, put the counters, data, and values each into their own + // comdats. We can't use a group because the Visual C++ linker will + // report duplicate symbol errors if there are multiple external symbols + // with the same name marked IMAGE_COMDAT_SELECT_ASSOCIATIVE. + Linkage = GlobalValue::LinkOnceODRLinkage; + Visibility = GlobalValue::HiddenVisibility; } - Cmdt = M->getOrInsertComdat(getVarName(Inc, CmdtPrefix)); } + auto MaybeSetComdat = [=](GlobalVariable *GV) { + if (NeedComdat) + GV->setComdat(M->getOrInsertComdat(GV->getName())); + }; uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); LLVMContext &Ctx = M->getContext(); @@ -775,9 +785,9 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { CounterPtr->setVisibility(Visibility); CounterPtr->setSection( getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat())); - CounterPtr->setAlignment(8); - CounterPtr->setComdat(Cmdt); - CounterPtr->setLinkage(CounterLinkage); + CounterPtr->setAlignment(Align(8)); + MaybeSetComdat(CounterPtr); + CounterPtr->setLinkage(Linkage); auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); // Allocate statically the array of pointers to value profile nodes for @@ -797,8 +807,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { ValuesVar->setVisibility(Visibility); ValuesVar->setSection( getInstrProfSectionName(IPSK_vals, TT.getObjectFormat())); - ValuesVar->setAlignment(8); - ValuesVar->setComdat(Cmdt); + ValuesVar->setAlignment(Align(8)); + MaybeSetComdat(ValuesVar); ValuesPtrExpr = ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx)); } @@ -830,8 +840,9 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { getVarName(Inc, getInstrProfDataVarPrefix())); Data->setVisibility(Visibility); Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat())); - Data->setAlignment(INSTR_PROF_DATA_ALIGNMENT); - Data->setComdat(Cmdt); + Data->setAlignment(Align(INSTR_PROF_DATA_ALIGNMENT)); + MaybeSetComdat(Data); + Data->setLinkage(Linkage); PD.RegionCounters = CounterPtr; PD.DataVar = Data; @@ -920,7 +931,7 @@ void InstrProfiling::emitNameData() { // On COFF, it's important to reduce the alignment down to 1 to prevent the // linker from inserting padding before the start of the names section or // between names entries. - NamesVar->setAlignment(1); + NamesVar->setAlignment(Align::None()); UsedVars.push_back(NamesVar); for (auto *NamePtr : ReferencedNames) diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index f56a1bd91b89..a6c2c9b464b6 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -68,7 +68,8 @@ GlobalVariable *llvm::createPrivateGlobalForString(Module &M, StringRef Str, GlobalValue::PrivateLinkage, StrConst, NamePrefix); if (AllowMerging) GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - GV->setAlignment(1); // Strings may not be merged w/o setting align 1. + GV->setAlignment(Align::None()); // Strings may not be merged w/o setting + // alignment explicitly. return GV; } @@ -116,7 +117,7 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) { initializeMemorySanitizerLegacyPassPass(Registry); initializeHWAddressSanitizerLegacyPassPass(Registry); initializeThreadSanitizerLegacyPassPass(Registry); - initializeSanitizerCoverageModulePass(Registry); + initializeModuleSanitizerCoverageLegacyPassPass(Registry); initializeDataFlowSanitizerPass(Registry); } diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index b25cbed1bb02..69c9020e060b 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -462,16 +462,9 @@ namespace { /// the module. class MemorySanitizer { public: - MemorySanitizer(Module &M, MemorySanitizerOptions Options) { - this->CompileKernel = - ClEnableKmsan.getNumOccurrences() > 0 ? ClEnableKmsan : Options.Kernel; - if (ClTrackOrigins.getNumOccurrences() > 0) - this->TrackOrigins = ClTrackOrigins; - else - this->TrackOrigins = this->CompileKernel ? 2 : Options.TrackOrigins; - this->Recover = ClKeepGoing.getNumOccurrences() > 0 - ? ClKeepGoing - : (this->CompileKernel | Options.Recover); + MemorySanitizer(Module &M, MemorySanitizerOptions Options) + : CompileKernel(Options.Kernel), TrackOrigins(Options.TrackOrigins), + Recover(Options.Recover) { initializeModule(M); } @@ -594,10 +587,26 @@ private: /// An empty volatile inline asm that prevents callback merge. InlineAsm *EmptyAsm; - - Function *MsanCtorFunction; }; +void insertModuleCtor(Module &M) { + getOrCreateSanitizerCtorAndInitFunctions( + M, kMsanModuleCtorName, kMsanInitName, + /*InitArgTypes=*/{}, + /*InitArgs=*/{}, + // This callback is invoked when the functions are created the first + // time. Hook them into the global ctors list in that case: + [&](Function *Ctor, FunctionCallee) { + if (!ClWithComdat) { + appendToGlobalCtors(M, Ctor, 0); + return; + } + Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName); + Ctor->setComdat(MsanCtorComdat); + appendToGlobalCtors(M, Ctor, 0, Ctor); + }); +} + /// A legacy function pass for msan instrumentation. /// /// Instruments functions to detect unitialized reads. @@ -615,7 +624,7 @@ struct MemorySanitizerLegacyPass : public FunctionPass { bool runOnFunction(Function &F) override { return MSan->sanitizeFunction( - F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI()); + F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F)); } bool doInitialization(Module &M) override; @@ -623,8 +632,17 @@ struct MemorySanitizerLegacyPass : public FunctionPass { MemorySanitizerOptions Options; }; +template <class T> T getOptOrDefault(const cl::opt<T> &Opt, T Default) { + return (Opt.getNumOccurrences() > 0) ? Opt : Default; +} + } // end anonymous namespace +MemorySanitizerOptions::MemorySanitizerOptions(int TO, bool R, bool K) + : Kernel(getOptOrDefault(ClEnableKmsan, K)), + TrackOrigins(getOptOrDefault(ClTrackOrigins, Kernel ? 2 : TO)), + Recover(getOptOrDefault(ClKeepGoing, Kernel || R)) {} + PreservedAnalyses MemorySanitizerPass::run(Function &F, FunctionAnalysisManager &FAM) { MemorySanitizer Msan(*F.getParent(), Options); @@ -633,6 +651,14 @@ PreservedAnalyses MemorySanitizerPass::run(Function &F, return PreservedAnalyses::all(); } +PreservedAnalyses MemorySanitizerPass::run(Module &M, + ModuleAnalysisManager &AM) { + if (Options.Kernel) + return PreservedAnalyses::all(); + insertModuleCtor(M); + return PreservedAnalyses::none(); +} + char MemorySanitizerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(MemorySanitizerLegacyPass, "msan", @@ -918,23 +944,6 @@ void MemorySanitizer::initializeModule(Module &M) { OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000); if (!CompileKernel) { - std::tie(MsanCtorFunction, std::ignore) = - getOrCreateSanitizerCtorAndInitFunctions( - M, kMsanModuleCtorName, kMsanInitName, - /*InitArgTypes=*/{}, - /*InitArgs=*/{}, - // This callback is invoked when the functions are created the first - // time. Hook them into the global ctors list in that case: - [&](Function *Ctor, FunctionCallee) { - if (!ClWithComdat) { - appendToGlobalCtors(M, Ctor, 0); - return; - } - Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName); - Ctor->setComdat(MsanCtorComdat); - appendToGlobalCtors(M, Ctor, 0, Ctor); - }); - if (TrackOrigins) M.getOrInsertGlobal("__msan_track_origins", IRB.getInt32Ty(), [&] { return new GlobalVariable( @@ -952,6 +961,8 @@ void MemorySanitizer::initializeModule(Module &M) { } bool MemorySanitizerLegacyPass::doInitialization(Module &M) { + if (!Options.Kernel) + insertModuleCtor(M); MSan.emplace(M, Options); return true; } @@ -2562,6 +2573,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return false; } + void handleInvariantGroup(IntrinsicInst &I) { + setShadow(&I, getShadow(&I, 0)); + setOrigin(&I, getOrigin(&I, 0)); + } + void handleLifetimeStart(IntrinsicInst &I) { if (!PoisonStack) return; @@ -2993,6 +3009,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { case Intrinsic::lifetime_start: handleLifetimeStart(I); break; + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + handleInvariantGroup(I); + break; case Intrinsic::bswap: handleBswap(I); break; @@ -3627,10 +3647,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { int getNumOutputArgs(InlineAsm *IA, CallBase *CB) { int NumRetOutputs = 0; int NumOutputs = 0; - Type *RetTy = dyn_cast<Value>(CB)->getType(); + Type *RetTy = cast<Value>(CB)->getType(); if (!RetTy->isVoidTy()) { // Register outputs are returned via the CallInst return value. - StructType *ST = dyn_cast_or_null<StructType>(RetTy); + auto *ST = dyn_cast<StructType>(RetTy); if (ST) NumRetOutputs = ST->getNumElements(); else @@ -3667,7 +3687,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // corresponding CallInst has nO+nI+1 operands (the last operand is the // function to be called). const DataLayout &DL = F.getParent()->getDataLayout(); - CallBase *CB = dyn_cast<CallBase>(&I); + CallBase *CB = cast<CallBase>(&I); IRBuilder<> IRB(&I); InlineAsm *IA = cast<InlineAsm>(CB->getCalledValue()); int OutputArgs = getNumOutputArgs(IA, CB); @@ -4567,8 +4587,9 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, } bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) { - if (!CompileKernel && (&F == MsanCtorFunction)) + if (!CompileKernel && F.getName() == kMsanModuleCtorName) return false; + MemorySanitizerVisitor Visitor(F, *this, TLI); // Clear out readonly/readnone attributes. diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 6fec3c9c79ee..ca1bb62389e9 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -48,6 +48,7 @@ //===----------------------------------------------------------------------===// #include "CFGMST.h" +#include "ValueProfileCollector.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -61,7 +62,6 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" @@ -96,6 +96,7 @@ #include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/InstrProfReader.h" #include "llvm/Support/BranchProbability.h" +#include "llvm/Support/CRC.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DOTGraphTraits.h" @@ -103,11 +104,11 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" -#include "llvm/Support/JamCRC.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/MisExpect.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -120,6 +121,7 @@ using namespace llvm; using ProfileCount = Function::ProfileCount; +using VPCandidateInfo = ValueProfileCollector::CandidateInfo; #define DEBUG_TYPE "pgo-instrumentation" @@ -286,6 +288,11 @@ static std::string getBranchCondString(Instruction *TI) { return result; } +static const char *ValueProfKindDescr[] = { +#define VALUE_PROF_KIND(Enumerator, Value, Descr) Descr, +#include "llvm/ProfileData/InstrProfData.inc" +}; + namespace { /// The select instruction visitor plays three roles specified @@ -348,50 +355,6 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { unsigned getNumOfSelectInsts() const { return NSIs; } }; -/// Instruction Visitor class to visit memory intrinsic calls. -struct MemIntrinsicVisitor : public InstVisitor<MemIntrinsicVisitor> { - Function &F; - unsigned NMemIs = 0; // Number of memIntrinsics instrumented. - VisitMode Mode = VM_counting; // Visiting mode. - unsigned CurCtrId = 0; // Current counter index. - unsigned TotalNumCtrs = 0; // Total number of counters - GlobalVariable *FuncNameVar = nullptr; - uint64_t FuncHash = 0; - PGOUseFunc *UseFunc = nullptr; - std::vector<Instruction *> Candidates; - - MemIntrinsicVisitor(Function &Func) : F(Func) {} - - void countMemIntrinsics(Function &Func) { - NMemIs = 0; - Mode = VM_counting; - visit(Func); - } - - void instrumentMemIntrinsics(Function &Func, unsigned TotalNC, - GlobalVariable *FNV, uint64_t FHash) { - Mode = VM_instrument; - TotalNumCtrs = TotalNC; - FuncHash = FHash; - FuncNameVar = FNV; - visit(Func); - } - - std::vector<Instruction *> findMemIntrinsics(Function &Func) { - Candidates.clear(); - Mode = VM_annotate; - visit(Func); - return Candidates; - } - - // Visit the IR stream and annotate all mem intrinsic call instructions. - void instrumentOneMemIntrinsic(MemIntrinsic &MI); - - // Visit \p MI instruction and perform tasks according to visit mode. - void visitMemIntrinsic(MemIntrinsic &SI); - - unsigned getNumOfMemIntrinsics() const { return NMemIs; } -}; class PGOInstrumentationGenLegacyPass : public ModulePass { public: @@ -563,13 +526,14 @@ private: // A map that stores the Comdat group in function F. std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; + ValueProfileCollector VPC; + void computeCFGHash(); void renameComdatFunction(); public: - std::vector<std::vector<Instruction *>> ValueSites; + std::vector<std::vector<VPCandidateInfo>> ValueSites; SelectInstVisitor SIVisitor; - MemIntrinsicVisitor MIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; @@ -604,23 +568,21 @@ public: std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr, bool IsCS = false) - : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), - ValueSites(IPVK_Last + 1), SIVisitor(Func), MIVisitor(Func), - MST(F, BPI, BFI) { + : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func), + ValueSites(IPVK_Last + 1), SIVisitor(Func), MST(F, BPI, BFI) { // This should be done before CFG hash computation. SIVisitor.countSelects(Func); - MIVisitor.countMemIntrinsics(Func); + ValueSites[IPVK_MemOPSize] = VPC.get(IPVK_MemOPSize); if (!IsCS) { NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); - NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); + NumOfPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size(); NumOfPGOBB += MST.BBInfos.size(); - ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func); + ValueSites[IPVK_IndirectCallTarget] = VPC.get(IPVK_IndirectCallTarget); } else { NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); - NumOfCSPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); + NumOfCSPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size(); NumOfCSPGOBB += MST.BBInfos.size(); } - ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func); FuncName = getPGOFuncName(F); computeCFGHash(); @@ -647,7 +609,7 @@ public: // value of each BB in the CFG. The higher 32 bits record the number of edges. template <class Edge, class BBInfo> void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { - std::vector<char> Indexes; + std::vector<uint8_t> Indexes; JamCRC JC; for (auto &BB : F) { const Instruction *TI = BB.getTerminator(); @@ -658,7 +620,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { continue; uint32_t Index = BI->Index; for (int J = 0; J < 4; J++) - Indexes.push_back((char)(Index >> (J * 8))); + Indexes.push_back((uint8_t)(Index >> (J * 8))); } } JC.update(Indexes); @@ -874,28 +836,36 @@ static void instrumentOneFunc( if (DisableValueProfiling) return; - unsigned NumIndirectCalls = 0; - for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) { - CallSite CS(I); - Value *Callee = CS.getCalledValue(); - LLVM_DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " - << NumIndirectCalls << "\n"); - IRBuilder<> Builder(I); - assert(Builder.GetInsertPoint() != I->getParent()->end() && - "Cannot get the Instrumentation point"); - Builder.CreateCall( - Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), - {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), - Builder.getInt64(FuncInfo.FunctionHash), - Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()), - Builder.getInt32(IPVK_IndirectCallTarget), - Builder.getInt32(NumIndirectCalls++)}); - } - NumOfPGOICall += NumIndirectCalls; + NumOfPGOICall += FuncInfo.ValueSites[IPVK_IndirectCallTarget].size(); - // Now instrument memop intrinsic calls. - FuncInfo.MIVisitor.instrumentMemIntrinsics( - F, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash); + // For each VP Kind, walk the VP candidates and instrument each one. + for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) { + unsigned SiteIndex = 0; + if (Kind == IPVK_MemOPSize && !PGOInstrMemOP) + continue; + + for (VPCandidateInfo Cand : FuncInfo.ValueSites[Kind]) { + LLVM_DEBUG(dbgs() << "Instrument one VP " << ValueProfKindDescr[Kind] + << " site: CallSite Index = " << SiteIndex << "\n"); + + IRBuilder<> Builder(Cand.InsertPt); + assert(Builder.GetInsertPoint() != Cand.InsertPt->getParent()->end() && + "Cannot get the Instrumentation point"); + + Value *ToProfile = nullptr; + if (Cand.V->getType()->isIntegerTy()) + ToProfile = Builder.CreateZExtOrTrunc(Cand.V, Builder.getInt64Ty()); + else if (Cand.V->getType()->isPointerTy()) + ToProfile = Builder.CreatePtrToInt(Cand.V, Builder.getInt64Ty()); + assert(ToProfile && "value profiling Value is of unexpected type"); + + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), + {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), + Builder.getInt64(FuncInfo.FunctionHash), ToProfile, + Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)}); + } + } // IPVK_First <= Kind <= IPVK_Last } namespace { @@ -984,9 +954,9 @@ class PGOUseFunc { public: PGOUseFunc(Function &Func, Module *Modu, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, - BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFIin = nullptr, bool IsCS = false) - : F(Func), M(Modu), BFI(BFIin), + BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFIin, + ProfileSummaryInfo *PSI, bool IsCS) + : F(Func), M(Modu), BFI(BFIin), PSI(PSI), FuncInfo(Func, ComdatMembers, false, BPI, BFIin, IsCS), FreqAttr(FFA_Normal), IsCS(IsCS) {} @@ -1041,6 +1011,7 @@ private: Function &F; Module *M; BlockFrequencyInfo *BFI; + ProfileSummaryInfo *PSI; // This member stores the shared information with class PGOGenFunc. FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo; @@ -1078,15 +1049,9 @@ private: // FIXME: This function should be removed once the functionality in // the inliner is implemented. void markFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) { - if (ProgramMaxCount == 0) - return; - // Threshold of the hot functions. - const BranchProbability HotFunctionThreshold(1, 100); - // Threshold of the cold functions. - const BranchProbability ColdFunctionThreshold(2, 10000); - if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount)) + if (PSI->isHotCount(EntryCount)) FreqAttr = FFA_Hot; - else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount)) + else if (PSI->isColdCount(MaxCount)) FreqAttr = FFA_Cold; } }; @@ -1433,43 +1398,6 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) { llvm_unreachable("Unknown visiting mode"); } -void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) { - Module *M = F.getParent(); - IRBuilder<> Builder(&MI); - Type *Int64Ty = Builder.getInt64Ty(); - Type *I8PtrTy = Builder.getInt8PtrTy(); - Value *Length = MI.getLength(); - assert(!isa<ConstantInt>(Length)); - Builder.CreateCall( - Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), - {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), - Builder.getInt64(FuncHash), Builder.CreateZExtOrTrunc(Length, Int64Ty), - Builder.getInt32(IPVK_MemOPSize), Builder.getInt32(CurCtrId)}); - ++CurCtrId; -} - -void MemIntrinsicVisitor::visitMemIntrinsic(MemIntrinsic &MI) { - if (!PGOInstrMemOP) - return; - Value *Length = MI.getLength(); - // Not instrument constant length calls. - if (dyn_cast<ConstantInt>(Length)) - return; - - switch (Mode) { - case VM_counting: - NMemIs++; - return; - case VM_instrument: - instrumentOneMemIntrinsic(MI); - return; - case VM_annotate: - Candidates.push_back(&MI); - return; - } - llvm_unreachable("Unknown visiting mode"); -} - // Traverse all valuesites and annotate the instructions for all value kind. void PGOUseFunc::annotateValueSites() { if (DisableValueProfiling) @@ -1482,11 +1410,6 @@ void PGOUseFunc::annotateValueSites() { annotateValueSites(Kind); } -static const char *ValueProfKindDescr[] = { -#define VALUE_PROF_KIND(Enumerator, Value, Descr) Descr, -#include "llvm/ProfileData/InstrProfData.inc" -}; - // Annotate the instructions for a specific value kind. void PGOUseFunc::annotateValueSites(uint32_t Kind) { assert(Kind <= IPVK_Last); @@ -1505,11 +1428,11 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { return; } - for (auto &I : ValueSites) { + for (VPCandidateInfo &I : ValueSites) { LLVM_DEBUG(dbgs() << "Read one value site profile (kind = " << Kind << "): Index = " << ValueSiteIndex << " out of " << NumValueSites << "\n"); - annotateValueSite(*M, *I, ProfileRecord, + annotateValueSite(*M, *I.AnnotatedInst, ProfileRecord, static_cast<InstrProfValueKind>(Kind), ValueSiteIndex, Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations : MaxNumAnnotations); @@ -1595,7 +1518,8 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M, static bool annotateAllFunctions( Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, - function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) { + function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, + ProfileSummaryInfo *PSI, bool IsCS) { LLVM_DEBUG(dbgs() << "Read in profile counters: "); auto &Ctx = M.getContext(); // Read the counter array from file. @@ -1626,6 +1550,13 @@ static bool annotateAllFunctions( return false; } + // Add the profile summary (read from the header of the indexed summary) here + // so that we can use it below when reading counters (which checks if the + // function should be marked with a cold or inlinehint attribute). + M.setProfileSummary(PGOReader->getSummary(IsCS).getMD(M.getContext()), + IsCS ? ProfileSummary::PSK_CSInstr + : ProfileSummary::PSK_Instr); + std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; collectComdatMembers(M, ComdatMembers); std::vector<Function *> HotFunctions; @@ -1638,7 +1569,7 @@ static bool annotateAllFunctions( // Split indirectbr critical edges here before computing the MST rather than // later in getInstrBB() to avoid invalidating it. SplitIndirectBrCriticalEdges(F, BPI, BFI); - PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI, IsCS); + PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI, PSI, IsCS); bool AllZeros = false; if (!Func.readCounters(PGOReader.get(), AllZeros)) continue; @@ -1662,9 +1593,9 @@ static bool annotateAllFunctions( F.getName().equals(ViewBlockFreqFuncName))) { LoopInfo LI{DominatorTree(F)}; std::unique_ptr<BranchProbabilityInfo> NewBPI = - llvm::make_unique<BranchProbabilityInfo>(F, LI); + std::make_unique<BranchProbabilityInfo>(F, LI); std::unique_ptr<BlockFrequencyInfo> NewBFI = - llvm::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI); + std::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI); if (PGOViewCounts == PGOVCT_Graph) NewBFI->view(); else if (PGOViewCounts == PGOVCT_Text) { @@ -1686,9 +1617,6 @@ static bool annotateAllFunctions( } } } - M.setProfileSummary(PGOReader->getSummary(IsCS).getMD(M.getContext()), - IsCS ? ProfileSummary::PSK_CSInstr - : ProfileSummary::PSK_Instr); // Set function hotness attribute from the profile. // We have to apply these attributes at the end because their presence @@ -1730,8 +1658,10 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, return &FAM.getResult<BlockFrequencyAnalysis>(F); }; + auto *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, - LookupBPI, LookupBFI, IsCS)) + LookupBPI, LookupBFI, PSI, IsCS)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -1748,7 +1678,8 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI, + auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI, PSI, IsCS); } @@ -1776,6 +1707,9 @@ void llvm::setProfMetadata(Module *M, Instruction *TI, : Weights) { dbgs() << W << " "; } dbgs() << "\n";); + + misexpect::verifyMisExpect(TI, Weights, TI->getContext()); + TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); if (EmitBranchProbability) { std::string BrCondStr = getBranchCondString(TI); diff --git a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index 188f95b4676b..9f81bb16d0a7 100644 --- a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -138,7 +138,7 @@ public: OptimizationRemarkEmitter &ORE, DominatorTree *DT) : Func(Func), BFI(BFI), ORE(ORE), DT(DT), Changed(false) { ValueDataArray = - llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2); + std::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2); // Get the MemOPSize range information from option MemOPSizeRange, getMemOPSizeRangeFromOption(MemOPSizeRange, PreciseRangeStart, PreciseRangeLast); @@ -374,8 +374,8 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { Ctx, Twine("MemOP.Case.") + Twine(SizeId), &Func, DefaultBB); Instruction *NewInst = MI->clone(); // Fix the argument. - MemIntrinsic * MemI = dyn_cast<MemIntrinsic>(NewInst); - IntegerType *SizeType = dyn_cast<IntegerType>(MemI->getLength()->getType()); + auto *MemI = cast<MemIntrinsic>(NewInst); + auto *SizeType = dyn_cast<IntegerType>(MemI->getLength()->getType()); assert(SizeType && "Expected integer type size argument."); ConstantInt *CaseSizeId = ConstantInt::get(SizeType, SizeId); MemI->setLength(CaseSizeId); diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index ca0cb4bdbe84..f8fa9cad03b8 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Instrumentation/SanitizerCoverage.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/EHPersonalities.h" @@ -176,24 +177,21 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { return Options; } -class SanitizerCoverageModule : public ModulePass { +using DomTreeCallback = function_ref<const DominatorTree *(Function &F)>; +using PostDomTreeCallback = + function_ref<const PostDominatorTree *(Function &F)>; + +class ModuleSanitizerCoverage { public: - SanitizerCoverageModule( + ModuleSanitizerCoverage( const SanitizerCoverageOptions &Options = SanitizerCoverageOptions()) - : ModulePass(ID), Options(OverrideFromCL(Options)) { - initializeSanitizerCoverageModulePass(*PassRegistry::getPassRegistry()); - } - bool runOnModule(Module &M) override; - bool runOnFunction(Function &F); - static char ID; // Pass identification, replacement for typeid - StringRef getPassName() const override { return "SanitizerCoverageModule"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - } + : Options(OverrideFromCL(Options)) {} + bool instrumentModule(Module &M, DomTreeCallback DTCallback, + PostDomTreeCallback PDTCallback); private: + void instrumentFunction(Function &F, DomTreeCallback DTCallback, + PostDomTreeCallback PDTCallback); void InjectCoverageForIndirectCalls(Function &F, ArrayRef<Instruction *> IndirCalls); void InjectTraceForCmp(Function &F, ArrayRef<Instruction *> CmpTraceTargets); @@ -252,10 +250,57 @@ private: SanitizerCoverageOptions Options; }; +class ModuleSanitizerCoverageLegacyPass : public ModulePass { +public: + ModuleSanitizerCoverageLegacyPass( + const SanitizerCoverageOptions &Options = SanitizerCoverageOptions()) + : ModulePass(ID), Options(Options) { + initializeModuleSanitizerCoverageLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + bool runOnModule(Module &M) override { + ModuleSanitizerCoverage ModuleSancov(Options); + auto DTCallback = [this](Function &F) -> const DominatorTree * { + return &this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + }; + auto PDTCallback = [this](Function &F) -> const PostDominatorTree * { + return &this->getAnalysis<PostDominatorTreeWrapperPass>(F) + .getPostDomTree(); + }; + return ModuleSancov.instrumentModule(M, DTCallback, PDTCallback); + } + + static char ID; // Pass identification, replacement for typeid + StringRef getPassName() const override { return "ModuleSanitizerCoverage"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + } + +private: + SanitizerCoverageOptions Options; +}; + } // namespace +PreservedAnalyses ModuleSanitizerCoveragePass::run(Module &M, + ModuleAnalysisManager &MAM) { + ModuleSanitizerCoverage ModuleSancov(Options); + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto DTCallback = [&FAM](Function &F) -> const DominatorTree * { + return &FAM.getResult<DominatorTreeAnalysis>(F); + }; + auto PDTCallback = [&FAM](Function &F) -> const PostDominatorTree * { + return &FAM.getResult<PostDominatorTreeAnalysis>(F); + }; + if (ModuleSancov.instrumentModule(M, DTCallback, PDTCallback)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + std::pair<Value *, Value *> -SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section, +ModuleSanitizerCoverage::CreateSecStartEnd(Module &M, const char *Section, Type *Ty) { GlobalVariable *SecStart = new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, nullptr, @@ -278,7 +323,7 @@ SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section, return std::make_pair(IRB.CreatePointerCast(GEP, Ty), SecEndPtr); } -Function *SanitizerCoverageModule::CreateInitCallsForSections( +Function *ModuleSanitizerCoverage::CreateInitCallsForSections( Module &M, const char *CtorName, const char *InitFunctionName, Type *Ty, const char *Section) { auto SecStartEnd = CreateSecStartEnd(M, Section, Ty); @@ -310,7 +355,8 @@ Function *SanitizerCoverageModule::CreateInitCallsForSections( return CtorFunc; } -bool SanitizerCoverageModule::runOnModule(Module &M) { +bool ModuleSanitizerCoverage::instrumentModule( + Module &M, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) { if (Options.CoverageType == SanitizerCoverageOptions::SCK_None) return false; C = &(M.getContext()); @@ -403,7 +449,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, Int32PtrTy); for (auto &F : M) - runOnFunction(F); + instrumentFunction(F, DTCallback, PDTCallback); Function *Ctor = nullptr; @@ -518,29 +564,30 @@ static bool IsInterestingCmp(ICmpInst *CMP, const DominatorTree *DT, return true; } -bool SanitizerCoverageModule::runOnFunction(Function &F) { +void ModuleSanitizerCoverage::instrumentFunction( + Function &F, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) { if (F.empty()) - return false; + return; if (F.getName().find(".module_ctor") != std::string::npos) - return false; // Should not instrument sanitizer init functions. + return; // Should not instrument sanitizer init functions. if (F.getName().startswith("__sanitizer_")) - return false; // Don't instrument __sanitizer_* callbacks. + return; // Don't instrument __sanitizer_* callbacks. // Don't touch available_externally functions, their actual body is elewhere. if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) - return false; + return; // Don't instrument MSVC CRT configuration helpers. They may run before normal // initialization. if (F.getName() == "__local_stdio_printf_options" || F.getName() == "__local_stdio_scanf_options") - return false; + return; if (isa<UnreachableInst>(F.getEntryBlock().getTerminator())) - return false; + return; // Don't instrument functions using SEH for now. Splitting basic blocks like // we do for coverage breaks WinEHPrepare. // FIXME: Remove this when SEH no longer uses landingpad pattern matching. if (F.hasPersonalityFn() && isAsynchronousEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) - return false; + return; if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge) SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions().setIgnoreUnreachableDests()); SmallVector<Instruction *, 8> IndirCalls; @@ -550,10 +597,8 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { SmallVector<BinaryOperator *, 8> DivTraceTargets; SmallVector<GetElementPtrInst *, 8> GepTraceTargets; - const DominatorTree *DT = - &getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); - const PostDominatorTree *PDT = - &getAnalysis<PostDominatorTreeWrapperPass>(F).getPostDomTree(); + const DominatorTree *DT = DTCallback(F); + const PostDominatorTree *PDT = PDTCallback(F); bool IsLeafFunc = true; for (auto &BB : F) { @@ -593,10 +638,9 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { InjectTraceForSwitch(F, SwitchTraceTargets); InjectTraceForDiv(F, DivTraceTargets); InjectTraceForGep(F, GepTraceTargets); - return true; } -GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection( +GlobalVariable *ModuleSanitizerCoverage::CreateFunctionLocalArrayInSection( size_t NumElements, Function &F, Type *Ty, const char *Section) { ArrayType *ArrayTy = ArrayType::get(Ty, NumElements); auto Array = new GlobalVariable( @@ -608,8 +652,9 @@ GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection( GetOrCreateFunctionComdat(F, TargetTriple, CurModuleUniqueId)) Array->setComdat(Comdat); Array->setSection(getSectionName(Section)); - Array->setAlignment(Ty->isPointerTy() ? DL->getPointerSize() - : Ty->getPrimitiveSizeInBits() / 8); + Array->setAlignment(Align(Ty->isPointerTy() + ? DL->getPointerSize() + : Ty->getPrimitiveSizeInBits() / 8)); GlobalsToAppendToUsed.push_back(Array); GlobalsToAppendToCompilerUsed.push_back(Array); MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F)); @@ -619,7 +664,7 @@ GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection( } GlobalVariable * -SanitizerCoverageModule::CreatePCArray(Function &F, +ModuleSanitizerCoverage::CreatePCArray(Function &F, ArrayRef<BasicBlock *> AllBlocks) { size_t N = AllBlocks.size(); assert(N); @@ -646,7 +691,7 @@ SanitizerCoverageModule::CreatePCArray(Function &F, return PCArray; } -void SanitizerCoverageModule::CreateFunctionLocalArrays( +void ModuleSanitizerCoverage::CreateFunctionLocalArrays( Function &F, ArrayRef<BasicBlock *> AllBlocks) { if (Options.TracePCGuard) FunctionGuardArray = CreateFunctionLocalArrayInSection( @@ -660,7 +705,7 @@ void SanitizerCoverageModule::CreateFunctionLocalArrays( FunctionPCsArray = CreatePCArray(F, AllBlocks); } -bool SanitizerCoverageModule::InjectCoverage(Function &F, +bool ModuleSanitizerCoverage::InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks, bool IsLeafFunc) { if (AllBlocks.empty()) return false; @@ -677,7 +722,7 @@ bool SanitizerCoverageModule::InjectCoverage(Function &F, // The cache is used to speed up recording the caller-callee pairs. // The address of the caller is passed implicitly via caller PC. // CacheSize is encoded in the name of the run-time function. -void SanitizerCoverageModule::InjectCoverageForIndirectCalls( +void ModuleSanitizerCoverage::InjectCoverageForIndirectCalls( Function &F, ArrayRef<Instruction *> IndirCalls) { if (IndirCalls.empty()) return; @@ -696,7 +741,7 @@ void SanitizerCoverageModule::InjectCoverageForIndirectCalls( // __sanitizer_cov_trace_switch(CondValue, // {NumCases, ValueSizeInBits, Case0Value, Case1Value, Case2Value, ... }) -void SanitizerCoverageModule::InjectTraceForSwitch( +void ModuleSanitizerCoverage::InjectTraceForSwitch( Function &, ArrayRef<Instruction *> SwitchTraceTargets) { for (auto I : SwitchTraceTargets) { if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { @@ -735,7 +780,7 @@ void SanitizerCoverageModule::InjectTraceForSwitch( } } -void SanitizerCoverageModule::InjectTraceForDiv( +void ModuleSanitizerCoverage::InjectTraceForDiv( Function &, ArrayRef<BinaryOperator *> DivTraceTargets) { for (auto BO : DivTraceTargets) { IRBuilder<> IRB(BO); @@ -753,7 +798,7 @@ void SanitizerCoverageModule::InjectTraceForDiv( } } -void SanitizerCoverageModule::InjectTraceForGep( +void ModuleSanitizerCoverage::InjectTraceForGep( Function &, ArrayRef<GetElementPtrInst *> GepTraceTargets) { for (auto GEP : GepTraceTargets) { IRBuilder<> IRB(GEP); @@ -764,7 +809,7 @@ void SanitizerCoverageModule::InjectTraceForGep( } } -void SanitizerCoverageModule::InjectTraceForCmp( +void ModuleSanitizerCoverage::InjectTraceForCmp( Function &, ArrayRef<Instruction *> CmpTraceTargets) { for (auto I : CmpTraceTargets) { if (ICmpInst *ICMP = dyn_cast<ICmpInst>(I)) { @@ -799,7 +844,7 @@ void SanitizerCoverageModule::InjectTraceForCmp( } } -void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, +void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx, bool IsLeafFunc) { BasicBlock::iterator IP = BB.getFirstInsertionPt(); @@ -842,8 +887,10 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, } if (Options.StackDepth && IsEntryBB && !IsLeafFunc) { // Check stack depth. If it's the deepest so far, record it. - Function *GetFrameAddr = - Intrinsic::getDeclaration(F.getParent(), Intrinsic::frameaddress); + Module *M = F.getParent(); + Function *GetFrameAddr = Intrinsic::getDeclaration( + M, Intrinsic::frameaddress, + IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); auto FrameAddrPtr = IRB.CreateCall(GetFrameAddr, {Constant::getNullValue(Int32Ty)}); auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy); @@ -858,7 +905,7 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, } std::string -SanitizerCoverageModule::getSectionName(const std::string &Section) const { +ModuleSanitizerCoverage::getSectionName(const std::string &Section) const { if (TargetTriple.isOSBinFormatCOFF()) { if (Section == SanCovCountersSectionName) return ".SCOV$CM"; @@ -872,32 +919,29 @@ SanitizerCoverageModule::getSectionName(const std::string &Section) const { } std::string -SanitizerCoverageModule::getSectionStart(const std::string &Section) const { +ModuleSanitizerCoverage::getSectionStart(const std::string &Section) const { if (TargetTriple.isOSBinFormatMachO()) return "\1section$start$__DATA$__" + Section; return "__start___" + Section; } std::string -SanitizerCoverageModule::getSectionEnd(const std::string &Section) const { +ModuleSanitizerCoverage::getSectionEnd(const std::string &Section) const { if (TargetTriple.isOSBinFormatMachO()) return "\1section$end$__DATA$__" + Section; return "__stop___" + Section; } - -char SanitizerCoverageModule::ID = 0; -INITIALIZE_PASS_BEGIN(SanitizerCoverageModule, "sancov", - "SanitizerCoverage: TODO." - "ModulePass", - false, false) +char ModuleSanitizerCoverageLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(ModuleSanitizerCoverageLegacyPass, "sancov", + "Pass for instrumenting coverage on functions", false, + false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END(SanitizerCoverageModule, "sancov", - "SanitizerCoverage: TODO." - "ModulePass", - false, false) -ModulePass *llvm::createSanitizerCoverageModulePass( +INITIALIZE_PASS_END(ModuleSanitizerCoverageLegacyPass, "sancov", + "Pass for instrumenting coverage on functions", false, + false) +ModulePass *llvm::createModuleSanitizerCoverageLegacyPassPass( const SanitizerCoverageOptions &Options) { - return new SanitizerCoverageModule(Options); + return new ModuleSanitizerCoverageLegacyPass(Options); } diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 5be13fa745cb..ac274a155a80 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -92,11 +92,10 @@ namespace { /// ensures the __tsan_init function is in the list of global constructors for /// the module. struct ThreadSanitizer { - ThreadSanitizer(Module &M); bool sanitizeFunction(Function &F, const TargetLibraryInfo &TLI); private: - void initializeCallbacks(Module &M); + void initialize(Module &M); bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL); bool instrumentAtomic(Instruction *I, const DataLayout &DL); bool instrumentMemIntrinsic(Instruction *I); @@ -108,8 +107,6 @@ private: void InsertRuntimeIgnores(Function &F); Type *IntptrTy; - IntegerType *OrdTy; - // Callbacks to run-time library are computed in doInitialization. FunctionCallee TsanFuncEntry; FunctionCallee TsanFuncExit; FunctionCallee TsanIgnoreBegin; @@ -130,7 +127,6 @@ private: FunctionCallee TsanVptrUpdate; FunctionCallee TsanVptrLoad; FunctionCallee MemmoveFn, MemcpyFn, MemsetFn; - Function *TsanCtorFunction; }; struct ThreadSanitizerLegacyPass : FunctionPass { @@ -143,16 +139,32 @@ struct ThreadSanitizerLegacyPass : FunctionPass { private: Optional<ThreadSanitizer> TSan; }; + +void insertModuleCtor(Module &M) { + getOrCreateSanitizerCtorAndInitFunctions( + M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{}, + /*InitArgs=*/{}, + // This callback is invoked when the functions are created the first + // time. Hook them into the global ctors list in that case: + [&](Function *Ctor, FunctionCallee) { appendToGlobalCtors(M, Ctor, 0); }); +} + } // namespace PreservedAnalyses ThreadSanitizerPass::run(Function &F, FunctionAnalysisManager &FAM) { - ThreadSanitizer TSan(*F.getParent()); + ThreadSanitizer TSan; if (TSan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F))) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } +PreservedAnalyses ThreadSanitizerPass::run(Module &M, + ModuleAnalysisManager &MAM) { + insertModuleCtor(M); + return PreservedAnalyses::none(); +} + char ThreadSanitizerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ThreadSanitizerLegacyPass, "tsan", "ThreadSanitizer: detects data races.", false, false) @@ -169,12 +181,13 @@ void ThreadSanitizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { } bool ThreadSanitizerLegacyPass::doInitialization(Module &M) { - TSan.emplace(M); + insertModuleCtor(M); + TSan.emplace(); return true; } bool ThreadSanitizerLegacyPass::runOnFunction(Function &F) { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); TSan->sanitizeFunction(F, TLI); return true; } @@ -183,7 +196,10 @@ FunctionPass *llvm::createThreadSanitizerLegacyPassPass() { return new ThreadSanitizerLegacyPass(); } -void ThreadSanitizer::initializeCallbacks(Module &M) { +void ThreadSanitizer::initialize(Module &M) { + const DataLayout &DL = M.getDataLayout(); + IntptrTy = DL.getIntPtrType(M.getContext()); + IRBuilder<> IRB(M.getContext()); AttributeList Attr; Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex, @@ -197,7 +213,7 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { IRB.getVoidTy()); TsanIgnoreEnd = M.getOrInsertFunction("__tsan_ignore_thread_end", Attr, IRB.getVoidTy()); - OrdTy = IRB.getInt32Ty(); + IntegerType *OrdTy = IRB.getInt32Ty(); for (size_t i = 0; i < kNumberOfAccessSizes; ++i) { const unsigned ByteSize = 1U << i; const unsigned BitSize = ByteSize * 8; @@ -280,20 +296,6 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); } -ThreadSanitizer::ThreadSanitizer(Module &M) { - const DataLayout &DL = M.getDataLayout(); - IntptrTy = DL.getIntPtrType(M.getContext()); - std::tie(TsanCtorFunction, std::ignore) = - getOrCreateSanitizerCtorAndInitFunctions( - M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{}, - /*InitArgs=*/{}, - // This callback is invoked when the functions are created the first - // time. Hook them into the global ctors list in that case: - [&](Function *Ctor, FunctionCallee) { - appendToGlobalCtors(M, Ctor, 0); - }); -} - static bool isVtableAccess(Instruction *I) { if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa)) return Tag->isTBAAVtableAccess(); @@ -436,9 +438,9 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, const TargetLibraryInfo &TLI) { // This is required to prevent instrumenting call to __tsan_init from within // the module constructor. - if (&F == TsanCtorFunction) + if (F.getName() == kTsanModuleCtorName) return false; - initializeCallbacks(*F.getParent()); + initialize(*F.getParent()); SmallVector<Instruction*, 8> AllLoadsAndStores; SmallVector<Instruction*, 8> LocalLoadsAndStores; SmallVector<Instruction*, 8> AtomicAccesses; diff --git a/lib/Transforms/Instrumentation/ValueProfileCollector.cpp b/lib/Transforms/Instrumentation/ValueProfileCollector.cpp new file mode 100644 index 000000000000..604726d4f40f --- /dev/null +++ b/lib/Transforms/Instrumentation/ValueProfileCollector.cpp @@ -0,0 +1,78 @@ +//===- ValueProfileCollector.cpp - determine what to value profile --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The implementation of the ValueProfileCollector via ValueProfileCollectorImpl +// +//===----------------------------------------------------------------------===// + +#include "ValueProfilePlugins.inc" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/InitializePasses.h" + +#include <cassert> + +using namespace llvm; + +namespace { + +/// A plugin-based class that takes an arbitrary number of Plugin types. +/// Each plugin type must satisfy the following API: +/// 1) the constructor must take a `Function &f`. Typically, the plugin would +/// scan the function looking for candidates. +/// 2) contain a member function with the following signature and name: +/// void run(std::vector<CandidateInfo> &Candidates); +/// such that the plugin would append its result into the vector parameter. +/// +/// Plugins are defined in ValueProfilePlugins.inc +template <class... Ts> class PluginChain; + +/// The type PluginChainFinal is the final chain of plugins that will be used by +/// ValueProfileCollectorImpl. +using PluginChainFinal = PluginChain<VP_PLUGIN_LIST>; + +template <> class PluginChain<> { +public: + PluginChain(Function &F) {} + void get(InstrProfValueKind K, std::vector<CandidateInfo> &Candidates) {} +}; + +template <class PluginT, class... Ts> +class PluginChain<PluginT, Ts...> : public PluginChain<Ts...> { + PluginT Plugin; + using Base = PluginChain<Ts...>; + +public: + PluginChain(Function &F) : PluginChain<Ts...>(F), Plugin(F) {} + + void get(InstrProfValueKind K, std::vector<CandidateInfo> &Candidates) { + if (K == PluginT::Kind) + Plugin.run(Candidates); + Base::get(K, Candidates); + } +}; + +} // end anonymous namespace + +/// ValueProfileCollectorImpl inherits the API of PluginChainFinal. +class ValueProfileCollector::ValueProfileCollectorImpl : public PluginChainFinal { +public: + using PluginChainFinal::PluginChainFinal; +}; + +ValueProfileCollector::ValueProfileCollector(Function &F) + : PImpl(new ValueProfileCollectorImpl(F)) {} + +ValueProfileCollector::~ValueProfileCollector() = default; + +std::vector<CandidateInfo> +ValueProfileCollector::get(InstrProfValueKind Kind) const { + std::vector<CandidateInfo> Result; + PImpl->get(Kind, Result); + return Result; +} diff --git a/lib/Transforms/Instrumentation/ValueProfileCollector.h b/lib/Transforms/Instrumentation/ValueProfileCollector.h new file mode 100644 index 000000000000..ff883c8d0c77 --- /dev/null +++ b/lib/Transforms/Instrumentation/ValueProfileCollector.h @@ -0,0 +1,79 @@ +//===- ValueProfileCollector.h - determine what to value profile ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a utility class, ValueProfileCollector, that is used to +// determine what kind of llvm::Value's are worth value-profiling, at which +// point in the program, and which instruction holds the Value Profile metadata. +// Currently, the only users of this utility is the PGOInstrumentation[Gen|Use] +// passes. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_PROFILE_GEN_ANALYSIS_H +#define LLVM_ANALYSIS_PROFILE_GEN_ANALYSIS_H + +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" + +namespace llvm { + +/// Utility analysis that determines what values are worth profiling. +/// The actual logic is inside the ValueProfileCollectorImpl, whose job is to +/// populate the Candidates vector. +/// +/// Value profiling an expression means to track the values that this expression +/// takes at runtime and the frequency of each value. +/// It is important to distinguish between two sets of value profiles for a +/// particular expression: +/// 1) The set of values at the point of evaluation. +/// 2) The set of values at the point of use. +/// In some cases, the two sets are identical, but it's not unusual for the two +/// to differ. +/// +/// To elaborate more, consider this C code, and focus on the expression `nn`: +/// void foo(int nn, bool b) { +/// if (b) memcpy(x, y, nn); +/// } +/// The point of evaluation can be as early as the start of the function, and +/// let's say the value profile for `nn` is: +/// total=100; (value,freq) set = {(8,10), (32,50)} +/// The point of use is right before we call memcpy, and since we execute the +/// memcpy conditionally, the value profile of `nn` can be: +/// total=15; (value,freq) set = {(8,10), (4,5)} +/// +/// For this reason, a plugin is responsible for computing the insertion point +/// for each value to be profiled. The `CandidateInfo` structure encapsulates +/// all the information needed for each value profile site. +class ValueProfileCollector { +public: + struct CandidateInfo { + Value *V; // The value to profile. + Instruction *InsertPt; // Insert the VP lib call before this instr. + Instruction *AnnotatedInst; // Where metadata is attached. + }; + + ValueProfileCollector(Function &Fn); + ValueProfileCollector(ValueProfileCollector &&) = delete; + ValueProfileCollector &operator=(ValueProfileCollector &&) = delete; + + ValueProfileCollector(const ValueProfileCollector &) = delete; + ValueProfileCollector &operator=(const ValueProfileCollector &) = delete; + ~ValueProfileCollector(); + + /// returns a list of value profiling candidates of the given kind + std::vector<CandidateInfo> get(InstrProfValueKind Kind) const; + +private: + class ValueProfileCollectorImpl; + std::unique_ptr<ValueProfileCollectorImpl> PImpl; +}; + +} // namespace llvm + +#endif diff --git a/lib/Transforms/Instrumentation/ValueProfilePlugins.inc b/lib/Transforms/Instrumentation/ValueProfilePlugins.inc new file mode 100644 index 000000000000..4cc4c6c848c3 --- /dev/null +++ b/lib/Transforms/Instrumentation/ValueProfilePlugins.inc @@ -0,0 +1,75 @@ +//=== ValueProfilePlugins.inc - set of plugins used by ValueProfileCollector =// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of plugin classes used in ValueProfileCollectorImpl. +// Each plugin is responsible for collecting Value Profiling candidates for a +// particular optimization. +// Each plugin must satisfy the interface described in ValueProfileCollector.cpp +// +//===----------------------------------------------------------------------===// + +#include "ValueProfileCollector.h" +#include "llvm/Analysis/IndirectCallVisitor.h" +#include "llvm/IR/InstVisitor.h" + +using namespace llvm; +using CandidateInfo = ValueProfileCollector::CandidateInfo; + +///--------------------------- MemIntrinsicPlugin ------------------------------ +class MemIntrinsicPlugin : public InstVisitor<MemIntrinsicPlugin> { + Function &F; + std::vector<CandidateInfo> *Candidates; + +public: + static constexpr InstrProfValueKind Kind = IPVK_MemOPSize; + + MemIntrinsicPlugin(Function &Fn) : F(Fn), Candidates(nullptr) {} + + void run(std::vector<CandidateInfo> &Cs) { + Candidates = &Cs; + visit(F); + Candidates = nullptr; + } + void visitMemIntrinsic(MemIntrinsic &MI) { + Value *Length = MI.getLength(); + // Not instrument constant length calls. + if (dyn_cast<ConstantInt>(Length)) + return; + + Instruction *InsertPt = &MI; + Instruction *AnnotatedInst = &MI; + Candidates->emplace_back(CandidateInfo{Length, InsertPt, AnnotatedInst}); + } +}; + +///------------------------ IndirectCallPromotionPlugin ------------------------ +class IndirectCallPromotionPlugin { + Function &F; + +public: + static constexpr InstrProfValueKind Kind = IPVK_IndirectCallTarget; + + IndirectCallPromotionPlugin(Function &Fn) : F(Fn) {} + + void run(std::vector<CandidateInfo> &Candidates) { + std::vector<Instruction *> Result = findIndirectCalls(F); + for (Instruction *I : Result) { + Value *Callee = CallSite(I).getCalledValue(); + Instruction *InsertPt = I; + Instruction *AnnotatedInst = I; + Candidates.emplace_back(CandidateInfo{Callee, InsertPt, AnnotatedInst}); + } + } +}; + +///----------------------- Registration of the plugins ------------------------- +/// For now, registering a plugin with the ValueProfileCollector is done by +/// adding the plugin type to the VP_PLUGIN_LIST macro. +#define VP_PLUGIN_LIST \ + MemIntrinsicPlugin, \ + IndirectCallPromotionPlugin diff --git a/lib/Transforms/ObjCARC/PtrState.cpp b/lib/Transforms/ObjCARC/PtrState.cpp index 3243481dee0d..26dd416d6184 100644 --- a/lib/Transforms/ObjCARC/PtrState.cpp +++ b/lib/Transforms/ObjCARC/PtrState.cpp @@ -275,6 +275,10 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst, } else { InsertAfter = std::next(Inst->getIterator()); } + + if (InsertAfter != BB->end()) + InsertAfter = skipDebugIntrinsics(InsertAfter); + InsertReverseInsertPt(&*InsertAfter); }; diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index de9a62e88c27..0e9f03a06061 100644 --- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -93,9 +93,7 @@ static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, const SCEV *AlignSCEV, ScalarEvolution *SE) { // DiffUnits = Diff % int64_t(Alignment) - const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV); - const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV); - const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV); + const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV); LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n"); @@ -323,7 +321,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { LI->getPointerOperand(), SE); if (NewAlignment > LI->getAlignment()) { - LI->setAlignment(NewAlignment); + LI->setAlignment(MaybeAlign(NewAlignment)); ++NumLoadAlignChanged; } } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) { @@ -331,7 +329,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { SI->getPointerOperand(), SE); if (NewAlignment > SI->getAlignment()) { - SI->setAlignment(NewAlignment); + SI->setAlignment(MaybeAlign(NewAlignment)); ++NumStoreAlignChanged; } } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) { diff --git a/lib/Transforms/Scalar/CallSiteSplitting.cpp b/lib/Transforms/Scalar/CallSiteSplitting.cpp index 3519b000a33f..c3fba923104f 100644 --- a/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -562,7 +562,7 @@ struct CallSiteSplittingLegacyPass : public FunctionPass { if (skipFunction(F)) return false; - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); return doCallSiteSplitting(F, TLI, TTI, DT); diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index 98243a23f1ef..9f340afbf7c2 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -204,7 +204,7 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, /// set found in \p BBs. static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, BasicBlock *Entry, - SmallPtrSet<BasicBlock *, 8> &BBs) { + SetVector<BasicBlock *> &BBs) { assert(!BBs.count(Entry) && "Assume Entry is not in BBs"); // Nodes on the current path to the root. SmallPtrSet<BasicBlock *, 8> Path; @@ -257,7 +257,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, // Visit Orders in bottom-up order. using InsertPtsCostPair = - std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency>; + std::pair<SetVector<BasicBlock *>, BlockFrequency>; // InsertPtsMap is a map from a BB to the best insertion points for the // subtree of BB (subtree not including the BB itself). @@ -266,7 +266,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, for (auto RIt = Orders.rbegin(); RIt != Orders.rend(); RIt++) { BasicBlock *Node = *RIt; bool NodeInBBs = BBs.count(Node); - SmallPtrSet<BasicBlock *, 16> &InsertPts = InsertPtsMap[Node].first; + auto &InsertPts = InsertPtsMap[Node].first; BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second; // Return the optimal insert points in BBs. @@ -283,7 +283,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock(); // Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child // will update its parent's ParentInsertPts and ParentPtsFreq. - SmallPtrSet<BasicBlock *, 16> &ParentInsertPts = InsertPtsMap[Parent].first; + auto &ParentInsertPts = InsertPtsMap[Parent].first; BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second; // Choose to insert in Node or in subtree of Node. // Don't hoist to EHPad because we may not find a proper place to insert @@ -305,12 +305,12 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, } /// Find an insertion point that dominates all uses. -SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( +SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( const ConstantInfo &ConstInfo) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. - SmallPtrSet<BasicBlock *, 8> BBs; - SmallPtrSet<Instruction *, 8> InsertPts; + SetVector<BasicBlock *> BBs; + SetVector<Instruction *> InsertPts; for (auto const &RCI : ConstInfo.RebasedConstants) for (auto const &U : RCI.Uses) BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent()); @@ -333,15 +333,13 @@ SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( while (BBs.size() >= 2) { BasicBlock *BB, *BB1, *BB2; - BB1 = *BBs.begin(); - BB2 = *std::next(BBs.begin()); + BB1 = BBs.pop_back_val(); + BB2 = BBs.pop_back_val(); BB = DT->findNearestCommonDominator(BB1, BB2); if (BB == Entry) { InsertPts.insert(&Entry->front()); return InsertPts; } - BBs.erase(BB1); - BBs.erase(BB2); BBs.insert(BB); } assert((BBs.size() == 1) && "Expected only one element."); @@ -403,7 +401,7 @@ void ConstantHoistingPass::collectConstantCandidates( return; // Get offset from the base GV. - PointerType *GVPtrTy = dyn_cast<PointerType>(BaseGV->getType()); + PointerType *GVPtrTy = cast<PointerType>(BaseGV->getType()); IntegerType *PtrIntTy = DL->getIntPtrType(*Ctx, GVPtrTy->getAddressSpace()); APInt Offset(DL->getTypeSizeInBits(PtrIntTy), /*val*/0, /*isSigned*/true); auto *GEPO = cast<GEPOperator>(ConstExpr); @@ -830,7 +828,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec = BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; for (auto const &ConstInfo : ConstInfoVec) { - SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo); + SetVector<Instruction *> IPSet = findConstantInsertionPoint(ConstInfo); // We can have an empty set if the function contains unreachable blocks. if (IPSet.empty()) continue; diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp index 770321c740a0..e9e6afe3fdd4 100644 --- a/lib/Transforms/Scalar/ConstantProp.cpp +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -82,7 +82,7 @@ bool ConstantPropagation::runOnFunction(Function &F) { bool Changed = false; const DataLayout &DL = F.getParent()->getDataLayout(); TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); while (!WorkList.empty()) { SmallVector<Instruction*, 16> NewWorkListVec; diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 89497177524f..2ef85268df48 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -62,6 +62,23 @@ STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); STATISTIC(NumUDivs, "Number of udivs whose width was decreased"); STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); +STATISTIC(NumSExt, "Number of sext converted to zext"); +STATISTIC(NumAnd, "Number of ands removed"); +STATISTIC(NumNW, "Number of no-wrap deductions"); +STATISTIC(NumNSW, "Number of no-signed-wrap deductions"); +STATISTIC(NumNUW, "Number of no-unsigned-wrap deductions"); +STATISTIC(NumAddNW, "Number of no-wrap deductions for add"); +STATISTIC(NumAddNSW, "Number of no-signed-wrap deductions for add"); +STATISTIC(NumAddNUW, "Number of no-unsigned-wrap deductions for add"); +STATISTIC(NumSubNW, "Number of no-wrap deductions for sub"); +STATISTIC(NumSubNSW, "Number of no-signed-wrap deductions for sub"); +STATISTIC(NumSubNUW, "Number of no-unsigned-wrap deductions for sub"); +STATISTIC(NumMulNW, "Number of no-wrap deductions for mul"); +STATISTIC(NumMulNSW, "Number of no-signed-wrap deductions for mul"); +STATISTIC(NumMulNUW, "Number of no-unsigned-wrap deductions for mul"); +STATISTIC(NumShlNW, "Number of no-wrap deductions for shl"); +STATISTIC(NumShlNSW, "Number of no-signed-wrap deductions for shl"); +STATISTIC(NumShlNUW, "Number of no-unsigned-wrap deductions for shl"); STATISTIC(NumOverflows, "Number of overflow checks removed"); STATISTIC(NumSaturating, "Number of saturating arithmetics converted to normal arithmetics"); @@ -85,6 +102,7 @@ namespace { AU.addRequired<LazyValueInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LazyValueInfoWrapperPass>(); } }; @@ -416,37 +434,96 @@ static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) { return NWRegion.contains(LRange); } -static void processOverflowIntrinsic(WithOverflowInst *WO) { - IRBuilder<> B(WO); - Value *NewOp = B.CreateBinOp( - WO->getBinaryOp(), WO->getLHS(), WO->getRHS(), WO->getName()); - // Constant-folding could have happened. - if (auto *Inst = dyn_cast<Instruction>(NewOp)) { - if (WO->isSigned()) +static void setDeducedOverflowingFlags(Value *V, Instruction::BinaryOps Opcode, + bool NewNSW, bool NewNUW) { + Statistic *OpcNW, *OpcNSW, *OpcNUW; + switch (Opcode) { + case Instruction::Add: + OpcNW = &NumAddNW; + OpcNSW = &NumAddNSW; + OpcNUW = &NumAddNUW; + break; + case Instruction::Sub: + OpcNW = &NumSubNW; + OpcNSW = &NumSubNSW; + OpcNUW = &NumSubNUW; + break; + case Instruction::Mul: + OpcNW = &NumMulNW; + OpcNSW = &NumMulNSW; + OpcNUW = &NumMulNUW; + break; + case Instruction::Shl: + OpcNW = &NumShlNW; + OpcNSW = &NumShlNSW; + OpcNUW = &NumShlNUW; + break; + default: + llvm_unreachable("Will not be called with other binops"); + } + + auto *Inst = dyn_cast<Instruction>(V); + if (NewNSW) { + ++NumNW; + ++*OpcNW; + ++NumNSW; + ++*OpcNSW; + if (Inst) Inst->setHasNoSignedWrap(); - else + } + if (NewNUW) { + ++NumNW; + ++*OpcNW; + ++NumNUW; + ++*OpcNUW; + if (Inst) Inst->setHasNoUnsignedWrap(); } +} - Value *NewI = B.CreateInsertValue(UndefValue::get(WO->getType()), NewOp, 0); - NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(WO->getContext()), 1); +static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI); + +// Rewrite this with.overflow intrinsic as non-overflowing. +static void processOverflowIntrinsic(WithOverflowInst *WO, LazyValueInfo *LVI) { + IRBuilder<> B(WO); + Instruction::BinaryOps Opcode = WO->getBinaryOp(); + bool NSW = WO->isSigned(); + bool NUW = !WO->isSigned(); + + Value *NewOp = + B.CreateBinOp(Opcode, WO->getLHS(), WO->getRHS(), WO->getName()); + setDeducedOverflowingFlags(NewOp, Opcode, NSW, NUW); + + StructType *ST = cast<StructType>(WO->getType()); + Constant *Struct = ConstantStruct::get(ST, + { UndefValue::get(ST->getElementType(0)), + ConstantInt::getFalse(ST->getElementType(1)) }); + Value *NewI = B.CreateInsertValue(Struct, NewOp, 0); WO->replaceAllUsesWith(NewI); WO->eraseFromParent(); ++NumOverflows; + + // See if we can infer the other no-wrap too. + if (auto *BO = dyn_cast<BinaryOperator>(NewOp)) + processBinOp(BO, LVI); } -static void processSaturatingInst(SaturatingInst *SI) { +static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) { + Instruction::BinaryOps Opcode = SI->getBinaryOp(); + bool NSW = SI->isSigned(); + bool NUW = !SI->isSigned(); BinaryOperator *BinOp = BinaryOperator::Create( - SI->getBinaryOp(), SI->getLHS(), SI->getRHS(), SI->getName(), SI); + Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI); BinOp->setDebugLoc(SI->getDebugLoc()); - if (SI->isSigned()) - BinOp->setHasNoSignedWrap(); - else - BinOp->setHasNoUnsignedWrap(); + setDeducedOverflowingFlags(BinOp, Opcode, NSW, NUW); SI->replaceAllUsesWith(BinOp); SI->eraseFromParent(); ++NumSaturating; + + // See if we can infer the other no-wrap too. + if (auto *BO = dyn_cast<BinaryOperator>(BinOp)) + processBinOp(BO, LVI); } /// Infer nonnull attributes for the arguments at the specified callsite. @@ -456,14 +533,14 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) { if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) { - processOverflowIntrinsic(WO); + processOverflowIntrinsic(WO, LVI); return true; } } if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) { if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) { - processSaturatingInst(SI); + processSaturatingInst(SI, LVI); return true; } } @@ -632,6 +709,27 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { return true; } +static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy()) + return false; + + Value *Base = SDI->getOperand(0); + + Constant *Zero = ConstantInt::get(Base->getType(), 0); + if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, Base, Zero, SDI) != + LazyValueInfo::True) + return false; + + ++NumSExt; + auto *ZExt = + CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI); + ZExt->setDebugLoc(SDI->getDebugLoc()); + SDI->replaceAllUsesWith(ZExt); + SDI->eraseFromParent(); + + return true; +} + static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { using OBO = OverflowingBinaryOperator; @@ -648,6 +746,7 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { BasicBlock *BB = BinOp->getParent(); + Instruction::BinaryOps Opcode = BinOp->getOpcode(); Value *LHS = BinOp->getOperand(0); Value *RHS = BinOp->getOperand(1); @@ -655,24 +754,48 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { ConstantRange RRange = LVI->getConstantRange(RHS, BB, BinOp); bool Changed = false; + bool NewNUW = false, NewNSW = false; if (!NUW) { ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( - BinOp->getOpcode(), RRange, OBO::NoUnsignedWrap); - bool NewNUW = NUWRange.contains(LRange); - BinOp->setHasNoUnsignedWrap(NewNUW); + Opcode, RRange, OBO::NoUnsignedWrap); + NewNUW = NUWRange.contains(LRange); Changed |= NewNUW; } if (!NSW) { ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( - BinOp->getOpcode(), RRange, OBO::NoSignedWrap); - bool NewNSW = NSWRange.contains(LRange); - BinOp->setHasNoSignedWrap(NewNSW); + Opcode, RRange, OBO::NoSignedWrap); + NewNSW = NSWRange.contains(LRange); Changed |= NewNSW; } + setDeducedOverflowingFlags(BinOp, Opcode, NewNSW, NewNUW); + return Changed; } +static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) { + if (BinOp->getType()->isVectorTy()) + return false; + + // Pattern match (and lhs, C) where C includes a superset of bits which might + // be set in lhs. This is a common truncation idiom created by instcombine. + BasicBlock *BB = BinOp->getParent(); + Value *LHS = BinOp->getOperand(0); + ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1)); + if (!RHS || !RHS->getValue().isMask()) + return false; + + ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp); + if (!LRange.getUnsignedMax().ule(RHS->getValue())) + return false; + + BinOp->replaceAllUsesWith(LHS); + BinOp->eraseFromParent(); + NumAnd++; + return true; +} + + static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { if (Constant *C = LVI->getConstant(V, At->getParent(), At)) return C; @@ -740,10 +863,18 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, case Instruction::AShr: BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); break; + case Instruction::SExt: + BBChanged |= processSExt(cast<SExtInst>(II), LVI); + break; case Instruction::Add: case Instruction::Sub: + case Instruction::Mul: + case Instruction::Shl: BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI); break; + case Instruction::And: + BBChanged |= processAnd(cast<BinaryOperator>(II), LVI); + break; } } @@ -796,5 +927,6 @@ CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { PreservedAnalyses PA; PA.preserve<GlobalsAA>(); PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LazyValueAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp index 479e0ed74074..a79d775aa7f3 100644 --- a/lib/Transforms/Scalar/DCE.cpp +++ b/lib/Transforms/Scalar/DCE.cpp @@ -38,17 +38,19 @@ namespace { //===--------------------------------------------------------------------===// // DeadInstElimination pass implementation // - struct DeadInstElimination : public BasicBlockPass { - static char ID; // Pass identification, replacement for typeid - DeadInstElimination() : BasicBlockPass(ID) { - initializeDeadInstEliminationPass(*PassRegistry::getPassRegistry()); - } - bool runOnBasicBlock(BasicBlock &BB) override { - if (skipBasicBlock(BB)) - return false; - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; - bool Changed = false; +struct DeadInstElimination : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DeadInstElimination() : FunctionPass(ID) { + initializeDeadInstEliminationPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; + + bool Changed = false; + for (auto &BB : F) { for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) { Instruction *Inst = &*DI++; if (isInstructionTriviallyDead(Inst, TLI)) { @@ -60,13 +62,14 @@ namespace { ++DIEEliminated; } } - return Changed; } + return Changed; + } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); } - }; +}; } char DeadInstElimination::ID = 0; @@ -154,7 +157,7 @@ struct DCELegacyPass : public FunctionPass { return false; auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; + TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; return eliminateDeadCode(F, TLI); } diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index a81645745b48..685de82810ed 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -1254,8 +1254,9 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, auto *SI = new StoreInst( ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), - Earlier->getPointerOperand(), false, Earlier->getAlignment(), - Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite); + Earlier->getPointerOperand(), false, + MaybeAlign(Earlier->getAlignment()), Earlier->getOrdering(), + Earlier->getSyncScopeID(), DepWrite); unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, @@ -1361,7 +1362,7 @@ public: MemoryDependenceResults *MD = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); return eliminateDeadStores(F, AA, MD, DT, TLI); } diff --git a/lib/Transforms/Scalar/DivRemPairs.cpp b/lib/Transforms/Scalar/DivRemPairs.cpp index 876681b4f9de..934853507478 100644 --- a/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/lib/Transforms/Scalar/DivRemPairs.cpp @@ -1,4 +1,4 @@ -//===- DivRemPairs.cpp - Hoist/decompose division and remainder -*- C++ -*-===// +//===- DivRemPairs.cpp - Hoist/[dr]ecompose division and remainder --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This pass hoists and/or decomposes integer division and remainder +// This pass hoists and/or decomposes/recomposes integer division and remainder // instructions to enable CFG improvements and better codegen. // //===----------------------------------------------------------------------===// @@ -19,37 +19,105 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BypassSlowDivision.h" + using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "div-rem-pairs" STATISTIC(NumPairs, "Number of div/rem pairs"); +STATISTIC(NumRecomposed, "Number of instructions recomposed"); STATISTIC(NumHoisted, "Number of instructions hoisted"); STATISTIC(NumDecomposed, "Number of instructions decomposed"); DEBUG_COUNTER(DRPCounter, "div-rem-pairs-transform", "Controls transformations in div-rem-pairs pass"); -/// Find matching pairs of integer div/rem ops (they have the same numerator, -/// denominator, and signedness). If they exist in different basic blocks, bring -/// them together by hoisting or replace the common division operation that is -/// implicit in the remainder: -/// X % Y <--> X - ((X / Y) * Y). -/// -/// We can largely ignore the normal safety and cost constraints on speculation -/// of these ops when we find a matching pair. This is because we are already -/// guaranteed that any exceptions and most cost are already incurred by the -/// first member of the pair. -/// -/// Note: This transform could be an oddball enhancement to EarlyCSE, GVN, or -/// SimplifyCFG, but it's split off on its own because it's different enough -/// that it doesn't quite match the stated objectives of those passes. -static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT) { - bool Changed = false; +namespace { +struct ExpandedMatch { + DivRemMapKey Key; + Instruction *Value; +}; +} // namespace + +/// See if we can match: (which is the form we expand into) +/// X - ((X ?/ Y) * Y) +/// which is equivalent to: +/// X ?% Y +static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) { + Value *Dividend, *XroundedDownToMultipleOfY; + if (!match(&I, m_Sub(m_Value(Dividend), m_Value(XroundedDownToMultipleOfY)))) + return llvm::None; + + Value *Divisor; + Instruction *Div; + // Look for ((X / Y) * Y) + if (!match( + XroundedDownToMultipleOfY, + m_c_Mul(m_CombineAnd(m_IDiv(m_Specific(Dividend), m_Value(Divisor)), + m_Instruction(Div)), + m_Deferred(Divisor)))) + return llvm::None; + + ExpandedMatch M; + M.Key.SignedOp = Div->getOpcode() == Instruction::SDiv; + M.Key.Dividend = Dividend; + M.Key.Divisor = Divisor; + M.Value = &I; + return M; +} + +/// A thin wrapper to store two values that we matched as div-rem pair. +/// We want this extra indirection to avoid dealing with RAUW'ing the map keys. +struct DivRemPairWorklistEntry { + /// The actual udiv/sdiv instruction. Source of truth. + AssertingVH<Instruction> DivInst; + + /// The instruction that we have matched as a remainder instruction. + /// Should only be used as Value, don't introspect it. + AssertingVH<Instruction> RemInst; + + DivRemPairWorklistEntry(Instruction *DivInst_, Instruction *RemInst_) + : DivInst(DivInst_), RemInst(RemInst_) { + assert((DivInst->getOpcode() == Instruction::UDiv || + DivInst->getOpcode() == Instruction::SDiv) && + "Not a division."); + assert(DivInst->getType() == RemInst->getType() && "Types should match."); + // We can't check anything else about remainder instruction, + // it's not strictly required to be a urem/srem. + } + /// The type for this pair, identical for both the div and rem. + Type *getType() const { return DivInst->getType(); } + + /// Is this pair signed or unsigned? + bool isSigned() const { return DivInst->getOpcode() == Instruction::SDiv; } + + /// In this pair, what are the divident and divisor? + Value *getDividend() const { return DivInst->getOperand(0); } + Value *getDivisor() const { return DivInst->getOperand(1); } + + bool isRemExpanded() const { + switch (RemInst->getOpcode()) { + case Instruction::SRem: + case Instruction::URem: + return false; // single 'rem' instruction - unexpanded form. + default: + return true; // anything else means we have remainder in expanded form. + } + } +}; +using DivRemWorklistTy = SmallVector<DivRemPairWorklistEntry, 4>; + +/// Find matching pairs of integer div/rem ops (they have the same numerator, +/// denominator, and signedness). Place those pairs into a worklist for further +/// processing. This indirection is needed because we have to use TrackingVH<> +/// because we will be doing RAUW, and if one of the rem instructions we change +/// happens to be an input to another div/rem in the maps, we'd have problems. +static DivRemWorklistTy getWorklist(Function &F) { // Insert all divide and remainder instructions into maps keyed by their // operands and opcode (signed or unsigned). DenseMap<DivRemMapKey, Instruction *> DivMap; @@ -66,9 +134,14 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, RemMap[DivRemMapKey(true, I.getOperand(0), I.getOperand(1))] = &I; else if (I.getOpcode() == Instruction::URem) RemMap[DivRemMapKey(false, I.getOperand(0), I.getOperand(1))] = &I; + else if (auto Match = matchExpandedRem(I)) + RemMap[Match->Key] = Match->Value; } } + // We'll accumulate the matching pairs of div-rem instructions here. + DivRemWorklistTy Worklist; + // We can iterate over either map because we are only looking for matched // pairs. Choose remainders for efficiency because they are usually even more // rare than division. @@ -78,12 +151,77 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, if (!DivInst) continue; - // We have a matching pair of div/rem instructions. If one dominates the - // other, hoist and/or replace one. + // We have a matching pair of div/rem instructions. NumPairs++; Instruction *RemInst = RemPair.second; - bool IsSigned = DivInst->getOpcode() == Instruction::SDiv; - bool HasDivRemOp = TTI.hasDivRemOp(DivInst->getType(), IsSigned); + + // Place it in the worklist. + Worklist.emplace_back(DivInst, RemInst); + } + + return Worklist; +} + +/// Find matching pairs of integer div/rem ops (they have the same numerator, +/// denominator, and signedness). If they exist in different basic blocks, bring +/// them together by hoisting or replace the common division operation that is +/// implicit in the remainder: +/// X % Y <--> X - ((X / Y) * Y). +/// +/// We can largely ignore the normal safety and cost constraints on speculation +/// of these ops when we find a matching pair. This is because we are already +/// guaranteed that any exceptions and most cost are already incurred by the +/// first member of the pair. +/// +/// Note: This transform could be an oddball enhancement to EarlyCSE, GVN, or +/// SimplifyCFG, but it's split off on its own because it's different enough +/// that it doesn't quite match the stated objectives of those passes. +static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, + const DominatorTree &DT) { + bool Changed = false; + + // Get the matching pairs of div-rem instructions. We want this extra + // indirection to avoid dealing with having to RAUW the keys of the maps. + DivRemWorklistTy Worklist = getWorklist(F); + + // Process each entry in the worklist. + for (DivRemPairWorklistEntry &E : Worklist) { + if (!DebugCounter::shouldExecute(DRPCounter)) + continue; + + bool HasDivRemOp = TTI.hasDivRemOp(E.getType(), E.isSigned()); + + auto &DivInst = E.DivInst; + auto &RemInst = E.RemInst; + + const bool RemOriginallyWasInExpandedForm = E.isRemExpanded(); + (void)RemOriginallyWasInExpandedForm; // suppress unused variable warning + + if (HasDivRemOp && E.isRemExpanded()) { + // The target supports div+rem but the rem is expanded. + // We should recompose it first. + Value *X = E.getDividend(); + Value *Y = E.getDivisor(); + Instruction *RealRem = E.isSigned() ? BinaryOperator::CreateSRem(X, Y) + : BinaryOperator::CreateURem(X, Y); + // Note that we place it right next to the original expanded instruction, + // and letting further handling to move it if needed. + RealRem->setName(RemInst->getName() + ".recomposed"); + RealRem->insertAfter(RemInst); + Instruction *OrigRemInst = RemInst; + // Update AssertingVH<> with new instruction so it doesn't assert. + RemInst = RealRem; + // And replace the original instruction with the new one. + OrigRemInst->replaceAllUsesWith(RealRem); + OrigRemInst->eraseFromParent(); + NumRecomposed++; + // Note that we have left ((X / Y) * Y) around. + // If it had other uses we could rewrite it as X - X % Y + } + + assert((!E.isRemExpanded() || !HasDivRemOp) && + "*If* the target supports div-rem, then by now the RemInst *is* " + "Instruction::[US]Rem."); // If the target supports div+rem and the instructions are in the same block // already, there's nothing to do. The backend should handle this. If the @@ -92,10 +230,16 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, continue; bool DivDominates = DT.dominates(DivInst, RemInst); - if (!DivDominates && !DT.dominates(RemInst, DivInst)) + if (!DivDominates && !DT.dominates(RemInst, DivInst)) { + // We have matching div-rem pair, but they are in two different blocks, + // neither of which dominates one another. + // FIXME: We could hoist both ops to the common predecessor block? continue; + } - if (!DebugCounter::shouldExecute(DRPCounter)) + // The target does not have a single div/rem operation, + // and the rem is already in expanded form. Nothing to do. + if (!HasDivRemOp && E.isRemExpanded()) continue; if (HasDivRemOp) { @@ -107,11 +251,17 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, DivInst->moveAfter(RemInst); NumHoisted++; } else { - // The target does not have a single div/rem operation. Decompose the - // remainder calculation as: + // The target does not have a single div/rem operation, + // and the rem is *not* in a already-expanded form. + // Decompose the remainder calculation as: // X % Y --> X - ((X / Y) * Y). - Value *X = RemInst->getOperand(0); - Value *Y = RemInst->getOperand(1); + + assert(!RemOriginallyWasInExpandedForm && + "We should not be expanding if the rem was in expanded form to " + "begin with."); + + Value *X = E.getDividend(); + Value *Y = E.getDivisor(); Instruction *Mul = BinaryOperator::CreateMul(DivInst, Y); Instruction *Sub = BinaryOperator::CreateSub(X, Mul); @@ -152,8 +302,13 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, // Now kill the explicit remainder. We have replaced it with: // (sub X, (mul (div X, Y), Y) - RemInst->replaceAllUsesWith(Sub); - RemInst->eraseFromParent(); + Sub->setName(RemInst->getName() + ".decomposed"); + Instruction *OrigRemInst = RemInst; + // Update AssertingVH<> with new instruction so it doesn't assert. + RemInst = Sub; + // And replace the original instruction with the new one. + OrigRemInst->replaceAllUsesWith(Sub); + OrigRemInst->eraseFromParent(); NumDecomposed++; } Changed = true; @@ -188,7 +343,7 @@ struct DivRemPairsLegacyPass : public FunctionPass { return optimizeDivRem(F, TTI, DT); } }; -} +} // namespace char DivRemPairsLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(DivRemPairsLegacyPass, "div-rem-pairs", diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index f1f075257020..ce540683dae2 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -108,11 +108,12 @@ struct SimpleValue { // This can only handle non-void readnone functions. if (CallInst *CI = dyn_cast<CallInst>(Inst)) return CI->doesNotAccessMemory() && !CI->getType()->isVoidTy(); - return isa<CastInst>(Inst) || isa<BinaryOperator>(Inst) || - isa<GetElementPtrInst>(Inst) || isa<CmpInst>(Inst) || - isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) || - isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) || - isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst); + return isa<CastInst>(Inst) || isa<UnaryOperator>(Inst) || + isa<BinaryOperator>(Inst) || isa<GetElementPtrInst>(Inst) || + isa<CmpInst>(Inst) || isa<SelectInst>(Inst) || + isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || + isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) || + isa<InsertValueInst>(Inst); } }; @@ -240,7 +241,7 @@ static unsigned getHashValueImpl(SimpleValue Val) { assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) || isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || - isa<ShuffleVectorInst>(Inst)) && + isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst)) && "Invalid/unknown instruction"); // Mix in the opcode. @@ -526,7 +527,7 @@ public: const TargetTransformInfo &TTI, DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA) : TLI(TLI), TTI(TTI), DT(DT), AC(AC), SQ(DL, &TLI, &DT, &AC), MSSA(MSSA), - MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {} + MSSAUpdater(std::make_unique<MemorySSAUpdater>(MSSA)) {} bool run(); @@ -651,7 +652,7 @@ private: bool isInvariantLoad() const { if (auto *LI = dyn_cast<LoadInst>(Inst)) - return LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr; + return LI->hasMetadata(LLVMContext::MD_invariant_load); return false; } @@ -790,7 +791,7 @@ bool EarlyCSE::isOperatingOnInvariantMemAt(Instruction *I, unsigned GenAt) { // A location loaded from with an invariant_load is assumed to *never* change // within the visible scope of the compilation. if (auto *LI = dyn_cast<LoadInst>(I)) - if (LI->getMetadata(LLVMContext::MD_invariant_load)) + if (LI->hasMetadata(LLVMContext::MD_invariant_load)) return true; auto MemLocOpt = MemoryLocation::getOrNone(I); @@ -1359,7 +1360,7 @@ public: if (skipFunction(F)) return false; - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); @@ -1381,6 +1382,7 @@ public: AU.addPreserved<MemorySSAWrapperPass>(); } AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); AU.setPreservesCFG(); } }; diff --git a/lib/Transforms/Scalar/FlattenCFGPass.cpp b/lib/Transforms/Scalar/FlattenCFGPass.cpp index 31670b1464e4..e6abf1ceb026 100644 --- a/lib/Transforms/Scalar/FlattenCFGPass.cpp +++ b/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -11,10 +11,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" + using namespace llvm; #define DEBUG_TYPE "flattencfg" @@ -52,15 +54,23 @@ FunctionPass *llvm::createFlattenCFGPass() { return new FlattenCFGPass(); } static bool iterativelyFlattenCFG(Function &F, AliasAnalysis *AA) { bool Changed = false; bool LocalChange = true; + + // Use block handles instead of iterating over function blocks directly + // to avoid using iterators invalidated by erasing blocks. + std::vector<WeakVH> Blocks; + Blocks.reserve(F.size()); + for (auto &BB : F) + Blocks.push_back(&BB); + while (LocalChange) { LocalChange = false; - // Loop over all of the basic blocks and remove them if they are unneeded... - // - for (Function::iterator BBIt = F.begin(); BBIt != F.end();) { - if (FlattenCFG(&*BBIt++, AA)) { - LocalChange = true; - } + // Loop over all of the basic blocks and try to flatten them. + for (WeakVH &BlockHandle : Blocks) { + // Skip blocks erased by FlattenCFG. + if (auto *BB = cast_or_null<BasicBlock>(BlockHandle)) + if (FlattenCFG(BB, AA)) + LocalChange = true; } Changed |= LocalChange; } diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp index 4f83e869b303..4d2eac0451df 100644 --- a/lib/Transforms/Scalar/Float2Int.cpp +++ b/lib/Transforms/Scalar/Float2Int.cpp @@ -60,11 +60,13 @@ namespace { if (skipFunction(F)) return false; - return Impl.runImpl(F); + const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + return Impl.runImpl(F, DT); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } @@ -116,21 +118,29 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) { // Find the roots - instructions that convert from the FP domain to // integer domain. -void Float2IntPass::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { - for (auto &I : instructions(F)) { - if (isa<VectorType>(I.getType())) +void Float2IntPass::findRoots(Function &F, const DominatorTree &DT, + SmallPtrSet<Instruction*,8> &Roots) { + for (BasicBlock &BB : F) { + // Unreachable code can take on strange forms that we are not prepared to + // handle. For example, an instruction may have itself as an operand. + if (!DT.isReachableFromEntry(&BB)) continue; - switch (I.getOpcode()) { - default: break; - case Instruction::FPToUI: - case Instruction::FPToSI: - Roots.insert(&I); - break; - case Instruction::FCmp: - if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) != - CmpInst::BAD_ICMP_PREDICATE) + + for (Instruction &I : BB) { + if (isa<VectorType>(I.getType())) + continue; + switch (I.getOpcode()) { + default: break; + case Instruction::FPToUI: + case Instruction::FPToSI: Roots.insert(&I); - break; + break; + case Instruction::FCmp: + if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) != + CmpInst::BAD_ICMP_PREDICATE) + Roots.insert(&I); + break; + } } } } @@ -503,7 +513,7 @@ void Float2IntPass::cleanup() { I.first->eraseFromParent(); } -bool Float2IntPass::runImpl(Function &F) { +bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) { LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n"); // Clear out all state. ECs = EquivalenceClasses<Instruction*>(); @@ -513,7 +523,7 @@ bool Float2IntPass::runImpl(Function &F) { Ctx = &F.getParent()->getContext(); - findRoots(F, Roots); + findRoots(F, DT, Roots); walkBackwards(Roots); walkForwards(); @@ -527,8 +537,9 @@ bool Float2IntPass::runImpl(Function &F) { namespace llvm { FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); } -PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) { - if (!runImpl(F)) +PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) { + const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + if (!runImpl(F, DT)) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index 1a02e9d33f49..743353eaea22 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -70,6 +70,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -626,6 +627,8 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<DominatorTreeAnalysis>(); PA.preserve<GlobalsAA>(); PA.preserve<TargetLibraryAnalysis>(); + if (LI) + PA.preserve<LoopAnalysis>(); return PA; } @@ -1161,15 +1164,30 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // Do PHI translation to get its value in the predecessor if necessary. The // returned pointer (if non-null) is guaranteed to dominate UnavailablePred. + // We do the translation for each edge we skipped by going from LI's block + // to LoadBB, otherwise we might miss pieces needing translation. // If all preds have a single successor, then we know it is safe to insert // the load on the pred (?!?), so we can insert code to materialize the // pointer if it is not available. - PHITransAddr Address(LI->getPointerOperand(), DL, AC); - Value *LoadPtr = nullptr; - LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, - *DT, NewInsts); + Value *LoadPtr = LI->getPointerOperand(); + BasicBlock *Cur = LI->getParent(); + while (Cur != LoadBB) { + PHITransAddr Address(LoadPtr, DL, AC); + LoadPtr = Address.PHITranslateWithInsertion( + Cur, Cur->getSinglePredecessor(), *DT, NewInsts); + if (!LoadPtr) { + CanDoPRE = false; + break; + } + Cur = Cur->getSinglePredecessor(); + } + if (LoadPtr) { + PHITransAddr Address(LoadPtr, DL, AC); + LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, *DT, + NewInsts); + } // If we couldn't find or insert a computation of this phi translated value, // we fail PRE. if (!LoadPtr) { @@ -1184,8 +1202,12 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, if (!CanDoPRE) { while (!NewInsts.empty()) { - Instruction *I = NewInsts.pop_back_val(); - markInstructionForDeletion(I); + // Erase instructions generated by the failed PHI translation before + // trying to number them. PHI translation might insert instructions + // in basic blocks other than the current one, and we delete them + // directly, as markInstructionForDeletion only allows removing from the + // current basic block. + NewInsts.pop_back_val()->eraseFromParent(); } // HINT: Don't revert the edge-splitting as following transformation may // also need to split these critical edges. @@ -1219,10 +1241,10 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, BasicBlock *UnavailablePred = PredLoad.first; Value *LoadPtr = PredLoad.second; - auto *NewLoad = - new LoadInst(LI->getType(), LoadPtr, LI->getName() + ".pre", - LI->isVolatile(), LI->getAlignment(), LI->getOrdering(), - LI->getSyncScopeID(), UnavailablePred->getTerminator()); + auto *NewLoad = new LoadInst( + LI->getType(), LoadPtr, LI->getName() + ".pre", LI->isVolatile(), + MaybeAlign(LI->getAlignment()), LI->getOrdering(), LI->getSyncScopeID(), + UnavailablePred->getTerminator()); NewLoad->setDebugLoc(LI->getDebugLoc()); // Transfer the old load's AA tags to the new load. @@ -1365,6 +1387,14 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks); } +static bool hasUsersIn(Value *V, BasicBlock *BB) { + for (User *U : V->users()) + if (isa<Instruction>(U) && + cast<Instruction>(U)->getParent() == BB) + return true; + return false; +} + bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { assert(IntrinsicI->getIntrinsicID() == Intrinsic::assume && "This function can only be called with llvm.assume intrinsic"); @@ -1403,12 +1433,23 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { // We can replace assume value with true, which covers cases like this: // call void @llvm.assume(i1 %cmp) // br i1 %cmp, label %bb1, label %bb2 ; will change %cmp to true - ReplaceWithConstMap[V] = True; - - // If one of *cmp *eq operand is const, adding it to map will cover this: + ReplaceOperandsWithMap[V] = True; + + // If we find an equality fact, canonicalize all dominated uses in this block + // to one of the two values. We heuristically choice the "oldest" of the + // two where age is determined by value number. (Note that propagateEquality + // above handles the cross block case.) + // + // Key case to cover are: + // 1) // %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen // call void @llvm.assume(i1 %cmp) // ret float %0 ; will change it to ret float 3.000000e+00 + // 2) + // %load = load float, float* %addr + // %cmp = fcmp oeq float %load, %0 + // call void @llvm.assume(i1 %cmp) + // ret float %load ; will change it to ret float %0 if (auto *CmpI = dyn_cast<CmpInst>(V)) { if (CmpI->getPredicate() == CmpInst::Predicate::ICMP_EQ || CmpI->getPredicate() == CmpInst::Predicate::FCMP_OEQ || @@ -1416,13 +1457,50 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { CmpI->getFastMathFlags().noNaNs())) { Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); - if (isa<Constant>(CmpLHS)) + // Heuristically pick the better replacement -- the choice of heuristic + // isn't terribly important here, but the fact we canonicalize on some + // replacement is for exposing other simplifications. + // TODO: pull this out as a helper function and reuse w/existing + // (slightly different) logic. + if (isa<Constant>(CmpLHS) && !isa<Constant>(CmpRHS)) std::swap(CmpLHS, CmpRHS); - auto *RHSConst = dyn_cast<Constant>(CmpRHS); + if (!isa<Instruction>(CmpLHS) && isa<Instruction>(CmpRHS)) + std::swap(CmpLHS, CmpRHS); + if ((isa<Argument>(CmpLHS) && isa<Argument>(CmpRHS)) || + (isa<Instruction>(CmpLHS) && isa<Instruction>(CmpRHS))) { + // Move the 'oldest' value to the right-hand side, using the value + // number as a proxy for age. + uint32_t LVN = VN.lookupOrAdd(CmpLHS); + uint32_t RVN = VN.lookupOrAdd(CmpRHS); + if (LVN < RVN) + std::swap(CmpLHS, CmpRHS); + } - // If only one operand is constant. - if (RHSConst != nullptr && !isa<Constant>(CmpLHS)) - ReplaceWithConstMap[CmpLHS] = RHSConst; + // Handle degenerate case where we either haven't pruned a dead path or a + // removed a trivial assume yet. + if (isa<Constant>(CmpLHS) && isa<Constant>(CmpRHS)) + return Changed; + + // +0.0 and -0.0 compare equal, but do not imply equivalence. Unless we + // can prove equivalence, bail. + if (CmpRHS->getType()->isFloatTy() && + (!isa<ConstantFP>(CmpRHS) || cast<ConstantFP>(CmpRHS)->isZero())) + return Changed; + + LLVM_DEBUG(dbgs() << "Replacing dominated uses of " + << *CmpLHS << " with " + << *CmpRHS << " in block " + << IntrinsicI->getParent()->getName() << "\n"); + + + // Setup the replacement map - this handles uses within the same block + if (hasUsersIn(CmpLHS, IntrinsicI->getParent())) + ReplaceOperandsWithMap[CmpLHS] = CmpRHS; + + // NOTE: The non-block local cases are handled by the call to + // propagateEquality above; this block is just about handling the block + // local cases. TODO: There's a bunch of logic in propagateEqualiy which + // isn't duplicated for the block local case, can we share it somehow? } } return Changed; @@ -1522,6 +1600,41 @@ uint32_t GVN::ValueTable::phiTranslate(const BasicBlock *Pred, return NewNum; } +// Return true if the value number \p Num and NewNum have equal value. +// Return false if the result is unknown. +bool GVN::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum, + const BasicBlock *Pred, + const BasicBlock *PhiBlock, GVN &Gvn) { + CallInst *Call = nullptr; + LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; + while (Vals) { + Call = dyn_cast<CallInst>(Vals->Val); + if (Call && Call->getParent() == PhiBlock) + break; + Vals = Vals->Next; + } + + if (AA->doesNotAccessMemory(Call)) + return true; + + if (!MD || !AA->onlyReadsMemory(Call)) + return false; + + MemDepResult local_dep = MD->getDependency(Call); + if (!local_dep.isNonLocal()) + return false; + + const MemoryDependenceResults::NonLocalDepInfo &deps = + MD->getNonLocalCallDependency(Call); + + // Check to see if the Call has no function local clobber. + for (unsigned i = 0; i < deps.size(); i++) { + if (deps[i].getResult().isNonFuncLocal()) + return true; + } + return false; +} + /// Translate value number \p Num using phis, so that it has the values of /// the phis in BB. uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, @@ -1568,8 +1681,11 @@ uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, } } - if (uint32_t NewNum = expressionNumbering[Exp]) + if (uint32_t NewNum = expressionNumbering[Exp]) { + if (Exp.opcode == Instruction::Call && NewNum != Num) + return areCallValsEqual(Num, NewNum, Pred, PhiBlock, Gvn) ? NewNum : Num; return NewNum; + } return Num; } @@ -1637,16 +1753,12 @@ void GVN::assignBlockRPONumber(Function &F) { InvalidBlockRPONumbers = false; } -// Tries to replace instruction with const, using information from -// ReplaceWithConstMap. -bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { +bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const { bool Changed = false; for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { - Value *Operand = Instr->getOperand(OpNum); - auto it = ReplaceWithConstMap.find(Operand); - if (it != ReplaceWithConstMap.end()) { - assert(!isa<Constant>(Operand) && - "Replacing constants with constants is invalid"); + Value *Operand = Instr->getOperand(OpNum); + auto it = ReplaceOperandsWithMap.find(Operand); + if (it != ReplaceOperandsWithMap.end()) { LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " << *it->second << " in instruction " << *Instr << '\n'); Instr->setOperand(OpNum, it->second); @@ -1976,6 +2088,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, MD = RunMD; ImplicitControlFlowTracking ImplicitCFT(DT); ICF = &ImplicitCFT; + this->LI = LI; VN.setMemDep(MD); ORE = RunORE; InvalidBlockRPONumbers = true; @@ -2037,13 +2150,13 @@ bool GVN::processBlock(BasicBlock *BB) { return false; // Clearing map before every BB because it can be used only for single BB. - ReplaceWithConstMap.clear(); + ReplaceOperandsWithMap.clear(); bool ChangedFunction = false; for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - if (!ReplaceWithConstMap.empty()) - ChangedFunction |= replaceOperandsWithConsts(&*BI); + if (!ReplaceOperandsWithMap.empty()) + ChangedFunction |= replaceOperandsForInBlockEquality(&*BI); ChangedFunction |= processInstruction(&*BI); if (InstrsToErase.empty()) { @@ -2335,7 +2448,7 @@ bool GVN::performPRE(Function &F) { /// the block inserted to the critical edge. BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { BasicBlock *BB = - SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT)); + SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT, LI)); if (MD) MD->invalidateCachedPredecessors(); InvalidBlockRPONumbers = true; @@ -2350,7 +2463,7 @@ bool GVN::splitCriticalEdges() { do { std::pair<Instruction *, unsigned> Edge = toSplit.pop_back_val(); SplitCriticalEdge(Edge.first, Edge.second, - CriticalEdgeSplittingOptions(DT)); + CriticalEdgeSplittingOptions(DT, LI)); } while (!toSplit.empty()); if (MD) MD->invalidateCachedPredecessors(); InvalidBlockRPONumbers = true; @@ -2456,18 +2569,26 @@ void GVN::addDeadBlock(BasicBlock *BB) { if (DeadBlocks.count(B)) continue; + // First, split the critical edges. This might also create additional blocks + // to preserve LoopSimplify form and adjust edges accordingly. SmallVector<BasicBlock *, 4> Preds(pred_begin(B), pred_end(B)); for (BasicBlock *P : Preds) { if (!DeadBlocks.count(P)) continue; - if (isCriticalEdge(P->getTerminator(), GetSuccessorNumber(P, B))) { + if (llvm::any_of(successors(P), + [B](BasicBlock *Succ) { return Succ == B; }) && + isCriticalEdge(P->getTerminator(), B)) { if (BasicBlock *S = splitCriticalEdges(P, B)) DeadBlocks.insert(P = S); } + } - for (BasicBlock::iterator II = B->begin(); isa<PHINode>(II); ++II) { - PHINode &Phi = cast<PHINode>(*II); + // Now undef the incoming values from the dead predecessors. + for (BasicBlock *P : predecessors(B)) { + if (!DeadBlocks.count(P)) + continue; + for (PHINode &Phi : B->phis()) { Phi.setIncomingValueForBlock(P, UndefValue::get(Phi.getType())); if (MD) MD->invalidateCachedPointerInfo(&Phi); @@ -2544,10 +2665,11 @@ public: return Impl.runImpl( F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), getAnalysis<AAResultsWrapperPass>().getAAResults(), - NoMemDepAnalysis ? nullptr - : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), + NoMemDepAnalysis + ? nullptr + : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), LIWP ? &LIWP->getLoopInfo() : nullptr, &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); } @@ -2556,6 +2678,7 @@ public: AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); if (!NoMemDepAnalysis) AU.addRequired<MemoryDependenceWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); @@ -2563,6 +2686,8 @@ public: AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addPreservedID(LoopSimplifyID); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp index 7614599653c4..c87e41484b13 100644 --- a/lib/Transforms/Scalar/GVNHoist.cpp +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -257,7 +257,7 @@ public: GVNHoist(DominatorTree *DT, PostDominatorTree *PDT, AliasAnalysis *AA, MemoryDependenceResults *MD, MemorySSA *MSSA) : DT(DT), PDT(PDT), AA(AA), MD(MD), MSSA(MSSA), - MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {} + MSSAUpdater(std::make_unique<MemorySSAUpdater>(MSSA)) {} bool run(Function &F) { NumFuncArgs = F.arg_size(); @@ -539,7 +539,7 @@ private: // Check for unsafe hoistings due to side effects. if (K == InsKind::Store) { - if (hasEHOrLoadsOnPath(NewPt, dyn_cast<MemoryDef>(U), NBBsOnAllPaths)) + if (hasEHOrLoadsOnPath(NewPt, cast<MemoryDef>(U), NBBsOnAllPaths)) return false; } else if (hasEHOnPath(NewBB, OldBB, NBBsOnAllPaths)) return false; @@ -889,19 +889,18 @@ private: void updateAlignment(Instruction *I, Instruction *Repl) { if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { - ReplacementLoad->setAlignment( - std::min(ReplacementLoad->getAlignment(), - cast<LoadInst>(I)->getAlignment())); + ReplacementLoad->setAlignment(MaybeAlign(std::min( + ReplacementLoad->getAlignment(), cast<LoadInst>(I)->getAlignment()))); ++NumLoadsRemoved; } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { ReplacementStore->setAlignment( - std::min(ReplacementStore->getAlignment(), - cast<StoreInst>(I)->getAlignment())); + MaybeAlign(std::min(ReplacementStore->getAlignment(), + cast<StoreInst>(I)->getAlignment()))); ++NumStoresRemoved; } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { ReplacementAlloca->setAlignment( - std::max(ReplacementAlloca->getAlignment(), - cast<AllocaInst>(I)->getAlignment())); + MaybeAlign(std::max(ReplacementAlloca->getAlignment(), + cast<AllocaInst>(I)->getAlignment()))); } else if (isa<CallInst>(Repl)) { ++NumCallsRemoved; } diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index e14f44bb7069..2697d7809568 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -591,7 +591,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, else Result = RC.getCheckInst(); } - + assert(Result && "Failed to find result value"); Result->setName("wide.chk"); } return true; diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index f9fc698a4a9b..5519a00c12c9 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -124,6 +124,11 @@ static cl::opt<bool> DisableLFTR("disable-lftr", cl::Hidden, cl::init(false), cl::desc("Disable Linear Function Test Replace optimization")); +static cl::opt<bool> +LoopPredication("indvars-predicate-loops", cl::Hidden, cl::init(false), + cl::desc("Predicate conditions in read only loops")); + + namespace { struct RewritePhi; @@ -144,7 +149,11 @@ class IndVarSimplify { bool rewriteNonIntegerIVs(Loop *L); bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); - bool optimizeLoopExits(Loop *L); + /// Try to eliminate loop exits based on analyzeable exit counts + bool optimizeLoopExits(Loop *L, SCEVExpander &Rewriter); + /// Try to form loop invariant tests for loop exits by changing how many + /// iterations of the loop run when that is unobservable. + bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); @@ -628,12 +637,30 @@ bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // Okay, this instruction has a user outside of the current loop // and varies predictably *inside* the loop. Evaluate the value it - // contains when the loop exits, if possible. + // contains when the loop exits, if possible. We prefer to start with + // expressions which are true for all exits (so as to maximize + // expression reuse by the SCEVExpander), but resort to per-exit + // evaluation if that fails. const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); - if (!SE->isLoopInvariant(ExitValue, L) || - !isSafeToExpand(ExitValue, *SE)) - continue; - + if (isa<SCEVCouldNotCompute>(ExitValue) || + !SE->isLoopInvariant(ExitValue, L) || + !isSafeToExpand(ExitValue, *SE)) { + // TODO: This should probably be sunk into SCEV in some way; maybe a + // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for + // most SCEV expressions and other recurrence types (e.g. shift + // recurrences). Is there existing code we can reuse? + const SCEV *ExitCount = SE->getExitCount(L, PN->getIncomingBlock(i)); + if (isa<SCEVCouldNotCompute>(ExitCount)) + continue; + if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Inst))) + if (AddRec->getLoop() == L) + ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE); + if (isa<SCEVCouldNotCompute>(ExitValue) || + !SE->isLoopInvariant(ExitValue, L) || + !isSafeToExpand(ExitValue, *SE)) + continue; + } + // Computing the value outside of the loop brings no benefit if it is // definitely used inside the loop in a way which can not be optimized // away. Avoid doing so unless we know we have a value which computes @@ -804,7 +831,7 @@ bool IndVarSimplify::canLoopBeDeleted( L->getExitingBlocks(ExitingBlocks); SmallVector<BasicBlock *, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); - if (ExitBlocks.size() > 1 || ExitingBlocks.size() > 1) + if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1) return false; BasicBlock *ExitBlock = ExitBlocks[0]; @@ -1654,6 +1681,10 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { return nullptr; } + // if we reached this point then we are going to replace + // DU.NarrowUse with WideUse. Reattach DbgValue then. + replaceAllDbgUsesWith(*DU.NarrowUse, *WideUse, *WideUse, *DT); + ExtendKindMap[DU.NarrowUse] = WideAddRec.second; // Returning WideUse pushes it on the worklist. return WideUse; @@ -1779,14 +1810,9 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { DeadInsts.emplace_back(DU.NarrowDef); } - // Attach any debug information to the new PHI. Since OrigPhi and WidePHI - // evaluate the same recurrence, we can just copy the debug info over. - SmallVector<DbgValueInst *, 1> DbgValues; - llvm::findDbgValues(DbgValues, OrigPhi); - auto *MDPhi = MetadataAsValue::get(WidePhi->getContext(), - ValueAsMetadata::get(WidePhi)); - for (auto &DbgValue : DbgValues) - DbgValue->setOperand(0, MDPhi); + // Attach any debug information to the new PHI. + replaceAllDbgUsesWith(*OrigPhi, *WidePhi, *WidePhi, *DT); + return WidePhi; } @@ -1817,8 +1843,8 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef, auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS)); auto CmpConstrainedLHSRange = ConstantRange::makeAllowedICmpRegion(P, CmpRHSRange); - auto NarrowDefRange = - CmpConstrainedLHSRange.addWithNoSignedWrap(*NarrowDefRHS); + auto NarrowDefRange = CmpConstrainedLHSRange.addWithNoWrap( + *NarrowDefRHS, OverflowingBinaryOperator::NoSignedWrap); updatePostIncRangeInfo(NarrowDef, NarrowUser, NarrowDefRange); }; @@ -2242,8 +2268,8 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy()) continue; - const auto *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); - + const auto *AR = cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); + // AR may be a pointer type, while BECount is an integer type. // AR may be wider than BECount. With eq/ne tests overflow is immaterial. // AR may not be a narrower type, or we may never exit. @@ -2624,74 +2650,125 @@ bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { return MadeAnyChanges; } -bool IndVarSimplify::optimizeLoopExits(Loop *L) { +/// Return a symbolic upper bound for the backedge taken count of the loop. +/// This is more general than getConstantMaxBackedgeTakenCount as it returns +/// an arbitrary expression as opposed to only constants. +/// TODO: Move into the ScalarEvolution class. +static const SCEV* getMaxBackedgeTakenCount(ScalarEvolution &SE, + DominatorTree &DT, Loop *L) { SmallVector<BasicBlock*, 16> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); // Form an expression for the maximum exit count possible for this loop. We // merge the max and exact information to approximate a version of - // getMaxBackedgeTakenInfo which isn't restricted to just constants. - // TODO: factor this out as a version of getMaxBackedgeTakenCount which - // isn't guaranteed to return a constant. + // getConstantMaxBackedgeTakenCount which isn't restricted to just constants. SmallVector<const SCEV*, 4> ExitCounts; - const SCEV *MaxConstEC = SE->getMaxBackedgeTakenCount(L); + const SCEV *MaxConstEC = SE.getConstantMaxBackedgeTakenCount(L); if (!isa<SCEVCouldNotCompute>(MaxConstEC)) ExitCounts.push_back(MaxConstEC); for (BasicBlock *ExitingBB : ExitingBlocks) { - const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + const SCEV *ExitCount = SE.getExitCount(L, ExitingBB); if (!isa<SCEVCouldNotCompute>(ExitCount)) { - assert(DT->dominates(ExitingBB, L->getLoopLatch()) && + assert(DT.dominates(ExitingBB, L->getLoopLatch()) && "We should only have known counts for exiting blocks that " "dominate latch!"); ExitCounts.push_back(ExitCount); } } if (ExitCounts.empty()) - return false; - const SCEV *MaxExitCount = SE->getUMinFromMismatchedTypes(ExitCounts); + return SE.getCouldNotCompute(); + return SE.getUMinFromMismatchedTypes(ExitCounts); +} - bool Changed = false; - for (BasicBlock *ExitingBB : ExitingBlocks) { +bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { + SmallVector<BasicBlock*, 16> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // Remove all exits which aren't both rewriteable and analyzeable. + auto NewEnd = llvm::remove_if(ExitingBlocks, + [&](BasicBlock *ExitingBB) { // If our exitting block exits multiple loops, we can only rewrite the // innermost one. Otherwise, we're changing how many times the innermost // loop runs before it exits. if (LI->getLoopFor(ExitingBB) != L) - continue; + return true; // Can't rewrite non-branch yet. BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); if (!BI) - continue; + return true; // If already constant, nothing to do. if (isa<Constant>(BI->getCondition())) - continue; + return true; const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); if (isa<SCEVCouldNotCompute>(ExitCount)) - continue; + return true; + return false; + }); + ExitingBlocks.erase(NewEnd, ExitingBlocks.end()); + + if (ExitingBlocks.empty()) + return false; + + // Get a symbolic upper bound on the loop backedge taken count. + const SCEV *MaxExitCount = getMaxBackedgeTakenCount(*SE, *DT, L); + if (isa<SCEVCouldNotCompute>(MaxExitCount)) + return false; + + // Visit our exit blocks in order of dominance. We know from the fact that + // all exits (left) are analyzeable that the must be a total dominance order + // between them as each must dominate the latch. The visit order only + // matters for the provably equal case. + llvm::sort(ExitingBlocks, + [&](BasicBlock *A, BasicBlock *B) { + // std::sort sorts in ascending order, so we want the inverse of + // the normal dominance relation. + if (DT->properlyDominates(A, B)) return true; + if (DT->properlyDominates(B, A)) return false; + llvm_unreachable("expected total dominance order!"); + }); +#ifdef ASSERT + for (unsigned i = 1; i < ExitingBlocks.size(); i++) { + assert(DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i])); + } +#endif + + auto FoldExit = [&](BasicBlock *ExitingBB, bool IsTaken) { + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); + bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); + auto *OldCond = BI->getCondition(); + auto *NewCond = ConstantInt::get(OldCond->getType(), + IsTaken ? ExitIfTrue : !ExitIfTrue); + BI->setCondition(NewCond); + if (OldCond->use_empty()) + DeadInsts.push_back(OldCond); + }; + bool Changed = false; + SmallSet<const SCEV*, 8> DominatingExitCounts; + for (BasicBlock *ExitingBB : ExitingBlocks) { + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + assert(!isa<SCEVCouldNotCompute>(ExitCount) && "checked above"); + // If we know we'd exit on the first iteration, rewrite the exit to // reflect this. This does not imply the loop must exit through this // exit; there may be an earlier one taken on the first iteration. // TODO: Given we know the backedge can't be taken, we should go ahead // and break it. Or at least, kill all the header phis and simplify. if (ExitCount->isZero()) { - bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); - auto *OldCond = BI->getCondition(); - auto *NewCond = ExitIfTrue ? ConstantInt::getTrue(OldCond->getType()) : - ConstantInt::getFalse(OldCond->getType()); - BI->setCondition(NewCond); - if (OldCond->use_empty()) - DeadInsts.push_back(OldCond); + FoldExit(ExitingBB, true); Changed = true; continue; } - // If we end up with a pointer exit count, bail. + // If we end up with a pointer exit count, bail. Note that we can end up + // with a pointer exit count for one exiting block, and not for another in + // the same loop. if (!ExitCount->getType()->isIntegerTy() || !MaxExitCount->getType()->isIntegerTy()) - return false; + continue; Type *WiderType = SE->getWiderType(MaxExitCount->getType(), ExitCount->getType()); @@ -2700,35 +2777,198 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L) { assert(MaxExitCount->getType() == ExitCount->getType()); // Can we prove that some other exit must be taken strictly before this - // one? TODO: handle cases where ule is known, and equality is covered - // by a dominating exit + // one? if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT, MaxExitCount, ExitCount)) { - bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); - auto *OldCond = BI->getCondition(); - auto *NewCond = ExitIfTrue ? ConstantInt::getFalse(OldCond->getType()) : - ConstantInt::getTrue(OldCond->getType()); - BI->setCondition(NewCond); - if (OldCond->use_empty()) - DeadInsts.push_back(OldCond); + FoldExit(ExitingBB, false); Changed = true; continue; } - // TODO: If we can prove that the exiting iteration is equal to the exit - // count for this exit and that no previous exit oppurtunities exist within - // the loop, then we can discharge all other exits. (May fall out of - // previous TODO.) - - // TODO: If we can't prove any relation between our exit count and the - // loops exit count, but taking this exit doesn't require actually running - // the loop (i.e. no side effects, no computed values used in exit), then - // we can replace the exit test with a loop invariant test which exits on - // the first iteration. + // As we run, keep track of which exit counts we've encountered. If we + // find a duplicate, we've found an exit which would have exited on the + // exiting iteration, but (from the visit order) strictly follows another + // which does the same and is thus dead. + if (!DominatingExitCounts.insert(ExitCount).second) { + FoldExit(ExitingBB, false); + Changed = true; + continue; + } + + // TODO: There might be another oppurtunity to leverage SCEV's reasoning + // here. If we kept track of the min of dominanting exits so far, we could + // discharge exits with EC >= MDEC. This is less powerful than the existing + // transform (since later exits aren't considered), but potentially more + // powerful for any case where SCEV can prove a >=u b, but neither a == b + // or a >u b. Such a case is not currently known. } return Changed; } +bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { + SmallVector<BasicBlock*, 16> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + bool Changed = false; + + // Finally, see if we can rewrite our exit conditions into a loop invariant + // form. If we have a read-only loop, and we can tell that we must exit down + // a path which does not need any of the values computed within the loop, we + // can rewrite the loop to exit on the first iteration. Note that this + // doesn't either a) tell us the loop exits on the first iteration (unless + // *all* exits are predicateable) or b) tell us *which* exit might be taken. + // This transformation looks a lot like a restricted form of dead loop + // elimination, but restricted to read-only loops and without neccesssarily + // needing to kill the loop entirely. + if (!LoopPredication) + return Changed; + + if (!SE->hasLoopInvariantBackedgeTakenCount(L)) + return Changed; + + // Note: ExactBTC is the exact backedge taken count *iff* the loop exits + // through *explicit* control flow. We have to eliminate the possibility of + // implicit exits (see below) before we know it's truly exact. + const SCEV *ExactBTC = SE->getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(ExactBTC) || + !SE->isLoopInvariant(ExactBTC, L) || + !isSafeToExpand(ExactBTC, *SE)) + return Changed; + + auto BadExit = [&](BasicBlock *ExitingBB) { + // If our exiting block exits multiple loops, we can only rewrite the + // innermost one. Otherwise, we're changing how many times the innermost + // loop runs before it exits. + if (LI->getLoopFor(ExitingBB) != L) + return true; + + // Can't rewrite non-branch yet. + BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); + if (!BI) + return true; + + // If already constant, nothing to do. + if (isa<Constant>(BI->getCondition())) + return true; + + // If the exit block has phis, we need to be able to compute the values + // within the loop which contains them. This assumes trivially lcssa phis + // have already been removed; TODO: generalize + BasicBlock *ExitBlock = + BI->getSuccessor(L->contains(BI->getSuccessor(0)) ? 1 : 0); + if (!ExitBlock->phis().empty()) + return true; + + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + assert(!isa<SCEVCouldNotCompute>(ExactBTC) && "implied by having exact trip count"); + if (!SE->isLoopInvariant(ExitCount, L) || + !isSafeToExpand(ExitCount, *SE)) + return true; + + return false; + }; + + // If we have any exits which can't be predicated themselves, than we can't + // predicate any exit which isn't guaranteed to execute before it. Consider + // two exits (a) and (b) which would both exit on the same iteration. If we + // can predicate (b), but not (a), and (a) preceeds (b) along some path, then + // we could convert a loop from exiting through (a) to one exiting through + // (b). Note that this problem exists only for exits with the same exit + // count, and we could be more aggressive when exit counts are known inequal. + llvm::sort(ExitingBlocks, + [&](BasicBlock *A, BasicBlock *B) { + // std::sort sorts in ascending order, so we want the inverse of + // the normal dominance relation, plus a tie breaker for blocks + // unordered by dominance. + if (DT->properlyDominates(A, B)) return true; + if (DT->properlyDominates(B, A)) return false; + return A->getName() < B->getName(); + }); + // Check to see if our exit blocks are a total order (i.e. a linear chain of + // exits before the backedge). If they aren't, reasoning about reachability + // is complicated and we choose not to for now. + for (unsigned i = 1; i < ExitingBlocks.size(); i++) + if (!DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i])) + return Changed; + + // Given our sorted total order, we know that exit[j] must be evaluated + // after all exit[i] such j > i. + for (unsigned i = 0, e = ExitingBlocks.size(); i < e; i++) + if (BadExit(ExitingBlocks[i])) { + ExitingBlocks.resize(i); + break; + } + + if (ExitingBlocks.empty()) + return Changed; + + // We rely on not being able to reach an exiting block on a later iteration + // then it's statically compute exit count. The implementaton of + // getExitCount currently has this invariant, but assert it here so that + // breakage is obvious if this ever changes.. + assert(llvm::all_of(ExitingBlocks, [&](BasicBlock *ExitingBB) { + return DT->dominates(ExitingBB, L->getLoopLatch()); + })); + + // At this point, ExitingBlocks consists of only those blocks which are + // predicatable. Given that, we know we have at least one exit we can + // predicate if the loop is doesn't have side effects and doesn't have any + // implicit exits (because then our exact BTC isn't actually exact). + // @Reviewers - As structured, this is O(I^2) for loop nests. Any + // suggestions on how to improve this? I can obviously bail out for outer + // loops, but that seems less than ideal. MemorySSA can find memory writes, + // is that enough for *all* side effects? + for (BasicBlock *BB : L->blocks()) + for (auto &I : *BB) + // TODO:isGuaranteedToTransfer + if (I.mayHaveSideEffects() || I.mayThrow()) + return Changed; + + // Finally, do the actual predication for all predicatable blocks. A couple + // of notes here: + // 1) We don't bother to constant fold dominated exits with identical exit + // counts; that's simply a form of CSE/equality propagation and we leave + // it for dedicated passes. + // 2) We insert the comparison at the branch. Hoisting introduces additional + // legality constraints and we leave that to dedicated logic. We want to + // predicate even if we can't insert a loop invariant expression as + // peeling or unrolling will likely reduce the cost of the otherwise loop + // varying check. + Rewriter.setInsertPoint(L->getLoopPreheader()->getTerminator()); + IRBuilder<> B(L->getLoopPreheader()->getTerminator()); + Value *ExactBTCV = nullptr; //lazy generated if needed + for (BasicBlock *ExitingBB : ExitingBlocks) { + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + + auto *BI = cast<BranchInst>(ExitingBB->getTerminator()); + Value *NewCond; + if (ExitCount == ExactBTC) { + NewCond = L->contains(BI->getSuccessor(0)) ? + B.getFalse() : B.getTrue(); + } else { + Value *ECV = Rewriter.expandCodeFor(ExitCount); + if (!ExactBTCV) + ExactBTCV = Rewriter.expandCodeFor(ExactBTC); + Value *RHS = ExactBTCV; + if (ECV->getType() != RHS->getType()) { + Type *WiderTy = SE->getWiderType(ECV->getType(), RHS->getType()); + ECV = B.CreateZExt(ECV, WiderTy); + RHS = B.CreateZExt(RHS, WiderTy); + } + auto Pred = L->contains(BI->getSuccessor(0)) ? + ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + NewCond = B.CreateICmp(Pred, ECV, RHS); + } + Value *OldCond = BI->getCondition(); + BI->setCondition(NewCond); + if (OldCond->use_empty()) + DeadInsts.push_back(OldCond); + Changed = true; + } + + return Changed; +} + //===----------------------------------------------------------------------===// // IndVarSimplify driver. Manage several subpasses of IV simplification. //===----------------------------------------------------------------------===// @@ -2755,7 +2995,10 @@ bool IndVarSimplify::run(Loop *L) { // transform them to use integer recurrences. Changed |= rewriteNonIntegerIVs(L); +#ifndef NDEBUG + // Used below for a consistency check only const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); +#endif // Create a rewriter object which we'll use to transform the code with. SCEVExpander Rewriter(*SE, DL, "indvars"); @@ -2772,20 +3015,22 @@ bool IndVarSimplify::run(Loop *L) { Rewriter.disableCanonicalMode(); Changed |= simplifyAndExtend(L, Rewriter, LI); - // Check to see if this loop has a computable loop-invariant execution count. - // If so, this means that we can compute the final value of any expressions + // Check to see if we can compute the final value of any expressions // that are recurrent in the loop, and substitute the exit values from the - // loop into any instructions outside of the loop that use the final values of - // the current expressions. - // - if (ReplaceExitValue != NeverRepl && - !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + // loop into any instructions outside of the loop that use the final values + // of the current expressions. + if (ReplaceExitValue != NeverRepl) Changed |= rewriteLoopExitValues(L, Rewriter); // Eliminate redundant IV cycles. NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); - Changed |= optimizeLoopExits(L); + // Try to eliminate loop exits based on analyzeable exit counts + Changed |= optimizeLoopExits(L, Rewriter); + + // Try to form loop invariant tests for loop exits by changing how many + // iterations of the loop run when that is unobservable. + Changed |= predicateLoopExits(L, Rewriter); // If we have a trip count expression, rewrite the loop's exit condition // using it. @@ -2825,7 +3070,7 @@ bool IndVarSimplify::run(Loop *L) { // that our definition of "high cost" is not exactly principled. if (Rewriter.isHighCostExpansion(ExitCount, L)) continue; - + // Check preconditions for proper SCEVExpander operation. SCEV does not // express SCEVExpander's dependencies, such as LoopSimplify. Instead // any pass that uses the SCEVExpander must do it. This does not work @@ -2924,7 +3169,7 @@ struct IndVarSimplifyLegacyPass : public LoopPass { auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr; auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp index 5f0e2001c73d..e7e73a132fbe 100644 --- a/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -141,6 +141,8 @@ using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; /// InferAddressSpaces class InferAddressSpaces : public FunctionPass { + const TargetTransformInfo *TTI; + /// Target specific address space which uses of should be replaced if /// possible. unsigned FlatAddrSpace; @@ -264,17 +266,6 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, Module *M = II->getParent()->getParent()->getParent(); switch (II->getIntrinsicID()) { - case Intrinsic::amdgcn_atomic_inc: - case Intrinsic::amdgcn_atomic_dec: - case Intrinsic::amdgcn_ds_fadd: - case Intrinsic::amdgcn_ds_fmin: - case Intrinsic::amdgcn_ds_fmax: { - const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4)); - if (!IsVolatile || !IsVolatile->isZero()) - return false; - - LLVM_FALLTHROUGH; - } case Intrinsic::objectsize: { Type *DestTy = II->getType(); Type *SrcTy = NewV->getType(); @@ -285,25 +276,27 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, return true; } default: - return false; + return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); } } -// TODO: Move logic to TTI? void InferAddressSpaces::collectRewritableIntrinsicOperands( IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack, DenseSet<Value *> &Visited) const { - switch (II->getIntrinsicID()) { + auto IID = II->getIntrinsicID(); + switch (IID) { case Intrinsic::objectsize: - case Intrinsic::amdgcn_atomic_inc: - case Intrinsic::amdgcn_atomic_dec: - case Intrinsic::amdgcn_ds_fadd: - case Intrinsic::amdgcn_ds_fmin: - case Intrinsic::amdgcn_ds_fmax: appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); break; default: + SmallVector<int, 2> OpIndexes; + if (TTI->collectFlatAddressOperands(OpIndexes, IID)) { + for (int Idx : OpIndexes) { + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(Idx), + PostorderStack, Visited); + } + } break; } } @@ -631,11 +624,10 @@ bool InferAddressSpaces::runOnFunction(Function &F) { if (skipFunction(F)) return false; - const TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); if (FlatAddrSpace == UninitializedAddressSpace) { - FlatAddrSpace = TTI.getFlatAddressSpace(); + FlatAddrSpace = TTI->getFlatAddressSpace(); if (FlatAddrSpace == UninitializedAddressSpace) return false; } @@ -650,7 +642,7 @@ bool InferAddressSpaces::runOnFunction(Function &F) { // Changes the address spaces of the flat address expressions who are inferred // to point to a specific address space. - return rewriteWithNewAddressSpaces(TTI, Postorder, InferredAddrSpace, &F); + return rewriteWithNewAddressSpaces(*TTI, Postorder, InferredAddrSpace, &F); } // Constants need to be tracked through RAUW to handle cases with nested diff --git a/lib/Transforms/Scalar/InstSimplifyPass.cpp b/lib/Transforms/Scalar/InstSimplifyPass.cpp index 6616364ab203..ec28f790f252 100644 --- a/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -33,37 +33,39 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ, bool Changed = false; do { - for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { - // Here be subtlety: the iterator must be incremented before the loop - // body (not sure why), so a range-for loop won't work here. - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - Instruction *I = &*BI++; - // The first time through the loop ToSimplify is empty and we try to - // simplify all instructions. On later iterations ToSimplify is not + for (BasicBlock &BB : F) { + // Unreachable code can take on strange forms that we are not prepared to + // handle. For example, an instruction may have itself as an operand. + if (!SQ.DT->isReachableFromEntry(&BB)) + continue; + + SmallVector<Instruction *, 8> DeadInstsInBB; + for (Instruction &I : BB) { + // The first time through the loop, ToSimplify is empty and we try to + // simplify all instructions. On later iterations, ToSimplify is not // empty and we only bother simplifying instructions that are in it. - if (!ToSimplify->empty() && !ToSimplify->count(I)) + if (!ToSimplify->empty() && !ToSimplify->count(&I)) continue; - // Don't waste time simplifying unused instructions. - if (!I->use_empty()) { - if (Value *V = SimplifyInstruction(I, SQ, ORE)) { + // Don't waste time simplifying dead/unused instructions. + if (isInstructionTriviallyDead(&I)) { + DeadInstsInBB.push_back(&I); + Changed = true; + } else if (!I.use_empty()) { + if (Value *V = SimplifyInstruction(&I, SQ, ORE)) { // Mark all uses for resimplification next time round the loop. - for (User *U : I->users()) + for (User *U : I.users()) Next->insert(cast<Instruction>(U)); - I->replaceAllUsesWith(V); + I.replaceAllUsesWith(V); ++NumSimplified; Changed = true; + // A call can get simplified, but it may not be trivially dead. + if (isInstructionTriviallyDead(&I)) + DeadInstsInBB.push_back(&I); } } - if (RecursivelyDeleteTriviallyDeadInstructions(I, SQ.TLI)) { - // RecursivelyDeleteTriviallyDeadInstruction can remove more than one - // instruction, so simply incrementing the iterator does not work. - // When instructions get deleted re-iterate instead. - BI = BB->begin(); - BE = BB->end(); - Changed = true; - } } + RecursivelyDeleteTriviallyDeadInstructions(DeadInstsInBB, SQ.TLI); } // Place the list of instructions to simplify on the next loop iteration @@ -90,7 +92,7 @@ struct InstSimplifyLegacyPass : public FunctionPass { AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } - /// runOnFunction - Remove instructions that simplify. + /// Remove instructions that simplify. bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; @@ -98,7 +100,7 @@ struct InstSimplifyLegacyPass : public FunctionPass { const DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); OptimizationRemarkEmitter *ORE = diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index b86bf2fefbe5..0cf00baaa24a 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -224,13 +224,21 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { BasicBlock *PhiBB) -> std::pair<BasicBlock *, BasicBlock *> { auto *PredBB = IncomingBB; auto *SuccBB = PhiBB; + SmallPtrSet<BasicBlock *, 16> Visited; while (true) { BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator()); if (PredBr && PredBr->isConditional()) return {PredBB, SuccBB}; + Visited.insert(PredBB); auto *SinglePredBB = PredBB->getSinglePredecessor(); if (!SinglePredBB) return {nullptr, nullptr}; + + // Stop searching when SinglePredBB has been visited. It means we see + // an unreachable loop. + if (Visited.count(SinglePredBB)) + return {nullptr, nullptr}; + SuccBB = PredBB; PredBB = SinglePredBB; } @@ -253,7 +261,9 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { return; BasicBlock *PredBB = PredOutEdge.first; - BranchInst *PredBr = cast<BranchInst>(PredBB->getTerminator()); + BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator()); + if (!PredBr) + return; uint64_t PredTrueWeight, PredFalseWeight; // FIXME: We currently only set the profile data when it is missing. @@ -286,7 +296,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { bool JumpThreading::runOnFunction(Function &F) { if (skipFunction(F)) return false; - auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); // Get DT analysis before LVI. When LVI is initialized it conditionally adds // DT if it's available. auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -1461,7 +1471,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { "Can't handle critical edge here!"); LoadInst *NewVal = new LoadInst( LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), - LoadI->getName() + ".pr", false, LoadI->getAlignment(), + LoadI->getName() + ".pr", false, MaybeAlign(LoadI->getAlignment()), LoadI->getOrdering(), LoadI->getSyncScopeID(), UnavailablePred->getTerminator()); NewVal->setDebugLoc(LoadI->getDebugLoc()); @@ -2423,7 +2433,7 @@ void JumpThreadingPass::UnfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB, // |----- // v // BB - BranchInst *PredTerm = dyn_cast<BranchInst>(Pred->getTerminator()); + BranchInst *PredTerm = cast<BranchInst>(Pred->getTerminator()); BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "select.unfold", BB->getParent(), BB); // Move the unconditional branch to NewBB. diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index d9dda4cef2d2..6ce4831a7359 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -220,7 +220,8 @@ struct LegacyLICMPass : public LoopPass { &getAnalysis<AAResultsWrapperPass>().getAAResults(), &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( + *L->getHeader()->getParent()), &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( *L->getHeader()->getParent()), SE ? &SE->getSE() : nullptr, MSSA, &ORE, false); @@ -294,7 +295,7 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LoopAnalysis>(); - if (EnableMSSALoopDependency) + if (AR.MSSA) PA.preserve<MemorySSAAnalysis>(); return PA; @@ -330,6 +331,12 @@ bool LoopInvariantCodeMotion::runOnLoop( assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); + // If this loop has metadata indicating that LICM is not to be performed then + // just exit. + if (hasDisableLICMTransformsHint(L)) { + return false; + } + std::unique_ptr<AliasSetTracker> CurAST; std::unique_ptr<MemorySSAUpdater> MSSAU; bool NoOfMemAccTooLarge = false; @@ -340,7 +347,7 @@ bool LoopInvariantCodeMotion::runOnLoop( CurAST = collectAliasInfoForLoop(L, LI, AA); } else { LLVM_DEBUG(dbgs() << "LICM: Using MemorySSA.\n"); - MSSAU = make_unique<MemorySSAUpdater>(MSSA); + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); unsigned AccessCapCount = 0; for (auto *BB : L->getBlocks()) { @@ -956,7 +963,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // Now that we've finished hoisting make sure that LI and DT are still // valid. -#ifndef NDEBUG +#ifdef EXPENSIVE_CHECKS if (Changed) { assert(DT->verify(DominatorTree::VerificationLevel::Fast) && "Dominator tree verification failed"); @@ -1026,7 +1033,8 @@ namespace { bool isHoistableAndSinkableInst(Instruction &I) { // Only these instructions are hoistable/sinkable. return (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<CallInst>(I) || - isa<FenceInst>(I) || isa<BinaryOperator>(I) || isa<CastInst>(I) || + isa<FenceInst>(I) || isa<CastInst>(I) || + isa<UnaryOperator>(I) || isa<BinaryOperator>(I) || isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) || @@ -1092,7 +1100,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // in the same alias set as something that ends up being modified. if (AA->pointsToConstantMemory(LI->getOperand(0))) return true; - if (LI->getMetadata(LLVMContext::MD_invariant_load)) + if (LI->hasMetadata(LLVMContext::MD_invariant_load)) return true; if (LI->isAtomic() && !TargetExecutesOncePerLoop) @@ -1240,12 +1248,22 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // FIXME: More precise: no Uses that alias SI. if (!Flags->IsSink && !MSSA->dominates(SIMD, MU)) return false; - } else if (const auto *MD = dyn_cast<MemoryDef>(&MA)) + } else if (const auto *MD = dyn_cast<MemoryDef>(&MA)) { if (auto *LI = dyn_cast<LoadInst>(MD->getMemoryInst())) { (void)LI; // Silence warning. assert(!LI->isUnordered() && "Expected unordered load"); return false; } + // Any call, while it may not be clobbering SI, it may be a use. + if (auto *CI = dyn_cast<CallInst>(MD->getMemoryInst())) { + // Check if the call may read from the memory locattion written + // to by SI. Check CI's attributes and arguments; the number of + // such checks performed is limited above by NoOfMemAccTooLarge. + ModRefInfo MRI = AA->getModRefInfo(CI, MemoryLocation::get(SI)); + if (isModOrRefSet(MRI)) + return false; + } + } } auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); @@ -1375,8 +1393,7 @@ static Instruction *CloneInstructionInExitBlock( if (!I.getName().empty()) New->setName(I.getName() + ".le"); - MemoryAccess *OldMemAcc; - if (MSSAU && (OldMemAcc = MSSAU->getMemorySSA()->getMemoryAccess(&I))) { + if (MSSAU && MSSAU->getMemorySSA()->getMemoryAccess(&I)) { // Create a new MemoryAccess and let MemorySSA set its defining access. MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( New, nullptr, New->getParent(), MemorySSA::Beginning); @@ -1385,7 +1402,7 @@ static Instruction *CloneInstructionInExitBlock( MSSAU->insertDef(MemDef, /*RenameUses=*/true); else { auto *MemUse = cast<MemoryUse>(NewMemAcc); - MSSAU->insertUse(MemUse); + MSSAU->insertUse(MemUse, /*RenameUses=*/true); } } } @@ -1783,7 +1800,7 @@ public: StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); if (UnorderedAtomic) NewSI->setOrdering(AtomicOrdering::Unordered); - NewSI->setAlignment(Alignment); + NewSI->setAlignment(MaybeAlign(Alignment)); NewSI->setDebugLoc(DL); if (AATags) NewSI->setAAMetadata(AATags); @@ -2016,7 +2033,8 @@ bool llvm::promoteLoopAccessesToScalars( if (!DereferenceableInPH) { DereferenceableInPH = isDereferenceableAndAlignedPointer( Store->getPointerOperand(), Store->getValueOperand()->getType(), - Store->getAlignment(), MDL, Preheader->getTerminator(), DT); + MaybeAlign(Store->getAlignment()), MDL, + Preheader->getTerminator(), DT); } } else return false; // Not a load or store. @@ -2101,20 +2119,21 @@ bool llvm::promoteLoopAccessesToScalars( SomePtr->getName() + ".promoted", Preheader->getTerminator()); if (SawUnorderedAtomic) PreheaderLoad->setOrdering(AtomicOrdering::Unordered); - PreheaderLoad->setAlignment(Alignment); + PreheaderLoad->setAlignment(MaybeAlign(Alignment)); PreheaderLoad->setDebugLoc(DL); if (AATags) PreheaderLoad->setAAMetadata(AATags); SSA.AddAvailableValue(Preheader, PreheaderLoad); - MemoryAccess *PreheaderLoadMemoryAccess; if (MSSAU) { - PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB( + MemoryAccess *PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB( PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End); MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess); - MSSAU->insertUse(NewMemUse); + MSSAU->insertUse(NewMemUse, /*RenameUses=*/true); } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); // Rewrite all the loads in the loop and remember all the definitions from // stores in the loop. Promoter.run(LoopUses); @@ -2161,7 +2180,7 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, LoopToAliasSetMap.erase(MapI); } if (!CurAST) - CurAST = make_unique<AliasSetTracker>(*AA); + CurAST = std::make_unique<AliasSetTracker>(*AA); // Add everything from the sub loops that are no longer directly available. for (Loop *InnerL : RecomputeLoops) @@ -2180,7 +2199,7 @@ std::unique_ptr<AliasSetTracker> LoopInvariantCodeMotion::collectAliasInfoForLoopWithMSSA( Loop *L, AliasAnalysis *AA, MemorySSAUpdater *MSSAU) { auto *MSSA = MSSAU->getMemorySSA(); - auto CurAST = make_unique<AliasSetTracker>(*AA, MSSA, L); + auto CurAST = std::make_unique<AliasSetTracker>(*AA, MSSA, L); CurAST->addAllInstructionsInLoopUsingMSSA(); return CurAST; } diff --git a/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 1fcf1315a177..a972d6fa2fcd 100644 --- a/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -312,8 +312,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { IRBuilder<> Builder(MemI); Module *M = BB->getParent()->getParent(); Type *I32 = Type::getInt32Ty(BB->getContext()); - Function *PrefetchFunc = - Intrinsic::getDeclaration(M, Intrinsic::prefetch); + Function *PrefetchFunc = Intrinsic::getDeclaration( + M, Intrinsic::prefetch, PrefPtrValue->getType()); Builder.CreateCall( PrefetchFunc, {PrefPtrValue, diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index 8371367e24e7..cee197cf8354 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -191,7 +191,7 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, // Don't remove loops for which we can't solve the trip count. // They could be infinite, in which case we'd be changing program behavior. - const SCEV *S = SE.getMaxBackedgeTakenCount(L); + const SCEV *S = SE.getConstantMaxBackedgeTakenCount(L); if (isa<SCEVCouldNotCompute>(S)) { LLVM_DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n"); return Changed ? LoopDeletionResult::Modified diff --git a/lib/Transforms/Scalar/LoopFuse.cpp b/lib/Transforms/Scalar/LoopFuse.cpp index 0bc2bcff2ae1..9f93c68e6128 100644 --- a/lib/Transforms/Scalar/LoopFuse.cpp +++ b/lib/Transforms/Scalar/LoopFuse.cpp @@ -66,7 +66,7 @@ using namespace llvm; #define DEBUG_TYPE "loop-fusion" -STATISTIC(FuseCounter, "Count number of loop fusions performed"); +STATISTIC(FuseCounter, "Loops fused"); STATISTIC(NumFusionCandidates, "Number of candidates for loop fusion"); STATISTIC(InvalidPreheader, "Loop has invalid preheader"); STATISTIC(InvalidHeader, "Loop has invalid header"); @@ -79,12 +79,15 @@ STATISTIC(MayThrowException, "Loop may throw an exception"); STATISTIC(ContainsVolatileAccess, "Loop contains a volatile access"); STATISTIC(NotSimplifiedForm, "Loop is not in simplified form"); STATISTIC(InvalidDependencies, "Dependencies prevent fusion"); -STATISTIC(InvalidTripCount, - "Loop does not have invariant backedge taken count"); +STATISTIC(UnknownTripCount, "Loop has unknown trip count"); STATISTIC(UncomputableTripCount, "SCEV cannot compute trip count of loop"); -STATISTIC(NonEqualTripCount, "Candidate trip counts are not the same"); -STATISTIC(NonAdjacent, "Candidates are not adjacent"); -STATISTIC(NonEmptyPreheader, "Candidate has a non-empty preheader"); +STATISTIC(NonEqualTripCount, "Loop trip counts are not the same"); +STATISTIC(NonAdjacent, "Loops are not adjacent"); +STATISTIC(NonEmptyPreheader, "Loop has a non-empty preheader"); +STATISTIC(FusionNotBeneficial, "Fusion is not beneficial"); +STATISTIC(NonIdenticalGuards, "Candidates have different guards"); +STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block"); +STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block"); enum FusionDependenceAnalysisChoice { FUSION_DEPENDENCE_ANALYSIS_SCEV, @@ -110,6 +113,7 @@ static cl::opt<bool> cl::Hidden, cl::init(false), cl::ZeroOrMore); #endif +namespace { /// This class is used to represent a candidate for loop fusion. When it is /// constructed, it checks the conditions for loop fusion to ensure that it /// represents a valid candidate. It caches several parts of a loop that are @@ -143,6 +147,8 @@ struct FusionCandidate { SmallVector<Instruction *, 16> MemWrites; /// Are all of the members of this fusion candidate still valid bool Valid; + /// Guard branch of the loop, if it exists + BranchInst *GuardBranch; /// Dominator and PostDominator trees are needed for the /// FusionCandidateCompare function, required by FusionCandidateSet to @@ -151,11 +157,20 @@ struct FusionCandidate { const DominatorTree *DT; const PostDominatorTree *PDT; + OptimizationRemarkEmitter &ORE; + FusionCandidate(Loop *L, const DominatorTree *DT, - const PostDominatorTree *PDT) + const PostDominatorTree *PDT, OptimizationRemarkEmitter &ORE) : Preheader(L->getLoopPreheader()), Header(L->getHeader()), ExitingBlock(L->getExitingBlock()), ExitBlock(L->getExitBlock()), - Latch(L->getLoopLatch()), L(L), Valid(true), DT(DT), PDT(PDT) { + Latch(L->getLoopLatch()), L(L), Valid(true), GuardBranch(nullptr), + DT(DT), PDT(PDT), ORE(ORE) { + + // TODO: This is temporary while we fuse both rotated and non-rotated + // loops. Once we switch to only fusing rotated loops, the initialization of + // GuardBranch can be moved into the initialization list above. + if (isRotated()) + GuardBranch = L->getLoopGuardBranch(); // Walk over all blocks in the loop and check for conditions that may // prevent fusion. For each block, walk over all instructions and collect @@ -163,28 +178,28 @@ struct FusionCandidate { // found, invalidate this object and return. for (BasicBlock *BB : L->blocks()) { if (BB->hasAddressTaken()) { - AddressTakenBB++; invalidate(); + reportInvalidCandidate(AddressTakenBB); return; } for (Instruction &I : *BB) { if (I.mayThrow()) { - MayThrowException++; invalidate(); + reportInvalidCandidate(MayThrowException); return; } if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { if (SI->isVolatile()) { - ContainsVolatileAccess++; invalidate(); + reportInvalidCandidate(ContainsVolatileAccess); return; } } if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { if (LI->isVolatile()) { - ContainsVolatileAccess++; invalidate(); + reportInvalidCandidate(ContainsVolatileAccess); return; } } @@ -214,19 +229,96 @@ struct FusionCandidate { assert(Latch == L->getLoopLatch() && "Latch is out of sync"); } + /// Get the entry block for this fusion candidate. + /// + /// If this fusion candidate represents a guarded loop, the entry block is the + /// loop guard block. If it represents an unguarded loop, the entry block is + /// the preheader of the loop. + BasicBlock *getEntryBlock() const { + if (GuardBranch) + return GuardBranch->getParent(); + else + return Preheader; + } + + /// Given a guarded loop, get the successor of the guard that is not in the + /// loop. + /// + /// This method returns the successor of the loop guard that is not located + /// within the loop (i.e., the successor of the guard that is not the + /// preheader). + /// This method is only valid for guarded loops. + BasicBlock *getNonLoopBlock() const { + assert(GuardBranch && "Only valid on guarded loops."); + assert(GuardBranch->isConditional() && + "Expecting guard to be a conditional branch."); + return (GuardBranch->getSuccessor(0) == Preheader) + ? GuardBranch->getSuccessor(1) + : GuardBranch->getSuccessor(0); + } + + bool isRotated() const { + assert(L && "Expecting loop to be valid."); + assert(Latch && "Expecting latch to be valid."); + return L->isLoopExiting(Latch); + } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void dump() const { - dbgs() << "\tPreheader: " << (Preheader ? Preheader->getName() : "nullptr") + dbgs() << "\tGuardBranch: " + << (GuardBranch ? GuardBranch->getName() : "nullptr") << "\n" + << "\tPreheader: " << (Preheader ? Preheader->getName() : "nullptr") << "\n" << "\tHeader: " << (Header ? Header->getName() : "nullptr") << "\n" << "\tExitingBB: " << (ExitingBlock ? ExitingBlock->getName() : "nullptr") << "\n" << "\tExitBB: " << (ExitBlock ? ExitBlock->getName() : "nullptr") << "\n" - << "\tLatch: " << (Latch ? Latch->getName() : "nullptr") << "\n"; + << "\tLatch: " << (Latch ? Latch->getName() : "nullptr") << "\n" + << "\tEntryBlock: " + << (getEntryBlock() ? getEntryBlock()->getName() : "nullptr") + << "\n"; } #endif + /// Determine if a fusion candidate (representing a loop) is eligible for + /// fusion. Note that this only checks whether a single loop can be fused - it + /// does not check whether it is *legal* to fuse two loops together. + bool isEligibleForFusion(ScalarEvolution &SE) const { + if (!isValid()) { + LLVM_DEBUG(dbgs() << "FC has invalid CFG requirements!\n"); + if (!Preheader) + ++InvalidPreheader; + if (!Header) + ++InvalidHeader; + if (!ExitingBlock) + ++InvalidExitingBlock; + if (!ExitBlock) + ++InvalidExitBlock; + if (!Latch) + ++InvalidLatch; + if (L->isInvalid()) + ++InvalidLoop; + + return false; + } + + // Require ScalarEvolution to be able to determine a trip count. + if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { + LLVM_DEBUG(dbgs() << "Loop " << L->getName() + << " trip count not computable!\n"); + return reportInvalidCandidate(UnknownTripCount); + } + + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Loop " << L->getName() + << " is not in simplified form!\n"); + return reportInvalidCandidate(NotSimplifiedForm); + } + + return true; + } + private: // This is only used internally for now, to clear the MemWrites and MemReads // list and setting Valid to false. I can't envision other uses of this right @@ -239,17 +331,18 @@ private: MemReads.clear(); Valid = false; } -}; -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, - const FusionCandidate &FC) { - if (FC.isValid()) - OS << FC.Preheader->getName(); - else - OS << "<Invalid>"; - - return OS; -} + bool reportInvalidCandidate(llvm::Statistic &Stat) const { + using namespace ore; + assert(L && Preheader && "Fusion candidate not initialized properly!"); + ++Stat; + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, Stat.getName(), + L->getStartLoc(), Preheader) + << "[" << Preheader->getParent()->getName() << "]: " + << "Loop is not a candidate for fusion: " << Stat.getDesc()); + return false; + } +}; struct FusionCandidateCompare { /// Comparison functor to sort two Control Flow Equivalent fusion candidates @@ -260,21 +353,24 @@ struct FusionCandidateCompare { const FusionCandidate &RHS) const { const DominatorTree *DT = LHS.DT; + BasicBlock *LHSEntryBlock = LHS.getEntryBlock(); + BasicBlock *RHSEntryBlock = RHS.getEntryBlock(); + // Do not save PDT to local variable as it is only used in asserts and thus // will trigger an unused variable warning if building without asserts. assert(DT && LHS.PDT && "Expecting valid dominator tree"); // Do this compare first so if LHS == RHS, function returns false. - if (DT->dominates(RHS.Preheader, LHS.Preheader)) { + if (DT->dominates(RHSEntryBlock, LHSEntryBlock)) { // RHS dominates LHS // Verify LHS post-dominates RHS - assert(LHS.PDT->dominates(LHS.Preheader, RHS.Preheader)); + assert(LHS.PDT->dominates(LHSEntryBlock, RHSEntryBlock)); return false; } - if (DT->dominates(LHS.Preheader, RHS.Preheader)) { + if (DT->dominates(LHSEntryBlock, RHSEntryBlock)) { // Verify RHS Postdominates LHS - assert(LHS.PDT->dominates(RHS.Preheader, LHS.Preheader)); + assert(LHS.PDT->dominates(RHSEntryBlock, LHSEntryBlock)); return true; } @@ -286,7 +382,6 @@ struct FusionCandidateCompare { } }; -namespace { using LoopVector = SmallVector<Loop *, 4>; // Set of Control Flow Equivalent (CFE) Fusion Candidates, sorted in dominance @@ -301,17 +396,26 @@ using LoopVector = SmallVector<Loop *, 4>; // keeps the FusionCandidateSet sorted will also simplify the implementation. using FusionCandidateSet = std::set<FusionCandidate, FusionCandidateCompare>; using FusionCandidateCollection = SmallVector<FusionCandidateSet, 4>; -} // namespace -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, +#if !defined(NDEBUG) +static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + const FusionCandidate &FC) { + if (FC.isValid()) + OS << FC.Preheader->getName(); + else + OS << "<Invalid>"; + + return OS; +} + +static llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const FusionCandidateSet &CandSet) { - for (auto IT : CandSet) - OS << IT << "\n"; + for (const FusionCandidate &FC : CandSet) + OS << FC << '\n'; return OS; } -#if !defined(NDEBUG) static void printFusionCandidates(const FusionCandidateCollection &FusionCandidates) { dbgs() << "Fusion Candidates: \n"; @@ -391,16 +495,6 @@ static void printLoopVector(const LoopVector &LV) { } #endif -static void reportLoopFusion(const FusionCandidate &FC0, - const FusionCandidate &FC1, - OptimizationRemarkEmitter &ORE) { - using namespace ore; - ORE.emit( - OptimizationRemark(DEBUG_TYPE, "LoopFusion", FC0.Preheader->getParent()) - << "Fused " << NV("Cand1", StringRef(FC0.Preheader->getName())) - << " with " << NV("Cand2", StringRef(FC1.Preheader->getName()))); -} - struct LoopFuser { private: // Sets of control flow equivalent fusion candidates for a given nest level. @@ -497,53 +591,16 @@ private: const FusionCandidate &FC1) const { assert(FC0.Preheader && FC1.Preheader && "Expecting valid preheaders"); - if (DT.dominates(FC0.Preheader, FC1.Preheader)) - return PDT.dominates(FC1.Preheader, FC0.Preheader); + BasicBlock *FC0EntryBlock = FC0.getEntryBlock(); + BasicBlock *FC1EntryBlock = FC1.getEntryBlock(); - if (DT.dominates(FC1.Preheader, FC0.Preheader)) - return PDT.dominates(FC0.Preheader, FC1.Preheader); + if (DT.dominates(FC0EntryBlock, FC1EntryBlock)) + return PDT.dominates(FC1EntryBlock, FC0EntryBlock); - return false; - } - - /// Determine if a fusion candidate (representing a loop) is eligible for - /// fusion. Note that this only checks whether a single loop can be fused - it - /// does not check whether it is *legal* to fuse two loops together. - bool eligibleForFusion(const FusionCandidate &FC) const { - if (!FC.isValid()) { - LLVM_DEBUG(dbgs() << "FC " << FC << " has invalid CFG requirements!\n"); - if (!FC.Preheader) - InvalidPreheader++; - if (!FC.Header) - InvalidHeader++; - if (!FC.ExitingBlock) - InvalidExitingBlock++; - if (!FC.ExitBlock) - InvalidExitBlock++; - if (!FC.Latch) - InvalidLatch++; - if (FC.L->isInvalid()) - InvalidLoop++; + if (DT.dominates(FC1EntryBlock, FC0EntryBlock)) + return PDT.dominates(FC0EntryBlock, FC1EntryBlock); - return false; - } - - // Require ScalarEvolution to be able to determine a trip count. - if (!SE.hasLoopInvariantBackedgeTakenCount(FC.L)) { - LLVM_DEBUG(dbgs() << "Loop " << FC.L->getName() - << " trip count not computable!\n"); - InvalidTripCount++; - return false; - } - - if (!FC.L->isLoopSimplifyForm()) { - LLVM_DEBUG(dbgs() << "Loop " << FC.L->getName() - << " is not in simplified form!\n"); - NotSimplifiedForm++; - return false; - } - - return true; + return false; } /// Iterate over all loops in the given loop set and identify the loops that @@ -551,8 +608,8 @@ private: /// Flow Equivalent sets, sorted by dominance. void collectFusionCandidates(const LoopVector &LV) { for (Loop *L : LV) { - FusionCandidate CurrCand(L, &DT, &PDT); - if (!eligibleForFusion(CurrCand)) + FusionCandidate CurrCand(L, &DT, &PDT, ORE); + if (!CurrCand.isEligibleForFusion(SE)) continue; // Go through each list in FusionCandidates and determine if L is control @@ -664,31 +721,64 @@ private: if (!identicalTripCounts(*FC0, *FC1)) { LLVM_DEBUG(dbgs() << "Fusion candidates do not have identical trip " "counts. Not fusing.\n"); - NonEqualTripCount++; + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEqualTripCount); continue; } if (!isAdjacent(*FC0, *FC1)) { LLVM_DEBUG(dbgs() << "Fusion candidates are not adjacent. Not fusing.\n"); - NonAdjacent++; + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, NonAdjacent); continue; } - // For now we skip fusing if the second candidate has any instructions - // in the preheader. This is done because we currently do not have the - // safety checks to determine if it is save to move the preheader of - // the second candidate past the body of the first candidate. Once - // these checks are added, this condition can be removed. + // Ensure that FC0 and FC1 have identical guards. + // If one (or both) are not guarded, this check is not necessary. + if (FC0->GuardBranch && FC1->GuardBranch && + !haveIdenticalGuards(*FC0, *FC1)) { + LLVM_DEBUG(dbgs() << "Fusion candidates do not have identical " + "guards. Not Fusing.\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonIdenticalGuards); + continue; + } + + // The following three checks look for empty blocks in FC0 and FC1. If + // any of these blocks are non-empty, we do not fuse. This is done + // because we currently do not have the safety checks to determine if + // it is safe to move the blocks past other blocks in the loop. Once + // these checks are added, these conditions can be relaxed. if (!isEmptyPreheader(*FC1)) { LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty " "preheader. Not fusing.\n"); - NonEmptyPreheader++; + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEmptyPreheader); + continue; + } + + if (FC0->GuardBranch && !isEmptyExitBlock(*FC0)) { + LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty exit " + "block. Not fusing.\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEmptyExitBlock); + continue; + } + + if (FC1->GuardBranch && !isEmptyGuardBlock(*FC1)) { + LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty guard " + "block. Not fusing.\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEmptyGuardBlock); continue; } + // Check the dependencies across the loops and do not fuse if it would + // violate them. if (!dependencesAllowFusion(*FC0, *FC1)) { LLVM_DEBUG(dbgs() << "Memory dependencies do not allow fusion!\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + InvalidDependencies); continue; } @@ -696,9 +786,11 @@ private: LLVM_DEBUG(dbgs() << "\tFusion appears to be " << (BeneficialToFuse ? "" : "un") << "profitable!\n"); - if (!BeneficialToFuse) + if (!BeneficialToFuse) { + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + FusionNotBeneficial); continue; - + } // All analysis has completed and has determined that fusion is legal // and profitable. At this point, start transforming the code and // perform fusion. @@ -710,15 +802,14 @@ private: // Note this needs to be done *before* performFusion because // performFusion will change the original loops, making it not // possible to identify them after fusion is complete. - reportLoopFusion(*FC0, *FC1, ORE); + reportLoopFusion<OptimizationRemark>(*FC0, *FC1, FuseCounter); - FusionCandidate FusedCand(performFusion(*FC0, *FC1), &DT, &PDT); + FusionCandidate FusedCand(performFusion(*FC0, *FC1), &DT, &PDT, ORE); FusedCand.verify(); - assert(eligibleForFusion(FusedCand) && + assert(FusedCand.isEligibleForFusion(SE) && "Fused candidate should be eligible for fusion!"); // Notify the loop-depth-tree that these loops are not valid objects - // anymore. LDT.removeLoop(FC1->L); CandidateSet.erase(FC0); @@ -889,7 +980,7 @@ private: LLVM_DEBUG(dbgs() << "Check if " << FC0 << " can be fused with " << FC1 << "\n"); assert(FC0.L->getLoopDepth() == FC1.L->getLoopDepth()); - assert(DT.dominates(FC0.Preheader, FC1.Preheader)); + assert(DT.dominates(FC0.getEntryBlock(), FC1.getEntryBlock())); for (Instruction *WriteL0 : FC0.MemWrites) { for (Instruction *WriteL1 : FC1.MemWrites) @@ -939,18 +1030,89 @@ private: return true; } - /// Determine if the exit block of \p FC0 is the preheader of \p FC1. In this - /// case, there is no code in between the two fusion candidates, thus making - /// them adjacent. + /// Determine if two fusion candidates are adjacent in the CFG. + /// + /// This method will determine if there are additional basic blocks in the CFG + /// between the exit of \p FC0 and the entry of \p FC1. + /// If the two candidates are guarded loops, then it checks whether the + /// non-loop successor of the \p FC0 guard branch is the entry block of \p + /// FC1. If not, then the loops are not adjacent. If the two candidates are + /// not guarded loops, then it checks whether the exit block of \p FC0 is the + /// preheader of \p FC1. bool isAdjacent(const FusionCandidate &FC0, const FusionCandidate &FC1) const { - return FC0.ExitBlock == FC1.Preheader; + // If the successor of the guard branch is FC1, then the loops are adjacent + if (FC0.GuardBranch) + return FC0.getNonLoopBlock() == FC1.getEntryBlock(); + else + return FC0.ExitBlock == FC1.getEntryBlock(); + } + + /// Determine if two fusion candidates have identical guards + /// + /// This method will determine if two fusion candidates have the same guards. + /// The guards are considered the same if: + /// 1. The instructions to compute the condition used in the compare are + /// identical. + /// 2. The successors of the guard have the same flow into/around the loop. + /// If the compare instructions are identical, then the first successor of the + /// guard must go to the same place (either the preheader of the loop or the + /// NonLoopBlock). In other words, the the first successor of both loops must + /// both go into the loop (i.e., the preheader) or go around the loop (i.e., + /// the NonLoopBlock). The same must be true for the second successor. + bool haveIdenticalGuards(const FusionCandidate &FC0, + const FusionCandidate &FC1) const { + assert(FC0.GuardBranch && FC1.GuardBranch && + "Expecting FC0 and FC1 to be guarded loops."); + + if (auto FC0CmpInst = + dyn_cast<Instruction>(FC0.GuardBranch->getCondition())) + if (auto FC1CmpInst = + dyn_cast<Instruction>(FC1.GuardBranch->getCondition())) + if (!FC0CmpInst->isIdenticalTo(FC1CmpInst)) + return false; + + // The compare instructions are identical. + // Now make sure the successor of the guards have the same flow into/around + // the loop + if (FC0.GuardBranch->getSuccessor(0) == FC0.Preheader) + return (FC1.GuardBranch->getSuccessor(0) == FC1.Preheader); + else + return (FC1.GuardBranch->getSuccessor(1) == FC1.Preheader); + } + + /// Check that the guard for \p FC *only* contains the cmp/branch for the + /// guard. + /// Once we are able to handle intervening code, any code in the guard block + /// for FC1 will need to be treated as intervening code and checked whether + /// it can safely move around the loops. + bool isEmptyGuardBlock(const FusionCandidate &FC) const { + assert(FC.GuardBranch && "Expecting a fusion candidate with guard branch."); + if (auto *CmpInst = dyn_cast<Instruction>(FC.GuardBranch->getCondition())) { + auto *GuardBlock = FC.GuardBranch->getParent(); + // If the generation of the cmp value is in GuardBlock, then the size of + // the guard block should be 2 (cmp + branch). If the generation of the + // cmp value is in a different block, then the size of the guard block + // should only be 1. + if (CmpInst->getParent() == GuardBlock) + return GuardBlock->size() == 2; + else + return GuardBlock->size() == 1; + } + + return false; } bool isEmptyPreheader(const FusionCandidate &FC) const { + assert(FC.Preheader && "Expecting a valid preheader"); return FC.Preheader->size() == 1; } + bool isEmptyExitBlock(const FusionCandidate &FC) const { + assert(FC.ExitBlock && "Expecting a valid exit block"); + return FC.ExitBlock->size() == 1; + } + /// Fuse two fusion candidates, creating a new fused loop. /// /// This method contains the mechanics of fusing two loops, represented by \p @@ -987,6 +1149,12 @@ private: LLVM_DEBUG(dbgs() << "Fusion Candidate 0: \n"; FC0.dump(); dbgs() << "Fusion Candidate 1: \n"; FC1.dump();); + // Fusing guarded loops is handled slightly differently than non-guarded + // loops and has been broken out into a separate method instead of trying to + // intersperse the logic within a single method. + if (FC0.GuardBranch) + return fuseGuardedLoops(FC0, FC1); + assert(FC1.Preheader == FC0.ExitBlock); assert(FC1.Preheader->size() == 1 && FC1.Preheader->getSingleSuccessor() == FC1.Header); @@ -1131,7 +1299,258 @@ private: SE.verify(); #endif - FuseCounter++; + LLVM_DEBUG(dbgs() << "Fusion done:\n"); + + return FC0.L; + } + + /// Report details on loop fusion opportunities. + /// + /// This template function can be used to report both successful and missed + /// loop fusion opportunities, based on the RemarkKind. The RemarkKind should + /// be one of: + /// - OptimizationRemarkMissed to report when loop fusion is unsuccessful + /// given two valid fusion candidates. + /// - OptimizationRemark to report successful fusion of two fusion + /// candidates. + /// The remarks will be printed using the form: + /// <path/filename>:<line number>:<column number>: [<function name>]: + /// <Cand1 Preheader> and <Cand2 Preheader>: <Stat Description> + template <typename RemarkKind> + void reportLoopFusion(const FusionCandidate &FC0, const FusionCandidate &FC1, + llvm::Statistic &Stat) { + assert(FC0.Preheader && FC1.Preheader && + "Expecting valid fusion candidates"); + using namespace ore; + ++Stat; + ORE.emit(RemarkKind(DEBUG_TYPE, Stat.getName(), FC0.L->getStartLoc(), + FC0.Preheader) + << "[" << FC0.Preheader->getParent()->getName() + << "]: " << NV("Cand1", StringRef(FC0.Preheader->getName())) + << " and " << NV("Cand2", StringRef(FC1.Preheader->getName())) + << ": " << Stat.getDesc()); + } + + /// Fuse two guarded fusion candidates, creating a new fused loop. + /// + /// Fusing guarded loops is handled much the same way as fusing non-guarded + /// loops. The rewiring of the CFG is slightly different though, because of + /// the presence of the guards around the loops and the exit blocks after the + /// loop body. As such, the new loop is rewired as follows: + /// 1. Keep the guard branch from FC0 and use the non-loop block target + /// from the FC1 guard branch. + /// 2. Remove the exit block from FC0 (this exit block should be empty + /// right now). + /// 3. Remove the guard branch for FC1 + /// 4. Remove the preheader for FC1. + /// The exit block successor for the latch of FC0 is updated to be the header + /// of FC1 and the non-exit block successor of the latch of FC1 is updated to + /// be the header of FC0, thus creating the fused loop. + Loop *fuseGuardedLoops(const FusionCandidate &FC0, + const FusionCandidate &FC1) { + assert(FC0.GuardBranch && FC1.GuardBranch && "Expecting guarded loops"); + + BasicBlock *FC0GuardBlock = FC0.GuardBranch->getParent(); + BasicBlock *FC1GuardBlock = FC1.GuardBranch->getParent(); + BasicBlock *FC0NonLoopBlock = FC0.getNonLoopBlock(); + BasicBlock *FC1NonLoopBlock = FC1.getNonLoopBlock(); + + assert(FC0NonLoopBlock == FC1GuardBlock && "Loops are not adjacent"); + + SmallVector<DominatorTree::UpdateType, 8> TreeUpdates; + + //////////////////////////////////////////////////////////////////////////// + // Update the Loop Guard + //////////////////////////////////////////////////////////////////////////// + // The guard for FC0 is updated to guard both FC0 and FC1. This is done by + // changing the NonLoopGuardBlock for FC0 to the NonLoopGuardBlock for FC1. + // Thus, one path from the guard goes to the preheader for FC0 (and thus + // executes the new fused loop) and the other path goes to the NonLoopBlock + // for FC1 (where FC1 guard would have gone if FC1 was not executed). + FC0.GuardBranch->replaceUsesOfWith(FC0NonLoopBlock, FC1NonLoopBlock); + FC0.ExitBlock->getTerminator()->replaceUsesOfWith(FC1GuardBlock, + FC1.Header); + + // The guard of FC1 is not necessary anymore. + FC1.GuardBranch->eraseFromParent(); + new UnreachableInst(FC1GuardBlock->getContext(), FC1GuardBlock); + + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC1GuardBlock, FC1.Preheader)); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC1GuardBlock, FC1NonLoopBlock)); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC0GuardBlock, FC1GuardBlock)); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Insert, FC0GuardBlock, FC1NonLoopBlock)); + + assert(pred_begin(FC1GuardBlock) == pred_end(FC1GuardBlock) && + "Expecting guard block to have no predecessors"); + assert(succ_begin(FC1GuardBlock) == succ_end(FC1GuardBlock) && + "Expecting guard block to have no successors"); + + // Remember the phi nodes originally in the header of FC0 in order to rewire + // them later. However, this is only necessary if the new loop carried + // values might not dominate the exiting branch. While we do not generally + // test if this is the case but simply insert intermediate phi nodes, we + // need to make sure these intermediate phi nodes have different + // predecessors. To this end, we filter the special case where the exiting + // block is the latch block of the first loop. Nothing needs to be done + // anyway as all loop carried values dominate the latch and thereby also the + // exiting branch. + // KB: This is no longer necessary because FC0.ExitingBlock == FC0.Latch + // (because the loops are rotated. Thus, nothing will ever be added to + // OriginalFC0PHIs. + SmallVector<PHINode *, 8> OriginalFC0PHIs; + if (FC0.ExitingBlock != FC0.Latch) + for (PHINode &PHI : FC0.Header->phis()) + OriginalFC0PHIs.push_back(&PHI); + + assert(OriginalFC0PHIs.empty() && "Expecting OriginalFC0PHIs to be empty!"); + + // Replace incoming blocks for header PHIs first. + FC1.Preheader->replaceSuccessorsPhiUsesWith(FC0.Preheader); + FC0.Latch->replaceSuccessorsPhiUsesWith(FC1.Latch); + + // The old exiting block of the first loop (FC0) has to jump to the header + // of the second as we need to execute the code in the second header block + // regardless of the trip count. That is, if the trip count is 0, so the + // back edge is never taken, we still have to execute both loop headers, + // especially (but not only!) if the second is a do-while style loop. + // However, doing so might invalidate the phi nodes of the first loop as + // the new values do only need to dominate their latch and not the exiting + // predicate. To remedy this potential problem we always introduce phi + // nodes in the header of the second loop later that select the loop carried + // value, if the second header was reached through an old latch of the + // first, or undef otherwise. This is sound as exiting the first implies the + // second will exit too, __without__ taking the back-edge (their + // trip-counts are equal after all). + FC0.ExitingBlock->getTerminator()->replaceUsesOfWith(FC0.ExitBlock, + FC1.Header); + + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC0.ExitingBlock, FC0.ExitBlock)); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Insert, FC0.ExitingBlock, FC1.Header)); + + // Remove FC0 Exit Block + // The exit block for FC0 is no longer needed since control will flow + // directly to the header of FC1. Since it is an empty block, it can be + // removed at this point. + // TODO: In the future, we can handle non-empty exit blocks my merging any + // instructions from FC0 exit block into FC1 exit block prior to removing + // the block. + assert(pred_begin(FC0.ExitBlock) == pred_end(FC0.ExitBlock) && + "Expecting exit block to be empty"); + FC0.ExitBlock->getTerminator()->eraseFromParent(); + new UnreachableInst(FC0.ExitBlock->getContext(), FC0.ExitBlock); + + // Remove FC1 Preheader + // The pre-header of L1 is not necessary anymore. + assert(pred_begin(FC1.Preheader) == pred_end(FC1.Preheader)); + FC1.Preheader->getTerminator()->eraseFromParent(); + new UnreachableInst(FC1.Preheader->getContext(), FC1.Preheader); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC1.Preheader, FC1.Header)); + + // Moves the phi nodes from the second to the first loops header block. + while (PHINode *PHI = dyn_cast<PHINode>(&FC1.Header->front())) { + if (SE.isSCEVable(PHI->getType())) + SE.forgetValue(PHI); + if (PHI->hasNUsesOrMore(1)) + PHI->moveBefore(&*FC0.Header->getFirstInsertionPt()); + else + PHI->eraseFromParent(); + } + + // Introduce new phi nodes in the second loop header to ensure + // exiting the first and jumping to the header of the second does not break + // the SSA property of the phis originally in the first loop. See also the + // comment above. + Instruction *L1HeaderIP = &FC1.Header->front(); + for (PHINode *LCPHI : OriginalFC0PHIs) { + int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch); + assert(L1LatchBBIdx >= 0 && + "Expected loop carried value to be rewired at this point!"); + + Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx); + + PHINode *L1HeaderPHI = PHINode::Create( + LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP); + L1HeaderPHI->addIncoming(LCV, FC0.Latch); + L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()), + FC0.ExitingBlock); + + LCPHI->setIncomingValue(L1LatchBBIdx, L1HeaderPHI); + } + + // Update the latches + + // Replace latch terminator destinations. + FC0.Latch->getTerminator()->replaceUsesOfWith(FC0.Header, FC1.Header); + FC1.Latch->getTerminator()->replaceUsesOfWith(FC1.Header, FC0.Header); + + // If FC0.Latch and FC0.ExitingBlock are the same then we have already + // performed the updates above. + if (FC0.Latch != FC0.ExitingBlock) + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Insert, FC0.Latch, FC1.Header)); + + TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Delete, + FC0.Latch, FC0.Header)); + TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Insert, + FC1.Latch, FC0.Header)); + TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Delete, + FC1.Latch, FC1.Header)); + + // All done + // Apply the updates to the Dominator Tree and cleanup. + + assert(succ_begin(FC1GuardBlock) == succ_end(FC1GuardBlock) && + "FC1GuardBlock has successors!!"); + assert(pred_begin(FC1GuardBlock) == pred_end(FC1GuardBlock) && + "FC1GuardBlock has predecessors!!"); + + // Update DT/PDT + DTU.applyUpdates(TreeUpdates); + + LI.removeBlock(FC1.Preheader); + DTU.deleteBB(FC1.Preheader); + DTU.deleteBB(FC0.ExitBlock); + DTU.flush(); + + // Is there a way to keep SE up-to-date so we don't need to forget the loops + // and rebuild the information in subsequent passes of fusion? + SE.forgetLoop(FC1.L); + SE.forgetLoop(FC0.L); + + // Merge the loops. + SmallVector<BasicBlock *, 8> Blocks(FC1.L->block_begin(), + FC1.L->block_end()); + for (BasicBlock *BB : Blocks) { + FC0.L->addBlockEntry(BB); + FC1.L->removeBlockFromLoop(BB); + if (LI.getLoopFor(BB) != FC1.L) + continue; + LI.changeLoopFor(BB, FC0.L); + } + while (!FC1.L->empty()) { + const auto &ChildLoopIt = FC1.L->begin(); + Loop *ChildLoop = *ChildLoopIt; + FC1.L->removeChildLoop(ChildLoopIt); + FC0.L->addChildLoop(ChildLoop); + } + + // Delete the now empty loop L1. + LI.erase(FC1.L); + +#ifndef NDEBUG + assert(!verifyFunction(*FC0.Header->getParent(), &errs())); + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); + assert(PDT.verify()); + LI.verify(DT); + SE.verify(); +#endif LLVM_DEBUG(dbgs() << "Fusion done:\n"); @@ -1177,6 +1596,7 @@ struct LoopFuseLegacy : public FunctionPass { return LF.fuseLoops(F); } }; +} // namespace PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) { auto &LI = AM.getResult<LoopAnalysis>(F); diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index e561494f19cf..dd477e800693 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -41,6 +41,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -77,16 +78,20 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/IR/Verifier.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -102,6 +107,7 @@ using namespace llvm; STATISTIC(NumMemSet, "Number of memset's formed from loop stores"); STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores"); +STATISTIC(NumBCmp, "Number of memcmp's formed from loop 2xload+eq-compare"); static cl::opt<bool> UseLIRCodeSizeHeurs( "use-lir-code-size-heurs", @@ -111,6 +117,26 @@ static cl::opt<bool> UseLIRCodeSizeHeurs( namespace { +// FIXME: reinventing the wheel much? Is there a cleaner solution? +struct PMAbstraction { + virtual void markLoopAsDeleted(Loop *L) = 0; + virtual ~PMAbstraction() = default; +}; +struct LegacyPMAbstraction : PMAbstraction { + LPPassManager &LPM; + LegacyPMAbstraction(LPPassManager &LPM) : LPM(LPM) {} + virtual ~LegacyPMAbstraction() = default; + void markLoopAsDeleted(Loop *L) override { LPM.markLoopAsDeleted(*L); } +}; +struct NewPMAbstraction : PMAbstraction { + LPMUpdater &Updater; + NewPMAbstraction(LPMUpdater &Updater) : Updater(Updater) {} + virtual ~NewPMAbstraction() = default; + void markLoopAsDeleted(Loop *L) override { + Updater.markLoopAsDeleted(*L, L->getName()); + } +}; + class LoopIdiomRecognize { Loop *CurLoop = nullptr; AliasAnalysis *AA; @@ -120,6 +146,7 @@ class LoopIdiomRecognize { TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; const DataLayout *DL; + PMAbstraction &LoopDeleter; OptimizationRemarkEmitter &ORE; bool ApplyCodeSizeHeuristics; @@ -128,9 +155,10 @@ public: LoopInfo *LI, ScalarEvolution *SE, TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, - const DataLayout *DL, + const DataLayout *DL, PMAbstraction &LoopDeleter, OptimizationRemarkEmitter &ORE) - : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {} + : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), + LoopDeleter(LoopDeleter), ORE(ORE) {} bool runOnLoop(Loop *L); @@ -144,6 +172,8 @@ private: bool HasMemset; bool HasMemsetPattern; bool HasMemcpy; + bool HasMemCmp; + bool HasBCmp; /// Return code for isLegalStore() enum LegalStoreKind { @@ -186,6 +216,32 @@ private: bool runOnNoncountableLoop(); + struct CmpLoopStructure { + Value *BCmpValue, *LatchCmpValue; + BasicBlock *HeaderBrEqualBB, *HeaderBrUnequalBB; + BasicBlock *LatchBrFinishBB, *LatchBrContinueBB; + }; + bool matchBCmpLoopStructure(CmpLoopStructure &CmpLoop) const; + struct CmpOfLoads { + ICmpInst::Predicate BCmpPred; + Value *LoadSrcA, *LoadSrcB; + Value *LoadA, *LoadB; + }; + bool matchBCmpOfLoads(Value *BCmpValue, CmpOfLoads &CmpOfLoads) const; + bool recognizeBCmpLoopControlFlow(const CmpOfLoads &CmpOfLoads, + CmpLoopStructure &CmpLoop) const; + bool recognizeBCmpLoopSCEV(uint64_t BCmpTyBytes, CmpOfLoads &CmpOfLoads, + const SCEV *&SrcA, const SCEV *&SrcB, + const SCEV *&Iterations) const; + bool detectBCmpIdiom(ICmpInst *&BCmpInst, CmpInst *&LatchCmpInst, + LoadInst *&LoadA, LoadInst *&LoadB, const SCEV *&SrcA, + const SCEV *&SrcB, const SCEV *&NBytes) const; + BasicBlock *transformBCmpControlFlow(ICmpInst *ComparedEqual); + void transformLoopToBCmp(ICmpInst *BCmpInst, CmpInst *LatchCmpInst, + LoadInst *LoadA, LoadInst *LoadB, const SCEV *SrcA, + const SCEV *SrcB, const SCEV *NBytes); + bool recognizeBCmp(); + bool recognizePopcount(); void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var); @@ -217,18 +273,20 @@ public: LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( + *L->getHeader()->getParent()); const TargetTransformInfo *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( *L->getHeader()->getParent()); const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); + LegacyPMAbstraction LoopDeleter(LPM); // For the old PM, we can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations // but ORE cannot be preserved (see comment before the pass definition). OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); - LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, ORE); + LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, LoopDeleter, ORE); return LIR.runOnLoop(L); } @@ -247,7 +305,7 @@ char LoopIdiomRecognizeLegacyPass::ID = 0; PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, - LPMUpdater &) { + LPMUpdater &Updater) { const auto *DL = &L.getHeader()->getModule()->getDataLayout(); const auto &FAM = @@ -261,8 +319,9 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, "LoopIdiomRecognizePass: OptimizationRemarkEmitterAnalysis not cached " "at a higher level"); + NewPMAbstraction LoopDeleter(Updater); LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL, - *ORE); + LoopDeleter, *ORE); if (!LIR.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -299,7 +358,8 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) { // Disable loop idiom recognition if the function's name is a common idiom. StringRef Name = L->getHeader()->getParent()->getName(); - if (Name == "memset" || Name == "memcpy") + if (Name == "memset" || Name == "memcpy" || Name == "memcmp" || + Name == "bcmp") return false; // Determine if code size heuristics need to be applied. @@ -309,8 +369,10 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) { HasMemset = TLI->has(LibFunc_memset); HasMemsetPattern = TLI->has(LibFunc_memset_pattern16); HasMemcpy = TLI->has(LibFunc_memcpy); + HasMemCmp = TLI->has(LibFunc_memcmp); + HasBCmp = TLI->has(LibFunc_bcmp); - if (HasMemset || HasMemsetPattern || HasMemcpy) + if (HasMemset || HasMemsetPattern || HasMemcpy || HasMemCmp || HasBCmp) if (SE->hasLoopInvariantBackedgeTakenCount(L)) return runOnCountableLoop(); @@ -961,7 +1023,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( GlobalValue::PrivateLinkage, PatternValue, ".memset_pattern"); GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these. - GV->setAlignment(16); + GV->setAlignment(Align(16)); Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); } @@ -1149,7 +1211,7 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() { << "] Noncountable Loop %" << CurLoop->getHeader()->getName() << "\n"); - return recognizePopcount() || recognizeAndInsertFFS(); + return recognizeBCmp() || recognizePopcount() || recognizeAndInsertFFS(); } /// Check if the given conditional branch is based on the comparison between @@ -1823,3 +1885,811 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, // loop. The loop would otherwise not be deleted even if it becomes empty. SE->forgetLoop(CurLoop); } + +bool LoopIdiomRecognize::matchBCmpLoopStructure( + CmpLoopStructure &CmpLoop) const { + ICmpInst::Predicate BCmpPred; + + // We are looking for the following basic layout: + // PreheaderBB: <preheader> ; preds = ??? + // <...> + // br label %LoopHeaderBB + // LoopHeaderBB: <header,exiting> ; preds = %PreheaderBB,%LoopLatchBB + // <...> + // %BCmpValue = icmp <...> + // br i1 %BCmpValue, label %LoopLatchBB, label %Successor0 + // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB + // <...> + // %LatchCmpValue = <are we done, or do next iteration?> + // br i1 %LatchCmpValue, label %Successor1, label %LoopHeaderBB + // Successor0: <exit> ; preds = %LoopHeaderBB + // <...> + // Successor1: <exit> ; preds = %LoopLatchBB + // <...> + // + // Successor0 and Successor1 may or may not be the same basic block. + + // Match basic frame-work of this supposedly-comparison loop. + using namespace PatternMatch; + if (!match(CurLoop->getHeader()->getTerminator(), + m_Br(m_CombineAnd(m_ICmp(BCmpPred, m_Value(), m_Value()), + m_Value(CmpLoop.BCmpValue)), + CmpLoop.HeaderBrEqualBB, CmpLoop.HeaderBrUnequalBB)) || + !match(CurLoop->getLoopLatch()->getTerminator(), + m_Br(m_CombineAnd(m_Cmp(), m_Value(CmpLoop.LatchCmpValue)), + CmpLoop.LatchBrFinishBB, CmpLoop.LatchBrContinueBB))) { + LLVM_DEBUG(dbgs() << "Basic control-flow layout unrecognized.\n"); + return false; + } + LLVM_DEBUG(dbgs() << "Recognized basic control-flow layout.\n"); + return true; +} + +bool LoopIdiomRecognize::matchBCmpOfLoads(Value *BCmpValue, + CmpOfLoads &CmpOfLoads) const { + using namespace PatternMatch; + LLVM_DEBUG(dbgs() << "Analyzing header icmp " << *BCmpValue + << " as bcmp pattern.\n"); + + // Match bcmp-style loop header cmp. It must be an eq-icmp of loads. Example: + // %v0 = load <...>, <...>* %LoadSrcA + // %v1 = load <...>, <...>* %LoadSrcB + // %CmpLoop.BCmpValue = icmp eq <...> %v0, %v1 + // There won't be any no-op bitcasts between load and icmp, + // they would have been transformed into a load of bitcast. + // FIXME: {b,mem}cmp() calls have the same semantics as icmp. Match them too. + if (!match(BCmpValue, + m_ICmp(CmpOfLoads.BCmpPred, + m_CombineAnd(m_Load(m_Value(CmpOfLoads.LoadSrcA)), + m_Value(CmpOfLoads.LoadA)), + m_CombineAnd(m_Load(m_Value(CmpOfLoads.LoadSrcB)), + m_Value(CmpOfLoads.LoadB)))) || + !ICmpInst::isEquality(CmpOfLoads.BCmpPred)) { + LLVM_DEBUG(dbgs() << "Loop header icmp did not match bcmp pattern.\n"); + return false; + } + LLVM_DEBUG(dbgs() << "Recognized header icmp as bcmp pattern with loads:\n\t" + << *CmpOfLoads.LoadA << "\n\t" << *CmpOfLoads.LoadB + << "\n"); + // FIXME: handle memcmp pattern? + return true; +} + +bool LoopIdiomRecognize::recognizeBCmpLoopControlFlow( + const CmpOfLoads &CmpOfLoads, CmpLoopStructure &CmpLoop) const { + BasicBlock *LoopHeaderBB = CurLoop->getHeader(); + BasicBlock *LoopLatchBB = CurLoop->getLoopLatch(); + + // Be wary, comparisons can be inverted, canonicalize order. + // If this 'element' comparison passed, we expect to proceed to the next elt. + if (CmpOfLoads.BCmpPred != ICmpInst::Predicate::ICMP_EQ) + std::swap(CmpLoop.HeaderBrEqualBB, CmpLoop.HeaderBrUnequalBB); + // The predicate on loop latch does not matter, just canonicalize some order. + if (CmpLoop.LatchBrContinueBB != LoopHeaderBB) + std::swap(CmpLoop.LatchBrFinishBB, CmpLoop.LatchBrContinueBB); + + SmallVector<BasicBlock *, 2> ExitBlocks; + + CurLoop->getUniqueExitBlocks(ExitBlocks); + assert(ExitBlocks.size() <= 2U && "Can't have more than two exit blocks."); + + // Check that control-flow between blocks is as expected. + if (CmpLoop.HeaderBrEqualBB != LoopLatchBB || + CmpLoop.LatchBrContinueBB != LoopHeaderBB || + !is_contained(ExitBlocks, CmpLoop.HeaderBrUnequalBB) || + !is_contained(ExitBlocks, CmpLoop.LatchBrFinishBB)) { + LLVM_DEBUG(dbgs() << "Loop control-flow not recognized.\n"); + return false; + } + + assert(!is_contained(ExitBlocks, CmpLoop.HeaderBrEqualBB) && + !is_contained(ExitBlocks, CmpLoop.LatchBrContinueBB) && + "Unexpected exit edges."); + + LLVM_DEBUG(dbgs() << "Recognized loop control-flow.\n"); + + LLVM_DEBUG(dbgs() << "Performing side-effect analysis on the loop.\n"); + assert(CurLoop->isLCSSAForm(*DT) && "Should only get LCSSA-form loops here."); + // No loop instructions must be used outside of the loop. Since we are in + // LCSSA form, we only need to check successor block's PHI nodes's incoming + // values for incoming blocks that are the loop basic blocks. + for (const BasicBlock *ExitBB : ExitBlocks) { + for (const PHINode &PHI : ExitBB->phis()) { + for (const BasicBlock *LoopBB : + make_filter_range(PHI.blocks(), [this](BasicBlock *PredecessorBB) { + return CurLoop->contains(PredecessorBB); + })) { + const auto *I = + dyn_cast<Instruction>(PHI.getIncomingValueForBlock(LoopBB)); + if (I && CurLoop->contains(I)) { + LLVM_DEBUG(dbgs() + << "Loop contains instruction " << *I + << " which is used outside of the loop in basic block " + << ExitBB->getName() << " in phi node " << PHI << "\n"); + return false; + } + } + } + } + // Similarly, the loop should not have any other observable side-effects + // other than the final comparison result. + for (BasicBlock *LoopBB : CurLoop->blocks()) { + for (Instruction &I : *LoopBB) { + if (isa<DbgInfoIntrinsic>(I)) // Ignore dbginfo. + continue; // FIXME: anything else? lifetime info? + if ((I.mayHaveSideEffects() || I.isAtomic() || I.isFenceLike()) && + &I != CmpOfLoads.LoadA && &I != CmpOfLoads.LoadB) { + LLVM_DEBUG( + dbgs() << "Loop contains instruction with potential side-effects: " + << I << "\n"); + return false; + } + } + } + LLVM_DEBUG(dbgs() << "No loop instructions deemed to have side-effects.\n"); + return true; +} + +bool LoopIdiomRecognize::recognizeBCmpLoopSCEV(uint64_t BCmpTyBytes, + CmpOfLoads &CmpOfLoads, + const SCEV *&SrcA, + const SCEV *&SrcB, + const SCEV *&Iterations) const { + // Try to compute SCEV of the loads, for this loop's scope. + const auto *ScevForSrcA = dyn_cast<SCEVAddRecExpr>( + SE->getSCEVAtScope(CmpOfLoads.LoadSrcA, CurLoop)); + const auto *ScevForSrcB = dyn_cast<SCEVAddRecExpr>( + SE->getSCEVAtScope(CmpOfLoads.LoadSrcB, CurLoop)); + if (!ScevForSrcA || !ScevForSrcB) { + LLVM_DEBUG(dbgs() << "Failed to get SCEV expressions for load sources.\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "Got SCEV expressions (at loop scope) for loads:\n\t" + << *ScevForSrcA << "\n\t" << *ScevForSrcB << "\n"); + + // Loads must have folloving SCEV exprs: {%ptr,+,BCmpTyBytes}<%LoopHeaderBB> + const SCEV *RecStepForA = ScevForSrcA->getStepRecurrence(*SE); + const SCEV *RecStepForB = ScevForSrcB->getStepRecurrence(*SE); + if (!ScevForSrcA->isAffine() || !ScevForSrcB->isAffine() || + ScevForSrcA->getLoop() != CurLoop || ScevForSrcB->getLoop() != CurLoop || + RecStepForA != RecStepForB || !isa<SCEVConstant>(RecStepForA) || + cast<SCEVConstant>(RecStepForA)->getAPInt() != BCmpTyBytes) { + LLVM_DEBUG(dbgs() << "Unsupported SCEV expressions for loads. Only support " + "affine SCEV expressions originating in the loop we " + "are analysing with identical constant positive step, " + "equal to the count of bytes compared. Got:\n\t" + << *RecStepForA << "\n\t" << *RecStepForB << "\n"); + return false; + // FIXME: can support BCmpTyBytes > Step. + // But will need to account for the extra bytes compared at the end. + } + + SrcA = ScevForSrcA->getStart(); + SrcB = ScevForSrcB->getStart(); + LLVM_DEBUG(dbgs() << "Got SCEV expressions for load sources:\n\t" << *SrcA + << "\n\t" << *SrcB << "\n"); + + // The load sources must be loop-invants that dominate the loop header. + if (SrcA == SE->getCouldNotCompute() || SrcB == SE->getCouldNotCompute() || + !SE->isAvailableAtLoopEntry(SrcA, CurLoop) || + !SE->isAvailableAtLoopEntry(SrcB, CurLoop)) { + LLVM_DEBUG(dbgs() << "Unsupported SCEV expressions for loads, unavaliable " + "prior to loop header.\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "SCEV expressions for loads are acceptable.\n"); + + // bcmp / memcmp take length argument as size_t, so let's conservatively + // assume that the iteration count should be not wider than that. + Type *CmpFuncSizeTy = DL->getIntPtrType(SE->getContext()); + + // For how many iterations is loop guaranteed not to exit via LoopLatch? + // This is one less than the maximal number of comparisons,and is: n + -1 + const SCEV *LoopExitCount = + SE->getExitCount(CurLoop, CurLoop->getLoopLatch()); + LLVM_DEBUG(dbgs() << "Got SCEV expression for loop latch exit count: " + << *LoopExitCount << "\n"); + // Exit count, similarly, must be loop-invant that dominates the loop header. + if (LoopExitCount == SE->getCouldNotCompute() || + !LoopExitCount->getType()->isIntOrPtrTy() || + LoopExitCount->getType()->getScalarSizeInBits() > + CmpFuncSizeTy->getScalarSizeInBits() || + !SE->isAvailableAtLoopEntry(LoopExitCount, CurLoop)) { + LLVM_DEBUG(dbgs() << "Unsupported SCEV expression for loop latch exit.\n"); + return false; + } + + // LoopExitCount is always one less than the actual count of iterations. + // Do this before cast, else we will be stuck with 1 + zext(-1 + n) + Iterations = SE->getAddExpr( + LoopExitCount, SE->getOne(LoopExitCount->getType()), SCEV::FlagNUW); + assert(Iterations != SE->getCouldNotCompute() && + "Shouldn't fail to increment by one."); + + LLVM_DEBUG(dbgs() << "Computed iteration count: " << *Iterations << "\n"); + return true; +} + +/// Return true iff the bcmp idiom is detected in the loop. +/// +/// Additionally: +/// 1) \p BCmpInst is set to the root byte-comparison instruction. +/// 2) \p LatchCmpInst is set to the comparison that controls the latch. +/// 3) \p LoadA is set to the first LoadInst. +/// 4) \p LoadB is set to the second LoadInst. +/// 5) \p SrcA is set to the first source location that is being compared. +/// 6) \p SrcB is set to the second source location that is being compared. +/// 7) \p NBytes is set to the number of bytes to compare. +bool LoopIdiomRecognize::detectBCmpIdiom(ICmpInst *&BCmpInst, + CmpInst *&LatchCmpInst, + LoadInst *&LoadA, LoadInst *&LoadB, + const SCEV *&SrcA, const SCEV *&SrcB, + const SCEV *&NBytes) const { + LLVM_DEBUG(dbgs() << "Recognizing bcmp idiom\n"); + + // Give up if the loop is not in normal form, or has more than 2 blocks. + if (!CurLoop->isLoopSimplifyForm() || CurLoop->getNumBlocks() > 2) { + LLVM_DEBUG(dbgs() << "Basic loop structure unrecognized.\n"); + return false; + } + LLVM_DEBUG(dbgs() << "Recognized basic loop structure.\n"); + + CmpLoopStructure CmpLoop; + if (!matchBCmpLoopStructure(CmpLoop)) + return false; + + CmpOfLoads CmpOfLoads; + if (!matchBCmpOfLoads(CmpLoop.BCmpValue, CmpOfLoads)) + return false; + + if (!recognizeBCmpLoopControlFlow(CmpOfLoads, CmpLoop)) + return false; + + BCmpInst = cast<ICmpInst>(CmpLoop.BCmpValue); // FIXME: is there no + LatchCmpInst = cast<CmpInst>(CmpLoop.LatchCmpValue); // way to combine + LoadA = cast<LoadInst>(CmpOfLoads.LoadA); // these cast with + LoadB = cast<LoadInst>(CmpOfLoads.LoadB); // m_Value() matcher? + + Type *BCmpValTy = BCmpInst->getOperand(0)->getType(); + LLVMContext &Context = BCmpValTy->getContext(); + uint64_t BCmpTyBits = DL->getTypeSizeInBits(BCmpValTy); + static constexpr uint64_t ByteTyBits = 8; + + LLVM_DEBUG(dbgs() << "Got comparison between values of type " << *BCmpValTy + << " of size " << BCmpTyBits + << " bits (while byte = " << ByteTyBits << " bits).\n"); + // bcmp()/memcmp() minimal unit of work is a byte. Therefore we must check + // that we are dealing with a multiple of a byte here. + if (BCmpTyBits % ByteTyBits != 0) { + LLVM_DEBUG(dbgs() << "Value size is not a multiple of byte.\n"); + return false; + // FIXME: could still be done under a run-time check that the total bit + // count is a multiple of a byte i guess? Or handle remainder separately? + } + + // Each comparison is done on this many bytes. + uint64_t BCmpTyBytes = BCmpTyBits / ByteTyBits; + LLVM_DEBUG(dbgs() << "Size is exactly " << BCmpTyBytes + << " bytes, eligible for bcmp conversion.\n"); + + const SCEV *Iterations; + if (!recognizeBCmpLoopSCEV(BCmpTyBytes, CmpOfLoads, SrcA, SrcB, Iterations)) + return false; + + // bcmp / memcmp take length argument as size_t, do promotion now. + Type *CmpFuncSizeTy = DL->getIntPtrType(Context); + Iterations = SE->getNoopOrZeroExtend(Iterations, CmpFuncSizeTy); + assert(Iterations != SE->getCouldNotCompute() && "Promotion failed."); + // Note that it didn't do ptrtoint cast, we will need to do it manually. + + // We will be comparing *bytes*, not BCmpTy, we need to recalculate size. + // It's a multiplication, and it *could* overflow. But for it to overflow + // we'd want to compare more bytes than could be represented by size_t, But + // allocation functions also take size_t. So how'd you produce such buffer? + // FIXME: we likely need to actually check that we know this won't overflow, + // via llvm::computeOverflowForUnsignedMul(). + NBytes = SE->getMulExpr( + Iterations, SE->getConstant(CmpFuncSizeTy, BCmpTyBytes), SCEV::FlagNUW); + assert(NBytes != SE->getCouldNotCompute() && + "Shouldn't fail to increment by one."); + + LLVM_DEBUG(dbgs() << "Computed total byte count: " << *NBytes << "\n"); + + if (LoadA->getPointerAddressSpace() != LoadB->getPointerAddressSpace() || + LoadA->getPointerAddressSpace() != 0 || !LoadA->isSimple() || + !LoadB->isSimple()) { + StringLiteral L("Unsupported loads in idiom - only support identical, " + "simple loads from address space 0.\n"); + LLVM_DEBUG(dbgs() << L); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "BCmpIdiomUnsupportedLoads", + BCmpInst->getDebugLoc(), + CurLoop->getHeader()) + << L; + }); + return false; // FIXME: support non-simple loads. + } + + LLVM_DEBUG(dbgs() << "Recognized bcmp idiom\n"); + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "RecognizedBCmpIdiom", + CurLoop->getStartLoc(), + CurLoop->getHeader()) + << "Loop recognized as a bcmp idiom"; + }); + + return true; +} + +BasicBlock * +LoopIdiomRecognize::transformBCmpControlFlow(ICmpInst *ComparedEqual) { + LLVM_DEBUG(dbgs() << "Transforming control-flow.\n"); + SmallVector<DominatorTree::UpdateType, 8> DTUpdates; + + BasicBlock *PreheaderBB = CurLoop->getLoopPreheader(); + BasicBlock *HeaderBB = CurLoop->getHeader(); + BasicBlock *LoopLatchBB = CurLoop->getLoopLatch(); + SmallString<32> LoopName = CurLoop->getName(); + Function *Func = PreheaderBB->getParent(); + LLVMContext &Context = Func->getContext(); + + // Before doing anything, drop SCEV info. + SE->forgetLoop(CurLoop); + + // Here we start with: (0/6) + // PreheaderBB: <preheader> ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // br label %LoopHeaderBB + // LoopHeaderBB: <header,exiting> ; preds = %PreheaderBB,%LoopLatchBB + // <...> + // br i1 %<...>, label %LoopLatchBB, label %Successor0BB + // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB + // <...> + // br i1 %<...>, label %Successor1BB, label %LoopHeaderBB + // Successor0BB: <exit> ; preds = %LoopHeaderBB + // %S0PHI = phi <...> [ <...>, %LoopHeaderBB ] + // <...> + // Successor1BB: <exit> ; preds = %LoopLatchBB + // %S1PHI = phi <...> [ <...>, %LoopLatchBB ] + // <...> + // + // Successor0 and Successor1 may or may not be the same basic block. + + // Decouple the edge between loop preheader basic block and loop header basic + // block. Thus the loop has become unreachable. + assert(cast<BranchInst>(PreheaderBB->getTerminator())->isUnconditional() && + PreheaderBB->getTerminator()->getSuccessor(0) == HeaderBB && + "Preheader bb must end with an unconditional branch to header bb."); + PreheaderBB->getTerminator()->eraseFromParent(); + DTUpdates.push_back({DominatorTree::Delete, PreheaderBB, HeaderBB}); + + // Create a new preheader basic block before loop header basic block. + auto *PhonyPreheaderBB = BasicBlock::Create( + Context, LoopName + ".phonypreheaderbb", Func, HeaderBB); + // And insert an unconditional branch from phony preheader basic block to + // loop header basic block. + IRBuilder<>(PhonyPreheaderBB).CreateBr(HeaderBB); + DTUpdates.push_back({DominatorTree::Insert, PhonyPreheaderBB, HeaderBB}); + + // Create a *single* new empty block that we will substitute as a + // successor basic block for the loop's exits. This one is temporary. + // Much like phony preheader basic block, it is not connected. + auto *PhonySuccessorBB = + BasicBlock::Create(Context, LoopName + ".phonysuccessorbb", Func, + LoopLatchBB->getNextNode()); + // That block must have *some* non-PHI instruction, or else deleteDeadLoop() + // will mess up cleanup of dbginfo, and verifier will complain. + IRBuilder<>(PhonySuccessorBB).CreateUnreachable(); + + // Create two new empty blocks that we will use to preserve the original + // loop exit control-flow, and preserve the incoming values in the PHI nodes + // in loop's successor exit blocks. These will live one. + auto *ComparedUnequalBB = + BasicBlock::Create(Context, ComparedEqual->getName() + ".unequalbb", Func, + PhonySuccessorBB->getNextNode()); + auto *ComparedEqualBB = + BasicBlock::Create(Context, ComparedEqual->getName() + ".equalbb", Func, + PhonySuccessorBB->getNextNode()); + + // By now we have: (1/6) + // PreheaderBB: ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // [no terminator instruction!] + // PhonyPreheaderBB: <preheader> ; No preds, UNREACHABLE! + // br label %LoopHeaderBB + // LoopHeaderBB: <header,exiting> ; preds = %PhonyPreheaderBB, %LoopLatchBB + // <...> + // br i1 %<...>, label %LoopLatchBB, label %Successor0BB + // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB + // <...> + // br i1 %<...>, label %Successor1BB, label %LoopHeaderBB + // PhonySuccessorBB: ; No preds, UNREACHABLE! + // unreachable + // EqualBB: ; No preds, UNREACHABLE! + // [no terminator instruction!] + // UnequalBB: ; No preds, UNREACHABLE! + // [no terminator instruction!] + // Successor0BB: <exit> ; preds = %LoopHeaderBB + // %S0PHI = phi <...> [ <...>, %LoopHeaderBB ] + // <...> + // Successor1BB: <exit> ; preds = %LoopLatchBB + // %S1PHI = phi <...> [ <...>, %LoopLatchBB ] + // <...> + + // What is the mapping/replacement basic block for exiting out of the loop + // from either of old's loop basic blocks? + auto GetReplacementBB = [this, ComparedEqualBB, + ComparedUnequalBB](const BasicBlock *OldBB) { + assert(CurLoop->contains(OldBB) && "Only for loop's basic blocks."); + if (OldBB == CurLoop->getLoopLatch()) // "all elements compared equal". + return ComparedEqualBB; + if (OldBB == CurLoop->getHeader()) // "element compared unequal". + return ComparedUnequalBB; + llvm_unreachable("Only had two basic blocks in loop."); + }; + + // What are the exits out of this loop? + SmallVector<Loop::Edge, 2> LoopExitEdges; + CurLoop->getExitEdges(LoopExitEdges); + assert(LoopExitEdges.size() == 2 && "Should have only to two exit edges."); + + // Populate new basic blocks, update the exiting control-flow, PHI nodes. + for (const Loop::Edge &Edge : LoopExitEdges) { + auto *OldLoopBB = const_cast<BasicBlock *>(Edge.first); + auto *SuccessorBB = const_cast<BasicBlock *>(Edge.second); + assert(CurLoop->contains(OldLoopBB) && !CurLoop->contains(SuccessorBB) && + "Unexpected edge."); + + // If we would exit the loop from this loop's basic block, + // what semantically would that mean? Did comparison succeed or fail? + BasicBlock *NewBB = GetReplacementBB(OldLoopBB); + assert(NewBB->empty() && "Should not get same new basic block here twice."); + IRBuilder<> Builder(NewBB); + Builder.SetCurrentDebugLocation(OldLoopBB->getTerminator()->getDebugLoc()); + Builder.CreateBr(SuccessorBB); + DTUpdates.push_back({DominatorTree::Insert, NewBB, SuccessorBB}); + // Also, be *REALLY* careful with PHI nodes in successor basic block, + // update them to recieve the same input value, but not from current loop's + // basic block, but from new basic block instead. + SuccessorBB->replacePhiUsesWith(OldLoopBB, NewBB); + // Also, change loop control-flow. This loop's basic block shall no longer + // exit from the loop to it's original successor basic block, but to our new + // phony successor basic block. Note that new successor will be unique exit. + OldLoopBB->getTerminator()->replaceSuccessorWith(SuccessorBB, + PhonySuccessorBB); + DTUpdates.push_back({DominatorTree::Delete, OldLoopBB, SuccessorBB}); + DTUpdates.push_back({DominatorTree::Insert, OldLoopBB, PhonySuccessorBB}); + } + + // Inform DomTree about edge changes. Note that LoopInfo is still out-of-date. + assert(DTUpdates.size() == 8 && "Update count prediction failed."); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + DTU.applyUpdates(DTUpdates); + DTUpdates.clear(); + + // By now we have: (2/6) + // PreheaderBB: ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // [no terminator instruction!] + // PhonyPreheaderBB: <preheader> ; No preds, UNREACHABLE! + // br label %LoopHeaderBB + // LoopHeaderBB: <header,exiting> ; preds = %PhonyPreheaderBB, %LoopLatchBB + // <...> + // br i1 %<...>, label %LoopLatchBB, label %PhonySuccessorBB + // LoopLatchBB: <latch,exiting> ; preds = %LoopHeaderBB + // <...> + // br i1 %<...>, label %PhonySuccessorBB, label %LoopHeaderBB + // PhonySuccessorBB: <uniq. exit> ; preds = %LoopHeaderBB, %LoopLatchBB + // unreachable + // EqualBB: ; No preds, UNREACHABLE! + // br label %Successor1BB + // UnequalBB: ; No preds, UNREACHABLE! + // br label %Successor0BB + // Successor0BB: ; preds = %UnequalBB + // %S0PHI = phi <...> [ <...>, %UnequalBB ] + // <...> + // Successor1BB: ; preds = %EqualBB + // %S0PHI = phi <...> [ <...>, %EqualBB ] + // <...> + + // *Finally*, zap the original loop. Record it's parent loop though. + Loop *ParentLoop = CurLoop->getParentLoop(); + LLVM_DEBUG(dbgs() << "Deleting old loop.\n"); + LoopDeleter.markLoopAsDeleted(CurLoop); // Mark as deleted *BEFORE* deleting! + deleteDeadLoop(CurLoop, DT, SE, LI); // And actually delete the loop. + CurLoop = nullptr; + + // By now we have: (3/6) + // PreheaderBB: ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // [no terminator instruction!] + // PhonyPreheaderBB: ; No preds, UNREACHABLE! + // br label %PhonySuccessorBB + // PhonySuccessorBB: ; preds = %PhonyPreheaderBB + // unreachable + // EqualBB: ; No preds, UNREACHABLE! + // br label %Successor1BB + // UnequalBB: ; No preds, UNREACHABLE! + // br label %Successor0BB + // Successor0BB: ; preds = %UnequalBB + // %S0PHI = phi <...> [ <...>, %UnequalBB ] + // <...> + // Successor1BB: ; preds = %EqualBB + // %S0PHI = phi <...> [ <...>, %EqualBB ] + // <...> + + // Now, actually restore the CFG. + + // Insert an unconditional branch from an actual preheader basic block to + // phony preheader basic block. + IRBuilder<>(PreheaderBB).CreateBr(PhonyPreheaderBB); + DTUpdates.push_back({DominatorTree::Insert, PhonyPreheaderBB, HeaderBB}); + // Insert proper conditional branch from phony successor basic block to the + // "dispatch" basic blocks, which were used to preserve incoming values in + // original loop's successor basic blocks. + assert(isa<UnreachableInst>(PhonySuccessorBB->getTerminator()) && + "Yep, that's the one we created to keep deleteDeadLoop() happy."); + PhonySuccessorBB->getTerminator()->eraseFromParent(); + { + IRBuilder<> Builder(PhonySuccessorBB); + Builder.SetCurrentDebugLocation(ComparedEqual->getDebugLoc()); + Builder.CreateCondBr(ComparedEqual, ComparedEqualBB, ComparedUnequalBB); + } + DTUpdates.push_back( + {DominatorTree::Insert, PhonySuccessorBB, ComparedEqualBB}); + DTUpdates.push_back( + {DominatorTree::Insert, PhonySuccessorBB, ComparedUnequalBB}); + + BasicBlock *DispatchBB = PhonySuccessorBB; + DispatchBB->setName(LoopName + ".bcmpdispatchbb"); + + assert(DTUpdates.size() == 3 && "Update count prediction failed."); + DTU.applyUpdates(DTUpdates); + DTUpdates.clear(); + + // By now we have: (4/6) + // PreheaderBB: ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // br label %PhonyPreheaderBB + // PhonyPreheaderBB: ; preds = %PreheaderBB + // br label %DispatchBB + // DispatchBB: ; preds = %PhonyPreheaderBB + // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB + // EqualBB: ; preds = %DispatchBB + // br label %Successor1BB + // UnequalBB: ; preds = %DispatchBB + // br label %Successor0BB + // Successor0BB: ; preds = %UnequalBB + // %S0PHI = phi <...> [ <...>, %UnequalBB ] + // <...> + // Successor1BB: ; preds = %EqualBB + // %S0PHI = phi <...> [ <...>, %EqualBB ] + // <...> + + // The basic CFG has been restored! Now let's merge redundant basic blocks. + + // Merge phony successor basic block into it's only predecessor, + // phony preheader basic block. It is fully pointlessly redundant. + MergeBasicBlockIntoOnlyPred(DispatchBB, &DTU); + + // By now we have: (5/6) + // PreheaderBB: ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // br label %DispatchBB + // DispatchBB: ; preds = %PreheaderBB + // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB + // EqualBB: ; preds = %DispatchBB + // br label %Successor1BB + // UnequalBB: ; preds = %DispatchBB + // br label %Successor0BB + // Successor0BB: ; preds = %UnequalBB + // %S0PHI = phi <...> [ <...>, %UnequalBB ] + // <...> + // Successor1BB: ; preds = %EqualBB + // %S0PHI = phi <...> [ <...>, %EqualBB ] + // <...> + + // Was this loop nested? + if (!ParentLoop) { + // If the loop was *NOT* nested, then let's also merge phony successor + // basic block into it's only predecessor, preheader basic block. + // Also, here we need to update LoopInfo. + LI->removeBlock(PreheaderBB); + MergeBasicBlockIntoOnlyPred(DispatchBB, &DTU); + + // By now we have: (6/6) + // DispatchBB: ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB + // EqualBB: ; preds = %DispatchBB + // br label %Successor1BB + // UnequalBB: ; preds = %DispatchBB + // br label %Successor0BB + // Successor0BB: ; preds = %UnequalBB + // %S0PHI = phi <...> [ <...>, %UnequalBB ] + // <...> + // Successor1BB: ; preds = %EqualBB + // %S0PHI = phi <...> [ <...>, %EqualBB ] + // <...> + + return DispatchBB; + } + + // Otherwise, we need to "preserve" the LoopSimplify form of the deleted loop. + // To achieve that, we shall keep the preheader basic block (mainly so that + // the loop header block will be guaranteed to have a predecessor outside of + // the loop), and create a phony loop with all these new three basic blocks. + Loop *PhonyLoop = LI->AllocateLoop(); + ParentLoop->addChildLoop(PhonyLoop); + PhonyLoop->addBasicBlockToLoop(DispatchBB, *LI); + PhonyLoop->addBasicBlockToLoop(ComparedEqualBB, *LI); + PhonyLoop->addBasicBlockToLoop(ComparedUnequalBB, *LI); + + // But we only have a preheader basic block, a header basic block block and + // two exiting basic blocks. For a proper loop we also need a backedge from + // non-header basic block to header bb. + // Let's just add a never-taken branch from both of the exiting basic blocks. + for (BasicBlock *BB : {ComparedEqualBB, ComparedUnequalBB}) { + BranchInst *OldTerminator = cast<BranchInst>(BB->getTerminator()); + assert(OldTerminator->isUnconditional() && "That's the one we created."); + BasicBlock *SuccessorBB = OldTerminator->getSuccessor(0); + + IRBuilder<> Builder(OldTerminator); + Builder.SetCurrentDebugLocation(OldTerminator->getDebugLoc()); + Builder.CreateCondBr(ConstantInt::getTrue(Context), SuccessorBB, + DispatchBB); + OldTerminator->eraseFromParent(); + // Yes, the backedge will never be taken. The control-flow is redundant. + // If it can be simplified further, other passes will take care. + DTUpdates.push_back({DominatorTree::Delete, BB, SuccessorBB}); + DTUpdates.push_back({DominatorTree::Insert, BB, SuccessorBB}); + DTUpdates.push_back({DominatorTree::Insert, BB, DispatchBB}); + } + assert(DTUpdates.size() == 6 && "Update count prediction failed."); + DTU.applyUpdates(DTUpdates); + DTUpdates.clear(); + + // By now we have: (6/6) + // PreheaderBB: <preheader> ; preds = ??? + // <...> + // %memcmp = call i32 @memcmp(i8* %LoadSrcA, i8* %LoadSrcB, i64 %Nbytes) + // %ComparedEqual = icmp eq <...> %memcmp, 0 + // br label %BCmpDispatchBB + // BCmpDispatchBB: <header> ; preds = %PreheaderBB + // br i1 %ComparedEqual, label %EqualBB, label %UnequalBB + // EqualBB: <latch,exiting> ; preds = %BCmpDispatchBB + // br i1 %true, label %Successor1BB, label %BCmpDispatchBB + // UnequalBB: <latch,exiting> ; preds = %BCmpDispatchBB + // br i1 %true, label %Successor0BB, label %BCmpDispatchBB + // Successor0BB: ; preds = %UnequalBB + // %S0PHI = phi <...> [ <...>, %UnequalBB ] + // <...> + // Successor1BB: ; preds = %EqualBB + // %S0PHI = phi <...> [ <...>, %EqualBB ] + // <...> + + // Finally fully DONE! + return DispatchBB; +} + +void LoopIdiomRecognize::transformLoopToBCmp(ICmpInst *BCmpInst, + CmpInst *LatchCmpInst, + LoadInst *LoadA, LoadInst *LoadB, + const SCEV *SrcA, const SCEV *SrcB, + const SCEV *NBytes) { + // We will be inserting before the terminator instruction of preheader block. + IRBuilder<> Builder(CurLoop->getLoopPreheader()->getTerminator()); + + LLVM_DEBUG(dbgs() << "Transforming bcmp loop idiom into a call.\n"); + LLVM_DEBUG(dbgs() << "Emitting new instructions.\n"); + + // Expand the SCEV expressions for both sources to compare, and produce value + // for the byte len (beware of Iterations potentially being a pointer, and + // account for element size being BCmpTyBytes bytes, which may be not 1 byte) + Value *PtrA, *PtrB, *Len; + { + SCEVExpander SExp(*SE, *DL, "LoopToBCmp"); + SExp.setInsertPoint(&*Builder.GetInsertPoint()); + + auto HandlePtr = [&SExp](LoadInst *Load, const SCEV *Src) { + SExp.SetCurrentDebugLocation(DebugLoc()); + // If the pointer operand of original load had dbgloc - use it. + if (const auto *I = dyn_cast<Instruction>(Load->getPointerOperand())) + SExp.SetCurrentDebugLocation(I->getDebugLoc()); + return SExp.expandCodeFor(Src); + }; + PtrA = HandlePtr(LoadA, SrcA); + PtrB = HandlePtr(LoadB, SrcB); + + // For len calculation let's use dbgloc for the loop's latch condition. + Builder.SetCurrentDebugLocation(LatchCmpInst->getDebugLoc()); + SExp.SetCurrentDebugLocation(LatchCmpInst->getDebugLoc()); + Len = SExp.expandCodeFor(NBytes); + + Type *CmpFuncSizeTy = DL->getIntPtrType(Builder.getContext()); + assert(SE->getTypeSizeInBits(Len->getType()) == + DL->getTypeSizeInBits(CmpFuncSizeTy) && + "Len should already have the correct size."); + + // Make sure that iteration count is a number, insert ptrtoint cast if not. + if (Len->getType()->isPointerTy()) + Len = Builder.CreatePtrToInt(Len, CmpFuncSizeTy); + assert(Len->getType() == CmpFuncSizeTy && "Should have correct type now."); + + Len->setName(Len->getName() + ".bytecount"); + + // There is no legality check needed. We want to compare that the memory + // regions [PtrA, PtrA+Len) and [PtrB, PtrB+Len) are fully identical, equal. + // For them to be fully equal, they must match bit-by-bit. And likewise, + // for them to *NOT* be fully equal, they have to differ just by one bit. + // The step of comparison (bits compared at once) simply does not matter. + } + + // For the rest of new instructions, dbgloc should point at the value cmp. + Builder.SetCurrentDebugLocation(BCmpInst->getDebugLoc()); + + // Emit the comparison itself. + auto *CmpCall = + cast<CallInst>(HasBCmp ? emitBCmp(PtrA, PtrB, Len, Builder, *DL, TLI) + : emitMemCmp(PtrA, PtrB, Len, Builder, *DL, TLI)); + // FIXME: add {B,Mem}CmpInst with MemoryCompareInst + // (based on MemIntrinsicBase) as base? + // FIXME: propagate metadata from loads? (alignments, AS, TBAA, ...) + + // {b,mem}cmp returned 0 if they were equal, or non-zero if not equal. + auto *ComparedEqual = cast<ICmpInst>(Builder.CreateICmpEQ( + CmpCall, ConstantInt::get(CmpCall->getType(), 0), + PtrA->getName() + ".vs." + PtrB->getName() + ".eqcmp")); + + BasicBlock *BB = transformBCmpControlFlow(ComparedEqual); + Builder.ClearInsertionPoint(); + + // We're done. + LLVM_DEBUG(dbgs() << "Transformed loop bcmp idiom into a call.\n"); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "TransformedBCmpIdiomToCall", + CmpCall->getDebugLoc(), BB) + << "Transformed bcmp idiom into a call to " + << ore::NV("NewFunction", CmpCall->getCalledFunction()) + << "() function"; + }); + ++NumBCmp; +} + +/// Recognizes a bcmp idiom in a non-countable loop. +/// +/// If detected, transforms the relevant code to issue the bcmp (or memcmp) +/// intrinsic function call, and returns true; otherwise, returns false. +bool LoopIdiomRecognize::recognizeBCmp() { + if (!HasMemCmp && !HasBCmp) + return false; + + ICmpInst *BCmpInst; + CmpInst *LatchCmpInst; + LoadInst *LoadA, *LoadB; + const SCEV *SrcA, *SrcB, *NBytes; + if (!detectBCmpIdiom(BCmpInst, LatchCmpInst, LoadA, LoadB, SrcA, SrcB, + NBytes)) { + LLVM_DEBUG(dbgs() << "bcmp idiom recognition failed.\n"); + return false; + } + + transformLoopToBCmp(BCmpInst, LatchCmpInst, LoadA, LoadB, SrcA, SrcB, NBytes); + return true; +} diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index 31191b52895c..368b9d4e8df1 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -192,7 +192,8 @@ public: getAnalysis<AssumptionCacheTracker>().getAssumptionCache( *L->getHeader()->getParent()); const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( + *L->getHeader()->getParent()); MemorySSA *MSSA = nullptr; Optional<MemorySSAUpdater> MSSAU; if (EnableMSSALoopDependency) { @@ -233,7 +234,7 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, auto PA = getLoopPassPreservedAnalyses(); PA.preserveSet<CFGAnalyses>(); - if (EnableMSSALoopDependency) + if (AR.MSSA) PA.preserve<MemorySSAAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index 9a42365adc1b..1af4b21b432e 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -410,8 +410,6 @@ public: void removeChildLoop(Loop *OuterLoop, Loop *InnerLoop); private: - void splitInnerLoopLatch(Instruction *); - void splitInnerLoopHeader(); bool adjustLoopLinks(); void adjustLoopPreheaders(); bool adjustLoopBranches(); @@ -1226,7 +1224,7 @@ bool LoopInterchangeTransform::transform() { if (InnerLoop->getSubLoops().empty()) { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - LLVM_DEBUG(dbgs() << "Calling Split Inner Loop\n"); + LLVM_DEBUG(dbgs() << "Splitting the inner loop latch\n"); PHINode *InductionPHI = getInductionVariable(InnerLoop, SE); if (!InductionPHI) { LLVM_DEBUG(dbgs() << "Failed to find the point to split loop latch \n"); @@ -1242,11 +1240,55 @@ bool LoopInterchangeTransform::transform() { if (&InductionPHI->getParent()->front() != InductionPHI) InductionPHI->moveBefore(&InductionPHI->getParent()->front()); - // Split at the place were the induction variable is - // incremented/decremented. - // TODO: This splitting logic may not work always. Fix this. - splitInnerLoopLatch(InnerIndexVar); - LLVM_DEBUG(dbgs() << "splitInnerLoopLatch done\n"); + // Create a new latch block for the inner loop. We split at the + // current latch's terminator and then move the condition and all + // operands that are not either loop-invariant or the induction PHI into the + // new latch block. + BasicBlock *NewLatch = + SplitBlock(InnerLoop->getLoopLatch(), + InnerLoop->getLoopLatch()->getTerminator(), DT, LI); + + SmallSetVector<Instruction *, 4> WorkList; + unsigned i = 0; + auto MoveInstructions = [&i, &WorkList, this, InductionPHI, NewLatch]() { + for (; i < WorkList.size(); i++) { + // Duplicate instruction and move it the new latch. Update uses that + // have been moved. + Instruction *NewI = WorkList[i]->clone(); + NewI->insertBefore(NewLatch->getFirstNonPHI()); + assert(!NewI->mayHaveSideEffects() && + "Moving instructions with side-effects may change behavior of " + "the loop nest!"); + for (auto UI = WorkList[i]->use_begin(), UE = WorkList[i]->use_end(); + UI != UE;) { + Use &U = *UI++; + Instruction *UserI = cast<Instruction>(U.getUser()); + if (!InnerLoop->contains(UserI->getParent()) || + UserI->getParent() == NewLatch || UserI == InductionPHI) + U.set(NewI); + } + // Add operands of moved instruction to the worklist, except if they are + // outside the inner loop or are the induction PHI. + for (Value *Op : WorkList[i]->operands()) { + Instruction *OpI = dyn_cast<Instruction>(Op); + if (!OpI || + this->LI->getLoopFor(OpI->getParent()) != this->InnerLoop || + OpI == InductionPHI) + continue; + WorkList.insert(OpI); + } + } + }; + + // FIXME: Should we interchange when we have a constant condition? + Instruction *CondI = dyn_cast<Instruction>( + cast<BranchInst>(InnerLoop->getLoopLatch()->getTerminator()) + ->getCondition()); + if (CondI) + WorkList.insert(CondI); + MoveInstructions(); + WorkList.insert(cast<Instruction>(InnerIndexVar)); + MoveInstructions(); // Splits the inner loops phi nodes out into a separate basic block. BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); @@ -1263,10 +1305,6 @@ bool LoopInterchangeTransform::transform() { return true; } -void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { - SplitBlock(InnerLoop->getLoopLatch(), Inc, DT, LI); -} - /// \brief Move all instructions except the terminator from FromBB right before /// InsertBefore static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index 2b3d5e0ce9b7..e8dc879a184b 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -435,7 +435,8 @@ public: PH->getTerminator()); Value *Initial = new LoadInst( Cand.Load->getType(), InitialPtr, "load_initial", - /* isVolatile */ false, Cand.Load->getAlignment(), PH->getTerminator()); + /* isVolatile */ false, MaybeAlign(Cand.Load->getAlignment()), + PH->getTerminator()); PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", &L->getHeader()->front()); diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index 507a1e251ca6..885c0e8f4b8b 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -543,7 +543,7 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) { if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) if (LI->isUnordered() && L->hasLoopInvariantOperands(LI)) if (AA->pointsToConstantMemory(LI->getOperand(0)) || - LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr) + LI->hasMetadata(LLVMContext::MD_invariant_load)) return true; return false; } diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp index 166b57f20b43..96e2c2a3ac6b 100644 --- a/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -1644,7 +1644,8 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( + *L->getHeader()->getParent()); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index e009947690af..94517996df39 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -55,7 +55,7 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, AR.MSSA->verifyMemorySSA(); auto PA = getLoopPassPreservedAnalyses(); - if (EnableMSSALoopDependency) + if (AR.MSSA) PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -94,17 +94,15 @@ public: auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - auto *SE = SEWP ? &SEWP->getSE() : nullptr; + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); Optional<MemorySSAUpdater> MSSAU; if (EnableMSSALoopDependency) { MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); MSSAU = MemorySSAUpdater(MSSA); } - return LoopRotation(L, LI, TTI, AC, DT, SE, + return LoopRotation(L, LI, TTI, AC, &DT, &SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ, false, MaxHeaderSize, false); } diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 046f4c8af492..299f3fc5fb19 100644 --- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -690,7 +690,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &LPMU) { Optional<MemorySSAUpdater> MSSAU; - if (EnableMSSALoopDependency && AR.MSSA) + if (AR.MSSA) MSSAU = MemorySSAUpdater(AR.MSSA); bool DeleteCurrentLoop = false; if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE, @@ -702,7 +702,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LPMU.markLoopAsDeleted(L, "loop-simplifycfg"); auto PA = getLoopPassPreservedAnalyses(); - if (EnableMSSALoopDependency) + if (AR.MSSA) PA.preserve<MemorySSAAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp index 975452e13f09..65e0dee0225a 100644 --- a/lib/Transforms/Scalar/LoopSink.cpp +++ b/lib/Transforms/Scalar/LoopSink.cpp @@ -230,12 +230,9 @@ static bool sinkInstruction(Loop &L, Instruction &I, IC->setName(I.getName()); IC->insertBefore(&*N->getFirstInsertionPt()); // Replaces uses of I with IC in N - for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); UI != UE;) { - Use &U = *UI++; - auto *I = cast<Instruction>(U.getUser()); - if (I->getParent() == N) - U.set(IC); - } + I.replaceUsesWithIf(IC, [N](Use &U) { + return cast<Instruction>(U.getUser())->getParent() == N; + }); // Replaces uses of I with IC in blocks dominated by N replaceDominatedUsesWith(&I, IC, DT, N); LLVM_DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName() diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 59a387a186b8..7f119175c4a8 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1386,7 +1386,9 @@ void Cost::RateFormula(const Formula &F, // Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as // additional instruction (at least fill). - unsigned TTIRegNum = TTI->getNumberOfRegisters(false) - 1; + // TODO: Need distinguish register class? + unsigned TTIRegNum = TTI->getNumberOfRegisters( + TTI->getRegisterClassForType(false, F.getType())) - 1; if (C.NumRegs > TTIRegNum) { // Cost already exceeded TTIRegNum, then only newly added register can add // new instructions. @@ -3165,6 +3167,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, LLVM_DEBUG(dbgs() << "Concealed chain head: " << *Head.UserInst << "\n"); return; } + assert(IVSrc && "Failed to find IV chain source"); LLVM_DEBUG(dbgs() << "Generate chain at: " << *IVSrc << "\n"); Type *IVTy = IVSrc->getType(); @@ -3265,12 +3268,12 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // requirements for both N and i at the same time. Limiting this code to // equality icmps is not a problem because all interesting loops use // equality icmps, thanks to IndVarSimplify. - if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst)) + if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst)) { + // If CI can be saved in some target, like replaced inside hardware loop + // in PowerPC, no need to generate initial formulae for it. + if (SaveCmp && CI == dyn_cast<ICmpInst>(ExitBranch->getCondition())) + continue; if (CI->isEquality()) { - // If CI can be saved in some target, like replaced inside hardware loop - // in PowerPC, no need to generate initial formulae for it. - if (SaveCmp && CI == dyn_cast<ICmpInst>(ExitBranch->getCondition())) - continue; // Swap the operands if needed to put the OperandValToReplace on the // left, for consistency. Value *NV = CI->getOperand(1); @@ -3298,6 +3301,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { Factors.insert(-(uint64_t)Factors[i]); Factors.insert(-1); } + } // Get or create an LSRUse. std::pair<size_t, int64_t> P = getUse(S, Kind, AccessTy); @@ -4834,6 +4838,7 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { } } } + assert(Best && "Failed to find best LSRUse candidate"); LLVM_DEBUG(dbgs() << "Narrowing the search space by assuming " << *Best << " will yield profitable reuse.\n"); @@ -5740,7 +5745,8 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { *L->getHeader()->getParent()); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( *L->getHeader()->getParent()); - auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( + *L->getHeader()->getParent()); return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, LibInfo); } diff --git a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 86891eb451bb..8d88be420314 100644 --- a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -166,7 +166,7 @@ static bool computeUnrollAndJamCount( bool UseUpperBound = false; bool ExplicitUnroll = computeUnrollCount( L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, - OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); + /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); if (ExplicitUnroll || UseUpperBound) { // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it // for the unroller instead. @@ -293,9 +293,9 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, if (Latch != Exit || SubLoopLatch != SubLoopExit) return LoopUnrollResult::Unmodified; - TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, SE, TTI, nullptr, nullptr, OptLevel, - None, None, None, None, None, None); + TargetTransformInfo::UnrollingPreferences UP = + gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, OptLevel, None, + None, None, None, None, None, None, None); if (AllowUnrollAndJam.getNumOccurrences() > 0) UP.UnrollAndJam = AllowUnrollAndJam; if (UnrollAndJamThreshold.getNumOccurrences() > 0) diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index 2fa7436213dd..a6d4164c3645 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -178,7 +178,9 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, - Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling) { + Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling, + Optional<bool> UserAllowProfileBasedPeeling, + Optional<unsigned> UserFullUnrollMaxCount) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults @@ -202,6 +204,7 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.UpperBound = false; UP.AllowPeeling = true; UP.UnrollAndJam = false; + UP.PeelProfiledIterations = true; UP.UnrollAndJamInnerLoopThreshold = 60; // Override with any target specific settings @@ -257,6 +260,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.UpperBound = *UserUpperBound; if (UserAllowPeeling.hasValue()) UP.AllowPeeling = *UserAllowPeeling; + if (UserAllowProfileBasedPeeling.hasValue()) + UP.PeelProfiledIterations = *UserAllowProfileBasedPeeling; + if (UserFullUnrollMaxCount.hasValue()) + UP.FullUnrollMaxCount = *UserFullUnrollMaxCount; return UP; } @@ -730,7 +737,7 @@ bool llvm::computeUnrollCount( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount, - unsigned &TripMultiple, unsigned LoopSize, + bool MaxOrZero, unsigned &TripMultiple, unsigned LoopSize, TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { // Check for explicit Count. @@ -781,18 +788,34 @@ bool llvm::computeUnrollCount( // Also we need to check if we exceed FullUnrollMaxCount. // If using the upper bound to unroll, TripMultiple should be set to 1 because // we do not know when loop may exit. - // MaxTripCount and ExactTripCount cannot both be non zero since we only + + // We can unroll by the upper bound amount if it's generally allowed or if + // we know that the loop is executed either the upper bound or zero times. + // (MaxOrZero unrolling keeps only the first loop test, so the number of + // loop tests remains the same compared to the non-unrolled version, whereas + // the generic upper bound unrolling keeps all but the last loop test so the + // number of loop tests goes up which may end up being worse on targets with + // constrained branch predictor resources so is controlled by an option.) + // In addition we only unroll small upper bounds. + unsigned FullUnrollMaxTripCount = MaxTripCount; + if (!(UP.UpperBound || MaxOrZero) || + FullUnrollMaxTripCount > UnrollMaxUpperBound) + FullUnrollMaxTripCount = 0; + + // UnrollByMaxCount and ExactTripCount cannot both be non zero since we only // compute the former when the latter is zero. unsigned ExactTripCount = TripCount; - assert((ExactTripCount == 0 || MaxTripCount == 0) && - "ExtractTripCount and MaxTripCount cannot both be non zero."); - unsigned FullUnrollTripCount = ExactTripCount ? ExactTripCount : MaxTripCount; + assert((ExactTripCount == 0 || FullUnrollMaxTripCount == 0) && + "ExtractTripCount and UnrollByMaxCount cannot both be non zero."); + + unsigned FullUnrollTripCount = + ExactTripCount ? ExactTripCount : FullUnrollMaxTripCount; UP.Count = FullUnrollTripCount; if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) { // When computing the unrolled size, note that BEInsns are not replicated // like the rest of the loop body. if (getUnrolledLoopSize(LoopSize, UP) < UP.Threshold) { - UseUpperBound = (MaxTripCount == FullUnrollTripCount); + UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount); TripCount = FullUnrollTripCount; TripMultiple = UP.UpperBound ? 1 : TripMultiple; return ExplicitUnroll; @@ -806,7 +829,7 @@ bool llvm::computeUnrollCount( unsigned Boost = getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { - UseUpperBound = (MaxTripCount == FullUnrollTripCount); + UseUpperBound = (FullUnrollMaxTripCount == FullUnrollTripCount); TripCount = FullUnrollTripCount; TripMultiple = UP.UpperBound ? 1 : TripMultiple; return ExplicitUnroll; @@ -882,6 +905,8 @@ bool llvm::computeUnrollCount( "because " "unrolled size is too large."; }); + LLVM_DEBUG(dbgs() << " partially unrolling with count: " << UP.Count + << "\n"); return ExplicitUnroll; } assert(TripCount == 0 && @@ -903,6 +928,12 @@ bool llvm::computeUnrollCount( return false; } + // Don't unroll a small upper bound loop unless user or TTI asked to do so. + if (MaxTripCount && !UP.Force && MaxTripCount < UnrollMaxUpperBound) { + UP.Count = 0; + return false; + } + // Check if the runtime trip count is too small when profile is available. if (L->getHeader()->getParent()->hasProfileData()) { if (auto ProfileTripCount = getLoopEstimatedTripCount(L)) { @@ -966,7 +997,11 @@ bool llvm::computeUnrollCount( if (UP.Count > UP.MaxCount) UP.Count = UP.MaxCount; - LLVM_DEBUG(dbgs() << " partially unrolling with count: " << UP.Count + + if (MaxTripCount && UP.Count > MaxTripCount) + UP.Count = MaxTripCount; + + LLVM_DEBUG(dbgs() << " runtime unrolling with count: " << UP.Count << "\n"); if (UP.Count < 2) UP.Count = 0; @@ -976,13 +1011,14 @@ bool llvm::computeUnrollCount( static LoopUnrollResult tryToUnrollLoop( Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, - OptimizationRemarkEmitter &ORE, - BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, - bool PreserveLCSSA, int OptLevel, + OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, bool PreserveLCSSA, int OptLevel, bool OnlyWhenForced, bool ForgetAllSCEV, Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, Optional<bool> ProvidedRuntime, Optional<bool> ProvidedUpperBound, - Optional<bool> ProvidedAllowPeeling) { + Optional<bool> ProvidedAllowPeeling, + Optional<bool> ProvidedAllowProfileBasedPeeling, + Optional<unsigned> ProvidedFullUnrollMaxCount) { LLVM_DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); @@ -1007,7 +1043,8 @@ static LoopUnrollResult tryToUnrollLoop( TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, - ProvidedAllowPeeling); + ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling, + ProvidedFullUnrollMaxCount); // Exit early if unrolling is disabled. For OptForSize, we pick the loop size // as threshold later on. @@ -1028,10 +1065,10 @@ static LoopUnrollResult tryToUnrollLoop( return LoopUnrollResult::Unmodified; } - // When optimizing for size, use LoopSize as threshold, to (fully) unroll - // loops, if it does not increase code size. + // When optimizing for size, use LoopSize + 1 as threshold (we use < Threshold + // later), to (fully) unroll loops, if it does not increase code size. if (OptForSize) - UP.Threshold = std::max(UP.Threshold, LoopSize); + UP.Threshold = std::max(UP.Threshold, LoopSize + 1); if (NumInlineCandidates != 0) { LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); @@ -1040,7 +1077,6 @@ static LoopUnrollResult tryToUnrollLoop( // Find trip count and trip multiple if count is not available unsigned TripCount = 0; - unsigned MaxTripCount = 0; unsigned TripMultiple = 1; // If there are multiple exiting blocks but one of them is the latch, use the // latch for the trip count estimation. Otherwise insist on a single exiting @@ -1070,28 +1106,18 @@ static LoopUnrollResult tryToUnrollLoop( // Try to find the trip count upper bound if we cannot find the exact trip // count. + unsigned MaxTripCount = 0; bool MaxOrZero = false; if (!TripCount) { MaxTripCount = SE.getSmallConstantMaxTripCount(L); MaxOrZero = SE.isBackedgeTakenCountMaxOrZero(L); - // We can unroll by the upper bound amount if it's generally allowed or if - // we know that the loop is executed either the upper bound or zero times. - // (MaxOrZero unrolling keeps only the first loop test, so the number of - // loop tests remains the same compared to the non-unrolled version, whereas - // the generic upper bound unrolling keeps all but the last loop test so the - // number of loop tests goes up which may end up being worse on targets with - // constrained branch predictor resources so is controlled by an option.) - // In addition we only unroll small upper bounds. - if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) { - MaxTripCount = 0; - } } // computeUnrollCount() decides whether it is beneficial to use upper bound to // fully unroll the loop. bool UseUpperBound = false; bool IsCountSetExplicitly = computeUnrollCount( - L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, + L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero, TripMultiple, LoopSize, UP, UseUpperBound); if (!UP.Count) return LoopUnrollResult::Unmodified; @@ -1139,7 +1165,7 @@ static LoopUnrollResult tryToUnrollLoop( // If the loop was peeled, we already "used up" the profile information // we had, so we don't want to unroll or peel again. if (UnrollResult != LoopUnrollResult::FullyUnrolled && - (IsCountSetExplicitly || UP.PeelCount)) + (IsCountSetExplicitly || (UP.PeelProfiledIterations && UP.PeelCount))) L->setLoopAlreadyUnrolled(); return UnrollResult; @@ -1169,18 +1195,24 @@ public: Optional<bool> ProvidedRuntime; Optional<bool> ProvidedUpperBound; Optional<bool> ProvidedAllowPeeling; + Optional<bool> ProvidedAllowProfileBasedPeeling; + Optional<unsigned> ProvidedFullUnrollMaxCount; LoopUnroll(int OptLevel = 2, bool OnlyWhenForced = false, bool ForgetAllSCEV = false, Optional<unsigned> Threshold = None, Optional<unsigned> Count = None, Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, Optional<bool> UpperBound = None, - Optional<bool> AllowPeeling = None) + Optional<bool> AllowPeeling = None, + Optional<bool> AllowProfileBasedPeeling = None, + Optional<unsigned> ProvidedFullUnrollMaxCount = None) : LoopPass(ID), OptLevel(OptLevel), OnlyWhenForced(OnlyWhenForced), ForgetAllSCEV(ForgetAllSCEV), ProvidedCount(std::move(Count)), ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound), - ProvidedAllowPeeling(AllowPeeling) { + ProvidedAllowPeeling(AllowPeeling), + ProvidedAllowProfileBasedPeeling(AllowProfileBasedPeeling), + ProvidedFullUnrollMaxCount(ProvidedFullUnrollMaxCount) { initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); } @@ -1203,10 +1235,11 @@ public: bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); LoopUnrollResult Result = tryToUnrollLoop( - L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr, - PreserveLCSSA, OptLevel, OnlyWhenForced, - ForgetAllSCEV, ProvidedCount, ProvidedThreshold, ProvidedAllowPartial, - ProvidedRuntime, ProvidedUpperBound, ProvidedAllowPeeling); + L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr, PreserveLCSSA, OptLevel, + OnlyWhenForced, ForgetAllSCEV, ProvidedCount, ProvidedThreshold, + ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, + ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling, + ProvidedFullUnrollMaxCount); if (Result == LoopUnrollResult::FullyUnrolled) LPM.markLoopAsDeleted(*L); @@ -1283,14 +1316,16 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, std::string LoopName = L.getName(); - bool Changed = - tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, - /*BFI*/ nullptr, /*PSI*/ nullptr, - /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced, - ForgetSCEV, /*Count*/ None, - /*Threshold*/ None, /*AllowPartial*/ false, - /*Runtime*/ false, /*UpperBound*/ false, - /*AllowPeeling*/ false) != LoopUnrollResult::Unmodified; + bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, + /*BFI*/ nullptr, /*PSI*/ nullptr, + /*PreserveLCSSA*/ true, OptLevel, + OnlyWhenForced, ForgetSCEV, /*Count*/ None, + /*Threshold*/ None, /*AllowPartial*/ false, + /*Runtime*/ false, /*UpperBound*/ false, + /*AllowPeeling*/ false, + /*AllowProfileBasedPeeling*/ false, + /*FullUnrollMaxCount*/ None) != + LoopUnrollResult::Unmodified; if (!Changed) return PreservedAnalyses::all(); @@ -1430,7 +1465,8 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV, /*Count*/ None, /*Threshold*/ None, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime, - UnrollOpts.AllowUpperBound, LocalAllowPeeling); + UnrollOpts.AllowUpperBound, LocalAllowPeeling, + UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount); Changed |= Result != LoopUnrollResult::Unmodified; // The parent must not be damaged by unrolling! diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index b5b8e720069c..b410df0c5f68 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -420,7 +420,8 @@ enum OperatorChain { /// cost of creating an entirely new loop. static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, OperatorChain &ParentChain, - DenseMap<Value *, Value *> &Cache) { + DenseMap<Value *, Value *> &Cache, + MemorySSAUpdater *MSSAU) { auto CacheIt = Cache.find(Cond); if (CacheIt != Cache.end()) return CacheIt->second; @@ -438,7 +439,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, // TODO: Handle: br (VARIANT|INVARIANT). // Hoist simple values out. - if (L->makeLoopInvariant(Cond, Changed)) { + if (L->makeLoopInvariant(Cond, Changed, nullptr, MSSAU)) { Cache[Cond] = Cond; return Cond; } @@ -478,7 +479,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, // which will cause the branch to go away in one loop and the condition to // simplify in the other one. if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, - ParentChain, Cache)) { + ParentChain, Cache, MSSAU)) { Cache[Cond] = LHS; return LHS; } @@ -486,7 +487,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, // operand(1). ParentChain = NewChain; if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, - ParentChain, Cache)) { + ParentChain, Cache, MSSAU)) { Cache[Cond] = RHS; return RHS; } @@ -500,12 +501,12 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant along with the operator chain type. /// Otherwise, return null. -static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond, - Loop *L, - bool &Changed) { +static std::pair<Value *, OperatorChain> +FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + MemorySSAUpdater *MSSAU) { DenseMap<Value *, Value *> Cache; OperatorChain OpChain = OC_OpChainNone; - Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU); // In case we do find a LIV, it can not be obtained by walking up a mixed // operator chain. @@ -525,7 +526,7 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); if (EnableMSSALoopDependency) { MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = make_unique<MemorySSAUpdater>(MSSA); + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); assert(DT && "Cannot update MemorySSA without a valid DomTree."); } currentLoop = L; @@ -694,8 +695,9 @@ bool LoopUnswitch::processCurrentLoop() { } for (IntrinsicInst *Guard : Guards) { - Value *LoopCond = - FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; + Value *LoopCond = FindLIVLoopCondition(Guard->getOperand(0), currentLoop, + Changed, MSSAU.get()) + .first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the @@ -735,8 +737,9 @@ bool LoopUnswitch::processCurrentLoop() { if (BI->isConditional()) { // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. - Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed).first; + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, + Changed, MSSAU.get()) + .first; if (LoopCond && !EqualityPropUnSafe(*LoopCond) && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; @@ -748,7 +751,7 @@ bool LoopUnswitch::processCurrentLoop() { Value *LoopCond; OperatorChain OpChain; std::tie(LoopCond, OpChain) = - FindLIVLoopCondition(SC, currentLoop, Changed); + FindLIVLoopCondition(SC, currentLoop, Changed, MSSAU.get()); unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { @@ -808,8 +811,9 @@ bool LoopUnswitch::processCurrentLoop() { for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); BBI != E; ++BBI) if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed).first; + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, + Changed, MSSAU.get()) + .first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; @@ -1123,8 +1127,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (!BI->isConditional()) return false; - Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed).first; + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, + Changed, MSSAU.get()) + .first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -1157,8 +1162,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { return true; } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // If this isn't switching on an invariant condition, we can't unswitch it. - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed).first; + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, + Changed, MSSAU.get()) + .first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -1240,6 +1246,9 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, LoopBlocks.clear(); NewBlocks.clear(); + if (MSSAU && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + // First step, split the preheader and exit blocks, and add these blocks to // the LoopBlocks list. BasicBlock *NewPreheader = @@ -1607,36 +1616,30 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // If BI's parent is the only pred of the successor, fold the two blocks // together. BasicBlock *Pred = BI->getParent(); + (void)Pred; BasicBlock *Succ = BI->getSuccessor(0); BasicBlock *SinglePred = Succ->getSinglePredecessor(); if (!SinglePred) continue; // Nothing to do. assert(SinglePred == Pred && "CFG broken"); - LLVM_DEBUG(dbgs() << "Merging blocks: " << Pred->getName() << " <- " - << Succ->getName() << "\n"); - - // Resolve any single entry PHI nodes in Succ. - while (PHINode *PN = dyn_cast<PHINode>(Succ->begin())) - ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist, L, LPM, - MSSAU.get()); - - // If Succ has any successors with PHI nodes, update them to have - // entries coming from Pred instead of Succ. - Succ->replaceAllUsesWith(Pred); - - // Move all of the successor contents from Succ to Pred. - Pred->getInstList().splice(BI->getIterator(), Succ->getInstList(), - Succ->begin(), Succ->end()); - if (MSSAU) - MSSAU->moveAllAfterMergeBlocks(Succ, Pred, BI); + // Make the LPM and Worklist updates specific to LoopUnswitch. LPM->deleteSimpleAnalysisValue(BI, L); RemoveFromWorklist(BI, Worklist); - BI->eraseFromParent(); - - // Remove Succ from the loop tree. - LI->removeBlock(Succ); LPM->deleteSimpleAnalysisValue(Succ, L); - Succ->eraseFromParent(); + auto SuccIt = Succ->begin(); + while (PHINode *PN = dyn_cast<PHINode>(SuccIt++)) { + for (unsigned It = 0, E = PN->getNumOperands(); It != E; ++It) + if (Instruction *Use = dyn_cast<Instruction>(PN->getOperand(It))) + Worklist.push_back(Use); + for (User *U : PN->users()) + Worklist.push_back(cast<Instruction>(U)); + LPM->deleteSimpleAnalysisValue(PN, L); + RemoveFromWorklist(PN, Worklist); + ++NumSimplify; + } + // Merge the block and make the remaining analyses updates. + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + MergeBlockIntoPredecessor(Succ, &DTU, LI, MSSAU.get()); ++NumSimplify; continue; } diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 896dd8bcb922..2ccb7cae3079 100644 --- a/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -112,37 +112,6 @@ static cl::opt<unsigned> LVLoopDepthThreshold( "LoopVersioningLICM's threshold for maximum allowed loop nest/depth"), cl::init(2), cl::Hidden); -/// Create MDNode for input string. -static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { - LLVMContext &Context = TheLoop->getHeader()->getContext(); - Metadata *MDs[] = { - MDString::get(Context, Name), - ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; - return MDNode::get(Context, MDs); -} - -/// Set input string into loop metadata by keeping other values intact. -void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *MDString, - unsigned V) { - SmallVector<Metadata *, 4> MDs(1); - // If the loop already has metadata, retain it. - MDNode *LoopID = TheLoop->getLoopID(); - if (LoopID) { - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); - MDs.push_back(Node); - } - } - // Add new metadata. - MDs.push_back(createStringMetadata(TheLoop, MDString, V)); - // Replace current metadata node with new one. - LLVMContext &Context = TheLoop->getHeader()->getContext(); - MDNode *NewLoopID = MDNode::get(Context, MDs); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - TheLoop->setLoopID(NewLoopID); -} - namespace { struct LoopVersioningLICM : public LoopPass { diff --git a/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp new file mode 100644 index 000000000000..d0fcf38b5a7b --- /dev/null +++ b/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -0,0 +1,170 @@ +//===- LowerConstantIntrinsics.cpp - Lower constant intrinsic calls -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers all remaining 'objectsize' 'is.constant' intrinsic calls +// and provides constant propagation and basic CFG cleanup on the result. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "lower-is-constant-intrinsic" + +STATISTIC(IsConstantIntrinsicsHandled, + "Number of 'is.constant' intrinsic calls handled"); +STATISTIC(ObjectSizeIntrinsicsHandled, + "Number of 'objectsize' intrinsic calls handled"); + +static Value *lowerIsConstantIntrinsic(IntrinsicInst *II) { + Value *Op = II->getOperand(0); + + return isa<Constant>(Op) ? ConstantInt::getTrue(II->getType()) + : ConstantInt::getFalse(II->getType()); +} + +static bool replaceConditionalBranchesOnConstant(Instruction *II, + Value *NewValue) { + bool HasDeadBlocks = false; + SmallSetVector<Instruction *, 8> Worklist; + replaceAndRecursivelySimplify(II, NewValue, nullptr, nullptr, nullptr, + &Worklist); + for (auto I : Worklist) { + BranchInst *BI = dyn_cast<BranchInst>(I); + if (!BI) + continue; + if (BI->isUnconditional()) + continue; + + BasicBlock *Target, *Other; + if (match(BI->getOperand(0), m_Zero())) { + Target = BI->getSuccessor(1); + Other = BI->getSuccessor(0); + } else if (match(BI->getOperand(0), m_One())) { + Target = BI->getSuccessor(0); + Other = BI->getSuccessor(1); + } else { + Target = nullptr; + Other = nullptr; + } + if (Target && Target != Other) { + BasicBlock *Source = BI->getParent(); + Other->removePredecessor(Source); + BI->eraseFromParent(); + BranchInst::Create(Target, Source); + if (pred_begin(Other) == pred_end(Other)) + HasDeadBlocks = true; + } + } + return HasDeadBlocks; +} + +static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI) { + bool HasDeadBlocks = false; + const auto &DL = F.getParent()->getDataLayout(); + SmallVector<WeakTrackingVH, 8> Worklist; + + ReversePostOrderTraversal<Function *> RPOT(&F); + for (BasicBlock *BB : RPOT) { + for (Instruction &I: *BB) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); + if (!II) + continue; + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::is_constant: + case Intrinsic::objectsize: + Worklist.push_back(WeakTrackingVH(&I)); + break; + } + } + } + for (WeakTrackingVH &VH: Worklist) { + // Items on the worklist can be mutated by earlier recursive replaces. + // This can remove the intrinsic as dead (VH == null), but also replace + // the intrinsic in place. + if (!VH) + continue; + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*VH); + if (!II) + continue; + Value *NewValue; + switch (II->getIntrinsicID()) { + default: + continue; + case Intrinsic::is_constant: + NewValue = lowerIsConstantIntrinsic(II); + IsConstantIntrinsicsHandled++; + break; + case Intrinsic::objectsize: + NewValue = lowerObjectSizeCall(II, DL, TLI, true); + ObjectSizeIntrinsicsHandled++; + break; + } + HasDeadBlocks |= replaceConditionalBranchesOnConstant(II, NewValue); + } + if (HasDeadBlocks) + removeUnreachableBlocks(F); + return !Worklist.empty(); +} + +PreservedAnalyses +LowerConstantIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { + if (lowerConstantIntrinsics(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) + return PreservedAnalyses::none(); + + return PreservedAnalyses::all(); +} + +namespace { +/// Legacy pass for lowering is.constant intrinsics out of the IR. +/// +/// When this pass is run over a function it converts is.constant intrinsics +/// into 'true' or 'false'. This is completements the normal constand folding +/// to 'true' as part of Instruction Simplify passes. +class LowerConstantIntrinsics : public FunctionPass { +public: + static char ID; + LowerConstantIntrinsics() : FunctionPass(ID) { + initializeLowerConstantIntrinsicsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + const TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; + return lowerConstantIntrinsics(F, TLI); + } +}; +} // namespace + +char LowerConstantIntrinsics::ID = 0; +INITIALIZE_PASS(LowerConstantIntrinsics, "lower-constant-intrinsics", + "Lower constant intrinsics", false, false) + +FunctionPass *llvm::createLowerConstantIntrinsicsPass() { + return new LowerConstantIntrinsics(); +} diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 0d67c0d740ec..d85f20b3f80c 100644 --- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -26,6 +26,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/MisExpect.h" using namespace llvm; @@ -71,15 +72,20 @@ static bool handleSwitchExpect(SwitchInst &SI) { unsigned n = SI.getNumCases(); // +1 for default case. SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight); - if (Case == *SI.case_default()) - Weights[0] = LikelyBranchWeight; - else - Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight; + uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1; + Weights[Index] = LikelyBranchWeight; + + SI.setMetadata( + LLVMContext::MD_misexpect, + MDBuilder(CI->getContext()) + .createMisExpect(Index, LikelyBranchWeight, UnlikelyBranchWeight)); + + SI.setCondition(ArgValue); + misexpect::checkFrontendInstrumentation(SI); SI.setMetadata(LLVMContext::MD_prof, MDBuilder(CI->getContext()).createBranchWeights(Weights)); - SI.setCondition(ArgValue); return true; } @@ -155,7 +161,7 @@ static void handlePhiDef(CallInst *Expect) { return Result; }; - auto *PhiDef = dyn_cast<PHINode>(V); + auto *PhiDef = cast<PHINode>(V); // Get the first dominating conditional branch of the operand // i's incoming block. @@ -280,19 +286,28 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { MDBuilder MDB(CI->getContext()); MDNode *Node; + MDNode *ExpNode; if ((ExpectedValue->getZExtValue() == ValueComparedTo) == - (Predicate == CmpInst::ICMP_EQ)) + (Predicate == CmpInst::ICMP_EQ)) { Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight); - else + ExpNode = MDB.createMisExpect(0, LikelyBranchWeight, UnlikelyBranchWeight); + } else { Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight); + ExpNode = MDB.createMisExpect(1, LikelyBranchWeight, UnlikelyBranchWeight); + } - BSI.setMetadata(LLVMContext::MD_prof, Node); + BSI.setMetadata(LLVMContext::MD_misexpect, ExpNode); if (CmpI) CmpI->setOperand(0, ArgValue); else BSI.setCondition(ArgValue); + + misexpect::checkFrontendInstrumentation(BSI); + + BSI.setMetadata(LLVMContext::MD_prof, Node); + return true; } diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 5a055139be4f..2364748efb05 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -69,90 +69,6 @@ STATISTIC(NumMemSetInfer, "Number of memsets inferred"); STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy"); STATISTIC(NumCpyToSet, "Number of memcpys converted to memset"); -static int64_t GetOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, - bool &VariableIdxFound, - const DataLayout &DL) { - // Skip over the first indices. - gep_type_iterator GTI = gep_type_begin(GEP); - for (unsigned i = 1; i != Idx; ++i, ++GTI) - /*skip along*/; - - // Compute the offset implied by the rest of the indices. - int64_t Offset = 0; - for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { - ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (!OpC) - return VariableIdxFound = true; - if (OpC->isZero()) continue; // No offset. - - // Handle struct indices, which add their field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); - continue; - } - - // Otherwise, we have a sequential type like an array or vector. Multiply - // the index by the ElementSize. - uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); - Offset += Size*OpC->getSExtValue(); - } - - return Offset; -} - -/// Return true if Ptr1 is provably equal to Ptr2 plus a constant offset, and -/// return that constant offset. For example, Ptr1 might be &A[42], and Ptr2 -/// might be &A[40]. In this case offset would be -8. -static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset, - const DataLayout &DL) { - Ptr1 = Ptr1->stripPointerCasts(); - Ptr2 = Ptr2->stripPointerCasts(); - - // Handle the trivial case first. - if (Ptr1 == Ptr2) { - Offset = 0; - return true; - } - - GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1); - GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2); - - bool VariableIdxFound = false; - - // If one pointer is a GEP and the other isn't, then see if the GEP is a - // constant offset from the base, as in "P" and "gep P, 1". - if (GEP1 && !GEP2 && GEP1->getOperand(0)->stripPointerCasts() == Ptr2) { - Offset = -GetOffsetFromIndex(GEP1, 1, VariableIdxFound, DL); - return !VariableIdxFound; - } - - if (GEP2 && !GEP1 && GEP2->getOperand(0)->stripPointerCasts() == Ptr1) { - Offset = GetOffsetFromIndex(GEP2, 1, VariableIdxFound, DL); - return !VariableIdxFound; - } - - // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical - // base. After that base, they may have some number of common (and - // potentially variable) indices. After that they handle some constant - // offset, which determines their offset from each other. At this point, we - // handle no other case. - if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0)) - return false; - - // Skip any common indices and track the GEP types. - unsigned Idx = 1; - for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx) - if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx)) - break; - - int64_t Offset1 = GetOffsetFromIndex(GEP1, Idx, VariableIdxFound, DL); - int64_t Offset2 = GetOffsetFromIndex(GEP2, Idx, VariableIdxFound, DL); - if (VariableIdxFound) return false; - - Offset = Offset2-Offset1; - return true; -} - namespace { /// Represents a range of memset'd bytes with the ByteVal value. @@ -419,12 +335,12 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, break; // Check to see if this store is to a constant offset from the start ptr. - int64_t Offset; - if (!IsPointerOffset(StartPtr, NextStore->getPointerOperand(), Offset, - DL)) + Optional<int64_t> Offset = + isPointerOffset(StartPtr, NextStore->getPointerOperand(), DL); + if (!Offset) break; - Ranges.addStore(Offset, NextStore); + Ranges.addStore(*Offset, NextStore); } else { MemSetInst *MSI = cast<MemSetInst>(BI); @@ -433,11 +349,11 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, break; // Check to see if this store is to a constant offset from the start ptr. - int64_t Offset; - if (!IsPointerOffset(StartPtr, MSI->getDest(), Offset, DL)) + Optional<int64_t> Offset = isPointerOffset(StartPtr, MSI->getDest(), DL); + if (!Offset) break; - Ranges.addMemSet(Offset, MSI); + Ranges.addMemSet(*Offset, MSI); } } @@ -597,9 +513,13 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, ToLift.push_back(C); for (unsigned k = 0, e = C->getNumOperands(); k != e; ++k) - if (auto *A = dyn_cast<Instruction>(C->getOperand(k))) - if (A->getParent() == SI->getParent()) + if (auto *A = dyn_cast<Instruction>(C->getOperand(k))) { + if (A->getParent() == SI->getParent()) { + // Cannot hoist user of P above P + if(A == P) return false; Args.insert(A); + } + } } // We made it, we need to lift @@ -979,7 +899,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, // If the destination wasn't sufficiently aligned then increase its alignment. if (!isDestSufficientlyAligned) { assert(isa<AllocaInst>(cpyDest) && "Can only increase alloca alignment!"); - cast<AllocaInst>(cpyDest)->setAlignment(srcAlign); + cast<AllocaInst>(cpyDest)->setAlignment(MaybeAlign(srcAlign)); } // Drop any cached information about the call, because we may have changed @@ -1516,7 +1436,7 @@ bool MemCpyOptLegacyPass::runOnFunction(Function &F) { return false; auto *MD = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); - auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto LookupAliasAnalysis = [this]() -> AliasAnalysis & { return getAnalysis<AAResultsWrapperPass>().getAAResults(); diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp index 3d047a193267..98a45b391319 100644 --- a/lib/Transforms/Scalar/MergeICmps.cpp +++ b/lib/Transforms/Scalar/MergeICmps.cpp @@ -897,7 +897,7 @@ public: bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; - const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); // MergeICmps does not need the DominatorTree, but we update it if it's // already available. diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 30645f4400e3..9799ea7960ec 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -14,9 +14,11 @@ // diamond (hammock) and merges them into a single load in the header. Similar // it sinks and merges two stores to the tail block (footer). The algorithm // iterates over the instructions of one side of the diamond and attempts to -// find a matching load/store on the other side. It hoists / sinks when it -// thinks it safe to do so. This optimization helps with eg. hiding load -// latencies, triggering if-conversion, and reducing static code size. +// find a matching load/store on the other side. New tail/footer block may be +// insterted if the tail/footer block has more predecessors (not only the two +// predecessors that are forming the diamond). It hoists / sinks when it thinks +// it safe to do so. This optimization helps with eg. hiding load latencies, +// triggering if-conversion, and reducing static code size. // // NOTE: This code no longer performs load hoisting, it is subsumed by GVNHoist. // @@ -103,7 +105,9 @@ class MergedLoadStoreMotion { // Control is enforced by the check Size0 * Size1 < MagicCompileTimeControl. const int MagicCompileTimeControl = 250; + const bool SplitFooterBB; public: + MergedLoadStoreMotion(bool SplitFooterBB) : SplitFooterBB(SplitFooterBB) {} bool run(Function &F, AliasAnalysis &AA); private: @@ -114,7 +118,9 @@ private: PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1); bool isStoreSinkBarrierInRange(const Instruction &Start, const Instruction &End, MemoryLocation Loc); - bool sinkStore(BasicBlock *BB, StoreInst *SinkCand, StoreInst *ElseInst); + bool canSinkStoresAndGEPs(StoreInst *S0, StoreInst *S1) const; + void sinkStoresAndGEPs(BasicBlock *BB, StoreInst *SinkCand, + StoreInst *ElseInst); bool mergeStores(BasicBlock *BB); }; } // end anonymous namespace @@ -217,74 +223,82 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, } /// +/// Check if 2 stores can be sunk together with corresponding GEPs +/// +bool MergedLoadStoreMotion::canSinkStoresAndGEPs(StoreInst *S0, + StoreInst *S1) const { + auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); + auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); + return A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() && + (A0->getParent() == S0->getParent()) && A1->hasOneUse() && + (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0); +} + +/// /// Merge two stores to same address and sink into \p BB /// /// Also sinks GEP instruction computing the store address /// -bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, - StoreInst *S1) { +void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0, + StoreInst *S1) { // Only one definition? auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); - if (A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() && - (A0->getParent() == S0->getParent()) && A1->hasOneUse() && - (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0)) { - LLVM_DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump(); - dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n"; - dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n"); - // Hoist the instruction. - BasicBlock::iterator InsertPt = BB->getFirstInsertionPt(); - // Intersect optional metadata. - S0->andIRFlags(S1); - S0->dropUnknownNonDebugMetadata(); - - // Create the new store to be inserted at the join point. - StoreInst *SNew = cast<StoreInst>(S0->clone()); - Instruction *ANew = A0->clone(); - SNew->insertBefore(&*InsertPt); - ANew->insertBefore(SNew); - - assert(S0->getParent() == A0->getParent()); - assert(S1->getParent() == A1->getParent()); - - // New PHI operand? Use it. - if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) - SNew->setOperand(0, NewPN); - S0->eraseFromParent(); - S1->eraseFromParent(); - A0->replaceAllUsesWith(ANew); - A0->eraseFromParent(); - A1->replaceAllUsesWith(ANew); - A1->eraseFromParent(); - return true; - } - return false; + LLVM_DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump(); + dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n"; + dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n"); + // Hoist the instruction. + BasicBlock::iterator InsertPt = BB->getFirstInsertionPt(); + // Intersect optional metadata. + S0->andIRFlags(S1); + S0->dropUnknownNonDebugMetadata(); + + // Create the new store to be inserted at the join point. + StoreInst *SNew = cast<StoreInst>(S0->clone()); + Instruction *ANew = A0->clone(); + SNew->insertBefore(&*InsertPt); + ANew->insertBefore(SNew); + + assert(S0->getParent() == A0->getParent()); + assert(S1->getParent() == A1->getParent()); + + // New PHI operand? Use it. + if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) + SNew->setOperand(0, NewPN); + S0->eraseFromParent(); + S1->eraseFromParent(); + A0->replaceAllUsesWith(ANew); + A0->eraseFromParent(); + A1->replaceAllUsesWith(ANew); + A1->eraseFromParent(); } /// /// True when two stores are equivalent and can sink into the footer /// -/// Starting from a diamond tail block, iterate over the instructions in one -/// predecessor block and try to match a store in the second predecessor. +/// Starting from a diamond head block, iterate over the instructions in one +/// successor block and try to match a store in the second successor. /// -bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { +bool MergedLoadStoreMotion::mergeStores(BasicBlock *HeadBB) { bool MergedStores = false; - assert(T && "Footer of a diamond cannot be empty"); - - pred_iterator PI = pred_begin(T), E = pred_end(T); - assert(PI != E); - BasicBlock *Pred0 = *PI; - ++PI; - BasicBlock *Pred1 = *PI; - ++PI; + BasicBlock *TailBB = getDiamondTail(HeadBB); + BasicBlock *SinkBB = TailBB; + assert(SinkBB && "Footer of a diamond cannot be empty"); + + succ_iterator SI = succ_begin(HeadBB); + assert(SI != succ_end(HeadBB) && "Diamond head cannot have zero successors"); + BasicBlock *Pred0 = *SI; + ++SI; + assert(SI != succ_end(HeadBB) && "Diamond head cannot have single successor"); + BasicBlock *Pred1 = *SI; // tail block of a diamond/hammock? if (Pred0 == Pred1) return false; // No. - if (PI != E) - return false; // No. More than 2 predecessors. - - // #Instructions in Succ1 for Compile Time Control + // bail out early if we can not merge into the footer BB + if (!SplitFooterBB && TailBB->hasNPredecessorsOrMore(3)) + return false; + // #Instructions in Pred1 for Compile Time Control auto InstsNoDbg = Pred1->instructionsWithoutDebug(); int Size1 = std::distance(InstsNoDbg.begin(), InstsNoDbg.end()); int NStores = 0; @@ -304,14 +318,23 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { if (NStores * Size1 >= MagicCompileTimeControl) break; if (StoreInst *S1 = canSinkFromBlock(Pred1, S0)) { - bool Res = sinkStore(T, S0, S1); - MergedStores |= Res; - // Don't attempt to sink below stores that had to stick around - // But after removal of a store and some of its feeding - // instruction search again from the beginning since the iterator - // is likely stale at this point. - if (!Res) + if (!canSinkStoresAndGEPs(S0, S1)) + // Don't attempt to sink below stores that had to stick around + // But after removal of a store and some of its feeding + // instruction search again from the beginning since the iterator + // is likely stale at this point. break; + + if (SinkBB == TailBB && TailBB->hasNPredecessorsOrMore(3)) { + // We have more than 2 predecessors. Insert a new block + // postdominating 2 predecessors we're going to sink from. + SinkBB = SplitBlockPredecessors(TailBB, {Pred0, Pred1}, ".sink.split"); + if (!SinkBB) + break; + } + + MergedStores = true; + sinkStoresAndGEPs(SinkBB, S0, S1); RBI = Pred0->rbegin(); RBE = Pred0->rend(); LLVM_DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump()); @@ -328,13 +351,15 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) { // Merge unconditional branches, allowing PRE to catch more // optimization opportunities. + // This loop doesn't care about newly inserted/split blocks + // since they never will be diamond heads. for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE;) { BasicBlock *BB = &*FI++; // Hoist equivalent loads and sink stores // outside diamonds when possible if (isDiamondHead(BB)) { - Changed |= mergeStores(getDiamondTail(BB)); + Changed |= mergeStores(BB); } } return Changed; @@ -342,9 +367,11 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) { namespace { class MergedLoadStoreMotionLegacyPass : public FunctionPass { + const bool SplitFooterBB; public: static char ID; // Pass identification, replacement for typeid - MergedLoadStoreMotionLegacyPass() : FunctionPass(ID) { + MergedLoadStoreMotionLegacyPass(bool SplitFooterBB = false) + : FunctionPass(ID), SplitFooterBB(SplitFooterBB) { initializeMergedLoadStoreMotionLegacyPassPass( *PassRegistry::getPassRegistry()); } @@ -355,13 +382,14 @@ public: bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; - MergedLoadStoreMotion Impl; + MergedLoadStoreMotion Impl(SplitFooterBB); return Impl.run(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); } private: void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); + if (!SplitFooterBB) + AU.setPreservesCFG(); AU.addRequired<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } @@ -373,8 +401,8 @@ char MergedLoadStoreMotionLegacyPass::ID = 0; /// /// createMergedLoadStoreMotionPass - The public interface to this file. /// -FunctionPass *llvm::createMergedLoadStoreMotionPass() { - return new MergedLoadStoreMotionLegacyPass(); +FunctionPass *llvm::createMergedLoadStoreMotionPass(bool SplitFooterBB) { + return new MergedLoadStoreMotionLegacyPass(SplitFooterBB); } INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion", @@ -385,13 +413,14 @@ INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", PreservedAnalyses MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { - MergedLoadStoreMotion Impl; + MergedLoadStoreMotion Impl(Options.SplitFooterBB); auto &AA = AM.getResult<AAManager>(F); if (!Impl.run(F, AA)) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); + if (!Options.SplitFooterBB) + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); return PA; } diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index 94436b55752a..1260bd39cdee 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -170,7 +170,7 @@ bool NaryReassociateLegacyPass::runOnFunction(Function &F) { auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); return Impl.runImpl(F, AC, DT, SE, TLI, TTI); diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp index 08ac2b666fce..b213264de557 100644 --- a/lib/Transforms/Scalar/NewGVN.cpp +++ b/lib/Transforms/Scalar/NewGVN.cpp @@ -89,6 +89,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -122,6 +123,7 @@ using namespace llvm; using namespace llvm::GVNExpression; using namespace llvm::VNCoercion; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "newgvn" @@ -656,7 +658,7 @@ public: TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA, const DataLayout &DL) : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), - PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), + PredInfo(std::make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {} bool runGVN(); @@ -1332,7 +1334,7 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, E->setOpcode(0); E->op_push_back(PointerOp); if (LI) - E->setAlignment(LI->getAlignment()); + E->setAlignment(MaybeAlign(LI->getAlignment())); // TODO: Value number heap versions. We may be able to discover // things alias analysis can't on it's own (IE that a store and a @@ -1637,8 +1639,11 @@ const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const { if (AA->doesNotAccessMemory(CI)) { return createCallExpression(CI, TOPClass->getMemoryLeader()); } else if (AA->onlyReadsMemory(CI)) { - MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI); - return createCallExpression(CI, DefiningAccess); + if (auto *MA = MSSA->getMemoryAccess(CI)) { + auto *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(MA); + return createCallExpression(CI, DefiningAccess); + } else // MSSA determined that CI does not access memory. + return createCallExpression(CI, TOPClass->getMemoryLeader()); } return nullptr; } @@ -1754,7 +1759,7 @@ NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, return true; }); // If we are left with no operands, it's dead. - if (empty(Filtered)) { + if (Filtered.empty()) { // If it has undef at this point, it means there are no-non-undef arguments, // and thus, the value of the phi node must be undef. if (HasUndef) { @@ -2464,9 +2469,9 @@ Value *NewGVN::findConditionEquivalence(Value *Cond) const { // Process the outgoing edges of a block for reachability. void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { // Evaluate reachability of terminator instruction. - BranchInst *BR; - if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) { - Value *Cond = BR->getCondition(); + Value *Cond; + BasicBlock *TrueSucc, *FalseSucc; + if (match(TI, m_Br(m_Value(Cond), TrueSucc, FalseSucc))) { Value *CondEvaluated = findConditionEquivalence(Cond); if (!CondEvaluated) { if (auto *I = dyn_cast<Instruction>(Cond)) { @@ -2479,8 +2484,6 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { } } ConstantInt *CI; - BasicBlock *TrueSucc = BR->getSuccessor(0); - BasicBlock *FalseSucc = BR->getSuccessor(1); if (CondEvaluated && (CI = dyn_cast<ConstantInt>(CondEvaluated))) { if (CI->isOne()) { LLVM_DEBUG(dbgs() << "Condition for Terminator " << *TI @@ -4196,7 +4199,7 @@ bool NewGVNLegacyPass::runOnFunction(Function &F) { return false; return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), &getAnalysis<AAResultsWrapperPass>().getAAResults(), &getAnalysis<MemorySSAWrapperPass>().getMSSA(), F.getParent()->getDataLayout()) diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 039123218544..68a0f5151ad5 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -161,7 +161,7 @@ public: return false; TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); const TargetTransformInfo *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); return runPartiallyInlineLibCalls(F, TLI, TTI); diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp index b544f0a39ea8..beb299272ed8 100644 --- a/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -131,7 +131,7 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); for (Loop *I : *LI) { runOnLoopAndSubLoops(I); } @@ -240,7 +240,7 @@ static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header, static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE, BasicBlock *Pred) { // A conservative bound on the loop as a whole. - const SCEV *MaxTrips = SE->getMaxBackedgeTakenCount(L); + const SCEV *MaxTrips = SE->getConstantMaxBackedgeTakenCount(L); if (MaxTrips != SE->getCouldNotCompute() && SE->getUnsignedRange(MaxTrips).getUnsignedMax().isIntN( CountedLoopTripWidth)) @@ -478,7 +478,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { return false; const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); bool Modified = false; diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index fa8c9e2a5fe4..124f625ef7b6 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -861,7 +861,7 @@ static Value *NegateValue(Value *V, Instruction *BI, // this use. We do this by moving it to the entry block (if it is a // non-instruction value) or right after the definition. These negates will // be zapped by reassociate later, so we don't need much finesse here. - BinaryOperator *TheNeg = cast<BinaryOperator>(U); + Instruction *TheNeg = cast<Instruction>(U); // Verify that the negate is in this function, V might be a constant expr. if (TheNeg->getParent()->getParent() != BI->getParent()->getParent()) @@ -1938,88 +1938,132 @@ void ReassociatePass::EraseInst(Instruction *I) { MadeChange = true; } -// Canonicalize expressions of the following form: -// x + (-Constant * y) -> x - (Constant * y) -// x - (-Constant * y) -> x + (Constant * y) -Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { - if (!I->hasOneUse() || I->getType()->isVectorTy()) - return nullptr; - - // Must be a fmul or fdiv instruction. - unsigned Opcode = I->getOpcode(); - if (Opcode != Instruction::FMul && Opcode != Instruction::FDiv) - return nullptr; - - auto *C0 = dyn_cast<ConstantFP>(I->getOperand(0)); - auto *C1 = dyn_cast<ConstantFP>(I->getOperand(1)); - - // Both operands are constant, let it get constant folded away. - if (C0 && C1) - return nullptr; - - ConstantFP *CF = C0 ? C0 : C1; - - // Must have one constant operand. - if (!CF) - return nullptr; +/// Recursively analyze an expression to build a list of instructions that have +/// negative floating-point constant operands. The caller can then transform +/// the list to create positive constants for better reassociation and CSE. +static void getNegatibleInsts(Value *V, + SmallVectorImpl<Instruction *> &Candidates) { + // Handle only one-use instructions. Combining negations does not justify + // replicating instructions. + Instruction *I; + if (!match(V, m_OneUse(m_Instruction(I)))) + return; - // Must be a negative ConstantFP. - if (!CF->isNegative()) - return nullptr; + // Handle expressions of multiplications and divisions. + // TODO: This could look through floating-point casts. + const APFloat *C; + switch (I->getOpcode()) { + case Instruction::FMul: + // Not expecting non-canonical code here. Bail out and wait. + if (match(I->getOperand(0), m_Constant())) + break; - // User must be a binary operator with one or more uses. - Instruction *User = I->user_back(); - if (!isa<BinaryOperator>(User) || User->use_empty()) - return nullptr; + if (match(I->getOperand(1), m_APFloat(C)) && C->isNegative()) { + Candidates.push_back(I); + LLVM_DEBUG(dbgs() << "FMul with negative constant: " << *I << '\n'); + } + getNegatibleInsts(I->getOperand(0), Candidates); + getNegatibleInsts(I->getOperand(1), Candidates); + break; + case Instruction::FDiv: + // Not expecting non-canonical code here. Bail out and wait. + if (match(I->getOperand(0), m_Constant()) && + match(I->getOperand(1), m_Constant())) + break; - unsigned UserOpcode = User->getOpcode(); - if (UserOpcode != Instruction::FAdd && UserOpcode != Instruction::FSub) - return nullptr; + if ((match(I->getOperand(0), m_APFloat(C)) && C->isNegative()) || + (match(I->getOperand(1), m_APFloat(C)) && C->isNegative())) { + Candidates.push_back(I); + LLVM_DEBUG(dbgs() << "FDiv with negative constant: " << *I << '\n'); + } + getNegatibleInsts(I->getOperand(0), Candidates); + getNegatibleInsts(I->getOperand(1), Candidates); + break; + default: + break; + } +} - // Subtraction is not commutative. Explicitly, the following transform is - // not valid: (-Constant * y) - x -> x + (Constant * y) - if (!User->isCommutative() && User->getOperand(1) != I) +/// Given an fadd/fsub with an operand that is a one-use instruction +/// (the fadd/fsub), try to change negative floating-point constants into +/// positive constants to increase potential for reassociation and CSE. +Instruction *ReassociatePass::canonicalizeNegFPConstantsForOp(Instruction *I, + Instruction *Op, + Value *OtherOp) { + assert((I->getOpcode() == Instruction::FAdd || + I->getOpcode() == Instruction::FSub) && "Expected fadd/fsub"); + + // Collect instructions with negative FP constants from the subtree that ends + // in Op. + SmallVector<Instruction *, 4> Candidates; + getNegatibleInsts(Op, Candidates); + if (Candidates.empty()) return nullptr; // Don't canonicalize x + (-Constant * y) -> x - (Constant * y), if the // resulting subtract will be broken up later. This can get us into an // infinite loop during reassociation. - if (UserOpcode == Instruction::FAdd && ShouldBreakUpSubtract(User)) + bool IsFSub = I->getOpcode() == Instruction::FSub; + bool NeedsSubtract = !IsFSub && Candidates.size() % 2 == 1; + if (NeedsSubtract && ShouldBreakUpSubtract(I)) return nullptr; - // Change the sign of the constant. - APFloat Val = CF->getValueAPF(); - Val.changeSign(); - I->setOperand(C0 ? 0 : 1, ConstantFP::get(CF->getContext(), Val)); - - // Canonicalize I to RHS to simplify the next bit of logic. E.g., - // ((-Const*y) + x) -> (x + (-Const*y)). - if (User->getOperand(0) == I && User->isCommutative()) - cast<BinaryOperator>(User)->swapOperands(); - - Value *Op0 = User->getOperand(0); - Value *Op1 = User->getOperand(1); - BinaryOperator *NI; - switch (UserOpcode) { - default: - llvm_unreachable("Unexpected Opcode!"); - case Instruction::FAdd: - NI = BinaryOperator::CreateFSub(Op0, Op1); - NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags()); - break; - case Instruction::FSub: - NI = BinaryOperator::CreateFAdd(Op0, Op1); - NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags()); - break; + for (Instruction *Negatible : Candidates) { + const APFloat *C; + if (match(Negatible->getOperand(0), m_APFloat(C))) { + assert(!match(Negatible->getOperand(1), m_Constant()) && + "Expecting only 1 constant operand"); + assert(C->isNegative() && "Expected negative FP constant"); + Negatible->setOperand(0, ConstantFP::get(Negatible->getType(), abs(*C))); + MadeChange = true; + } + if (match(Negatible->getOperand(1), m_APFloat(C))) { + assert(!match(Negatible->getOperand(0), m_Constant()) && + "Expecting only 1 constant operand"); + assert(C->isNegative() && "Expected negative FP constant"); + Negatible->setOperand(1, ConstantFP::get(Negatible->getType(), abs(*C))); + MadeChange = true; + } } + assert(MadeChange == true && "Negative constant candidate was not changed"); - NI->insertBefore(User); - NI->setName(User->getName()); - User->replaceAllUsesWith(NI); - NI->setDebugLoc(I->getDebugLoc()); + // Negations cancelled out. + if (Candidates.size() % 2 == 0) + return I; + + // Negate the final operand in the expression by flipping the opcode of this + // fadd/fsub. + assert(Candidates.size() % 2 == 1 && "Expected odd number"); + IRBuilder<> Builder(I); + Value *NewInst = IsFSub ? Builder.CreateFAddFMF(OtherOp, Op, I) + : Builder.CreateFSubFMF(OtherOp, Op, I); + I->replaceAllUsesWith(NewInst); RedoInsts.insert(I); - MadeChange = true; - return NI; + return dyn_cast<Instruction>(NewInst); +} + +/// Canonicalize expressions that contain a negative floating-point constant +/// of the following form: +/// OtherOp + (subtree) -> OtherOp {+/-} (canonical subtree) +/// (subtree) + OtherOp -> OtherOp {+/-} (canonical subtree) +/// OtherOp - (subtree) -> OtherOp {+/-} (canonical subtree) +/// +/// The fadd/fsub opcode may be switched to allow folding a negation into the +/// input instruction. +Instruction *ReassociatePass::canonicalizeNegFPConstants(Instruction *I) { + LLVM_DEBUG(dbgs() << "Combine negations for: " << *I << '\n'); + Value *X; + Instruction *Op; + if (match(I, m_FAdd(m_Value(X), m_OneUse(m_Instruction(Op))))) + if (Instruction *R = canonicalizeNegFPConstantsForOp(I, Op, X)) + I = R; + if (match(I, m_FAdd(m_OneUse(m_Instruction(Op)), m_Value(X)))) + if (Instruction *R = canonicalizeNegFPConstantsForOp(I, Op, X)) + I = R; + if (match(I, m_FSub(m_Value(X), m_OneUse(m_Instruction(Op))))) + if (Instruction *R = canonicalizeNegFPConstantsForOp(I, Op, X)) + I = R; + return I; } /// Inspect and optimize the given instruction. Note that erasing @@ -2042,16 +2086,16 @@ void ReassociatePass::OptimizeInst(Instruction *I) { I = NI; } - // Canonicalize negative constants out of expressions. - if (Instruction *Res = canonicalizeNegConstExpr(I)) - I = Res; - // Commute binary operators, to canonicalize the order of their operands. // This can potentially expose more CSE opportunities, and makes writing other // transformations simpler. if (I->isCommutative()) canonicalizeOperands(I); + // Canonicalize negative constants out of expressions. + if (Instruction *Res = canonicalizeNegFPConstants(I)) + I = Res; + // Don't optimize floating-point instructions unless they are 'fast'. if (I->getType()->isFPOrFPVectorTy() && !I->isFast()) return; diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index c358258d24cf..48bbdd8d1b33 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -172,8 +172,6 @@ public: bool runOnModule(Module &M) override { bool Changed = false; - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); for (Function &F : M) { // Nothing to do for declarations. if (F.isDeclaration() || F.empty()) @@ -186,6 +184,8 @@ public: TargetTransformInfo &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); Changed |= Impl.runOnFunction(F, DT, TTI, TLI); @@ -2530,7 +2530,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // statepoints surviving this pass. This makes testing easier and the // resulting IR less confusing to human readers. DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - bool MadeChange = removeUnreachableBlocks(F, nullptr, &DTU); + bool MadeChange = removeUnreachableBlocks(F, &DTU); // Flush the Dominator Tree. DTU.getDomTree(); diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 4093e50ce899..10fbdc8aacd2 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -191,7 +191,7 @@ public: /// class SCCPSolver : public InstVisitor<SCCPSolver> { const DataLayout &DL; - const TargetLibraryInfo *TLI; + std::function<const TargetLibraryInfo &(Function &)> GetTLI; SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable. DenseMap<Value *, LatticeVal> ValueState; // The state each value is in. // The state each parameter is in. @@ -268,8 +268,9 @@ public: return {A->second.DT, A->second.PDT, DomTreeUpdater::UpdateStrategy::Lazy}; } - SCCPSolver(const DataLayout &DL, const TargetLibraryInfo *tli) - : DL(DL), TLI(tli) {} + SCCPSolver(const DataLayout &DL, + std::function<const TargetLibraryInfo &(Function &)> GetTLI) + : DL(DL), GetTLI(std::move(GetTLI)) {} /// MarkBlockExecutable - This method can be used by clients to mark all of /// the blocks that are known to be intrinsically live in the processed unit. @@ -1290,7 +1291,7 @@ CallOverdefined: // If we can constant fold this, mark the result of the call as a // constant. if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F, - Operands, TLI)) { + Operands, &GetTLI(*F))) { // call -> undef. if (isa<UndefValue>(C)) return; @@ -1465,7 +1466,24 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { } LatticeVal &LV = getValueState(&I); - if (!LV.isUnknown()) continue; + if (!LV.isUnknown()) + continue; + + // There are two reasons a call can have an undef result + // 1. It could be tracked. + // 2. It could be constant-foldable. + // Because of the way we solve return values, tracked calls must + // never be marked overdefined in ResolvedUndefsIn. + if (CallSite CS = CallSite(&I)) { + if (Function *F = CS.getCalledFunction()) + if (TrackedRetVals.count(F)) + continue; + + // If the call is constant-foldable, we mark it overdefined because + // we do not know what return values are valid. + markOverdefined(&I); + return true; + } // extractvalue is safe; check here because the argument is a struct. if (isa<ExtractValueInst>(I)) @@ -1638,19 +1656,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::Call: case Instruction::Invoke: case Instruction::CallBr: - // There are two reasons a call can have an undef result - // 1. It could be tracked. - // 2. It could be constant-foldable. - // Because of the way we solve return values, tracked calls must - // never be marked overdefined in ResolvedUndefsIn. - if (Function *F = CallSite(&I).getCalledFunction()) - if (TrackedRetVals.count(F)) - break; - - // If the call is constant-foldable, we mark it overdefined because - // we do not know what return values are valid. - markOverdefined(&I); - return true; + llvm_unreachable("Call-like instructions should have be handled early"); default: // If we don't know what should happen here, conservatively mark it // overdefined. @@ -1751,7 +1757,7 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { [](const LatticeVal &LV) { return LV.isOverdefined(); })) return false; std::vector<Constant *> ConstVals; - auto *ST = dyn_cast<StructType>(V->getType()); + auto *ST = cast<StructType>(V->getType()); for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { LatticeVal V = IVs[i]; ConstVals.push_back(V.isConstant() @@ -1796,7 +1802,8 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { static bool runSCCP(Function &F, const DataLayout &DL, const TargetLibraryInfo *TLI) { LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); - SCCPSolver Solver(DL, TLI); + SCCPSolver Solver( + DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; }); // Mark the first block of the function as being executable. Solver.MarkBlockExecutable(&F.front()); @@ -1891,7 +1898,7 @@ public: return false; const DataLayout &DL = F.getParent()->getDataLayout(); const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); return runSCCP(F, DL, TLI); } }; @@ -1924,6 +1931,27 @@ static void findReturnsToZap(Function &F, return; } + assert( + all_of(F.users(), + [&Solver](User *U) { + if (isa<Instruction>(U) && + !Solver.isBlockExecutable(cast<Instruction>(U)->getParent())) + return true; + // Non-callsite uses are not impacted by zapping. Also, constant + // uses (like blockaddresses) could stuck around, without being + // used in the underlying IR, meaning we do not have lattice + // values for them. + if (!CallSite(U)) + return true; + if (U->getType()->isStructTy()) { + return all_of( + Solver.getStructLatticeValueFor(U), + [](const LatticeVal &LV) { return !LV.isOverdefined(); }); + } + return !Solver.getLatticeValueFor(U).isOverdefined(); + }) && + "We can only zap functions where all live users have a concrete value"); + for (BasicBlock &BB : F) { if (CallInst *CI = BB.getTerminatingMustTailCall()) { LLVM_DEBUG(dbgs() << "Can't zap return of the block due to present " @@ -1974,9 +2002,10 @@ static void forceIndeterminateEdge(Instruction* I, SCCPSolver &Solver) { } bool llvm::runIPSCCP( - Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI, + Module &M, const DataLayout &DL, + std::function<const TargetLibraryInfo &(Function &)> GetTLI, function_ref<AnalysisResultsForFn(Function &)> getAnalysis) { - SCCPSolver Solver(DL, TLI); + SCCPSolver Solver(DL, GetTLI); // Loop over all functions, marking arguments to those with their addresses // taken or that are external as overdefined. diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index 33f90d0b01e4..74b8ff913050 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -959,14 +959,16 @@ private: std::tie(UsedI, I) = Uses.pop_back_val(); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - Size = std::max(Size, DL.getTypeStoreSize(LI->getType())); + Size = std::max(Size, + DL.getTypeStoreSize(LI->getType()).getFixedSize()); continue; } if (StoreInst *SI = dyn_cast<StoreInst>(I)) { Value *Op = SI->getOperand(0); if (Op == UsedI) return SI; - Size = std::max(Size, DL.getTypeStoreSize(Op->getType())); + Size = std::max(Size, + DL.getTypeStoreSize(Op->getType()).getFixedSize()); continue; } @@ -1197,7 +1199,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { // TODO: Allow recursive phi users. // TODO: Allow stores. BasicBlock *BB = PN.getParent(); - unsigned MaxAlign = 0; + MaybeAlign MaxAlign; uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType()); APInt MaxSize(APWidth, 0); bool HaveLoad = false; @@ -1218,8 +1220,8 @@ static bool isSafePHIToSpeculate(PHINode &PN) { if (BBI->mayWriteToMemory()) return false; - uint64_t Size = DL.getTypeStoreSizeInBits(LI->getType()); - MaxAlign = std::max(MaxAlign, LI->getAlignment()); + uint64_t Size = DL.getTypeStoreSize(LI->getType()); + MaxAlign = std::max(MaxAlign, MaybeAlign(LI->getAlignment())); MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize; HaveLoad = true; } @@ -1266,11 +1268,11 @@ static void speculatePHINodeLoads(PHINode &PN) { PHINode *NewPN = PHIBuilder.CreatePHI(LoadTy, PN.getNumIncomingValues(), PN.getName() + ".sroa.speculated"); - // Get the AA tags and alignment to use from one of the loads. It doesn't + // Get the AA tags and alignment to use from one of the loads. It does not // matter which one we get and if any differ. AAMDNodes AATags; SomeLoad->getAAMetadata(AATags); - unsigned Align = SomeLoad->getAlignment(); + const MaybeAlign Align = MaybeAlign(SomeLoad->getAlignment()); // Rewrite all loads of the PN to use the new PHI. while (!PN.use_empty()) { @@ -1338,11 +1340,11 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { // Both operands to the select need to be dereferenceable, either // absolutely (e.g. allocas) or at this point because we can see other // accesses to it. - if (!isSafeToLoadUnconditionally(TValue, LI->getType(), LI->getAlignment(), - DL, LI)) + if (!isSafeToLoadUnconditionally(TValue, LI->getType(), + MaybeAlign(LI->getAlignment()), DL, LI)) return false; - if (!isSafeToLoadUnconditionally(FValue, LI->getType(), LI->getAlignment(), - DL, LI)) + if (!isSafeToLoadUnconditionally(FValue, LI->getType(), + MaybeAlign(LI->getAlignment()), DL, LI)) return false; } @@ -1368,8 +1370,8 @@ static void speculateSelectInstLoads(SelectInst &SI) { NumLoadsSpeculated += 2; // Transfer alignment and AA info if present. - TL->setAlignment(LI->getAlignment()); - FL->setAlignment(LI->getAlignment()); + TL->setAlignment(MaybeAlign(LI->getAlignment())); + FL->setAlignment(MaybeAlign(LI->getAlignment())); AAMDNodes Tags; LI->getAAMetadata(Tags); @@ -1888,6 +1890,14 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { bool HaveCommonEltTy = true; auto CheckCandidateType = [&](Type *Ty) { if (auto *VTy = dyn_cast<VectorType>(Ty)) { + // Return if bitcast to vectors is different for total size in bits. + if (!CandidateTys.empty()) { + VectorType *V = CandidateTys[0]; + if (DL.getTypeSizeInBits(VTy) != DL.getTypeSizeInBits(V)) { + CandidateTys.clear(); + return; + } + } CandidateTys.push_back(VTy); if (!CommonEltTy) CommonEltTy = VTy->getElementType(); @@ -3110,7 +3120,7 @@ private: unsigned LoadAlign = LI->getAlignment(); if (!LoadAlign) LoadAlign = DL.getABITypeAlignment(LI->getType()); - LI->setAlignment(std::min(LoadAlign, getSliceAlign())); + LI->setAlignment(MaybeAlign(std::min(LoadAlign, getSliceAlign()))); continue; } if (StoreInst *SI = dyn_cast<StoreInst>(I)) { @@ -3119,7 +3129,7 @@ private: Value *Op = SI->getOperand(0); StoreAlign = DL.getABITypeAlignment(Op->getType()); } - SI->setAlignment(std::min(StoreAlign, getSliceAlign())); + SI->setAlignment(MaybeAlign(std::min(StoreAlign, getSliceAlign()))); continue; } diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index 869cf00e0a89..1d2e40bf62be 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -79,6 +79,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopVersioningLICMPass(Registry); initializeLoopIdiomRecognizeLegacyPassPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); + initializeLowerConstantIntrinsicsPass(Registry); initializeLowerExpectIntrinsicPass(Registry); initializeLowerGuardIntrinsicLegacyPassPass(Registry); initializeLowerWidenableConditionLegacyPassPass(Registry); @@ -123,6 +124,10 @@ void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createAggressiveDCEPass()); } +void LLVMAddDCEPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createDeadCodeEliminationPass()); +} + void LLVMAddBitTrackingDCEPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createBitTrackingDCEPass()); } @@ -280,6 +285,10 @@ void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createBasicAAWrapperPass()); } +void LLVMAddLowerConstantIntrinsicsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLowerConstantIntrinsicsPass()); +} + void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLowerExpectIntrinsicPass()); } diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index f6a12fb13142..41554fccdf08 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -1121,7 +1121,7 @@ bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) { DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); bool Changed = false; for (BasicBlock &B : F) { for (BasicBlock::iterator I = B.begin(), IE = B.end(); I != IE;) diff --git a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index aeac6f548b32..ac832b9b4567 100644 --- a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -1909,7 +1909,7 @@ static void unswitchNontrivialInvariants( // We can only unswitch switches, conditional branches with an invariant // condition, or combining invariant conditions with an instruction. - assert((SI || BI->isConditional()) && + assert((SI || (BI && BI->isConditional())) && "Can only unswitch switches and conditional branch!"); bool FullUnswitch = SI || BI->getCondition() == Invariants[0]; if (FullUnswitch) @@ -2141,17 +2141,21 @@ static void unswitchNontrivialInvariants( buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, *ClonedPH, *LoopPH); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); + + if (MSSAU) { + DT.applyUpdates(DTUpdates); + DTUpdates.clear(); + + // Perform MSSA cloning updates. + for (auto &VMap : VMaps) + MSSAU->updateForClonedLoop(LBRPO, ExitBlocks, *VMap, + /*IgnoreIncomingWithNoClones=*/true); + MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMaps, DT); + } } // Apply the updates accumulated above to get an up-to-date dominator tree. DT.applyUpdates(DTUpdates); - if (!FullUnswitch && MSSAU) { - // Update MSSA for partial unswitch, after DT update. - SmallVector<CFGUpdate, 1> Updates; - Updates.push_back( - {cfg::UpdateKind::Insert, SplitBB, ClonedPHs.begin()->second}); - MSSAU->applyInsertUpdates(Updates, DT); - } // Now that we have an accurate dominator tree, first delete the dead cloned // blocks so that we can accurately build any cloned loops. It is important to @@ -2720,7 +2724,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, return Cost * (SuccessorsCount - 1); }; Instruction *BestUnswitchTI = nullptr; - int BestUnswitchCost; + int BestUnswitchCost = 0; ArrayRef<Value *> BestUnswitchInvariants; for (auto &TerminatorAndInvariants : UnswitchCandidates) { Instruction &TI = *TerminatorAndInvariants.first; @@ -2752,6 +2756,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, BestUnswitchInvariants = Invariants; } } + assert(BestUnswitchTI && "Failed to find loop unswitch candidate"); if (BestUnswitchCost >= UnswitchThreshold) { LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " @@ -2880,7 +2885,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); auto PA = getLoopPassPreservedAnalyses(); - if (EnableMSSALoopDependency) + if (AR.MSSA) PA.preserve<MemorySSAAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp index c13fb3e04516..e6db11f47ead 100644 --- a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp +++ b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -777,8 +777,10 @@ static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs, // speculation if the predecessor is an invoke. This doesn't seem // fundamental and we should probably be splitting critical edges // differently. - if (isa<IndirectBrInst>(PredBB->getTerminator()) || - isa<InvokeInst>(PredBB->getTerminator())) { + const auto *TermInst = PredBB->getTerminator(); + if (isa<IndirectBrInst>(TermInst) || + isa<InvokeInst>(TermInst) || + isa<CallBrInst>(TermInst)) { LLVM_DEBUG(dbgs() << " Invalid: predecessor terminator: " << PredBB->getName() << "\n"); return false; diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index e5400676c7e8..9791cf41f621 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -65,7 +65,7 @@ static cl::opt<bool> ForceSkipUniformRegions( static cl::opt<bool> RelaxedUniformRegions("structurizecfg-relaxed-uniform-regions", cl::Hidden, cl::desc("Allow relaxed uniform region checks"), - cl::init(false)); + cl::init(true)); // Definition of the complex types used in this pass. diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index f0b79079d817..b27a36b67d62 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -341,7 +341,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { const DataLayout &DL = L->getModule()->getDataLayout(); if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(), - L->getAlignment(), DL, L)) + MaybeAlign(L->getAlignment()), DL, L)) return false; } } diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 5fa371377c85..d85cc40c372a 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -170,7 +170,8 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, - MemoryDependenceResults *MemDep) { + MemoryDependenceResults *MemDep, + bool PredecessorWithTwoSuccessors) { if (BB->hasAddressTaken()) return false; @@ -185,9 +186,24 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, return false; // Can't merge if there are multiple distinct successors. - if (PredBB->getUniqueSuccessor() != BB) + if (!PredecessorWithTwoSuccessors && PredBB->getUniqueSuccessor() != BB) return false; + // Currently only allow PredBB to have two predecessors, one being BB. + // Update BI to branch to BB's only successor instead of BB. + BranchInst *PredBB_BI; + BasicBlock *NewSucc = nullptr; + unsigned FallThruPath; + if (PredecessorWithTwoSuccessors) { + if (!(PredBB_BI = dyn_cast<BranchInst>(PredBB->getTerminator()))) + return false; + BranchInst *BB_JmpI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BB_JmpI || !BB_JmpI->isUnconditional()) + return false; + NewSucc = BB_JmpI->getSuccessor(0); + FallThruPath = PredBB_BI->getSuccessor(0) == BB ? 0 : 1; + } + // Can't merge if there is PHI loop. for (PHINode &PN : BB->phis()) for (Value *IncValue : PN.incoming_values()) @@ -227,18 +243,39 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, Updates.push_back({DominatorTree::Delete, PredBB, BB}); } - if (MSSAU) - MSSAU->moveAllAfterMergeBlocks(BB, PredBB, &*(BB->begin())); + Instruction *PTI = PredBB->getTerminator(); + Instruction *STI = BB->getTerminator(); + Instruction *Start = &*BB->begin(); + // If there's nothing to move, mark the starting instruction as the last + // instruction in the block. + if (Start == STI) + Start = PTI; + + // Move all definitions in the successor to the predecessor... + PredBB->getInstList().splice(PTI->getIterator(), BB->getInstList(), + BB->begin(), STI->getIterator()); - // Delete the unconditional branch from the predecessor... - PredBB->getInstList().pop_back(); + if (MSSAU) + MSSAU->moveAllAfterMergeBlocks(BB, PredBB, Start); // Make all PHI nodes that referred to BB now refer to Pred as their // source... BB->replaceAllUsesWith(PredBB); - // Move all definitions in the successor to the predecessor... - PredBB->getInstList().splice(PredBB->end(), BB->getInstList()); + if (PredecessorWithTwoSuccessors) { + // Delete the unconditional branch from BB. + BB->getInstList().pop_back(); + + // Update branch in the predecessor. + PredBB_BI->setSuccessor(FallThruPath, NewSucc); + } else { + // Delete the unconditional branch from the predecessor. + PredBB->getInstList().pop_back(); + + // Move terminator instruction. + PredBB->getInstList().splice(PredBB->end(), BB->getInstList()); + } + // Add unreachable to now empty BB. new UnreachableInst(BB->getContext(), BB); // Eliminate duplicate dbg.values describing the entry PHI node post-splice. @@ -274,11 +311,10 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, "applying corresponding DTU updates."); DTU->applyUpdatesPermissive(Updates); DTU->deleteBB(BB); - } - - else { + } else { BB->eraseFromParent(); // Nuke BB if DTU is nullptr. } + return true; } @@ -365,11 +401,13 @@ llvm::SplitAllCriticalEdges(Function &F, BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI, - MemorySSAUpdater *MSSAU) { + MemorySSAUpdater *MSSAU, const Twine &BBName) { BasicBlock::iterator SplitIt = SplitPt->getIterator(); while (isa<PHINode>(SplitIt) || SplitIt->isEHPad()) ++SplitIt; - BasicBlock *New = Old->splitBasicBlock(SplitIt, Old->getName()+".split"); + std::string Name = BBName.str(); + BasicBlock *New = Old->splitBasicBlock( + SplitIt, Name.empty() ? Old->getName() + ".split" : Name); // The new block lives in whichever loop the old one did. This preserves // LCSSA as well, because we force the split point to be after any PHI nodes. diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index 27f110e24f9c..71316ce8f758 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -88,6 +88,14 @@ static bool setDoesNotCapture(Function &F, unsigned ArgNo) { return true; } +static bool setDoesNotAlias(Function &F, unsigned ArgNo) { + if (F.hasParamAttribute(ArgNo, Attribute::NoAlias)) + return false; + F.addParamAttr(ArgNo, Attribute::NoAlias); + ++NumNoAlias; + return true; +} + static bool setOnlyReadsMemory(Function &F, unsigned ArgNo) { if (F.hasParamAttribute(ArgNo, Attribute::ReadOnly)) return false; @@ -175,6 +183,9 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { return Changed; case LibFunc_strcpy: case LibFunc_strncpy: + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotAlias(F, 1); + LLVM_FALLTHROUGH; case LibFunc_strcat: case LibFunc_strncat: Changed |= setReturnedArg(F, 0); @@ -249,12 +260,14 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { case LibFunc_sprintf: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 0); + Changed |= setDoesNotAlias(F, 0); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); return Changed; case LibFunc_snprintf: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 0); + Changed |= setDoesNotAlias(F, 0); Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 2); return Changed; @@ -291,11 +304,23 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setDoesNotCapture(F, 1); return Changed; case LibFunc_memcpy: + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotAlias(F, 1); + Changed |= setReturnedArg(F, 0); + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; case LibFunc_memmove: Changed |= setReturnedArg(F, 0); - LLVM_FALLTHROUGH; + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; case LibFunc_mempcpy: case LibFunc_memccpy: + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotAlias(F, 1); Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); @@ -760,9 +785,8 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { } } -bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, - LibFunc DoubleFn, LibFunc FloatFn, - LibFunc LongDoubleFn) { +bool llvm::hasFloatFn(const TargetLibraryInfo *TLI, Type *Ty, + LibFunc DoubleFn, LibFunc FloatFn, LibFunc LongDoubleFn) { switch (Ty->getTypeID()) { case Type::HalfTyID: return false; @@ -775,10 +799,10 @@ bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, } } -StringRef llvm::getUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, - LibFunc DoubleFn, LibFunc FloatFn, - LibFunc LongDoubleFn) { - assert(hasUnaryFloatFn(TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) && +StringRef llvm::getFloatFnName(const TargetLibraryInfo *TLI, Type *Ty, + LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn) { + assert(hasFloatFn(TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) && "Cannot get name for unavailable function!"); switch (Ty->getTypeID()) { @@ -827,6 +851,12 @@ Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, B.getInt8PtrTy(), castToCStr(Ptr, B), B, TLI); } +Value *llvm::emitStrDup(Value *Ptr, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + return emitLibCall(LibFunc_strdup, B.getInt8PtrTy(), B.getInt8PtrTy(), + castToCStr(Ptr, B), B, TLI); +} + Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); @@ -1045,24 +1075,28 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, const TargetLibraryInfo *TLI, LibFunc LongDoubleFn, IRBuilder<> &B, const AttributeList &Attrs) { // Get the name of the function according to TLI. - StringRef Name = getUnaryFloatFn(TLI, Op->getType(), - DoubleFn, FloatFn, LongDoubleFn); + StringRef Name = getFloatFnName(TLI, Op->getType(), + DoubleFn, FloatFn, LongDoubleFn); return emitUnaryFloatFnCallHelper(Op, Name, B, Attrs); } -Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, - IRBuilder<> &B, const AttributeList &Attrs) { +static Value *emitBinaryFloatFnCallHelper(Value *Op1, Value *Op2, + StringRef Name, IRBuilder<> &B, + const AttributeList &Attrs) { assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall"); - SmallString<20> NameBuffer; - appendTypeSuffix(Op1, Name, NameBuffer); - Module *M = B.GetInsertBlock()->getModule(); - FunctionCallee Callee = M->getOrInsertFunction( - Name, Op1->getType(), Op1->getType(), Op2->getType()); - CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name); - CI->setAttributes(Attrs); + FunctionCallee Callee = M->getOrInsertFunction(Name, Op1->getType(), + Op1->getType(), Op2->getType()); + CallInst *CI = B.CreateCall(Callee, { Op1, Op2 }, Name); + + // The incoming attribute set may have come from a speculatable intrinsic, but + // is being replaced with a library call which is not allowed to be + // speculatable. + CI->setAttributes(Attrs.removeAttribute(B.getContext(), + AttributeList::FunctionIndex, + Attribute::Speculatable)); if (const Function *F = dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1070,6 +1104,28 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, return CI; } +Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, + IRBuilder<> &B, const AttributeList &Attrs) { + assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall"); + + SmallString<20> NameBuffer; + appendTypeSuffix(Op1, Name, NameBuffer); + + return emitBinaryFloatFnCallHelper(Op1, Op2, Name, B, Attrs); +} + +Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, + const TargetLibraryInfo *TLI, + LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn, IRBuilder<> &B, + const AttributeList &Attrs) { + // Get the name of the function according to TLI. + StringRef Name = getFloatFnName(TLI, Op1->getType(), + DoubleFn, FloatFn, LongDoubleFn); + + return emitBinaryFloatFnCallHelper(Op1, Op2, Name, B, Attrs); +} + Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc_putchar)) diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp index df299f673f65..9a6761040bd8 100644 --- a/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -448,13 +448,17 @@ bool llvm::bypassSlowDivision(BasicBlock *BB, DivCacheTy PerBBDivCache; bool MadeChange = false; - Instruction* Next = &*BB->begin(); + Instruction *Next = &*BB->begin(); while (Next != nullptr) { // We may add instructions immediately after I, but we want to skip over // them. - Instruction* I = Next; + Instruction *I = Next; Next = Next->getNextNode(); + // Ignore dead code to save time and avoid bugs. + if (I->hasNUses(0)) + continue; + FastDivInsertionTask Task(I, BypassWidths); if (Value *Replacement = Task.getReplacement(PerBBDivCache)) { I->replaceAllUsesWith(Replacement); diff --git a/lib/Transforms/Utils/CanonicalizeAliases.cpp b/lib/Transforms/Utils/CanonicalizeAliases.cpp index 455fcbb1cf98..3c7c8d872595 100644 --- a/lib/Transforms/Utils/CanonicalizeAliases.cpp +++ b/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" using namespace llvm; diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 1026c9d37038..75e8963303c2 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -210,6 +210,21 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, RemapInstruction(&II, VMap, ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, TypeMapper, Materializer); + + // Register all DICompileUnits of the old parent module in the new parent module + auto* OldModule = OldFunc->getParent(); + auto* NewModule = NewFunc->getParent(); + if (OldModule && NewModule && OldModule != NewModule && DIFinder.compile_unit_count()) { + auto* NMD = NewModule->getOrInsertNamedMetadata("llvm.dbg.cu"); + // Avoid multiple insertions of the same DICompileUnit to NMD. + SmallPtrSet<const void*, 8> Visited; + for (auto* Operand : NMD->operands()) + Visited.insert(Operand); + for (auto* Unit : DIFinder.compile_units()) + // VMap.MD()[Unit] == Unit + if (Visited.insert(Unit).second) + NMD->addOperand(Unit); + } } /// Return a copy of the specified function and add it to that function's diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 7ddf59becba9..2c8c3abb2922 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -48,7 +48,7 @@ std::unique_ptr<Module> llvm::CloneModule( function_ref<bool(const GlobalValue *)> ShouldCloneDefinition) { // First off, we need to create the new module. std::unique_ptr<Module> New = - llvm::make_unique<Module>(M.getModuleIdentifier(), M.getContext()); + std::make_unique<Module>(M.getModuleIdentifier(), M.getContext()); New->setSourceFileName(M.getSourceFileName()); New->setDataLayout(M.getDataLayout()); New->setTargetTriple(M.getTargetTriple()); @@ -181,13 +181,25 @@ std::unique_ptr<Module> llvm::CloneModule( } // And named metadata.... + const auto* LLVM_DBG_CU = M.getNamedMetadata("llvm.dbg.cu"); for (Module::const_named_metadata_iterator I = M.named_metadata_begin(), E = M.named_metadata_end(); I != E; ++I) { const NamedMDNode &NMD = *I; NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName()); - for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i) - NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap)); + if (&NMD == LLVM_DBG_CU) { + // Do not insert duplicate operands. + SmallPtrSet<const void*, 8> Visited; + for (const auto* Operand : NewNMD->operands()) + Visited.insert(Operand); + for (const auto* Operand : NMD.operands()) { + auto* MappedOperand = MapMetadata(Operand, VMap); + if (Visited.insert(MappedOperand).second) + NewNMD->addOperand(MappedOperand); + } + } else + for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i) + NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap)); } return New; diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index fa6d3f8ae873..0298ff9a395f 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -293,10 +293,8 @@ static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) { CommonExitBlock = Succ; continue; } - if (CommonExitBlock == Succ) - continue; - - return true; + if (CommonExitBlock != Succ) + return true; } return false; }; @@ -307,52 +305,79 @@ static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) { return CommonExitBlock; } -bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( - Instruction *Addr) const { - AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets()); - Function *Func = (*Blocks.begin())->getParent(); - for (BasicBlock &BB : *Func) { - if (Blocks.count(&BB)) - continue; - for (Instruction &II : BB) { - if (isa<DbgInfoIntrinsic>(II)) - continue; +CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function &F) { + for (BasicBlock &BB : F) { + for (Instruction &II : BB.instructionsWithoutDebug()) + if (auto *AI = dyn_cast<AllocaInst>(&II)) + Allocas.push_back(AI); - unsigned Opcode = II.getOpcode(); - Value *MemAddr = nullptr; - switch (Opcode) { - case Instruction::Store: - case Instruction::Load: { - if (Opcode == Instruction::Store) { - StoreInst *SI = cast<StoreInst>(&II); - MemAddr = SI->getPointerOperand(); - } else { - LoadInst *LI = cast<LoadInst>(&II); - MemAddr = LI->getPointerOperand(); - } - // Global variable can not be aliased with locals. - if (dyn_cast<Constant>(MemAddr)) - break; - Value *Base = MemAddr->stripInBoundsConstantOffsets(); - if (!isa<AllocaInst>(Base) || Base == AI) - return false; + findSideEffectInfoForBlock(BB); + } +} + +void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock &BB) { + for (Instruction &II : BB.instructionsWithoutDebug()) { + unsigned Opcode = II.getOpcode(); + Value *MemAddr = nullptr; + switch (Opcode) { + case Instruction::Store: + case Instruction::Load: { + if (Opcode == Instruction::Store) { + StoreInst *SI = cast<StoreInst>(&II); + MemAddr = SI->getPointerOperand(); + } else { + LoadInst *LI = cast<LoadInst>(&II); + MemAddr = LI->getPointerOperand(); + } + // Global variable can not be aliased with locals. + if (dyn_cast<Constant>(MemAddr)) break; + Value *Base = MemAddr->stripInBoundsConstantOffsets(); + if (!isa<AllocaInst>(Base)) { + SideEffectingBlocks.insert(&BB); + return; } - default: { - IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II); - if (IntrInst) { - if (IntrInst->isLifetimeStartOrEnd()) - break; - return false; - } - // Treat all the other cases conservatively if it has side effects. - if (II.mayHaveSideEffects()) - return false; + BaseMemAddrs[&BB].insert(Base); + break; + } + default: { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II); + if (IntrInst) { + if (IntrInst->isLifetimeStartOrEnd()) + break; + SideEffectingBlocks.insert(&BB); + return; } + // Treat all the other cases conservatively if it has side effects. + if (II.mayHaveSideEffects()) { + SideEffectingBlocks.insert(&BB); + return; } } + } } +} +bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr( + BasicBlock &BB, AllocaInst *Addr) const { + if (SideEffectingBlocks.count(&BB)) + return true; + auto It = BaseMemAddrs.find(&BB); + if (It != BaseMemAddrs.end()) + return It->second.count(Addr); + return false; +} + +bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( + const CodeExtractorAnalysisCache &CEAC, Instruction *Addr) const { + AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets()); + Function *Func = (*Blocks.begin())->getParent(); + for (BasicBlock &BB : *Func) { + if (Blocks.count(&BB)) + continue; + if (CEAC.doesBlockContainClobberOfAddr(BB, AI)) + return false; + } return true; } @@ -415,7 +440,8 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { // outline region. If there are not other untracked uses of the address, return // the pair of markers if found; otherwise return a pair of nullptr. CodeExtractor::LifetimeMarkerInfo -CodeExtractor::getLifetimeMarkers(Instruction *Addr, +CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC, + Instruction *Addr, BasicBlock *ExitBlock) const { LifetimeMarkerInfo Info; @@ -447,7 +473,7 @@ CodeExtractor::getLifetimeMarkers(Instruction *Addr, Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd); // Do legality check. if ((Info.SinkLifeStart || Info.HoistLifeEnd) && - !isLegalToShrinkwrapLifetimeMarkers(Addr)) + !isLegalToShrinkwrapLifetimeMarkers(CEAC, Addr)) return {}; // Check to see if we have a place to do hoisting, if not, bail. @@ -457,7 +483,8 @@ CodeExtractor::getLifetimeMarkers(Instruction *Addr, return Info; } -void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, +void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC, + ValueSet &SinkCands, ValueSet &HoistCands, BasicBlock *&ExitBlock) const { Function *Func = (*Blocks.begin())->getParent(); ExitBlock = getCommonExitBlock(Blocks); @@ -478,74 +505,104 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, return true; }; - for (BasicBlock &BB : *Func) { - if (Blocks.count(&BB)) + // Look up allocas in the original function in CodeExtractorAnalysisCache, as + // this is much faster than walking all the instructions. + for (AllocaInst *AI : CEAC.getAllocas()) { + BasicBlock *BB = AI->getParent(); + if (Blocks.count(BB)) continue; - for (Instruction &II : BB) { - auto *AI = dyn_cast<AllocaInst>(&II); - if (!AI) - continue; - LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(AI, ExitBlock); - bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo); - if (Moved) { - LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n"); - SinkCands.insert(AI); - continue; - } + // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca, + // check whether it is actually still in the original function. + Function *AIFunc = BB->getParent(); + if (AIFunc != Func) + continue; - // Follow any bitcasts. - SmallVector<Instruction *, 2> Bitcasts; - SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo; - for (User *U : AI->users()) { - if (U->stripInBoundsConstantOffsets() == AI) { - Instruction *Bitcast = cast<Instruction>(U); - LifetimeMarkerInfo LMI = getLifetimeMarkers(Bitcast, ExitBlock); - if (LMI.LifeStart) { - Bitcasts.push_back(Bitcast); - BitcastLifetimeInfo.push_back(LMI); - continue; - } - } + LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(CEAC, AI, ExitBlock); + bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo); + if (Moved) { + LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n"); + SinkCands.insert(AI); + continue; + } - // Found unknown use of AI. - if (!definedInRegion(Blocks, U)) { - Bitcasts.clear(); - break; + // Follow any bitcasts. + SmallVector<Instruction *, 2> Bitcasts; + SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo; + for (User *U : AI->users()) { + if (U->stripInBoundsConstantOffsets() == AI) { + Instruction *Bitcast = cast<Instruction>(U); + LifetimeMarkerInfo LMI = getLifetimeMarkers(CEAC, Bitcast, ExitBlock); + if (LMI.LifeStart) { + Bitcasts.push_back(Bitcast); + BitcastLifetimeInfo.push_back(LMI); + continue; } } - // Either no bitcasts reference the alloca or there are unknown uses. - if (Bitcasts.empty()) - continue; + // Found unknown use of AI. + if (!definedInRegion(Blocks, U)) { + Bitcasts.clear(); + break; + } + } - LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n"); - SinkCands.insert(AI); - for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) { - Instruction *BitcastAddr = Bitcasts[I]; - const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I]; - assert(LMI.LifeStart && - "Unsafe to sink bitcast without lifetime markers"); - moveOrIgnoreLifetimeMarkers(LMI); - if (!definedInRegion(Blocks, BitcastAddr)) { - LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr - << "\n"); - SinkCands.insert(BitcastAddr); - } + // Either no bitcasts reference the alloca or there are unknown uses. + if (Bitcasts.empty()) + continue; + + LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n"); + SinkCands.insert(AI); + for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) { + Instruction *BitcastAddr = Bitcasts[I]; + const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I]; + assert(LMI.LifeStart && + "Unsafe to sink bitcast without lifetime markers"); + moveOrIgnoreLifetimeMarkers(LMI); + if (!definedInRegion(Blocks, BitcastAddr)) { + LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr + << "\n"); + SinkCands.insert(BitcastAddr); } } } } +bool CodeExtractor::isEligible() const { + if (Blocks.empty()) + return false; + BasicBlock *Header = *Blocks.begin(); + Function *F = Header->getParent(); + + // For functions with varargs, check that varargs handling is only done in the + // outlined function, i.e vastart and vaend are only used in outlined blocks. + if (AllowVarArgs && F->getFunctionType()->isVarArg()) { + auto containsVarArgIntrinsic = [](const Instruction &I) { + if (const CallInst *CI = dyn_cast<CallInst>(&I)) + if (const Function *Callee = CI->getCalledFunction()) + return Callee->getIntrinsicID() == Intrinsic::vastart || + Callee->getIntrinsicID() == Intrinsic::vaend; + return false; + }; + + for (auto &BB : *F) { + if (Blocks.count(&BB)) + continue; + if (llvm::any_of(BB, containsVarArgIntrinsic)) + return false; + } + } + return true; +} + void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, const ValueSet &SinkCands) const { for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. for (Instruction &II : *BB) { - for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE; - ++OI) { - Value *V = *OI; + for (auto &OI : II.operands()) { + Value *V = OI; if (!SinkCands.count(V) && definedInCaller(Blocks, V)) Inputs.insert(V); } @@ -904,12 +961,12 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // within the new function. This must be done before we lose track of which // blocks were originally in the code region. std::vector<User *> Users(header->user_begin(), header->user_end()); - for (unsigned i = 0, e = Users.size(); i != e; ++i) + for (auto &U : Users) // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block - if (Instruction *I = dyn_cast<Instruction>(Users[i])) - if (I->isTerminator() && !Blocks.count(I->getParent()) && - I->getParent()->getParent() == oldFunction) + if (Instruction *I = dyn_cast<Instruction>(U)) + if (I->isTerminator() && I->getFunction() == oldFunction && + !Blocks.count(I->getParent())) I->replaceUsesOfWith(header, newHeader); return newFunction; @@ -1277,13 +1334,6 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { // Insert this basic block into the new function newBlocks.push_back(Block); - - // Remove @llvm.assume calls that were moved to the new function from the - // old function's assumption cache. - if (AC) - for (auto &I : *Block) - if (match(&I, m_Intrinsic<Intrinsic::assume>())) - AC->unregisterAssumption(cast<CallInst>(&I)); } } @@ -1332,7 +1382,8 @@ void CodeExtractor::calculateNewCallTerminatorWeights( MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); } -Function *CodeExtractor::extractCodeRegion() { +Function * +CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { if (!isEligible()) return nullptr; @@ -1341,27 +1392,6 @@ Function *CodeExtractor::extractCodeRegion() { BasicBlock *header = *Blocks.begin(); Function *oldFunction = header->getParent(); - // For functions with varargs, check that varargs handling is only done in the - // outlined function, i.e vastart and vaend are only used in outlined blocks. - if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) { - auto containsVarArgIntrinsic = [](Instruction &I) { - if (const CallInst *CI = dyn_cast<CallInst>(&I)) - if (const Function *F = CI->getCalledFunction()) - return F->getIntrinsicID() == Intrinsic::vastart || - F->getIntrinsicID() == Intrinsic::vaend; - return false; - }; - - for (auto &BB : *oldFunction) { - if (Blocks.count(&BB)) - continue; - if (llvm::any_of(BB, containsVarArgIntrinsic)) - return nullptr; - } - } - ValueSet inputs, outputs, SinkingCands, HoistingCands; - BasicBlock *CommonExit = nullptr; - // Calculate the entry frequency of the new function before we change the root // block. BlockFrequency EntryFreq; @@ -1375,6 +1405,15 @@ Function *CodeExtractor::extractCodeRegion() { } } + if (AC) { + // Remove @llvm.assume calls that were moved to the new function from the + // old function's assumption cache. + for (BasicBlock *Block : Blocks) + for (auto &I : *Block) + if (match(&I, m_Intrinsic<Intrinsic::assume>())) + AC->unregisterAssumption(cast<CallInst>(&I)); + } + // If we have any return instructions in the region, split those blocks so // that the return is not in the region. splitReturnBlocks(); @@ -1428,7 +1467,9 @@ Function *CodeExtractor::extractCodeRegion() { } newFuncRoot->getInstList().push_back(BranchI); - findAllocas(SinkingCands, HoistingCands, CommonExit); + ValueSet inputs, outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); assert(HoistingCands.empty() || CommonExit); // Find inputs to, outputs from the code region. @@ -1563,5 +1604,17 @@ Function *CodeExtractor::extractCodeRegion() { }); LLVM_DEBUG(if (verifyFunction(*oldFunction)) report_fatal_error("verification of oldFunction failed!")); + LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, AC)) + report_fatal_error("Stale Asumption cache for old Function!")); return newFunction; } + +bool CodeExtractor::verifyAssumptionCache(const Function& F, + AssumptionCache *AC) { + for (auto AssumeVH : AC->assumptions()) { + CallInst *I = cast<CallInst>(AssumeVH); + if (I->getFunction() != &F) + return true; + } + return false; +} diff --git a/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 4aa40eeadda4..57e2ff0251a9 100644 --- a/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -24,7 +24,7 @@ static void insertCall(Function &CurFn, StringRef Func, if (Func == "mcount" || Func == ".mcount" || - Func == "\01__gnu_mcount_nc" || + Func == "llvm.arm.gnu.eabi.mcount" || Func == "\01_mcount" || Func == "\01mcount" || Func == "__mcount" || diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp index 0e203f4e075d..ad36790b8c6a 100644 --- a/lib/Transforms/Utils/Evaluator.cpp +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -469,7 +469,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, return false; // Cannot handle array allocs. } Type *Ty = AI->getAllocatedType(); - AllocaTmps.push_back(llvm::make_unique<GlobalVariable>( + AllocaTmps.push_back(std::make_unique<GlobalVariable>( Ty, false, GlobalValue::InternalLinkage, UndefValue::get(Ty), AI->getName(), /*TLMode=*/GlobalValue::NotThreadLocal, AI->getType()->getPointerAddressSpace())); diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 0c52e6f3703b..893f23eb6048 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -67,7 +67,7 @@ public: /// Before: /// ...... /// %cmp10 = fcmp une float %tmp1, %tmp2 -/// br i1 %cmp1, label %if.then, label %lor.rhs +/// br i1 %cmp10, label %if.then, label %lor.rhs /// /// lor.rhs: /// ...... @@ -251,8 +251,8 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { bool EverChanged = false; for (; CurrBlock != FirstCondBlock; CurrBlock = CurrBlock->getSinglePredecessor()) { - BranchInst *BI = dyn_cast<BranchInst>(CurrBlock->getTerminator()); - CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); + auto *BI = cast<BranchInst>(CurrBlock->getTerminator()); + auto *CI = dyn_cast<CmpInst>(BI->getCondition()); if (!CI) continue; @@ -278,7 +278,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { // Do the transformation. BasicBlock *CB; - BranchInst *PBI = dyn_cast<BranchInst>(FirstCondBlock->getTerminator()); + BranchInst *PBI = cast<BranchInst>(FirstCondBlock->getTerminator()); bool Iteration = true; IRBuilder<>::InsertPointGuard Guard(Builder); Value *PC = PBI->getCondition(); @@ -444,7 +444,7 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { FirstEntryBlock->getInstList().pop_back(); FirstEntryBlock->getInstList() .splice(FirstEntryBlock->end(), SecondEntryBlock->getInstList()); - BranchInst *PBI = dyn_cast<BranchInst>(FirstEntryBlock->getTerminator()); + BranchInst *PBI = cast<BranchInst>(FirstEntryBlock->getTerminator()); Value *CC = PBI->getCondition(); BasicBlock *SaveInsertBB = Builder.GetInsertBlock(); BasicBlock::iterator SaveInsertPt = Builder.GetInsertPoint(); @@ -453,6 +453,16 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { PBI->replaceUsesOfWith(CC, NC); Builder.SetInsertPoint(SaveInsertBB, SaveInsertPt); + // Handle PHI node to replace its predecessors to FirstEntryBlock. + for (BasicBlock *Succ : successors(PBI)) { + for (PHINode &Phi : Succ->phis()) { + for (unsigned i = 0, e = Phi.getNumIncomingValues(); i != e; ++i) { + if (Phi.getIncomingBlock(i) == SecondEntryBlock) + Phi.setIncomingBlock(i, FirstEntryBlock); + } + } + } + // Remove IfTrue1 if (IfTrue1 != FirstEntryBlock) { IfTrue1->dropAllReferences(); diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp index c9cc0990f237..76b4635ad501 100644 --- a/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -210,7 +210,7 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { if (Function *F = dyn_cast<Function>(&GV)) { if (!F->isDeclaration()) { for (auto &S : VI.getSummaryList()) { - FunctionSummary *FS = dyn_cast<FunctionSummary>(S->getBaseObject()); + auto *FS = cast<FunctionSummary>(S->getBaseObject()); if (FS->modulePath() == M.getModuleIdentifier()) { F->setEntryCount(Function::ProfileCount(FS->entryCount(), Function::PCT_Synthetic)); diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp index 8041e66e6c4c..ea93f99d69e3 100644 --- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp +++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -25,8 +25,8 @@ ImportedFunctionsInliningStatistics::createInlineGraphNode(const Function &F) { auto &ValueLookup = NodesMap[F.getName()]; if (!ValueLookup) { - ValueLookup = llvm::make_unique<InlineGraphNode>(); - ValueLookup->Imported = F.getMetadata("thinlto_src_module") != nullptr; + ValueLookup = std::make_unique<InlineGraphNode>(); + ValueLookup->Imported = F.hasMetadata("thinlto_src_module"); } return *ValueLookup; } @@ -64,7 +64,7 @@ void ImportedFunctionsInliningStatistics::setModuleInfo(const Module &M) { if (F.isDeclaration()) continue; AllFunctions++; - ImportedFunctions += int(F.getMetadata("thinlto_src_module") != nullptr); + ImportedFunctions += int(F.hasMetadata("thinlto_src_module")); } } static std::string getStatString(const char *Msg, int32_t Fraction, int32_t All, diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 8c67d1dc6eb3..ed28fffc22b5 100644 --- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -533,7 +533,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI, } bool LibCallsShrinkWrapLegacyPass::runOnFunction(Function &F) { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; return runImpl(F, TLI, DT); diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index 39b6b889f91c..5bcd05757ec1 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -324,8 +324,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, Value *Address = IBI->getAddress(); IBI->eraseFromParent(); if (DeleteDeadConditions) + // Delete pointer cast instructions. RecursivelyDeleteTriviallyDeadInstructions(Address, TLI); + // Also zap the blockaddress constant if there are no users remaining, + // otherwise the destination is still marked as having its address taken. + if (BA->use_empty()) + BA->destroyConstant(); + // If we didn't find our destination in the IBI successor list, then we // have undefined behavior. Replace the unconditional branch with an // 'unreachable' instruction. @@ -633,17 +639,6 @@ bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, // Control Flow Graph Restructuring. // -/// RemovePredecessorAndSimplify - Like BasicBlock::removePredecessor, this -/// method is called when we're about to delete Pred as a predecessor of BB. If -/// BB contains any PHI nodes, this drops the entries in the PHI nodes for Pred. -/// -/// Unlike the removePredecessor method, this attempts to simplify uses of PHI -/// nodes that collapse into identity values. For example, if we have: -/// x = phi(1, 0, 0, 0) -/// y = and x, z -/// -/// .. and delete the predecessor corresponding to the '1', this will attempt to -/// recursively fold the and to 0. void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred, DomTreeUpdater *DTU) { // This only adjusts blocks with PHI nodes. @@ -672,10 +667,6 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred, DTU->applyUpdatesPermissive({{DominatorTree::Delete, Pred, BB}}); } -/// MergeBasicBlockIntoOnlyPred - DestBB is a block with one predecessor and its -/// predecessor is known to have one successor (DestBB!). Eliminate the edge -/// between them, moving the instructions in the predecessor into DestBB and -/// deleting the predecessor block. void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DomTreeUpdater *DTU) { @@ -755,15 +746,14 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, } } -/// CanMergeValues - Return true if we can choose one of these values to use -/// in place of the other. Note that we will always choose the non-undef -/// value to keep. +/// Return true if we can choose one of these values to use in place of the +/// other. Note that we will always choose the non-undef value to keep. static bool CanMergeValues(Value *First, Value *Second) { return First == Second || isa<UndefValue>(First) || isa<UndefValue>(Second); } -/// CanPropagatePredecessorsForPHIs - Return true if we can fold BB, an -/// almost-empty BB ending in an unconditional branch to Succ, into Succ. +/// Return true if we can fold BB, an almost-empty BB ending in an unconditional +/// branch to Succ, into Succ. /// /// Assumption: Succ is the single successor for BB. static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { @@ -956,11 +946,6 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB, replaceUndefValuesInPhi(PN, IncomingValues); } -/// TryToSimplifyUncondBranchFromEmptyBlock - BB is known to contain an -/// unconditional branch, and contains no instructions other than PHI nodes, -/// potential side-effect free intrinsics and the branch. If possible, -/// eliminate BB by rewriting all the predecessors to branch to the successor -/// block and return true. If we can't transform, return false. bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, DomTreeUpdater *DTU) { assert(BB != &BB->getParent()->getEntryBlock() && @@ -1088,10 +1073,6 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, return true; } -/// EliminateDuplicatePHINodes - Check for and eliminate duplicate PHI -/// nodes in this block. This doesn't try to be clever about PHI nodes -/// which differ only in the order of the incoming values, but instcombine -/// orders them so it usually won't matter. bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { // This implementation doesn't currently consider undef operands // specially. Theoretically, two phis which are identical except for @@ -1151,10 +1132,10 @@ bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { /// often possible though. If alignment is important, a more reliable approach /// is to simply align all global variables and allocation instructions to /// their preferred alignment from the beginning. -static unsigned enforceKnownAlignment(Value *V, unsigned Align, +static unsigned enforceKnownAlignment(Value *V, unsigned Alignment, unsigned PrefAlign, const DataLayout &DL) { - assert(PrefAlign > Align); + assert(PrefAlign > Alignment); V = V->stripPointerCasts(); @@ -1165,36 +1146,36 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align, // stripPointerCasts recurses through infinite layers of bitcasts, // while computeKnownBits is not allowed to traverse more than 6 // levels. - Align = std::max(AI->getAlignment(), Align); - if (PrefAlign <= Align) - return Align; + Alignment = std::max(AI->getAlignment(), Alignment); + if (PrefAlign <= Alignment) + return Alignment; // If the preferred alignment is greater than the natural stack alignment // then don't round up. This avoids dynamic stack realignment. - if (DL.exceedsNaturalStackAlignment(PrefAlign)) - return Align; - AI->setAlignment(PrefAlign); + if (DL.exceedsNaturalStackAlignment(Align(PrefAlign))) + return Alignment; + AI->setAlignment(MaybeAlign(PrefAlign)); return PrefAlign; } if (auto *GO = dyn_cast<GlobalObject>(V)) { // TODO: as above, this shouldn't be necessary. - Align = std::max(GO->getAlignment(), Align); - if (PrefAlign <= Align) - return Align; + Alignment = std::max(GO->getAlignment(), Alignment); + if (PrefAlign <= Alignment) + return Alignment; // If there is a large requested alignment and we can, bump up the alignment // of the global. If the memory we set aside for the global may not be the // memory used by the final program then it is impossible for us to reliably // enforce the preferred alignment. if (!GO->canIncreaseAlignment()) - return Align; + return Alignment; - GO->setAlignment(PrefAlign); + GO->setAlignment(MaybeAlign(PrefAlign)); return PrefAlign; } - return Align; + return Alignment; } unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, @@ -1397,7 +1378,12 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, /// Determine whether this alloca is either a VLA or an array. static bool isArray(AllocaInst *AI) { return AI->isArrayAllocation() || - AI->getType()->getElementType()->isArrayTy(); + (AI->getAllocatedType() && AI->getAllocatedType()->isArrayTy()); +} + +/// Determine whether this alloca is a structure. +static bool isStructure(AllocaInst *AI) { + return AI->getAllocatedType() && AI->getAllocatedType()->isStructTy(); } /// LowerDbgDeclare - Lowers llvm.dbg.declare intrinsics into appropriate set @@ -1422,7 +1408,7 @@ bool llvm::LowerDbgDeclare(Function &F) { // stored on the stack, while the dbg.declare can only describe // the stack slot (and at a lexical-scope granularity). Later // passes will attempt to elide the stack slot. - if (!AI || isArray(AI)) + if (!AI || isArray(AI) || isStructure(AI)) continue; // A volatile load/store means that the alloca can't be elided anyway. @@ -1591,15 +1577,10 @@ static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, DIExpr->getElement(0) != dwarf::DW_OP_deref) return; - // Insert the offset immediately after the first deref. + // Insert the offset before the first deref. // We could just change the offset argument of dbg.value, but it's unsigned... - if (Offset) { - SmallVector<uint64_t, 4> Ops; - Ops.push_back(dwarf::DW_OP_deref); - DIExpression::appendOffset(Ops, Offset); - Ops.append(DIExpr->elements_begin() + 1, DIExpr->elements_end()); - DIExpr = Builder.createExpression(Ops); - } + if (Offset) + DIExpr = DIExpression::prepend(DIExpr, 0, Offset); Builder.insertDbgValueIntrinsic(NewAddress, DIVar, DIExpr, Loc, DVI); DVI->eraseFromParent(); @@ -1957,18 +1938,24 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, return NumInstrsRemoved; } -/// changeToCall - Convert the specified invoke into a normal call. -static void changeToCall(InvokeInst *II, DomTreeUpdater *DTU = nullptr) { - SmallVector<Value*, 8> Args(II->arg_begin(), II->arg_end()); +CallInst *llvm::createCallMatchingInvoke(InvokeInst *II) { + SmallVector<Value *, 8> Args(II->arg_begin(), II->arg_end()); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); - CallInst *NewCall = CallInst::Create( - II->getFunctionType(), II->getCalledValue(), Args, OpBundles, "", II); - NewCall->takeName(II); + CallInst *NewCall = CallInst::Create(II->getFunctionType(), + II->getCalledValue(), Args, OpBundles); NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); NewCall->setDebugLoc(II->getDebugLoc()); NewCall->copyMetadata(*II); + return NewCall; +} + +/// changeToCall - Convert the specified invoke into a normal call. +void llvm::changeToCall(InvokeInst *II, DomTreeUpdater *DTU) { + CallInst *NewCall = createCallMatchingInvoke(II); + NewCall->takeName(II); + NewCall->insertBefore(II); II->replaceAllUsesWith(NewCall); // Follow the call by a branch to the normal destination. @@ -2223,12 +2210,10 @@ void llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { /// removeUnreachableBlocks - Remove blocks that are not reachable, even /// if they are in a dead cycle. Return true if a change was made, false -/// otherwise. If `LVI` is passed, this function preserves LazyValueInfo -/// after modifying the CFG. -bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, - DomTreeUpdater *DTU, +/// otherwise. +bool llvm::removeUnreachableBlocks(Function &F, DomTreeUpdater *DTU, MemorySSAUpdater *MSSAU) { - SmallPtrSet<BasicBlock*, 16> Reachable; + SmallPtrSet<BasicBlock *, 16> Reachable; bool Changed = markAliveBlocks(F, Reachable, DTU); // If there are unreachable blocks in the CFG... @@ -2236,21 +2221,21 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, return Changed; assert(Reachable.size() < F.size()); - NumRemoved += F.size()-Reachable.size(); + NumRemoved += F.size() - Reachable.size(); SmallSetVector<BasicBlock *, 8> DeadBlockSet; - for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ++I) { - auto *BB = &*I; - if (Reachable.count(BB)) + for (BasicBlock &BB : F) { + // Skip reachable basic blocks + if (Reachable.find(&BB) != Reachable.end()) continue; - DeadBlockSet.insert(BB); + DeadBlockSet.insert(&BB); } if (MSSAU) MSSAU->removeBlocks(DeadBlockSet); // Loop over all of the basic blocks that are not reachable, dropping all of - // their internal references. Update DTU and LVI if available. + // their internal references. Update DTU if available. std::vector<DominatorTree::UpdateType> Updates; for (auto *BB : DeadBlockSet) { for (BasicBlock *Successor : successors(BB)) { @@ -2259,26 +2244,18 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, if (DTU) Updates.push_back({DominatorTree::Delete, BB, Successor}); } - if (LVI) - LVI->eraseBlock(BB); BB->dropAllReferences(); - } - for (Function::iterator I = ++F.begin(); I != F.end();) { - auto *BB = &*I; - if (Reachable.count(BB)) { - ++I; - continue; - } if (DTU) { - // Remove the terminator of BB to clear the successor list of BB. - if (BB->getTerminator()) - BB->getInstList().pop_back(); + Instruction *TI = BB->getTerminator(); + assert(TI && "Basic block should have a terminator"); + // Terminators like invoke can have users. We have to replace their users, + // before removing them. + if (!TI->use_empty()) + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + TI->eraseFromParent(); new UnreachableInst(BB->getContext(), BB); assert(succ_empty(BB) && "The successor list of BB isn't empty before " "applying corresponding DTU updates."); - ++I; - } else { - I = F.getBasicBlockList().erase(I); } } @@ -2294,7 +2271,11 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, } if (!Deleted) return false; + } else { + for (auto *BB : DeadBlockSet) + BB->eraseFromParent(); } + return true; } @@ -2363,6 +2344,9 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; + case LLVMContext::MD_preserve_access_index: + // Preserve !preserve.access.index in K. + break; } } // Set !invariant.group from J if J has it. If both instructions have it @@ -2385,10 +2369,61 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, LLVMContext::MD_invariant_group, LLVMContext::MD_align, LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_access_group}; + LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index}; combineMetadata(K, J, KnownIDs, KDominatesJ); } +void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) { + SmallVector<std::pair<unsigned, MDNode *>, 8> MD; + Source.getAllMetadata(MD); + MDBuilder MDB(Dest.getContext()); + Type *NewType = Dest.getType(); + const DataLayout &DL = Source.getModule()->getDataLayout(); + for (const auto &MDPair : MD) { + unsigned ID = MDPair.first; + MDNode *N = MDPair.second; + // Note, essentially every kind of metadata should be preserved here! This + // routine is supposed to clone a load instruction changing *only its type*. + // The only metadata it makes sense to drop is metadata which is invalidated + // when the pointer type changes. This should essentially never be the case + // in LLVM, but we explicitly switch over only known metadata to be + // conservatively correct. If you are adding metadata to LLVM which pertains + // to loads, you almost certainly want to add it here. + switch (ID) { + case LLVMContext::MD_dbg: + case LLVMContext::MD_tbaa: + case LLVMContext::MD_prof: + case LLVMContext::MD_fpmath: + case LLVMContext::MD_tbaa_struct: + case LLVMContext::MD_invariant_load: + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_access_group: + // All of these directly apply. + Dest.setMetadata(ID, N); + break; + + case LLVMContext::MD_nonnull: + copyNonnullMetadata(Source, N, Dest); + break; + + case LLVMContext::MD_align: + case LLVMContext::MD_dereferenceable: + case LLVMContext::MD_dereferenceable_or_null: + // These only directly apply if the new type is also a pointer. + if (NewType->isPointerTy()) + Dest.setMetadata(ID, N); + break; + + case LLVMContext::MD_range: + copyRangeMetadata(DL, Source, N, Dest); + break; + } + } +} + void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) { auto *ReplInst = dyn_cast<Instruction>(Repl); if (!ReplInst) @@ -2417,7 +2452,7 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) { LLVMContext::MD_noalias, LLVMContext::MD_range, LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, LLVMContext::MD_invariant_group, LLVMContext::MD_nonnull, - LLVMContext::MD_access_group}; + LLVMContext::MD_access_group, LLVMContext::MD_preserve_access_index}; combineMetadata(ReplInst, I, KnownIDs, false); } diff --git a/lib/Transforms/Utils/LoopRotationUtils.cpp b/lib/Transforms/Utils/LoopRotationUtils.cpp index 37389a695b45..889ea5ca9970 100644 --- a/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -615,30 +615,9 @@ bool LoopRotate::simplifyLoopLatch(Loop *L) { LLVM_DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into " << LastExit->getName() << "\n"); - // Hoist the instructions from Latch into LastExit. - Instruction *FirstLatchInst = &*(Latch->begin()); - LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), - Latch->begin(), Jmp->getIterator()); - - // Update MemorySSA - if (MSSAU) - MSSAU->moveAllAfterMergeBlocks(Latch, LastExit, FirstLatchInst); - - unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1; - BasicBlock *Header = Jmp->getSuccessor(0); - assert(Header == L->getHeader() && "expected a backward branch"); - - // Remove Latch from the CFG so that LastExit becomes the new Latch. - BI->setSuccessor(FallThruPath, Header); - Latch->replaceSuccessorsPhiUsesWith(LastExit); - Jmp->eraseFromParent(); - - // Nuke the Latch block. - assert(Latch->empty() && "unable to evacuate Latch"); - LI->removeBlock(Latch); - if (DT) - DT->eraseNode(Latch); - Latch->eraseFromParent(); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + MergeBlockIntoPredecessor(Latch, &DTU, LI, MSSAU, nullptr, + /*PredecessorWithTwoSuccessors=*/true); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index 7e6da02d5707..d0f89dc54bfb 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -808,7 +808,7 @@ bool LoopSimplify::runOnFunction(Function &F) { auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); if (MSSAAnalysis) { MSSA = &MSSAAnalysis->getMSSA(); - MSSAU = make_unique<MemorySSAUpdater>(MSSA); + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); } } @@ -835,12 +835,19 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F, DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); ScalarEvolution *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F); AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); + auto *MSSAAnalysis = AM.getCachedResult<MemorySSAAnalysis>(F); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSAAnalysis) { + auto *MSSA = &MSSAAnalysis->getMSSA(); + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); + } + // Note that we don't preserve LCSSA in the new PM, if you need it run LCSSA - // after simplifying the loops. MemorySSA is not preserved either. + // after simplifying the loops. MemorySSA is preserved if it exists. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) Changed |= - simplifyLoop(*I, DT, LI, SE, AC, nullptr, /*PreserveLCSSA*/ false); + simplifyLoop(*I, DT, LI, SE, AC, MSSAU.get(), /*PreserveLCSSA*/ false); if (!Changed) return PreservedAnalyses::all(); @@ -853,6 +860,8 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F, PA.preserve<SCEVAA>(); PA.preserve<ScalarEvolutionAnalysis>(); PA.preserve<DependenceAnalysis>(); + if (MSSAAnalysis) + PA.preserve<MemorySSAAnalysis>(); // BPI maps conditional terminators to probabilities, LoopSimplify can insert // blocks, but it does so only by splitting existing blocks and edges. This // results in the interesting property that all new terminators inserted are diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index e39ade523714..a7590fc32545 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -711,7 +711,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, auto setDest = [LoopExit, ContinueOnTrue](BasicBlock *Src, BasicBlock *Dest, ArrayRef<BasicBlock *> NextBlocks, - BasicBlock *CurrentHeader, + BasicBlock *BlockInLoop, bool NeedConditional) { auto *Term = cast<BranchInst>(Src->getTerminator()); if (NeedConditional) { @@ -723,7 +723,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (Dest != LoopExit) { BasicBlock *BB = Src; for (BasicBlock *Succ : successors(BB)) { - if (Succ == CurrentHeader) + // Preserve the incoming value from BB if we are jumping to the block + // in the current loop. + if (Succ == BlockInLoop) continue; for (PHINode &Phi : Succ->phis()) Phi.removeIncomingValue(BB, false); @@ -794,7 +796,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, // unconditional branch for some iterations. NeedConditional = false; - setDest(Headers[i], Dest, Headers, Headers[i], NeedConditional); + setDest(Headers[i], Dest, Headers, HeaderSucc[i], NeedConditional); } // Set up latches to branch to the new header in the unrolled iterations or @@ -868,7 +870,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, assert(!DT || !UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); // Merge adjacent basic blocks, if possible. for (BasicBlock *Latch : Latches) { BranchInst *Term = dyn_cast<BranchInst>(Latch->getTerminator()); @@ -888,6 +890,8 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, } } } + // Apply updates to the DomTree. + DT = &DTU.getDomTree(); // At this point, the code is well formed. We now simplify the unrolled loop, // doing constant propagation and dead code elimination as we go. diff --git a/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/lib/Transforms/Utils/LoopUnrollAndJam.cpp index ff49d83f25c5..bf2e87b0d49f 100644 --- a/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -517,6 +517,7 @@ LoopUnrollResult llvm::UnrollAndJamLoop( movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]); } + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the // new ones required. if (Count != 1) { @@ -530,7 +531,7 @@ LoopUnrollResult llvm::UnrollAndJamLoop( ForeBlocksLast.back(), SubLoopBlocksFirst[0]); DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert, SubLoopBlocksLast.back(), AftBlocksFirst[0]); - DT->applyUpdates(DTUpdates); + DTU.applyUpdatesPermissive(DTUpdates); } // Merge adjacent basic blocks, if possible. @@ -538,7 +539,6 @@ LoopUnrollResult llvm::UnrollAndJamLoop( MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end()); MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end()); MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end()); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); while (!MergeBlocks.empty()) { BasicBlock *BB = *MergeBlocks.begin(); BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator()); @@ -555,6 +555,8 @@ LoopUnrollResult llvm::UnrollAndJamLoop( } else MergeBlocks.erase(BB); } + // Apply updates to the DomTree. + DT = &DTU.getDomTree(); // At this point, the code is well formed. We now do a quick sweep over the // inserted code, doing constant propagation and dead code elimination as we diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp index 005306cf1898..58e42074f963 100644 --- a/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -62,9 +62,11 @@ static cl::opt<unsigned> UnrollForcePeelCount( cl::desc("Force a peel count regardless of profiling information.")); static cl::opt<bool> UnrollPeelMultiDeoptExit( - "unroll-peel-multi-deopt-exit", cl::init(false), cl::Hidden, + "unroll-peel-multi-deopt-exit", cl::init(true), cl::Hidden, cl::desc("Allow peeling of loops with multiple deopt exits.")); +static const char *PeeledCountMetaData = "llvm.loop.peeled.count"; + // Designates that a Phi is estimated to become invariant after an "infinite" // number of loop iterations (i.e. only may become an invariant if the loop is // fully unrolled). @@ -275,6 +277,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, LLVM_DEBUG(dbgs() << "Force-peeling first " << UnrollForcePeelCount << " iterations.\n"); UP.PeelCount = UnrollForcePeelCount; + UP.PeelProfiledIterations = true; return; } @@ -282,6 +285,13 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (!UP.AllowPeeling) return; + unsigned AlreadyPeeled = 0; + if (auto Peeled = getOptionalIntLoopAttribute(L, PeeledCountMetaData)) + AlreadyPeeled = *Peeled; + // Stop if we already peeled off the maximum number of iterations. + if (AlreadyPeeled >= UnrollPeelMaxCount) + return; + // Here we try to get rid of Phis which become invariants after 1, 2, ..., N // iterations of the loop. For this we compute the number for iterations after // which every Phi is guaranteed to become an invariant, and try to peel the @@ -317,11 +327,14 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); // Consider max peel count limitation. assert(DesiredPeelCount > 0 && "Wrong loop size estimation?"); - LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount - << " iteration(s) to turn" - << " some Phis into invariants.\n"); - UP.PeelCount = DesiredPeelCount; - return; + if (DesiredPeelCount + AlreadyPeeled <= UnrollPeelMaxCount) { + LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount + << " iteration(s) to turn" + << " some Phis into invariants.\n"); + UP.PeelCount = DesiredPeelCount; + UP.PeelProfiledIterations = false; + return; + } } } @@ -330,6 +343,9 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (TripCount) return; + // Do not apply profile base peeling if it is disabled. + if (!UP.PeelProfiledIterations) + return; // If we don't know the trip count, but have reason to believe the average // trip count is low, peeling should be beneficial, since we will usually // hit the peeled section. @@ -344,7 +360,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, << "\n"); if (*PeelCount) { - if ((*PeelCount <= UnrollPeelMaxCount) && + if ((*PeelCount + AlreadyPeeled <= UnrollPeelMaxCount) && (LoopSize * (*PeelCount + 1) <= UP.Threshold)) { LLVM_DEBUG(dbgs() << "Peeling first " << *PeelCount << " iterations.\n"); @@ -352,6 +368,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, return; } LLVM_DEBUG(dbgs() << "Requested peel count: " << *PeelCount << "\n"); + LLVM_DEBUG(dbgs() << "Already peel count: " << AlreadyPeeled << "\n"); LLVM_DEBUG(dbgs() << "Max peel count: " << UnrollPeelMaxCount << "\n"); LLVM_DEBUG(dbgs() << "Peel cost: " << LoopSize * (*PeelCount + 1) << "\n"); @@ -364,88 +381,77 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, /// iteration. /// This sets the branch weights for the latch of the recently peeled off loop /// iteration correctly. -/// Our goal is to make sure that: -/// a) The total weight of all the copies of the loop body is preserved. -/// b) The total weight of the loop exit is preserved. -/// c) The body weight is reasonably distributed between the peeled iterations. +/// Let F is a weight of the edge from latch to header. +/// Let E is a weight of the edge from latch to exit. +/// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to +/// go to exit. +/// Then, Estimated TripCount = F / E. +/// For I-th (counting from 0) peeled off iteration we set the the weights for +/// the peeled latch as (TC - I, 1). It gives us reasonable distribution, +/// The probability to go to exit 1/(TC-I) increases. At the same time +/// the estimated trip count of remaining loop reduces by I. +/// To avoid dealing with division rounding we can just multiple both part +/// of weights to E and use weight as (F - I * E, E). /// /// \param Header The copy of the header block that belongs to next iteration. /// \param LatchBR The copy of the latch branch that belongs to this iteration. -/// \param IterNumber The serial number of the iteration that was just -/// peeled off. -/// \param AvgIters The average number of iterations we expect the loop to have. -/// \param[in,out] PeeledHeaderWeight The total number of dynamic loop -/// iterations that are unaccounted for. As an input, it represents the number -/// of times we expect to enter the header of the iteration currently being -/// peeled off. The output is the number of times we expect to enter the -/// header of the next iteration. +/// \param[in,out] FallThroughWeight The weight of the edge from latch to +/// header before peeling (in) and after peeled off one iteration (out). static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - unsigned IterNumber, unsigned AvgIters, - uint64_t &PeeledHeaderWeight) { - if (!PeeledHeaderWeight) + uint64_t ExitWeight, + uint64_t &FallThroughWeight) { + // FallThroughWeight is 0 means that there is no branch weights on original + // latch block or estimated trip count is zero. + if (!FallThroughWeight) return; - // FIXME: Pick a more realistic distribution. - // Currently the proportion of weight we assign to the fall-through - // side of the branch drops linearly with the iteration number, and we use - // a 0.9 fudge factor to make the drop-off less sharp... - uint64_t FallThruWeight = - PeeledHeaderWeight * ((float)(AvgIters - IterNumber) / AvgIters * 0.9); - uint64_t ExitWeight = PeeledHeaderWeight - FallThruWeight; - PeeledHeaderWeight -= ExitWeight; unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); MDBuilder MDB(LatchBR->getContext()); MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThruWeight) - : MDB.createBranchWeights(FallThruWeight, ExitWeight); + HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) + : MDB.createBranchWeights(FallThroughWeight, ExitWeight); LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); + FallThroughWeight = + FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1; } /// Initialize the weights. /// /// \param Header The header block. /// \param LatchBR The latch branch. -/// \param AvgIters The average number of iterations we expect the loop to have. -/// \param[out] ExitWeight The # of times the edge from Latch to Exit is taken. -/// \param[out] CurHeaderWeight The # of times the header is executed. +/// \param[out] ExitWeight The weight of the edge from Latch to Exit. +/// \param[out] FallThroughWeight The weight of the edge from Latch to Header. static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - unsigned AvgIters, uint64_t &ExitWeight, - uint64_t &CurHeaderWeight) { + uint64_t &ExitWeight, + uint64_t &FallThroughWeight) { uint64_t TrueWeight, FalseWeight; if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) return; unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; - // The # of times the loop body executes is the sum of the exit block - // is taken and the # of times the backedges are taken. - CurHeaderWeight = TrueWeight + FalseWeight; + FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight; } /// Update the weights of original Latch block after peeling off all iterations. /// /// \param Header The header block. /// \param LatchBR The latch branch. -/// \param ExitWeight The weight of the edge from Latch to Exit block. -/// \param CurHeaderWeight The # of time the header is executed. +/// \param ExitWeight The weight of the edge from Latch to Exit. +/// \param FallThroughWeight The weight of the edge from Latch to Header. static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t ExitWeight, uint64_t CurHeaderWeight) { - // Adjust the branch weights on the loop exit. - if (!ExitWeight) + uint64_t ExitWeight, + uint64_t FallThroughWeight) { + // FallThroughWeight is 0 means that there is no branch weights on original + // latch block or estimated trip count is zero. + if (!FallThroughWeight) return; - // The backedge count is the difference of current header weight and - // current loop exit weight. If the current header weight is smaller than - // the current loop exit weight, we mark the loop backedge weight as 1. - uint64_t BackEdgeWeight = 0; - if (ExitWeight < CurHeaderWeight) - BackEdgeWeight = CurHeaderWeight - ExitWeight; - else - BackEdgeWeight = 1; + // Sets the branch weights on the loop exit. MDBuilder MDB(LatchBR->getContext()); unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) - : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) + : MDB.createBranchWeights(FallThroughWeight, ExitWeight); LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); } @@ -586,11 +592,30 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, DenseMap<BasicBlock *, BasicBlock *> ExitIDom; if (DT) { + // We'd like to determine the idom of exit block after peeling one + // iteration. + // Let Exit is exit block. + // Let ExitingSet - is a set of predecessors of Exit block. They are exiting + // blocks. + // Let Latch' and ExitingSet' are copies after a peeling. + // We'd like to find an idom'(Exit) - idom of Exit after peeling. + // It is an evident that idom'(Exit) will be the nearest common dominator + // of ExitingSet and ExitingSet'. + // idom(Exit) is a nearest common dominator of ExitingSet. + // idom(Exit)' is a nearest common dominator of ExitingSet'. + // Taking into account that we have a single Latch, Latch' will dominate + // Header and idom(Exit). + // So the idom'(Exit) is nearest common dominator of idom(Exit)' and Latch'. + // All these basic blocks are in the same loop, so what we find is + // (nearest common dominator of idom(Exit) and Latch)'. + // In the loop below we remember nearest common dominator of idom(Exit) and + // Latch to update idom of Exit later. assert(L->hasDedicatedExits() && "No dedicated exits?"); for (auto Edge : ExitEdges) { if (ExitIDom.count(Edge.second)) continue; - BasicBlock *BB = DT->getNode(Edge.second)->getIDom()->getBlock(); + BasicBlock *BB = DT->findNearestCommonDominator( + DT->getNode(Edge.second)->getIDom()->getBlock(), Latch); assert(L->contains(BB) && "IDom is not in a loop"); ExitIDom[Edge.second] = BB; } @@ -659,23 +684,14 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // newly created branches. BranchInst *LatchBR = cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator()); - uint64_t ExitWeight = 0, CurHeaderWeight = 0; - initBranchWeights(Header, LatchBR, PeelCount, ExitWeight, CurHeaderWeight); + uint64_t ExitWeight = 0, FallThroughWeight = 0; + initBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight); // For each peeled-off iteration, make a copy of the loop. for (unsigned Iter = 0; Iter < PeelCount; ++Iter) { SmallVector<BasicBlock *, 8> NewBlocks; ValueToValueMapTy VMap; - // Subtract the exit weight from the current header weight -- the exit - // weight is exactly the weight of the previous iteration's header. - // FIXME: due to the way the distribution is constructed, we need a - // guard here to make sure we don't end up with non-positive weights. - if (ExitWeight < CurHeaderWeight) - CurHeaderWeight -= ExitWeight; - else - CurHeaderWeight = 1; - cloneLoopBlocks(L, Iter, InsertTop, InsertBot, ExitEdges, NewBlocks, LoopBlocks, VMap, LVMap, DT, LI); @@ -697,8 +713,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, } auto *LatchBRCopy = cast<BranchInst>(VMap[LatchBR]); - updateBranchWeights(InsertBot, LatchBRCopy, Iter, - PeelCount, ExitWeight); + updateBranchWeights(InsertBot, LatchBRCopy, ExitWeight, FallThroughWeight); // Remove Loop metadata from the latch branch instruction // because it is not the Loop's latch branch anymore. LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr); @@ -724,7 +739,13 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, PHI->setIncomingValueForBlock(NewPreHeader, NewVal); } - fixupBranchWeights(Header, LatchBR, ExitWeight, CurHeaderWeight); + fixupBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight); + + // Update Metadata for count of peeled off iterations. + unsigned AlreadyPeeled = 0; + if (auto Peeled = getOptionalIntLoopAttribute(L, PeeledCountMetaData)) + AlreadyPeeled = *Peeled; + addStringMetadataToLoop(L, PeeledCountMetaData, AlreadyPeeled + PeelCount); if (Loop *ParentLoop = L->getParentLoop()) L = ParentLoop; diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index ec226e65f650..b4d7f35d2d9a 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -45,6 +46,7 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "loop-utils" static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced"; +static const char *LLVMLoopDisableLICM = "llvm.licm.disable"; bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, MemorySSAUpdater *MSSAU, @@ -169,6 +171,8 @@ void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) { AU.addPreserved<SCEVAAWrapperPass>(); AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); + // FIXME: When all loop passes preserve MemorySSA, it can be required and + // preserved here instead of the individual handling in each pass. } /// Manually defined generic "LoopPass" dependency initialization. This is used @@ -189,6 +193,54 @@ void llvm::initializeLoopPassPass(PassRegistry &Registry) { INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) + INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +} + +/// Create MDNode for input string. +static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = { + MDString::get(Context, Name), + ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); +} + +/// Set input string into loop metadata by keeping other values intact. +/// If the string is already in loop metadata update value if it is +/// different. +void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD, + unsigned V) { + SmallVector<Metadata *, 4> MDs(1); + // If the loop already has metadata, retain it. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); + // If it is of form key = value, try to parse it. + if (Node->getNumOperands() == 2) { + MDString *S = dyn_cast<MDString>(Node->getOperand(0)); + if (S && S->getString().equals(StringMD)) { + ConstantInt *IntMD = + mdconst::extract_or_null<ConstantInt>(Node->getOperand(1)); + if (IntMD && IntMD->getSExtValue() == V) + // It is already in place. Do nothing. + return; + // We need to update the value, so just skip it here and it will + // be added after copying other existed nodes. + continue; + } + } + MDs.push_back(Node); + } + } + // Add new metadata. + MDs.push_back(createStringMetadata(TheLoop, StringMD, V)); + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + TheLoop->setLoopID(NewLoopID); } /// Find string metadata for loop @@ -332,6 +384,10 @@ bool llvm::hasDisableAllTransformsHint(const Loop *L) { return getBooleanLoopAttribute(L, LLVMLoopDisableNonforced); } +bool llvm::hasDisableLICMTransformsHint(const Loop *L) { + return getBooleanLoopAttribute(L, LLVMLoopDisableLICM); +} + TransformationMode llvm::hasUnrollTransformation(Loop *L) { if (getBooleanLoopAttribute(L, "llvm.loop.unroll.disable")) return TM_SuppressedByUser; diff --git a/lib/Transforms/Utils/LoopVersioning.cpp b/lib/Transforms/Utils/LoopVersioning.cpp index a9a480a4b7f9..5d7759056c7d 100644 --- a/lib/Transforms/Utils/LoopVersioning.cpp +++ b/lib/Transforms/Utils/LoopVersioning.cpp @@ -92,8 +92,8 @@ void LoopVersioning::versionLoop( // Create empty preheader for the loop (and after cloning for the // non-versioned loop). BasicBlock *PH = - SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI); - PH->setName(VersionedLoop->getHeader()->getName() + ".ph"); + SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI, + nullptr, VersionedLoop->getHeader()->getName() + ".ph"); // Clone the loop including the preheader. // diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp index c0b7edc547fd..60bb2775a194 100644 --- a/lib/Transforms/Utils/MetaRenamer.cpp +++ b/lib/Transforms/Utils/MetaRenamer.cpp @@ -121,15 +121,14 @@ namespace { } // Rename all functions - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); for (auto &F : M) { StringRef Name = F.getName(); LibFunc Tmp; // Leave library functions alone because their presence or absence could // affect the behavior of other passes. if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || - TLI.getLibFunc(F, Tmp)) + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F).getLibFunc( + F, Tmp)) continue; // Leave @main alone. The output of -metarenamer might be passed to diff --git a/lib/Transforms/Utils/MisExpect.cpp b/lib/Transforms/Utils/MisExpect.cpp new file mode 100644 index 000000000000..26d3402bd279 --- /dev/null +++ b/lib/Transforms/Utils/MisExpect.cpp @@ -0,0 +1,177 @@ +//===--- MisExpect.cpp - Check the use of llvm.expect with PGO data -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This contains code to emit warnings for potentially incorrect usage of the +// llvm.expect intrinsic. This utility extracts the threshold values from +// metadata associated with the instrumented Branch or Switch instruction. The +// threshold values are then used to determine if a warning should be emmited. +// +// MisExpect metadata is generated when llvm.expect intrinsics are lowered see +// LowerExpectIntrinsic.cpp +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/MisExpect.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include <cstdint> +#include <functional> +#include <numeric> + +#define DEBUG_TYPE "misexpect" + +using namespace llvm; +using namespace misexpect; + +namespace llvm { + +// Command line option to enable/disable the warning when profile data suggests +// a mismatch with the use of the llvm.expect intrinsic +static cl::opt<bool> PGOWarnMisExpect( + "pgo-warn-misexpect", cl::init(false), cl::Hidden, + cl::desc("Use this option to turn on/off " + "warnings about incorrect usage of llvm.expect intrinsics.")); + +} // namespace llvm + +namespace { + +Instruction *getOprndOrInst(Instruction *I) { + assert(I != nullptr && "MisExpect target Instruction cannot be nullptr"); + Instruction *Ret = nullptr; + if (auto *B = dyn_cast<BranchInst>(I)) { + Ret = dyn_cast<Instruction>(B->getCondition()); + } + // TODO: Find a way to resolve condition location for switches + // Using the condition of the switch seems to often resolve to an earlier + // point in the program, i.e. the calculation of the switch condition, rather + // than the switches location in the source code. Thus, we should use the + // instruction to get source code locations rather than the condition to + // improve diagnostic output, such as the caret. If the same problem exists + // for branch instructions, then we should remove this function and directly + // use the instruction + // + // else if (auto S = dyn_cast<SwitchInst>(I)) { + // Ret = I; + //} + return Ret ? Ret : I; +} + +void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx, + uint64_t ProfCount, uint64_t TotalCount) { + double PercentageCorrect = (double)ProfCount / TotalCount; + auto PerString = + formatv("{0:P} ({1} / {2})", PercentageCorrect, ProfCount, TotalCount); + auto RemStr = formatv( + "Potential performance regression from use of the llvm.expect intrinsic: " + "Annotation was correct on {0} of profiled executions.", + PerString); + Twine Msg(PerString); + Instruction *Cond = getOprndOrInst(I); + if (PGOWarnMisExpect) + Ctx.diagnose(DiagnosticInfoMisExpect(Cond, Msg)); + OptimizationRemarkEmitter ORE(I->getParent()->getParent()); + ORE.emit(OptimizationRemark(DEBUG_TYPE, "misexpect", Cond) << RemStr.str()); +} + +} // namespace + +namespace llvm { +namespace misexpect { + +void verifyMisExpect(Instruction *I, const SmallVector<uint32_t, 4> &Weights, + LLVMContext &Ctx) { + if (auto *MisExpectData = I->getMetadata(LLVMContext::MD_misexpect)) { + auto *MisExpectDataName = dyn_cast<MDString>(MisExpectData->getOperand(0)); + if (MisExpectDataName && + MisExpectDataName->getString().equals("misexpect")) { + LLVM_DEBUG(llvm::dbgs() << "------------------\n"); + LLVM_DEBUG(llvm::dbgs() + << "Function: " << I->getFunction()->getName() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Instruction: " << *I << ":\n"); + LLVM_DEBUG(for (int Idx = 0, Size = Weights.size(); Idx < Size; ++Idx) { + llvm::dbgs() << "Weights[" << Idx << "] = " << Weights[Idx] << "\n"; + }); + + // extract values from misexpect metadata + const auto *IndexCint = + mdconst::dyn_extract<ConstantInt>(MisExpectData->getOperand(1)); + const auto *LikelyCInt = + mdconst::dyn_extract<ConstantInt>(MisExpectData->getOperand(2)); + const auto *UnlikelyCInt = + mdconst::dyn_extract<ConstantInt>(MisExpectData->getOperand(3)); + + if (!IndexCint || !LikelyCInt || !UnlikelyCInt) + return; + + const uint64_t Index = IndexCint->getZExtValue(); + const uint64_t LikelyBranchWeight = LikelyCInt->getZExtValue(); + const uint64_t UnlikelyBranchWeight = UnlikelyCInt->getZExtValue(); + const uint64_t ProfileCount = Weights[Index]; + const uint64_t CaseTotal = std::accumulate( + Weights.begin(), Weights.end(), (uint64_t)0, std::plus<uint64_t>()); + const uint64_t NumUnlikelyTargets = Weights.size() - 1; + + const uint64_t TotalBranchWeight = + LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets); + + const llvm::BranchProbability LikelyThreshold(LikelyBranchWeight, + TotalBranchWeight); + uint64_t ScaledThreshold = LikelyThreshold.scale(CaseTotal); + + LLVM_DEBUG(llvm::dbgs() + << "Unlikely Targets: " << NumUnlikelyTargets << ":\n"); + LLVM_DEBUG(llvm::dbgs() << "Profile Count: " << ProfileCount << ":\n"); + LLVM_DEBUG(llvm::dbgs() + << "Scaled Threshold: " << ScaledThreshold << ":\n"); + LLVM_DEBUG(llvm::dbgs() << "------------------\n"); + if (ProfileCount < ScaledThreshold) + emitMisexpectDiagnostic(I, Ctx, ProfileCount, CaseTotal); + } + } +} + +void checkFrontendInstrumentation(Instruction &I) { + if (auto *MD = I.getMetadata(LLVMContext::MD_prof)) { + unsigned NOps = MD->getNumOperands(); + + // Only emit misexpect diagnostics if at least 2 branch weights are present. + // Less than 2 branch weights means that the profiling metadata is: + // 1) incorrect/corrupted + // 2) not branch weight metadata + // 3) completely deterministic + // In these cases we should not emit any diagnostic related to misexpect. + if (NOps < 3) + return; + + // Operand 0 is a string tag "branch_weights" + if (MDString *Tag = cast<MDString>(MD->getOperand(0))) { + if (Tag->getString().equals("branch_weights")) { + SmallVector<uint32_t, 4> RealWeights(NOps - 1); + for (unsigned i = 1; i < NOps; i++) { + ConstantInt *Value = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(i)); + RealWeights[i - 1] = Value->getZExtValue(); + } + verifyMisExpect(&I, RealWeights, I.getContext()); + } + } + } +} + +} // namespace misexpect +} // namespace llvm +#undef DEBUG_TYPE diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index c84beceee191..1ef3757017a8 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -73,7 +73,7 @@ static void appendToUsedList(Module &M, StringRef Name, ArrayRef<GlobalValue *> SmallPtrSet<Constant *, 16> InitAsSet; SmallVector<Constant *, 16> Init; if (GV) { - ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer()); + auto *CA = cast<ConstantArray>(GV->getInitializer()); for (auto &Op : CA->operands()) { Constant *C = cast_or_null<Constant>(Op); if (InitAsSet.insert(C).second) diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp index bdf24d80bd17..44859eafb9c1 100644 --- a/lib/Transforms/Utils/PredicateInfo.cpp +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -125,8 +125,10 @@ static bool valueComesBefore(OrderedInstructions &OI, const Value *A, // necessary to compare uses/defs in the same block. Doing so allows us to walk // the minimum number of instructions necessary to compute our def/use ordering. struct ValueDFS_Compare { + DominatorTree &DT; OrderedInstructions &OI; - ValueDFS_Compare(OrderedInstructions &OI) : OI(OI) {} + ValueDFS_Compare(DominatorTree &DT, OrderedInstructions &OI) + : DT(DT), OI(OI) {} bool operator()(const ValueDFS &A, const ValueDFS &B) const { if (&A == &B) @@ -136,7 +138,9 @@ struct ValueDFS_Compare { // comesbefore to see what the real ordering is, because they are in the // same basic block. - bool SameBlock = std::tie(A.DFSIn, A.DFSOut) == std::tie(B.DFSIn, B.DFSOut); + assert((A.DFSIn != B.DFSIn || A.DFSOut == B.DFSOut) && + "Equal DFS-in numbers imply equal out numbers"); + bool SameBlock = A.DFSIn == B.DFSIn; // We want to put the def that will get used for a given set of phi uses, // before those phi uses. @@ -145,9 +149,11 @@ struct ValueDFS_Compare { if (SameBlock && A.LocalNum == LN_Last && B.LocalNum == LN_Last) return comparePHIRelated(A, B); + bool isADef = A.Def; + bool isBDef = B.Def; if (!SameBlock || A.LocalNum != LN_Middle || B.LocalNum != LN_Middle) - return std::tie(A.DFSIn, A.DFSOut, A.LocalNum, A.Def, A.U) < - std::tie(B.DFSIn, B.DFSOut, B.LocalNum, B.Def, B.U); + return std::tie(A.DFSIn, A.LocalNum, isADef) < + std::tie(B.DFSIn, B.LocalNum, isBDef); return localComesBefore(A, B); } @@ -164,10 +170,35 @@ struct ValueDFS_Compare { // For two phi related values, return the ordering. bool comparePHIRelated(const ValueDFS &A, const ValueDFS &B) const { - auto &ABlockEdge = getBlockEdge(A); - auto &BBlockEdge = getBlockEdge(B); - // Now sort by block edge and then defs before uses. - return std::tie(ABlockEdge, A.Def, A.U) < std::tie(BBlockEdge, B.Def, B.U); + BasicBlock *ASrc, *ADest, *BSrc, *BDest; + std::tie(ASrc, ADest) = getBlockEdge(A); + std::tie(BSrc, BDest) = getBlockEdge(B); + +#ifndef NDEBUG + // This function should only be used for values in the same BB, check that. + DomTreeNode *DomASrc = DT.getNode(ASrc); + DomTreeNode *DomBSrc = DT.getNode(BSrc); + assert(DomASrc->getDFSNumIn() == (unsigned)A.DFSIn && + "DFS numbers for A should match the ones of the source block"); + assert(DomBSrc->getDFSNumIn() == (unsigned)B.DFSIn && + "DFS numbers for B should match the ones of the source block"); + assert(A.DFSIn == B.DFSIn && "Values must be in the same block"); +#endif + (void)ASrc; + (void)BSrc; + + // Use DFS numbers to compare destination blocks, to guarantee a + // deterministic order. + DomTreeNode *DomADest = DT.getNode(ADest); + DomTreeNode *DomBDest = DT.getNode(BDest); + unsigned AIn = DomADest->getDFSNumIn(); + unsigned BIn = DomBDest->getDFSNumIn(); + bool isADef = A.Def; + bool isBDef = B.Def; + assert((!A.Def || !A.U) && (!B.Def || !B.U) && + "Def and U cannot be set at the same time"); + // Now sort by edge destination and then defs before uses. + return std::tie(AIn, isADef) < std::tie(BIn, isBDef); } // Get the definition of an instruction that occurs in the middle of a block. @@ -306,10 +337,11 @@ void collectCmpOps(CmpInst *Comparison, SmallVectorImpl<Value *> &CmpOperands) { } // Add Op, PB to the list of value infos for Op, and mark Op to be renamed. -void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op, +void PredicateInfo::addInfoFor(SmallVectorImpl<Value *> &OpsToRename, Value *Op, PredicateBase *PB) { - OpsToRename.insert(Op); auto &OperandInfo = getOrCreateValueInfo(Op); + if (OperandInfo.Infos.empty()) + OpsToRename.push_back(Op); AllInfos.push_back(PB); OperandInfo.Infos.push_back(PB); } @@ -317,7 +349,7 @@ void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op, // Process an assume instruction and place relevant operations we want to rename // into OpsToRename. void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB, - SmallPtrSetImpl<Value *> &OpsToRename) { + SmallVectorImpl<Value *> &OpsToRename) { // See if we have a comparison we support SmallVector<Value *, 8> CmpOperands; SmallVector<Value *, 2> ConditionsToProcess; @@ -357,7 +389,7 @@ void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB, // Process a block terminating branch, and place relevant operations to be // renamed into OpsToRename. void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB, - SmallPtrSetImpl<Value *> &OpsToRename) { + SmallVectorImpl<Value *> &OpsToRename) { BasicBlock *FirstBB = BI->getSuccessor(0); BasicBlock *SecondBB = BI->getSuccessor(1); SmallVector<BasicBlock *, 2> SuccsToProcess; @@ -427,7 +459,7 @@ void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB, // Process a block terminating switch, and place relevant operations to be // renamed into OpsToRename. void PredicateInfo::processSwitch(SwitchInst *SI, BasicBlock *BranchBB, - SmallPtrSetImpl<Value *> &OpsToRename) { + SmallVectorImpl<Value *> &OpsToRename) { Value *Op = SI->getCondition(); if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse()) return; @@ -457,7 +489,7 @@ void PredicateInfo::buildPredicateInfo() { DT.updateDFSNumbers(); // Collect operands to rename from all conditional branch terminators, as well // as assume statements. - SmallPtrSet<Value *, 8> OpsToRename; + SmallVector<Value *, 8> OpsToRename; for (auto DTN : depth_first(DT.getRootNode())) { BasicBlock *BranchBB = DTN->getBlock(); if (auto *BI = dyn_cast<BranchInst>(BranchBB->getTerminator())) { @@ -524,7 +556,7 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, if (isa<PredicateWithEdge>(ValInfo)) { IRBuilder<> B(getBranchTerminator(ValInfo)); Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); - if (empty(IF->users())) + if (IF->users().empty()) CreatedDeclarations.insert(IF); CallInst *PIC = B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++)); @@ -536,7 +568,7 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, "Should not have gotten here without it being an assume"); IRBuilder<> B(PAssume->AssumeInst); Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); - if (empty(IF->users())) + if (IF->users().empty()) CreatedDeclarations.insert(IF); CallInst *PIC = B.CreateCall(IF, Op); PredicateMap.insert({PIC, ValInfo}); @@ -565,14 +597,8 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, // // TODO: Use this algorithm to perform fast single-variable renaming in // promotememtoreg and memoryssa. -void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { - // Sort OpsToRename since we are going to iterate it. - SmallVector<Value *, 8> OpsToRename(OpSet.begin(), OpSet.end()); - auto Comparator = [&](const Value *A, const Value *B) { - return valueComesBefore(OI, A, B); - }; - llvm::sort(OpsToRename, Comparator); - ValueDFS_Compare Compare(OI); +void PredicateInfo::renameUses(SmallVectorImpl<Value *> &OpsToRename) { + ValueDFS_Compare Compare(DT, OI); // Compute liveness, and rename in O(uses) per Op. for (auto *Op : OpsToRename) { LLVM_DEBUG(dbgs() << "Visiting " << *Op << "\n"); @@ -772,7 +798,7 @@ static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) { bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto PredInfo = make_unique<PredicateInfo>(F, DT, AC); + auto PredInfo = std::make_unique<PredicateInfo>(F, DT, AC); PredInfo->print(dbgs()); if (VerifyPredicateInfo) PredInfo->verifyPredicateInfo(); @@ -786,7 +812,7 @@ PreservedAnalyses PredicateInfoPrinterPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); OS << "PredicateInfo for function: " << F.getName() << "\n"; - auto PredInfo = make_unique<PredicateInfo>(F, DT, AC); + auto PredInfo = std::make_unique<PredicateInfo>(F, DT, AC); PredInfo->print(OS); replaceCreatedSSACopys(*PredInfo, F); @@ -845,7 +871,7 @@ PreservedAnalyses PredicateInfoVerifierPass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo(); + std::make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo(); return PreservedAnalyses::all(); } diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 11651d040dc0..3a5e3293ed4f 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -94,6 +94,12 @@ static cl::opt<unsigned> PHINodeFoldingThreshold( cl::desc( "Control the amount of phi node folding to perform (default = 2)")); +static cl::opt<unsigned> TwoEntryPHINodeFoldingThreshold( + "two-entry-phi-node-folding-threshold", cl::Hidden, cl::init(4), + cl::desc("Control the maximal total instruction cost that we are willing " + "to speculatively execute to fold a 2-entry PHI node into a " + "select (default = 4)")); + static cl::opt<bool> DupRet( "simplifycfg-dup-ret", cl::Hidden, cl::init(false), cl::desc("Duplicate return instructions into unconditional branches")); @@ -332,7 +338,7 @@ static unsigned ComputeSpeculationCost(const User *I, /// CostRemaining, false is returned and CostRemaining is undefined. static bool DominatesMergePoint(Value *V, BasicBlock *BB, SmallPtrSetImpl<Instruction *> &AggressiveInsts, - unsigned &CostRemaining, + int &BudgetRemaining, const TargetTransformInfo &TTI, unsigned Depth = 0) { // It is possible to hit a zero-cost cycle (phi/gep instructions for example), @@ -375,7 +381,7 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, if (!isSafeToSpeculativelyExecute(I)) return false; - unsigned Cost = ComputeSpeculationCost(I, TTI); + BudgetRemaining -= ComputeSpeculationCost(I, TTI); // Allow exactly one instruction to be speculated regardless of its cost // (as long as it is safe to do so). @@ -383,17 +389,14 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, // or other expensive operation. The speculation of an expensive instruction // is expected to be undone in CodeGenPrepare if the speculation has not // enabled further IR optimizations. - if (Cost > CostRemaining && + if (BudgetRemaining < 0 && (!SpeculateOneExpensiveInst || !AggressiveInsts.empty() || Depth > 0)) return false; - // Avoid unsigned wrap. - CostRemaining = (Cost > CostRemaining) ? 0 : CostRemaining - Cost; - // Okay, we can only really hoist these out if their operands do // not take us over the cost threshold. for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) - if (!DominatesMergePoint(*i, BB, AggressiveInsts, CostRemaining, TTI, + if (!DominatesMergePoint(*i, BB, AggressiveInsts, BudgetRemaining, TTI, Depth + 1)) return false; // Okay, it's safe to do this! Remember this instruction. @@ -629,8 +632,7 @@ private: /// vector. /// One "Extra" case is allowed to differ from the other. void gather(Value *V) { - Instruction *I = dyn_cast<Instruction>(V); - bool isEQ = (I->getOpcode() == Instruction::Or); + bool isEQ = (cast<Instruction>(V)->getOpcode() == Instruction::Or); // Keep a stack (SmallVector for efficiency) for depth-first traversal SmallVector<Value *, 8> DFT; @@ -1313,7 +1315,8 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_mem_parallel_loop_access, - LLVMContext::MD_access_group}; + LLVMContext::MD_access_group, + LLVMContext::MD_preserve_access_index}; combineMetadata(I1, I2, KnownIDs, true); // I1 and I2 are being combined into a single instruction. Its debug @@ -1420,6 +1423,20 @@ HoistTerminator: return true; } +// Check lifetime markers. +static bool isLifeTimeMarker(const Instruction *I) { + if (auto II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + return true; + } + } + return false; +} + // All instructions in Insts belong to different blocks that all unconditionally // branch to a common successor. Analyze each instruction and return true if it // would be possible to sink them into their successor, creating one common @@ -1474,20 +1491,25 @@ static bool canSinkInstructions( return false; } - // Because SROA can't handle speculating stores of selects, try not - // to sink loads or stores of allocas when we'd have to create a PHI for - // the address operand. Also, because it is likely that loads or stores - // of allocas will disappear when Mem2Reg/SROA is run, don't sink them. + // Because SROA can't handle speculating stores of selects, try not to sink + // loads, stores or lifetime markers of allocas when we'd have to create a + // PHI for the address operand. Also, because it is likely that loads or + // stores of allocas will disappear when Mem2Reg/SROA is run, don't sink + // them. // This can cause code churn which can have unintended consequences down // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244. // FIXME: This is a workaround for a deficiency in SROA - see // https://llvm.org/bugs/show_bug.cgi?id=30188 if (isa<StoreInst>(I0) && any_of(Insts, [](const Instruction *I) { - return isa<AllocaInst>(I->getOperand(1)); + return isa<AllocaInst>(I->getOperand(1)->stripPointerCasts()); })) return false; if (isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) { - return isa<AllocaInst>(I->getOperand(0)); + return isa<AllocaInst>(I->getOperand(0)->stripPointerCasts()); + })) + return false; + if (isLifeTimeMarker(I0) && any_of(Insts, [](const Instruction *I) { + return isa<AllocaInst>(I->getOperand(1)->stripPointerCasts()); })) return false; @@ -1959,7 +1981,7 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, SmallVector<Instruction *, 4> SpeculatedDbgIntrinsics; - unsigned SpeculationCost = 0; + unsigned SpeculatedInstructions = 0; Value *SpeculatedStoreValue = nullptr; StoreInst *SpeculatedStore = nullptr; for (BasicBlock::iterator BBI = ThenBB->begin(), @@ -1974,8 +1996,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Only speculatively execute a single instruction (not counting the // terminator) for now. - ++SpeculationCost; - if (SpeculationCost > 1) + ++SpeculatedInstructions; + if (SpeculatedInstructions > 1) return false; // Don't hoist the instruction if it's unsafe or expensive. @@ -2012,8 +2034,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, E = SinkCandidateUseCounts.end(); I != E; ++I) if (I->first->hasNUses(I->second)) { - ++SpeculationCost; - if (SpeculationCost > 1) + ++SpeculatedInstructions; + if (SpeculatedInstructions > 1) return false; } @@ -2053,8 +2075,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // getting expanded into Instructions. // FIXME: This doesn't account for how many operations are combined in the // constant expression. - ++SpeculationCost; - if (SpeculationCost > 1) + ++SpeculatedInstructions; + if (SpeculatedInstructions > 1) return false; } @@ -2302,10 +2324,8 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // instructions. While we are at it, keep track of the instructions // that need to be moved to the dominating block. SmallPtrSet<Instruction *, 4> AggressiveInsts; - unsigned MaxCostVal0 = PHINodeFoldingThreshold, - MaxCostVal1 = PHINodeFoldingThreshold; - MaxCostVal0 *= TargetTransformInfo::TCC_Basic; - MaxCostVal1 *= TargetTransformInfo::TCC_Basic; + int BudgetRemaining = + TwoEntryPHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic; for (BasicBlock::iterator II = BB->begin(); isa<PHINode>(II);) { PHINode *PN = cast<PHINode>(II++); @@ -2316,9 +2336,9 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, } if (!DominatesMergePoint(PN->getIncomingValue(0), BB, AggressiveInsts, - MaxCostVal0, TTI) || + BudgetRemaining, TTI) || !DominatesMergePoint(PN->getIncomingValue(1), BB, AggressiveInsts, - MaxCostVal1, TTI)) + BudgetRemaining, TTI)) return false; } @@ -2328,12 +2348,24 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, if (!PN) return true; - // Don't fold i1 branches on PHIs which contain binary operators. These can - // often be turned into switches and other things. + // Return true if at least one of these is a 'not', and another is either + // a 'not' too, or a constant. + auto CanHoistNotFromBothValues = [](Value *V0, Value *V1) { + if (!match(V0, m_Not(m_Value()))) + std::swap(V0, V1); + auto Invertible = m_CombineOr(m_Not(m_Value()), m_AnyIntegralConstant()); + return match(V0, m_Not(m_Value())) && match(V1, Invertible); + }; + + // Don't fold i1 branches on PHIs which contain binary operators, unless one + // of the incoming values is an 'not' and another one is freely invertible. + // These can often be turned into switches and other things. if (PN->getType()->isIntegerTy(1) && (isa<BinaryOperator>(PN->getIncomingValue(0)) || isa<BinaryOperator>(PN->getIncomingValue(1)) || - isa<BinaryOperator>(IfCond))) + isa<BinaryOperator>(IfCond)) && + !CanHoistNotFromBothValues(PN->getIncomingValue(0), + PN->getIncomingValue(1))) return false; // If all PHI nodes are promotable, check to make sure that all instructions @@ -2368,6 +2400,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, return false; } } + assert(DomBlock && "Failed to find root DomBlock"); LLVM_DEBUG(dbgs() << "FOUND IF CONDITION! " << *IfCond << " T: " << IfTrue->getName() @@ -2913,42 +2946,8 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, BasicBlock *QTB, BasicBlock *QFB, BasicBlock *PostBB, Value *Address, bool InvertPCond, bool InvertQCond, - const DataLayout &DL) { - auto IsaBitcastOfPointerType = [](const Instruction &I) { - return Operator::getOpcode(&I) == Instruction::BitCast && - I.getType()->isPointerTy(); - }; - - // If we're not in aggressive mode, we only optimize if we have some - // confidence that by optimizing we'll allow P and/or Q to be if-converted. - auto IsWorthwhile = [&](BasicBlock *BB) { - if (!BB) - return true; - // Heuristic: if the block can be if-converted/phi-folded and the - // instructions inside are all cheap (arithmetic/GEPs), it's worthwhile to - // thread this store. - unsigned N = 0; - for (auto &I : BB->instructionsWithoutDebug()) { - // Cheap instructions viable for folding. - if (isa<BinaryOperator>(I) || isa<GetElementPtrInst>(I) || - isa<StoreInst>(I)) - ++N; - // Free instructions. - else if (I.isTerminator() || IsaBitcastOfPointerType(I)) - continue; - else - return false; - } - // The store we want to merge is counted in N, so add 1 to make sure - // we're counting the instructions that would be left. - return N <= (PHINodeFoldingThreshold + 1); - }; - - if (!MergeCondStoresAggressively && - (!IsWorthwhile(PTB) || !IsWorthwhile(PFB) || !IsWorthwhile(QTB) || - !IsWorthwhile(QFB))) - return false; - + const DataLayout &DL, + const TargetTransformInfo &TTI) { // For every pointer, there must be exactly two stores, one coming from // PTB or PFB, and the other from QTB or QFB. We don't support more than one // store (to any address) in PTB,PFB or QTB,QFB. @@ -2989,6 +2988,46 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, if (&*I != PStore && I->mayReadOrWriteMemory()) return false; + // If we're not in aggressive mode, we only optimize if we have some + // confidence that by optimizing we'll allow P and/or Q to be if-converted. + auto IsWorthwhile = [&](BasicBlock *BB, ArrayRef<StoreInst *> FreeStores) { + if (!BB) + return true; + // Heuristic: if the block can be if-converted/phi-folded and the + // instructions inside are all cheap (arithmetic/GEPs), it's worthwhile to + // thread this store. + int BudgetRemaining = + PHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic; + for (auto &I : BB->instructionsWithoutDebug()) { + // Consider terminator instruction to be free. + if (I.isTerminator()) + continue; + // If this is one the stores that we want to speculate out of this BB, + // then don't count it's cost, consider it to be free. + if (auto *S = dyn_cast<StoreInst>(&I)) + if (llvm::find(FreeStores, S)) + continue; + // Else, we have a white-list of instructions that we are ak speculating. + if (!isa<BinaryOperator>(I) && !isa<GetElementPtrInst>(I)) + return false; // Not in white-list - not worthwhile folding. + // And finally, if this is a non-free instruction that we are okay + // speculating, ensure that we consider the speculation budget. + BudgetRemaining -= TTI.getUserCost(&I); + if (BudgetRemaining < 0) + return false; // Eagerly refuse to fold as soon as we're out of budget. + } + assert(BudgetRemaining >= 0 && + "When we run out of budget we will eagerly return from within the " + "per-instruction loop."); + return true; + }; + + const SmallVector<StoreInst *, 2> FreeStores = {PStore, QStore}; + if (!MergeCondStoresAggressively && + (!IsWorthwhile(PTB, FreeStores) || !IsWorthwhile(PFB, FreeStores) || + !IsWorthwhile(QTB, FreeStores) || !IsWorthwhile(QFB, FreeStores))) + return false; + // If PostBB has more than two predecessors, we need to split it so we can // sink the store. if (std::next(pred_begin(PostBB), 2) != pred_end(PostBB)) { @@ -3048,15 +3087,15 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, // store that doesn't execute. if (MinAlignment != 0) { // Choose the minimum of all non-zero alignments. - SI->setAlignment(MinAlignment); + SI->setAlignment(Align(MinAlignment)); } else if (MaxAlignment != 0) { // Choose the minimal alignment between the non-zero alignment and the ABI // default alignment for the type of the stored value. - SI->setAlignment(std::min(MaxAlignment, TypeAlignment)); + SI->setAlignment(Align(std::min(MaxAlignment, TypeAlignment))); } else { // If both alignments are zero, use ABI default alignment for the type of // the stored value. - SI->setAlignment(TypeAlignment); + SI->setAlignment(Align(TypeAlignment)); } QStore->eraseFromParent(); @@ -3066,7 +3105,8 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, } static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI, - const DataLayout &DL) { + const DataLayout &DL, + const TargetTransformInfo &TTI) { // The intention here is to find diamonds or triangles (see below) where each // conditional block contains a store to the same address. Both of these // stores are conditional, so they can't be unconditionally sunk. But it may @@ -3168,7 +3208,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI, bool Changed = false; for (auto *Address : CommonAddresses) Changed |= mergeConditionalStoreToAddress( - PTB, PFB, QTB, QFB, PostBB, Address, InvertPCond, InvertQCond, DL); + PTB, PFB, QTB, QFB, PostBB, Address, InvertPCond, InvertQCond, DL, TTI); return Changed; } @@ -3177,7 +3217,8 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI, /// that PBI and BI are both conditional branches, and BI is in one of the /// successor blocks of PBI - PBI branches to BI. static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, - const DataLayout &DL) { + const DataLayout &DL, + const TargetTransformInfo &TTI) { assert(PBI->isConditional() && BI->isConditional()); BasicBlock *BB = BI->getParent(); @@ -3233,7 +3274,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // If both branches are conditional and both contain stores to the same // address, remove the stores from the conditionals and create a conditional // merged store at the end. - if (MergeCondStores && mergeConditionalStores(PBI, BI, DL)) + if (MergeCondStores && mergeConditionalStores(PBI, BI, DL, TTI)) return true; // If this is a conditional branch in an empty block, and if any @@ -3697,12 +3738,17 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, BasicBlock *BB = BI->getParent(); + // MSAN does not like undefs as branch condition which can be introduced + // with "explicit branch". + if (ExtraCase && BB->getParent()->hasFnAttribute(Attribute::SanitizeMemory)) + return false; + LLVM_DEBUG(dbgs() << "Converting 'icmp' chain with " << Values.size() << " cases into SWITCH. BB is:\n" << *BB); // If there are any extra values that couldn't be folded into the switch - // then we evaluate them with an explicit branch first. Split the block + // then we evaluate them with an explicit branch first. Split the block // right before the condbr to handle it. if (ExtraCase) { BasicBlock *NewBB = @@ -3851,7 +3897,7 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { // Simplify resume that is only used by a single (non-phi) landing pad. bool SimplifyCFGOpt::SimplifySingleResume(ResumeInst *RI) { BasicBlock *BB = RI->getParent(); - LandingPadInst *LPInst = dyn_cast<LandingPadInst>(BB->getFirstNonPHI()); + auto *LPInst = cast<LandingPadInst>(BB->getFirstNonPHI()); assert(RI->getValue() == LPInst && "Resume must unwind the exception that caused control to here"); @@ -4178,23 +4224,22 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { IRBuilder<> Builder(TI); if (auto *BI = dyn_cast<BranchInst>(TI)) { if (BI->isUnconditional()) { - if (BI->getSuccessor(0) == BB) { - new UnreachableInst(TI->getContext(), TI); - TI->eraseFromParent(); - Changed = true; - } + assert(BI->getSuccessor(0) == BB && "Incorrect CFG"); + new UnreachableInst(TI->getContext(), TI); + TI->eraseFromParent(); + Changed = true; } else { Value* Cond = BI->getCondition(); if (BI->getSuccessor(0) == BB) { Builder.CreateAssumption(Builder.CreateNot(Cond)); Builder.CreateBr(BI->getSuccessor(1)); - EraseTerminatorAndDCECond(BI); - } else if (BI->getSuccessor(1) == BB) { + } else { + assert(BI->getSuccessor(1) == BB && "Incorrect CFG"); Builder.CreateAssumption(Cond); Builder.CreateBr(BI->getSuccessor(0)); - EraseTerminatorAndDCECond(BI); - Changed = true; } + EraseTerminatorAndDCECond(BI); + Changed = true; } } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { SwitchInstProfUpdateWrapper SU(*SI); @@ -4276,6 +4321,17 @@ static bool CasesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) { return true; } +static void createUnreachableSwitchDefault(SwitchInst *Switch) { + LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); + BasicBlock *NewDefaultBlock = + SplitBlockPredecessors(Switch->getDefaultDest(), Switch->getParent(), ""); + Switch->setDefaultDest(&*NewDefaultBlock); + SplitBlock(&*NewDefaultBlock, &NewDefaultBlock->front()); + auto *NewTerminator = NewDefaultBlock->getTerminator(); + new UnreachableInst(Switch->getContext(), NewTerminator); + EraseTerminatorAndDCECond(NewTerminator); +} + /// Turn a switch with two reachable destinations into an integer range /// comparison and branch. static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { @@ -4384,6 +4440,11 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); } + // Clean up the default block - it may have phis or other instructions before + // the unreachable terminator. + if (!HasDefault) + createUnreachableSwitchDefault(SI); + // Drop the switch. SI->eraseFromParent(); @@ -4428,14 +4489,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, if (HasDefault && DeadCases.empty() && NumUnknownBits < 64 /* avoid overflow */ && SI->getNumCases() == (1ULL << NumUnknownBits)) { - LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); - BasicBlock *NewDefault = - SplitBlockPredecessors(SI->getDefaultDest(), SI->getParent(), ""); - SI->setDefaultDest(&*NewDefault); - SplitBlock(&*NewDefault, &NewDefault->front()); - auto *OldTI = NewDefault->getTerminator(); - new UnreachableInst(SI->getContext(), OldTI); - EraseTerminatorAndDCECond(OldTI); + createUnreachableSwitchDefault(SI); return true; } @@ -5031,7 +5085,7 @@ SwitchLookupTable::SwitchLookupTable( Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Set the alignment to that of an array items. We will be only loading one // value out of it. - Array->setAlignment(DL.getPrefTypeAlignment(ValueType)); + Array->setAlignment(Align(DL.getPrefTypeAlignment(ValueType))); Kind = ArrayKind; } @@ -5260,7 +5314,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Figure out the corresponding result for each case value and phi node in the // common destination, as well as the min and max case values. - assert(!empty(SI->cases())); + assert(!SI->cases().empty()); SwitchInst::CaseIt CI = SI->case_begin(); ConstantInt *MinCaseVal = CI->getCaseValue(); ConstantInt *MaxCaseVal = CI->getCaseValue(); @@ -5892,7 +5946,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) - if (SimplifyCondBranchToCondBranch(PBI, BI, DL)) + if (SimplifyCondBranchToCondBranch(PBI, BI, DL, TTI)) return requestResimplify(); // Look for diamond patterns. @@ -5900,7 +5954,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BasicBlock *PrevBB = allPredecessorsComeFromSameSource(BB)) if (BranchInst *PBI = dyn_cast<BranchInst>(PrevBB->getTerminator())) if (PBI != BI && PBI->isConditional()) - if (mergeConditionalStores(PBI, BI, DL)) + if (mergeConditionalStores(PBI, BI, DL, TTI)) return requestResimplify(); return false; diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index e0def81d5eee..0324993a8203 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/SizeOpts.h" @@ -47,7 +48,6 @@ static cl::opt<bool> cl::desc("Enable unsafe double to float " "shrinking for math lib calls")); - //===----------------------------------------------------------------------===// // Helper Functions //===----------------------------------------------------------------------===// @@ -177,7 +177,8 @@ static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len, if (!isOnlyUsedInComparisonWithZero(CI)) return false; - if (!isDereferenceableAndAlignedPointer(Str, 1, APInt(64, Len), DL)) + if (!isDereferenceableAndAlignedPointer(Str, Align::None(), APInt(64, Len), + DL)) return false; if (CI->getFunction()->hasFnAttribute(Attribute::SanitizeMemory)) @@ -186,6 +187,67 @@ static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len, return true; } +static void annotateDereferenceableBytes(CallInst *CI, + ArrayRef<unsigned> ArgNos, + uint64_t DereferenceableBytes) { + const Function *F = CI->getCaller(); + if (!F) + return; + for (unsigned ArgNo : ArgNos) { + uint64_t DerefBytes = DereferenceableBytes; + unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace(); + if (!llvm::NullPointerIsDefined(F, AS) || + CI->paramHasAttr(ArgNo, Attribute::NonNull)) + DerefBytes = std::max(CI->getDereferenceableOrNullBytes( + ArgNo + AttributeList::FirstArgIndex), + DereferenceableBytes); + + if (CI->getDereferenceableBytes(ArgNo + AttributeList::FirstArgIndex) < + DerefBytes) { + CI->removeParamAttr(ArgNo, Attribute::Dereferenceable); + if (!llvm::NullPointerIsDefined(F, AS) || + CI->paramHasAttr(ArgNo, Attribute::NonNull)) + CI->removeParamAttr(ArgNo, Attribute::DereferenceableOrNull); + CI->addParamAttr(ArgNo, Attribute::getWithDereferenceableBytes( + CI->getContext(), DerefBytes)); + } + } +} + +static void annotateNonNullBasedOnAccess(CallInst *CI, + ArrayRef<unsigned> ArgNos) { + Function *F = CI->getCaller(); + if (!F) + return; + + for (unsigned ArgNo : ArgNos) { + if (CI->paramHasAttr(ArgNo, Attribute::NonNull)) + continue; + unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace(); + if (llvm::NullPointerIsDefined(F, AS)) + continue; + + CI->addParamAttr(ArgNo, Attribute::NonNull); + annotateDereferenceableBytes(CI, ArgNo, 1); + } +} + +static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> ArgNos, + Value *Size, const DataLayout &DL) { + if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size)) { + annotateNonNullBasedOnAccess(CI, ArgNos); + annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue()); + } else if (isKnownNonZero(Size, DL)) { + annotateNonNullBasedOnAccess(CI, ArgNos); + const APInt *X, *Y; + uint64_t DerefMin = 1; + if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) { + DerefMin = std::min(X->getZExtValue(), Y->getZExtValue()); + annotateDereferenceableBytes(CI, ArgNos, DerefMin); + } + } +} + //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// @@ -194,10 +256,13 @@ Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilder<> &B) { // Extract some information from the instruction Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); + annotateNonNullBasedOnAccess(CI, {0, 1}); // See if we can get the length of the input string. uint64_t Len = GetStringLength(Src); - if (Len == 0) + if (Len) + annotateDereferenceableBytes(CI, 1, Len); + else return nullptr; --Len; // Unbias length. @@ -232,24 +297,34 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) { // Extract some information from the instruction. Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); + Value *Size = CI->getArgOperand(2); uint64_t Len; + annotateNonNullBasedOnAccess(CI, 0); + if (isKnownNonZero(Size, DL)) + annotateNonNullBasedOnAccess(CI, 1); // We don't do anything if length is not constant. - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size); + if (LengthArg) { Len = LengthArg->getZExtValue(); - else + // strncat(x, c, 0) -> x + if (!Len) + return Dst; + } else { return nullptr; + } // See if we can get the length of the input string. uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) + if (SrcLen) { + annotateDereferenceableBytes(CI, 1, SrcLen); + --SrcLen; // Unbias length. + } else { return nullptr; - --SrcLen; // Unbias length. + } - // Handle the simple, do-nothing cases: // strncat(x, "", c) -> x - // strncat(x, c, 0) -> x - if (SrcLen == 0 || Len == 0) + if (SrcLen == 0) return Dst; // We don't optimize this case. @@ -265,13 +340,18 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); Value *SrcStr = CI->getArgOperand(0); + annotateNonNullBasedOnAccess(CI, 0); // If the second operand is non-constant, see if we can compute the length // of the input string and turn this into memchr. ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); if (!CharC) { uint64_t Len = GetStringLength(SrcStr); - if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. + if (Len) + annotateDereferenceableBytes(CI, 0, Len); + else + return nullptr; + if (!FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. return nullptr; return emitMemChr(SrcStr, CI->getArgOperand(1), // include nul. @@ -304,6 +384,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { Value *SrcStr = CI->getArgOperand(0); ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + annotateNonNullBasedOnAccess(CI, 0); // Cannot fold anything if we're not looking for a constant. if (!CharC) @@ -351,7 +432,12 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { // strcmp(P, "x") -> memcmp(P, "x", 2) uint64_t Len1 = GetStringLength(Str1P); + if (Len1) + annotateDereferenceableBytes(CI, 0, Len1); uint64_t Len2 = GetStringLength(Str2P); + if (Len2) + annotateDereferenceableBytes(CI, 1, Len2); + if (Len1 && Len2) { return emitMemCmp(Str1P, Str2P, ConstantInt::get(DL.getIntPtrType(CI->getContext()), @@ -374,17 +460,22 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { TLI); } + annotateNonNullBasedOnAccess(CI, {0, 1}); return nullptr; } Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + Value *Str1P = CI->getArgOperand(0); + Value *Str2P = CI->getArgOperand(1); + Value *Size = CI->getArgOperand(2); if (Str1P == Str2P) // strncmp(x,x,n) -> 0 return ConstantInt::get(CI->getType(), 0); + if (isKnownNonZero(Size, DL)) + annotateNonNullBasedOnAccess(CI, {0, 1}); // Get the length argument if it is constant. uint64_t Length; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size)) Length = LengthArg->getZExtValue(); else return nullptr; @@ -393,7 +484,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { return ConstantInt::get(CI->getType(), 0); if (Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) - return emitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); + return emitMemCmp(Str1P, Str2P, Size, B, DL, TLI); StringRef Str1, Str2; bool HasStr1 = getConstantStringInfo(Str1P, Str1); @@ -415,7 +506,11 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { CI->getType()); uint64_t Len1 = GetStringLength(Str1P); + if (Len1) + annotateDereferenceableBytes(CI, 0, Len1); uint64_t Len2 = GetStringLength(Str2P); + if (Len2) + annotateDereferenceableBytes(CI, 1, Len2); // strncmp to memcmp if (!HasStr1 && HasStr2) { @@ -437,20 +532,38 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { return nullptr; } +Value *LibCallSimplifier::optimizeStrNDup(CallInst *CI, IRBuilder<> &B) { + Value *Src = CI->getArgOperand(0); + ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen && Size) { + annotateDereferenceableBytes(CI, 0, SrcLen); + if (SrcLen <= Size->getZExtValue() + 1) + return emitStrDup(Src, B, TLI); + } + + return nullptr; +} + Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); if (Dst == Src) // strcpy(x,x) -> x return Src; - + + annotateNonNullBasedOnAccess(CI, {0, 1}); // See if we can get the length of the input string. uint64_t Len = GetStringLength(Src); - if (Len == 0) + if (Len) + annotateDereferenceableBytes(CI, 1, Len); + else return nullptr; // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, 1, Src, 1, - ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len)); + CallInst *NewCI = + B.CreateMemCpy(Dst, 1, Src, 1, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len)); + NewCI->setAttributes(CI->getAttributes()); return Dst; } @@ -464,7 +577,9 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { // See if we can get the length of the input string. uint64_t Len = GetStringLength(Src); - if (Len == 0) + if (Len) + annotateDereferenceableBytes(CI, 1, Len); + else return nullptr; Type *PT = Callee->getFunctionType()->getParamType(0); @@ -474,7 +589,8 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, 1, Src, 1, LenV); + CallInst *NewCI = B.CreateMemCpy(Dst, 1, Src, 1, LenV); + NewCI->setAttributes(CI->getAttributes()); return DstEnd; } @@ -482,37 +598,47 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); - Value *LenOp = CI->getArgOperand(2); + Value *Size = CI->getArgOperand(2); + annotateNonNullBasedOnAccess(CI, 0); + if (isKnownNonZero(Size, DL)) + annotateNonNullBasedOnAccess(CI, 1); + + uint64_t Len; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size)) + Len = LengthArg->getZExtValue(); + else + return nullptr; + + // strncpy(x, y, 0) -> x + if (Len == 0) + return Dst; // See if we can get the length of the input string. uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) + if (SrcLen) { + annotateDereferenceableBytes(CI, 1, SrcLen); + --SrcLen; // Unbias length. + } else { return nullptr; - --SrcLen; + } if (SrcLen == 0) { // strncpy(x, "", y) -> memset(align 1 x, '\0', y) - B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); + CallInst *NewCI = B.CreateMemSet(Dst, B.getInt8('\0'), Size, 1); + AttrBuilder ArgAttrs(CI->getAttributes().getParamAttributes(0)); + NewCI->setAttributes(NewCI->getAttributes().addParamAttributes( + CI->getContext(), 0, ArgAttrs)); return Dst; } - uint64_t Len; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) - Len = LengthArg->getZExtValue(); - else - return nullptr; - - if (Len == 0) - return Dst; // strncpy(x, y, 0) -> x - // Let strncpy handle the zero padding if (Len > SrcLen + 1) return nullptr; Type *PT = Callee->getFunctionType()->getParamType(0); // strncpy(x, s, c) -> memcpy(align 1 x, align 1 s, c) [s and c are constant] - B.CreateMemCpy(Dst, 1, Src, 1, ConstantInt::get(DL.getIntPtrType(PT), Len)); - + CallInst *NewCI = B.CreateMemCpy(Dst, 1, Src, 1, ConstantInt::get(DL.getIntPtrType(PT), Len)); + NewCI->setAttributes(CI->getAttributes()); return Dst; } @@ -608,7 +734,10 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilder<> &B, } Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { - return optimizeStringLength(CI, B, 8); + if (Value *V = optimizeStringLength(CI, B, 8)) + return V; + annotateNonNullBasedOnAccess(CI, 0); + return nullptr; } Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilder<> &B) { @@ -756,21 +885,35 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { Value *StrChr = emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI); return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; } + + annotateNonNullBasedOnAccess(CI, {0, 1}); + return nullptr; +} + +Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilder<> &B) { + if (isKnownNonZero(CI->getOperand(2), DL)) + annotateNonNullBasedOnAccess(CI, 0); return nullptr; } Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { Value *SrcStr = CI->getArgOperand(0); + Value *Size = CI->getArgOperand(2); + annotateNonNullAndDereferenceable(CI, 0, Size, DL); ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + ConstantInt *LenC = dyn_cast<ConstantInt>(Size); // memchr(x, y, 0) -> null - if (LenC && LenC->isZero()) - return Constant::getNullValue(CI->getType()); + if (LenC) { + if (LenC->isZero()) + return Constant::getNullValue(CI->getType()); + } else { + // From now on we need at least constant length and string. + return nullptr; + } - // From now on we need at least constant length and string. StringRef Str; - if (!LenC || !getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) + if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) return nullptr; // Truncate the string to LenC. If Str is smaller than LenC we will still only @@ -913,6 +1056,7 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, Ret = 1; return ConstantInt::get(CI->getType(), Ret); } + return nullptr; } @@ -925,12 +1069,19 @@ Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI, if (LHS == RHS) // memcmp(s,s,x) -> 0 return Constant::getNullValue(CI->getType()); + annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL); // Handle constant lengths. - if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size)) - if (Value *Res = optimizeMemCmpConstantSize(CI, LHS, RHS, - LenC->getZExtValue(), B, DL)) - return Res; + ConstantInt *LenC = dyn_cast<ConstantInt>(Size); + if (!LenC) + return nullptr; + // memcmp(d,s,0) -> 0 + if (LenC->getZExtValue() == 0) + return Constant::getNullValue(CI->getType()); + + if (Value *Res = + optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL)) + return Res; return nullptr; } @@ -939,9 +1090,9 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { return V; // memcmp(x, y, Len) == 0 -> bcmp(x, y, Len) == 0 - // `bcmp` can be more efficient than memcmp because it only has to know that - // there is a difference, not where it is. - if (isOnlyUsedInZeroEqualityComparison(CI) && TLI->has(LibFunc_bcmp)) { + // bcmp can be more efficient than memcmp because it only has to know that + // there is a difference, not how different one is to the other. + if (TLI->has(LibFunc_bcmp) && isOnlyUsedInZeroEqualityComparison(CI)) { Value *LHS = CI->getArgOperand(0); Value *RHS = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); @@ -956,16 +1107,37 @@ Value *LibCallSimplifier::optimizeBCmp(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { + Value *Size = CI->getArgOperand(2); + annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL); + if (isa<IntrinsicInst>(CI)) + return nullptr; + // memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n) - B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, - CI->getArgOperand(2)); + CallInst *NewCI = + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, Size); + NewCI->setAttributes(CI->getAttributes()); return CI->getArgOperand(0); } +Value *LibCallSimplifier::optimizeMemPCpy(CallInst *CI, IRBuilder<> &B) { + Value *Dst = CI->getArgOperand(0); + Value *N = CI->getArgOperand(2); + // mempcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n), x + n + CallInst *NewCI = B.CreateMemCpy(Dst, 1, CI->getArgOperand(1), 1, N); + NewCI->setAttributes(CI->getAttributes()); + return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, N); +} + Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { + Value *Size = CI->getArgOperand(2); + annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL); + if (isa<IntrinsicInst>(CI)) + return nullptr; + // memmove(x, y, n) -> llvm.memmove(align 1 x, align 1 y, n) - B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, - CI->getArgOperand(2)); + CallInst *NewCI = + B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, Size); + NewCI->setAttributes(CI->getAttributes()); return CI->getArgOperand(0); } @@ -1003,25 +1175,29 @@ Value *LibCallSimplifier::foldMallocMemset(CallInst *Memset, IRBuilder<> &B) { B.SetInsertPoint(Malloc->getParent(), ++Malloc->getIterator()); const DataLayout &DL = Malloc->getModule()->getDataLayout(); IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext()); - Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), - Malloc->getArgOperand(0), Malloc->getAttributes(), - B, *TLI); - if (!Calloc) - return nullptr; - - Malloc->replaceAllUsesWith(Calloc); - eraseFromParent(Malloc); + if (Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), + Malloc->getArgOperand(0), + Malloc->getAttributes(), B, *TLI)) { + substituteInParent(Malloc, Calloc); + return Calloc; + } - return Calloc; + return nullptr; } Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { + Value *Size = CI->getArgOperand(2); + annotateNonNullAndDereferenceable(CI, 0, Size, DL); + if (isa<IntrinsicInst>(CI)) + return nullptr; + if (auto *Calloc = foldMallocMemset(CI, B)) return Calloc; // memset(p, v, n) -> llvm.memset(align 1 p, v, n) Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, Size, 1); + NewCI->setAttributes(CI->getAttributes()); return CI->getArgOperand(0); } @@ -1096,21 +1272,18 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, if (!V[0] || (isBinary && !V[1])) return nullptr; - StringRef CalleeNm = CalleeFn->getName(); - AttributeList CalleeAt = CalleeFn->getAttributes(); - bool CalleeIn = CalleeFn->isIntrinsic(); - // If call isn't an intrinsic, check that it isn't within a function with the // same name as the float version of this call, otherwise the result is an // infinite loop. For example, from MinGW-w64: // // float expf(float val) { return (float) exp((double) val); } - if (!CalleeIn) { - const Function *Fn = CI->getFunction(); - StringRef FnName = Fn->getName(); - if (FnName.back() == 'f' && - FnName.size() == (CalleeNm.size() + 1) && - FnName.startswith(CalleeNm)) + StringRef CalleeName = CalleeFn->getName(); + bool IsIntrinsic = CalleeFn->isIntrinsic(); + if (!IsIntrinsic) { + StringRef CallerName = CI->getFunction()->getName(); + if (!CallerName.empty() && CallerName.back() == 'f' && + CallerName.size() == (CalleeName.size() + 1) && + CallerName.startswith(CalleeName)) return nullptr; } @@ -1120,16 +1293,16 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, // g((double) float) -> (double) gf(float) Value *R; - if (CalleeIn) { + if (IsIntrinsic) { Module *M = CI->getModule(); Intrinsic::ID IID = CalleeFn->getIntrinsicID(); Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]); + } else { + AttributeList CalleeAttrs = CalleeFn->getAttributes(); + R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeName, B, CalleeAttrs) + : emitUnaryFloatFnCall(V[0], CalleeName, B, CalleeAttrs); } - else - R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeNm, B, CalleeAt) - : emitUnaryFloatFnCall(V[0], CalleeNm, B, CalleeAt); - return B.CreateFPExt(R, B.getDoubleTy()); } @@ -1234,9 +1407,25 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { return InnerChain[Exp]; } +// Return a properly extended 32-bit integer if the operation is an itofp. +static Value *getIntToFPVal(Value *I2F, IRBuilder<> &B) { + if (isa<SIToFPInst>(I2F) || isa<UIToFPInst>(I2F)) { + Value *Op = cast<Instruction>(I2F)->getOperand(0); + // Make sure that the exponent fits inside an int32_t, + // thus avoiding any range issues that FP has not. + unsigned BitWidth = Op->getType()->getPrimitiveSizeInBits(); + if (BitWidth < 32 || + (BitWidth == 32 && isa<SIToFPInst>(I2F))) + return isa<SIToFPInst>(I2F) ? B.CreateSExt(Op, B.getInt32Ty()) + : B.CreateZExt(Op, B.getInt32Ty()); + } + + return nullptr; +} + /// Use exp{,2}(x * y) for pow(exp{,2}(x), y); -/// exp2(n * x) for pow(2.0 ** n, x); exp10(x) for pow(10.0, x); -/// exp2(log2(n) * x) for pow(n, x). +/// ldexp(1.0, x) for pow(2.0, itofp(x)); exp2(n * x) for pow(2.0 ** n, x); +/// exp10(x) for pow(10.0, x); exp2(log2(n) * x) for pow(n, x). Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); @@ -1269,9 +1458,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { StringRef ExpName; Intrinsic::ID ID; Value *ExpFn; - LibFunc LibFnFloat; - LibFunc LibFnDouble; - LibFunc LibFnLongDouble; + LibFunc LibFnFloat, LibFnDouble, LibFnLongDouble; switch (LibFn) { default: @@ -1305,9 +1492,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { // elimination cannot be trusted to remove it, since it may have side // effects (e.g., errno). When the only consumer for the original // exp{,2}() is pow(), then it has to be explicitly erased. - BaseFn->replaceAllUsesWith(ExpFn); - eraseFromParent(BaseFn); - + substituteInParent(BaseFn, ExpFn); return ExpFn; } } @@ -1318,8 +1503,18 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) return nullptr; + // pow(2.0, itofp(x)) -> ldexp(1.0, x) + if (match(Base, m_SpecificFP(2.0)) && + (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) && + hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + if (Value *ExpoI = getIntToFPVal(Expo, B)) + return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), ExpoI, TLI, + LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl, + B, Attrs); + } + // pow(2.0 ** n, x) -> exp2(n * x) - if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { + if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { APFloat BaseR = APFloat(1.0); BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); BaseR = BaseR / *BaseF; @@ -1344,7 +1539,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { // pow(10.0, x) -> exp10(x) // TODO: There is no exp10() intrinsic yet, but some day there shall be one. if (match(Base, m_SpecificFP(10.0)) && - hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + hasFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) return emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l, B, Attrs); @@ -1359,17 +1554,15 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { if (Log) { Value *FMul = B.CreateFMul(Log, Expo, "mul"); - if (Pow->doesNotAccessMemory()) { + if (Pow->doesNotAccessMemory()) return B.CreateCall(Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty), FMul, "exp2"); - } else { - if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, - LibFunc_exp2l)) - return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, - LibFunc_exp2l, B, Attrs); - } + else if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) + return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l, B, Attrs); } } + return nullptr; } @@ -1384,8 +1577,7 @@ static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, } // Otherwise, use the libcall for sqrt(). - if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) + if (hasFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) // TODO: We also should check that the target can in fact lower the sqrt() // libcall. We currently have no way to ask this question, so we ask if // the target has a sqrt() libcall, which is not exactly the same. @@ -1452,7 +1644,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { bool Ignored; // Bail out if simplifying libcalls to pow() is disabled. - if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl)) + if (!hasFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl)) return nullptr; // Propagate the math semantics from the call to any created instructions. @@ -1480,8 +1672,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { if (match(Expo, m_SpecificFP(-1.0))) return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal"); - // pow(x, 0.0) -> 1.0 - if (match(Expo, m_SpecificFP(0.0))) + // pow(x, +/-0.0) -> 1.0 + if (match(Expo, m_AnyZeroFP())) return ConstantFP::get(Ty, 1.0); // pow(x, 1.0) -> x @@ -1558,16 +1750,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { // powf(x, itofp(y)) -> powi(x, y) if (AllowApprox && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo))) { - Value *IntExpo = cast<Instruction>(Expo)->getOperand(0); - Value *NewExpo = nullptr; - unsigned BitWidth = IntExpo->getType()->getPrimitiveSizeInBits(); - if (isa<SIToFPInst>(Expo) && BitWidth == 32) - NewExpo = IntExpo; - else if (BitWidth < 32) - NewExpo = isa<SIToFPInst>(Expo) ? B.CreateSExt(IntExpo, B.getInt32Ty()) - : B.CreateZExt(IntExpo, B.getInt32Ty()); - if (NewExpo) - return createPowWithIntegerExponent(Base, NewExpo, M, B); + if (Value *ExpoI = getIntToFPVal(Expo, B)) + return createPowWithIntegerExponent(Base, ExpoI, M, B); } return Shrunk; @@ -1575,45 +1759,25 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - Value *Ret = nullptr; StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "exp2" && hasFloatVersion(Name)) + Value *Ret = nullptr; + if (UnsafeFPShrink && Name == TLI->getName(LibFunc_exp2) && + hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); + Type *Ty = CI->getType(); Value *Op = CI->getArgOperand(0); + // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 - LibFunc LdExp = LibFunc_ldexpl; - if (Op->getType()->isFloatTy()) - LdExp = LibFunc_ldexpf; - else if (Op->getType()->isDoubleTy()) - LdExp = LibFunc_ldexp; - - if (TLI->has(LdExp)) { - Value *LdExpArg = nullptr; - if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) - LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); - } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) - LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); - } - - if (LdExpArg) { - Constant *One = ConstantFP::get(CI->getContext(), APFloat(1.0f)); - if (!Op->getType()->isFloatTy()) - One = ConstantExpr::getFPExtend(One, Op->getType()); - - Module *M = CI->getModule(); - FunctionCallee NewCallee = M->getOrInsertFunction( - TLI->getName(LdExp), Op->getType(), Op->getType(), B.getInt32Ty()); - CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg}); - if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - - return CI; - } + if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) && + hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + if (Value *Exp = getIntToFPVal(Op, B)) + return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, + LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl, + B, CI->getCalledFunction()->getAttributes()); } + return Ret; } @@ -1644,48 +1808,155 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { return B.CreateCall(F, { CI->getArgOperand(0), CI->getArgOperand(1) }); } -Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); +Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilder<> &B) { + Function *LogFn = Log->getCalledFunction(); + AttributeList Attrs = LogFn->getAttributes(); + StringRef LogNm = LogFn->getName(); + Intrinsic::ID LogID = LogFn->getIntrinsicID(); + Module *Mod = Log->getModule(); + Type *Ty = Log->getType(); Value *Ret = nullptr; - StringRef Name = Callee->getName(); - if (UnsafeFPShrink && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); - if (!CI->isFast()) - return Ret; - Value *Op1 = CI->getArgOperand(0); - auto *OpC = dyn_cast<CallInst>(Op1); + if (UnsafeFPShrink && hasFloatVersion(LogNm)) + Ret = optimizeUnaryDoubleFP(Log, B, true); // The earlier call must also be 'fast' in order to do these transforms. - if (!OpC || !OpC->isFast()) + CallInst *Arg = dyn_cast<CallInst>(Log->getArgOperand(0)); + if (!Log->isFast() || !Arg || !Arg->isFast() || !Arg->hasOneUse()) return Ret; - // log(pow(x,y)) -> y*log(x) - // This is only applicable to log, log2, log10. - if (Name != "log" && Name != "log2" && Name != "log10") + LibFunc LogLb, ExpLb, Exp2Lb, Exp10Lb, PowLb; + + // This is only applicable to log(), log2(), log10(). + if (TLI->getLibFunc(LogNm, LogLb)) + switch (LogLb) { + case LibFunc_logf: + LogID = Intrinsic::log; + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + break; + case LibFunc_log: + LogID = Intrinsic::log; + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + break; + case LibFunc_logl: + LogID = Intrinsic::log; + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + PowLb = LibFunc_powl; + break; + case LibFunc_log2f: + LogID = Intrinsic::log2; + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + break; + case LibFunc_log2: + LogID = Intrinsic::log2; + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + break; + case LibFunc_log2l: + LogID = Intrinsic::log2; + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + PowLb = LibFunc_powl; + break; + case LibFunc_log10f: + LogID = Intrinsic::log10; + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + break; + case LibFunc_log10: + LogID = Intrinsic::log10; + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + break; + case LibFunc_log10l: + LogID = Intrinsic::log10; + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + PowLb = LibFunc_powl; + break; + default: + return Ret; + } + else if (LogID == Intrinsic::log || LogID == Intrinsic::log2 || + LogID == Intrinsic::log10) { + if (Ty->getScalarType()->isFloatTy()) { + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + PowLb = LibFunc_powf; + } else if (Ty->getScalarType()->isDoubleTy()) { + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + PowLb = LibFunc_pow; + } else + return Ret; + } else return Ret; IRBuilder<>::FastMathFlagGuard Guard(B); - FastMathFlags FMF; - FMF.setFast(); - B.setFastMathFlags(FMF); + B.setFastMathFlags(FastMathFlags::getFast()); + + Intrinsic::ID ArgID = Arg->getIntrinsicID(); + LibFunc ArgLb = NotLibFunc; + TLI->getLibFunc(Arg, ArgLb); + + // log(pow(x,y)) -> y*log(x) + if (ArgLb == PowLb || ArgID == Intrinsic::pow) { + Value *LogX = + Log->doesNotAccessMemory() + ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), + Arg->getOperand(0), "log") + : emitUnaryFloatFnCall(Arg->getOperand(0), LogNm, B, Attrs); + Value *MulY = B.CreateFMul(Arg->getArgOperand(1), LogX, "mul"); + // Since pow() may have side effects, e.g. errno, + // dead code elimination may not be trusted to remove it. + substituteInParent(Arg, MulY); + return MulY; + } + + // log(exp{,2,10}(y)) -> y*log({e,2,10}) + // TODO: There is no exp10() intrinsic yet. + if (ArgLb == ExpLb || ArgLb == Exp2Lb || ArgLb == Exp10Lb || + ArgID == Intrinsic::exp || ArgID == Intrinsic::exp2) { + Constant *Eul; + if (ArgLb == ExpLb || ArgID == Intrinsic::exp) + // FIXME: Add more precise value of e for long double. + Eul = ConstantFP::get(Log->getType(), numbers::e); + else if (ArgLb == Exp2Lb || ArgID == Intrinsic::exp2) + Eul = ConstantFP::get(Log->getType(), 2.0); + else + Eul = ConstantFP::get(Log->getType(), 10.0); + Value *LogE = Log->doesNotAccessMemory() + ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), + Eul, "log") + : emitUnaryFloatFnCall(Eul, LogNm, B, Attrs); + Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul"); + // Since exp() may have side effects, e.g. errno, + // dead code elimination may not be trusted to remove it. + substituteInParent(Arg, MulY); + return MulY; + } - LibFunc Func; - Function *F = OpC->getCalledFunction(); - if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && - Func == LibFunc_pow) || F->getIntrinsicID() == Intrinsic::pow)) - return B.CreateFMul(OpC->getArgOperand(1), - emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, - Callee->getAttributes()), "mul"); - - // log(exp2(y)) -> y*log(2) - if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) && - TLI->has(Func) && Func == LibFunc_exp2) - return B.CreateFMul( - OpC->getArgOperand(0), - emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0), - Callee->getName(), B, Callee->getAttributes()), - "logmul"); return Ret; } @@ -2137,6 +2408,7 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { return New; } + annotateNonNullBasedOnAccess(CI, 0); return nullptr; } @@ -2231,21 +2503,21 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { return New; } + annotateNonNullBasedOnAccess(CI, {0, 1}); return nullptr; } Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) - return nullptr; - // Check for size ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1)); if (!Size) return nullptr; uint64_t N = Size->getZExtValue(); + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) + return nullptr; // If we just have a format string (nothing else crazy) transform it. if (CI->getNumArgOperands() == 3) { @@ -2318,6 +2590,8 @@ Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilder<> &B) { return V; } + if (isKnownNonZero(CI->getOperand(1), DL)) + annotateNonNullBasedOnAccess(CI, 0); return nullptr; } @@ -2503,6 +2777,7 @@ Value *LibCallSimplifier::optimizeFRead(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { + annotateNonNullBasedOnAccess(CI, 0); if (!CI->use_empty()) return nullptr; @@ -2515,6 +2790,12 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { return nullptr; } +Value *LibCallSimplifier::optimizeBCopy(CallInst *CI, IRBuilder<> &B) { + // bcopy(src, dst, n) -> llvm.memmove(dst, src, n) + return B.CreateMemMove(CI->getArgOperand(1), 1, CI->getArgOperand(0), 1, + CI->getArgOperand(2)); +} + bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { LibFunc Func; SmallString<20> FloatFuncName = FuncName; @@ -2557,6 +2838,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeStrLen(CI, Builder); case LibFunc_strpbrk: return optimizeStrPBrk(CI, Builder); + case LibFunc_strndup: + return optimizeStrNDup(CI, Builder); case LibFunc_strtol: case LibFunc_strtod: case LibFunc_strtof: @@ -2573,12 +2856,16 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeStrStr(CI, Builder); case LibFunc_memchr: return optimizeMemChr(CI, Builder); + case LibFunc_memrchr: + return optimizeMemRChr(CI, Builder); case LibFunc_bcmp: return optimizeBCmp(CI, Builder); case LibFunc_memcmp: return optimizeMemCmp(CI, Builder); case LibFunc_memcpy: return optimizeMemCpy(CI, Builder); + case LibFunc_mempcpy: + return optimizeMemPCpy(CI, Builder); case LibFunc_memmove: return optimizeMemMove(CI, Builder); case LibFunc_memset: @@ -2587,6 +2874,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeRealloc(CI, Builder); case LibFunc_wcslen: return optimizeWcslen(CI, Builder); + case LibFunc_bcopy: + return optimizeBCopy(CI, Builder); default: break; } @@ -2626,11 +2915,21 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sqrt: case LibFunc_sqrtl: return optimizeSqrt(CI, Builder); + case LibFunc_logf: case LibFunc_log: + case LibFunc_logl: + case LibFunc_log10f: case LibFunc_log10: + case LibFunc_log10l: + case LibFunc_log1pf: case LibFunc_log1p: + case LibFunc_log1pl: + case LibFunc_log2f: case LibFunc_log2: + case LibFunc_log2l: + case LibFunc_logbf: case LibFunc_logb: + case LibFunc_logbl: return optimizeLog(CI, Builder); case LibFunc_tan: case LibFunc_tanf: @@ -2721,10 +3020,18 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case Intrinsic::exp2: return optimizeExp2(CI, Builder); case Intrinsic::log: + case Intrinsic::log2: + case Intrinsic::log10: return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); // TODO: Use foldMallocMemset() with memset intrinsic. + case Intrinsic::memset: + return optimizeMemSet(CI, Builder); + case Intrinsic::memcpy: + return optimizeMemCpy(CI, Builder); + case Intrinsic::memmove: + return optimizeMemMove(CI, Builder); default: return nullptr; } @@ -2740,8 +3047,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { IRBuilder<> TmpBuilder(SimplifiedCI); if (Value *V = optimizeStringMemoryLibCall(SimplifiedCI, TmpBuilder)) { // If we were able to further simplify, remove the now redundant call. - SimplifiedCI->replaceAllUsesWith(V); - eraseFromParent(SimplifiedCI); + substituteInParent(SimplifiedCI, V); return V; } } @@ -2898,7 +3204,9 @@ FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp)); // If the length is 0 we don't know how long it is and so we can't // remove the check. - if (Len == 0) + if (Len) + annotateDereferenceableBytes(CI, *StrOp, Len); + else return false; return ObjSizeCI->getZExtValue() >= Len; } @@ -2915,8 +3223,9 @@ FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { if (isFortifiedCallFoldable(CI, 3, 2)) { - B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, - CI->getArgOperand(2)); + CallInst *NewCI = B.CreateMemCpy( + CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, CI->getArgOperand(2)); + NewCI->setAttributes(CI->getAttributes()); return CI->getArgOperand(0); } return nullptr; @@ -2925,8 +3234,9 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { if (isFortifiedCallFoldable(CI, 3, 2)) { - B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, - CI->getArgOperand(2)); + CallInst *NewCI = B.CreateMemMove( + CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, CI->getArgOperand(2)); + NewCI->setAttributes(CI->getAttributes()); return CI->getArgOperand(0); } return nullptr; @@ -2938,7 +3248,9 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, if (isFortifiedCallFoldable(CI, 3, 2)) { Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + CallInst *NewCI = + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + NewCI->setAttributes(CI->getAttributes()); return CI->getArgOperand(0); } return nullptr; @@ -2974,7 +3286,9 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, // Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk. uint64_t Len = GetStringLength(Src); - if (Len == 0) + if (Len) + annotateDereferenceableBytes(CI, 1, Len); + else return nullptr; Type *SizeTTy = DL.getIntPtrType(CI->getContext()); diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp index 456724779b43..5d380dcf231c 100644 --- a/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -380,11 +380,11 @@ parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, // TODO see if there is a more elegant solution to selecting the rewrite // descriptor type if (!Target.empty()) - DL->push_back(llvm::make_unique<ExplicitRewriteFunctionDescriptor>( + DL->push_back(std::make_unique<ExplicitRewriteFunctionDescriptor>( Source, Target, Naked)); else DL->push_back( - llvm::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform)); + std::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform)); return true; } @@ -442,11 +442,11 @@ parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } if (!Target.empty()) - DL->push_back(llvm::make_unique<ExplicitRewriteGlobalVariableDescriptor>( + DL->push_back(std::make_unique<ExplicitRewriteGlobalVariableDescriptor>( Source, Target, /*Naked*/ false)); else - DL->push_back(llvm::make_unique<PatternRewriteGlobalVariableDescriptor>( + DL->push_back(std::make_unique<PatternRewriteGlobalVariableDescriptor>( Source, Transform)); return true; @@ -505,11 +505,11 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } if (!Target.empty()) - DL->push_back(llvm::make_unique<ExplicitRewriteNamedAliasDescriptor>( + DL->push_back(std::make_unique<ExplicitRewriteNamedAliasDescriptor>( Source, Target, /*Naked*/ false)); else - DL->push_back(llvm::make_unique<PatternRewriteNamedAliasDescriptor>( + DL->push_back(std::make_unique<PatternRewriteNamedAliasDescriptor>( Source, Transform)); return true; diff --git a/lib/Transforms/Utils/VNCoercion.cpp b/lib/Transforms/Utils/VNCoercion.cpp index a77bf50fe10b..591e1fd2dbee 100644 --- a/lib/Transforms/Utils/VNCoercion.cpp +++ b/lib/Transforms/Utils/VNCoercion.cpp @@ -431,7 +431,7 @@ Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); LoadInst *NewLoad = Builder.CreateLoad(DestTy, PtrVal); NewLoad->takeName(SrcVal); - NewLoad->setAlignment(SrcVal->getAlignment()); + NewLoad->setAlignment(MaybeAlign(SrcVal->getAlignment())); LLVM_DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); LLVM_DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index fbc3407c301f..da68d3713b40 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -27,8 +27,8 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalObject.h" +#include "llvm/IR/GlobalIndirectSymbol.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instruction.h" @@ -66,7 +66,7 @@ struct WorklistEntry { enum EntryKind { MapGlobalInit, MapAppendingVar, - MapGlobalAliasee, + MapGlobalIndirectSymbol, RemapFunction }; struct GVInitTy { @@ -77,9 +77,9 @@ struct WorklistEntry { GlobalVariable *GV; Constant *InitPrefix; }; - struct GlobalAliaseeTy { - GlobalAlias *GA; - Constant *Aliasee; + struct GlobalIndirectSymbolTy { + GlobalIndirectSymbol *GIS; + Constant *Target; }; unsigned Kind : 2; @@ -89,7 +89,7 @@ struct WorklistEntry { union { GVInitTy GVInit; AppendingGVTy AppendingGV; - GlobalAliaseeTy GlobalAliasee; + GlobalIndirectSymbolTy GlobalIndirectSymbol; Function *RemapF; } Data; }; @@ -161,8 +161,8 @@ public: bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers, unsigned MCID); - void scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee, - unsigned MCID); + void scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, Constant &Target, + unsigned MCID); void scheduleRemapFunction(Function &F, unsigned MCID); void flush(); @@ -172,7 +172,7 @@ private: void mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers); - void mapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee); + void mapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, Constant &Target); void remapFunction(Function &F, ValueToValueMapTy &VM); ValueToValueMapTy &getVM() { return *MCs[CurrentMCID].VM; } @@ -774,20 +774,6 @@ Metadata *MDNodeMapper::mapTopLevelUniquedNode(const MDNode &FirstN) { return *getMappedOp(&FirstN); } -namespace { - -struct MapMetadataDisabler { - ValueToValueMapTy &VM; - - MapMetadataDisabler(ValueToValueMapTy &VM) : VM(VM) { - VM.disableMapMetadata(); - } - - ~MapMetadataDisabler() { VM.enableMapMetadata(); } -}; - -} // end anonymous namespace - Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { // If the value already exists in the map, use it. if (Optional<Metadata *> NewMD = getVM().getMappedMD(MD)) @@ -802,9 +788,6 @@ Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { return const_cast<Metadata *>(MD); if (auto *CMD = dyn_cast<ConstantAsMetadata>(MD)) { - // Disallow recursion into metadata mapping through mapValue. - MapMetadataDisabler MMD(getVM()); - // Don't memoize ConstantAsMetadata. Instead of lasting until the // LLVMContext is destroyed, they can be deleted when the GlobalValue they // reference is destructed. These aren't super common, so the extra @@ -846,9 +829,9 @@ void Mapper::flush() { AppendingInits.resize(PrefixSize); break; } - case WorklistEntry::MapGlobalAliasee: - E.Data.GlobalAliasee.GA->setAliasee( - mapConstant(E.Data.GlobalAliasee.Aliasee)); + case WorklistEntry::MapGlobalIndirectSymbol: + E.Data.GlobalIndirectSymbol.GIS->setIndirectSymbol( + mapConstant(E.Data.GlobalIndirectSymbol.Target)); break; case WorklistEntry::RemapFunction: remapFunction(*E.Data.RemapF); @@ -1041,16 +1024,16 @@ void Mapper::scheduleMapAppendingVariable(GlobalVariable &GV, AppendingInits.append(NewMembers.begin(), NewMembers.end()); } -void Mapper::scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee, - unsigned MCID) { - assert(AlreadyScheduled.insert(&GA).second && "Should not reschedule"); +void Mapper::scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, + Constant &Target, unsigned MCID) { + assert(AlreadyScheduled.insert(&GIS).second && "Should not reschedule"); assert(MCID < MCs.size() && "Invalid mapping context"); WorklistEntry WE; - WE.Kind = WorklistEntry::MapGlobalAliasee; + WE.Kind = WorklistEntry::MapGlobalIndirectSymbol; WE.MCID = MCID; - WE.Data.GlobalAliasee.GA = &GA; - WE.Data.GlobalAliasee.Aliasee = &Aliasee; + WE.Data.GlobalIndirectSymbol.GIS = &GIS; + WE.Data.GlobalIndirectSymbol.Target = &Target; Worklist.push_back(WE); } @@ -1147,9 +1130,10 @@ void ValueMapper::scheduleMapAppendingVariable(GlobalVariable &GV, GV, InitPrefix, IsOldCtorDtor, NewMembers, MCID); } -void ValueMapper::scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee, - unsigned MCID) { - getAsMapper(pImpl)->scheduleMapGlobalAliasee(GA, Aliasee, MCID); +void ValueMapper::scheduleMapGlobalIndirectSymbol(GlobalIndirectSymbol &GIS, + Constant &Target, + unsigned MCID) { + getAsMapper(pImpl)->scheduleMapGlobalIndirectSymbol(GIS, Target, MCID); } void ValueMapper::scheduleRemapFunction(Function &F, unsigned MCID) { diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 4273080ddd91..f44976c723ec 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -147,7 +147,7 @@ private: static const unsigned MaxDepth = 3; bool isConsecutiveAccess(Value *A, Value *B); - bool areConsecutivePointers(Value *PtrA, Value *PtrB, const APInt &PtrDelta, + bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta, unsigned Depth = 0) const; bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, unsigned Depth) const; @@ -336,14 +336,29 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { } bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, - const APInt &PtrDelta, - unsigned Depth) const { + APInt PtrDelta, unsigned Depth) const { unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType()); APInt OffsetA(PtrBitWidth, 0); APInt OffsetB(PtrBitWidth, 0); PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); + unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType()); + + if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType())) + return false; + + // In case if we have to shrink the pointer + // stripAndAccumulateInBoundsConstantOffsets should properly handle a + // possible overflow and the value should fit into a smallest data type + // used in the cast/gep chain. + assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth && + OffsetB.getMinSignedBits() <= NewPtrBitWidth); + + OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth); + OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth); + PtrDelta = PtrDelta.sextOrTrunc(NewPtrBitWidth); + APInt OffsetDelta = OffsetB - OffsetA; // Check if they are based on the same pointer. That makes the offsets @@ -650,7 +665,7 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { // We can ignore the alias if the we have a load store pair and the load // is known to be invariant. The load cannot be clobbered by the store. auto IsInvariantLoad = [](const LoadInst *LI) -> bool { - return LI->getMetadata(LLVMContext::MD_invariant_load); + return LI->hasMetadata(LLVMContext::MD_invariant_load); }; // We can ignore the alias as long as the load comes before the store, @@ -1077,7 +1092,7 @@ bool Vectorizer::vectorizeLoadChain( LoadInst *L0 = cast<LoadInst>(Chain[0]); // If the vector has an int element, default to int for the whole load. - Type *LoadTy; + Type *LoadTy = nullptr; for (const auto &V : Chain) { LoadTy = cast<LoadInst>(V)->getType(); if (LoadTy->isIntOrIntVectorTy()) @@ -1089,6 +1104,7 @@ bool Vectorizer::vectorizeLoadChain( break; } } + assert(LoadTy && "Can't determine LoadInst type from chain"); unsigned Sz = DL.getTypeSizeInBits(LoadTy); unsigned AS = L0->getPointerAddressSpace(); diff --git a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 6ef8dc2d3cd7..f43842be5357 100644 --- a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -13,7 +13,10 @@ // pass. It should be easy to create an analysis pass around it if there // is a need (but D45420 needs to happen first). // +#include "llvm/Transforms/Vectorize/LoopVectorize.h" #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IntrinsicInst.h" @@ -47,38 +50,6 @@ static const unsigned MaxInterleaveFactor = 16; namespace llvm { -#ifndef NDEBUG -static void debugVectorizationFailure(const StringRef DebugMsg, - Instruction *I) { - dbgs() << "LV: Not vectorizing: " << DebugMsg; - if (I != nullptr) - dbgs() << " " << *I; - else - dbgs() << '.'; - dbgs() << '\n'; -} -#endif - -OptimizationRemarkAnalysis createLVMissedAnalysis(const char *PassName, - StringRef RemarkName, - Loop *TheLoop, - Instruction *I) { - Value *CodeRegion = TheLoop->getHeader(); - DebugLoc DL = TheLoop->getStartLoc(); - - if (I) { - CodeRegion = I->getParent(); - // If there is no debug location attached to the instruction, revert back to - // using the loop's. - if (I->getDebugLoc()) - DL = I->getDebugLoc(); - } - - OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion); - R << "loop not vectorized: "; - return R; -} - bool LoopVectorizeHints::Hint::validate(unsigned Val) { switch (Kind) { case HK_WIDTH: @@ -88,6 +59,7 @@ bool LoopVectorizeHints::Hint::validate(unsigned Val) { case HK_FORCE: return (Val <= 1); case HK_ISVECTORIZED: + case HK_PREDICATE: return (Val == 0 || Val == 1); } return false; @@ -99,7 +71,9 @@ LoopVectorizeHints::LoopVectorizeHints(const Loop *L, : Width("vectorize.width", VectorizerParams::VectorizationFactor, HK_WIDTH), Interleave("interleave.count", InterleaveOnlyWhenForced, HK_UNROLL), Force("vectorize.enable", FK_Undefined, HK_FORCE), - IsVectorized("isvectorized", 0, HK_ISVECTORIZED), TheLoop(L), ORE(ORE) { + IsVectorized("isvectorized", 0, HK_ISVECTORIZED), + Predicate("vectorize.predicate.enable", 0, HK_PREDICATE), TheLoop(L), + ORE(ORE) { // Populate values with existing loop metadata. getHintsFromMetadata(); @@ -250,7 +224,7 @@ void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { return; unsigned Val = C->getZExtValue(); - Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized}; + Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized, &Predicate}; for (auto H : Hints) { if (Name == H->Name) { if (H->validate(Val)) @@ -435,7 +409,8 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { const ValueToValueMap &Strides = getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap(); - int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, true, false); + bool CanAddPredicate = !TheLoop->getHeader()->getParent()->hasOptSize(); + int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, CanAddPredicate, false); if (Stride == 1 || Stride == -1) return Stride; return 0; @@ -445,14 +420,6 @@ bool LoopVectorizationLegality::isUniform(Value *V) { return LAI->isUniform(V); } -void LoopVectorizationLegality::reportVectorizationFailure( - const StringRef DebugMsg, const StringRef OREMsg, - const StringRef ORETag, Instruction *I) const { - LLVM_DEBUG(debugVectorizationFailure(DebugMsg, I)); - ORE->emit(createLVMissedAnalysis(Hints->vectorizeAnalysisPassName(), - ORETag, TheLoop, I) << OREMsg); -} - bool LoopVectorizationLegality::canVectorizeOuterLoop() { assert(!TheLoop->empty() && "We are not vectorizing an outer loop."); // Store the result and return it at the end instead of exiting early, in case @@ -467,7 +434,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { if (!Br) { reportVectorizationFailure("Unsupported basic block terminator", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -486,7 +453,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { !LI->isLoopHeader(Br->getSuccessor(1))) { reportVectorizationFailure("Unsupported conditional branch", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -500,7 +467,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { TheLoop /*context outer loop*/)) { reportVectorizationFailure("Outer loop contains divergent loops", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -511,7 +478,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { if (!setupOuterLoopInductions()) { reportVectorizationFailure("Unsupported outer loop Phi(s)", "Unsupported outer loop Phi(s)", - "UnsupportedPhi"); + "UnsupportedPhi", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -618,7 +585,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { !PhiTy->isPointerTy()) { reportVectorizationFailure("Found a non-int non-pointer PHI", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); return false; } @@ -631,6 +598,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Unsafe cyclic dependencies with header phis are identified during // legalization for reduction, induction and first order // recurrences. + AllowedExit.insert(&I); continue; } @@ -638,7 +606,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (Phi->getNumIncomingValues() != 2) { reportVectorizationFailure("Found an invalid PHI", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood", Phi); + "CFGNotUnderstood", ORE, TheLoop, Phi); return false; } @@ -690,7 +658,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { reportVectorizationFailure("Found an unidentified PHI", "value that could not be identified as " "reduction is used outside the loop", - "NonReductionValueUsedOutsideLoop", Phi); + "NonReductionValueUsedOutsideLoop", ORE, TheLoop, Phi); return false; } // end of PHI handling @@ -721,11 +689,11 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { "library call cannot be vectorized. " "Try compiling with -fno-math-errno, -ffast-math, " "or similar flags", - "CantVectorizeLibcall", CI); + "CantVectorizeLibcall", ORE, TheLoop, CI); } else { reportVectorizationFailure("Found a non-intrinsic callsite", "call instruction cannot be vectorized", - "CantVectorizeLibcall", CI); + "CantVectorizeLibcall", ORE, TheLoop, CI); } return false; } @@ -740,7 +708,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(i)), TheLoop)) { reportVectorizationFailure("Found unvectorizable intrinsic", "intrinsic instruction cannot be vectorized", - "CantVectorizeIntrinsic", CI); + "CantVectorizeIntrinsic", ORE, TheLoop, CI); return false; } } @@ -753,7 +721,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { isa<ExtractElementInst>(I)) { reportVectorizationFailure("Found unvectorizable type", "instruction return type cannot be vectorized", - "CantVectorizeInstructionReturnType", &I); + "CantVectorizeInstructionReturnType", ORE, TheLoop, &I); return false; } @@ -763,7 +731,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (!VectorType::isValidElementType(T)) { reportVectorizationFailure("Store instruction cannot be vectorized", "store instruction cannot be vectorized", - "CantVectorizeStore", ST); + "CantVectorizeStore", ORE, TheLoop, ST); return false; } @@ -773,12 +741,13 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Arbitrarily try a vector of 2 elements. Type *VecTy = VectorType::get(T, /*NumElements=*/2); assert(VecTy && "did not find vectorized version of stored type"); - unsigned Alignment = getLoadStoreAlignment(ST); - if (!TTI->isLegalNTStore(VecTy, Alignment)) { + const MaybeAlign Alignment = getLoadStoreAlignment(ST); + assert(Alignment && "Alignment should be set"); + if (!TTI->isLegalNTStore(VecTy, *Alignment)) { reportVectorizationFailure( "nontemporal store instruction cannot be vectorized", "nontemporal store instruction cannot be vectorized", - "CantVectorizeNontemporalStore", ST); + "CantVectorizeNontemporalStore", ORE, TheLoop, ST); return false; } } @@ -789,12 +758,13 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // supported on the target (arbitrarily try a vector of 2 elements). Type *VecTy = VectorType::get(I.getType(), /*NumElements=*/2); assert(VecTy && "did not find vectorized version of load type"); - unsigned Alignment = getLoadStoreAlignment(LD); - if (!TTI->isLegalNTLoad(VecTy, Alignment)) { + const MaybeAlign Alignment = getLoadStoreAlignment(LD); + assert(Alignment && "Alignment should be set"); + if (!TTI->isLegalNTLoad(VecTy, *Alignment)) { reportVectorizationFailure( "nontemporal load instruction cannot be vectorized", "nontemporal load instruction cannot be vectorized", - "CantVectorizeNontemporalLoad", LD); + "CantVectorizeNontemporalLoad", ORE, TheLoop, LD); return false; } } @@ -823,7 +793,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } reportVectorizationFailure("Value cannot be used outside the loop", "value cannot be used outside the loop", - "ValueUsedOutsideLoop", &I); + "ValueUsedOutsideLoop", ORE, TheLoop, &I); return false; } } // next instr. @@ -833,12 +803,12 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (Inductions.empty()) { reportVectorizationFailure("Did not find one integer induction var", "loop induction variable could not be identified", - "NoInductionVariable"); + "NoInductionVariable", ORE, TheLoop); return false; } else if (!WidestIndTy) { reportVectorizationFailure("Did not find one integer induction var", "integer loop induction variable could not be identified", - "NoIntegerInductionVariable"); + "NoIntegerInductionVariable", ORE, TheLoop); return false; } else { LLVM_DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); @@ -869,7 +839,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { reportVectorizationFailure("Stores to a uniform address", "write to a loop invariant address could not be vectorized", - "CantVectorizeStoreToLoopInvariantAddress"); + "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop); return false; } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); @@ -905,7 +875,7 @@ bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { } bool LoopVectorizationLegality::blockCanBePredicated( - BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs) { + BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs, bool PreserveGuards) { const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); for (Instruction &I : *BB) { @@ -924,7 +894,7 @@ bool LoopVectorizationLegality::blockCanBePredicated( // !llvm.mem.parallel_loop_access implies if-conversion safety. // Otherwise, record that the load needs (real or emulated) masking // and let the cost model decide. - if (!IsAnnotatedParallel) + if (!IsAnnotatedParallel || PreserveGuards) MaskedOp.insert(LI); continue; } @@ -953,23 +923,41 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { if (!EnableIfConversion) { reportVectorizationFailure("If-conversion is disabled", "if-conversion is disabled", - "IfConversionDisabled"); + "IfConversionDisabled", + ORE, TheLoop); return false; } assert(TheLoop->getNumBlocks() > 1 && "Single block loops are vectorizable"); - // A list of pointers that we can safely read and write to. + // A list of pointers which are known to be dereferenceable within scope of + // the loop body for each iteration of the loop which executes. That is, + // the memory pointed to can be dereferenced (with the access size implied by + // the value's type) unconditionally within the loop header without + // introducing a new fault. SmallPtrSet<Value *, 8> SafePointes; // Collect safe addresses. for (BasicBlock *BB : TheLoop->blocks()) { - if (blockNeedsPredication(BB)) + if (!blockNeedsPredication(BB)) { + for (Instruction &I : *BB) + if (auto *Ptr = getLoadStorePointerOperand(&I)) + SafePointes.insert(Ptr); continue; + } - for (Instruction &I : *BB) - if (auto *Ptr = getLoadStorePointerOperand(&I)) - SafePointes.insert(Ptr); + // For a block which requires predication, a address may be safe to access + // in the loop w/o predication if we can prove dereferenceability facts + // sufficient to ensure it'll never fault within the loop. For the moment, + // we restrict this to loads; stores are more complicated due to + // concurrency restrictions. + ScalarEvolution &SE = *PSE.getSE(); + for (Instruction &I : *BB) { + LoadInst *LI = dyn_cast<LoadInst>(&I); + if (LI && !mustSuppressSpeculation(*LI) && + isDereferenceableAndAlignedInLoop(LI, TheLoop, SE, *DT)) + SafePointes.insert(LI->getPointerOperand()); + } } // Collect the blocks that need predication. @@ -979,7 +967,8 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { if (!isa<BranchInst>(BB->getTerminator())) { reportVectorizationFailure("Loop contains a switch statement", "loop contains a switch statement", - "LoopContainsSwitch", BB->getTerminator()); + "LoopContainsSwitch", ORE, TheLoop, + BB->getTerminator()); return false; } @@ -989,14 +978,16 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { reportVectorizationFailure( "Control flow cannot be substituted for a select", "control flow cannot be substituted for a select", - "NoCFGForSelect", BB->getTerminator()); + "NoCFGForSelect", ORE, TheLoop, + BB->getTerminator()); return false; } } else if (BB != Header && !canIfConvertPHINodes(BB)) { reportVectorizationFailure( "Control flow cannot be substituted for a select", "control flow cannot be substituted for a select", - "NoCFGForSelect", BB->getTerminator()); + "NoCFGForSelect", ORE, TheLoop, + BB->getTerminator()); return false; } } @@ -1026,7 +1017,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, if (!Lp->getLoopPreheader()) { reportVectorizationFailure("Loop doesn't have a legal pre-header", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -1037,7 +1028,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, if (Lp->getNumBackEdges() != 1) { reportVectorizationFailure("The loop must have a single backedge", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -1048,7 +1039,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, if (!Lp->getExitingBlock()) { reportVectorizationFailure("The loop must have an exiting block", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -1061,7 +1052,7 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, if (Lp->getExitingBlock() != Lp->getLoopLatch()) { reportVectorizationFailure("The exiting block is not the loop latch", "loop control flow is not understood by vectorizer", - "CFGNotUnderstood"); + "CFGNotUnderstood", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -1124,7 +1115,8 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { if (!canVectorizeOuterLoop()) { reportVectorizationFailure("Unsupported outer loop", "unsupported outer loop", - "UnsupportedOuterLoop"); + "UnsupportedOuterLoop", + ORE, TheLoop); // TODO: Implement DoExtraAnalysis when subsequent legal checks support // outer loops. return false; @@ -1176,7 +1168,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { reportVectorizationFailure("Too many SCEV checks needed", "Too many SCEV assumptions need to be made and checked at runtime", - "TooManySCEVRunTimeChecks"); + "TooManySCEVRunTimeChecks", ORE, TheLoop); if (DoExtraAnalysis) Result = false; else @@ -1190,7 +1182,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return Result; } -bool LoopVectorizationLegality::canFoldTailByMasking() { +bool LoopVectorizationLegality::prepareToFoldTailByMasking() { LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n"); @@ -1199,22 +1191,21 @@ bool LoopVectorizationLegality::canFoldTailByMasking() { "No primary induction, cannot fold tail by masking", "Missing a primary induction variable in the loop, which is " "needed in order to fold tail by masking as required.", - "NoPrimaryInduction"); + "NoPrimaryInduction", ORE, TheLoop); return false; } - // TODO: handle reductions when tail is folded by masking. - if (!Reductions.empty()) { - reportVectorizationFailure( - "Loop has reductions, cannot fold tail by masking", - "Cannot fold tail by masking in the presence of reductions.", - "ReductionFoldingTailByMasking"); - return false; - } + SmallPtrSet<const Value *, 8> ReductionLiveOuts; - // TODO: handle outside users when tail is folded by masking. + for (auto &Reduction : *getReductionVars()) + ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr()); + + // TODO: handle non-reduction outside users when tail is folded by masking. for (auto *AE : AllowedExit) { - // Check that all users of allowed exit values are inside the loop. + // Check that all users of allowed exit values are inside the loop or + // are the live-out of a reduction. + if (ReductionLiveOuts.count(AE)) + continue; for (User *U : AE->users()) { Instruction *UI = cast<Instruction>(U); if (TheLoop->contains(UI)) @@ -1222,7 +1213,7 @@ bool LoopVectorizationLegality::canFoldTailByMasking() { reportVectorizationFailure( "Cannot fold tail by masking, loop has an outside user for", "Cannot fold tail by masking in the presence of live outs.", - "LiveOutFoldingTailByMasking", UI); + "LiveOutFoldingTailByMasking", ORE, TheLoop, UI); return false; } } @@ -1233,11 +1224,12 @@ bool LoopVectorizationLegality::canFoldTailByMasking() { // Check and mark all blocks for predication, including those that ordinarily // do not need predication such as the header block. for (BasicBlock *BB : TheLoop->blocks()) { - if (!blockCanBePredicated(BB, SafePointers)) { + if (!blockCanBePredicated(BB, SafePointers, /* MaskAllLoads= */ true)) { reportVectorizationFailure( "Cannot fold tail by masking as required", "control flow cannot be substituted for a select", - "NoCFGForSelect", BB->getTerminator()); + "NoCFGForSelect", ORE, TheLoop, + BB->getTerminator()); return false; } } diff --git a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 97077cce83e3..a5e85f27fabf 100644 --- a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -228,11 +228,11 @@ public: /// Plan how to best vectorize, return the best VF and its cost, or None if /// vectorization and interleaving should be avoided up front. - Optional<VectorizationFactor> plan(bool OptForSize, unsigned UserVF); + Optional<VectorizationFactor> plan(unsigned UserVF); /// Use the VPlan-native path to plan how to best vectorize, return the best /// VF and its cost. - VectorizationFactor planInVPlanNativePath(bool OptForSize, unsigned UserVF); + VectorizationFactor planInVPlanNativePath(unsigned UserVF); /// Finalize the best decision and dispose of all other VPlans. void setBestPlan(unsigned VF, unsigned UF); diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 46265e3f3e13..8f0bf70f873c 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -177,6 +177,14 @@ static cl::opt<unsigned> TinyTripCountVectorThreshold( "value are vectorized only if no scalar iteration overheads " "are incurred.")); +// Indicates that an epilogue is undesired, predication is preferred. +// This means that the vectorizer will try to fold the loop-tail (epilogue) +// into the loop and predicate the loop body accordingly. +static cl::opt<bool> PreferPredicateOverEpilog( + "prefer-predicate-over-epilog", cl::init(false), cl::Hidden, + cl::desc("Indicate that an epilogue is undesired, predication should be " + "used instead.")); + static cl::opt<bool> MaximizeBandwidth( "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, cl::desc("Maximize bandwidth when selecting vectorization factor which " @@ -347,6 +355,29 @@ static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { : ConstantFP::get(Ty, C); } +/// Returns "best known" trip count for the specified loop \p L as defined by +/// the following procedure: +/// 1) Returns exact trip count if it is known. +/// 2) Returns expected trip count according to profile data if any. +/// 3) Returns upper bound estimate if it is known. +/// 4) Returns None if all of the above failed. +static Optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, Loop *L) { + // Check if exact trip count is known. + if (unsigned ExpectedTC = SE.getSmallConstantTripCount(L)) + return ExpectedTC; + + // Check if there is an expected trip count available from profile data. + if (LoopVectorizeWithBlockFrequency) + if (auto EstimatedTC = getLoopEstimatedTripCount(L)) + return EstimatedTC; + + // Check if upper bound estimate is known. + if (unsigned ExpectedTC = SE.getSmallConstantMaxTripCount(L)) + return ExpectedTC; + + return None; +} + namespace llvm { /// InnerLoopVectorizer vectorizes loops which contain only one basic @@ -795,6 +826,59 @@ void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) B.SetCurrentDebugLocation(DebugLoc()); } +/// Write a record \p DebugMsg about vectorization failure to the debug +/// output stream. If \p I is passed, it is an instruction that prevents +/// vectorization. +#ifndef NDEBUG +static void debugVectorizationFailure(const StringRef DebugMsg, + Instruction *I) { + dbgs() << "LV: Not vectorizing: " << DebugMsg; + if (I != nullptr) + dbgs() << " " << *I; + else + dbgs() << '.'; + dbgs() << '\n'; +} +#endif + +/// Create an analysis remark that explains why vectorization failed +/// +/// \p PassName is the name of the pass (e.g. can be AlwaysPrint). \p +/// RemarkName is the identifier for the remark. If \p I is passed it is an +/// instruction that prevents vectorization. Otherwise \p TheLoop is used for +/// the location of the remark. \return the remark object that can be +/// streamed to. +static OptimizationRemarkAnalysis createLVAnalysis(const char *PassName, + StringRef RemarkName, Loop *TheLoop, Instruction *I) { + Value *CodeRegion = TheLoop->getHeader(); + DebugLoc DL = TheLoop->getStartLoc(); + + if (I) { + CodeRegion = I->getParent(); + // If there is no debug location attached to the instruction, revert back to + // using the loop's. + if (I->getDebugLoc()) + DL = I->getDebugLoc(); + } + + OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion); + R << "loop not vectorized: "; + return R; +} + +namespace llvm { + +void reportVectorizationFailure(const StringRef DebugMsg, + const StringRef OREMsg, const StringRef ORETag, + OptimizationRemarkEmitter *ORE, Loop *TheLoop, Instruction *I) { + LLVM_DEBUG(debugVectorizationFailure(DebugMsg, I)); + LoopVectorizeHints Hints(TheLoop, true /* doesn't matter */, *ORE); + ORE->emit(createLVAnalysis(Hints.vectorizeAnalysisPassName(), + ORETag, TheLoop, I) << OREMsg); +} + +} // end namespace llvm + #ifndef NDEBUG /// \return string containing a file name and a line # for the given loop. static std::string getDebugLocString(const Loop *L) { @@ -836,6 +920,26 @@ void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, namespace llvm { +// Loop vectorization cost-model hints how the scalar epilogue loop should be +// lowered. +enum ScalarEpilogueLowering { + + // The default: allowing scalar epilogues. + CM_ScalarEpilogueAllowed, + + // Vectorization with OptForSize: don't allow epilogues. + CM_ScalarEpilogueNotAllowedOptSize, + + // A special case of vectorisation with OptForSize: loops with a very small + // trip count are considered for vectorization under OptForSize, thereby + // making sure the cost of their loop body is dominant, free of runtime + // guards and scalar iteration overheads. + CM_ScalarEpilogueNotAllowedLowTripLoop, + + // Loop hint predicate indicating an epilogue is undesired. + CM_ScalarEpilogueNotNeededUsePredicate +}; + /// LoopVectorizationCostModel - estimates the expected speedups due to /// vectorization. /// In many cases vectorization is not profitable. This can happen because of @@ -845,20 +949,26 @@ namespace llvm { /// different operations. class LoopVectorizationCostModel { public: - LoopVectorizationCostModel(Loop *L, PredicatedScalarEvolution &PSE, - LoopInfo *LI, LoopVectorizationLegality *Legal, + LoopVectorizationCostModel(ScalarEpilogueLowering SEL, Loop *L, + PredicatedScalarEvolution &PSE, LoopInfo *LI, + LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, DemandedBits *DB, AssumptionCache *AC, OptimizationRemarkEmitter *ORE, const Function *F, const LoopVectorizeHints *Hints, InterleavedAccessInfo &IAI) - : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), - AC(AC), ORE(ORE), TheFunction(F), Hints(Hints), InterleaveInfo(IAI) {} + : ScalarEpilogueStatus(SEL), TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), + TTI(TTI), TLI(TLI), DB(DB), AC(AC), ORE(ORE), TheFunction(F), + Hints(Hints), InterleaveInfo(IAI) {} /// \return An upper bound for the vectorization factor, or None if /// vectorization and interleaving should be avoided up front. - Optional<unsigned> computeMaxVF(bool OptForSize); + Optional<unsigned> computeMaxVF(); + + /// \return True if runtime checks are required for vectorization, and false + /// otherwise. + bool runtimeChecksRequired(); /// \return The most profitable vectorization factor and the cost of that VF. /// This method checks every power of two up to MaxVF. If UserVF is not ZERO @@ -881,8 +991,7 @@ public: /// If interleave count has been specified by metadata it will be returned. /// Otherwise, the interleave count is computed and returned. VF and LoopCost /// are the selected vectorization factor and the cost of the selected VF. - unsigned selectInterleaveCount(bool OptForSize, unsigned VF, - unsigned LoopCost); + unsigned selectInterleaveCount(unsigned VF, unsigned LoopCost); /// Memory access instruction may be vectorized in more than one way. /// Form of instruction after vectorization depends on cost. @@ -897,10 +1006,11 @@ public: /// of a loop. struct RegisterUsage { /// Holds the number of loop invariant values that are used in the loop. - unsigned LoopInvariantRegs; - + /// The key is ClassID of target-provided register class. + SmallMapVector<unsigned, unsigned, 4> LoopInvariantRegs; /// Holds the maximum number of concurrent live intervals in the loop. - unsigned MaxLocalUsers; + /// The key is ClassID of target-provided register class. + SmallMapVector<unsigned, unsigned, 4> MaxLocalUsers; }; /// \return Returns information about the register usages of the loop for the @@ -1080,14 +1190,16 @@ public: /// Returns true if the target machine supports masked store operation /// for the given \p DataType and kind of access to \p Ptr. - bool isLegalMaskedStore(Type *DataType, Value *Ptr) { - return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedStore(DataType); + bool isLegalMaskedStore(Type *DataType, Value *Ptr, MaybeAlign Alignment) { + return Legal->isConsecutivePtr(Ptr) && + TTI.isLegalMaskedStore(DataType, Alignment); } /// Returns true if the target machine supports masked load operation /// for the given \p DataType and kind of access to \p Ptr. - bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { - return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedLoad(DataType); + bool isLegalMaskedLoad(Type *DataType, Value *Ptr, MaybeAlign Alignment) { + return Legal->isConsecutivePtr(Ptr) && + TTI.isLegalMaskedLoad(DataType, Alignment); } /// Returns true if the target machine supports masked scatter operation @@ -1157,11 +1269,14 @@ public: /// to handle accesses with gaps, and there is nothing preventing us from /// creating a scalar epilogue. bool requiresScalarEpilogue() const { - return IsScalarEpilogueAllowed && InterleaveInfo.requiresScalarEpilogue(); + return isScalarEpilogueAllowed() && InterleaveInfo.requiresScalarEpilogue(); } - /// Returns true if a scalar epilogue is not allowed due to optsize. - bool isScalarEpilogueAllowed() const { return IsScalarEpilogueAllowed; } + /// Returns true if a scalar epilogue is not allowed due to optsize or a + /// loop hint annotation. + bool isScalarEpilogueAllowed() const { + return ScalarEpilogueStatus == CM_ScalarEpilogueAllowed; + } /// Returns true if all loop blocks should be masked to fold tail loop. bool foldTailByMasking() const { return FoldTailByMasking; } @@ -1187,7 +1302,7 @@ private: /// \return An upper bound for the vectorization factor, larger than zero. /// One is returned if vectorization should best be avoided due to cost. - unsigned computeFeasibleMaxVF(bool OptForSize, unsigned ConstTripCount); + unsigned computeFeasibleMaxVF(unsigned ConstTripCount); /// The vectorization cost is a combination of the cost itself and a boolean /// indicating whether any of the contributing operations will actually @@ -1246,15 +1361,6 @@ private: /// should be used. bool useEmulatedMaskMemRefHack(Instruction *I); - /// Create an analysis remark that explains why vectorization failed - /// - /// \p RemarkName is the identifier for the remark. \return the remark object - /// that can be streamed to. - OptimizationRemarkAnalysis createMissedAnalysis(StringRef RemarkName) { - return createLVMissedAnalysis(Hints->vectorizeAnalysisPassName(), - RemarkName, TheLoop); - } - /// Map of scalar integer values to the smallest bitwidth they can be legally /// represented as. The vector equivalents of these values should be truncated /// to this type. @@ -1270,13 +1376,13 @@ private: SmallPtrSet<BasicBlock *, 4> PredicatedBBsAfterVectorization; /// Records whether it is allowed to have the original scalar loop execute at - /// least once. This may be needed as a fallback loop in case runtime + /// least once. This may be needed as a fallback loop in case runtime /// aliasing/dependence checks fail, or to handle the tail/remainder /// iterations when the trip count is unknown or doesn't divide by the VF, /// or as a peel-loop to handle gaps in interleave-groups. /// Under optsize and when the trip count is very small we don't allow any /// iterations to execute in the scalar loop. - bool IsScalarEpilogueAllowed = true; + ScalarEpilogueLowering ScalarEpilogueStatus = CM_ScalarEpilogueAllowed; /// All blocks of loop are to be masked to fold tail of scalar iterations. bool FoldTailByMasking = false; @@ -1496,7 +1602,7 @@ struct LoopVectorize : public FunctionPass { auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); @@ -2253,12 +2359,11 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, Type *ScalarDataTy = getMemInstValueType(Instr); Type *DataTy = VectorType::get(ScalarDataTy, VF); Value *Ptr = getLoadStorePointerOperand(Instr); - unsigned Alignment = getLoadStoreAlignment(Instr); // An alignment of 0 means target abi alignment. We need to use the scalar's // target abi alignment in such a case. const DataLayout &DL = Instr->getModule()->getDataLayout(); - if (!Alignment) - Alignment = DL.getABITypeAlignment(ScalarDataTy); + const Align Alignment = + DL.getValueOrABITypeAlignment(getLoadStoreAlignment(Instr), ScalarDataTy); unsigned AddressSpace = getLoadStoreAddressSpace(Instr); // Determine if the pointer operand of the access is either consecutive or @@ -2322,8 +2427,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, if (CreateGatherScatter) { Value *MaskPart = isMaskRequired ? Mask[Part] : nullptr; Value *VectorGep = getOrCreateVectorValue(Ptr, Part); - NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, - MaskPart); + NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, + Alignment.value(), MaskPart); } else { if (Reverse) { // If we store to reverse consecutive memory locations, then we need @@ -2334,10 +2439,11 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, } auto *VecPtr = CreateVecPtr(Part, Ptr); if (isMaskRequired) - NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, - Mask[Part]); + NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, + Alignment.value(), Mask[Part]); else - NewSI = Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment); + NewSI = + Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment.value()); } addMetadata(NewSI, SI); } @@ -2352,18 +2458,18 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, if (CreateGatherScatter) { Value *MaskPart = isMaskRequired ? Mask[Part] : nullptr; Value *VectorGep = getOrCreateVectorValue(Ptr, Part); - NewLI = Builder.CreateMaskedGather(VectorGep, Alignment, MaskPart, + NewLI = Builder.CreateMaskedGather(VectorGep, Alignment.value(), MaskPart, nullptr, "wide.masked.gather"); addMetadata(NewLI, LI); } else { auto *VecPtr = CreateVecPtr(Part, Ptr); if (isMaskRequired) - NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment.value(), Mask[Part], UndefValue::get(DataTy), "wide.masked.load"); else - NewLI = - Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load"); + NewLI = Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment.value(), + "wide.load"); // Add metadata to the load, but setVectorValue to the reverse shuffle. addMetadata(NewLI, LI); @@ -2615,8 +2721,9 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { if (C->isZero()) return; - assert(!Cost->foldTailByMasking() && - "Cannot SCEV check stride or overflow when folding tail"); + assert(!BB->getParent()->hasOptSize() && + "Cannot SCEV check stride or overflow when optimizing for size"); + // Create a new block containing the stride check. BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); @@ -2649,7 +2756,20 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { if (!MemRuntimeCheck) return; - assert(!Cost->foldTailByMasking() && "Cannot check memory when folding tail"); + if (BB->getParent()->hasOptSize()) { + assert(Cost->Hints->getForce() == LoopVectorizeHints::FK_Enabled && + "Cannot emit memory checks when optimizing for size, unless forced " + "to vectorize."); + ORE->emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "VectorizationCodeSize", + L->getStartLoc(), L->getHeader()) + << "Code-size may be reduced by not forcing " + "vectorization, or by source-code modifications " + "eliminating the need for runtime checks " + "(e.g., adding 'restrict')."; + }); + } + // Create a new block containing the memory check. BB->setName("vector.memcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); @@ -2666,7 +2786,7 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { // We currently don't use LoopVersioning for the actual loop cloning but we // still use it to add the noalias metadata. - LVer = llvm::make_unique<LoopVersioning>(*Legal->getLAI(), OrigLoop, LI, DT, + LVer = std::make_unique<LoopVersioning>(*Legal->getLAI(), OrigLoop, LI, DT, PSE.getSE()); LVer->prepareNoAliasMetadata(); } @@ -3598,6 +3718,26 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { setDebugLocFromInst(Builder, LoopExitInst); + // If tail is folded by masking, the vector value to leave the loop should be + // a Select choosing between the vectorized LoopExitInst and vectorized Phi, + // instead of the former. + if (Cost->foldTailByMasking()) { + for (unsigned Part = 0; Part < UF; ++Part) { + Value *VecLoopExitInst = + VectorLoopValueMap.getVectorValue(LoopExitInst, Part); + Value *Sel = nullptr; + for (User *U : VecLoopExitInst->users()) { + if (isa<SelectInst>(U)) { + assert(!Sel && "Reduction exit feeding two selects"); + Sel = U; + } else + assert(isa<PHINode>(U) && "Reduction exit must feed Phi's or select"); + } + assert(Sel && "Reduction exit feeds no select"); + VectorLoopValueMap.resetVectorValue(LoopExitInst, Part, Sel); + } + } + // If the vector reduction can be performed in a smaller type, we truncate // then extend the loop exit value to enable InstCombine to evaluate the // entire expression in the smaller type. @@ -4064,7 +4204,7 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { case Instruction::FCmp: { // Widen compares. Generate vector compares. bool FCmp = (I.getOpcode() == Instruction::FCmp); - auto *Cmp = dyn_cast<CmpInst>(&I); + auto *Cmp = cast<CmpInst>(&I); setDebugLocFromInst(Builder, Cmp); for (unsigned Part = 0; Part < UF; ++Part) { Value *A = getOrCreateVectorValue(Cmp->getOperand(0), Part); @@ -4097,7 +4237,7 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - auto *CI = dyn_cast<CastInst>(&I); + auto *CI = cast<CastInst>(&I); setDebugLocFromInst(Builder, CI); /// Vectorize casts. @@ -4421,9 +4561,10 @@ bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I, unsigne "Widening decision should be ready at this moment"); return WideningDecision == CM_Scalarize; } + const MaybeAlign Alignment = getLoadStoreAlignment(I); return isa<LoadInst>(I) ? - !(isLegalMaskedLoad(Ty, Ptr) || isLegalMaskedGather(Ty)) - : !(isLegalMaskedStore(Ty, Ptr) || isLegalMaskedScatter(Ty)); + !(isLegalMaskedLoad(Ty, Ptr, Alignment) || isLegalMaskedGather(Ty)) + : !(isLegalMaskedStore(Ty, Ptr, Alignment) || isLegalMaskedScatter(Ty)); } case Instruction::UDiv: case Instruction::SDiv: @@ -4452,10 +4593,10 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(Instruction *I, // Check if masking is required. // A Group may need masking for one of two reasons: it resides in a block that // needs predication, or it was decided to use masking to deal with gaps. - bool PredicatedAccessRequiresMasking = + bool PredicatedAccessRequiresMasking = Legal->blockNeedsPredication(I->getParent()) && Legal->isMaskRequired(I); - bool AccessWithGapsRequiresMasking = - Group->requiresScalarEpilogue() && !IsScalarEpilogueAllowed; + bool AccessWithGapsRequiresMasking = + Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed(); if (!PredicatedAccessRequiresMasking && !AccessWithGapsRequiresMasking) return true; @@ -4466,8 +4607,9 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(Instruction *I, "Masked interleave-groups for predicated accesses are not enabled."); auto *Ty = getMemInstValueType(I); - return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty) - : TTI.isLegalMaskedStore(Ty); + const MaybeAlign Alignment = getLoadStoreAlignment(I); + return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty, Alignment) + : TTI.isLegalMaskedStore(Ty, Alignment); } bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(Instruction *I, @@ -4675,82 +4817,96 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { Uniforms[VF].insert(Worklist.begin(), Worklist.end()); } -Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { - if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) { - // TODO: It may by useful to do since it's still likely to be dynamically - // uniform if the target can skip. - LLVM_DEBUG( - dbgs() << "LV: Not inserting runtime ptr check for divergent target"); - - ORE->emit( - createMissedAnalysis("CantVersionLoopWithDivergentTarget") - << "runtime pointer checks needed. Not enabled for divergent target"); - - return None; - } - - unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); - if (!OptForSize) // Remaining checks deal with scalar loop when OptForSize. - return computeFeasibleMaxVF(OptForSize, TC); +bool LoopVectorizationCostModel::runtimeChecksRequired() { + LLVM_DEBUG(dbgs() << "LV: Performing code size checks.\n"); if (Legal->getRuntimePointerChecking()->Need) { - ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") - << "runtime pointer checks needed. Enable vectorization of this " - "loop with '#pragma clang loop vectorize(enable)' when " - "compiling with -Os/-Oz"); - LLVM_DEBUG( - dbgs() - << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); - return None; + reportVectorizationFailure("Runtime ptr check is required with -Os/-Oz", + "runtime pointer checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz", + "CantVersionLoopWithOptForSize", ORE, TheLoop); + return true; } if (!PSE.getUnionPredicate().getPredicates().empty()) { - ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") - << "runtime SCEV checks needed. Enable vectorization of this " - "loop with '#pragma clang loop vectorize(enable)' when " - "compiling with -Os/-Oz"); - LLVM_DEBUG( - dbgs() - << "LV: Aborting. Runtime SCEV check is required with -Os/-Oz.\n"); - return None; + reportVectorizationFailure("Runtime SCEV check is required with -Os/-Oz", + "runtime SCEV checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz", + "CantVersionLoopWithOptForSize", ORE, TheLoop); + return true; } // FIXME: Avoid specializing for stride==1 instead of bailing out. if (!Legal->getLAI()->getSymbolicStrides().empty()) { - ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") - << "runtime stride == 1 checks needed. Enable vectorization of " - "this loop with '#pragma clang loop vectorize(enable)' when " - "compiling with -Os/-Oz"); - LLVM_DEBUG( - dbgs() - << "LV: Aborting. Runtime stride check is required with -Os/-Oz.\n"); + reportVectorizationFailure("Runtime stride check is required with -Os/-Oz", + "runtime stride == 1 checks needed. Enable vectorization of " + "this loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz", + "CantVersionLoopWithOptForSize", ORE, TheLoop); + return true; + } + + return false; +} + +Optional<unsigned> LoopVectorizationCostModel::computeMaxVF() { + if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) { + // TODO: It may by useful to do since it's still likely to be dynamically + // uniform if the target can skip. + reportVectorizationFailure( + "Not inserting runtime ptr check for divergent target", + "runtime pointer checks needed. Not enabled for divergent target", + "CantVersionLoopWithDivergentTarget", ORE, TheLoop); return None; } - // If we optimize the program for size, avoid creating the tail loop. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); - if (TC == 1) { - ORE->emit(createMissedAnalysis("SingleIterationLoop") - << "loop trip count is one, irrelevant for vectorization"); - LLVM_DEBUG(dbgs() << "LV: Aborting, single iteration (non) loop.\n"); + reportVectorizationFailure("Single iteration (non) loop", + "loop trip count is one, irrelevant for vectorization", + "SingleIterationLoop", ORE, TheLoop); return None; } - // Record that scalar epilogue is not allowed. - LLVM_DEBUG(dbgs() << "LV: Not allowing scalar epilogue due to -Os/-Oz.\n"); + switch (ScalarEpilogueStatus) { + case CM_ScalarEpilogueAllowed: + return computeFeasibleMaxVF(TC); + case CM_ScalarEpilogueNotNeededUsePredicate: + LLVM_DEBUG( + dbgs() << "LV: vector predicate hint/switch found.\n" + << "LV: Not allowing scalar epilogue, creating predicated " + << "vector loop.\n"); + break; + case CM_ScalarEpilogueNotAllowedLowTripLoop: + // fallthrough as a special case of OptForSize + case CM_ScalarEpilogueNotAllowedOptSize: + if (ScalarEpilogueStatus == CM_ScalarEpilogueNotAllowedOptSize) + LLVM_DEBUG( + dbgs() << "LV: Not allowing scalar epilogue due to -Os/-Oz.\n"); + else + LLVM_DEBUG(dbgs() << "LV: Not allowing scalar epilogue due to low trip " + << "count.\n"); + + // Bail if runtime checks are required, which are not good when optimising + // for size. + if (runtimeChecksRequired()) + return None; + break; + } - IsScalarEpilogueAllowed = !OptForSize; + // Now try the tail folding - // We don't create an epilogue when optimizing for size. // Invalidate interleave groups that require an epilogue if we can't mask // the interleave-group. - if (!useMaskedInterleavedAccesses(TTI)) + if (!useMaskedInterleavedAccesses(TTI)) InterleaveInfo.invalidateGroupsRequiringScalarEpilogue(); - unsigned MaxVF = computeFeasibleMaxVF(OptForSize, TC); - + unsigned MaxVF = computeFeasibleMaxVF(TC); if (TC > 0 && TC % MaxVF == 0) { + // Accept MaxVF if we do not have a tail. LLVM_DEBUG(dbgs() << "LV: No tail will remain for any chosen VF.\n"); return MaxVF; } @@ -4759,28 +4915,30 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { // found modulo the vectorization factor is not zero, try to fold the tail // by masking. // FIXME: look for a smaller MaxVF that does divide TC rather than masking. - if (Legal->canFoldTailByMasking()) { + if (Legal->prepareToFoldTailByMasking()) { FoldTailByMasking = true; return MaxVF; } if (TC == 0) { - ORE->emit( - createMissedAnalysis("UnknownLoopCountComplexCFG") - << "unable to calculate the loop count due to complex control flow"); + reportVectorizationFailure( + "Unable to calculate the loop count due to complex control flow", + "unable to calculate the loop count due to complex control flow", + "UnknownLoopCountComplexCFG", ORE, TheLoop); return None; } - ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") - << "cannot optimize for size and vectorize at the same time. " - "Enable vectorization of this loop with '#pragma clang loop " - "vectorize(enable)' when compiling with -Os/-Oz"); + reportVectorizationFailure( + "Cannot optimize for size and vectorize at the same time.", + "cannot optimize for size and vectorize at the same time. " + "Enable vectorization of this loop with '#pragma clang loop " + "vectorize(enable)' when compiling with -Os/-Oz", + "NoTailLoopWithOptForSize", ORE, TheLoop); return None; } unsigned -LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize, - unsigned ConstTripCount) { +LoopVectorizationCostModel::computeFeasibleMaxVF(unsigned ConstTripCount) { MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); unsigned SmallestType, WidestType; std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); @@ -4818,8 +4976,8 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize, } unsigned MaxVF = MaxVectorSize; - if (TTI.shouldMaximizeVectorBandwidth(OptForSize) || - (MaximizeBandwidth && !OptForSize)) { + if (TTI.shouldMaximizeVectorBandwidth(!isScalarEpilogueAllowed()) || + (MaximizeBandwidth && isScalarEpilogueAllowed())) { // Collect all viable vectorization factors larger than the default MaxVF // (i.e. MaxVectorSize). SmallVector<unsigned, 8> VFs; @@ -4832,9 +4990,14 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize, // Select the largest VF which doesn't require more registers than existing // ones. - unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true); for (int i = RUs.size() - 1; i >= 0; --i) { - if (RUs[i].MaxLocalUsers <= TargetNumRegisters) { + bool Selected = true; + for (auto& pair : RUs[i].MaxLocalUsers) { + unsigned TargetNumRegisters = TTI.getNumberOfRegisters(pair.first); + if (pair.second > TargetNumRegisters) + Selected = false; + } + if (Selected) { MaxVF = VFs[i]; break; } @@ -4886,10 +5049,9 @@ LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) { } if (!EnableCondStoresVectorization && NumPredStores) { - ORE->emit(createMissedAnalysis("ConditionalStore") - << "store that is conditionally executed prevents vectorization"); - LLVM_DEBUG( - dbgs() << "LV: No vectorization. There are conditional stores.\n"); + reportVectorizationFailure("There are conditional stores.", + "store that is conditionally executed prevents vectorization", + "ConditionalStore", ORE, TheLoop); Width = 1; Cost = ScalarCost; } @@ -4958,8 +5120,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { return {MinWidth, MaxWidth}; } -unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, - unsigned VF, +unsigned LoopVectorizationCostModel::selectInterleaveCount(unsigned VF, unsigned LoopCost) { // -- The interleave heuristics -- // We interleave the loop in order to expose ILP and reduce the loop overhead. @@ -4975,8 +5136,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // 3. We don't interleave if we think that we will spill registers to memory // due to the increased register pressure. - // When we optimize for size, we don't interleave. - if (OptForSize) + if (!isScalarEpilogueAllowed()) return 1; // We used the distance for the interleave count. @@ -4988,22 +5148,12 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, if (TC > 1 && TC < TinyTripCountInterleaveThreshold) return 1; - unsigned TargetNumRegisters = TTI.getNumberOfRegisters(VF > 1); - LLVM_DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters - << " registers\n"); - - if (VF == 1) { - if (ForceTargetNumScalarRegs.getNumOccurrences() > 0) - TargetNumRegisters = ForceTargetNumScalarRegs; - } else { - if (ForceTargetNumVectorRegs.getNumOccurrences() > 0) - TargetNumRegisters = ForceTargetNumVectorRegs; - } - RegisterUsage R = calculateRegisterUsage({VF})[0]; // We divide by these constants so assume that we have at least one // instruction that uses at least one register. - R.MaxLocalUsers = std::max(R.MaxLocalUsers, 1U); + for (auto& pair : R.MaxLocalUsers) { + pair.second = std::max(pair.second, 1U); + } // We calculate the interleave count using the following formula. // Subtract the number of loop invariants from the number of available @@ -5016,13 +5166,35 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // We also want power of two interleave counts to ensure that the induction // variable of the vector loop wraps to zero, when tail is folded by masking; // this currently happens when OptForSize, in which case IC is set to 1 above. - unsigned IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs) / - R.MaxLocalUsers); + unsigned IC = UINT_MAX; - // Don't count the induction variable as interleaved. - if (EnableIndVarRegisterHeur) - IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs - 1) / - std::max(1U, (R.MaxLocalUsers - 1))); + for (auto& pair : R.MaxLocalUsers) { + unsigned TargetNumRegisters = TTI.getNumberOfRegisters(pair.first); + LLVM_DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters + << " registers of " + << TTI.getRegisterClassName(pair.first) << " register class\n"); + if (VF == 1) { + if (ForceTargetNumScalarRegs.getNumOccurrences() > 0) + TargetNumRegisters = ForceTargetNumScalarRegs; + } else { + if (ForceTargetNumVectorRegs.getNumOccurrences() > 0) + TargetNumRegisters = ForceTargetNumVectorRegs; + } + unsigned MaxLocalUsers = pair.second; + unsigned LoopInvariantRegs = 0; + if (R.LoopInvariantRegs.find(pair.first) != R.LoopInvariantRegs.end()) + LoopInvariantRegs = R.LoopInvariantRegs[pair.first]; + + unsigned TmpIC = PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs) / MaxLocalUsers); + // Don't count the induction variable as interleaved. + if (EnableIndVarRegisterHeur) { + TmpIC = + PowerOf2Floor((TargetNumRegisters - LoopInvariantRegs - 1) / + std::max(1U, (MaxLocalUsers - 1))); + } + + IC = std::min(IC, TmpIC); + } // Clamp the interleave ranges to reasonable counts. unsigned MaxInterleaveCount = TTI.getMaxInterleaveFactor(VF); @@ -5036,6 +5208,14 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, MaxInterleaveCount = ForceTargetMaxVectorInterleaveFactor; } + // If the trip count is constant, limit the interleave count to be less than + // the trip count divided by VF. + if (TC > 0) { + assert(TC >= VF && "VF exceeds trip count?"); + if ((TC / VF) < MaxInterleaveCount) + MaxInterleaveCount = (TC / VF); + } + // If we did not calculate the cost for VF (because the user selected the VF) // then we calculate the cost of VF here. if (LoopCost == 0) @@ -5044,7 +5224,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, assert(LoopCost && "Non-zero loop cost expected"); // Clamp the calculated IC to be between the 1 and the max interleave count - // that the target allows. + // that the target and trip count allows. if (IC > MaxInterleaveCount) IC = MaxInterleaveCount; else if (IC < 1) @@ -5196,7 +5376,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { const DataLayout &DL = TheFunction->getParent()->getDataLayout(); SmallVector<RegisterUsage, 8> RUs(VFs.size()); - SmallVector<unsigned, 8> MaxUsages(VFs.size(), 0); + SmallVector<SmallMapVector<unsigned, unsigned, 4>, 8> MaxUsages(VFs.size()); LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n"); @@ -5226,21 +5406,45 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { // For each VF find the maximum usage of registers. for (unsigned j = 0, e = VFs.size(); j < e; ++j) { + // Count the number of live intervals. + SmallMapVector<unsigned, unsigned, 4> RegUsage; + if (VFs[j] == 1) { - MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size()); - continue; + for (auto Inst : OpenIntervals) { + unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType()); + if (RegUsage.find(ClassID) == RegUsage.end()) + RegUsage[ClassID] = 1; + else + RegUsage[ClassID] += 1; + } + } else { + collectUniformsAndScalars(VFs[j]); + for (auto Inst : OpenIntervals) { + // Skip ignored values for VF > 1. + if (VecValuesToIgnore.find(Inst) != VecValuesToIgnore.end()) + continue; + if (isScalarAfterVectorization(Inst, VFs[j])) { + unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType()); + if (RegUsage.find(ClassID) == RegUsage.end()) + RegUsage[ClassID] = 1; + else + RegUsage[ClassID] += 1; + } else { + unsigned ClassID = TTI.getRegisterClassForType(true, Inst->getType()); + if (RegUsage.find(ClassID) == RegUsage.end()) + RegUsage[ClassID] = GetRegUsage(Inst->getType(), VFs[j]); + else + RegUsage[ClassID] += GetRegUsage(Inst->getType(), VFs[j]); + } + } } - collectUniformsAndScalars(VFs[j]); - // Count the number of live intervals. - unsigned RegUsage = 0; - for (auto Inst : OpenIntervals) { - // Skip ignored values for VF > 1. - if (VecValuesToIgnore.find(Inst) != VecValuesToIgnore.end() || - isScalarAfterVectorization(Inst, VFs[j])) - continue; - RegUsage += GetRegUsage(Inst->getType(), VFs[j]); + + for (auto& pair : RegUsage) { + if (MaxUsages[j].find(pair.first) != MaxUsages[j].end()) + MaxUsages[j][pair.first] = std::max(MaxUsages[j][pair.first], pair.second); + else + MaxUsages[j][pair.first] = pair.second; } - MaxUsages[j] = std::max(MaxUsages[j], RegUsage); } LLVM_DEBUG(dbgs() << "LV(REG): At #" << i << " Interval # " @@ -5251,18 +5455,34 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { } for (unsigned i = 0, e = VFs.size(); i < e; ++i) { - unsigned Invariant = 0; - if (VFs[i] == 1) - Invariant = LoopInvariants.size(); - else { - for (auto Inst : LoopInvariants) - Invariant += GetRegUsage(Inst->getType(), VFs[i]); + SmallMapVector<unsigned, unsigned, 4> Invariant; + + for (auto Inst : LoopInvariants) { + unsigned Usage = VFs[i] == 1 ? 1 : GetRegUsage(Inst->getType(), VFs[i]); + unsigned ClassID = TTI.getRegisterClassForType(VFs[i] > 1, Inst->getType()); + if (Invariant.find(ClassID) == Invariant.end()) + Invariant[ClassID] = Usage; + else + Invariant[ClassID] += Usage; } - LLVM_DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); - LLVM_DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n'); - LLVM_DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant - << '\n'); + LLVM_DEBUG({ + dbgs() << "LV(REG): VF = " << VFs[i] << '\n'; + dbgs() << "LV(REG): Found max usage: " << MaxUsages[i].size() + << " item\n"; + for (const auto &pair : MaxUsages[i]) { + dbgs() << "LV(REG): RegisterClass: " + << TTI.getRegisterClassName(pair.first) << ", " << pair.second + << " registers\n"; + } + dbgs() << "LV(REG): Found invariant usage: " << Invariant.size() + << " item\n"; + for (const auto &pair : Invariant) { + dbgs() << "LV(REG): RegisterClass: " + << TTI.getRegisterClassName(pair.first) << ", " << pair.second + << " registers\n"; + } + }); RU.LoopInvariantRegs = Invariant; RU.MaxLocalUsers = MaxUsages[i]; @@ -5511,7 +5731,6 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, Type *ValTy = getMemInstValueType(I); auto SE = PSE.getSE(); - unsigned Alignment = getLoadStoreAlignment(I); unsigned AS = getLoadStoreAddressSpace(I); Value *Ptr = getLoadStorePointerOperand(I); Type *PtrTy = ToVectorTy(Ptr->getType(), VF); @@ -5525,9 +5744,9 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // Don't pass *I here, since it is scalar but will actually be part of a // vectorized loop where the user of it is a vectorized instruction. - Cost += VF * - TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, - AS); + const MaybeAlign Alignment = getLoadStoreAlignment(I); + Cost += VF * TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), + Alignment ? Alignment->value() : 0, AS); // Get the overhead of the extractelement and insertelement instructions // we might create due to scalarization. @@ -5552,18 +5771,20 @@ unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, unsigned VF) { Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned Alignment = getLoadStoreAlignment(I); Value *Ptr = getLoadStorePointerOperand(I); unsigned AS = getLoadStoreAddressSpace(I); int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && "Stride should be 1 or -1 for consecutive memory access"); + const MaybeAlign Alignment = getLoadStoreAlignment(I); unsigned Cost = 0; if (Legal->isMaskRequired(I)) - Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); + Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, + Alignment ? Alignment->value() : 0, AS); else - Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, I); + Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, + Alignment ? Alignment->value() : 0, AS, I); bool Reverse = ConsecutiveStride < 0; if (Reverse) @@ -5575,33 +5796,37 @@ unsigned LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, unsigned VF) { Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned Alignment = getLoadStoreAlignment(I); + const MaybeAlign Alignment = getLoadStoreAlignment(I); unsigned AS = getLoadStoreAddressSpace(I); if (isa<LoadInst>(I)) { return TTI.getAddressComputationCost(ValTy) + - TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) + + TTI.getMemoryOpCost(Instruction::Load, ValTy, + Alignment ? Alignment->value() : 0, AS) + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy); } StoreInst *SI = cast<StoreInst>(I); bool isLoopInvariantStoreValue = Legal->isUniform(SI->getValueOperand()); return TTI.getAddressComputationCost(ValTy) + - TTI.getMemoryOpCost(Instruction::Store, ValTy, Alignment, AS) + - (isLoopInvariantStoreValue ? 0 : TTI.getVectorInstrCost( - Instruction::ExtractElement, - VectorTy, VF - 1)); + TTI.getMemoryOpCost(Instruction::Store, ValTy, + Alignment ? Alignment->value() : 0, AS) + + (isLoopInvariantStoreValue + ? 0 + : TTI.getVectorInstrCost(Instruction::ExtractElement, VectorTy, + VF - 1)); } unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, unsigned VF) { Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned Alignment = getLoadStoreAlignment(I); + const MaybeAlign Alignment = getLoadStoreAlignment(I); Value *Ptr = getLoadStorePointerOperand(I); return TTI.getAddressComputationCost(VectorTy) + TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, - Legal->isMaskRequired(I), Alignment); + Legal->isMaskRequired(I), + Alignment ? Alignment->value() : 0); } unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, @@ -5626,8 +5851,8 @@ unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, } // Calculate the cost of the whole interleaved group. - bool UseMaskForGaps = - Group->requiresScalarEpilogue() && !IsScalarEpilogueAllowed; + bool UseMaskForGaps = + Group->requiresScalarEpilogue() && !isScalarEpilogueAllowed(); unsigned Cost = TTI.getInterleavedMemoryOpCost( I->getOpcode(), WideVecTy, Group->getFactor(), Indices, Group->getAlignment(), AS, Legal->isMaskRequired(I), UseMaskForGaps); @@ -5648,11 +5873,12 @@ unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, // moment. if (VF == 1) { Type *ValTy = getMemInstValueType(I); - unsigned Alignment = getLoadStoreAlignment(I); + const MaybeAlign Alignment = getLoadStoreAlignment(I); unsigned AS = getLoadStoreAddressSpace(I); return TTI.getAddressComputationCost(ValTy) + - TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, I); + TTI.getMemoryOpCost(I->getOpcode(), ValTy, + Alignment ? Alignment->value() : 0, AS, I); } return getWideningCost(I, VF); } @@ -6167,8 +6393,7 @@ static unsigned determineVPlanVF(const unsigned WidestVectorRegBits, } VectorizationFactor -LoopVectorizationPlanner::planInVPlanNativePath(bool OptForSize, - unsigned UserVF) { +LoopVectorizationPlanner::planInVPlanNativePath(unsigned UserVF) { unsigned VF = UserVF; // Outer loop handling: They may require CFG and instruction level // transformations before even evaluating whether vectorization is profitable. @@ -6207,10 +6432,9 @@ LoopVectorizationPlanner::planInVPlanNativePath(bool OptForSize, return VectorizationFactor::Disabled(); } -Optional<VectorizationFactor> LoopVectorizationPlanner::plan(bool OptForSize, - unsigned UserVF) { +Optional<VectorizationFactor> LoopVectorizationPlanner::plan(unsigned UserVF) { assert(OrigLoop->empty() && "Inner loop expected."); - Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize); + Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(); if (!MaybeMaxVF) // Cases that should not to be vectorized nor interleaved. return None; @@ -6840,8 +7064,15 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(unsigned MinVF, // If the tail is to be folded by masking, the primary induction variable // needs to be represented in VPlan for it to model early-exit masking. - if (CM.foldTailByMasking()) + // Also, both the Phi and the live-out instruction of each reduction are + // required in order to introduce a select between them in VPlan. + if (CM.foldTailByMasking()) { NeedDef.insert(Legal->getPrimaryInduction()); + for (auto &Reduction : *Legal->getReductionVars()) { + NeedDef.insert(Reduction.first); + NeedDef.insert(Reduction.second.getLoopExitInstr()); + } + } // Collect instructions from the original loop that will become trivially dead // in the vectorized loop. We don't need to vectorize these instructions. For @@ -6873,7 +7104,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // Create a dummy pre-entry VPBasicBlock to start building the VPlan. VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); - auto Plan = llvm::make_unique<VPlan>(VPBB); + auto Plan = std::make_unique<VPlan>(VPBB); VPRecipeBuilder RecipeBuilder(OrigLoop, TLI, Legal, CM, Builder); // Represent values that will have defs inside VPlan. @@ -6968,6 +7199,18 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( VPBlockUtils::disconnectBlocks(PreEntry, Entry); delete PreEntry; + // Finally, if tail is folded by masking, introduce selects between the phi + // and the live-out instruction of each reduction, at the end of the latch. + if (CM.foldTailByMasking()) { + Builder.setInsertPoint(VPBB); + auto *Cond = RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan); + for (auto &Reduction : *Legal->getReductionVars()) { + VPValue *Phi = Plan->getVPValue(Reduction.first); + VPValue *Red = Plan->getVPValue(Reduction.second.getLoopExitInstr()); + Builder.createNaryOp(Instruction::Select, {Cond, Red, Phi}); + } + } + std::string PlanName; raw_string_ostream RSO(PlanName); unsigned VF = Range.Start; @@ -6993,7 +7236,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { assert(EnableVPlanNativePath && "VPlan-native path is not enabled."); // Create new empty VPlan - auto Plan = llvm::make_unique<VPlan>(); + auto Plan = std::make_unique<VPlan>(); // Build hierarchical CFG VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan); @@ -7199,6 +7442,20 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { State.ILV->vectorizeMemoryInstruction(&Instr, &MaskValues); } +static ScalarEpilogueLowering +getScalarEpilogueLowering(Function *F, Loop *L, LoopVectorizeHints &Hints, + ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) { + ScalarEpilogueLowering SEL = CM_ScalarEpilogueAllowed; + if (Hints.getForce() != LoopVectorizeHints::FK_Enabled && + (F->hasOptSize() || + llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI))) + SEL = CM_ScalarEpilogueNotAllowedOptSize; + else if (PreferPredicateOverEpilog || Hints.getPredicate()) + SEL = CM_ScalarEpilogueNotNeededUsePredicate; + + return SEL; +} + // Process the loop in the VPlan-native vectorization path. This path builds // VPlan upfront in the vectorization pipeline, which allows to apply // VPlan-to-VPlan transformations from the very beginning without modifying the @@ -7213,7 +7470,9 @@ static bool processLoopInVPlanNativePath( assert(EnableVPlanNativePath && "VPlan-native path is disabled."); Function *F = L->getHeader()->getParent(); InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL->getLAI()); - LoopVectorizationCostModel CM(L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F, + ScalarEpilogueLowering SEL = getScalarEpilogueLowering(F, L, Hints, PSI, BFI); + + LoopVectorizationCostModel CM(SEL, L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F, &Hints, IAI); // Use the planner for outer loop vectorization. // TODO: CM is not used at this point inside the planner. Turn CM into an @@ -7223,15 +7482,8 @@ static bool processLoopInVPlanNativePath( // Get user vectorization factor. const unsigned UserVF = Hints.getWidth(); - // Check the function attributes and profiles to find out if this function - // should be optimized for size. - bool OptForSize = - Hints.getForce() != LoopVectorizeHints::FK_Enabled && - (F->hasOptSize() || - llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI)); - // Plan how to best vectorize, return the best VF and its cost. - const VectorizationFactor VF = LVP.planInVPlanNativePath(OptForSize, UserVF); + const VectorizationFactor VF = LVP.planInVPlanNativePath(UserVF); // If we are stress testing VPlan builds, do not attempt to generate vector // code. Masked vector code generation support will follow soon. @@ -7310,10 +7562,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Check the function attributes and profiles to find out if this function // should be optimized for size. - bool OptForSize = - Hints.getForce() != LoopVectorizeHints::FK_Enabled && - (F->hasOptSize() || - llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI)); + ScalarEpilogueLowering SEL = getScalarEpilogueLowering(F, L, Hints, PSI, BFI); // Entrance to the VPlan-native vectorization path. Outer loops are processed // here. They may require CFG and instruction level transformations before @@ -7325,36 +7574,11 @@ bool LoopVectorizePass::processLoop(Loop *L) { ORE, BFI, PSI, Hints); assert(L->empty() && "Inner loop expected."); + // Check the loop for a trip count threshold: vectorize loops with a tiny trip // count by optimizing for size, to minimize overheads. - // Prefer constant trip counts over profile data, over upper bound estimate. - unsigned ExpectedTC = 0; - bool HasExpectedTC = false; - if (const SCEVConstant *ConstExits = - dyn_cast<SCEVConstant>(SE->getBackedgeTakenCount(L))) { - const APInt &ExitsCount = ConstExits->getAPInt(); - // We are interested in small values for ExpectedTC. Skip over those that - // can't fit an unsigned. - if (ExitsCount.ult(std::numeric_limits<unsigned>::max())) { - ExpectedTC = static_cast<unsigned>(ExitsCount.getZExtValue()) + 1; - HasExpectedTC = true; - } - } - // ExpectedTC may be large because it's bound by a variable. Check - // profiling information to validate we should vectorize. - if (!HasExpectedTC && LoopVectorizeWithBlockFrequency) { - auto EstimatedTC = getLoopEstimatedTripCount(L); - if (EstimatedTC) { - ExpectedTC = *EstimatedTC; - HasExpectedTC = true; - } - } - if (!HasExpectedTC) { - ExpectedTC = SE->getSmallConstantMaxTripCount(L); - HasExpectedTC = (ExpectedTC > 0); - } - - if (HasExpectedTC && ExpectedTC < TinyTripCountVectorThreshold) { + auto ExpectedTC = getSmallBestKnownTC(*SE, L); + if (ExpectedTC && *ExpectedTC < TinyTripCountVectorThreshold) { LLVM_DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " << "This loop is worth vectorizing only if no scalar " << "iteration overheads are incurred."); @@ -7362,10 +7586,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { LLVM_DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); else { LLVM_DEBUG(dbgs() << "\n"); - // Loops with a very small trip count are considered for vectorization - // under OptForSize, thereby making sure the cost of their loop body is - // dominant, free of runtime guards and scalar iteration overheads. - OptForSize = true; + SEL = CM_ScalarEpilogueNotAllowedLowTripLoop; } } @@ -7374,11 +7595,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { // an integer loop and the vector instructions selected are purely integer // vector instructions? if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { - LLVM_DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" - "attribute is used.\n"); - ORE->emit(createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), - "NoImplicitFloat", L) - << "loop not vectorized due to NoImplicitFloat attribute"); + reportVectorizationFailure( + "Can't vectorize when the NoImplicitFloat attribute is used", + "loop not vectorized due to NoImplicitFloat attribute", + "NoImplicitFloat", ORE, L); Hints.emitRemarkWithHints(); return false; } @@ -7389,11 +7609,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { // additional fp-math flags can help. if (Hints.isPotentiallyUnsafe() && TTI->isFPVectorizationPotentiallyUnsafe()) { - LLVM_DEBUG( - dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n"); - ORE->emit( - createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L) - << "loop not vectorized due to unsafe FP support."); + reportVectorizationFailure( + "Potentially unsafe FP op prevents vectorization", + "loop not vectorized due to unsafe FP support.", + "UnsafeFP", ORE, L); Hints.emitRemarkWithHints(); return false; } @@ -7411,8 +7630,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { } // Use the cost model. - LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, - &Hints, IAI); + LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, + F, &Hints, IAI); CM.collectValuesToIgnore(); // Use the planner for vectorization. @@ -7422,7 +7641,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { unsigned UserVF = Hints.getWidth(); // Plan how to best vectorize, return the best VF and its cost. - Optional<VectorizationFactor> MaybeVF = LVP.plan(OptForSize, UserVF); + Optional<VectorizationFactor> MaybeVF = LVP.plan(UserVF); VectorizationFactor VF = VectorizationFactor::Disabled(); unsigned IC = 1; @@ -7431,7 +7650,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (MaybeVF) { VF = *MaybeVF; // Select the interleave count. - IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); + IC = CM.selectInterleaveCount(VF.Width, VF.Cost); } // Identify the diagnostic messages that should be produced. @@ -7609,7 +7828,8 @@ bool LoopVectorizePass::runImpl( // The second condition is necessary because, even if the target has no // vector registers, loop vectorization may still enable scalar // interleaving. - if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2) + if (!TTI->getNumberOfRegisters(TTI->getRegisterClassForType(true)) && + TTI->getMaxInterleaveFactor(1) < 2) return false; bool Changed = false; diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 27a86c0bca91..974eff9974d9 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -194,10 +194,13 @@ static bool allSameBlock(ArrayRef<Value *> VL) { return true; } -/// \returns True if all of the values in \p VL are constants. +/// \returns True if all of the values in \p VL are constants (but not +/// globals/constant expressions). static bool allConstant(ArrayRef<Value *> VL) { + // Constant expressions and globals can't be vectorized like normal integer/FP + // constants. for (Value *i : VL) - if (!isa<Constant>(i)) + if (!isa<Constant>(i) || isa<ConstantExpr>(i) || isa<GlobalValue>(i)) return false; return true; } @@ -486,6 +489,7 @@ namespace slpvectorizer { /// Bottom Up SLP Vectorizer. class BoUpSLP { struct TreeEntry; + struct ScheduleData; public: using ValueList = SmallVector<Value *, 8>; @@ -614,6 +618,15 @@ public: /// vectorizable. We do not vectorize such trees. bool isTreeTinyAndNotFullyVectorizable() const; + /// Assume that a legal-sized 'or'-reduction of shifted/zexted loaded values + /// can be load combined in the backend. Load combining may not be allowed in + /// the IR optimizer, so we do not want to alter the pattern. For example, + /// partially transforming a scalar bswap() pattern into vector code is + /// effectively impossible for the backend to undo. + /// TODO: If load combining is allowed in the IR optimizer, this analysis + /// may not be necessary. + bool isLoadCombineReductionCandidate(unsigned ReductionOpcode) const; + OptimizationRemarkEmitter *getORE() { return ORE; } /// This structure holds any data we need about the edges being traversed @@ -1117,6 +1130,14 @@ public: #endif }; + /// Checks if the instruction is marked for deletion. + bool isDeleted(Instruction *I) const { return DeletedInstructions.count(I); } + + /// Marks values operands for later deletion by replacing them with Undefs. + void eraseInstructions(ArrayRef<Value *> AV); + + ~BoUpSLP(); + private: /// Checks if all users of \p I are the part of the vectorization tree. bool areAllUsersVectorized(Instruction *I) const; @@ -1153,8 +1174,7 @@ private: /// Set the Builder insert point to one after the last instruction in /// the bundle - void setInsertPointAfterBundle(ArrayRef<Value *> VL, - const InstructionsState &S); + void setInsertPointAfterBundle(TreeEntry *E); /// \returns a vector from a collection of scalars in \p VL. Value *Gather(ArrayRef<Value *> VL, VectorType *Ty); @@ -1220,27 +1240,37 @@ private: /// reordering of operands during buildTree_rec() and vectorizeTree(). SmallVector<ValueList, 2> Operands; + /// The main/alternate instruction. + Instruction *MainOp = nullptr; + Instruction *AltOp = nullptr; + public: /// Set this bundle's \p OpIdx'th operand to \p OpVL. - void setOperand(unsigned OpIdx, ArrayRef<Value *> OpVL, - ArrayRef<unsigned> ReuseShuffleIndices) { + void setOperand(unsigned OpIdx, ArrayRef<Value *> OpVL) { if (Operands.size() < OpIdx + 1) Operands.resize(OpIdx + 1); assert(Operands[OpIdx].size() == 0 && "Already resized?"); Operands[OpIdx].resize(Scalars.size()); for (unsigned Lane = 0, E = Scalars.size(); Lane != E; ++Lane) - Operands[OpIdx][Lane] = (!ReuseShuffleIndices.empty()) - ? OpVL[ReuseShuffleIndices[Lane]] - : OpVL[Lane]; - } - - /// If there is a user TreeEntry, then set its operand. - void trySetUserTEOperand(const EdgeInfo &UserTreeIdx, - ArrayRef<Value *> OpVL, - ArrayRef<unsigned> ReuseShuffleIndices) { - if (UserTreeIdx.UserTE) - UserTreeIdx.UserTE->setOperand(UserTreeIdx.EdgeIdx, OpVL, - ReuseShuffleIndices); + Operands[OpIdx][Lane] = OpVL[Lane]; + } + + /// Set the operands of this bundle in their original order. + void setOperandsInOrder() { + assert(Operands.empty() && "Already initialized?"); + auto *I0 = cast<Instruction>(Scalars[0]); + Operands.resize(I0->getNumOperands()); + unsigned NumLanes = Scalars.size(); + for (unsigned OpIdx = 0, NumOperands = I0->getNumOperands(); + OpIdx != NumOperands; ++OpIdx) { + Operands[OpIdx].resize(NumLanes); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + auto *I = cast<Instruction>(Scalars[Lane]); + assert(I->getNumOperands() == NumOperands && + "Expected same number of operands"); + Operands[OpIdx][Lane] = I->getOperand(OpIdx); + } + } } /// \returns the \p OpIdx operand of this TreeEntry. @@ -1249,6 +1279,9 @@ private: return Operands[OpIdx]; } + /// \returns the number of operands. + unsigned getNumOperands() const { return Operands.size(); } + /// \return the single \p OpIdx operand. Value *getSingleOperand(unsigned OpIdx) const { assert(OpIdx < Operands.size() && "Off bounds"); @@ -1256,6 +1289,58 @@ private: return Operands[OpIdx][0]; } + /// Some of the instructions in the list have alternate opcodes. + bool isAltShuffle() const { + return getOpcode() != getAltOpcode(); + } + + bool isOpcodeOrAlt(Instruction *I) const { + unsigned CheckedOpcode = I->getOpcode(); + return (getOpcode() == CheckedOpcode || + getAltOpcode() == CheckedOpcode); + } + + /// Chooses the correct key for scheduling data. If \p Op has the same (or + /// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is + /// \p OpValue. + Value *isOneOf(Value *Op) const { + auto *I = dyn_cast<Instruction>(Op); + if (I && isOpcodeOrAlt(I)) + return Op; + return MainOp; + } + + void setOperations(const InstructionsState &S) { + MainOp = S.MainOp; + AltOp = S.AltOp; + } + + Instruction *getMainOp() const { + return MainOp; + } + + Instruction *getAltOp() const { + return AltOp; + } + + /// The main/alternate opcodes for the list of instructions. + unsigned getOpcode() const { + return MainOp ? MainOp->getOpcode() : 0; + } + + unsigned getAltOpcode() const { + return AltOp ? AltOp->getOpcode() : 0; + } + + /// Update operations state of this entry if reorder occurred. + bool updateStateIfReorder() { + if (ReorderIndices.empty()) + return false; + InstructionsState S = getSameOpcode(Scalars, ReorderIndices.front()); + setOperations(S); + return true; + } + #ifndef NDEBUG /// Debug printer. LLVM_DUMP_METHOD void dump() const { @@ -1269,6 +1354,8 @@ private: for (Value *V : Scalars) dbgs().indent(2) << *V << "\n"; dbgs() << "NeedToGather: " << NeedToGather << "\n"; + dbgs() << "MainOp: " << *MainOp << "\n"; + dbgs() << "AltOp: " << *AltOp << "\n"; dbgs() << "VectorizedValue: "; if (VectorizedValue) dbgs() << *VectorizedValue; @@ -1279,12 +1366,12 @@ private: if (ReuseShuffleIndices.empty()) dbgs() << "Emtpy"; else - for (unsigned Idx : ReuseShuffleIndices) - dbgs() << Idx << ", "; + for (unsigned ReuseIdx : ReuseShuffleIndices) + dbgs() << ReuseIdx << ", "; dbgs() << "\n"; dbgs() << "ReorderIndices: "; - for (unsigned Idx : ReorderIndices) - dbgs() << Idx << ", "; + for (unsigned ReorderIdx : ReorderIndices) + dbgs() << ReorderIdx << ", "; dbgs() << "\n"; dbgs() << "UserTreeIndices: "; for (const auto &EInfo : UserTreeIndices) @@ -1295,11 +1382,13 @@ private: }; /// Create a new VectorizableTree entry. - TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, + TreeEntry *newTreeEntry(ArrayRef<Value *> VL, Optional<ScheduleData *> Bundle, + const InstructionsState &S, const EdgeInfo &UserTreeIdx, ArrayRef<unsigned> ReuseShuffleIndices = None, ArrayRef<unsigned> ReorderIndices = None) { - VectorizableTree.push_back(llvm::make_unique<TreeEntry>(VectorizableTree)); + bool Vectorized = (bool)Bundle; + VectorizableTree.push_back(std::make_unique<TreeEntry>(VectorizableTree)); TreeEntry *Last = VectorizableTree.back().get(); Last->Idx = VectorizableTree.size() - 1; Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); @@ -1307,11 +1396,22 @@ private: Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(), ReuseShuffleIndices.end()); Last->ReorderIndices = ReorderIndices; + Last->setOperations(S); if (Vectorized) { for (int i = 0, e = VL.size(); i != e; ++i) { assert(!getTreeEntry(VL[i]) && "Scalar already in tree!"); - ScalarToTreeEntry[VL[i]] = Last->Idx; - } + ScalarToTreeEntry[VL[i]] = Last; + } + // Update the scheduler bundle to point to this TreeEntry. + unsigned Lane = 0; + for (ScheduleData *BundleMember = Bundle.getValue(); BundleMember; + BundleMember = BundleMember->NextInBundle) { + BundleMember->TE = Last; + BundleMember->Lane = Lane; + ++Lane; + } + assert((!Bundle.getValue() || Lane == VL.size()) && + "Bundle and VL out of sync"); } else { MustGather.insert(VL.begin(), VL.end()); } @@ -1319,7 +1419,6 @@ private: if (UserTreeIdx.UserTE) Last->UserTreeIndices.push_back(UserTreeIdx); - Last->trySetUserTEOperand(UserTreeIdx, VL, ReuseShuffleIndices); return Last; } @@ -1340,19 +1439,19 @@ private: TreeEntry *getTreeEntry(Value *V) { auto I = ScalarToTreeEntry.find(V); if (I != ScalarToTreeEntry.end()) - return VectorizableTree[I->second].get(); + return I->second; return nullptr; } const TreeEntry *getTreeEntry(Value *V) const { auto I = ScalarToTreeEntry.find(V); if (I != ScalarToTreeEntry.end()) - return VectorizableTree[I->second].get(); + return I->second; return nullptr; } /// Maps a specific scalar to its tree entry. - SmallDenseMap<Value*, int> ScalarToTreeEntry; + SmallDenseMap<Value*, TreeEntry *> ScalarToTreeEntry; /// A list of scalars that we found that we need to keep as scalars. ValueSet MustGather; @@ -1408,15 +1507,14 @@ private: /// This is required to ensure that there are no incorrect collisions in the /// AliasCache, which can happen if a new instruction is allocated at the /// same address as a previously deleted instruction. - void eraseInstruction(Instruction *I) { - I->removeFromParent(); - I->dropAllReferences(); - DeletedInstructions.emplace_back(I); + void eraseInstruction(Instruction *I, bool ReplaceOpsWithUndef = false) { + auto It = DeletedInstructions.try_emplace(I, ReplaceOpsWithUndef).first; + It->getSecond() = It->getSecond() && ReplaceOpsWithUndef; } /// Temporary store for deleted instructions. Instructions will be deleted /// eventually when the BoUpSLP is destructed. - SmallVector<unique_value, 8> DeletedInstructions; + DenseMap<Instruction *, bool> DeletedInstructions; /// A list of values that need to extracted out of the tree. /// This list holds pairs of (Internal Scalar : External User). External User @@ -1453,6 +1551,8 @@ private: UnscheduledDepsInBundle = UnscheduledDeps; clearDependencies(); OpValue = OpVal; + TE = nullptr; + Lane = -1; } /// Returns true if the dependency information has been calculated. @@ -1559,6 +1659,12 @@ private: /// Opcode of the current instruction in the schedule data. Value *OpValue = nullptr; + + /// The TreeEntry that this instruction corresponds to. + TreeEntry *TE = nullptr; + + /// The lane of this node in the TreeEntry. + int Lane = -1; }; #ifndef NDEBUG @@ -1633,10 +1739,9 @@ private: continue; } // Handle the def-use chain dependencies. - for (Use &U : BundleMember->Inst->operands()) { - auto *I = dyn_cast<Instruction>(U.get()); - if (!I) - continue; + + // Decrement the unscheduled counter and insert to ready list if ready. + auto &&DecrUnsched = [this, &ReadyList](Instruction *I) { doForAllOpcodes(I, [&ReadyList](ScheduleData *OpDef) { if (OpDef && OpDef->hasValidDependencies() && OpDef->incrementUnscheduledDeps(-1) == 0) { @@ -1651,6 +1756,24 @@ private: << "SLP: gets ready (def): " << *DepBundle << "\n"); } }); + }; + + // If BundleMember is a vector bundle, its operands may have been + // reordered duiring buildTree(). We therefore need to get its operands + // through the TreeEntry. + if (TreeEntry *TE = BundleMember->TE) { + int Lane = BundleMember->Lane; + assert(Lane >= 0 && "Lane not set"); + for (unsigned OpIdx = 0, NumOperands = TE->getNumOperands(); + OpIdx != NumOperands; ++OpIdx) + if (auto *I = dyn_cast<Instruction>(TE->getOperand(OpIdx)[Lane])) + DecrUnsched(I); + } else { + // If BundleMember is a stand-alone instruction, no operand reordering + // has taken place, so we directly access its operands. + for (Use &U : BundleMember->Inst->operands()) + if (auto *I = dyn_cast<Instruction>(U.get())) + DecrUnsched(I); } // Handle the memory dependencies. for (ScheduleData *MemoryDepSD : BundleMember->MemoryDependencies) { @@ -1697,8 +1820,11 @@ private: /// Checks if a bundle of instructions can be scheduled, i.e. has no /// cyclic dependencies. This is only a dry-run, no instructions are /// actually moved at this stage. - bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, - const InstructionsState &S); + /// \returns the scheduling bundle. The returned Optional value is non-None + /// if \p VL is allowed to be scheduled. + Optional<ScheduleData *> + tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, + const InstructionsState &S); /// Un-bundles a group of instructions. void cancelScheduling(ArrayRef<Value *> VL, Value *OpValue); @@ -1945,6 +2071,30 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { } // end namespace llvm +BoUpSLP::~BoUpSLP() { + for (const auto &Pair : DeletedInstructions) { + // Replace operands of ignored instructions with Undefs in case if they were + // marked for deletion. + if (Pair.getSecond()) { + Value *Undef = UndefValue::get(Pair.getFirst()->getType()); + Pair.getFirst()->replaceAllUsesWith(Undef); + } + Pair.getFirst()->dropAllReferences(); + } + for (const auto &Pair : DeletedInstructions) { + assert(Pair.getFirst()->use_empty() && + "trying to erase instruction with users."); + Pair.getFirst()->eraseFromParent(); + } +} + +void BoUpSLP::eraseInstructions(ArrayRef<Value *> AV) { + for (auto *V : AV) { + if (auto *I = dyn_cast<Instruction>(V)) + eraseInstruction(I, /*ReplaceWithUndef=*/true); + }; +} + void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst) { ExtraValueToDebugLocsMap ExternallyUsedValues; @@ -2026,28 +2176,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, InstructionsState S = getSameOpcode(VL); if (Depth == RecursionMaxDepth) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } // Don't handle vectors. if (S.OpValue->getType()->isVectorTy()) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) if (SI->getValueOperand()->getType()->isVectorTy()) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } // If all of the operands are identical or constant we have a simple solution. if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode()) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } @@ -2055,11 +2205,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // the same block. // Don't vectorize ephemeral values. - for (unsigned i = 0, e = VL.size(); i != e; ++i) { - if (EphValues.count(VL[i])) { - LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] + for (Value *V : VL) { + if (EphValues.count(V)) { + LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V << ") is ephemeral.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } } @@ -2069,7 +2219,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.OpValue << ".\n"); if (!E->isSame(VL)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } // Record the reuse of the tree node. FIXME, currently this is only used to @@ -2077,19 +2227,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, E->UserTreeIndices.push_back(UserTreeIdx); LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue << ".\n"); - E->trySetUserTEOperand(UserTreeIdx, VL, None); return; } // Check that none of the instructions in the bundle are already in the tree. - for (unsigned i = 0, e = VL.size(); i != e; ++i) { - auto *I = dyn_cast<Instruction>(VL[i]); + for (Value *V : VL) { + auto *I = dyn_cast<Instruction>(V); if (!I) continue; if (getTreeEntry(I)) { - LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] + LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V << ") is already in tree.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } } @@ -2097,10 +2246,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // If any of the scalars is marked as a value that needs to stay scalar, then // we need to gather the scalars. // The reduction nodes (stored in UserIgnoreList) also should stay scalar. - for (unsigned i = 0, e = VL.size(); i != e; ++i) { - if (MustGather.count(VL[i]) || is_contained(UserIgnoreList, VL[i])) { + for (Value *V : VL) { + if (MustGather.count(V) || is_contained(UserIgnoreList, V)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } } @@ -2114,7 +2263,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Don't go into unreachable blocks. They may contain instructions with // dependency cycles which confuse the final scheduling. LLVM_DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } @@ -2128,13 +2277,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (Res.second) UniqueValues.emplace_back(V); } - if (UniqueValues.size() == VL.size()) { + size_t NumUniqueScalarValues = UniqueValues.size(); + if (NumUniqueScalarValues == VL.size()) { ReuseShuffleIndicies.clear(); } else { LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n"); - if (UniqueValues.size() <= 1 || !llvm::isPowerOf2_32(UniqueValues.size())) { + if (NumUniqueScalarValues <= 1 || + !llvm::isPowerOf2_32(NumUniqueScalarValues)) { LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); return; } VL = UniqueValues; @@ -2142,16 +2293,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, auto &BSRef = BlocksSchedules[BB]; if (!BSRef) - BSRef = llvm::make_unique<BlockScheduling>(BB); + BSRef = std::make_unique<BlockScheduling>(BB); BlockScheduling &BS = *BSRef.get(); - if (!BS.tryScheduleBundle(VL, this, S)) { + Optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S); + if (!Bundle) { LLVM_DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); assert((!BS.getScheduleData(VL0) || !BS.getScheduleData(VL0)->isPartOfBundle()) && "tryScheduleBundle should cancelScheduling on failure"); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } LLVM_DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); @@ -2160,7 +2313,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: { - PHINode *PH = dyn_cast<PHINode>(VL0); + auto *PH = cast<PHINode>(VL0); // Check for terminator values (e.g. invoke). for (unsigned j = 0; j < VL.size(); ++j) @@ -2172,23 +2325,29 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (terminator use).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = + newTreeEntry(VL, Bundle, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); + // Keeps the reordered operands to avoid code duplication. + SmallVector<ValueList, 2> OperandsVec; for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. for (Value *j : VL) Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock( PH->getIncomingBlock(i))); - - buildTree_rec(Operands, Depth + 1, {TE, i}); + TE->setOperand(i, Operands); + OperandsVec.push_back(Operands); } + for (unsigned OpIdx = 0, OpE = OperandsVec.size(); OpIdx != OpE; ++OpIdx) + buildTree_rec(OperandsVec[OpIdx], Depth + 1, {TE, OpIdx}); return; } case Instruction::ExtractValue: @@ -2198,13 +2357,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (Reuse) { LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n"); ++NumOpsWantToKeepOriginalOrder; - newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, + newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); // This is a special case, as it does not gather, but at the same time // we are not extending buildTree_rec() towards the operands. ValueList Op0; Op0.assign(VL.size(), VL0->getOperand(0)); - VectorizableTree.back()->setOperand(0, Op0, ReuseShuffleIndicies); + VectorizableTree.back()->setOperand(0, Op0); return; } if (!CurrentOrder.empty()) { @@ -2220,17 +2379,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, auto StoredCurrentOrderAndNum = NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first; ++StoredCurrentOrderAndNum->getSecond(); - newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, ReuseShuffleIndicies, + newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies, StoredCurrentOrderAndNum->getFirst()); // This is a special case, as it does not gather, but at the same time // we are not extending buildTree_rec() towards the operands. ValueList Op0; Op0.assign(VL.size(), VL0->getOperand(0)); - VectorizableTree.back()->setOperand(0, Op0, ReuseShuffleIndicies); + VectorizableTree.back()->setOperand(0, Op0); return; } LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); - newTreeEntry(VL, /*Vectorized=*/false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); BS.cancelScheduling(VL, VL0); return; } @@ -2246,7 +2407,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); return; } @@ -2259,7 +2421,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, auto *L = cast<LoadInst>(V); if (!L->isSimple()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); return; } @@ -2289,15 +2452,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (CurrentOrder.empty()) { // Original loads are consecutive and does not require reordering. ++NumOpsWantToKeepOriginalOrder; - newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, - ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, + UserTreeIdx, ReuseShuffleIndicies); + TE->setOperandsInOrder(); LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); } else { // Need to reorder. auto I = NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first; ++I->getSecond(); - newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, - ReuseShuffleIndicies, I->getFirst()); + TreeEntry *TE = + newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies, I->getFirst()); + TE->setOperandsInOrder(); LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); } return; @@ -2306,7 +2472,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } case Instruction::ZExt: @@ -2322,24 +2489,27 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::FPTrunc: case Instruction::BitCast: { Type *SrcTy = VL0->getOperand(0)->getType(); - for (unsigned i = 0; i < VL.size(); ++i) { - Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType(); + for (Value *V : VL) { + Type *Ty = cast<Instruction>(V)->getOperand(0)->getType(); if (Ty != SrcTy || !isValidElementType(Ty)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n"); return; } } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n"); + TE->setOperandsInOrder(); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (Value *j : VL) - Operands.push_back(cast<Instruction>(j)->getOperand(i)); + for (Value *V : VL) + Operands.push_back(cast<Instruction>(V)->getOperand(i)); buildTree_rec(Operands, Depth + 1, {TE, i}); } @@ -2351,19 +2521,21 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0); Type *ComparedTy = VL0->getOperand(0)->getType(); - for (unsigned i = 1, e = VL.size(); i < e; ++i) { - CmpInst *Cmp = cast<CmpInst>(VL[i]); + for (Value *V : VL) { + CmpInst *Cmp = cast<CmpInst>(V); if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) || Cmp->getOperand(0)->getType() != ComparedTy) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); return; } } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n"); ValueList Left, Right; @@ -2384,7 +2556,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Right.push_back(RHS); } } - + TE->setOperand(0, Left); + TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); buildTree_rec(Right, Depth + 1, {TE, 1}); return; @@ -2409,7 +2582,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::And: case Instruction::Or: case Instruction::Xor: { - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of un/bin op.\n"); // Sort operands of the instructions so that each side is more likely to @@ -2417,11 +2591,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE); + TE->setOperand(0, Left); + TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); buildTree_rec(Right, Depth + 1, {TE, 1}); return; } + TE->setOperandsInOrder(); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. @@ -2434,11 +2611,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } case Instruction::GetElementPtr: { // We don't combine GEPs with complicated (nested) indexing. - for (unsigned j = 0; j < VL.size(); ++j) { - if (cast<Instruction>(VL[j])->getNumOperands() != 2) { + for (Value *V : VL) { + if (cast<Instruction>(V)->getNumOperands() != 2) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } } @@ -2446,58 +2624,64 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // We can't combine several GEPs into one vector if they operate on // different types. Type *Ty0 = VL0->getOperand(0)->getType(); - for (unsigned j = 0; j < VL.size(); ++j) { - Type *CurTy = cast<Instruction>(VL[j])->getOperand(0)->getType(); + for (Value *V : VL) { + Type *CurTy = cast<Instruction>(V)->getOperand(0)->getType(); if (Ty0 != CurTy) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } } // We don't combine GEPs with non-constant indexes. - for (unsigned j = 0; j < VL.size(); ++j) { - auto Op = cast<Instruction>(VL[j])->getOperand(1); + for (Value *V : VL) { + auto Op = cast<Instruction>(V)->getOperand(1); if (!isa<ConstantInt>(Op)) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); return; } } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); + TE->setOperandsInOrder(); for (unsigned i = 0, e = 2; i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (Value *j : VL) - Operands.push_back(cast<Instruction>(j)->getOperand(i)); + for (Value *V : VL) + Operands.push_back(cast<Instruction>(V)->getOperand(i)); buildTree_rec(Operands, Depth + 1, {TE, i}); } return; } case Instruction::Store: { - // Check if the stores are consecutive or of we need to swizzle them. + // Check if the stores are consecutive or if we need to swizzle them. for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); ValueList Operands; - for (Value *j : VL) - Operands.push_back(cast<Instruction>(j)->getOperand(0)); - + for (Value *V : VL) + Operands.push_back(cast<Instruction>(V)->getOperand(0)); + TE->setOperandsInOrder(); buildTree_rec(Operands, Depth + 1, {TE, 0}); return; } @@ -2509,7 +2693,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; } @@ -2519,14 +2704,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (unsigned j = 0; j != NumArgs; ++j) if (hasVectorInstrinsicScalarOpd(ID, j)) ScalarArgs[j] = CI->getArgOperand(j); - for (unsigned i = 1, e = VL.size(); i != e; ++i) { - CallInst *CI2 = dyn_cast<CallInst>(VL[i]); + for (Value *V : VL) { + CallInst *CI2 = dyn_cast<CallInst>(V); if (!CI2 || CI2->getCalledFunction() != Int || getVectorIntrinsicIDForCall(CI2, TLI) != ID || !CI->hasIdenticalOperandBundleSchema(*CI2)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V << "\n"); return; } @@ -2537,7 +2723,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Value *A1J = CI2->getArgOperand(j); if (ScalarArgs[j] != A1J) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI << " argument " << ScalarArgs[j] << "!=" << A1J << "\n"); @@ -2551,19 +2738,22 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CI->op_begin() + CI->getBundleOperandsEndIndex(), CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" - << *CI << "!=" << *VL[i] << '\n'); + << *CI << "!=" << *V << '\n'); return; } } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + TE->setOperandsInOrder(); for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. - for (Value *j : VL) { - CallInst *CI2 = dyn_cast<CallInst>(j); + for (Value *V : VL) { + auto *CI2 = cast<CallInst>(V); Operands.push_back(CI2->getArgOperand(i)); } buildTree_rec(Operands, Depth + 1, {TE, i}); @@ -2575,27 +2765,32 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // then do not vectorize this instruction. if (!S.isAltShuffle()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; } - auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. if (isa<BinaryOperator>(VL0)) { ValueList Left, Right; reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE); + TE->setOperand(0, Left); + TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); buildTree_rec(Right, Depth + 1, {TE, 1}); return; } + TE->setOperandsInOrder(); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (Value *j : VL) - Operands.push_back(cast<Instruction>(j)->getOperand(i)); + for (Value *V : VL) + Operands.push_back(cast<Instruction>(V)->getOperand(i)); buildTree_rec(Operands, Depth + 1, {TE, i}); } @@ -2603,7 +2798,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } default: BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); return; } @@ -2738,7 +2934,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { return ReuseShuffleCost + TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0); } - if (getSameOpcode(VL).getOpcode() == Instruction::ExtractElement && + if (E->getOpcode() == Instruction::ExtractElement && allSameType(VL) && allSameBlock(VL)) { Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = isShuffle(VL); if (ShuffleKind.hasValue()) { @@ -2761,11 +2957,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { } return ReuseShuffleCost + getGatherCost(VL); } - InstructionsState S = getSameOpcode(VL); - assert(S.getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); - Instruction *VL0 = cast<Instruction>(S.OpValue); - unsigned ShuffleOrOp = S.isAltShuffle() ? - (unsigned) Instruction::ShuffleVector : S.getOpcode(); + assert(E->getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); + Instruction *VL0 = E->getMainOp(); + unsigned ShuffleOrOp = + E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: return 0; @@ -2851,7 +3046,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { case Instruction::BitCast: { Type *SrcTy = VL0->getOperand(0)->getType(); int ScalarEltCost = - TTI->getCastInstrCost(S.getOpcode(), ScalarTy, SrcTy, VL0); + TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, VL0); if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; } @@ -2864,7 +3059,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // Check if the values are candidates to demote. if (!MinBWs.count(VL0) || VecTy != SrcVecTy) { VecCost = ReuseShuffleCost + - TTI->getCastInstrCost(S.getOpcode(), VecTy, SrcVecTy, VL0); + TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, VL0); } return VecCost - ScalarCost; } @@ -2872,14 +3067,14 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { case Instruction::ICmp: case Instruction::Select: { // Calculate the cost of this instruction. - int ScalarEltCost = TTI->getCmpSelInstrCost(S.getOpcode(), ScalarTy, + int ScalarEltCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, Builder.getInt1Ty(), VL0); if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; } VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); int ScalarCost = VecTy->getNumElements() * ScalarEltCost; - int VecCost = TTI->getCmpSelInstrCost(S.getOpcode(), VecTy, MaskTy, VL0); + int VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, VL0); return ReuseShuffleCost + VecCost - ScalarCost; } case Instruction::FNeg: @@ -2940,12 +3135,12 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { SmallVector<const Value *, 4> Operands(VL0->operand_values()); int ScalarEltCost = TTI->getArithmeticInstrCost( - S.getOpcode(), ScalarTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands); + E->getOpcode(), ScalarTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands); if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; } int ScalarCost = VecTy->getNumElements() * ScalarEltCost; - int VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy, Op1VK, + int VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands); return ReuseShuffleCost + VecCost - ScalarCost; } @@ -3027,11 +3222,11 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { return ReuseShuffleCost + VecCallCost - ScalarCallCost; } case Instruction::ShuffleVector: { - assert(S.isAltShuffle() && - ((Instruction::isBinaryOp(S.getOpcode()) && - Instruction::isBinaryOp(S.getAltOpcode())) || - (Instruction::isCast(S.getOpcode()) && - Instruction::isCast(S.getAltOpcode()))) && + assert(E->isAltShuffle() && + ((Instruction::isBinaryOp(E->getOpcode()) && + Instruction::isBinaryOp(E->getAltOpcode())) || + (Instruction::isCast(E->getOpcode()) && + Instruction::isCast(E->getAltOpcode()))) && "Invalid Shuffle Vector Operand"); int ScalarCost = 0; if (NeedToShuffleReuses) { @@ -3046,25 +3241,25 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { I, TargetTransformInfo::TCK_RecipThroughput); } } - for (Value *i : VL) { - Instruction *I = cast<Instruction>(i); - assert(S.isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); + for (Value *V : VL) { + Instruction *I = cast<Instruction>(V); + assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); ScalarCost += TTI->getInstructionCost( I, TargetTransformInfo::TCK_RecipThroughput); } // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. int VecCost = 0; - if (Instruction::isBinaryOp(S.getOpcode())) { - VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy); - VecCost += TTI->getArithmeticInstrCost(S.getAltOpcode(), VecTy); + if (Instruction::isBinaryOp(E->getOpcode())) { + VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy); + VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy); } else { - Type *Src0SclTy = S.MainOp->getOperand(0)->getType(); - Type *Src1SclTy = S.AltOp->getOperand(0)->getType(); + Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType(); + Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType(); VectorType *Src0Ty = VectorType::get(Src0SclTy, VL.size()); VectorType *Src1Ty = VectorType::get(Src1SclTy, VL.size()); - VecCost = TTI->getCastInstrCost(S.getOpcode(), VecTy, Src0Ty); - VecCost += TTI->getCastInstrCost(S.getAltOpcode(), VecTy, Src1Ty); + VecCost = TTI->getCastInstrCost(E->getOpcode(), VecTy, Src0Ty); + VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty); } VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0); return ReuseShuffleCost + VecCost - ScalarCost; @@ -3098,6 +3293,43 @@ bool BoUpSLP::isFullyVectorizableTinyTree() const { return true; } +bool BoUpSLP::isLoadCombineReductionCandidate(unsigned RdxOpcode) const { + if (RdxOpcode != Instruction::Or) + return false; + + unsigned NumElts = VectorizableTree[0]->Scalars.size(); + Value *FirstReduced = VectorizableTree[0]->Scalars[0]; + + // Look past the reduction to find a source value. Arbitrarily follow the + // path through operand 0 of any 'or'. Also, peek through optional + // shift-left-by-constant. + Value *ZextLoad = FirstReduced; + while (match(ZextLoad, m_Or(m_Value(), m_Value())) || + match(ZextLoad, m_Shl(m_Value(), m_Constant()))) + ZextLoad = cast<BinaryOperator>(ZextLoad)->getOperand(0); + + // Check if the input to the reduction is an extended load. + Value *LoadPtr; + if (!match(ZextLoad, m_ZExt(m_Load(m_Value(LoadPtr))))) + return false; + + // Require that the total load bit width is a legal integer type. + // For example, <8 x i8> --> i64 is a legal integer on a 64-bit target. + // But <16 x i8> --> i128 is not, so the backend probably can't reduce it. + Type *SrcTy = LoadPtr->getType()->getPointerElementType(); + unsigned LoadBitWidth = SrcTy->getIntegerBitWidth() * NumElts; + LLVMContext &Context = FirstReduced->getContext(); + if (!TTI->isTypeLegal(IntegerType::get(Context, LoadBitWidth))) + return false; + + // Everything matched - assume that we can fold the whole sequence using + // load combining. + LLVM_DEBUG(dbgs() << "SLP: Assume load combining for scalar reduction of " + << *(cast<Instruction>(FirstReduced)) << "\n"); + + return true; +} + bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const { // We can vectorize the tree if its size is greater than or equal to the // minimum size specified by the MinTreeSize command line option. @@ -3319,16 +3551,16 @@ void BoUpSLP::reorderInputsAccordingToOpcode( Right = Ops.getVL(1); } -void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, - const InstructionsState &S) { +void BoUpSLP::setInsertPointAfterBundle(TreeEntry *E) { // Get the basic block this bundle is in. All instructions in the bundle // should be in this block. - auto *Front = cast<Instruction>(S.OpValue); + auto *Front = E->getMainOp(); auto *BB = Front->getParent(); - assert(llvm::all_of(make_range(VL.begin(), VL.end()), [=](Value *V) -> bool { - auto *I = cast<Instruction>(V); - return !S.isOpcodeOrAlt(I) || I->getParent() == BB; - })); + assert(llvm::all_of(make_range(E->Scalars.begin(), E->Scalars.end()), + [=](Value *V) -> bool { + auto *I = cast<Instruction>(V); + return !E->isOpcodeOrAlt(I) || I->getParent() == BB; + })); // The last instruction in the bundle in program order. Instruction *LastInst = nullptr; @@ -3339,7 +3571,7 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, // bundle. The end of the bundle is marked by null ScheduleData. if (BlocksSchedules.count(BB)) { auto *Bundle = - BlocksSchedules[BB]->getScheduleData(isOneOf(S, VL.back())); + BlocksSchedules[BB]->getScheduleData(E->isOneOf(E->Scalars.back())); if (Bundle && Bundle->isPartOfBundle()) for (; Bundle; Bundle = Bundle->NextInBundle) if (Bundle->OpValue == Bundle->Inst) @@ -3365,14 +3597,15 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, // we both exit early from buildTree_rec and that the bundle be out-of-order // (causing us to iterate all the way to the end of the block). if (!LastInst) { - SmallPtrSet<Value *, 16> Bundle(VL.begin(), VL.end()); + SmallPtrSet<Value *, 16> Bundle(E->Scalars.begin(), E->Scalars.end()); for (auto &I : make_range(BasicBlock::iterator(Front), BB->end())) { - if (Bundle.erase(&I) && S.isOpcodeOrAlt(&I)) + if (Bundle.erase(&I) && E->isOpcodeOrAlt(&I)) LastInst = &I; if (Bundle.empty()) break; } } + assert(LastInst && "Failed to find last instruction in bundle"); // Set the insertion point after the last instruction in the bundle. Set the // debug location to Front. @@ -3385,7 +3618,7 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) { // Generate the 'InsertElement' instruction. for (unsigned i = 0; i < Ty->getNumElements(); ++i) { Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i)); - if (Instruction *Insrt = dyn_cast<Instruction>(Vec)) { + if (auto *Insrt = dyn_cast<InsertElementInst>(Vec)) { GatherSeq.insert(Insrt); CSEBlocks.insert(Insrt->getParent()); @@ -3494,8 +3727,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return E->VectorizedValue; } - InstructionsState S = getSameOpcode(E->Scalars); - Instruction *VL0 = cast<Instruction>(S.OpValue); + Instruction *VL0 = E->getMainOp(); Type *ScalarTy = VL0->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL0)) ScalarTy = SI->getValueOperand()->getType(); @@ -3504,7 +3736,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); if (E->NeedToGather) { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); auto *V = Gather(E->Scalars, VecTy); if (NeedToShuffleReuses) { V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), @@ -3518,11 +3750,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } - unsigned ShuffleOrOp = S.isAltShuffle() ? - (unsigned) Instruction::ShuffleVector : S.getOpcode(); + unsigned ShuffleOrOp = + E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: { - PHINode *PH = dyn_cast<PHINode>(VL0); + auto *PH = cast<PHINode>(VL0); Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues()); @@ -3577,7 +3809,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { E->VectorizedValue = V; return V; } - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); auto *V = Gather(E->Scalars, VecTy); if (NeedToShuffleReuses) { V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), @@ -3612,7 +3844,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { E->VectorizedValue = NewV; return NewV; } - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); auto *V = Gather(E->Scalars, VecTy); if (NeedToShuffleReuses) { V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), @@ -3637,7 +3869,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *InVec = vectorizeTree(E->getOperand(0)); @@ -3646,7 +3878,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return E->VectorizedValue; } - CastInst *CI = dyn_cast<CastInst>(VL0); + auto *CI = cast<CastInst>(VL0); Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy); if (NeedToShuffleReuses) { V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), @@ -3658,7 +3890,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::FCmp: case Instruction::ICmp: { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *L = vectorizeTree(E->getOperand(0)); Value *R = vectorizeTree(E->getOperand(1)); @@ -3670,7 +3902,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); Value *V; - if (S.getOpcode() == Instruction::FCmp) + if (E->getOpcode() == Instruction::FCmp) V = Builder.CreateFCmp(P0, L, R); else V = Builder.CreateICmp(P0, L, R); @@ -3685,7 +3917,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::Select: { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *Cond = vectorizeTree(E->getOperand(0)); Value *True = vectorizeTree(E->getOperand(1)); @@ -3706,7 +3938,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::FNeg: { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *Op = vectorizeTree(E->getOperand(0)); @@ -3716,7 +3948,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V = Builder.CreateUnOp( - static_cast<Instruction::UnaryOps>(S.getOpcode()), Op); + static_cast<Instruction::UnaryOps>(E->getOpcode()), Op); propagateIRFlags(V, E->Scalars, VL0); if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); @@ -3748,7 +3980,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::And: case Instruction::Or: case Instruction::Xor: { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *LHS = vectorizeTree(E->getOperand(0)); Value *RHS = vectorizeTree(E->getOperand(1)); @@ -3759,7 +3991,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V = Builder.CreateBinOp( - static_cast<Instruction::BinaryOps>(S.getOpcode()), LHS, RHS); + static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, + RHS); propagateIRFlags(V, E->Scalars, VL0); if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); @@ -3776,12 +4009,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Load: { // Loads are inserted at the head of the tree because we don't want to // sink them all the way down past store instructions. - bool IsReorder = !E->ReorderIndices.empty(); - if (IsReorder) { - S = getSameOpcode(E->Scalars, E->ReorderIndices.front()); - VL0 = cast<Instruction>(S.OpValue); - } - setInsertPointAfterBundle(E->Scalars, S); + bool IsReorder = E->updateStateIfReorder(); + if (IsReorder) + VL0 = E->getMainOp(); + setInsertPointAfterBundle(E); LoadInst *LI = cast<LoadInst>(VL0); Type *ScalarLoadTy = LI->getType(); @@ -3797,11 +4028,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (getTreeEntry(PO)) ExternalUses.push_back(ExternalUser(PO, cast<User>(VecPtr), 0)); - unsigned Alignment = LI->getAlignment(); + MaybeAlign Alignment = MaybeAlign(LI->getAlignment()); LI = Builder.CreateLoad(VecTy, VecPtr); - if (!Alignment) { - Alignment = DL->getABITypeAlignment(ScalarLoadTy); - } + if (!Alignment) + Alignment = MaybeAlign(DL->getABITypeAlignment(ScalarLoadTy)); LI->setAlignment(Alignment); Value *V = propagateMetadata(LI, E->Scalars); if (IsReorder) { @@ -3824,7 +4054,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { unsigned Alignment = SI->getAlignment(); unsigned AS = SI->getPointerAddressSpace(); - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *VecValue = vectorizeTree(E->getOperand(0)); Value *ScalarPtr = SI->getPointerOperand(); @@ -3840,7 +4070,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (!Alignment) Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType()); - ST->setAlignment(Alignment); + ST->setAlignment(Align(Alignment)); Value *V = propagateMetadata(ST, E->Scalars); if (NeedToShuffleReuses) { V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), @@ -3851,7 +4081,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::GetElementPtr: { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); Value *Op0 = vectorizeTree(E->getOperand(0)); @@ -3878,13 +4108,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::Call: { CallInst *CI = cast<CallInst>(VL0); - setInsertPointAfterBundle(E->Scalars, S); - Function *FI; + setInsertPointAfterBundle(E); + Intrinsic::ID IID = Intrinsic::not_intrinsic; - Value *ScalarArg = nullptr; - if (CI && (FI = CI->getCalledFunction())) { + if (Function *FI = CI->getCalledFunction()) IID = FI->getIntrinsicID(); - } + + Value *ScalarArg = nullptr; std::vector<Value *> OpVecs; for (int j = 0, e = CI->getNumArgOperands(); j < e; ++j) { ValueList OpVL; @@ -3926,20 +4156,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::ShuffleVector: { - assert(S.isAltShuffle() && - ((Instruction::isBinaryOp(S.getOpcode()) && - Instruction::isBinaryOp(S.getAltOpcode())) || - (Instruction::isCast(S.getOpcode()) && - Instruction::isCast(S.getAltOpcode()))) && + assert(E->isAltShuffle() && + ((Instruction::isBinaryOp(E->getOpcode()) && + Instruction::isBinaryOp(E->getAltOpcode())) || + (Instruction::isCast(E->getOpcode()) && + Instruction::isCast(E->getAltOpcode()))) && "Invalid Shuffle Vector Operand"); - Value *LHS, *RHS; - if (Instruction::isBinaryOp(S.getOpcode())) { - setInsertPointAfterBundle(E->Scalars, S); + Value *LHS = nullptr, *RHS = nullptr; + if (Instruction::isBinaryOp(E->getOpcode())) { + setInsertPointAfterBundle(E); LHS = vectorizeTree(E->getOperand(0)); RHS = vectorizeTree(E->getOperand(1)); } else { - setInsertPointAfterBundle(E->Scalars, S); + setInsertPointAfterBundle(E); LHS = vectorizeTree(E->getOperand(0)); } @@ -3949,16 +4179,16 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V0, *V1; - if (Instruction::isBinaryOp(S.getOpcode())) { + if (Instruction::isBinaryOp(E->getOpcode())) { V0 = Builder.CreateBinOp( - static_cast<Instruction::BinaryOps>(S.getOpcode()), LHS, RHS); + static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS); V1 = Builder.CreateBinOp( - static_cast<Instruction::BinaryOps>(S.getAltOpcode()), LHS, RHS); + static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS); } else { V0 = Builder.CreateCast( - static_cast<Instruction::CastOps>(S.getOpcode()), LHS, VecTy); + static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy); V1 = Builder.CreateCast( - static_cast<Instruction::CastOps>(S.getAltOpcode()), LHS, VecTy); + static_cast<Instruction::CastOps>(E->getAltOpcode()), LHS, VecTy); } // Create shuffle to take alternate operations from the vector. @@ -3969,8 +4199,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { SmallVector<Constant *, 8> Mask(e); for (unsigned i = 0; i < e; ++i) { auto *OpInst = cast<Instruction>(E->Scalars[i]); - assert(S.isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode"); - if (OpInst->getOpcode() == S.getAltOpcode()) { + assert(E->isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode"); + if (OpInst->getOpcode() == E->getAltOpcode()) { Mask[i] = Builder.getInt32(e + i); AltScalars.push_back(E->Scalars[i]); } else { @@ -4136,20 +4366,18 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { Value *Scalar = Entry->Scalars[Lane]; +#ifndef NDEBUG Type *Ty = Scalar->getType(); if (!Ty->isVoidTy()) { -#ifndef NDEBUG for (User *U : Scalar->users()) { LLVM_DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); - // It is legal to replace users in the ignorelist by undef. + // It is legal to delete users in the ignorelist. assert((getTreeEntry(U) || is_contained(UserIgnoreList, U)) && - "Replacing out-of-tree value with undef"); + "Deleting out-of-tree value"); } -#endif - Value *Undef = UndefValue::get(Ty); - Scalar->replaceAllUsesWith(Undef); } +#endif LLVM_DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n"); eraseInstruction(cast<Instruction>(Scalar)); } @@ -4165,7 +4393,7 @@ void BoUpSLP::optimizeGatherSequence() { << " gather sequences instructions.\n"); // LICM InsertElementInst sequences. for (Instruction *I : GatherSeq) { - if (!isa<InsertElementInst>(I) && !isa<ShuffleVectorInst>(I)) + if (isDeleted(I)) continue; // Check if this block is inside a loop. @@ -4219,6 +4447,8 @@ void BoUpSLP::optimizeGatherSequence() { // For all instructions in blocks containing gather sequences: for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) { Instruction *In = &*it++; + if (isDeleted(In)) + continue; if (!isa<InsertElementInst>(In) && !isa<ExtractElementInst>(In)) continue; @@ -4245,11 +4475,11 @@ void BoUpSLP::optimizeGatherSequence() { // Groups the instructions to a bundle (which is then a single scheduling entity) // and schedules instructions until the bundle gets ready. -bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, - BoUpSLP *SLP, - const InstructionsState &S) { +Optional<BoUpSLP::ScheduleData *> +BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, + const InstructionsState &S) { if (isa<PHINode>(S.OpValue)) - return true; + return nullptr; // Initialize the instruction bundle. Instruction *OldScheduleEnd = ScheduleEnd; @@ -4262,7 +4492,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, // instructions of the bundle. for (Value *V : VL) { if (!extendSchedulingRegion(V, S)) - return false; + return None; } for (Value *V : VL) { @@ -4308,6 +4538,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, resetSchedule(); initialFillReadyList(ReadyInsts); } + assert(Bundle && "Failed to find schedule bundle"); LLVM_DEBUG(dbgs() << "SLP: try schedule bundle " << *Bundle << " in block " << BB->getName() << "\n"); @@ -4329,9 +4560,9 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, } if (!Bundle->isReady()) { cancelScheduling(VL, S.OpValue); - return false; + return None; } - return true; + return Bundle; } void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, @@ -4364,7 +4595,7 @@ void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, BoUpSLP::ScheduleData *BoUpSLP::BlockScheduling::allocateScheduleDataChunks() { // Allocate a new ScheduleData for the instruction. if (ChunkPos >= ChunkSize) { - ScheduleDataChunks.push_back(llvm::make_unique<ScheduleData[]>(ChunkSize)); + ScheduleDataChunks.push_back(std::make_unique<ScheduleData[]>(ChunkSize)); ChunkPos = 0; } return &(ScheduleDataChunks.back()[ChunkPos++]); @@ -4977,7 +5208,7 @@ struct SLPVectorizer : public FunctionPass { auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -5052,7 +5283,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, // If the target claims to have no vector registers don't attempt // vectorization. - if (!TTI->getNumberOfRegisters(true)) + if (!TTI->getNumberOfRegisters(TTI->getRegisterClassForType(true))) return false; // Don't vectorize when the attribute NoImplicitFloat is used. @@ -5100,19 +5331,6 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, return Changed; } -/// Check that the Values in the slice in VL array are still existent in -/// the WeakTrackingVH array. -/// Vectorization of part of the VL array may cause later values in the VL array -/// to become invalid. We track when this has happened in the WeakTrackingVH -/// array. -static bool hasValueBeenRAUWed(ArrayRef<Value *> VL, - ArrayRef<WeakTrackingVH> VH, unsigned SliceBegin, - unsigned SliceSize) { - VL = VL.slice(SliceBegin, SliceSize); - VH = VH.slice(SliceBegin, SliceSize); - return !std::equal(VL.begin(), VL.end(), VH.begin()); -} - bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, unsigned VecRegSize) { const unsigned ChainLen = Chain.size(); @@ -5124,20 +5342,20 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, if (!isPowerOf2_32(Sz) || VF < 2) return false; - // Keep track of values that were deleted by vectorizing in the loop below. - const SmallVector<WeakTrackingVH, 8> TrackValues(Chain.begin(), Chain.end()); - bool Changed = false; // Look for profitable vectorizable trees at all offsets, starting at zero. for (unsigned i = 0, e = ChainLen; i + VF <= e; ++i) { + ArrayRef<Value *> Operands = Chain.slice(i, VF); // Check that a previous iteration of this loop did not delete the Value. - if (hasValueBeenRAUWed(Chain, TrackValues, i, VF)) + if (llvm::any_of(Operands, [&R](Value *V) { + auto *I = dyn_cast<Instruction>(V); + return I && R.isDeleted(I); + })) continue; LLVM_DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i << "\n"); - ArrayRef<Value *> Operands = Chain.slice(i, VF); R.buildTree(Operands); if (R.isTreeTinyAndNotFullyVectorizable()) @@ -5329,12 +5547,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, bool CandidateFound = false; int MinCost = SLPCostThreshold; - // Keep track of values that were deleted by vectorizing in the loop below. - SmallVector<WeakTrackingVH, 8> TrackValues(VL.begin(), VL.end()); - unsigned NextInst = 0, MaxInst = VL.size(); - for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF; - VF /= 2) { + for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF; VF /= 2) { // No actual vectorization should happen, if number of parts is the same as // provided vectorization factor (i.e. the scalar type is used for vector // code during codegen). @@ -5352,13 +5566,16 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (!isPowerOf2_32(OpsWidth) || OpsWidth < 2) break; + ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); // Check that a previous iteration of this loop did not delete the Value. - if (hasValueBeenRAUWed(VL, TrackValues, I, OpsWidth)) + if (llvm::any_of(Ops, [&R](Value *V) { + auto *I = dyn_cast<Instruction>(V); + return I && R.isDeleted(I); + })) continue; LLVM_DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations " << "\n"); - ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); R.buildTree(Ops); Optional<ArrayRef<unsigned>> Order = R.bestOrder(); @@ -5571,7 +5788,7 @@ class HorizontalReduction { Value *createOp(IRBuilder<> &Builder, const Twine &Name) const { assert(isVectorizable() && "Expected add|fadd or min/max reduction operation."); - Value *Cmp; + Value *Cmp = nullptr; switch (Kind) { case RK_Arithmetic: return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, @@ -5579,23 +5796,23 @@ class HorizontalReduction { case RK_Min: Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS) : Builder.CreateFCmpOLT(LHS, RHS); - break; + return Builder.CreateSelect(Cmp, LHS, RHS, Name); case RK_Max: Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS) : Builder.CreateFCmpOGT(LHS, RHS); - break; + return Builder.CreateSelect(Cmp, LHS, RHS, Name); case RK_UMin: assert(Opcode == Instruction::ICmp && "Expected integer types."); Cmp = Builder.CreateICmpULT(LHS, RHS); - break; + return Builder.CreateSelect(Cmp, LHS, RHS, Name); case RK_UMax: assert(Opcode == Instruction::ICmp && "Expected integer types."); Cmp = Builder.CreateICmpUGT(LHS, RHS); - break; + return Builder.CreateSelect(Cmp, LHS, RHS, Name); case RK_None: - llvm_unreachable("Unknown reduction operation."); + break; } - return Builder.CreateSelect(Cmp, LHS, RHS, Name); + llvm_unreachable("Unknown reduction operation."); } public: @@ -6203,6 +6420,8 @@ public: } if (V.isTreeTinyAndNotFullyVectorizable()) break; + if (V.isLoadCombineReductionCandidate(ReductionData.getOpcode())) + break; V.computeMinimumValueSizes(); @@ -6275,6 +6494,9 @@ public: } // Update users. ReductionRoot->replaceAllUsesWith(VectorizedTree); + // Mark all scalar reduction ops for deletion, they are replaced by the + // vector reductions. + V.eraseInstructions(IgnoreList); } return VectorizedTree != nullptr; } @@ -6323,7 +6545,7 @@ private: IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; - int ScalarReduxCost; + int ScalarReduxCost = 0; switch (ReductionData.getKind()) { case RK_Arithmetic: ScalarReduxCost = @@ -6429,10 +6651,9 @@ static bool findBuildVector(InsertElementInst *LastInsertElem, /// \return true if it matches. static bool findBuildAggregate(InsertValueInst *IV, SmallVectorImpl<Value *> &BuildVectorOpds) { - Value *V; do { BuildVectorOpds.push_back(IV->getInsertedValueOperand()); - V = IV->getAggregateOperand(); + Value *V = IV->getAggregateOperand(); if (isa<UndefValue>(V)) break; IV = dyn_cast<InsertValueInst>(V); @@ -6530,18 +6751,13 @@ static bool tryToVectorizeHorReductionOrInstOperands( // horizontal reduction. // Interrupt the process if the Root instruction itself was vectorized or all // sub-trees not higher that RecursionMaxDepth were analyzed/vectorized. - SmallVector<std::pair<WeakTrackingVH, unsigned>, 8> Stack(1, {Root, 0}); + SmallVector<std::pair<Instruction *, unsigned>, 8> Stack(1, {Root, 0}); SmallPtrSet<Value *, 8> VisitedInstrs; bool Res = false; while (!Stack.empty()) { - Value *V; + Instruction *Inst; unsigned Level; - std::tie(V, Level) = Stack.pop_back_val(); - if (!V) - continue; - auto *Inst = dyn_cast<Instruction>(V); - if (!Inst) - continue; + std::tie(Inst, Level) = Stack.pop_back_val(); auto *BI = dyn_cast<BinaryOperator>(Inst); auto *SI = dyn_cast<SelectInst>(Inst); if (BI || SI) { @@ -6582,8 +6798,8 @@ static bool tryToVectorizeHorReductionOrInstOperands( for (auto *Op : Inst->operand_values()) if (VisitedInstrs.insert(Op).second) if (auto *I = dyn_cast<Instruction>(Op)) - if (!isa<PHINode>(I) && I->getParent() == BB) - Stack.emplace_back(Op, Level); + if (!isa<PHINode>(I) && !R.isDeleted(I) && I->getParent() == BB) + Stack.emplace_back(I, Level); } return Res; } @@ -6652,11 +6868,10 @@ bool SLPVectorizerPass::vectorizeCmpInst(CmpInst *CI, BasicBlock *BB, } bool SLPVectorizerPass::vectorizeSimpleInstructions( - SmallVectorImpl<WeakVH> &Instructions, BasicBlock *BB, BoUpSLP &R) { + SmallVectorImpl<Instruction *> &Instructions, BasicBlock *BB, BoUpSLP &R) { bool OpsChanged = false; - for (auto &VH : reverse(Instructions)) { - auto *I = dyn_cast_or_null<Instruction>(VH); - if (!I) + for (auto *I : reverse(Instructions)) { + if (R.isDeleted(I)) continue; if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) OpsChanged |= vectorizeInsertValueInst(LastInsertValue, BB, R); @@ -6685,7 +6900,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (!P) break; - if (!VisitedInstrs.count(P)) + if (!VisitedInstrs.count(P) && !R.isDeleted(P)) Incoming.push_back(P); } @@ -6729,9 +6944,12 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { VisitedInstrs.clear(); - SmallVector<WeakVH, 8> PostProcessInstructions; + SmallVector<Instruction *, 8> PostProcessInstructions; SmallDenseSet<Instruction *, 4> KeyNodes; for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Skip instructions marked for the deletion. + if (R.isDeleted(&*it)) + continue; // We may go through BB multiple times so skip the one we have checked. if (!VisitedInstrs.insert(&*it).second) { if (it->use_empty() && KeyNodes.count(&*it) > 0 && @@ -6811,10 +7029,16 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { LLVM_DEBUG(dbgs() << "SLP: Analyzing a getelementptr list of length " << Entry.second.size() << ".\n"); - // We process the getelementptr list in chunks of 16 (like we do for - // stores) to minimize compile-time. - for (unsigned BI = 0, BE = Entry.second.size(); BI < BE; BI += 16) { - auto Len = std::min<unsigned>(BE - BI, 16); + // Process the GEP list in chunks suitable for the target's supported + // vector size. If a vector register can't hold 1 element, we are done. + unsigned MaxVecRegSize = R.getMaxVecRegSize(); + unsigned EltSize = R.getVectorElementSize(Entry.second[0]); + if (MaxVecRegSize < EltSize) + continue; + + unsigned MaxElts = MaxVecRegSize / EltSize; + for (unsigned BI = 0, BE = Entry.second.size(); BI < BE; BI += MaxElts) { + auto Len = std::min<unsigned>(BE - BI, MaxElts); auto GEPList = makeArrayRef(&Entry.second[BI], Len); // Initialize a set a candidate getelementptrs. Note that we use a @@ -6824,10 +7048,10 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { SetVector<Value *> Candidates(GEPList.begin(), GEPList.end()); // Some of the candidates may have already been vectorized after we - // initially collected them. If so, the WeakTrackingVHs will have - // nullified the - // values, so remove them from the set of candidates. - Candidates.remove(nullptr); + // initially collected them. If so, they are marked as deleted, so remove + // them from the set of candidates. + Candidates.remove_if( + [&R](Value *I) { return R.isDeleted(cast<Instruction>(I)); }); // Remove from the set of candidates all pairs of getelementptrs with // constant differences. Such getelementptrs are likely not good @@ -6835,18 +7059,18 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { // computed from the other. We also ensure all candidate getelementptr // indices are unique. for (int I = 0, E = GEPList.size(); I < E && Candidates.size() > 1; ++I) { - auto *GEPI = cast<GetElementPtrInst>(GEPList[I]); + auto *GEPI = GEPList[I]; if (!Candidates.count(GEPI)) continue; auto *SCEVI = SE->getSCEV(GEPList[I]); for (int J = I + 1; J < E && Candidates.size() > 1; ++J) { - auto *GEPJ = cast<GetElementPtrInst>(GEPList[J]); + auto *GEPJ = GEPList[J]; auto *SCEVJ = SE->getSCEV(GEPList[J]); if (isa<SCEVConstant>(SE->getMinusSCEV(SCEVI, SCEVJ))) { - Candidates.remove(GEPList[I]); - Candidates.remove(GEPList[J]); + Candidates.remove(GEPI); + Candidates.remove(GEPJ); } else if (GEPI->idx_begin()->get() == GEPJ->idx_begin()->get()) { - Candidates.remove(GEPList[J]); + Candidates.remove(GEPJ); } } } diff --git a/lib/Transforms/Vectorize/VPlan.cpp b/lib/Transforms/Vectorize/VPlan.cpp index 517d759d7bfc..4b80d1fb20aa 100644 --- a/lib/Transforms/Vectorize/VPlan.cpp +++ b/lib/Transforms/Vectorize/VPlan.cpp @@ -283,6 +283,12 @@ iplist<VPRecipeBase>::iterator VPRecipeBase::eraseFromParent() { return getParent()->getRecipeList().erase(getIterator()); } +void VPRecipeBase::moveAfter(VPRecipeBase *InsertPos) { + InsertPos->getParent()->getRecipeList().splice( + std::next(InsertPos->getIterator()), getParent()->getRecipeList(), + getIterator()); +} + void VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilder<> &Builder = State.Builder; @@ -309,6 +315,14 @@ void VPInstruction::generateInstruction(VPTransformState &State, State.set(this, V, Part); break; } + case Instruction::Select: { + Value *Cond = State.get(getOperand(0), Part); + Value *Op1 = State.get(getOperand(1), Part); + Value *Op2 = State.get(getOperand(2), Part); + Value *V = Builder.CreateSelect(Cond, Op1, Op2); + State.set(this, V, Part); + break; + } default: llvm_unreachable("Unsupported opcode for instruction"); } @@ -728,7 +742,7 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, auto NewIGIter = Old2New.find(IG); if (NewIGIter == Old2New.end()) Old2New[IG] = new InterleaveGroup<VPInstruction>( - IG->getFactor(), IG->isReverse(), IG->getAlignment()); + IG->getFactor(), IG->isReverse(), Align(IG->getAlignment())); if (Inst == IG->getInsertPos()) Old2New[IG]->setInsertPos(VPInst); @@ -736,7 +750,8 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, InterleaveGroupMap[VPInst] = Old2New[IG]; InterleaveGroupMap[VPInst]->insertMember( VPInst, IG->getIndex(Inst), - IG->isReverse() ? (-1) * int(IG->getFactor()) : IG->getFactor()); + Align(IG->isReverse() ? (-1) * int(IG->getFactor()) + : IG->getFactor())); } } else if (VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) visitRegion(Region, Old2New, IAI); diff --git a/lib/Transforms/Vectorize/VPlan.h b/lib/Transforms/Vectorize/VPlan.h index 8a06412ad590..44d8a198f27e 100644 --- a/lib/Transforms/Vectorize/VPlan.h +++ b/lib/Transforms/Vectorize/VPlan.h @@ -615,6 +615,10 @@ public: /// the specified recipe. void insertBefore(VPRecipeBase *InsertPos); + /// Unlink this recipe from its current VPBasicBlock and insert it into + /// the VPBasicBlock that MovePos lives in, right after MovePos. + void moveAfter(VPRecipeBase *MovePos); + /// This method unlinks 'this' from the containing basic block and deletes it. /// /// \returns an iterator pointing to the element after the erased one diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp index 7ed7d21b6caa..b22d3190d654 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp +++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp @@ -21,7 +21,7 @@ void VPlanHCFGTransforms::VPInstructionsToVPRecipes( LoopVectorizationLegality::InductionList *Inductions, SmallPtrSetImpl<Instruction *> &DeadInstructions) { - VPRegionBlock *TopRegion = dyn_cast<VPRegionBlock>(Plan->getEntry()); + auto *TopRegion = cast<VPRegionBlock>(Plan->getEntry()); ReversePostOrderTraversal<VPBlockBase *> RPOT(TopRegion->getEntry()); // Condition bit VPValues get deleted during transformation to VPRecipes. diff --git a/lib/Transforms/Vectorize/VPlanSLP.cpp b/lib/Transforms/Vectorize/VPlanSLP.cpp index e5ab24e52df6..9019ed15ec5f 100644 --- a/lib/Transforms/Vectorize/VPlanSLP.cpp +++ b/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -346,11 +346,14 @@ SmallVector<VPlanSlp::MultiNodeOpTy, 4> VPlanSlp::reorderMultiNodeOps() { void VPlanSlp::dumpBundle(ArrayRef<VPValue *> Values) { dbgs() << " Ops: "; - for (auto Op : Values) - if (auto *Instr = cast_or_null<VPInstruction>(Op)->getUnderlyingInstr()) - dbgs() << *Instr << " | "; - else - dbgs() << " nullptr | "; + for (auto Op : Values) { + if (auto *VPInstr = cast_or_null<VPInstruction>(Op)) + if (auto *Instr = VPInstr->getUnderlyingInstr()) { + dbgs() << *Instr << " | "; + continue; + } + dbgs() << " nullptr | "; + } dbgs() << "\n"; } |